Skip to content

Commit 10e9d9a

Browse files
author
captain
committed
update cnn_model code
1 parent f4e51e9 commit 10e9d9a

File tree

5 files changed

+409
-91
lines changed

5 files changed

+409
-91
lines changed

CNN_model.ipynb

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 22,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"# 全家桶 使用 Pytorch 必备 具体功能且看下文\n",
12+
"import torch\n",
13+
"import torch.nn as nn\n",
14+
"import torch.nn.functional as F\n",
15+
"import torch.optim as optim\n",
16+
"from torch.autograd import Variable"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 62,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"class CNN(nn.Module):\n",
26+
" def __init__(self, output_dimesion, vocab_size, dropout_rate, emb_dim, max_len, n_filters, init_W=None):\n",
27+
" # number_filters\n",
28+
" super(CNN, self).__init__()\n",
29+
"\n",
30+
" self.max_len = max_len\n",
31+
" max_features = vocab_size\n",
32+
" vanila_dimension = 200 #倒数第二层的节点数\n",
33+
" projection_dimension = output_dimesion #输出层的节点数\n",
34+
" self.qual_conv_set = {} \n",
35+
"\n",
36+
" '''Embedding Layer'''\n",
37+
" if init_W is None:\n",
38+
" # 先尝试使用embedding随机赋值\n",
39+
" self.embedding = nn.Embedding(max_features, emb_dim)\n",
40+
"\n",
41+
" self.conv1 = nn.Sequential(\n",
42+
" # 卷积层的激活函数\n",
43+
" nn.Conv2d(1, n_filters, kernel_size=(3, emb_dim)),\n",
44+
" nn.ReLU(),\n",
45+
" nn.MaxPool2d(kernel_size=(max_len - 3 + 1, 1))\n",
46+
" )\n",
47+
" self.conv2 = nn.Sequential(\n",
48+
" nn.Conv2d(1, n_filters, kernel_size=(4, emb_dim)),\n",
49+
" nn.ReLU(),\n",
50+
" nn.MaxPool2d(kernel_size=(max_len - 4 + 1, 1))\n",
51+
" )\n",
52+
" self.conv3 = nn.Sequential(\n",
53+
" nn.Conv2d(1, n_filters, kernel_size=(5, emb_dim)),\n",
54+
" nn.ReLU(),\n",
55+
" nn.MaxPool2d(kernel_size=(max_len - 5 + 1, 1))\n",
56+
" )\n",
57+
" \n",
58+
" '''Dropout Layer'''\n",
59+
" #layer = Dense(vanila_dimension, activation='tanh')(flatten_layer)\n",
60+
" #layer = Dropout(dropout_rate)(layer)\n",
61+
" self.layer = nn.Linear(300, vanila_dimension)\n",
62+
" self.dropout = nn.Dropout(dropout_rate)\n",
63+
"\n",
64+
" '''Projection Layer & Output Layer'''\n",
65+
" #output_layer = Dense(projection_dimension, activation='tanh')(layer)\n",
66+
" self.output_layer = nn.Linear(vanila_dimension, projection_dimension)\n",
67+
"\n",
68+
" \n",
69+
"\n",
70+
" def forward(self, input):\n",
71+
" embeds = self.embedding(input)\n",
72+
" # concatenate the tensors\n",
73+
" x = self.conv_1(embeds)\n",
74+
" y = self.conv_2(embeds)\n",
75+
" z = self.conv_3(embeds)\n",
76+
" flatten = torch.cat((x,view(-1), y.view(-1), z.view(-1)))\n",
77+
" \n",
78+
" out = F.tanh(self.layer(flatten))\n",
79+
" out = self.dropout(out)\n",
80+
" out = F.tanh(self.output_layer(out)) \n",
81+
" \n",
82+
"cnn = CNN(50, 8000, 0.5, 50, 150, 100)"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": 63,
88+
"metadata": {},
89+
"outputs": [
90+
{
91+
"data": {
92+
"text/plain": [
93+
"CNN(\n",
94+
" (embedding): Embedding(8000, 50)\n",
95+
" (conv1): Sequential(\n",
96+
" (0): Conv2d (1, 100, kernel_size=(3, 50), stride=(1, 1))\n",
97+
" (1): ReLU()\n",
98+
" (2): MaxPool2d(kernel_size=(148, 1), stride=(148, 1), dilation=(1, 1))\n",
99+
" )\n",
100+
" (conv2): Sequential(\n",
101+
" (0): Conv2d (1, 100, kernel_size=(4, 50), stride=(1, 1))\n",
102+
" (1): ReLU()\n",
103+
" (2): MaxPool2d(kernel_size=(147, 1), stride=(147, 1), dilation=(1, 1))\n",
104+
" )\n",
105+
" (conv3): Sequential(\n",
106+
" (0): Conv2d (1, 100, kernel_size=(5, 50), stride=(1, 1))\n",
107+
" (1): ReLU()\n",
108+
" (2): MaxPool2d(kernel_size=(146, 1), stride=(146, 1), dilation=(1, 1))\n",
109+
" )\n",
110+
" (layer): Linear(in_features=300, out_features=200)\n",
111+
" (dropout): Dropout(p=0.5)\n",
112+
" (output_layer): Linear(in_features=200, out_features=50)\n",
113+
")"
114+
]
115+
},
116+
"execution_count": 63,
117+
"metadata": {},
118+
"output_type": "execute_result"
119+
}
120+
],
121+
"source": [
122+
"cnn"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": 13,
128+
"metadata": {},
129+
"outputs": [
130+
{
131+
"name": "stdout",
132+
"output_type": "stream",
133+
"text": [
134+
"Load preprocessed rating data - ./data/preprocessed/ml-1m//ratings.all\n",
135+
"Load preprocessed document data - ./data/preprocessed/ml-1m//document.all\n"
136+
]
137+
}
138+
],
139+
"source": [
140+
"from data_manager import Data_Factory\n",
141+
"import pprint\n",
142+
"data_factory = Data_Factory()\n",
143+
"\n",
144+
"R, D_all = data_factory.load(\"./data/preprocessed/ml-1m/\")"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 59,
150+
"metadata": {},
151+
"outputs": [
152+
{
153+
"name": "stdout",
154+
"output_type": "stream",
155+
"text": [
156+
"3544\n",
157+
"95\n",
158+
"[2497, 7513, 6630, 4814, 1994, 2754, 3900, 5018, 4346, 7235, 2533, 2610, 2633, 4156, 249, 2161, 1127, 146, 6530, 5018, 337, 6530, 6985, 6530, 4157, 3071, 6530, 3900, 1500, 4316, 7833, 5018, 5150, 7102, 6530, 6476, 6530, 1394, 4450, 6751, 1238, 7824, 6530, 740, 3773, 7062, 5917, 2514, 1171, 3782, 5251, 2992, 2353, 1496, 7819, 6530, 2101, 1496, 7446, 5832, 1052, 4109, 1865, 7355, 7769, 1496, 3590, 2271, 7458, 5529, 6087, 475, 6530, 2063, 1908, 2497, 2754, 3379, 4161, 5526, 6474, 2535, 7934, 3782, 6530, 5150, 807, 1354, 172, 4156, 355, 3417, 249, 2168, 1649]\n"
159+
]
160+
},
161+
{
162+
"ename": "AttributeError",
163+
"evalue": "'list' object has no attribute 'shape'",
164+
"output_type": "error",
165+
"traceback": [
166+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
167+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
168+
"\u001b[0;32m<ipython-input-59-2917a9b68dc8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN_X\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
169+
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'shape'"
170+
]
171+
}
172+
],
173+
"source": [
174+
"CNN_X = D_all['X_sequence']\n",
175+
"print(len(CNN_X))\n",
176+
"print(len(CNN_X[3]))\n",
177+
"print(CNN_X[3])\n",
178+
"print(CNN_X)"
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": 9,
184+
"metadata": {},
185+
"outputs": [
186+
{
187+
"data": {
188+
"text/plain": [
189+
"8000"
190+
]
191+
},
192+
"execution_count": 9,
193+
"metadata": {},
194+
"output_type": "execute_result"
195+
}
196+
],
197+
"source": [
198+
"len(D_all['X_vocab'])"
199+
]
200+
},
201+
{
202+
"cell_type": "code",
203+
"execution_count": 48,
204+
"metadata": {},
205+
"outputs": [
206+
{
207+
"name": "stdout",
208+
"output_type": "stream",
209+
"text": [
210+
"Variable containing:\n",
211+
" 1.1375 0.6195 0.1585 1.0799 0.1302\n",
212+
"-0.5405 -0.9589 -1.3669 1.2314 1.9734\n",
213+
"-0.4789 0.5938 0.1744 -0.0176 -0.0497\n",
214+
"-0.2310 -1.1388 0.7172 -0.4343 0.7839\n",
215+
" 0.5238 0.7899 -0.5901 1.0298 0.3844\n",
216+
"-1.4921 1.8542 -1.1308 0.7227 -1.6314\n",
217+
"-0.9999 0.4745 0.3701 0.2189 0.4824\n",
218+
" 0.0339 1.6608 0.5456 -2.0539 0.0004\n",
219+
" 0.0580 0.9189 1.2705 1.6964 -0.6851\n",
220+
"-0.4247 -1.4672 0.5220 0.0431 -0.2025\n",
221+
" 1.0033 -1.0548 1.1176 0.5650 -1.4660\n",
222+
"-0.8414 1.8125 1.8854 -1.6015 -0.6787\n",
223+
"-0.8838 0.0412 -0.6423 1.7509 -1.9570\n",
224+
" 0.5814 -1.5999 0.6436 1.4211 -1.3188\n",
225+
"-0.4954 -0.6092 -1.6808 -1.0020 0.1801\n",
226+
"-0.9836 -0.0847 -1.2562 -0.1226 -0.2108\n",
227+
"-1.3440 -0.1142 -1.2649 0.2782 -1.4181\n",
228+
"-0.0528 0.0718 -0.6514 1.1687 -1.0889\n",
229+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
230+
" 0.0339 1.6608 0.5456 -2.0539 0.0004\n",
231+
"-0.7024 -0.6674 -1.9162 -0.1312 1.1091\n",
232+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
233+
" 0.2818 0.5606 -0.3546 -0.6588 -0.7651\n",
234+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
235+
"-0.7891 -1.7500 0.1098 0.8820 0.5139\n",
236+
" 1.2017 0.5298 -0.7179 -1.1478 -1.6993\n",
237+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
238+
"-0.9999 0.4745 0.3701 0.2189 0.4824\n",
239+
" 0.0888 -0.0128 1.5520 1.2025 0.6651\n",
240+
" 0.6077 0.5434 -1.5032 1.5325 1.8256\n",
241+
"-0.4112 -1.2229 -0.2878 0.6258 1.1456\n",
242+
" 0.0339 1.6608 0.5456 -2.0539 0.0004\n",
243+
" 0.1364 -0.6930 -2.3371 1.6786 0.5617\n",
244+
" 1.0285 -1.7050 -0.4896 -1.0000 0.9725\n",
245+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
246+
"-0.3655 0.4535 -0.4016 0.2056 -1.3832\n",
247+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
248+
" 0.1841 -0.2055 1.9259 1.5805 0.2368\n",
249+
"-0.4701 -0.7426 -1.1546 -1.3005 -1.7871\n",
250+
" 0.2266 -0.5332 -1.2338 0.4280 2.4386\n",
251+
"-0.6303 0.5834 0.5205 -0.8387 0.4257\n",
252+
" 0.5106 -0.4741 0.8534 -0.0879 0.0737\n",
253+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
254+
"-0.6402 1.6770 1.0849 0.3854 -1.0779\n",
255+
" 0.7510 -2.0220 -0.0449 -1.5944 -1.0741\n",
256+
" 0.0605 0.4658 -0.6328 -0.2047 0.2944\n",
257+
"-0.4521 0.4285 0.3141 0.3153 0.7379\n",
258+
"-0.2910 0.7501 1.2844 0.8987 -1.4570\n",
259+
" 0.2266 0.4233 -0.7622 0.6053 0.9736\n",
260+
"-0.5485 -0.0073 0.7028 0.4528 -1.2437\n",
261+
" 0.3651 -0.7326 1.1882 0.6137 -1.1131\n",
262+
"-0.2044 -1.9507 -0.3135 -1.3187 0.6094\n",
263+
"-1.1596 1.4216 -0.5054 -0.3568 -0.5185\n",
264+
"-0.4572 0.0472 -1.4310 0.5741 -1.2894\n",
265+
"-0.7071 1.8620 1.1305 -1.1232 1.5237\n",
266+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
267+
"-2.5396 1.3397 1.0959 -0.7480 0.0679\n",
268+
"-0.4572 0.0472 -1.4310 0.5741 -1.2894\n",
269+
" 1.6656 1.1903 -0.3698 0.2036 0.0240\n",
270+
" 0.7796 -1.7166 0.5709 0.0085 -1.1771\n",
271+
" 1.8073 0.3372 -0.1976 0.8187 0.1685\n",
272+
" 0.8279 -0.1674 -0.9651 -0.1265 0.0651\n",
273+
"-0.2603 -1.1816 0.3361 0.2628 0.8348\n",
274+
" 0.7354 0.1170 -0.9391 -2.4669 -1.3682\n",
275+
" 0.1860 -0.7448 -1.6378 -0.0045 1.5380\n",
276+
"-0.4572 0.0472 -1.4310 0.5741 -1.2894\n",
277+
" 2.1237 1.0455 -0.5948 0.0934 -1.6559\n",
278+
"-0.1634 -0.5910 0.2927 -0.0937 0.7996\n",
279+
" 2.6495 0.5423 -1.1649 -2.0393 0.2268\n",
280+
"-0.4307 -1.1426 -0.9575 -0.3125 -0.0436\n",
281+
" 0.2849 0.1704 -0.2270 0.0564 0.3925\n",
282+
" 2.3563 -0.5101 1.8536 0.4569 0.2821\n",
283+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
284+
" 1.6585 0.6344 -0.0001 -0.8202 0.1913\n",
285+
"-0.4837 -0.7519 -0.7759 -0.4802 -0.6648\n",
286+
" 1.1375 0.6195 0.1585 1.0799 0.1302\n",
287+
"-1.4921 1.8542 -1.1308 0.7227 -1.6314\n",
288+
"-1.1122 1.3342 -0.7807 -0.3339 -1.1619\n",
289+
" 0.1296 0.1896 1.2773 0.1513 -0.0704\n",
290+
" 0.9736 0.8593 0.3178 -2.2234 0.3245\n",
291+
"-0.7505 -0.9183 -1.8172 -0.0884 1.0104\n",
292+
"-0.8394 0.9989 -0.3466 -0.7640 -0.3779\n",
293+
"-0.6200 0.5447 -0.6092 -0.0782 -1.1962\n",
294+
"-0.5485 -0.0073 0.7028 0.4528 -1.2437\n",
295+
" 1.7251 1.5146 -0.6547 -0.2933 -1.5057\n",
296+
" 0.1364 -0.6930 -2.3371 1.6786 0.5617\n",
297+
"-1.1328 1.9744 -0.6251 0.9932 0.2207\n",
298+
"-1.6040 -0.5013 0.0782 1.1310 -0.4072\n",
299+
" 0.0398 -0.3110 0.3703 0.6808 -0.5264\n",
300+
" 0.5814 -1.5999 0.6436 1.4211 -1.3188\n",
301+
"-0.1824 -0.4074 -0.1582 -0.4725 1.2616\n",
302+
"-0.3176 -1.0342 0.9127 1.4634 -0.1190\n",
303+
"-0.4954 -0.6092 -1.6808 -1.0020 0.1801\n",
304+
" 0.4479 -0.5058 -2.0886 -2.4117 0.6307\n",
305+
" 1.4570 0.4706 -1.3763 -0.6453 0.4371\n",
306+
"[torch.FloatTensor of size 95x5]\n",
307+
"\n"
308+
]
309+
}
310+
],
311+
"source": [
312+
"embeds = nn.Embedding(8000, 5)\n",
313+
"test = CNN_X[3]\n",
314+
"tensor = torch.LongTensor(list(map(int, test)))\n",
315+
"me_embed = embeds(Variable(tensor))\n",
316+
"print(me_embed)"
317+
]
318+
},
319+
{
320+
"cell_type": "code",
321+
"execution_count": 49,
322+
"metadata": {},
323+
"outputs": [
324+
{
325+
"data": {
326+
"text/plain": [
327+
"<map at 0x7fd13d822ac8>"
328+
]
329+
},
330+
"execution_count": 49,
331+
"metadata": {},
332+
"output_type": "execute_result"
333+
}
334+
],
335+
"source": [
336+
"map(int, test)"
337+
]
338+
},
339+
{
340+
"cell_type": "code",
341+
"execution_count": null,
342+
"metadata": {
343+
"collapsed": true
344+
},
345+
"outputs": [],
346+
"source": []
347+
}
348+
],
349+
"metadata": {
350+
"kernelspec": {
351+
"display_name": "Python 3",
352+
"language": "python",
353+
"name": "python3"
354+
},
355+
"language_info": {
356+
"codemirror_mode": {
357+
"name": "ipython",
358+
"version": 3
359+
},
360+
"file_extension": ".py",
361+
"mimetype": "text/x-python",
362+
"name": "python",
363+
"nbconvert_exporter": "python",
364+
"pygments_lexer": "ipython3",
365+
"version": "3.6.3"
366+
}
367+
},
368+
"nbformat": 4,
369+
"nbformat_minor": 2
370+
}

0 commit comments

Comments
 (0)