- "print(__doc__)\n\n# Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve\n# License: BSD\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom scipy.ndimage import convolve\nfrom sklearn import linear_model, datasets, metrics\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.neural_network import BernoulliRBM\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.base import clone\n\n\n# #############################################################################\n# Setting up\n\ndef nudge_dataset(X, Y):\n \"\"\"\n This produces a dataset 5 times bigger than the original one,\n by moving the 8x8 images in X around by 1px to left, right, down, up\n \"\"\"\n direction_vectors = [\n [[0, 1, 0],\n [0, 0, 0],\n [0, 0, 0]],\n\n [[0, 0, 0],\n [1, 0, 0],\n [0, 0, 0]],\n\n [[0, 0, 0],\n [0, 0, 1],\n [0, 0, 0]],\n\n [[0, 0, 0],\n [0, 0, 0],\n [0, 1, 0]]]\n\n def shift(x, w):\n return convolve(x.reshape((8, 8)), mode='constant', weights=w).ravel()\n\n X = np.concatenate([X] +\n [np.apply_along_axis(shift, 1, X, vector)\n for vector in direction_vectors])\n Y = np.concatenate([Y for _ in range(5)], axis=0)\n return X, Y\n\n\n# Load Data\ndigits = datasets.load_digits()\nX = np.asarray(digits.data, 'float32')\nX, Y = nudge_dataset(X, digits.target)\nX = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 0-1 scaling\n\nX_train, X_test, Y_train, Y_test = train_test_split(\n X, Y, test_size=0.2, random_state=0)\n\n# Models we will use\nlogistic = linear_model.LogisticRegression(solver='newton-cg', tol=1,\n multi_class='multinomial')\nrbm = BernoulliRBM(random_state=0, verbose=True)\n\nrbm_features_classifier = Pipeline(\n steps=[('rbm', rbm), ('logistic', logistic)])\n\n# #############################################################################\n# Training\n\n# Hyper-parameters. These were set by cross-validation,\n# using a GridSearchCV. Here we are not performing cross-validation to\n# save time.\nrbm.learning_rate = 0.06\nrbm.n_iter = 20\n# More components tend to give better prediction performance, but larger\n# fitting time\nrbm.n_components = 100\nlogistic.C = 6000\n\n# Training RBM-Logistic Pipeline\nrbm_features_classifier.fit(X_train, Y_train)\n\n# Training the Logistic regression classifier directly on the pixel\nraw_pixel_classifier = clone(logistic)\nraw_pixel_classifier.C = 100.\nraw_pixel_classifier.fit(X_train, Y_train)\n\n# #############################################################################\n# Evaluation\n\nY_pred = rbm_features_classifier.predict(X_test)\nprint(\"Logistic regression using RBM features:\\n%s\\n\" % (\n metrics.classification_report(Y_test, Y_pred)))\n\nY_pred = raw_pixel_classifier.predict(X_test)\nprint(\"Logistic regression using raw pixel features:\\n%s\\n\" % (\n metrics.classification_report(Y_test, Y_pred)))\n\n# #############################################################################\n# Plotting\n\nplt.figure(figsize=(4.2, 4))\nfor i, comp in enumerate(rbm.components_):\n plt.subplot(10, 10, i + 1)\n plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,\n interpolation='nearest')\n plt.xticks(())\n plt.yticks(())\nplt.suptitle('100 components extracted by RBM', fontsize=16)\nplt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)\n\nplt.show()"
0 commit comments