# AI 학습 — YOLOv8 재학습 및 모델 저장 import multiprocessing import os import random import shutil from utils.path_helper import get_path import yaml from PyQt5.QtCore import QThread, pyqtSignal class Trainer: def __init__(self): self.model = None self.is_training = False # ------------------------------------------------------------------ # def prepare_dataset(self, image_folder: str) -> str: dataset_dir = get_path("ai", "dataset") if os.path.exists(dataset_dir): shutil.rmtree(dataset_dir) for split in ("train", "val"): os.makedirs(os.path.join(dataset_dir, "images", split), exist_ok=True) os.makedirs(os.path.join(dataset_dir, "labels", split), exist_ok=True) pairs = [] for f in os.listdir(image_folder): if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): img_path = os.path.join(image_folder, f) txt_path = os.path.join( image_folder, os.path.splitext(f)[0] + ".txt" ) if os.path.exists(txt_path): pairs.append((img_path, txt_path)) random.shuffle(pairs) split_idx = int(len(pairs) * 0.8) train_pairs = pairs[:split_idx] val_pairs = pairs[split_idx:] for img, lbl in train_pairs: shutil.copy(img, os.path.join(dataset_dir, "images", "train")) shutil.copy(lbl, os.path.join(dataset_dir, "labels", "train")) for img, lbl in val_pairs: shutil.copy(img, os.path.join(dataset_dir, "images", "val")) shutil.copy(lbl, os.path.join(dataset_dir, "labels", "val")) yaml_content = { "path": dataset_dir, "train": "images/train", "val": "images/val", "nc": 4, "names": ["스크래치", "이물", "흑점", "변형"], } yaml_path = os.path.join(dataset_dir, "data.yaml") with open(yaml_path, "w", encoding="utf-8") as fh: yaml.dump(yaml_content, fh, allow_unicode=True) return yaml_path # ------------------------------------------------------------------ # def train( self, image_folder: str, epochs: int, batch: int, save_path: str, log_callback=None, progress_callback=None, ): self.is_training = True try: if log_callback: log_callback("데이터셋 준비 중...") yaml_path = self.prepare_dataset(image_folder) if log_callback: log_callback(f"데이터셋 준비 완료: {yaml_path}") if log_callback: log_callback("YOLOv8 모델 로드 중...") from ultralytics import YOLO # 지연 로딩 — 앱 시작 시 torch DLL 오류 방지 self.model = YOLO("yolov8n.pt") if log_callback: log_callback(f"학습 시작 (epoch={epochs}, batch={batch})") def _on_epoch_end(trainer): ep = trainer.epoch + 1 try: loss_val = float(trainer.loss) loss_str = f"{loss_val:.4f}" except Exception: loss_str = "?" if log_callback: log_callback(f"Epoch {ep}/{epochs} loss={loss_str}") if progress_callback: progress_callback(int(ep / epochs * 100)) self.model.add_callback("on_train_epoch_end", _on_epoch_end) self.model.train( data=yaml_path, epochs=epochs, batch=batch, imgsz=640, project=get_path("ai", "runs"), name="train", exist_ok=True, verbose=True, workers=0, # disable DataLoader multiprocessing inside subprocess amp=False, # AMP check also spawns a subprocess on Windows plots=False, # matplotlib can interfere with Qt event loop ) # best.pt 복사 best_pt = get_path("ai", "runs", "train", "weights", "best.pt") if os.path.exists(best_pt): os.makedirs(os.path.dirname(save_path), exist_ok=True) shutil.copy(best_pt, save_path) if log_callback: log_callback(f"모델 저장 완료: {save_path}") if progress_callback: progress_callback(100) if log_callback: log_callback("학습 완료!") except BaseException as e: import traceback if log_callback: try: log_callback(f"학습 오류: {e}") log_callback(traceback.format_exc()) except Exception: pass finally: self.is_training = False # ------------------------------------------------------------------ # def stop(self): self.is_training = False # ====================================================================== # # Subprocess entry point — defined at module level so it is picklable. # ultralytics training can call sys.exit() internally; running it in a # separate process completely isolates the Qt application from that. # ====================================================================== # def _train_subprocess_main(queue, image_folder, epochs, batch, save_path): """Entry point for the isolated training subprocess.""" try: trainer = Trainer() def _log(msg): try: queue.put(("log", msg)) except Exception: pass def _progress(pct): try: queue.put(("progress", int(pct))) except Exception: pass trainer.train( image_folder=image_folder, epochs=epochs, batch=batch, save_path=save_path, log_callback=_log, progress_callback=_progress, ) queue.put(("done", True)) except BaseException as e: import traceback try: queue.put(("log", f"학습 오류: {e}")) queue.put(("log", traceback.format_exc())) except Exception: pass try: queue.put(("done", False)) except Exception: pass # ====================================================================== # class TrainWorker(QThread): log_signal = pyqtSignal(str) progress_signal = pyqtSignal(int) finished_signal = pyqtSignal(bool) def __init__(self, trainer, image_folder, epochs, batch, save_path): super().__init__() self.trainer = trainer self.image_folder = image_folder self.epochs = epochs self.batch = batch self.save_path = save_path self._proc = None # training subprocess handle def run(self): # Spawn an isolated subprocess so that any sys.exit() call inside # ultralytics does not reach PyQt5's QThread handler and trigger # QApplication.exit() in the main process. ctx = multiprocessing.get_context("spawn") q = ctx.Queue() self._proc = ctx.Process( target=_train_subprocess_main, args=(q, self.image_folder, self.epochs, self.batch, self.save_path), daemon=True, ) self._proc.start() success = False while True: proc_alive = self._proc.is_alive() try: msg_type, msg_data = q.get(timeout=0.3) except Exception: # queue.Empty — check if subprocess died unexpectedly if not proc_alive: # Give one last chance to read remaining messages while True: try: msg_type, msg_data = q.get_nowait() except Exception: break if msg_type == "log": self.log_signal.emit(str(msg_data)) elif msg_type == "progress": self.progress_signal.emit(int(msg_data)) elif msg_type == "done": success = bool(msg_data) break continue if msg_type == "log": self.log_signal.emit(str(msg_data)) elif msg_type == "progress": self.progress_signal.emit(int(msg_data)) elif msg_type == "done": success = bool(msg_data) break self._proc.join(timeout=30) if self._proc.is_alive(): self._proc.terminate() self._proc.join(timeout=5) self.finished_signal.emit(success) def stop_subprocess(self): """Call from the main thread to forcefully stop training.""" if self._proc and self._proc.is_alive(): self._proc.terminate()