Skip to content

Commit 8182ea3

Browse files
author
浅梦
authored
Refactor input module of linear part
1 parent 1404f0d commit 8182ea3

26 files changed

+209
-129
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Steps to reproduce the behavior:
2020
**Operating environment(运行环境):**
2121
- python version [e.g. 3.4, 3.6]
2222
- tensorflow version [e.g. 1.4.0, 1.12.0]
23-
- deepctr version [e.g. 0.2.3,]
23+
- deepctr version [e.g. 0.5.2,]
2424

2525
**Additional context**
2626
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/question.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ labels: question
66
assignees: ''
77

88
---
9+
Please refer to the [FAQ](https://deepctr-doc.readthedocs.io/en/latest/FAQ.html) in doc and search for the [related issues](https://github.com/shenweichen/DeepCTR/issues) before you ask the question.
910

1011
**Describe the question(问题描述)**
1112
A clear and concise description of what the question is.
@@ -16,4 +17,4 @@ Add any other context about the problem here.
1617
**Operating environment(运行环境):**
1718
- python version [e.g. 3.6]
1819
- tensorflow version [e.g. 1.4.0,]
19-
- deepctr version [e.g. 0.3.2,]
20+
- deepctr version [e.g. 0.5.2,]

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ env:
1717
#Not Support- TF_VERSION=1.7.1
1818
#Not Support- TF_VERSION=1.8.0
1919
#- TF_VERSION=1.8.0
20-
- TF_VERSION=1.10.0 #- TF_VERSION=1.10.1
20+
#- TF_VERSION=1.10.0 >50 mins limit #- TF_VERSION=1.10.1
2121
# - TF_VERSION=1.11.0
2222
#- TF_VERSION=1.5.1 #- TF_VERSION=1.5.0
2323
- TF_VERSION=1.6.0

deepctr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from . import models
33
from .utils import check_version
44

5-
__version__ = '0.5.1'
5+
__version__ = '0.5.2'
66
check_version(__version__)

deepctr/inputs.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from itertools import chain
1111

1212
from tensorflow.python.keras.initializers import RandomNormal
13-
from tensorflow.python.keras.layers import Concatenate, Dense, Embedding, Input, add,Flatten
13+
from tensorflow.python.keras.layers import Embedding, Input, Flatten
1414
from tensorflow.python.keras.regularizers import l2
1515

1616
from .layers.sequence import SequencePoolingLayer
17-
from .layers.utils import Hash,concat_fun
17+
from .layers.utils import Hash,concat_fun,Linear
1818

1919

2020
class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype','embedding_name','embedding'])):
@@ -45,12 +45,14 @@ def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype
4545

4646
def get_fixlen_feature_names(feature_columns):
4747
features = build_input_features(feature_columns, include_varlen=False,include_fixlen=True)
48-
return features.keys()
48+
return list(features.keys())
4949

5050
def get_varlen_feature_names(feature_columns):
5151
features = build_input_features(feature_columns, include_varlen=True,include_fixlen=False)
52-
return features.keys()
52+
return list(features.keys())
5353

54+
def get_inputs_list(inputs):
55+
return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs)))))
5456

5557
def build_input_features(feature_columns, include_varlen=True, mask_zero=True, prefix='',include_fixlen=True):
5658
input_features = OrderedDict()
@@ -61,7 +63,7 @@ def build_input_features(feature_columns, include_varlen=True, mask_zero=True, p
6163
shape=(1,), name=prefix+fc.name, dtype=fc.dtype)
6264
elif isinstance(fc,DenseFeat):
6365
input_features[fc.name] = Input(
64-
shape=(1,), name=prefix + fc.name, dtype=fc.dtype)
66+
shape=(fc.dimension,), name=prefix + fc.name, dtype=fc.dtype)
6567
if include_varlen:
6668
for fc in feature_columns:
6769
if isinstance(fc,VarLenSparseFeat):
@@ -138,8 +140,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
138140
return embedding_vec_list
139141

140142

141-
def get_inputs_list(inputs):
142-
return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs)))))
143+
143144

144145
def create_embedding_matrix(feature_columns,l2_reg,init_std,seed,embedding_size, prefix="",seq_mask_zero=True):
145146
sparse_feature_columns = list(
@@ -155,24 +156,24 @@ def get_linear_logit(features, feature_columns, units=1, l2_reg=0, init_std=0.00
155156
linear_emb_list = [input_from_feature_columns(features,feature_columns,1,l2_reg,init_std,seed,prefix=prefix+str(i))[0] for i in range(units)]
156157
_, dense_input_list = input_from_feature_columns(features,feature_columns,1,l2_reg,init_std,seed,prefix=prefix)
157158

158-
if len(linear_emb_list[0]) > 1:
159-
linear_term = concat_fun([add(linear_emb) for linear_emb in linear_emb_list])
160-
elif len(linear_emb_list[0]) == 1:
161-
linear_term = concat_fun([linear_emb[0] for linear_emb in linear_emb_list])
162-
else:
163-
linear_term = None
164-
165-
if len(dense_input_list) > 0:
166-
dense_input__ = dense_input_list[0] if len(
167-
dense_input_list) == 1 else Concatenate()(dense_input_list)
168-
linear_dense_logit = Dense(
169-
units, activation=None, use_bias=False, kernel_regularizer=l2(l2_reg))(dense_input__)
170-
if linear_term is not None:
171-
linear_term = add([linear_dense_logit, linear_term])
159+
linear_logit_list = []
160+
for i in range(units):
161+
162+
if len(linear_emb_list[0])>0 and len(dense_input_list) >0:
163+
sparse_input = concat_fun(linear_emb_list[i])
164+
dense_input = concat_fun(dense_input_list)
165+
linear_logit = Linear(l2_reg,mode=2)([sparse_input,dense_input])
166+
elif len(linear_emb_list[0])>0:
167+
sparse_input = concat_fun(linear_emb_list[i])
168+
linear_logit = Linear(l2_reg,mode=0)(sparse_input)
169+
elif len(dense_input_list) >0:
170+
dense_input = concat_fun(dense_input_list)
171+
linear_logit = Linear(l2_reg,mode=1)(dense_input)
172172
else:
173-
linear_term = linear_dense_logit
173+
raise NotImplementedError
174+
linear_logit_list.append(linear_logit)
174175

175-
return linear_term
176+
return concat_fun(linear_logit_list)
176177

177178
def embedding_lookup(sparse_embedding_dict,sparse_input_dict,sparse_feature_columns,return_feat_list=(), mask_feat_list=()):
178179
embedding_vec_list = []

deepctr/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
1010
KMaxPooling, SequencePoolingLayer,
1111
Transformer, DynamicGRU)
12-
from .utils import NoMask, Hash
12+
from .utils import NoMask, Hash,Linear
1313

1414
custom_objects = {'tf': tf,
1515
'InnerProductLayer': InnerProductLayer,
@@ -34,6 +34,7 @@
3434
'KMaxPooling': KMaxPooling,
3535
'FGCNNLayer': FGCNNLayer,
3636
'Hash': Hash,
37+
'Linear':Linear,
3738
'DynamicGRU': DynamicGRU,
3839
'SENETLayer':SENETLayer,
3940
'BilinearInteraction':BilinearInteraction,

deepctr/layers/utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def compute_mask(self, inputs, mask):
2323
return None
2424

2525

26-
2726
class Hash(tf.keras.layers.Layer):
2827
"""
2928
hash the input to [0,num_buckets)
@@ -43,7 +42,7 @@ def call(self, x, mask=None, **kwargs):
4342
if x.dtype != tf.string:
4443
x = tf.as_string(x, )
4544
hash_x = tf.string_to_hash_bucket_fast(x, self.num_buckets if not self.mask_zero else self.num_buckets - 1,
46-
name=None)#weak hash
45+
name=None) # weak hash
4746
if self.mask_zero:
4847
mask_1 = tf.cast(tf.not_equal(x, "0"), 'int64')
4948
mask_2 = tf.cast(tf.not_equal(x, "0.0"), 'int64')
@@ -60,6 +59,54 @@ def get_config(self, ):
6059
return dict(list(base_config.items()) + list(config.items()))
6160

6261

62+
class Linear(tf.keras.layers.Layer):
63+
64+
def __init__(self, l2_reg=0.0, mode=0, **kwargs):
65+
66+
self.l2_reg = l2_reg
67+
# self.l2_reg = tf.contrib.layers.l2_regularizer(float(l2_reg_linear))
68+
self.mode = mode
69+
super(Linear, self).__init__(**kwargs)
70+
71+
def build(self, input_shape):
72+
73+
self.bias = self.add_weight(name='linear_bias',
74+
shape=(1,),
75+
initializer=tf.keras.initializers.Zeros(),
76+
trainable=True)
77+
78+
self.dense = tf.keras.layers.Dense(units=1, activation=None, use_bias=False,
79+
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg))
80+
81+
super(Linear, self).build(input_shape) # Be sure to call this somewhere!
82+
83+
def call(self, inputs , **kwargs):
84+
85+
if self.mode == 0:
86+
sparse_input = inputs
87+
linear_logit = tf.reduce_sum(sparse_input, axis=-1, keep_dims=True)
88+
elif self.mode == 1:
89+
dense_input = inputs
90+
linear_logit = self.dense(dense_input)
91+
92+
else:
93+
sparse_input, dense_input = inputs
94+
95+
linear_logit = tf.reduce_sum(sparse_input, axis=-1, keep_dims=False) + self.dense(dense_input)
96+
97+
linear_bias_logit = linear_logit + self.bias
98+
99+
return linear_bias_logit
100+
101+
def compute_output_shape(self, input_shape):
102+
return (None, 1)
103+
104+
def get_config(self, ):
105+
config = {'mode': self.mode, 'l2_reg': self.l2_reg}
106+
base_config = super(Linear, self).get_config()
107+
return dict(list(base_config.items()) + list(config.items()))
108+
109+
63110
def concat_fun(inputs, axis=-1):
64111
if len(inputs) == 1:
65112
return inputs[0]

docs/source/Examples.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,6 @@ There are 2 additional steps to use DeepCTR with sequence feature input.
218218
- embedding : default `True`.If `False`, the feature will not be embeded to a dense vector.
219219

220220

221-
Now multi-value input is avaliable for `AFM,AutoInt,DCN,DeepFM,FNN,NFM,PNN,xDeepFM,CCPM,FGCNN`,for `DIN,DIEN,DSIN` please read the example in [run_din.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_din.py),[run_dien.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dien.py) and [run_dsin.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dsin.py)
222-
223221
This example shows how to use ``DeepFM`` with sequence(multi-value) feature. You can get the demo data
224222
[movielens_sample.txt](https://github.com/shenweichen/DeepCTR/tree/master/examples/movielens_sample.txt) and run the following codes.
225223

docs/source/FAQ.md

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from tensorflow.python.keras.models import save_model,load_model
1717
model = DeepFM()
1818
save_model(model, 'DeepFM.h5')# save_model, same as before
1919

20-
from deepctr.utils import custom_objects
20+
from deepctr.layers import custom_objects
2121
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
2222
```
2323
## 2. Set learning rate and use earlystopping
@@ -30,7 +30,7 @@ import deepctr
3030
from tensorflow.python.keras.optimizers import Adam,Adagrad
3131
from tensorflow.python.keras.callbacks import EarlyStopping
3232

33-
model = deepctr.models.DeepFM({"sparse": sparse_feature_dict, "dense": dense_feature_list})
33+
model = deepctr.models.DeepFM(linear_feature_columns,dnn_feature_columns)
3434
model.compile(Adagrad('0.0808'),'binary_crossentropy',metrics=['binary_crossentropy'])
3535

3636
es = EarlyStopping(monitor='val_binary_crossentropy')
@@ -47,36 +47,70 @@ Then,use the following code,the `attentional_weights[:,i,0]` is the `feature_int
4747
```python
4848
import itertools
4949
import deepctr
50+
from deepctr.models import AFM
51+
from deepctr.inputs import get_fixlen_feature_names,get_varlen_feature_names
5052
from tensorflow.python.keras.models import Model
5153
from tensorflow.python.keras.layers import Lambda
5254

53-
feature_dim_dict = {"sparse": sparse_feature_dict, "dense": dense_feature_list}
54-
model = deepctr.models.AFM(feature_dim_dict)
55+
model = AFM(linear_feature_columns,dnn_feature_columns)
5556
model.fit(model_input,target)
5657

5758
afmlayer = model.layers[-3]
5859
afm_weight_model = Model(model.input,outputs=Lambda(lambda x:afmlayer.normalized_att_score)(model.input))
5960
attentional_weights = afm_weight_model.predict(model_input,batch_size=4096)
60-
feature_interactions = list(itertools.combinations(list(feature_dim_dict['sparse'].keys()) + feature_dim_dict['dense'] ,2))
61+
62+
fixlen_names = get_fixlen_feature_names( dnn_feature_columns)
63+
varlen_names = get_varlen_feature_names(dnn_feature_columns)
64+
feature_interactions = list(itertools.combinations(fixlen_names+varlen_names ,2))
65+
```
66+
## 4. How to extract the embedding vectors in deepfm?
67+
```python
68+
feature_columns = [SparseFeat('user_id',120,),SparseFeat('item_id',60,),SparseFeat('cate_id',60,)]
69+
70+
def get_embedding_weights(dnn_feature_columns,model):
71+
embedding_dict = {}
72+
for fc in dnn_feature_columns:
73+
if hasattr(fc,'embedding_name'):
74+
if fc.embedding_name is not None:
75+
name = fc.embedding_name
76+
else:
77+
name = fc.name
78+
embedding_dict[name] = model.get_layer("sparse_emb_"+name).get_weights()[0]
79+
return embedding_dict
80+
81+
embedding_dict = get_embedding_weights(feature_columns,model)
82+
83+
user_id_emb = embedding_dict['user_id']
84+
item_id_emb = embedding_dict['item_id']
6185
```
6286

63-
## 4. Does the models support multi-value input?
64-
---------------------------------------------------
65-
Now multi-value input is avaliable for `AFM,AutoInt,DCN,DeepFM,FNN,NFM,PNN,xDeepFM`,you can read the example [here](./Examples.html#multi-value-input-movielens).
87+
## 5. How to add a long dense feature vector as a input to the model?
88+
```python
89+
from deepctr.models import DeepFM
90+
from deepctr.inputs import DenseFeat,SparseFeat,get_fixlen_feature_names
91+
import numpy as np
6692

67-
For `DIN` please read the code example in [run_din.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_din.py
68-
).
93+
feature_columns = [SparseFeat('user_id',120,),SparseFeat('item_id',60,),DenseFeat("pic_vec",5)]
94+
fixlen_feature_names = get_fixlen_feature_names(feature_columns)
6995

70-
For `DIEN` please read the code example in [run_dien.py](https://github.com/shenweichen/DeepCTR/blob/master/examples/run_dien.py
71-
).
96+
user_id = np.array([[1],[0],[1]])
97+
item_id = np.array([[30],[20],[10]])
98+
pic_vec = np.array([[0.1,0.5,0.4,0.3,0.2],[0.1,0.5,0.4,0.3,0.2],[0.1,0.5,0.4,0.3,0.2]])
99+
label = np.array([1,0,1])
72100

73-
You can also use layers in [sequence](./deepctr.layers.sequence.html)to build your own models !
101+
input_dict = {'user_id':user_id,'item_id':item_id,'pic_vec':pic_vec}
102+
model_input = [input_dict[name] for name in fixlen_feature_names]
74103

75-
## 5. How to add a long feature vector as a feature to the model?
76-
please refer [this](https://github.com/shenweichen/DeepCTR/issues/42)
104+
model = DeepFM(feature_columns,feature_columns[:-1])
105+
model.compile('adagrad','binary_crossentropy')
106+
model.fit(model_input,label)
107+
```
77108

78109
## 6. How to run the demo with GPU ?
79-
please refer [this](https://github.com/shenweichen/DeepCTR/issues/40)
110+
just install deepctr with
111+
```bash
112+
$ pip install deepctr[gpu]
113+
```
80114

81115
## 7. Could not find a version that satisfies the requirement deepctr (from versions)
82116
please install with `pip3 install` instead of `pip install`

docs/source/History.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# History
2+
- 07/21/2019 : [v0.5.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.2) released.Refactor `Linear` Layer.
23
- 07/10/2019 : [v0.5.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.1) released.Add [FiBiNET](./Features.html#fibinet-feature-importance-and-bilinear-feature-interaction-network).
34
- 06/30/2019 : [v0.5.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.5.0) released.Refactor inputs module.
45
- 05/19/2019 : [v0.4.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.4.1) released.Add [DSIN](./Features.html#dsin-deep-session-interest-network).

0 commit comments

Comments
 (0)