82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
import os
|
|
import tensorflow as tf
|
|
from tensorflow.keras import layers, models, optimizers
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
|
def create_model(num_classes):
|
|
# Use MobileNetV2 as the base model
|
|
base_model = tf.keras.applications.MobileNetV2(
|
|
input_shape=(64, 64, 3),
|
|
include_top=False,
|
|
weights='imagenet'
|
|
)
|
|
base_model.trainable = False # Freeze base model for transfer learning
|
|
|
|
model = models.Sequential([
|
|
base_model,
|
|
layers.GlobalAveragePooling2D(),
|
|
layers.Dropout(0.2),
|
|
layers.Dense(num_classes, activation='softmax')
|
|
])
|
|
|
|
model.compile(
|
|
optimizer=optimizers.Adam(),
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy']
|
|
)
|
|
return model
|
|
|
|
def train(model_type, dataset_root, output_path):
|
|
# Set parameters based on model type
|
|
if model_type == 'suit':
|
|
num_classes = 4
|
|
elif model_type == 'value':
|
|
num_classes = 9
|
|
else:
|
|
raise ValueError("model_type must be 'suit' or 'value'")
|
|
|
|
# Data augmentation as per ML_SETUP_GUIDE.md
|
|
datagen = ImageDataGenerator(
|
|
rescale=1./255,
|
|
rotation_range=15,
|
|
brightness_range=[0.8, 1.2],
|
|
validation_split=0.2
|
|
)
|
|
|
|
train_generator = datagen.flow_from_directory(
|
|
os.path.join(dataset_root, f'{model_type}_model'),
|
|
target_size=(64, 64),
|
|
batch_size=32,
|
|
class_mode='categorical',
|
|
subset='training'
|
|
)
|
|
|
|
validation_generator = datagen.flow_from_directory(
|
|
os.path.join(dataset_root, f'{model_type}_model'),
|
|
target_size=(64, 64),
|
|
batch_size=32,
|
|
class_mode='categorical',
|
|
subset='validation'
|
|
)
|
|
|
|
model = create_model(num_classes)
|
|
|
|
print(f"Training {model_type} model...")
|
|
model.fit(
|
|
train_generator,
|
|
epochs=20,
|
|
validation_data=validation_generator
|
|
)
|
|
|
|
model.save(output_path)
|
|
print(f"Model saved to {output_path}")
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--type', choices=['suit', 'value'], required=True)
|
|
parser.add_argument('--dataset', default='dataset')
|
|
parser.add_argument('--output', required=True)
|
|
args = parser.parse_args()
|
|
|
|
train(args.type, args.dataset, args.output)
|