Update ML setup guide and add model scripts
This commit is contained in:
parent
073d395ae6
commit
ff1561a704
4 changed files with 121 additions and 3 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -1,3 +1,7 @@
|
|||
/venv
|
||||
/dataset
|
||||
*.h5
|
||||
|
||||
# Dependencies
|
||||
node_modules/
|
||||
.pnp/
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The system uses a **Hybrid Pipeline** to balance performance on mobile devices w
|
|||
2. **Classification**: Crops each detected card and passes it through two specialized TensorFlow.js models to determine the **Suit** and the **Value**.
|
||||
|
||||
## 2. Data Collection & Labeling
|
||||
Because Jass cards have unique iconography, we use a custom dataset created from internet samples.
|
||||
Because Jass cards have unique iconography, we use a custom dataset.
|
||||
|
||||
### Labeling Strategy: Folder-Based Annotation
|
||||
Instead of using bounding-box tools, we use the directory structure as labels.
|
||||
|
|
@ -29,7 +29,7 @@ dataset/
|
|||
├── 6/
|
||||
├── 7/
|
||||
...
|
||||
└── 13/
|
||||
└── 14/
|
||||
```
|
||||
|
||||
### Process:
|
||||
|
|
|
|||
32
convert_model.sh
Executable file
32
convert_model.sh
Executable file
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Check if tensorflowjs is installed
|
||||
if ! command -v tensorflowjs_converter &> /dev/null
|
||||
then
|
||||
echo "tensorflowjs_converter could not be found. Please install it using: pip install tensorflowjs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODEL_TYPE=$1
|
||||
INPUT_MODEL=$2
|
||||
OUTPUT_DIR=$3
|
||||
|
||||
if [ -z "$MODEL_TYPE" ] || [ -z "$INPUT_MODEL" ] || [ -z "$OUTPUT_DIR" ]; then
|
||||
echo "Usage: ./convert_model.sh [suit|value] [path_to_h5_model] [output_directory]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Converting $MODEL_TYPE model from $INPUT_MODEL to $OUTPUT_DIR..."
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Perform conversion
|
||||
tensorflowjs_converter --input_format=keras "$INPUT_MODEL" "$OUTPUT_DIR"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Conversion successful. Model saved to $OUTPUT_DIR"
|
||||
else
|
||||
echo "Conversion failed."
|
||||
exit 1
|
||||
fi
|
||||
82
train_model.py
Normal file
82
train_model.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models, optimizers
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
def create_model(num_classes):
|
||||
# Use MobileNetV2 as the base model
|
||||
base_model = tf.keras.applications.MobileNetV2(
|
||||
input_shape=(64, 64, 3),
|
||||
include_top=False,
|
||||
weights='imagenet'
|
||||
)
|
||||
base_model.trainable = False # Freeze base model for transfer learning
|
||||
|
||||
model = models.Sequential([
|
||||
base_model,
|
||||
layers.GlobalAveragePooling2D(),
|
||||
layers.Dropout(0.2),
|
||||
layers.Dense(num_classes, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.Adam(),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
return model
|
||||
|
||||
def train(model_type, dataset_root, output_path):
|
||||
# Set parameters based on model type
|
||||
if model_type == 'suit':
|
||||
num_classes = 4
|
||||
elif model_type == 'value':
|
||||
num_classes = 9
|
||||
else:
|
||||
raise ValueError("model_type must be 'suit' or 'value'")
|
||||
|
||||
# Data augmentation as per ML_SETUP_GUIDE.md
|
||||
datagen = ImageDataGenerator(
|
||||
rescale=1./255,
|
||||
rotation_range=15,
|
||||
brightness_range=[0.8, 1.2],
|
||||
validation_split=0.2
|
||||
)
|
||||
|
||||
train_generator = datagen.flow_from_directory(
|
||||
os.path.join(dataset_root, f'{model_type}_model'),
|
||||
target_size=(64, 64),
|
||||
batch_size=32,
|
||||
class_mode='categorical',
|
||||
subset='training'
|
||||
)
|
||||
|
||||
validation_generator = datagen.flow_from_directory(
|
||||
os.path.join(dataset_root, f'{model_type}_model'),
|
||||
target_size=(64, 64),
|
||||
batch_size=32,
|
||||
class_mode='categorical',
|
||||
subset='validation'
|
||||
)
|
||||
|
||||
model = create_model(num_classes)
|
||||
|
||||
print(f"Training {model_type} model...")
|
||||
model.fit(
|
||||
train_generator,
|
||||
epochs=20,
|
||||
validation_data=validation_generator
|
||||
)
|
||||
|
||||
model.save(output_path)
|
||||
print(f"Model saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--type', choices=['suit', 'value'], required=True)
|
||||
parser.add_argument('--dataset', default='dataset')
|
||||
parser.add_argument('--output', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
train(args.type, args.dataset, args.output)
|
||||
Loading…
Add table
Add a link
Reference in a new issue