+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Utils","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyMYHshcegvtTBQwBygO/eoj"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"code","metadata":{"id":"_ImvYera1GfS","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1595963420315,"user_tz":-330,"elapsed":2543,"user":{"displayName":"Agrover112","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiMJACGAX3kCfRjB2hgzdG8w9zL1lAAKbPPMz0qLA=s64","userId":"09574164879083471944"}}},"source":["import tensorflow as tf\n","import numpy as np\n","import matplotlib.pyplot as plt\n","def load_data():\n"," (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n"," x_train = np.reshape(x_train, (x_train.shape[0], 784))/255.\n"," x_test = np.reshape(x_test, (x_test.shape[0], 784))/255.\n"," y_train = tf.keras.utils.to_categorical(y_train)\n"," y_test = tf.keras.utils.to_categorical(y_test)\n"," return (x_train, y_train), (x_test, y_test)\n","\n","def plot_random_examples(x, y, p=None):\n"," indices = np.random.choice(range(0, x.shape[0]), 10)\n"," y = np.argmax(y, axis=1)\n"," if p is None:\n"," p = y\n"," plt.figure(figsize=(10, 5))\n"," for i, index in enumerate(indices):\n"," plt.subplot(2, 5, i+1)\n"," plt.imshow(x[index].reshape((28, 28)), cmap='binary')\n"," plt.xticks([])\n"," plt.yticks([])\n"," if y[index] == p[index]:\n"," col = 'g'\n"," else:\n"," col = 'r'\n"," plt.xlabel(str(p[index]), color=col)\n"," return plt\n","\n","def plot_results(history):\n"," history = history.history\n"," plt.figure(figsize=(12, 4))\n"," epochs = len(history['val_loss'])\n"," plt.subplot(1, 2, 1)\n"," plt.plot(range(epochs), history['val_loss'], label='Val Loss')\n"," plt.plot(range(epochs), history['loss'], label='Train Loss')\n"," plt.xticks(list(range(epochs)))\n"," plt.xlabel('Epochs')\n"," plt.ylabel('Loss')\n"," plt.legend()\n"," plt.subplot(1, 2, 2)\n"," plt.plot(range(epochs), history['val_accuracy'], label='Val Acc')\n"," plt.plot(range(epochs), history['accuracy'], label='Acc')\n"," plt.xticks(list(range(epochs)))\n"," plt.xlabel('Epochs')\n"," plt.ylabel('Accuracy')\n"," plt.legend()\n"," return plt"],"execution_count":1,"outputs":[]},{"cell_type":"code","metadata":{"id":"iRwkOk0p1SPt","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]}]}
0 commit comments