Skip to content

Commit 8929710

Browse files
Add classification_signature_fn_with_probabilities and make it the default for LinearClassifier.
Change: 133795652
1 parent a3ca3c9 commit 8929710

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

tensorflow/contrib/learn/python/learn/estimators/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,8 @@ def default_input_fn(unused_estimator, examples):
571571
input_fn=input_fn or default_input_fn,
572572
input_feature_key=input_feature_key,
573573
use_deprecated_input_fn=use_deprecated_input_fn,
574-
signature_fn=signature_fn or export.classification_signature_fn,
574+
signature_fn=(
575+
signature_fn or export.classification_signature_fn_with_prob),
575576
prediction_key=_PROBABILITIES,
576577
default_batch_size=default_batch_size,
577578
exports_to_keep=exports_to_keep)

tensorflow/contrib/learn/python/learn/utils/export.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ def classification_signature_fn(examples, unused_features, predictions):
119119
return default_signature, {}
120120

121121

122+
def classification_signature_fn_with_prob(
123+
examples, unused_features, predictions):
124+
"""Classification signature from given examples and predicted probabilities.
125+
126+
Args:
127+
examples: `Tensor`.
128+
unused_features: `dict` of `Tensor`s.
129+
predictions: `Tensor` of predicted probabilities.
130+
131+
Returns:
132+
Tuple of default classification signature and empty named signatures.
133+
"""
134+
default_signature = exporter.classification_signature(
135+
examples, scores_tensor=predictions)
136+
return default_signature, {}
137+
138+
122139
def regression_signature_fn(examples, unused_features, predictions):
123140
"""Creates regression signature from given examples and predictions.
124141

0 commit comments

Comments
 (0)