@@ -400,6 +400,8 @@ def test_stratified_shuffle_split_iter_no_indices():
400400def test_shuffle_split_even ():
401401 # Test the in StratifiedShuffleSplit, indices are drawn with a
402402 # equal chance
403+ n_folds = 5
404+ n_iter = 1000
403405
404406 def assert_counts_are_ok (idx_counts , p ):
405407 # Here we test that the distribution of the counts
@@ -412,26 +414,30 @@ def assert_counts_are_ok(idx_counts, p):
412414 "An index is not drawn with chance corresponding "
413415 "to even draws" )
414416
415- for n_labels in (6 , 22 ):
416- labels = np .array ((n_labels // 2 ) * [0 , 1 ])
417- n_folds = 5
418- splits = cval .StratifiedShuffleSplit (labels , n_iter = 1000 ,
417+ for n_samples in (6 , 22 ):
418+ labels = np .array ((n_samples // 2 ) * [0 , 1 ])
419+ splits = cval .StratifiedShuffleSplit (labels , n_iter = n_iter ,
419420 test_size = 1. / n_folds , random_state = 0 )
420421
421- train_counts = [0 ] * len (labels )
422- test_counts = [0 ] * len (labels )
422+ train_counts = [0 ] * n_samples
423+ test_counts = [0 ] * n_samples
424+ n_splits = 0
423425 for train , test in splits :
426+ n_splits += 1
424427 for counter , ids in [(train_counts , train ), (test_counts , test )]:
425428 for id in ids :
426429 counter [id ] += 1
430+ assert_equal (n_splits , n_iter )
431+
432+ assert_equal (len (train ), splits .n_train )
433+ assert_equal (len (test ), splits .n_test )
427434
428- n_splits = len (splits )
429435 label_counts = np .unique (labels )
430436 assert_equal (splits .test_size , 1.0 / n_folds )
431437 assert_equal (splits .n_train + splits .n_test , len (labels ))
432438 assert_equal (len (label_counts ), 2 )
433- ex_test_p = ( 1. * splits .n_test ) / n_labels
434- ex_train_p = 1.0 - ex_test_p
439+ ex_test_p = float ( splits .n_test ) / n_samples
440+ ex_train_p = float ( splits . n_train ) / n_samples
435441
436442 assert_counts_are_ok (train_counts , ex_train_p )
437443 assert_counts_are_ok (test_counts , ex_test_p )
0 commit comments