FashionFlow/flow.py

53 lines
1.2 KiB
Python
Raw Normal View History

2018-12-27 15:47:45 +00:00
#!/usr/bin/python3
"""My tensorflow keras playground"""
2018-12-27 15:47:45 +00:00
import tensorflow as tf
from tensorflow import keras
from graph import plot_training_acc
2018-12-27 15:47:45 +00:00
2018-12-29 22:55:23 +00:00
print("Running TensorFlow", tf.__version__)
2018-12-27 15:47:45 +00:00
def model():
model = keras.Sequential(
[
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax),
]
)
model.compile(
optimizer=tf.train.AdamOptimizer(),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
2018-12-27 21:29:24 +00:00
)
return model
2018-12-27 21:29:24 +00:00
if __name__ == "__main__":
fashion_mnist = keras.datasets.fashion_mnist
2018-12-27 21:29:24 +00:00
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
2018-12-27 21:29:24 +00:00
train_images = train_images / 255.0
test_images = test_images / 255.0
model = model()
early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)
history = model.fit(
train_images,
train_labels,
epochs=64,
batch_size=1024,
validation_data=(test_images, test_labels),
callbacks=[early_stop],
)
2018-12-27 21:29:24 +00:00
plot_training_acc(history)