Compare commits
2 commits
073d395ae6
...
a084777e64
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a084777e64 | ||
|
|
ff1561a704 |
16 changed files with 172 additions and 50 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -1,3 +1,7 @@
|
||||||
|
/venv
|
||||||
|
/dataset
|
||||||
|
*.h5
|
||||||
|
|
||||||
# Dependencies
|
# Dependencies
|
||||||
node_modules/
|
node_modules/
|
||||||
.pnp/
|
.pnp/
|
||||||
|
|
@ -79,4 +83,4 @@ yarn-cache/
|
||||||
*.sublime-*
|
*.sublime-*
|
||||||
.sass-cache
|
.sass-cache
|
||||||
parcel-cache.json
|
parcel-cache.json
|
||||||
.pnp.*
|
.pnp.*
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ The system uses a **Hybrid Pipeline** to balance performance on mobile devices w
|
||||||
2. **Classification**: Crops each detected card and passes it through two specialized TensorFlow.js models to determine the **Suit** and the **Value**.
|
2. **Classification**: Crops each detected card and passes it through two specialized TensorFlow.js models to determine the **Suit** and the **Value**.
|
||||||
|
|
||||||
## 2. Data Collection & Labeling
|
## 2. Data Collection & Labeling
|
||||||
Because Jass cards have unique iconography, we use a custom dataset created from internet samples.
|
Because Jass cards have unique iconography, we use a custom dataset.
|
||||||
|
|
||||||
### Labeling Strategy: Folder-Based Annotation
|
### Labeling Strategy: Folder-Based Annotation
|
||||||
Instead of using bounding-box tools, we use the directory structure as labels.
|
Instead of using bounding-box tools, we use the directory structure as labels.
|
||||||
|
|
@ -29,7 +29,7 @@ dataset/
|
||||||
├── 6/
|
├── 6/
|
||||||
├── 7/
|
├── 7/
|
||||||
...
|
...
|
||||||
└── 13/
|
└── 14/
|
||||||
```
|
```
|
||||||
|
|
||||||
### Process:
|
### Process:
|
||||||
|
|
|
||||||
32
convert_model.sh
Executable file
32
convert_model.sh
Executable file
|
|
@ -0,0 +1,32 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if tensorflowjs is installed
|
||||||
|
if ! command -v tensorflowjs_converter &> /dev/null
|
||||||
|
then
|
||||||
|
echo "tensorflowjs_converter could not be found. Please install it using: pip install tensorflowjs"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
MODEL_TYPE=$1
|
||||||
|
INPUT_MODEL=$2
|
||||||
|
OUTPUT_DIR=$3
|
||||||
|
|
||||||
|
if [ -z "$MODEL_TYPE" ] || [ -z "$INPUT_MODEL" ] || [ -z "$OUTPUT_DIR" ]; then
|
||||||
|
echo "Usage: ./convert_model.sh [suit|value] [path_to_h5_model] [output_directory]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Converting $MODEL_TYPE model from $INPUT_MODEL to $OUTPUT_DIR..."
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Perform conversion
|
||||||
|
tensorflowjs_converter --input_format=keras "$INPUT_MODEL" "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo "Conversion successful. Model saved to $OUTPUT_DIR"
|
||||||
|
else
|
||||||
|
echo "Conversion failed."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
BIN
public/models/suit_model/group1-shard1of3.bin
Normal file
BIN
public/models/suit_model/group1-shard1of3.bin
Normal file
Binary file not shown.
BIN
public/models/suit_model/group1-shard2of3.bin
Normal file
BIN
public/models/suit_model/group1-shard2of3.bin
Normal file
Binary file not shown.
BIN
public/models/suit_model/group1-shard3of3.bin
Normal file
BIN
public/models/suit_model/group1-shard3of3.bin
Normal file
Binary file not shown.
1
public/models/suit_model/model.json
Normal file
1
public/models/suit_model/model.json
Normal file
File diff suppressed because one or more lines are too long
BIN
public/models/value_model/group1-shard1of3.bin
Normal file
BIN
public/models/value_model/group1-shard1of3.bin
Normal file
Binary file not shown.
BIN
public/models/value_model/group1-shard2of3.bin
Normal file
BIN
public/models/value_model/group1-shard2of3.bin
Normal file
Binary file not shown.
BIN
public/models/value_model/group1-shard3of3.bin
Normal file
BIN
public/models/value_model/group1-shard3of3.bin
Normal file
Binary file not shown.
1
public/models/value_model/model.json
Normal file
1
public/models/value_model/model.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -18,15 +18,6 @@ const App = () => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
const App = () => {
|
|
||||||
return (
|
|
||||||
<GameStateProvider>
|
|
||||||
<AppContent />
|
|
||||||
</GameStateProvider>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const AppContent = () => {
|
const AppContent = () => {
|
||||||
const { gameState } = useGameStateContext();
|
const { gameState } = useGameStateContext();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,39 +96,45 @@ const Detection: React.FC<DetectionProps> = ({ videoRef, canvasRef, onCardsDetec
|
||||||
|
|
||||||
|
|
||||||
// Enhanced card region detection specialized for Jass cards
|
// Enhanced card region detection specialized for Jass cards
|
||||||
const findCardRegions = (imageData: ImageData, width: number, height: number): {x: number, y: number, width: number, height: number}[] => {
|
const findCardRegions = (imageData: ImageData, width: number, height: number): {x: number, y: number, width: number, height: number}[] => {
|
||||||
const regions = [];
|
const regions = [];
|
||||||
const step = 12; // Adjusted step for better detection
|
const step = 24; // Increased step to reduce noise and overlapping detections
|
||||||
|
|
||||||
// Look for card-colored regions: typically white/gray cards with identifiable suit symbols
|
for (let y = 0; y < height; y += step) {
|
||||||
for (let y = 0; y < height; y += step) {
|
for (let x = 0; x < width; x += step) {
|
||||||
for (let x = 0; x < width; x += step) {
|
const i = (y * width + x) * 4;
|
||||||
const i = (y * width + x) * 4;
|
const r = imageData.data[i];
|
||||||
const r = imageData.data[i];
|
const g = imageData.data[i + 1];
|
||||||
const g = imageData.data[i + 1];
|
const b = imageData.data[i + 2];
|
||||||
const b = imageData.data[i + 2];
|
const brightness = (r + g + b) / 3;
|
||||||
const brightness = (r + g + b) / 3;
|
|
||||||
|
if (brightness > 200 && brightness < 255) {
|
||||||
// Card background brightness range (light cards)
|
const region = getCardRegionWithShapeAnalysis(imageData, width, height, x, y, step);
|
||||||
if (brightness > 180 && brightness < 250) {
|
if (region && region.width > 60 && region.height > 80) {
|
||||||
const region = getCardRegionWithShapeAnalysis(imageData, width, height, x, y, step);
|
const aspectRatio = region.width / region.height;
|
||||||
if (region && region.width > 30 && region.height > 40) {
|
if (aspectRatio > 0.6 && aspectRatio < 1.4) {
|
||||||
// Check that this is a proper rectangular card area by testing aspect ratio
|
regions.push(region);
|
||||||
const widthDiff = region.width;
|
}
|
||||||
const heightDiff = region.height;
|
|
||||||
const aspectRatio = widthDiff / heightDiff;
|
|
||||||
|
|
||||||
// Typical playing card aspect ratio (roughly 0.7 to 1.3)
|
|
||||||
if (aspectRatio > 0.7 && aspectRatio < 1.3) {
|
|
||||||
regions.push(region);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
// Remove overlapping regions to avoid "dozens of nonsense cards"
|
||||||
return regions;
|
const uniqueRegions = [];
|
||||||
};
|
regions.sort((a, b) => (a.width * a.height) - (b.width * b.height));
|
||||||
|
|
||||||
|
for (const region of regions) {
|
||||||
|
const isOverlapping = uniqueRegions.some(u =>
|
||||||
|
Math.abs(u.x - region.x) < 30 && Math.abs(u.y - region.y) < 30
|
||||||
|
);
|
||||||
|
if (!isOverlapping) {
|
||||||
|
uniqueRegions.push(region);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueRegions;
|
||||||
|
};
|
||||||
|
|
||||||
// Improved card region extraction with shape analysis for more accurate detection
|
// Improved card region extraction with shape analysis for more accurate detection
|
||||||
const getCardRegionWithShapeAnalysis = (imageData: ImageData, width: number, height: number, x: number, y: number, step: number): {x: number, y: number, width: number, height: number} | null => {
|
const getCardRegionWithShapeAnalysis = (imageData: ImageData, width: number, height: number, x: number, y: number, step: number): {x: number, y: number, width: number, height: number} | null => {
|
||||||
|
|
|
||||||
|
|
@ -77,14 +77,19 @@ const ResultsScreen: React.FC = () => {
|
||||||
|
|
||||||
<div className="cards-summary">
|
<div className="cards-summary">
|
||||||
<h3>Detected Cards</h3>
|
<h3>Detected Cards</h3>
|
||||||
<div className="cards-grid">
|
<div className="cards-grid">
|
||||||
{gameState.detectedCards.map(card => (
|
{gameState.detectedCards.map(card => (
|
||||||
<div key={card.id} className="card-preview">
|
<div key={card.id} className="card-preview">
|
||||||
<div className="card-suit">♠</div>
|
<div className="card-suit">
|
||||||
<div className="card-value">{card.value || '?'}pts</div>
|
{card.suit === 'Schellen' && '🔔'}
|
||||||
</div>
|
{card.suit === 'Schilten' && '🛡️'}
|
||||||
))}
|
{card.suit === 'Eicheln' && '🌰'}
|
||||||
</div>
|
{card.suit === 'Rosen' && '🌹'}
|
||||||
|
</div>
|
||||||
|
<div className="card-value">{card.value || '?'}pts</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class CardModelService {
|
||||||
private readonly VALUE_MODEL_PATH = '/models/value_model/model.json';
|
private readonly VALUE_MODEL_PATH = '/models/value_model/model.json';
|
||||||
|
|
||||||
private readonly SUIT_LABELS = ['Schellen', 'Schilten', 'Eicheln', 'Rosen'];
|
private readonly SUIT_LABELS = ['Schellen', 'Schilten', 'Eicheln', 'Rosen'];
|
||||||
private readonly VALUE_LABELS = ['6', '7', '8', '9', '10', '11', '12', '13'];
|
private readonly VALUE_LABELS = ['6', '7', '8', '9', '10', '11', '12', '13', '14'];
|
||||||
|
|
||||||
async init(): Promise<void> {
|
async init(): Promise<void> {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
82
train_model.py
Normal file
82
train_model.py
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import layers, models, optimizers
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
def create_model(num_classes):
|
||||||
|
# Use MobileNetV2 as the base model
|
||||||
|
base_model = tf.keras.applications.MobileNetV2(
|
||||||
|
input_shape=(64, 64, 3),
|
||||||
|
include_top=False,
|
||||||
|
weights='imagenet'
|
||||||
|
)
|
||||||
|
base_model.trainable = False # Freeze base model for transfer learning
|
||||||
|
|
||||||
|
model = models.Sequential([
|
||||||
|
base_model,
|
||||||
|
layers.GlobalAveragePooling2D(),
|
||||||
|
layers.Dropout(0.2),
|
||||||
|
layers.Dense(num_classes, activation='softmax')
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer=optimizers.Adam(),
|
||||||
|
loss='categorical_crossentropy',
|
||||||
|
metrics=['accuracy']
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def train(model_type, dataset_root, output_path):
|
||||||
|
# Set parameters based on model type
|
||||||
|
if model_type == 'suit':
|
||||||
|
num_classes = 4
|
||||||
|
elif model_type == 'value':
|
||||||
|
num_classes = 9
|
||||||
|
else:
|
||||||
|
raise ValueError("model_type must be 'suit' or 'value'")
|
||||||
|
|
||||||
|
# Data augmentation as per ML_SETUP_GUIDE.md
|
||||||
|
datagen = ImageDataGenerator(
|
||||||
|
rescale=1./255,
|
||||||
|
rotation_range=15,
|
||||||
|
brightness_range=[0.8, 1.2],
|
||||||
|
validation_split=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
train_generator = datagen.flow_from_directory(
|
||||||
|
os.path.join(dataset_root, f'{model_type}_model'),
|
||||||
|
target_size=(64, 64),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='categorical',
|
||||||
|
subset='training'
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_generator = datagen.flow_from_directory(
|
||||||
|
os.path.join(dataset_root, f'{model_type}_model'),
|
||||||
|
target_size=(64, 64),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='categorical',
|
||||||
|
subset='validation'
|
||||||
|
)
|
||||||
|
|
||||||
|
model = create_model(num_classes)
|
||||||
|
|
||||||
|
print(f"Training {model_type} model...")
|
||||||
|
model.fit(
|
||||||
|
train_generator,
|
||||||
|
epochs=20,
|
||||||
|
validation_data=validation_generator
|
||||||
|
)
|
||||||
|
|
||||||
|
model.save(output_path)
|
||||||
|
print(f"Model saved to {output_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--type', choices=['suit', 'value'], required=True)
|
||||||
|
parser.add_argument('--dataset', default='dataset')
|
||||||
|
parser.add_argument('--output', required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train(args.type, args.dataset, args.output)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue