Files
ant-vision-inspector/gui/pages/retrain_page.py
2026-06-10 16:18:41 +09:00

1194 lines
46 KiB
Python

# 재학습 페이지 — 이미지 로드·라벨링 UI + 학습 제어
import json
import os
import shutil
from datetime import datetime
import cv2
import numpy as np
from PyQt5.QtCore import Qt, QPoint, QRect, pyqtSignal
from PyQt5.QtGui import QPixmap, QPainter, QPen, QColor, QFont, QImage, QCursor
from PyQt5.QtWidgets import (
QWidget, QHBoxLayout, QVBoxLayout, QGroupBox,
QPushButton, QLabel, QListWidget, QListWidgetItem,
QSpinBox, QLineEdit, QProgressBar, QTextEdit,
QFormLayout, QFileDialog, QSizePolicy, QScrollArea, QMessageBox,
)
from ai.trainer import Trainer, TrainWorker
from paths import resolve_path, to_project_relative
from logger import log_train, log_action
_CLASS_COLORS = {
"스크래치": "#854F0B",
"이물": "#185FA5",
"흑점": "#3C3489",
"변형": "#A32D2D",
}
_CLASS_NAMES = list(_CLASS_COLORS.keys())
_GRP = (
"QGroupBox {"
" background:#222222; border:1px solid #333333; border-radius:6px;"
" margin-top:14px; padding:12px 10px 10px 10px;"
"}"
"QGroupBox::title {"
" color:#aaaaaa; subcontrol-origin:margin; left:10px; padding:0 4px;"
"}"
)
# ============================================================ #
# 라벨링 캔버스
# ============================================================ #
class LabelingCanvas(QWidget):
box_added = pyqtSignal(dict)
boxes_changed = pyqtSignal()
selection_changed = pyqtSignal(int)
zoom_changed = pyqtSignal(int) # 정수 퍼센트 (100 = 100%)
CLASS_MAP = {
"스크래치": 0, "이물": 1, "흑점": 2, "변형": 3,
}
CLASS_COLORS = {
"스크래치": "#FF4444", "이물": "#44AAFF",
"흑점": "#AA44FF", "변형": "#FF8800",
}
_HANDLE_SIZE = 8
# 핸들 인덱스: 0=TL, 1=T, 2=TR, 3=R, 4=BR, 5=B, 6=BL, 7=L
_HANDLE_CURSORS = [
Qt.SizeFDiagCursor, # 0 TL
Qt.SizeVerCursor, # 1 T
Qt.SizeBDiagCursor, # 2 TR
Qt.SizeHorCursor, # 3 R
Qt.SizeFDiagCursor, # 4 BR
Qt.SizeVerCursor, # 5 B
Qt.SizeBDiagCursor, # 6 BL
Qt.SizeHorCursor, # 7 L
]
def __init__(self, parent=None):
super().__init__(parent)
# 이미지 & 박스 상태
self.image: np.ndarray = None
self.boxes: list = []
self.history: list = [] # Ctrl+Z 용 (새 박스 추가 직전 스냅샷)
# 현재 그릴 클래스
self.current_class_id: int = 0
self.current_class_name: str = "스크래치"
# 선택 & 인터랙션
self.selected_index: int = -1
self.drag_mode: str = "none" # none/new_box/move/resize/pan
self.resize_handle: int = -1
self.current_rect: QRect = None
# 줌 / 패닝
self.scale: float = 1.0
self.offset_x: float = 0.0
self.offset_y: float = 0.0
self.space_pressed: bool = False
self._need_fit: bool = False
# 드래그 임시 상태
self._drag_start_img: QPoint = None
self._drag_start_screen: QPoint = None
self._move_orig_rect: QRect = None
self._resize_orig_rect: QRect = None
self._pan_start: QPoint = None
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.setStyleSheet("background:#111111;")
self.setFocusPolicy(Qt.StrongFocus)
self.setMouseTracking(True)
self.setCursor(Qt.CrossCursor)
# ── 좌표 변환 ──────────────────────────────────────────────────── #
def screen_to_image(self, pos: QPoint):
if self.image is None or self.scale == 0:
return None
img_h, img_w = self.image.shape[:2]
x = (pos.x() - self.offset_x) / self.scale
y = (pos.y() - self.offset_y) / self.scale
return QPoint(int(max(0.0, min(x, img_w - 1))),
int(max(0.0, min(y, img_h - 1))))
def image_to_screen_rect(self, rect: QRect) -> QRect:
return QRect(
int(rect.x() * self.scale + self.offset_x),
int(rect.y() * self.scale + self.offset_y),
int(rect.width() * self.scale),
int(rect.height() * self.scale),
)
# ── 핸들 / 히트테스트 ──────────────────────────────────────────── #
def _handle_rects(self, sr: QRect) -> list:
hs = self._HANDLE_SIZE
hh = hs // 2
cx = sr.center().x()
cy = sr.center().y()
l, t, r, b = sr.left(), sr.top(), sr.right(), sr.bottom()
return [
QRect(l - hh, t - hh, hs, hs),
QRect(cx - hh, t - hh, hs, hs),
QRect(r - hh, t - hh, hs, hs),
QRect(r - hh, cy - hh, hs, hs),
QRect(r - hh, b - hh, hs, hs),
QRect(cx - hh, b - hh, hs, hs),
QRect(l - hh, b - hh, hs, hs),
QRect(l - hh, cy - hh, hs, hs),
]
def get_handle_at(self, pos: QPoint, sr: QRect) -> int:
for i, hr in enumerate(self._handle_rects(sr)):
if hr.adjusted(-2, -2, 2, 2).contains(pos):
return i
return -1
def get_box_at(self, pos: QPoint) -> int:
result = -1
for i, box in enumerate(self.boxes):
if self.image_to_screen_rect(box["rect"]).contains(pos):
result = i
return result
# ── 히스토리 / 복사 ────────────────────────────────────────────── #
def _copy_boxes(self) -> list:
return [
{
"class_id": b["class_id"],
"class_name": b["class_name"],
"rect": QRect(b["rect"].x(), b["rect"].y(),
b["rect"].width(), b["rect"].height()),
}
for b in self.boxes
]
def _save_history(self):
self.history.append(self._copy_boxes())
if len(self.history) > 50:
self.history.pop(0)
# ── 커서 업데이트 ──────────────────────────────────────────────── #
def _update_cursor(self, pos: QPoint):
if self.space_pressed:
self.setCursor(Qt.OpenHandCursor)
return
if 0 <= self.selected_index < len(self.boxes):
sr = self.image_to_screen_rect(self.boxes[self.selected_index]["rect"])
h = self.get_handle_at(pos, sr)
if h >= 0:
self.setCursor(self._HANDLE_CURSORS[h])
return
if sr.contains(pos):
self.setCursor(Qt.SizeAllCursor)
return
if self.get_box_at(pos) >= 0:
self.setCursor(Qt.SizeAllCursor)
return
self.setCursor(Qt.CrossCursor)
# ── 퍼블릭 API ─────────────────────────────────────────────────── #
def set_image(self, img: np.ndarray):
self.image = img
self.boxes = []
self.history = []
self.selected_index = -1
self.drag_mode = "none"
self.current_rect = None
self._need_fit = True
self.fit_to_window()
self.selection_changed.emit(-1)
def set_class(self, class_id: int, class_name: str):
self.current_class_id = class_id
self.current_class_name = class_name
def change_selected_class(self, class_id: int, class_name: str):
if 0 <= self.selected_index < len(self.boxes):
self._save_history()
self.boxes[self.selected_index]["class_id"] = class_id
self.boxes[self.selected_index]["class_name"] = class_name
self.boxes_changed.emit()
self.update()
def delete_selected_box(self, index: int = -1):
if index < 0:
index = self.selected_index
if not (0 <= index < len(self.boxes)):
return
self.boxes.pop(index)
if self.selected_index >= len(self.boxes):
self.selected_index = -1
elif self.selected_index == index:
self.selected_index = -1
self.selection_changed.emit(self.selected_index)
self.boxes_changed.emit()
self.update()
def clear_boxes(self):
self.boxes = []
self.selected_index = -1
self.selection_changed.emit(-1)
self.boxes_changed.emit()
self.update()
def fit_to_window(self):
if self.image is None or self.width() == 0 or self.height() == 0:
return
img_h, img_w = self.image.shape[:2]
self.scale = min(self.width() / img_w, self.height() / img_h)
self.offset_x = (self.width() - img_w * self.scale) / 2
self.offset_y = (self.height() - img_h * self.scale) / 2
self._need_fit = False
self.zoom_changed.emit(int(self.scale * 100))
self.update()
def get_yolo_labels(self) -> list:
if self.image is None:
return []
img_h, img_w = self.image.shape[:2]
result = []
for box in self.boxes:
r = box["rect"]
cx = (r.x() + r.width() / 2) / img_w
cy = (r.y() + r.height() / 2) / img_h
nw = r.width() / img_w
nh = r.height() / img_h
result.append(
f"{box['class_id']} {cx:.6f} {cy:.6f} {nw:.6f} {nh:.6f}"
)
return result
# ── 마우스 이벤트 ──────────────────────────────────────────────── #
def mousePressEvent(self, e):
if self.image is None:
return
pos = e.pos()
self.setFocus()
# 패닝: 스페이스+좌클릭 or 중간 버튼
if (e.button() == Qt.MiddleButton or
(e.button() == Qt.LeftButton and self.space_pressed)):
self.drag_mode = "pan"
self._pan_start = pos
self.setCursor(Qt.ClosedHandCursor)
return
if e.button() != Qt.LeftButton:
return
# 선택된 박스의 핸들 먼저 확인
if 0 <= self.selected_index < len(self.boxes):
sr = self.image_to_screen_rect(self.boxes[self.selected_index]["rect"])
h = self.get_handle_at(pos, sr)
if h >= 0:
self.drag_mode = "resize"
self.resize_handle = h
self._drag_start_screen = pos
self._resize_orig_rect = QRect(self.boxes[self.selected_index]["rect"])
return
if sr.contains(pos):
self.drag_mode = "move"
self._drag_start_screen = pos
self._move_orig_rect = QRect(self.boxes[self.selected_index]["rect"])
return
# 다른 박스 히트테스트
hit = self.get_box_at(pos)
if hit >= 0:
prev = self.selected_index
self.selected_index = hit
if hit != prev:
self.selection_changed.emit(hit)
sr = self.image_to_screen_rect(self.boxes[hit]["rect"])
h = self.get_handle_at(pos, sr)
if h >= 0:
self.drag_mode = "resize"
self.resize_handle = h
self._drag_start_screen = pos
self._resize_orig_rect = QRect(self.boxes[hit]["rect"])
else:
self.drag_mode = "move"
self._drag_start_screen = pos
self._move_orig_rect = QRect(self.boxes[hit]["rect"])
self.update()
return
# 빈 공간 클릭 → 선택 해제 + 새 박스 그리기 시작
if self.selected_index >= 0:
self.selected_index = -1
self.selection_changed.emit(-1)
img_pt = self.screen_to_image(pos)
if img_pt is not None:
self._drag_start_img = img_pt
self.drag_mode = "new_box"
self.current_rect = None
self.update()
def mouseMoveEvent(self, e):
pos = e.pos()
if self.drag_mode == "pan":
self.offset_x += pos.x() - self._pan_start.x()
self.offset_y += pos.y() - self._pan_start.y()
self._pan_start = pos
self.update()
return
if self.drag_mode == "new_box":
ip = self.screen_to_image(pos)
if ip is not None and self._drag_start_img is not None:
self.current_rect = QRect(self._drag_start_img, ip).normalized()
self.update()
return
if self.drag_mode == "move":
self._do_move(pos)
self.update()
return
if self.drag_mode == "resize":
self._do_resize(pos)
self.update()
return
self._update_cursor(pos)
def mouseReleaseEvent(self, e):
pos = e.pos()
if self.drag_mode == "pan" and e.button() in (Qt.LeftButton, Qt.MiddleButton):
self.drag_mode = "none"
self._update_cursor(pos)
return
if e.button() != Qt.LeftButton:
return
if self.drag_mode == "new_box":
if (self.current_rect is not None
and self.current_rect.width() >= 10
and self.current_rect.height() >= 10):
self._save_history()
box = {
"class_id": self.current_class_id,
"class_name": self.current_class_name,
"rect": self.current_rect,
}
self.boxes.append(box)
self.selected_index = len(self.boxes) - 1
self.selection_changed.emit(self.selected_index)
self.box_added.emit(box)
self.current_rect = None
self._drag_start_img = None
elif self.drag_mode in ("move", "resize"):
self.boxes_changed.emit()
self.drag_mode = "none"
self._update_cursor(pos)
self.update()
def mouseDoubleClickEvent(self, e):
if e.button() == Qt.LeftButton:
self.fit_to_window()
def wheelEvent(self, e):
if self.image is None:
return
pos = e.pos()
delta = e.angleDelta().y()
factor = 1.15 if delta > 0 else 1.0 / 1.15
new_scale = max(0.5, min(self.scale * factor, 5.0))
if new_scale == self.scale:
return
img_x = (pos.x() - self.offset_x) / self.scale
img_y = (pos.y() - self.offset_y) / self.scale
self.scale = new_scale
self.offset_x = pos.x() - img_x * self.scale
self.offset_y = pos.y() - img_y * self.scale
self.zoom_changed.emit(int(self.scale * 100))
self.update()
# ── 키보드 이벤트 ──────────────────────────────────────────────── #
def keyPressEvent(self, e):
key = e.key()
if key == Qt.Key_Space and not e.isAutoRepeat():
self.space_pressed = True
self.setCursor(Qt.OpenHandCursor)
return
if key in (Qt.Key_Delete, Qt.Key_Backspace):
if self.selected_index >= 0:
self.delete_selected_box(self.selected_index)
return
if key == Qt.Key_Z and (e.modifiers() & Qt.ControlModifier):
if self.history:
self.boxes = self.history.pop()
self.selected_index = min(self.selected_index, len(self.boxes) - 1)
if self.selected_index < 0:
self.selected_index = -1
self.selection_changed.emit(self.selected_index)
self.boxes_changed.emit()
self.update()
return
if key == Qt.Key_Escape:
self.selected_index = -1
self.drag_mode = "none"
self.current_rect = None
self._drag_start_img = None
self.selection_changed.emit(-1)
self._update_cursor(self.mapFromGlobal(QCursor.pos()))
self.update()
return
super().keyPressEvent(e)
def keyReleaseEvent(self, e):
if e.key() == Qt.Key_Space and not e.isAutoRepeat():
self.space_pressed = False
if self.drag_mode == "pan":
self.drag_mode = "none"
self._update_cursor(self.mapFromGlobal(QCursor.pos()))
super().keyReleaseEvent(e)
# ── paintEvent ─────────────────────────────────────────────────── #
def paintEvent(self, _e):
painter = QPainter(self)
painter.fillRect(self.rect(), QColor("#111111"))
if self.image is None:
painter.setPen(QColor("#555555"))
painter.setFont(QFont("Arial", 14))
painter.drawText(self.rect(), Qt.AlignCenter, "이미지를 선택하세요")
painter.end()
return
# numpy BGR → QPixmap
rgb = np.ascontiguousarray(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
ih, iw, ch = rgb.shape
qimg = QImage(rgb.data, iw, ih, ch * iw, QImage.Format_RGB888)
pix = QPixmap.fromImage(qimg)
sw = max(1, int(iw * self.scale))
sh = max(1, int(ih * self.scale))
scaled = pix.scaled(sw, sh, Qt.IgnoreAspectRatio, Qt.SmoothTransformation)
painter.drawPixmap(int(self.offset_x), int(self.offset_y), scaled)
# 박스 그리기
painter.setFont(QFont("Arial", 10, QFont.Bold))
for i, box in enumerate(self.boxes):
color = QColor(self.CLASS_COLORS.get(box["class_name"], "#ffffff"))
sr = self.image_to_screen_rect(box["rect"])
is_sel = (i == self.selected_index)
if is_sel:
fill = QColor(color); fill.setAlpha(40)
painter.fillRect(sr, fill)
painter.setPen(QPen(QColor("#ffffff"), 3))
painter.drawRect(sr)
# 8개 핸들
painter.setPen(QPen(QColor("#555555"), 1))
for hr in self._handle_rects(sr):
painter.fillRect(hr, QColor("#ffffff"))
painter.drawRect(hr)
lbl_bg = QColor("#ffffff")
lbl_txt = QColor("#222222")
else:
painter.setPen(QPen(color, 2))
painter.drawRect(sr)
lbl_bg = color
lbl_txt = QColor("#ffffff")
# 클래스명 레이블
text = box["class_name"]
lbl_w = max(len(text) * 11, 60)
lbl_y = sr.y() - 18 if sr.y() >= 18 else sr.y()
lbl_r = QRect(sr.x(), lbl_y, lbl_w, 18)
painter.fillRect(lbl_r, lbl_bg)
painter.setPen(lbl_txt)
painter.drawText(lbl_r.x() + 3, lbl_r.y() + 13, text)
# 드래그 중인 새 박스 (점선 흰색)
if self.drag_mode == "new_box" and self.current_rect is not None:
sr = self.image_to_screen_rect(self.current_rect)
painter.setPen(QPen(QColor("#ffffff"), 2, Qt.DashLine))
painter.setBrush(Qt.NoBrush)
painter.drawRect(sr)
painter.end()
def resizeEvent(self, e):
super().resizeEvent(e)
if self._need_fit and self.width() > 0 and self.height() > 0:
self.fit_to_window()
else:
self.update()
# ── 이동 / 리사이즈 헬퍼 ──────────────────────────────────────── #
def _do_move(self, screen_pos: QPoint):
if self.image is None or self.selected_index < 0:
return
img_h, img_w = self.image.shape[:2]
dx = (screen_pos.x() - self._drag_start_screen.x()) / self.scale
dy = (screen_pos.y() - self._drag_start_screen.y()) / self.scale
orig = self._move_orig_rect
nx = max(0, min(int(orig.x() + dx), img_w - orig.width()))
ny = max(0, min(int(orig.y() + dy), img_h - orig.height()))
self.boxes[self.selected_index]["rect"] = QRect(nx, ny, orig.width(), orig.height())
def _do_resize(self, screen_pos: QPoint):
if self.image is None or self.selected_index < 0:
return
img_h, img_w = self.image.shape[:2]
dx = (screen_pos.x() - self._drag_start_screen.x()) / self.scale
dy = (screen_pos.y() - self._drag_start_screen.y()) / self.scale
orig = self._resize_orig_rect
x1, y1 = orig.x(), orig.y()
x2, y2 = orig.x() + orig.width(), orig.y() + orig.height()
h = self.resize_handle
if h in (0, 6, 7): x1 = int(orig.x() + dx)
if h in (2, 3, 4): x2 = int(orig.x() + orig.width() + dx)
if h in (0, 1, 2): y1 = int(orig.y() + dy)
if h in (4, 5, 6): y2 = int(orig.y() + orig.height() + dy)
if x2 - x1 < 10:
if h in (0, 6, 7): x1 = x2 - 10
else: x2 = x1 + 10
if y2 - y1 < 10:
if h in (0, 1, 2): y1 = y2 - 10
else: y2 = y1 + 10
x1 = max(0, x1); y1 = max(0, y1)
x2 = min(img_w, x2); y2 = min(img_h, y2)
self.boxes[self.selected_index]["rect"] = QRect(x1, y1, x2 - x1, y2 - y1)
# ============================================================ #
# 재학습 페이지
# ============================================================ #
class RetrainPage(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self._img_dir = ""
self._img_files = []
self._cur_path = ""
self._trainer = Trainer()
self._worker = None
self._build_ui()
# ================================================================ #
# 레이아웃
# ================================================================ #
def _build_ui(self):
root = QHBoxLayout(self)
root.setContentsMargins(0, 0, 0, 0)
root.setSpacing(0)
root.addWidget(self._build_left(), stretch=2)
root.addWidget(self._build_right(), stretch=3)
def _build_left(self) -> QScrollArea:
scroll = QScrollArea()
scroll.setWidgetResizable(True)
scroll.setFrameShape(QScrollArea.NoFrame)
scroll.setStyleSheet("background:#1a1a1a;")
inner = QWidget()
inner.setStyleSheet("background:#1a1a1a;")
lay = QVBoxLayout(inner)
lay.setContentsMargins(12, 12, 8, 12)
lay.setSpacing(0)
lay.addWidget(self._build_img_load_section())
lay.addWidget(self._build_class_section())
lay.addWidget(self._build_dataset_section())
lay.addStretch()
scroll.setWidget(inner)
return scroll
# ── 섹션 1: 이미지 로드 ────────────────────────────────────────── #
def _build_img_load_section(self) -> QGroupBox:
g = QGroupBox("이미지 로드")
g.setStyleSheet(_GRP)
lay = QVBoxLayout(g)
lay.setSpacing(8)
btn = QPushButton("이미지 폴더 선택")
btn.setFixedHeight(56)
btn.setStyleSheet(_btn_style("#1a3a5c"))
btn.clicked.connect(self._on_select_folder)
self._folder_lbl = QLabel("폴더를 선택하세요")
self._folder_lbl.setStyleSheet("color:#777777; font-size:12px;")
self._folder_lbl.setWordWrap(True)
self._img_list = QListWidget()
self._img_list.setMinimumHeight(160)
self._img_list.setStyleSheet("""
QListWidget {
background:#1a1a1a; border:1px solid #333333;
border-radius:4px; font-size:13px; color:#cccccc;
}
QListWidget::item { padding:4px 8px; border-bottom:1px solid #2a2a2a; }
QListWidget::item:selected { background:#185FA5; color:#ffffff; }
""")
self._img_list.currentRowChanged.connect(self._on_img_selected)
lay.addWidget(btn)
lay.addWidget(self._folder_lbl)
lay.addWidget(self._img_list, stretch=1)
return g
# ── 섹션 2: 불량 클래스 선택 ───────────────────────────────────── #
def _build_class_section(self) -> QGroupBox:
g = QGroupBox("불량 클래스 선택")
g.setStyleSheet(_GRP)
lay = QVBoxLayout(g)
lay.setSpacing(6)
self._class_btns: dict[str, QPushButton] = {}
for cls_name, color in _CLASS_COLORS.items():
btn = QPushButton(cls_name)
btn.setFixedHeight(56)
btn.setCheckable(True)
btn.setStyleSheet(_cls_btn_style(color, checked=False))
btn.clicked.connect(lambda _, c=cls_name: self._on_class_select(c))
self._class_btns[cls_name] = btn
lay.addWidget(btn)
first = _CLASS_NAMES[0]
self._class_btns[first].setChecked(True)
self._class_btns[first].setStyleSheet(_cls_btn_style(_CLASS_COLORS[first], checked=True))
self._active_cls = first
return g
# ── 섹션 3: 데이터셋 설정 ──────────────────────────────────────── #
def _build_dataset_section(self) -> QGroupBox:
g = QGroupBox("데이터셋 설정")
g.setStyleSheet(_GRP)
lay = QVBoxLayout(g)
lay.setSpacing(8)
form = QFormLayout()
form.setHorizontalSpacing(12)
form.setVerticalSpacing(8)
self._epoch_spin = QSpinBox()
self._epoch_spin.setRange(1, 300)
self._epoch_spin.setValue(100)
self._epoch_spin.setStyleSheet(_spinbox_style())
self._batch_spin = QSpinBox()
self._batch_spin.setRange(1, 64)
self._batch_spin.setValue(16)
self._batch_spin.setStyleSheet(_spinbox_style())
form.addRow("Epoch", self._epoch_spin)
form.addRow("Batch", self._batch_spin)
lay.addLayout(form)
path_lbl = QLabel("모델 저장 경로")
path_lbl.setStyleSheet("color:#888888; font-size:13px;")
path_row = QHBoxLayout()
self._save_path_edit = QLineEdit("ai/models/best.pt")
self._save_path_edit.setFixedHeight(44)
self._save_path_edit.setStyleSheet(
"background:#2a2a2a; color:#ffffff; border:1px solid #555555;"
"border-radius:4px; padding:0 8px; font-size:13px;"
)
btn_browse = QPushButton("찾기")
btn_browse.setFixedSize(64, 44)
btn_browse.setStyleSheet(_btn_style("#333333", font_size=13))
btn_browse.clicked.connect(self._on_browse_save)
path_row.addWidget(self._save_path_edit, stretch=1)
path_row.addWidget(btn_browse)
lay.addWidget(path_lbl)
lay.addLayout(path_row)
return g
# ── 우측 패널 ──────────────────────────────────────────────────── #
def _build_right(self) -> QWidget:
w = QWidget()
w.setStyleSheet("background:#1a1a1a;")
lay = QVBoxLayout(w)
lay.setContentsMargins(8, 12, 12, 12)
lay.setSpacing(0)
lay.addWidget(self._build_labeling_section(), stretch=55)
lay.addWidget(self._build_train_section(), stretch=25)
lay.addWidget(self._build_save_section(), stretch=20)
return w
def _build_labeling_section(self) -> QGroupBox:
g = QGroupBox("이미지 표시 / 라벨링")
g.setStyleSheet(_GRP)
outer = QVBoxLayout(g)
outer.setSpacing(6)
# ── 툴바 (줌 표시 + 초기화 버튼) ──
toolbar = QHBoxLayout()
self._zoom_lbl = QLabel("100%")
self._zoom_lbl.setStyleSheet(
"color:#aaaaaa; font-size:12px; min-width:55px;"
)
hint = QLabel("더블클릭: fit | 휠: 줌 | Space+드래그: 패닝 | Del: 박스삭제 | Ctrl+Z: 실행취소")
hint.setStyleSheet("color:#555555; font-size:11px;")
btn_fit = QPushButton("초기화")
btn_fit.setFixedHeight(28)
btn_fit.setStyleSheet(_btn_style("#333333", font_size=12))
btn_fit.clicked.connect(lambda: self._canvas.fit_to_window())
toolbar.addWidget(self._zoom_lbl)
toolbar.addWidget(hint)
toolbar.addStretch()
toolbar.addWidget(btn_fit)
outer.addLayout(toolbar)
# ── 캔버스 + 박스 목록 ──
main_row = QHBoxLayout()
self._canvas = LabelingCanvas()
self._canvas.box_added.connect(lambda _: self._refresh_box_list())
self._canvas.boxes_changed.connect(self._refresh_box_list)
self._canvas.selection_changed.connect(self._on_canvas_selection_changed)
self._canvas.zoom_changed.connect(lambda pct: self._zoom_lbl.setText(f"{pct}%"))
main_row.addWidget(self._canvas, stretch=3)
side = QVBoxLayout()
side.setSpacing(6)
box_lbl = QLabel("박스 목록")
box_lbl.setStyleSheet("color:#888888; font-size:13px;")
self._box_list = QListWidget()
self._box_list.setStyleSheet("""
QListWidget {
background:#1a1a1a; border:1px solid #333333;
border-radius:4px; font-size:12px; color:#cccccc;
}
QListWidget::item { padding:3px 6px; }
QListWidget::item:selected { background:#3C3489; }
""")
self._box_list.currentRowChanged.connect(self._on_box_list_select)
btn_del = QPushButton("박스 삭제")
btn_del.setFixedHeight(44)
btn_del.setStyleSheet(_btn_style("#5c1a1a", font_size=13))
btn_del.clicked.connect(self._on_del_box)
side.addWidget(box_lbl)
side.addWidget(self._box_list, stretch=1)
side.addWidget(btn_del)
main_row.addLayout(side, stretch=1)
outer.addLayout(main_row, stretch=1)
btn_save = QPushButton("라벨 저장 (YOLO .txt)")
btn_save.setFixedHeight(50)
btn_save.setStyleSheet(_btn_style("#2e5c2e", font_size=14))
btn_save.clicked.connect(self._on_label_save)
outer.addWidget(btn_save)
return g
def _build_train_section(self) -> QGroupBox:
g = QGroupBox("학습 제어")
g.setStyleSheet(_GRP)
lay = QVBoxLayout(g)
lay.setSpacing(6)
btn_row = QHBoxLayout()
self._start_btn = QPushButton("학습 시작")
self._start_btn.setFixedHeight(70)
self._start_btn.setStyleSheet(_btn_style("#1D9E75", font_size=17, bold=True))
self._start_btn.clicked.connect(self._on_train_start)
self._stop_btn = QPushButton("학습 중지")
self._stop_btn.setFixedHeight(70)
self._stop_btn.setEnabled(False)
self._stop_btn.setStyleSheet(_btn_style("#A32D2D", font_size=17, bold=True))
self._stop_btn.clicked.connect(self._on_train_stop)
btn_row.addWidget(self._start_btn)
btn_row.addWidget(self._stop_btn)
self._progress_bar = QProgressBar()
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
self._progress_bar.setFixedHeight(22)
self._progress_bar.setStyleSheet("""
QProgressBar {
background:#2a2a2a; border:1px solid #555555;
border-radius:4px; text-align:center; color:#ffffff; font-size:13px;
}
QProgressBar::chunk { background:#1D9E75; border-radius:4px; }
""")
self._status_lbl = QLabel("대기 중")
self._status_lbl.setStyleSheet("color:#888888; font-size:14px;")
self._status_lbl.setAlignment(Qt.AlignCenter)
self._log_box = QTextEdit()
self._log_box.setReadOnly(True)
self._log_box.setFixedHeight(80)
self._log_box.setStyleSheet(
"background:#111111; color:#aaaaaa; border:1px solid #333333;"
"border-radius:4px; font-size:12px; font-family:Consolas,monospace;"
)
lay.addLayout(btn_row)
lay.addWidget(self._progress_bar)
lay.addWidget(self._status_lbl)
lay.addWidget(self._log_box)
return g
def _build_save_section(self) -> QGroupBox:
g = QGroupBox("모델 저장")
g.setStyleSheet(_GRP)
lay = QVBoxLayout(g)
lay.setSpacing(8)
self._save_btn = QPushButton("모델 저장")
self._save_btn.setFixedHeight(56)
self._save_btn.setEnabled(False)
self._save_btn.setStyleSheet(_btn_style("#1a4d1a", font_size=15, bold=True))
self._save_btn.clicked.connect(self._on_model_save)
self._save_result_lbl = QLabel("")
self._save_result_lbl.setAlignment(Qt.AlignCenter)
self._save_result_lbl.setStyleSheet("color:#aaaaaa; font-size:13px;")
lay.addWidget(self._save_btn)
lay.addWidget(self._save_result_lbl)
return g
# ================================================================ #
# 슬롯
# ================================================================ #
def _on_select_folder(self):
folder = QFileDialog.getExistingDirectory(self, "이미지 폴더 선택", "")
if not folder:
return
self._img_dir = folder
self._folder_lbl.setText(folder)
exts = {".jpg", ".jpeg", ".png", ".bmp"}
self._img_files = sorted(
f for f in os.listdir(folder)
if os.path.splitext(f)[1].lower() in exts
)
self._img_list.clear()
for fname in self._img_files:
item = QListWidgetItem(fname)
item.setSizeHint(item.sizeHint().__class__(0, 44))
self._img_list.addItem(item)
print(f"[재학습] 폴더: {folder} ({len(self._img_files)}개)")
def _on_img_selected(self, row: int):
if row < 0 or row >= len(self._img_files):
return
path = os.path.join(self._img_dir, self._img_files[row])
self._cur_path = path
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR)
if img is None:
print(f"[재학습] 이미지 로드 실패: {path}")
return
self._canvas.set_image(img)
class_id = LabelingCanvas.CLASS_MAP.get(self._active_cls, 0)
self._canvas.set_class(class_id, self._active_cls)
txt_path = os.path.splitext(path)[0] + ".txt"
if os.path.exists(txt_path):
self._load_yolo_labels(txt_path, img.shape)
self._refresh_box_list()
def _load_yolo_labels(self, txt_path: str, img_shape):
img_h, img_w = img_shape[:2]
id_to_name = {v: k for k, v in LabelingCanvas.CLASS_MAP.items()}
try:
with open(txt_path, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 5:
continue
cid = int(parts[0])
cx, cy, nw, nh = map(float, parts[1:])
x = int((cx - nw / 2) * img_w)
y = int((cy - nh / 2) * img_h)
w = int(nw * img_w)
h = int(nh * img_h)
self._canvas.boxes.append({
"class_id": cid,
"class_name": id_to_name.get(cid, "스크래치"),
"rect": QRect(x, y, w, h),
})
self._canvas.update()
except Exception as err:
print(f"[재학습] 라벨 로드 실패: {err}")
def _on_class_select(self, cls_name: str):
for name, btn in self._class_btns.items():
is_active = (name == cls_name)
btn.setChecked(is_active)
btn.setStyleSheet(_cls_btn_style(_CLASS_COLORS[name], checked=is_active))
self._active_cls = cls_name
class_id = LabelingCanvas.CLASS_MAP.get(cls_name, 0)
self._canvas.set_class(class_id, cls_name)
# 선택된 박스가 있으면 클래스 변경
if self._canvas.selected_index >= 0:
self._canvas.change_selected_class(class_id, cls_name)
# ── 캔버스 ↔ 박스 목록 동기화 ─────────────────────────────────── #
def _on_canvas_selection_changed(self, index: int):
"""캔버스 선택 변경 → 리스트 하이라이트 (루프 방지)."""
self._box_list.blockSignals(True)
self._box_list.setCurrentRow(index)
self._box_list.blockSignals(False)
def _on_box_list_select(self, row: int):
"""리스트 클릭 → 캔버스 선택 변경."""
if row != self._canvas.selected_index:
self._canvas.selected_index = row
self._canvas.update()
def _on_del_box(self):
idx = self._canvas.selected_index
if idx < 0:
idx = self._box_list.currentRow()
if idx >= 0:
self._canvas.delete_selected_box(idx)
# _refresh_box_list는 boxes_changed 시그널로 자동 호출
def _on_label_save(self):
if not self._cur_path:
QMessageBox.warning(self, "경고", "이미지를 먼저 선택하세요.")
return
labels = self._canvas.get_yolo_labels()
if not labels:
QMessageBox.warning(self, "경고", "박스를 먼저 그려주세요.")
return
txt_path = os.path.splitext(self._cur_path)[0] + ".txt"
try:
with open(txt_path, "w", encoding="utf-8") as f:
f.write("\n".join(labels))
QMessageBox.information(self, "저장 완료", f"라벨 저장 완료:\n{txt_path}")
self._log_append(f"라벨 저장: {os.path.basename(txt_path)}")
except Exception as err:
QMessageBox.critical(self, "저장 실패", str(err))
def _on_browse_save(self):
path, _ = QFileDialog.getSaveFileName(
self, "모델 저장 경로", "ai/models/best.pt", "PyTorch 모델 (*.pt)"
)
if path:
self._save_path_edit.setText(path)
def _on_train_start(self):
if not self._img_dir:
QMessageBox.warning(self, "경고", "이미지 폴더를 먼저 선택하세요.")
return
label_files = [
f for f in os.listdir(self._img_dir) if f.lower().endswith(".txt")
]
if not label_files:
QMessageBox.warning(
self, "경고", "라벨 파일이 없습니다. 먼저 라벨링해주세요."
)
return
self._progress_bar.setValue(0)
self._progress_bar.setFormat("0%")
self._log_box.clear()
self._status_lbl.setText("학습 중...")
self._status_lbl.setStyleSheet("color:#1D9E75; font-size:14px; font-weight:bold;")
self._start_btn.setEnabled(False)
self._stop_btn.setEnabled(True)
self._save_btn.setEnabled(False)
save_path = self._save_path_edit.text().strip() or "ai/models/best.pt"
epochs = self._epoch_spin.value()
batch = self._batch_spin.value()
log_action(
f"[재학습] 학습 시작 | epochs={epochs} | batch={batch} | 저장={save_path}"
)
log_train(
f"학습 시작 | 데이터셋={self._img_dir} | epochs={epochs} | "
f"batch={batch} | 저장={save_path} | 라벨파일={len(label_files)}"
)
self._worker = TrainWorker(
self._trainer,
self._img_dir,
epochs,
batch,
save_path,
)
self._worker.log_signal.connect(self._on_log)
self._worker.progress_signal.connect(self._on_progress)
self._worker.finished_signal.connect(self._on_finished)
self._worker.start()
def _on_train_stop(self):
if self._worker and self._worker.isRunning():
self._worker.stop_subprocess()
self._worker.terminate()
self._trainer.stop()
self._log_append("학습 중지됨")
log_action("[재학습] 학습 중지 (사용자)")
log_train("학습 중지됨 (사용자 중지)")
self._status_lbl.setText("중지됨")
self._status_lbl.setStyleSheet("color:#A32D2D; font-size:14px; font-weight:bold;")
self._start_btn.setEnabled(True)
self._stop_btn.setEnabled(False)
def _on_log(self, message: str):
ts = datetime.now().strftime("%H:%M:%S")
self._log_box.append(f"[{ts}] {message}")
sb = self._log_box.verticalScrollBar()
sb.setValue(sb.maximum())
def _on_progress(self, value: int):
self._progress_bar.setValue(value)
self._progress_bar.setFormat(f"{value}%")
def _on_finished(self, success: bool):
self._start_btn.setEnabled(True)
self._stop_btn.setEnabled(False)
if success:
self._save_btn.setEnabled(True)
self._status_lbl.setText("학습 완료")
self._status_lbl.setStyleSheet("color:#22cc55; font-size:14px; font-weight:bold;")
log_train("학습 완료")
QMessageBox.information(self, "완료", "학습 완료!")
else:
self._status_lbl.setText("학습 실패")
self._status_lbl.setStyleSheet("color:#A32D2D; font-size:14px; font-weight:bold;")
log_train("학습 실패")
QMessageBox.critical(self, "실패", "학습 실패. 로그를 확인해주세요.")
def _on_model_save(self):
dest_input = self._save_path_edit.text().strip() or "ai/models/best.pt"
dest = resolve_path(dest_input)
src = resolve_path("ai/runs/train/weights/best.pt")
if not os.path.exists(src):
QMessageBox.warning(self, "경고", f"학습 결과 파일을 찾을 수 없습니다:\n{src}")
return
try:
os.makedirs(os.path.dirname(dest), exist_ok=True)
shutil.copy(src, dest)
saved_path = to_project_relative(dest)
# config.json ai.model_path 업데이트
config_path = resolve_path("config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as fh:
cfg = json.load(fh)
cfg.setdefault("ai", {})["model_path"] = saved_path
with open(config_path, "w", encoding="utf-8") as fh:
json.dump(cfg, fh, ensure_ascii=False, indent=2)
self._save_result_lbl.setText(f"저장 완료: {saved_path}")
self._save_result_lbl.setStyleSheet("color:#22cc55; font-size:13px;")
log_action(f"[재학습] 모델 저장 완료 → {saved_path}")
log_train(f"모델 저장 완료 | {saved_path}")
QMessageBox.information(self, "완료", f"모델 저장 완료\n{dest}")
except Exception as e:
log_train(f"모델 저장 실패 | {e}")
QMessageBox.critical(self, "저장 실패", str(e))
# ================================================================ #
# 헬퍼
# ================================================================ #
def _refresh_box_list(self):
self._box_list.blockSignals(True)
self._box_list.clear()
for box in self._canvas.boxes:
r = box["rect"]
color = LabelingCanvas.CLASS_COLORS.get(box["class_name"], "#ffffff")
item = QListWidgetItem(
f"[{box['class_name']}] x:{r.x()} y:{r.y()} w:{r.width()} h:{r.height()}"
)
item.setForeground(QColor(color))
self._box_list.addItem(item)
self._box_list.setCurrentRow(self._canvas.selected_index)
self._box_list.blockSignals(False)
def _log_append(self, text: str):
ts = datetime.now().strftime("%H:%M:%S")
self._log_box.append(f"[{ts}] {text}")
sb = self._log_box.verticalScrollBar()
sb.setValue(sb.maximum())
# ============================================================ #
# 스타일 헬퍼
# ============================================================ #
def _btn_style(bg: str, font_size: int = 14, bold: bool = False) -> str:
weight = "bold" if bold else "normal"
return (
f"QPushButton {{"
f" background:{bg}; color:#ffffff; border:none; border-radius:4px;"
f" font-size:{font_size}px; font-weight:{weight}; min-height:28px;"
f"}}"
f"QPushButton:hover {{ background:{_lighten(bg)}; }}"
f"QPushButton:pressed {{ background:{_darken(bg)}; }}"
f"QPushButton:disabled {{ background:#3a3a3a; color:#666666; }}"
)
def _cls_btn_style(color: str, checked: bool) -> str:
bg = color if checked else "#333333"
border = f"border:2px solid {color};" if checked else "border:1px solid #555555;"
return (
f"QPushButton {{"
f" background:{bg}; color:#ffffff; {border} border-radius:4px;"
f" font-size:15px; font-weight:bold; min-height:56px;"
f"}}"
f"QPushButton:hover {{ background:{_lighten(bg)}; }}"
)
def _spinbox_style() -> str:
return (
"QSpinBox {"
" background:#2a2a2a; color:#ffffff; border:1px solid #555555;"
" border-radius:4px; padding:4px 8px; font-size:14px; min-height:38px;"
"}"
)
def _lighten(hex_color: str) -> str:
try:
r, g, b = int(hex_color[1:3], 16), int(hex_color[3:5], 16), int(hex_color[5:7], 16)
return f"#{min(r+30,255):02x}{min(g+30,255):02x}{min(b+30,255):02x}"
except Exception:
return hex_color
def _darken(hex_color: str) -> str:
try:
r, g, b = int(hex_color[1:3], 16), int(hex_color[3:5], 16), int(hex_color[5:7], 16)
return f"#{max(r-30,0):02x}{max(g-30,0):02x}{max(b-30,0):02x}"
except Exception:
return hex_color