Skip to content

Commit 21b48a8

Browse files
authored
Merge pull request tensorflow#2718 from tensorflow/flaky-wide-deep-test
Make wide_deep_test much less flaky and more thorough
2 parents 205c5e0 + dbd3088 commit 21b48a8

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

official/wide_deep/wide_deep_test.csv

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@
1818
32,Private,186824,HS-grad,9,Never-married,Machine-op-inspct,Unmarried,,,0,0,40,,<=50K
1919
38,Private,28887,11th,7,Married-civ-spouse,Sales,Husband,,,0,0,50,,<=50K
2020
43,Self-emp-not-inc,292175,Masters,14,Divorced,Exec-managerial,Unmarried,,,0,0,45,,>50K
21+
40,Private,193524,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,60,,>50K
22+
56,Local-gov,216851,Bachelors,13,Married-civ-spouse,Tech-support,Husband,,,0,0,40,,>50K
23+
54,?,180211,Some-college,10,Married-civ-spouse,?,Husband,,,0,0,60,,>50K
24+
22,State-gov,311512,Some-college,10,Married-civ-spouse,Other-service,Husband,,,0,0,15,,<=50K
25+
31,Private,84154,Some-college,10,Married-civ-spouse,Sales,Husband,,,0,0,38,,>50K
26+
57,Federal-gov,337895,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K
27+
47,Private,51835,Prof-school,15,Married-civ-spouse,Prof-specialty,Wife,,,0,1902,60,,>50K
28+
50,Federal-gov,251585,Bachelors,13,Divorced,Exec-managerial,Not-in-family,,,0,0,55,,>50K
29+
25,Private,289980,HS-grad,9,Never-married,Handlers-cleaners,Not-in-family,,,0,0,35,,<=50K
30+
42,Private,116632,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,45,,>50K

official/wide_deep/wide_deep_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,23 @@ def build_and_test_estimator(self, model_type):
8585
input_fn=lambda: wide_deep.input_fn(
8686
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
8787

88-
# Train for 40 steps at batch size 2 and evaluate final loss
88+
# Train for 100 epochs at batch size 3 and evaluate final loss
8989
model.train(
9090
input_fn=lambda: wide_deep.input_fn(
91-
TEST_CSV, num_epochs=None, shuffle=True, batch_size=2),
92-
steps=40)
91+
TEST_CSV, num_epochs=100, shuffle=True, batch_size=3))
9392
final_results = model.evaluate(
9493
input_fn=lambda: wide_deep.input_fn(
9594
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
9695

9796
print('%s initial results:' % model_type, initial_results)
9897
print('%s final results:' % model_type, final_results)
98+
99+
# Ensure loss has decreased, while accuracy and both AUCs have increased.
99100
self.assertLess(final_results['loss'], initial_results['loss'])
101+
self.assertGreater(final_results['auc'], initial_results['auc'])
102+
self.assertGreater(final_results['auc_precision_recall'],
103+
initial_results['auc_precision_recall'])
104+
self.assertGreater(final_results['accuracy'], initial_results['accuracy'])
100105

101106
def test_wide_deep_estimator_training(self):
102107
self.build_and_test_estimator('wide_deep')

0 commit comments

Comments
 (0)