Compare commits

...

2 commits

Author SHA1 Message Date
10x Developer
a084777e64 implement ML model conversion, deploy models, and fix card detection noise 2026-05-09 21:37:21 +02:00
10x Developer
ff1561a704 Update ML setup guide and add model scripts 2026-05-09 01:38:35 +02:00
16 changed files with 172 additions and 50 deletions

4
.gitignore vendored
View file

@ -1,3 +1,7 @@
/venv
/dataset
*.h5
# Dependencies
node_modules/
.pnp/

View file

@ -12,7 +12,7 @@ The system uses a **Hybrid Pipeline** to balance performance on mobile devices w
2. **Classification**: Crops each detected card and passes it through two specialized TensorFlow.js models to determine the **Suit** and the **Value**.
## 2. Data Collection & Labeling
Because Jass cards have unique iconography, we use a custom dataset created from internet samples.
Because Jass cards have unique iconography, we use a custom dataset.
### Labeling Strategy: Folder-Based Annotation
Instead of using bounding-box tools, we use the directory structure as labels.
@ -29,7 +29,7 @@ dataset/
├── 6/
├── 7/
...
└── 13/
└── 14/
```
### Process:

32
convert_model.sh Executable file
View file

@ -0,0 +1,32 @@
#!/bin/bash
# Check if tensorflowjs is installed
if ! command -v tensorflowjs_converter &> /dev/null
then
echo "tensorflowjs_converter could not be found. Please install it using: pip install tensorflowjs"
exit 1
fi
MODEL_TYPE=$1
INPUT_MODEL=$2
OUTPUT_DIR=$3
if [ -z "$MODEL_TYPE" ] || [ -z "$INPUT_MODEL" ] || [ -z "$OUTPUT_DIR" ]; then
echo "Usage: ./convert_model.sh [suit|value] [path_to_h5_model] [output_directory]"
exit 1
fi
echo "Converting $MODEL_TYPE model from $INPUT_MODEL to $OUTPUT_DIR..."
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
# Perform conversion
tensorflowjs_converter --input_format=keras "$INPUT_MODEL" "$OUTPUT_DIR"
if [ $? -eq 0 ]; then
echo "Conversion successful. Model saved to $OUTPUT_DIR"
else
echo "Conversion failed."
exit 1
fi

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View file

@ -18,15 +18,6 @@ const App = () => {
);
};
const App = () => {
return (
<GameStateProvider>
<AppContent />
</GameStateProvider>
);
};
const AppContent = () => {
const { gameState } = useGameStateContext();

View file

@ -98,9 +98,8 @@ const Detection: React.FC<DetectionProps> = ({ videoRef, canvasRef, onCardsDetec
// Enhanced card region detection specialized for Jass cards
const findCardRegions = (imageData: ImageData, width: number, height: number): {x: number, y: number, width: number, height: number}[] => {
const regions = [];
const step = 12; // Adjusted step for better detection
const step = 24; // Increased step to reduce noise and overlapping detections
// Look for card-colored regions: typically white/gray cards with identifiable suit symbols
for (let y = 0; y < height; y += step) {
for (let x = 0; x < width; x += step) {
const i = (y * width + x) * 4;
@ -109,17 +108,11 @@ const Detection: React.FC<DetectionProps> = ({ videoRef, canvasRef, onCardsDetec
const b = imageData.data[i + 2];
const brightness = (r + g + b) / 3;
// Card background brightness range (light cards)
if (brightness > 180 && brightness < 250) {
if (brightness > 200 && brightness < 255) {
const region = getCardRegionWithShapeAnalysis(imageData, width, height, x, y, step);
if (region && region.width > 30 && region.height > 40) {
// Check that this is a proper rectangular card area by testing aspect ratio
const widthDiff = region.width;
const heightDiff = region.height;
const aspectRatio = widthDiff / heightDiff;
// Typical playing card aspect ratio (roughly 0.7 to 1.3)
if (aspectRatio > 0.7 && aspectRatio < 1.3) {
if (region && region.width > 60 && region.height > 80) {
const aspectRatio = region.width / region.height;
if (aspectRatio > 0.6 && aspectRatio < 1.4) {
regions.push(region);
}
}
@ -127,7 +120,20 @@ const Detection: React.FC<DetectionProps> = ({ videoRef, canvasRef, onCardsDetec
}
}
return regions;
// Remove overlapping regions to avoid "dozens of nonsense cards"
const uniqueRegions = [];
regions.sort((a, b) => (a.width * a.height) - (b.width * b.height));
for (const region of regions) {
const isOverlapping = uniqueRegions.some(u =>
Math.abs(u.x - region.x) < 30 && Math.abs(u.y - region.y) < 30
);
if (!isOverlapping) {
uniqueRegions.push(region);
}
}
return uniqueRegions;
};
// Improved card region extraction with shape analysis for more accurate detection

View file

@ -80,7 +80,12 @@ const ResultsScreen: React.FC = () => {
<div className="cards-grid">
{gameState.detectedCards.map(card => (
<div key={card.id} className="card-preview">
<div className="card-suit"></div>
<div className="card-suit">
{card.suit === 'Schellen' && '🔔'}
{card.suit === 'Schilten' && '🛡️'}
{card.suit === 'Eicheln' && '🌰'}
{card.suit === 'Rosen' && '🌹'}
</div>
<div className="card-value">{card.value || '?'}pts</div>
</div>
))}

View file

@ -13,7 +13,7 @@ class CardModelService {
private readonly VALUE_MODEL_PATH = '/models/value_model/model.json';
private readonly SUIT_LABELS = ['Schellen', 'Schilten', 'Eicheln', 'Rosen'];
private readonly VALUE_LABELS = ['6', '7', '8', '9', '10', '11', '12', '13'];
private readonly VALUE_LABELS = ['6', '7', '8', '9', '10', '11', '12', '13', '14'];
async init(): Promise<void> {
try {

82
train_model.py Normal file
View file

@ -0,0 +1,82 @@
import os
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def create_model(num_classes):
# Use MobileNetV2 as the base model
base_model = tf.keras.applications.MobileNetV2(
input_shape=(64, 64, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False # Freeze base model for transfer learning
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.2),
layers.Dense(num_classes, activation='softmax')
])
model.compile(
optimizer=optimizers.Adam(),
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
def train(model_type, dataset_root, output_path):
# Set parameters based on model type
if model_type == 'suit':
num_classes = 4
elif model_type == 'value':
num_classes = 9
else:
raise ValueError("model_type must be 'suit' or 'value'")
# Data augmentation as per ML_SETUP_GUIDE.md
datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=15,
brightness_range=[0.8, 1.2],
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
os.path.join(dataset_root, f'{model_type}_model'),
target_size=(64, 64),
batch_size=32,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
os.path.join(dataset_root, f'{model_type}_model'),
target_size=(64, 64),
batch_size=32,
class_mode='categorical',
subset='validation'
)
model = create_model(num_classes)
print(f"Training {model_type} model...")
model.fit(
train_generator,
epochs=20,
validation_data=validation_generator
)
model.save(output_path)
print(f"Model saved to {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--type', choices=['suit', 'value'], required=True)
parser.add_argument('--dataset', default='dataset')
parser.add_argument('--output', required=True)
args = parser.parse_args()
train(args.type, args.dataset, args.output)