diff --git a/flow.py b/flow.py index 2e774c1..6a119fd 100755 --- a/flow.py +++ b/flow.py @@ -1,128 +1,52 @@ #!/usr/bin/python3 +"""My tensorflow keras playground""" + import tensorflow as tf from tensorflow import keras -import numpy as np -import matplotlib.pyplot as plt - -import random +from graph import plot_training_acc print("Running TensorFlow", tf.__version__) -fashion_mnist = keras.datasets.fashion_mnist -(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() - -train_images = train_images / 255.0 -test_images = test_images / 255.0 - -class_names = [ - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", -] - -model = keras.Sequential( - [ - keras.layers.Flatten(input_shape=(28, 28)), - keras.layers.Dense(256, activation=tf.nn.relu), - keras.layers.Dense(10, activation=tf.nn.softmax), - ] -) - -model.compile( - optimizer=tf.train.AdamOptimizer(), - loss="sparse_categorical_crossentropy", - metrics=["accuracy"], -) - - -def plot_training(history): - acc = history.history["acc"] - val_acc = history.history["val_acc"] - - epochs = range(1, len(acc) + 1) - - plt.plot(epochs, acc, "bo", label="Training acc") - plt.plot(epochs, val_acc, "b", label="Validation acc") - plt.title("Training and validation accuracy") - plt.xlabel("Epochs") - plt.ylabel("Accuracy") - plt.legend() - - plt.show() - - -early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5) - -history = model.fit( - train_images, - train_labels, - epochs=64, - batch_size=512, - validation_data=(test_images, test_labels), - callbacks=[early_stop], -) - -plot_training(history) - -predictions = model.predict(test_images) - - -def plot_image(i, predictions_array, true_label, img): - predictions_array, true_label, img = predictions_array[i], true_label[i], img[i] - plt.grid(False) - plt.xticks([]) - plt.yticks([]) - - plt.imshow(img, cmap=plt.cm.binary) - - predicted_label = np.argmax(predictions_array) - if predicted_label == true_label: - color = "blue" - else: - color = "red" - - plt.xlabel( - "{} {:2.0f}% ({})".format( - class_names[predicted_label], - 100 * np.max(predictions_array), - class_names[true_label], - ), - color=color, +def model(): + model = keras.Sequential( + [ + keras.layers.Flatten(input_shape=(28, 28)), + keras.layers.Dense(128, activation=tf.nn.relu), + keras.layers.Dense(10, activation=tf.nn.softmax), + ] ) + model.compile( + optimizer=tf.train.AdamOptimizer(), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) -def plot_value_array(i, predictions_array, true_label): - predictions_array, true_label = predictions_array[i], true_label[i] - plt.grid(False) - plt.xticks([]) - plt.yticks([]) - thisplot = plt.bar(range(10), predictions_array, color="#777777") - plt.ylim([0, 1]) - predicted_label = np.argmax(predictions_array) - - thisplot[predicted_label].set_color("red") - thisplot[true_label].set_color("blue") + return model -num_rows = 5 -num_cols = 5 -num_images = num_rows * num_cols -plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows)) -for i in range(num_images): - image_idx = random.randint(0, len(test_images) - 1) - plt.subplot(num_rows, 2 * num_cols, 2 * i + 1) - plot_image(image_idx, predictions, test_labels, test_images) - plt.subplot(num_rows, 2 * num_cols, 2 * i + 2) - plot_value_array(image_idx, predictions, test_labels) +if __name__ == "__main__": + fashion_mnist = keras.datasets.fashion_mnist -plt.show() + (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() + + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + model = model() + + early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5) + + history = model.fit( + train_images, + train_labels, + epochs=64, + batch_size=1024, + validation_data=(test_images, test_labels), + callbacks=[early_stop], + ) + + plot_training_acc(history) diff --git a/graph.py b/graph.py new file mode 100644 index 0000000..485b978 --- /dev/null +++ b/graph.py @@ -0,0 +1,19 @@ +"""Tensorflow graphs""" + +import matplotlib.pyplot as plt + +def plot_training_acc(history): + """Plot training and validation accuracy""" + acc = history.history["acc"] + val_acc = history.history["val_acc"] + + epochs = range(1, len(acc) + 1) + + plt.plot(epochs, acc, "bo", label="Training acc") + plt.plot(epochs, val_acc, "b", label="Validation acc") + plt.title("Training and validation accuracy") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.legend() + + plt.show()