diff --git a/.gitignore b/.gitignore index aa41af3..a9f84f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +/venv +/dataset +*.h5 + # Dependencies node_modules/ .pnp/ @@ -79,4 +83,4 @@ yarn-cache/ *.sublime-* .sass-cache parcel-cache.json -.pnp.* \ No newline at end of file +.pnp.* diff --git a/ML_SETUP_GUIDE.md b/ML_SETUP_GUIDE.md index 436bb61..235bb0b 100644 --- a/ML_SETUP_GUIDE.md +++ b/ML_SETUP_GUIDE.md @@ -12,7 +12,7 @@ The system uses a **Hybrid Pipeline** to balance performance on mobile devices w 2. **Classification**: Crops each detected card and passes it through two specialized TensorFlow.js models to determine the **Suit** and the **Value**. ## 2. Data Collection & Labeling -Because Jass cards have unique iconography, we use a custom dataset created from internet samples. +Because Jass cards have unique iconography, we use a custom dataset. ### Labeling Strategy: Folder-Based Annotation Instead of using bounding-box tools, we use the directory structure as labels. @@ -29,7 +29,7 @@ dataset/ ├── 6/ ├── 7/ ... - └── 13/ + └── 14/ ``` ### Process: diff --git a/convert_model.sh b/convert_model.sh new file mode 100755 index 0000000..28c4a34 --- /dev/null +++ b/convert_model.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Check if tensorflowjs is installed +if ! command -v tensorflowjs_converter &> /dev/null +then + echo "tensorflowjs_converter could not be found. Please install it using: pip install tensorflowjs" + exit 1 +fi + +MODEL_TYPE=$1 +INPUT_MODEL=$2 +OUTPUT_DIR=$3 + +if [ -z "$MODEL_TYPE" ] || [ -z "$INPUT_MODEL" ] || [ -z "$OUTPUT_DIR" ]; then + echo "Usage: ./convert_model.sh [suit|value] [path_to_h5_model] [output_directory]" + exit 1 +fi + +echo "Converting $MODEL_TYPE model from $INPUT_MODEL to $OUTPUT_DIR..." + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Perform conversion +tensorflowjs_converter --input_format=keras "$INPUT_MODEL" "$OUTPUT_DIR" + +if [ $? -eq 0 ]; then + echo "Conversion successful. Model saved to $OUTPUT_DIR" +else + echo "Conversion failed." + exit 1 +fi diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..e6a3c16 --- /dev/null +++ b/train_model.py @@ -0,0 +1,82 @@ +import os +import tensorflow as tf +from tensorflow.keras import layers, models, optimizers +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +def create_model(num_classes): + # Use MobileNetV2 as the base model + base_model = tf.keras.applications.MobileNetV2( + input_shape=(64, 64, 3), + include_top=False, + weights='imagenet' + ) + base_model.trainable = False # Freeze base model for transfer learning + + model = models.Sequential([ + base_model, + layers.GlobalAveragePooling2D(), + layers.Dropout(0.2), + layers.Dense(num_classes, activation='softmax') + ]) + + model.compile( + optimizer=optimizers.Adam(), + loss='categorical_crossentropy', + metrics=['accuracy'] + ) + return model + +def train(model_type, dataset_root, output_path): + # Set parameters based on model type + if model_type == 'suit': + num_classes = 4 + elif model_type == 'value': + num_classes = 9 + else: + raise ValueError("model_type must be 'suit' or 'value'") + + # Data augmentation as per ML_SETUP_GUIDE.md + datagen = ImageDataGenerator( + rescale=1./255, + rotation_range=15, + brightness_range=[0.8, 1.2], + validation_split=0.2 + ) + + train_generator = datagen.flow_from_directory( + os.path.join(dataset_root, f'{model_type}_model'), + target_size=(64, 64), + batch_size=32, + class_mode='categorical', + subset='training' + ) + + validation_generator = datagen.flow_from_directory( + os.path.join(dataset_root, f'{model_type}_model'), + target_size=(64, 64), + batch_size=32, + class_mode='categorical', + subset='validation' + ) + + model = create_model(num_classes) + + print(f"Training {model_type} model...") + model.fit( + train_generator, + epochs=20, + validation_data=validation_generator + ) + + model.save(output_path) + print(f"Model saved to {output_path}") + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--type', choices=['suit', 'value'], required=True) + parser.add_argument('--dataset', default='dataset') + parser.add_argument('--output', required=True) + args = parser.parse_args() + + train(args.type, args.dataset, args.output)