Initial commit
This commit is contained in:
68
training/train.py
Normal file
68
training/train.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import tensorflow as tf
|
||||
from keras import layers, models
|
||||
from keras.preprocessing import image
|
||||
import pathlib
|
||||
import numpy as np
|
||||
|
||||
IMG_SIZE = 64
|
||||
BATCH_SIZE = 8
|
||||
NUM_CLASSES = 4
|
||||
EPOCHS = 10
|
||||
|
||||
data_dir = pathlib.Path("data")
|
||||
|
||||
train_ds = tf.keras.utils.image_dataset_from_directory(
|
||||
data_dir,
|
||||
labels='inferred',
|
||||
label_mode='categorical',
|
||||
color_mode='grayscale',
|
||||
batch_size=BATCH_SIZE,
|
||||
image_size=(IMG_SIZE, IMG_SIZE),
|
||||
validation_split=0.2,
|
||||
subset="training",
|
||||
seed=123
|
||||
)
|
||||
|
||||
val_ds = tf.keras.utils.image_dataset_from_directory(
|
||||
data_dir,
|
||||
labels='inferred',
|
||||
label_mode='categorical',
|
||||
color_mode='grayscale',
|
||||
batch_size=BATCH_SIZE,
|
||||
image_size=(IMG_SIZE, IMG_SIZE),
|
||||
validation_split=0.2,
|
||||
subset="validation",
|
||||
seed=123
|
||||
)
|
||||
|
||||
AUTOTUNE = tf.data.AUTOTUNE
|
||||
train_ds = train_ds.cache().shuffle(100).prefetch(buffer_size=AUTOTUNE)
|
||||
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
||||
|
||||
model = models.Sequential([
|
||||
layers.Rescaling(1/255, input_shape=(IMG_SIZE, IMG_SIZE, 1)),
|
||||
layers.Conv2D(32, (3,3), activation='relu'),
|
||||
layers.MaxPooling2D(2,2),
|
||||
|
||||
layers.Conv2D(64, (3,3), activation='relu'),
|
||||
layers.MaxPooling2D(2,2),
|
||||
|
||||
layers.Flatten(),
|
||||
layers.Dense(64, activation='relu'),
|
||||
layers.Dense(NUM_CLASSES, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS
|
||||
)
|
||||
|
||||
model.save("shape_model.keras")
|
||||
print("Saved model")
|
||||
Reference in New Issue
Block a user