import os import tensorflow as tf from tensorflow.keras import layers, models, optimizers from tensorflow.keras.preprocessing.image import ImageDataGenerator def create_model(num_classes): # Use MobileNetV2 as the base model base_model = tf.keras.applications.MobileNetV2( input_shape=(64, 64, 3), include_top=False, weights='imagenet' ) base_model.trainable = False # Freeze base model for transfer learning model = models.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dropout(0.2), layers.Dense(num_classes, activation='softmax') ]) model.compile( optimizer=optimizers.Adam(), loss='categorical_crossentropy', metrics=['accuracy'] ) return model def train(model_type, dataset_root, output_path): # Set parameters based on model type if model_type == 'suit': num_classes = 4 elif model_type == 'value': num_classes = 9 else: raise ValueError("model_type must be 'suit' or 'value'") # Data augmentation as per ML_SETUP_GUIDE.md datagen = ImageDataGenerator( rescale=1./255, rotation_range=15, brightness_range=[0.8, 1.2], validation_split=0.2 ) train_generator = datagen.flow_from_directory( os.path.join(dataset_root, f'{model_type}_model'), target_size=(64, 64), batch_size=32, class_mode='categorical', subset='training' ) validation_generator = datagen.flow_from_directory( os.path.join(dataset_root, f'{model_type}_model'), target_size=(64, 64), batch_size=32, class_mode='categorical', subset='validation' ) model = create_model(num_classes) print(f"Training {model_type} model...") model.fit( train_generator, epochs=20, validation_data=validation_generator ) model.save(output_path) print(f"Model saved to {output_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--type', choices=['suit', 'value'], required=True) parser.add_argument('--dataset', default='dataset') parser.add_argument('--output', required=True) args = parser.parse_args() train(args.type, args.dataset, args.output)