|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 22, |
| 6 | + "metadata": { |
| 7 | + "collapsed": true |
| 8 | + }, |
| 9 | + "outputs": [], |
| 10 | + "source": [ |
| 11 | + "# 全家桶 使用 Pytorch 必备 具体功能且看下文\n", |
| 12 | + "import torch\n", |
| 13 | + "import torch.nn as nn\n", |
| 14 | + "import torch.nn.functional as F\n", |
| 15 | + "import torch.optim as optim\n", |
| 16 | + "from torch.autograd import Variable" |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "code", |
| 21 | + "execution_count": 62, |
| 22 | + "metadata": {}, |
| 23 | + "outputs": [], |
| 24 | + "source": [ |
| 25 | + "class CNN(nn.Module):\n", |
| 26 | + " def __init__(self, output_dimesion, vocab_size, dropout_rate, emb_dim, max_len, n_filters, init_W=None):\n", |
| 27 | + " # number_filters\n", |
| 28 | + " super(CNN, self).__init__()\n", |
| 29 | + "\n", |
| 30 | + " self.max_len = max_len\n", |
| 31 | + " max_features = vocab_size\n", |
| 32 | + " vanila_dimension = 200 #倒数第二层的节点数\n", |
| 33 | + " projection_dimension = output_dimesion #输出层的节点数\n", |
| 34 | + " self.qual_conv_set = {} \n", |
| 35 | + "\n", |
| 36 | + " '''Embedding Layer'''\n", |
| 37 | + " if init_W is None:\n", |
| 38 | + " # 先尝试使用embedding随机赋值\n", |
| 39 | + " self.embedding = nn.Embedding(max_features, emb_dim)\n", |
| 40 | + "\n", |
| 41 | + " self.conv1 = nn.Sequential(\n", |
| 42 | + " # 卷积层的激活函数\n", |
| 43 | + " nn.Conv2d(1, n_filters, kernel_size=(3, emb_dim)),\n", |
| 44 | + " nn.ReLU(),\n", |
| 45 | + " nn.MaxPool2d(kernel_size=(max_len - 3 + 1, 1))\n", |
| 46 | + " )\n", |
| 47 | + " self.conv2 = nn.Sequential(\n", |
| 48 | + " nn.Conv2d(1, n_filters, kernel_size=(4, emb_dim)),\n", |
| 49 | + " nn.ReLU(),\n", |
| 50 | + " nn.MaxPool2d(kernel_size=(max_len - 4 + 1, 1))\n", |
| 51 | + " )\n", |
| 52 | + " self.conv3 = nn.Sequential(\n", |
| 53 | + " nn.Conv2d(1, n_filters, kernel_size=(5, emb_dim)),\n", |
| 54 | + " nn.ReLU(),\n", |
| 55 | + " nn.MaxPool2d(kernel_size=(max_len - 5 + 1, 1))\n", |
| 56 | + " )\n", |
| 57 | + " \n", |
| 58 | + " '''Dropout Layer'''\n", |
| 59 | + " #layer = Dense(vanila_dimension, activation='tanh')(flatten_layer)\n", |
| 60 | + " #layer = Dropout(dropout_rate)(layer)\n", |
| 61 | + " self.layer = nn.Linear(300, vanila_dimension)\n", |
| 62 | + " self.dropout = nn.Dropout(dropout_rate)\n", |
| 63 | + "\n", |
| 64 | + " '''Projection Layer & Output Layer'''\n", |
| 65 | + " #output_layer = Dense(projection_dimension, activation='tanh')(layer)\n", |
| 66 | + " self.output_layer = nn.Linear(vanila_dimension, projection_dimension)\n", |
| 67 | + "\n", |
| 68 | + " \n", |
| 69 | + "\n", |
| 70 | + " def forward(self, input):\n", |
| 71 | + " embeds = self.embedding(input)\n", |
| 72 | + " # concatenate the tensors\n", |
| 73 | + " x = self.conv_1(embeds)\n", |
| 74 | + " y = self.conv_2(embeds)\n", |
| 75 | + " z = self.conv_3(embeds)\n", |
| 76 | + " flatten = torch.cat((x,view(-1), y.view(-1), z.view(-1)))\n", |
| 77 | + " \n", |
| 78 | + " out = F.tanh(self.layer(flatten))\n", |
| 79 | + " out = self.dropout(out)\n", |
| 80 | + " out = F.tanh(self.output_layer(out)) \n", |
| 81 | + " \n", |
| 82 | + "cnn = CNN(50, 8000, 0.5, 50, 150, 100)" |
| 83 | + ] |
| 84 | + }, |
| 85 | + { |
| 86 | + "cell_type": "code", |
| 87 | + "execution_count": 63, |
| 88 | + "metadata": {}, |
| 89 | + "outputs": [ |
| 90 | + { |
| 91 | + "data": { |
| 92 | + "text/plain": [ |
| 93 | + "CNN(\n", |
| 94 | + " (embedding): Embedding(8000, 50)\n", |
| 95 | + " (conv1): Sequential(\n", |
| 96 | + " (0): Conv2d (1, 100, kernel_size=(3, 50), stride=(1, 1))\n", |
| 97 | + " (1): ReLU()\n", |
| 98 | + " (2): MaxPool2d(kernel_size=(148, 1), stride=(148, 1), dilation=(1, 1))\n", |
| 99 | + " )\n", |
| 100 | + " (conv2): Sequential(\n", |
| 101 | + " (0): Conv2d (1, 100, kernel_size=(4, 50), stride=(1, 1))\n", |
| 102 | + " (1): ReLU()\n", |
| 103 | + " (2): MaxPool2d(kernel_size=(147, 1), stride=(147, 1), dilation=(1, 1))\n", |
| 104 | + " )\n", |
| 105 | + " (conv3): Sequential(\n", |
| 106 | + " (0): Conv2d (1, 100, kernel_size=(5, 50), stride=(1, 1))\n", |
| 107 | + " (1): ReLU()\n", |
| 108 | + " (2): MaxPool2d(kernel_size=(146, 1), stride=(146, 1), dilation=(1, 1))\n", |
| 109 | + " )\n", |
| 110 | + " (layer): Linear(in_features=300, out_features=200)\n", |
| 111 | + " (dropout): Dropout(p=0.5)\n", |
| 112 | + " (output_layer): Linear(in_features=200, out_features=50)\n", |
| 113 | + ")" |
| 114 | + ] |
| 115 | + }, |
| 116 | + "execution_count": 63, |
| 117 | + "metadata": {}, |
| 118 | + "output_type": "execute_result" |
| 119 | + } |
| 120 | + ], |
| 121 | + "source": [ |
| 122 | + "cnn" |
| 123 | + ] |
| 124 | + }, |
| 125 | + { |
| 126 | + "cell_type": "code", |
| 127 | + "execution_count": 13, |
| 128 | + "metadata": {}, |
| 129 | + "outputs": [ |
| 130 | + { |
| 131 | + "name": "stdout", |
| 132 | + "output_type": "stream", |
| 133 | + "text": [ |
| 134 | + "Load preprocessed rating data - ./data/preprocessed/ml-1m//ratings.all\n", |
| 135 | + "Load preprocessed document data - ./data/preprocessed/ml-1m//document.all\n" |
| 136 | + ] |
| 137 | + } |
| 138 | + ], |
| 139 | + "source": [ |
| 140 | + "from data_manager import Data_Factory\n", |
| 141 | + "import pprint\n", |
| 142 | + "data_factory = Data_Factory()\n", |
| 143 | + "\n", |
| 144 | + "R, D_all = data_factory.load(\"./data/preprocessed/ml-1m/\")" |
| 145 | + ] |
| 146 | + }, |
| 147 | + { |
| 148 | + "cell_type": "code", |
| 149 | + "execution_count": 59, |
| 150 | + "metadata": {}, |
| 151 | + "outputs": [ |
| 152 | + { |
| 153 | + "name": "stdout", |
| 154 | + "output_type": "stream", |
| 155 | + "text": [ |
| 156 | + "3544\n", |
| 157 | + "95\n", |
| 158 | + "[2497, 7513, 6630, 4814, 1994, 2754, 3900, 5018, 4346, 7235, 2533, 2610, 2633, 4156, 249, 2161, 1127, 146, 6530, 5018, 337, 6530, 6985, 6530, 4157, 3071, 6530, 3900, 1500, 4316, 7833, 5018, 5150, 7102, 6530, 6476, 6530, 1394, 4450, 6751, 1238, 7824, 6530, 740, 3773, 7062, 5917, 2514, 1171, 3782, 5251, 2992, 2353, 1496, 7819, 6530, 2101, 1496, 7446, 5832, 1052, 4109, 1865, 7355, 7769, 1496, 3590, 2271, 7458, 5529, 6087, 475, 6530, 2063, 1908, 2497, 2754, 3379, 4161, 5526, 6474, 2535, 7934, 3782, 6530, 5150, 807, 1354, 172, 4156, 355, 3417, 249, 2168, 1649]\n" |
| 159 | + ] |
| 160 | + }, |
| 161 | + { |
| 162 | + "ename": "AttributeError", |
| 163 | + "evalue": "'list' object has no attribute 'shape'", |
| 164 | + "output_type": "error", |
| 165 | + "traceback": [ |
| 166 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 167 | + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", |
| 168 | + "\u001b[0;32m<ipython-input-59-2917a9b68dc8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", |
| 169 | + "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'shape'" |
| 170 | + ] |
| 171 | + } |
| 172 | + ], |
| 173 | + "source": [ |
| 174 | + "CNN_X = D_all['X_sequence']\n", |
| 175 | + "print(len(CNN_X))\n", |
| 176 | + "print(len(CNN_X[3]))\n", |
| 177 | + "print(CNN_X[3])\n", |
| 178 | + "print(CNN_X)" |
| 179 | + ] |
| 180 | + }, |
| 181 | + { |
| 182 | + "cell_type": "code", |
| 183 | + "execution_count": 9, |
| 184 | + "metadata": {}, |
| 185 | + "outputs": [ |
| 186 | + { |
| 187 | + "data": { |
| 188 | + "text/plain": [ |
| 189 | + "8000" |
| 190 | + ] |
| 191 | + }, |
| 192 | + "execution_count": 9, |
| 193 | + "metadata": {}, |
| 194 | + "output_type": "execute_result" |
| 195 | + } |
| 196 | + ], |
| 197 | + "source": [ |
| 198 | + "len(D_all['X_vocab'])" |
| 199 | + ] |
| 200 | + }, |
| 201 | + { |
| 202 | + "cell_type": "code", |
| 203 | + "execution_count": 48, |
| 204 | + "metadata": {}, |
| 205 | + "outputs": [ |
| 206 | + { |
| 207 | + "name": "stdout", |
| 208 | + "output_type": "stream", |
| 209 | + "text": [ |
| 210 | + "Variable containing:\n", |
| 211 | + " 1.1375 0.6195 0.1585 1.0799 0.1302\n", |
| 212 | + "-0.5405 -0.9589 -1.3669 1.2314 1.9734\n", |
| 213 | + "-0.4789 0.5938 0.1744 -0.0176 -0.0497\n", |
| 214 | + "-0.2310 -1.1388 0.7172 -0.4343 0.7839\n", |
| 215 | + " 0.5238 0.7899 -0.5901 1.0298 0.3844\n", |
| 216 | + "-1.4921 1.8542 -1.1308 0.7227 -1.6314\n", |
| 217 | + "-0.9999 0.4745 0.3701 0.2189 0.4824\n", |
| 218 | + " 0.0339 1.6608 0.5456 -2.0539 0.0004\n", |
| 219 | + " 0.0580 0.9189 1.2705 1.6964 -0.6851\n", |
| 220 | + "-0.4247 -1.4672 0.5220 0.0431 -0.2025\n", |
| 221 | + " 1.0033 -1.0548 1.1176 0.5650 -1.4660\n", |
| 222 | + "-0.8414 1.8125 1.8854 -1.6015 -0.6787\n", |
| 223 | + "-0.8838 0.0412 -0.6423 1.7509 -1.9570\n", |
| 224 | + " 0.5814 -1.5999 0.6436 1.4211 -1.3188\n", |
| 225 | + "-0.4954 -0.6092 -1.6808 -1.0020 0.1801\n", |
| 226 | + "-0.9836 -0.0847 -1.2562 -0.1226 -0.2108\n", |
| 227 | + "-1.3440 -0.1142 -1.2649 0.2782 -1.4181\n", |
| 228 | + "-0.0528 0.0718 -0.6514 1.1687 -1.0889\n", |
| 229 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 230 | + " 0.0339 1.6608 0.5456 -2.0539 0.0004\n", |
| 231 | + "-0.7024 -0.6674 -1.9162 -0.1312 1.1091\n", |
| 232 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 233 | + " 0.2818 0.5606 -0.3546 -0.6588 -0.7651\n", |
| 234 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 235 | + "-0.7891 -1.7500 0.1098 0.8820 0.5139\n", |
| 236 | + " 1.2017 0.5298 -0.7179 -1.1478 -1.6993\n", |
| 237 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 238 | + "-0.9999 0.4745 0.3701 0.2189 0.4824\n", |
| 239 | + " 0.0888 -0.0128 1.5520 1.2025 0.6651\n", |
| 240 | + " 0.6077 0.5434 -1.5032 1.5325 1.8256\n", |
| 241 | + "-0.4112 -1.2229 -0.2878 0.6258 1.1456\n", |
| 242 | + " 0.0339 1.6608 0.5456 -2.0539 0.0004\n", |
| 243 | + " 0.1364 -0.6930 -2.3371 1.6786 0.5617\n", |
| 244 | + " 1.0285 -1.7050 -0.4896 -1.0000 0.9725\n", |
| 245 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 246 | + "-0.3655 0.4535 -0.4016 0.2056 -1.3832\n", |
| 247 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 248 | + " 0.1841 -0.2055 1.9259 1.5805 0.2368\n", |
| 249 | + "-0.4701 -0.7426 -1.1546 -1.3005 -1.7871\n", |
| 250 | + " 0.2266 -0.5332 -1.2338 0.4280 2.4386\n", |
| 251 | + "-0.6303 0.5834 0.5205 -0.8387 0.4257\n", |
| 252 | + " 0.5106 -0.4741 0.8534 -0.0879 0.0737\n", |
| 253 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 254 | + "-0.6402 1.6770 1.0849 0.3854 -1.0779\n", |
| 255 | + " 0.7510 -2.0220 -0.0449 -1.5944 -1.0741\n", |
| 256 | + " 0.0605 0.4658 -0.6328 -0.2047 0.2944\n", |
| 257 | + "-0.4521 0.4285 0.3141 0.3153 0.7379\n", |
| 258 | + "-0.2910 0.7501 1.2844 0.8987 -1.4570\n", |
| 259 | + " 0.2266 0.4233 -0.7622 0.6053 0.9736\n", |
| 260 | + "-0.5485 -0.0073 0.7028 0.4528 -1.2437\n", |
| 261 | + " 0.3651 -0.7326 1.1882 0.6137 -1.1131\n", |
| 262 | + "-0.2044 -1.9507 -0.3135 -1.3187 0.6094\n", |
| 263 | + "-1.1596 1.4216 -0.5054 -0.3568 -0.5185\n", |
| 264 | + "-0.4572 0.0472 -1.4310 0.5741 -1.2894\n", |
| 265 | + "-0.7071 1.8620 1.1305 -1.1232 1.5237\n", |
| 266 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 267 | + "-2.5396 1.3397 1.0959 -0.7480 0.0679\n", |
| 268 | + "-0.4572 0.0472 -1.4310 0.5741 -1.2894\n", |
| 269 | + " 1.6656 1.1903 -0.3698 0.2036 0.0240\n", |
| 270 | + " 0.7796 -1.7166 0.5709 0.0085 -1.1771\n", |
| 271 | + " 1.8073 0.3372 -0.1976 0.8187 0.1685\n", |
| 272 | + " 0.8279 -0.1674 -0.9651 -0.1265 0.0651\n", |
| 273 | + "-0.2603 -1.1816 0.3361 0.2628 0.8348\n", |
| 274 | + " 0.7354 0.1170 -0.9391 -2.4669 -1.3682\n", |
| 275 | + " 0.1860 -0.7448 -1.6378 -0.0045 1.5380\n", |
| 276 | + "-0.4572 0.0472 -1.4310 0.5741 -1.2894\n", |
| 277 | + " 2.1237 1.0455 -0.5948 0.0934 -1.6559\n", |
| 278 | + "-0.1634 -0.5910 0.2927 -0.0937 0.7996\n", |
| 279 | + " 2.6495 0.5423 -1.1649 -2.0393 0.2268\n", |
| 280 | + "-0.4307 -1.1426 -0.9575 -0.3125 -0.0436\n", |
| 281 | + " 0.2849 0.1704 -0.2270 0.0564 0.3925\n", |
| 282 | + " 2.3563 -0.5101 1.8536 0.4569 0.2821\n", |
| 283 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 284 | + " 1.6585 0.6344 -0.0001 -0.8202 0.1913\n", |
| 285 | + "-0.4837 -0.7519 -0.7759 -0.4802 -0.6648\n", |
| 286 | + " 1.1375 0.6195 0.1585 1.0799 0.1302\n", |
| 287 | + "-1.4921 1.8542 -1.1308 0.7227 -1.6314\n", |
| 288 | + "-1.1122 1.3342 -0.7807 -0.3339 -1.1619\n", |
| 289 | + " 0.1296 0.1896 1.2773 0.1513 -0.0704\n", |
| 290 | + " 0.9736 0.8593 0.3178 -2.2234 0.3245\n", |
| 291 | + "-0.7505 -0.9183 -1.8172 -0.0884 1.0104\n", |
| 292 | + "-0.8394 0.9989 -0.3466 -0.7640 -0.3779\n", |
| 293 | + "-0.6200 0.5447 -0.6092 -0.0782 -1.1962\n", |
| 294 | + "-0.5485 -0.0073 0.7028 0.4528 -1.2437\n", |
| 295 | + " 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n", |
| 296 | + " 0.1364 -0.6930 -2.3371 1.6786 0.5617\n", |
| 297 | + "-1.1328 1.9744 -0.6251 0.9932 0.2207\n", |
| 298 | + "-1.6040 -0.5013 0.0782 1.1310 -0.4072\n", |
| 299 | + " 0.0398 -0.3110 0.3703 0.6808 -0.5264\n", |
| 300 | + " 0.5814 -1.5999 0.6436 1.4211 -1.3188\n", |
| 301 | + "-0.1824 -0.4074 -0.1582 -0.4725 1.2616\n", |
| 302 | + "-0.3176 -1.0342 0.9127 1.4634 -0.1190\n", |
| 303 | + "-0.4954 -0.6092 -1.6808 -1.0020 0.1801\n", |
| 304 | + " 0.4479 -0.5058 -2.0886 -2.4117 0.6307\n", |
| 305 | + " 1.4570 0.4706 -1.3763 -0.6453 0.4371\n", |
| 306 | + "[torch.FloatTensor of size 95x5]\n", |
| 307 | + "\n" |
| 308 | + ] |
| 309 | + } |
| 310 | + ], |
| 311 | + "source": [ |
| 312 | + "embeds = nn.Embedding(8000, 5)\n", |
| 313 | + "test = CNN_X[3]\n", |
| 314 | + "tensor = torch.LongTensor(list(map(int, test)))\n", |
| 315 | + "me_embed = embeds(Variable(tensor))\n", |
| 316 | + "print(me_embed)" |
| 317 | + ] |
| 318 | + }, |
| 319 | + { |
| 320 | + "cell_type": "code", |
| 321 | + "execution_count": 49, |
| 322 | + "metadata": {}, |
| 323 | + "outputs": [ |
| 324 | + { |
| 325 | + "data": { |
| 326 | + "text/plain": [ |
| 327 | + "<map at 0x7fd13d822ac8>" |
| 328 | + ] |
| 329 | + }, |
| 330 | + "execution_count": 49, |
| 331 | + "metadata": {}, |
| 332 | + "output_type": "execute_result" |
| 333 | + } |
| 334 | + ], |
| 335 | + "source": [ |
| 336 | + "map(int, test)" |
| 337 | + ] |
| 338 | + }, |
| 339 | + { |
| 340 | + "cell_type": "code", |
| 341 | + "execution_count": null, |
| 342 | + "metadata": { |
| 343 | + "collapsed": true |
| 344 | + }, |
| 345 | + "outputs": [], |
| 346 | + "source": [] |
| 347 | + } |
| 348 | + ], |
| 349 | + "metadata": { |
| 350 | + "kernelspec": { |
| 351 | + "display_name": "Python 3", |
| 352 | + "language": "python", |
| 353 | + "name": "python3" |
| 354 | + }, |
| 355 | + "language_info": { |
| 356 | + "codemirror_mode": { |
| 357 | + "name": "ipython", |
| 358 | + "version": 3 |
| 359 | + }, |
| 360 | + "file_extension": ".py", |
| 361 | + "mimetype": "text/x-python", |
| 362 | + "name": "python", |
| 363 | + "nbconvert_exporter": "python", |
| 364 | + "pygments_lexer": "ipython3", |
| 365 | + "version": "3.6.3" |
| 366 | + } |
| 367 | + }, |
| 368 | + "nbformat": 4, |
| 369 | + "nbformat_minor": 2 |
| 370 | +} |
0 commit comments