Files
ant-vision-inspector/logic/pattern_matcher.py
2026-06-10 16:18:41 +09:00

223 lines
9.0 KiB
Python

# Python PatMax 대체 구현
# ORB 특징점 매칭 (위치·회전 불변) + 엣지 NCC fallback (특징점 부족 시)
import os
import pickle
import cv2
import numpy as np
from typing import Optional
_PATTERNS_PATH = os.path.join("assets", "patterns.pkl")
SCORE_THRESHOLD = 60.0
_MAX_IMG_SIZE = 1200 # ORB 처리 전 긴 변 최대 픽셀 (속도·메모리 제한)
_GOOD_MATCH_REF = 20 # good match 이 수 이상 → 100점 기준
class PatternMatcher:
def __init__(self, threshold: float = SCORE_THRESHOLD):
# {product_id: {"method": "orb"|"ncc", "des": ndarray|None, ...}}
self._patterns: dict = {}
self._threshold = threshold
# ── 학습 ──────────────────────────────────────────────────────────── #
def train(self, image: np.ndarray, product_id: int, product_info: dict,
roi=None):
"""
roi: (x, y, w, h) 픽셀 좌표, None이면 전체 이미지.
ORB로 특징점을 자동 검출해 등록. 특징점 10개 미만이면 엣지 NCC 방식으로 전환.
"""
gray = _to_gray(image)
if roi is not None:
x, y, rw, rh = roi
x, y = max(0, x), max(0, y)
rw = min(rw, gray.shape[1] - x)
rh = min(rh, gray.shape[0] - y)
gray = gray[y:y + rh, x:x + rw].copy()
gray = _resize_if_large(gray)
kp, des = cv2.ORB_create(nfeatures=1000).detectAndCompute(gray, None)
if des is not None and len(kp) >= 10:
method = "orb"
n_kp = len(kp)
edges = None
else:
method = "ncc"
n_kp = 0
des = None
edges = _to_edges(gray)
name = (f"{product_info.get('name')} "
f"{product_info.get('model')} {product_info.get('type')}")
suffix = f" ORB 특징점 {n_kp}" if method == "orb" else " 엣지 NCC (특징점 부족)"
print(f"[PatternMatcher] 등록: id={product_id} {name}{suffix}")
self._patterns[product_id] = {
"method": method,
"gray": gray, # 축소된 그레이 (시각화·fallback용)
"des": des, # ORB 디스크립터 (np.ndarray) or None
"n_kp": n_kp,
"edges": edges, # Canny 엣지 or None
"info": product_info,
"roi": roi,
}
# ── 매칭 ──────────────────────────────────────────────────────────── #
def match_all(self, image: np.ndarray) -> dict:
"""모든 등록 패턴에 대한 점수 반환 {product_id: score(0~100)}"""
gray = _resize_if_large(_to_gray(image))
ncc_edges = None # NCC용 엣지는 필요할 때만 계산
scores = {}
for pid, data in self._patterns.items():
if data.get("method") == "orb" and data.get("des") is not None:
scores[pid] = _orb_score(gray, data["des"])
else:
if ncc_edges is None:
ncc_edges = _to_edges(gray)
tmpl = data.get("edges") or _to_edges(data["gray"])
scores[pid] = _best_rotated_score(ncc_edges, tmpl)
return scores
# ── 저장 / 로드 ────────────────────────────────────────────────────── #
def save(self, path: str = _PATTERNS_PATH):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "wb") as f:
pickle.dump({"patterns": self._patterns, "threshold": self._threshold}, f)
print(f"[PatternMatcher] 저장: {len(self._patterns)}개 → {path}")
def load(self, path: str = _PATTERNS_PATH) -> bool:
if not os.path.exists(path):
return False
try:
with open(path, "rb") as f:
data = pickle.load(f)
self._patterns = data.get("patterns", {})
self._threshold = data.get("threshold", SCORE_THRESHOLD)
# 구버전 호환: method 키 없으면 NCC로 처리
for pat in self._patterns.values():
if "method" not in pat:
pat["method"] = "ncc"
if not pat.get("edges"):
pat["edges"] = _to_edges(pat["gray"])
print(f"[PatternMatcher] 로드: {len(self._patterns)}개 ← {path}")
return True
except Exception as e:
print(f"[PatternMatcher] 로드 실패: {e}")
return False
# ── 조회 ──────────────────────────────────────────────────────────── #
def has_pattern(self, product_id: int) -> bool:
return product_id in self._patterns
def remove_pattern(self, product_id: int):
self._patterns.pop(product_id, None)
def get_product_info(self, product_id: int) -> Optional[dict]:
data = self._patterns.get(product_id)
return data["info"] if data else None
def get_pattern_summary(self, product_id: int) -> str:
"""패턴 등록 방식 및 특징점 수 요약 문자열."""
data = self._patterns.get(product_id)
if not data:
return ""
if data.get("method") == "orb":
return f"ORB 특징점 {data.get('n_kp', '?')}개 검출됨"
return "엣지 NCC 방식 (특징점 부족)"
@property
def registered_ids(self) -> list:
return list(self._patterns.keys())
@property
def score_threshold(self) -> float:
return self._threshold
# ── 공통 유틸 ──────────────────────────────────────────────────────────── #
def _to_gray(image: np.ndarray) -> np.ndarray:
if len(image.shape) == 3:
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return image.copy()
def _resize_if_large(image: np.ndarray) -> np.ndarray:
"""긴 변이 _MAX_IMG_SIZE 초과 시 비율 유지 축소."""
h, w = image.shape[:2]
if max(h, w) <= _MAX_IMG_SIZE:
return image
scale = _MAX_IMG_SIZE / max(h, w)
return cv2.resize(image, (int(w * scale), int(h * scale)),
interpolation=cv2.INTER_AREA)
# ── ORB 특징점 매칭 ────────────────────────────────────────────────────── #
def _orb_score(image: np.ndarray, template_des: np.ndarray) -> float:
"""
ORB 디스크립터로 유사도 점수 계산 (0~100).
Lowe's ratio test(0.75)로 good match 필터링.
good match _GOOD_MATCH_REF개 이상이면 100점.
"""
orb = cv2.ORB_create(nfeatures=1000)
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
kp2, des2 = orb.detectAndCompute(image, None)
if des2 is None or len(kp2) < 2:
return 0.0
try:
matches = bf.knnMatch(template_des, des2, k=2)
except cv2.error:
return 0.0
good = [p[0] for p in matches
if len(p) == 2 and p[0].distance < 0.75 * p[1].distance]
return min(len(good) / _GOOD_MATCH_REF * 100.0, 100.0)
# ── 엣지 NCC fallback (특징점 없는 매끄러운 제품용) ───────────────────── #
_ROTATION_ANGLES = list(range(-15, 16, 5))
def _to_edges(image: np.ndarray) -> np.ndarray:
return cv2.Canny(cv2.GaussianBlur(image, (3, 3), 0), 50, 150)
def _rotate_template(image: np.ndarray, angle_deg: float) -> np.ndarray:
if abs(angle_deg) < 0.1:
return image
h, w = image.shape[:2]
cx, cy = w / 2.0, h / 2.0
M = cv2.getRotationMatrix2D((cx, cy), angle_deg, 1.0)
cos_a, sin_a = abs(M[0, 0]), abs(M[0, 1])
new_w = int(h * sin_a + w * cos_a)
new_h = int(h * cos_a + w * sin_a)
M[0, 2] += (new_w - w) / 2.0
M[1, 2] += (new_h - h) / 2.0
return cv2.warpAffine(image, M, (new_w, new_h))
def _best_rotated_score(search: np.ndarray, template: np.ndarray) -> float:
return max(_ncc_score(search, _rotate_template(template, float(a)))
for a in _ROTATION_ANGLES)
def _ncc_score(image: np.ndarray, template: np.ndarray) -> float:
h, w = image.shape[:2]
th, tw = template.shape[:2]
if th > h or tw > w:
scale = 0.8 * min(h / th, w / tw)
template = cv2.resize(template,
(max(1, int(tw * scale)), max(1, int(th * scale))),
interpolation=cv2.INTER_AREA)
th, tw = template.shape[:2]
if th > h or tw > w:
return 0.0
result = cv2.matchTemplate(image, template, cv2.TM_CCOEFF_NORMED)
_, max_val, _, _ = cv2.minMaxLoc(result)
return float(max(0.0, max_val)) * 100.0