66import shutil
77import tempfile
88
9- from numpy .testing import assert_equal , assert_array_equal
9+ from numpy .testing import assert_equal
10+ from numpy .testing import assert_array_equal
11+ from numpy .testing import assert_array_almost_equal
1012from nose .tools import raises
1113
1214from sklearn .datasets import (load_svmlight_file , load_svmlight_files ,
@@ -22,10 +24,10 @@ def test_load_svmlight_file():
2224 X , y = load_svmlight_file (datafile )
2325
2426 # test X's shape
25- assert_equal (X .indptr .shape [0 ], 4 )
26- assert_equal (X .shape [0 ], 3 )
27+ assert_equal (X .indptr .shape [0 ], 5 )
28+ assert_equal (X .shape [0 ], 4 )
2729 assert_equal (X .shape [1 ], 21 )
28- assert_equal (y .shape [0 ], 3 )
30+ assert_equal (y .shape [0 ], 4 )
2931
3032 # test X's non-zero values
3133 for i , j , val in ((0 , 2 , 2.5 ), (0 , 10 , - 5.2 ), (0 , 15 , 1.5 ),
@@ -46,7 +48,7 @@ def test_load_svmlight_file():
4648 assert_equal (X [0 , 2 ], 5 )
4749
4850 # test y
49- assert_array_equal (y , [1 , 2 , 3 ])
51+ assert_array_equal (y , [1 , 2 , 3 , 4 ])
5052
5153
5254def test_load_svmlight_file_fd ():
@@ -86,8 +88,8 @@ def test_load_svmlight_file_n_features():
8688 X , y = load_svmlight_file (datafile , n_features = 20 )
8789
8890 # test X'shape
89- assert_equal (X .indptr .shape [0 ], 4 )
90- assert_equal (X .shape [0 ], 3 )
91+ assert_equal (X .indptr .shape [0 ], 5 )
92+ assert_equal (X .shape [0 ], 4 )
9193 assert_equal (X .shape [1 ], 20 )
9294
9395 # test X's non-zero values
@@ -168,9 +170,20 @@ def test_dump():
168170
169171 for X in (Xs , Xd ):
170172 for zero_based in (True , False ):
171- f = BytesIO ()
172- dump_svmlight_file (X , y , f , zero_based = zero_based )
173- f .seek (0 )
174- X2 , y2 = load_svmlight_file (f , zero_based = zero_based )
175- assert_array_equal (Xd , X2 .toarray ())
176- assert_array_equal (y , y2 )
173+ for dtype in [np .float32 , np .float64 ]:
174+ f = BytesIO ()
175+ dump_svmlight_file (X .astype (dtype ), y , f ,
176+ zero_based = zero_based )
177+ f .seek (0 )
178+ X2 , y2 = load_svmlight_file (f , dtype = dtype ,
179+ zero_based = zero_based )
180+ assert_equal (X2 .dtype , dtype )
181+ if dtype == np .float32 :
182+ assert_array_almost_equal (
183+ # allow a rounding error at the last decimal place
184+ Xd .astype (dtype ), X2 .toarray (), 4 )
185+ else :
186+ assert_array_almost_equal (
187+ # allow a rounding error at the last decimal place
188+ Xd .astype (dtype ), X2 .toarray (), 15 )
189+ assert_array_equal (y , y2 )
0 commit comments