Split into two files, removed unused code
This commit is contained in:
parent
88a3924637
commit
54db131db2
2 changed files with 56 additions and 113 deletions
150
flow.py
150
flow.py
|
@ -1,128 +1,52 @@
|
||||||
#!/usr/bin/python3
|
#!/usr/bin/python3
|
||||||
|
|
||||||
|
"""My tensorflow keras playground"""
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
|
||||||
import numpy as np
|
from graph import plot_training_acc
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
print("Running TensorFlow", tf.__version__)
|
print("Running TensorFlow", tf.__version__)
|
||||||
|
|
||||||
fashion_mnist = keras.datasets.fashion_mnist
|
|
||||||
|
|
||||||
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
|
def model():
|
||||||
|
model = keras.Sequential(
|
||||||
train_images = train_images / 255.0
|
[
|
||||||
test_images = test_images / 255.0
|
keras.layers.Flatten(input_shape=(28, 28)),
|
||||||
|
keras.layers.Dense(128, activation=tf.nn.relu),
|
||||||
class_names = [
|
keras.layers.Dense(10, activation=tf.nn.softmax),
|
||||||
"T-shirt/top",
|
]
|
||||||
"Trouser",
|
|
||||||
"Pullover",
|
|
||||||
"Dress",
|
|
||||||
"Coat",
|
|
||||||
"Sandal",
|
|
||||||
"Shirt",
|
|
||||||
"Sneaker",
|
|
||||||
"Bag",
|
|
||||||
"Ankle boot",
|
|
||||||
]
|
|
||||||
|
|
||||||
model = keras.Sequential(
|
|
||||||
[
|
|
||||||
keras.layers.Flatten(input_shape=(28, 28)),
|
|
||||||
keras.layers.Dense(256, activation=tf.nn.relu),
|
|
||||||
keras.layers.Dense(10, activation=tf.nn.softmax),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
model.compile(
|
|
||||||
optimizer=tf.train.AdamOptimizer(),
|
|
||||||
loss="sparse_categorical_crossentropy",
|
|
||||||
metrics=["accuracy"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_training(history):
|
|
||||||
acc = history.history["acc"]
|
|
||||||
val_acc = history.history["val_acc"]
|
|
||||||
|
|
||||||
epochs = range(1, len(acc) + 1)
|
|
||||||
|
|
||||||
plt.plot(epochs, acc, "bo", label="Training acc")
|
|
||||||
plt.plot(epochs, val_acc, "b", label="Validation acc")
|
|
||||||
plt.title("Training and validation accuracy")
|
|
||||||
plt.xlabel("Epochs")
|
|
||||||
plt.ylabel("Accuracy")
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)
|
|
||||||
|
|
||||||
history = model.fit(
|
|
||||||
train_images,
|
|
||||||
train_labels,
|
|
||||||
epochs=64,
|
|
||||||
batch_size=512,
|
|
||||||
validation_data=(test_images, test_labels),
|
|
||||||
callbacks=[early_stop],
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_training(history)
|
|
||||||
|
|
||||||
predictions = model.predict(test_images)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_image(i, predictions_array, true_label, img):
|
|
||||||
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
|
|
||||||
plt.grid(False)
|
|
||||||
plt.xticks([])
|
|
||||||
plt.yticks([])
|
|
||||||
|
|
||||||
plt.imshow(img, cmap=plt.cm.binary)
|
|
||||||
|
|
||||||
predicted_label = np.argmax(predictions_array)
|
|
||||||
if predicted_label == true_label:
|
|
||||||
color = "blue"
|
|
||||||
else:
|
|
||||||
color = "red"
|
|
||||||
|
|
||||||
plt.xlabel(
|
|
||||||
"{} {:2.0f}% ({})".format(
|
|
||||||
class_names[predicted_label],
|
|
||||||
100 * np.max(predictions_array),
|
|
||||||
class_names[true_label],
|
|
||||||
),
|
|
||||||
color=color,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer=tf.train.AdamOptimizer(),
|
||||||
|
loss="sparse_categorical_crossentropy",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
)
|
||||||
|
|
||||||
def plot_value_array(i, predictions_array, true_label):
|
return model
|
||||||
predictions_array, true_label = predictions_array[i], true_label[i]
|
|
||||||
plt.grid(False)
|
|
||||||
plt.xticks([])
|
|
||||||
plt.yticks([])
|
|
||||||
thisplot = plt.bar(range(10), predictions_array, color="#777777")
|
|
||||||
plt.ylim([0, 1])
|
|
||||||
predicted_label = np.argmax(predictions_array)
|
|
||||||
|
|
||||||
thisplot[predicted_label].set_color("red")
|
|
||||||
thisplot[true_label].set_color("blue")
|
|
||||||
|
|
||||||
|
|
||||||
num_rows = 5
|
if __name__ == "__main__":
|
||||||
num_cols = 5
|
fashion_mnist = keras.datasets.fashion_mnist
|
||||||
num_images = num_rows * num_cols
|
|
||||||
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
|
|
||||||
for i in range(num_images):
|
|
||||||
image_idx = random.randint(0, len(test_images) - 1)
|
|
||||||
plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
|
|
||||||
plot_image(image_idx, predictions, test_labels, test_images)
|
|
||||||
plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
|
|
||||||
plot_value_array(image_idx, predictions, test_labels)
|
|
||||||
|
|
||||||
plt.show()
|
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
|
||||||
|
|
||||||
|
train_images = train_images / 255.0
|
||||||
|
test_images = test_images / 255.0
|
||||||
|
|
||||||
|
model = model()
|
||||||
|
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)
|
||||||
|
|
||||||
|
history = model.fit(
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
epochs=64,
|
||||||
|
batch_size=1024,
|
||||||
|
validation_data=(test_images, test_labels),
|
||||||
|
callbacks=[early_stop],
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_training_acc(history)
|
||||||
|
|
19
graph.py
Normal file
19
graph.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
"""Tensorflow graphs"""
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def plot_training_acc(history):
|
||||||
|
"""Plot training and validation accuracy"""
|
||||||
|
acc = history.history["acc"]
|
||||||
|
val_acc = history.history["val_acc"]
|
||||||
|
|
||||||
|
epochs = range(1, len(acc) + 1)
|
||||||
|
|
||||||
|
plt.plot(epochs, acc, "bo", label="Training acc")
|
||||||
|
plt.plot(epochs, val_acc, "b", label="Validation acc")
|
||||||
|
plt.title("Training and validation accuracy")
|
||||||
|
plt.xlabel("Epochs")
|
||||||
|
plt.ylabel("Accuracy")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.show()
|
Loading…
Reference in a new issue