Skip to content

Commit d9c61b8

Browse files
ariejdljhseu
authored andcommitted
Simplify _linear_model_fn
* reduce line count using getattr() on layers 'weighted_sum...' rather than two explicit calls * create a variable that has a reference to the correct function as suggested * use a better name for the logit function * use regular if statement * be more explicit in if statement * fix whitespace * whitespace fix
1 parent 6a689e9 commit d9c61b8

File tree

1 file changed

+9
-14
lines changed
  • tensorflow/contrib/learn/python/learn/estimators

1 file changed

+9
-14
lines changed

tensorflow/contrib/learn/python/learn/estimators/linear.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,16 @@ def _linear_model_fn(features, labels, mode, params, config=None):
149149
values=tuple(six.itervalues(features)),
150150
partitioner=partitioner) as scope:
151151
if joint_weights:
152-
logits, _, _ = (
153-
layers.joint_weighted_sum_from_feature_columns(
154-
columns_to_tensors=features,
155-
feature_columns=feature_columns,
156-
num_outputs=head.logits_dimension,
157-
weight_collections=[parent_scope],
158-
scope=scope))
152+
layer_fn = layers.joint_weighted_sum_from_feature_columns
159153
else:
160-
logits, _, _ = (
161-
layers.weighted_sum_from_feature_columns(
162-
columns_to_tensors=features,
163-
feature_columns=feature_columns,
164-
num_outputs=head.logits_dimension,
165-
weight_collections=[parent_scope],
166-
scope=scope))
154+
layer_fn = layers.weighted_sum_from_feature_columns
155+
156+
logits, _, _ = layer_fn(
157+
columns_to_tensors=features,
158+
feature_columns=feature_columns,
159+
num_outputs=head.logits_dimension,
160+
weight_collections=[parent_scope],
161+
scope=scope)
167162

168163
def _train_op_fn(loss):
169164
global_step = contrib_variables.get_global_step()

0 commit comments

Comments
 (0)