Skip to content

Commit 285d9cc

Browse files
committed
Merge pull request scikit-learn#938 from ogrisel/svmlight-double-precision
MRG: preserve double precision values in svmlight serializer
2 parents 1864bca + 6b34e14 commit 285d9cc

File tree

4 files changed

+42
-15
lines changed

4 files changed

+42
-15
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ Changelog
2828
- SVMlight file format loader now detects compressed (gzip/bzip2) files and
2929
decompresses them on the fly.
3030

31+
- SVMlight file format serializer now preserves double precision floating
32+
point values, by `Olivier Grisel`_.
33+
3134
- A common testing framework for all estimators was added.
3235

3336
- Decision trees and forests of randomized trees now support multi-output

sklearn/datasets/svmlight_format.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,20 @@ def _dump_svmlight(X, y, f, zero_based):
195195
is_sp = int(hasattr(X, "tocsr"))
196196

197197
one_based = not zero_based
198+
if X.dtype == np.float64:
199+
value_pattern = u"%d:%0.16e"
200+
else:
201+
value_pattern = u"%d:%f"
202+
203+
if y.dtype.kind == 'i':
204+
line_pattern = u"%d %s\n"
205+
else:
206+
line_pattern = u"%f %s\n"
207+
198208
for i in xrange(X.shape[0]):
199-
s = u" ".join([u"%d:%f" % (j + one_based, X[i, j])
209+
s = u" ".join([value_pattern % (j + one_based, X[i, j])
200210
for j in X[i].nonzero()[is_sp]])
201-
f.write((u"%f %s\n" % (y[i], s)).encode('ascii'))
211+
f.write((line_pattern % (y[i], s)).encode('ascii'))
202212

203213

204214
def dump_svmlight_file(X, y, f, zero_based=True):

sklearn/datasets/tests/data/svmlight_classification.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
2.0 6:1.0 13:-3
55
# another comment
66
3.0 21:27
7+
4.0 2:1.234567890123456e10 # double precision value

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import shutil
77
import tempfile
88

9-
from numpy.testing import assert_equal, assert_array_equal
9+
from numpy.testing import assert_equal
10+
from numpy.testing import assert_array_equal
11+
from numpy.testing import assert_array_almost_equal
1012
from nose.tools import raises
1113

1214
from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
@@ -22,10 +24,10 @@ def test_load_svmlight_file():
2224
X, y = load_svmlight_file(datafile)
2325

2426
# test X's shape
25-
assert_equal(X.indptr.shape[0], 4)
26-
assert_equal(X.shape[0], 3)
27+
assert_equal(X.indptr.shape[0], 5)
28+
assert_equal(X.shape[0], 4)
2729
assert_equal(X.shape[1], 21)
28-
assert_equal(y.shape[0], 3)
30+
assert_equal(y.shape[0], 4)
2931

3032
# test X's non-zero values
3133
for i, j, val in ((0, 2, 2.5), (0, 10, -5.2), (0, 15, 1.5),
@@ -46,7 +48,7 @@ def test_load_svmlight_file():
4648
assert_equal(X[0, 2], 5)
4749

4850
# test y
49-
assert_array_equal(y, [1, 2, 3])
51+
assert_array_equal(y, [1, 2, 3, 4])
5052

5153

5254
def test_load_svmlight_file_fd():
@@ -86,8 +88,8 @@ def test_load_svmlight_file_n_features():
8688
X, y = load_svmlight_file(datafile, n_features=20)
8789

8890
# test X'shape
89-
assert_equal(X.indptr.shape[0], 4)
90-
assert_equal(X.shape[0], 3)
91+
assert_equal(X.indptr.shape[0], 5)
92+
assert_equal(X.shape[0], 4)
9193
assert_equal(X.shape[1], 20)
9294

9395
# test X's non-zero values
@@ -168,9 +170,20 @@ def test_dump():
168170

169171
for X in (Xs, Xd):
170172
for zero_based in (True, False):
171-
f = BytesIO()
172-
dump_svmlight_file(X, y, f, zero_based=zero_based)
173-
f.seek(0)
174-
X2, y2 = load_svmlight_file(f, zero_based=zero_based)
175-
assert_array_equal(Xd, X2.toarray())
176-
assert_array_equal(y, y2)
173+
for dtype in [np.float32, np.float64]:
174+
f = BytesIO()
175+
dump_svmlight_file(X.astype(dtype), y, f,
176+
zero_based=zero_based)
177+
f.seek(0)
178+
X2, y2 = load_svmlight_file(f, dtype=dtype,
179+
zero_based=zero_based)
180+
assert_equal(X2.dtype, dtype)
181+
if dtype == np.float32:
182+
assert_array_almost_equal(
183+
# allow a rounding error at the last decimal place
184+
Xd.astype(dtype), X2.toarray(), 4)
185+
else:
186+
assert_array_almost_equal(
187+
# allow a rounding error at the last decimal place
188+
Xd.astype(dtype), X2.toarray(), 15)
189+
assert_array_equal(y, y2)

0 commit comments

Comments
 (0)