From 88a3924637221d0dac8cfe535b05bd6572a51559 Mon Sep 17 00:00:00 2001 From: Michael Mandl Date: Sat, 29 Dec 2018 23:55:23 +0100 Subject: [PATCH] Graph out training results --- flow.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/flow.py b/flow.py index ff80b37..2e774c1 100755 --- a/flow.py +++ b/flow.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import random -print(tf.__version__) +print("Running TensorFlow", tf.__version__) fashion_mnist = keras.datasets.fashion_mnist @@ -33,7 +33,7 @@ class_names = [ model = keras.Sequential( [ keras.layers.Flatten(input_shape=(28, 28)), - keras.layers.Dense(128, activation=tf.nn.relu), + keras.layers.Dense(256, activation=tf.nn.relu), keras.layers.Dense(10, activation=tf.nn.softmax), ] ) @@ -44,11 +44,35 @@ model.compile( metrics=["accuracy"], ) -model.fit(train_images, train_labels, epochs=5) -test_loss, test_acc = model.evaluate(test_images, test_labels) +def plot_training(history): + acc = history.history["acc"] + val_acc = history.history["val_acc"] -print("Test accuracy:", test_acc) + epochs = range(1, len(acc) + 1) + + plt.plot(epochs, acc, "bo", label="Training acc") + plt.plot(epochs, val_acc, "b", label="Validation acc") + plt.title("Training and validation accuracy") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.legend() + + plt.show() + + +early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5) + +history = model.fit( + train_images, + train_labels, + epochs=64, + batch_size=512, + validation_data=(test_images, test_labels), + callbacks=[early_stop], +) + +plot_training(history) predictions = model.predict(test_images)