From 6d57fa4650899f055bba2760e9a9304d89b1bb22 Mon Sep 17 00:00:00 2001 From: Michael Mandl Date: Thu, 27 Dec 2018 22:29:24 +0100 Subject: [PATCH] Plot image classification results --- flow.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/flow.py b/flow.py index 62a8ca3..ff80b37 100755 --- a/flow.py +++ b/flow.py @@ -49,3 +49,56 @@ model.fit(train_images, train_labels, epochs=5) test_loss, test_acc = model.evaluate(test_images, test_labels) print("Test accuracy:", test_acc) + +predictions = model.predict(test_images) + + +def plot_image(i, predictions_array, true_label, img): + predictions_array, true_label, img = predictions_array[i], true_label[i], img[i] + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + + plt.imshow(img, cmap=plt.cm.binary) + + predicted_label = np.argmax(predictions_array) + if predicted_label == true_label: + color = "blue" + else: + color = "red" + + plt.xlabel( + "{} {:2.0f}% ({})".format( + class_names[predicted_label], + 100 * np.max(predictions_array), + class_names[true_label], + ), + color=color, + ) + + +def plot_value_array(i, predictions_array, true_label): + predictions_array, true_label = predictions_array[i], true_label[i] + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + thisplot = plt.bar(range(10), predictions_array, color="#777777") + plt.ylim([0, 1]) + predicted_label = np.argmax(predictions_array) + + thisplot[predicted_label].set_color("red") + thisplot[true_label].set_color("blue") + + +num_rows = 5 +num_cols = 5 +num_images = num_rows * num_cols +plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows)) +for i in range(num_images): + image_idx = random.randint(0, len(test_images) - 1) + plt.subplot(num_rows, 2 * num_cols, 2 * i + 1) + plot_image(image_idx, predictions, test_labels, test_images) + plt.subplot(num_rows, 2 * num_cols, 2 * i + 2) + plot_value_array(image_idx, predictions, test_labels) + +plt.show()