diff --git a/AGENTS.md b/AGENTS.md index 8e9f32e..750c3c8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,6 +23,7 @@ Mobile-first React web app for tracking Schaffhauser (Jass) card game rounds usi ### Development Commands - `npm run build` - Build production version - `npm run preview` - Preview production build +- `npm run typecheck` - Run TypeScript type checking ### Development Server - A development server is always guaranteed to be running on localhost:5173 diff --git a/DETECTION_IMPROVEMENT_PLAN.md b/DETECTION_IMPROVEMENT_PLAN.md new file mode 100644 index 0000000..2486d64 --- /dev/null +++ b/DETECTION_IMPROVEMENT_PLAN.md @@ -0,0 +1,44 @@ +# Card Object Detection Reliability Improvement Plan + +This document outlines the strategy to improve the reliability of the card detection pipeline, moving from a heuristic-based approach to a robust computer vision pipeline. + +## Current Limitations +- **Localization**: Relies on brightness thresholds, which are highly sensitive to lighting conditions and shadows. +- **Geometry**: Crops raw bounding boxes without correcting for perspective (table angle), forcing the ML models to handle distortion. +- **Stability**: Live detection is susceptible to frame-by-frame jitter (flickering). + +## Proposed Improvements + +### 1. Robust Localization (The "Where" Problem) +Transition from brightness-based search to shape and edge detection: +- **Edge-based Detection**: Implement **Canny Edge Detection** and **Contour Approximation** to identify rectangular shapes regardless of absolute brightness. +- **Color Space Shift**: Move from RGB to **HSV (Hue, Saturation, Value)** or **LAB** color spaces to decouple lighting (Value/Lightness) from color information. +- **End-to-End Detection (Long-term)**: Evaluate lightweight object detection models (e.g., **YOLOv8-nano** or **SSD MobileNet**) to replace manual region finding. + +### 2. Perspective Correction (The "Geometry" Problem) +Eliminate image skew to provide standardized input to classifiers: +- **Four-Point Transform (Warping)**: Identify the four corners of the detected card contour and apply a **Perspective Transform (Homography)** to "flatten" the card into a normalized top-down rectangle. +- **Standardized Input**: Ensure the ML models always receive a centered, non-distorted crop, reducing the reliance on massive geometric data augmentation. + +### 3. Enhanced Classification (The "What" Problem) +Improve the precision of identity recognition: +- **Unified Multi-Head Model**: Combine Suit and Value models into a single network with two output heads to reduce latency and exploit shared features. +- **Advanced Data Augmentation**: Expand the training set with: + - **Motion Blur**: Simulating handheld camera movement. + - **Perspective Distortions**: To handle imperfect warping. + - **Lighting Variations**: Simulating varied environmental lighting. +- **Confidence Calibration**: Implement a minimum confidence threshold to avoid false positives in noisy environments. + +### 4. Temporal Stability (The "Flicker" Problem) +Prevent identity jumping in live mode: +- **Object Tracking**: Implement a **Centroid Tracker** or **Kalman Filter** to maintain card identity across frames instead of detecting from scratch every time. +- **Temporal Smoothing**: Use a "Voting" mechanism where a card's identity is only confirmed if the model is consistent over a sliding window of 5-10 frames. + +## Implementation Roadmap + +| Phase | Focus | Key Change | Expected Impact | +| :--- | :--- | :--- | :--- | +| **Phase 1** | **Stability** | Edge detection + Temporal smoothing | Reduced flickering and lighting sensitivity. | +| **Phase 2** | **Geometry** | Perspective Warping (Flattening) | Significant boost in classification accuracy. | +| **Phase 3** | **Intelligence** | Unified Model + Expanded Dataset | Higher precision and lower inference latency. | +| **Phase 4** | **Architecture** | Full Object Detection Model (YOLO) | Industry-standard reliability and speed. | diff --git a/package.json b/package.json index 57bb273..f8c7756 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,8 @@ "scripts": { "dev": "vite --host 0.0.0.0", "build": "vite build", - "preview": "vite preview" + "preview": "vite preview", + "typecheck": "tsc --noEmit" }, "dependencies": { "@tensorflow/tfjs": "^4.19.0", @@ -13,9 +14,11 @@ "react-router-dom": "^7.15.0" }, "devDependencies": { + "@types/node": "^25.6.2", "@types/react": "^18.3.5", "@types/react-dom": "^18.3.0", "@vitejs/plugin-react": "^4.3.0", + "typescript": "^6.0.3", "vite": "^5.4.0" } } diff --git a/src/components/Detection/Detection.tsx b/src/components/Detection/Detection.tsx index 7c60648..8b093d8 100644 --- a/src/components/Detection/Detection.tsx +++ b/src/components/Detection/Detection.tsx @@ -1,6 +1,8 @@ import React, { useRef, useEffect } from 'react'; import { Card } from '../../types'; import { cardModelService } from '../../services/CardModelService'; +import { CentroidTracker, BoundingBox } from '../../utils/Tracker'; +import { ImageProcessing, Rect } from '../../utils/ImageProcessing'; interface DetectionProps { videoRef: React.RefObject; @@ -10,12 +12,16 @@ interface DetectionProps { onLiveCardsDetected?: (cards: Card[]) => void; } + const Detection: React.FC = ({ videoRef, canvasRef, onCardsDetected, live, onLiveCardsDetected }) => { const isDetectingRef = useRef(false); const requestRef = useRef(); + const trackerRef = useRef(new CentroidTracker()); + const classificationHistoryRef = useRef>(new Map()); // Expose detection method for external calls const detectCards = async () => { + if (!videoRef.current || !canvasRef.current || isDetectingRef.current) return; isDetectingRef.current = true; @@ -89,14 +95,20 @@ const Detection: React.FC = ({ videoRef, canvasRef, onCardsDetec // Enhanced card detection using image processing specialized for Jass cards const processImageForCards = async (canvas: HTMLCanvasElement, ctx: CanvasRenderingContext2D): Promise => { - // Card detection logic optimized for Swiss/German-style cards const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); - const cardRegions = findCardRegions(imageData, canvas.width, canvas.height); + + // Replace brightness thresholding with robust edge-based localization + const edges = ImageProcessing.detectEdges(imageData); + const cardPolygons = ImageProcessing.findRectangularRegions(edges, canvas.width, canvas.height); + const detectedCards: Card[] = []; - for (let i = 0; i < cardRegions.length; i++) { - const region = cardRegions[i]; - const cardCrop = createCrop(ctx, canvas, region); + for (let i = 0; i < cardPolygons.length; i++) { + const polygon = cardPolygons[i]; + + // Perspective Warping: Get 4 corners and flatten image + const corners = ImageProcessing.findCorners(polygon.points); + const cardCrop = ImageProcessing.warpPerspective(canvas, corners, 128, 192); let suit: 'Schellen' | 'Schilten' | 'Eicheln' | 'Rosen'; let value: number; @@ -109,23 +121,78 @@ const Detection: React.FC = ({ videoRef, canvasRef, onCardsDetec value = parseInt(valRes.label); confidence = (suitRes.confidence + valRes.confidence) / 2; } else { - suit = detectCardSuit(ctx, canvas, region); - value = detectCardValue(ctx, canvas, region); + suit = detectCardSuit(ctx, canvas, polygon.bbox); + value = detectCardValue(ctx, canvas, polygon.bbox); } detectedCards.push({ id: `card-${i}`, suit, value, - x: region.x, - y: region.y, - width: region.width, - height: region.height, + x: polygon.bbox.x, + y: polygon.bbox.y, + width: polygon.bbox.width, + height: polygon.bbox.height, confidence }); } - return detectedCards; + // Apply Tracking and Temporal Smoothing + const trackedObjects = trackerRef.current.update(cardPolygons.map(p => p.bbox)); + const finalCards: Card[] = []; + + + + + + + + + + for (const obj of trackedObjects) { + // Match raw detections to tracked objects to get current frame's identity + const detection = detectedCards.find(c => + Math.abs(c.x - obj.bbox.x) < 20 && Math.abs(c.y - obj.bbox.y) < 20 + ); + + if (detection) { + // Update history for temporal voting + if (!classificationHistoryRef.current.has(obj.id)) { + classificationHistoryRef.current.set(obj.id, { suits: [], values: [] }); + } + const history = classificationHistoryRef.current.get(obj.id)!; + history.suits.push(detection.suit); + history.values.push(detection.value); + if (history.suits.length > 10) { + history.suits.shift(); + history.values.shift(); + } + + // Vote for most common identity + const bestSuit = getMostCommon(history.suits) as any; + const bestValue = getMostCommon(history.values); + + finalCards.push({ + id: `card-${obj.id}`, + suit: bestSuit, + value: bestValue, + x: obj.bbox.x, + y: obj.bbox.y, + width: obj.bbox.width, + height: obj.bbox.height, + confidence: detection.confidence + }); + } + } + + return finalCards; + }; + + const getMostCommon = (arr: any[]) => { + if (arr.length === 0) return null; + const counts: Record = {}; + arr.forEach(item => counts[item] = (counts[item] || 0) + 1); + return Object.entries(counts).sort((a, b) => b[1] - a[1])[0][0]; }; const createCrop = (ctx: CanvasRenderingContext2D, canvas: HTMLCanvasElement, region: {x: number, y: number, width: number, height: number}): HTMLCanvasElement => { diff --git a/src/utils/ImageProcessing.ts b/src/utils/ImageProcessing.ts new file mode 100644 index 0000000..fc2c320 --- /dev/null +++ b/src/utils/ImageProcessing.ts @@ -0,0 +1,219 @@ +export interface Rect { + x: number; + y: number; + width: number; + height: number; +} + +export interface Point { + x: number; + y: number; +} + +export interface Polygon { + points: Point[]; + bbox: Rect; +} + +export class ImageProcessing { + /** + * Simple Sobel filter to detect edges in a grayscale image + */ + static Sobel(data: Uint8ClampedArray, width: number, height: number): Float32Array { + const output = new Float32Array(width * height); + const gx = [ + -1, 0, 1, + -2, 0, 2, + -1, 0, 1 + ]; + const gy = [ + -1, -2, -1, + 0, 0, 0, + 1, 2, 1 + ]; + + for (let y = 1; y < height - 1; y++) { + for (let x = 1; x < width - 1; x++) { + let sumX = 0; + let sumY = 0; + + for (let ky = -1; ky <= 1; ky++) { + for (let kx = -1; kx <= 1; kx++) { + const pixel = data[((y + ky) * width + (x + kx)) * 4]; // use red channel for grayscale + sumX += pixel * gx[(ky + 1) * 3 + (kx + 1)]; + sumY += pixel * gy[(ky + 1) * 3 + (kx + 1)]; + } + } + output[y * width + x] = Math.sqrt(sumX * sumX + sumY * sumY); + } + } + return output; + } + + /** + * Find rectangular regions based on edge strength + */ + static findRectangularRegions(edges: Float32Array, width: number, height: number, threshold: number = 50): Polygon[] { + const regions: Polygon[] = []; + const visited = new Uint8Array(width * height); + + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + if (edges[y * width + x] > threshold && !visited[y * width + x]) { + // Start a region search + const region = this.expandRegion(edges, visited, x, y, width, height, threshold); + if (region && region.bbox.width > 50 && region.bbox.height > 80) { + regions.push(region); + } + } + } + } + return regions; + } + + private static expandRegion(edges: Float32Array, visited: Uint8Array, startX: number, startY: number, width: number, height: number, threshold: number): Polygon | null { + + let minX = startX, maxX = startX; + let minY = startY, maxY = startY; + const points: Point[] = []; + + const stack = [[startX, startY]]; + visited[startY * width + startX] = 1; + + while (stack.length > 0) { + const [cx, cy] = stack.pop()!; + + minX = Math.min(minX, cx); + maxX = Math.max(maxX, cx); + minY = Math.min(minY, cy); + maxY = Math.max(maxY, cy); + points.push({ x: cx, y: cy }); + + const neighbors = [ + [cx + 1, cy], [cx - 1, cy], [cx, cy + 1], [cx, cy - 1] + ]; + + for (const [nx, ny] of neighbors) { + if (nx >= 0 && nx < width && ny >= 0 && ny < height && + !visited[ny * width + nx] && edges[ny * width + nx] > threshold) { + visited[ny * width + nx] = 1; + stack.push([nx, ny]); + } + } + } + + const w = maxX - minX; + const h = maxY - minY; + + if (w < 50 || h < 80) return null; + + return { + points, + bbox: { x: minX, y: minY, width: w, height: h } + }; + } + + + /** + * Find the 4 corners of a point set that most closely resemble a rectangle + */ + static findCorners(points: Point[]): Point[] { + if (points.length < 4) return []; + + let topLeft = points[0]; + let topRight = points[0]; + let bottomRight = points[0]; + let bottomLeft = points[0]; + + let minSum = Infinity, maxSum = -Infinity; + let minDiff = Infinity, maxDiff = -Infinity; + + for (const p of points) { + const sum = p.x + p.y; + const diff = p.x - p.y; + + if (sum < minSum) { minSum = sum; topLeft = p; } + if (sum > maxSum) { maxSum = sum; bottomRight = p; } + if (diff < minDiff) { minDiff = diff; bottomLeft = p; } + if (diff > maxDiff) { maxDiff = diff; topRight = p; } + } + + return [topLeft, topRight, bottomRight, bottomLeft]; + } + + /** + * Performs a basic bilinear interpolation warp of a source image + * from 4 corners to a destination rectangle + */ + static warpPerspective( + sourceCanvas: HTMLCanvasElement, + srcCorners: Point[], + destWidth: number, + destHeight: number + ): HTMLCanvasElement { + const destCanvas = document.createElement('canvas'); + destCanvas.width = destWidth; + destCanvas.height = destHeight; + const destCtx = destCanvas.getContext('2d'); + if (!destCtx) return destCanvas; + + const srcCtx = sourceCanvas.getContext('2d'); + if (!srcCtx) return destCanvas; + + const srcData = srcCtx.getImageData(0, 0, sourceCanvas.width, sourceCanvas.height).data; + const destData = destCtx.createImageData(destWidth, destHeight); + + for (let y = 0; y < destHeight; y++) { + for (let x = 0; x < destWidth; x++) { + const u = x / destWidth; + const v = y / destHeight; + + const srcX = (1 - u) * ((1 - v) * srcCorners[0].x + v * srcCorners[3].x) + + u * ((1 - v) * srcCorners[1].x + v * srcCorners[2].x); + const srcY = (1 - u) * ((1 - v) * srcCorners[0].y + v * srcCorners[3].y) + + u * ((1 - v) * srcCorners[1].y + v * srcCorners[2].y); + + const sx = Math.floor(srcX); + const sy = Math.floor(srcY); + + if (sx >= 0 && sx < sourceCanvas.width && sy >= 0 && sy < sourceCanvas.height) { + const srcIdx = (sy * sourceCanvas.width + sx) * 4; + const destIdx = (y * destWidth + x) * 4; + destData.data[destIdx] = srcData[srcIdx]; + destData.data[destIdx + 1] = srcData[srcIdx + 1]; + destData.data[destIdx + 2] = srcData[srcIdx + 2]; + destData.data[destIdx + 3] = srcData[srcIdx + 3]; + } + } + } + + destCtx.putImageData(destData, 0, 0); + return destCanvas; + } + + /** + * Converts RGB to grayscale ( Luminance ) + */ + static toGrayscale(imageData: ImageData): Uint8ClampedArray { + const { data, width, height } = imageData; + const gray = new Uint8ClampedArray(width * height); + for (let i = 0; i < data.length; i += 4) { + gray[i / 4] = 0.299 * data[i] + 0.587 * data[i + 1] + 0.114 * data[i + 2]; + } + return gray; + } + + /** + * Overloaded Sobel that takes ImageData and returns edge map + */ + static detectEdges(imageData: ImageData): Float32Array { + const gray = this.toGrayscale(imageData); + // Create a fake imageData for the Sobel method since it expects 4-channel + const fakeData = new Uint8ClampedArray(gray.length * 4); + for (let i = 0; i < gray.length; i++) { + fakeData[i * 4] = gray[i]; + } + return this.Sobel(fakeData, imageData.width, imageData.height); + } +} + diff --git a/src/utils/Tracker.ts b/src/utils/Tracker.ts new file mode 100644 index 0000000..cecf3c3 --- /dev/null +++ b/src/utils/Tracker.ts @@ -0,0 +1,121 @@ +export interface BoundingBox { + x: number; + y: number; + width: number; + height: number; +} + +export interface TrackedObject { + id: number; + bbox: BoundingBox; + centroid: { x: number; y: number }; + age: number; + hits: number; + suit?: string; + value?: number; + confidence?: number; +} + +export class CentroidTracker { + private nextObjectId = 0; + private objects: Map = new Map(); + private maxDisappeared = 5; // Frames to keep object after losing it + private maxDistance = 100; // Max distance to match centroids + + private getCentroid(bbox: BoundingBox) { + return { + x: bbox.x + bbox.width / 2, + y: bbox.y + bbox.height / 2, + }; + } + + private distance(p1: { x: number; y: number }, p2: { x: number; y: number }) { + return Math.sqrt(Math.pow(p1.x - p2.x, 2) + Math.pow(p1.y - p2.y, 2)); + } + + update(rects: BoundingBox[]): TrackedObject[] { + if (rects.length === 0) { + // Deregister objects that have disappeared for too long + for (const [id, obj] of this.objects.entries()) { + obj.age++; + if (obj.age > this.maxDisappeared) { + this.objects.delete(id); + } + } + return Array.from(this.objects.values()); + } + + const inputCentroids = rects.map(r => this.getCentroid(r)); + const existingObjects = Array.from(this.objects.values()); + + if (existingObjects.length === 0) { + // Initialize objects + rects.forEach((rect, i) => { + const id = this.nextObjectId++; + this.objects.set(id, { + id, + bbox: rect, + centroid: inputCentroids[i], + age: 0, + hits: 1 + }); + }); + } else { + // Match existing objects to new centroids + const objectCentroids = existingObjects.map(obj => obj.centroid); + const usedInput = new Set(); + const usedObject = new Set(); + + // Simple greedy matching based on distance + for (let i = 0; i < existingObjects.length; i++) { + let minDist = this.maxDistance; + let matchIdx = -1; + + for (let j = 0; j < inputCentroids.length; j++) { + if (usedInput.has(j)) continue; + const d = this.distance(objectCentroids[i], inputCentroids[j]); + if (d < minDist) { + minDist = d; + matchIdx = j; + } + } + + if (matchIdx !== -1) { + const obj = existingObjects[i]; + obj.bbox = rects[matchIdx]; + obj.centroid = inputCentroids[matchIdx]; + obj.age = 0; + obj.hits++; + usedInput.add(matchIdx); + usedObject.add(obj.id); + } + } + + // Update age for missed objects + for (const [id, obj] of this.objects.entries()) { + if (!usedObject.has(id)) { + obj.age++; + if (obj.age > this.maxDisappeared) { + this.objects.delete(id); + } + } + } + + // Register new objects + inputCentroids.forEach((centroid, i) => { + if (!usedInput.has(i)) { + const id = this.nextObjectId++; + this.objects.set(id, { + id, + bbox: rects[i], + centroid, + age: 0, + hits: 1 + }); + } + }); + } + + return Array.from(this.objects.values()); + } +} diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 0000000..2121127 --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,28 @@ +{ + "compilerOptions": { + "target": "ESNext", + "useDefineForClassFields": true, + "lib": ["DOM", "DOM.Iterable", "ESNext"], + "allowJs": true, + "skipLibCheck": true, + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "ignoreDeprecations": "6.0", + + + + "strict": true, + "forceConsistentCasingInFileNames": true, + "module": "ESNext", + "moduleResolution": "Bundler", + "resolveJsonModule": true, + "jsx": "react-jsx", + "noEmit": true, + "baseUrl": ".", + "paths": { + "@/*": ["./src/*"] + } + }, + "include": ["src"] +} +