Skip to content

Commit ba49b38

Browse files
committed
Support for masked arrays
1 parent 90f0bbf commit ba49b38

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

sklearn/manifold/mds.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ def _smacof_single(dissimilarities, metric=True, n_components=2, init=None,
7272
n_samples = dissimilarities.shape[0]
7373
random_state = check_random_state(random_state)
7474

75-
sim_flat = ((1 - np.tri(n_samples)) * dissimilarities).ravel()
76-
sim_flat_w = sim_flat[sim_flat != 0]
75+
if not metric:
76+
sim_flat = ((1 - np.tri(n_samples)) * dissimilarities).ravel()
77+
sim_flat_w = sim_flat[sim_flat != 0]
7778
if init is None:
7879
# Randomly choose initial configuration
7980
X = random_state.rand(n_samples * n_components)
@@ -88,6 +89,7 @@ def _smacof_single(dissimilarities, metric=True, n_components=2, init=None,
8889

8990
old_stress = None
9091
ir = IsotonicRegression()
92+
masked = hasattr(dissimilarities,'mask')
9193
for it in range(max_iter):
9294
# Compute distance and monotonic regression
9395
dis = euclidean_distances(X)
@@ -108,14 +110,21 @@ def _smacof_single(dissimilarities, metric=True, n_components=2, init=None,
108110
(disparities ** 2).sum())
109111

110112
# Compute stress
111-
stress = ((dis.ravel() - disparities.ravel()) ** 2).sum() / 2
113+
if masked:
114+
stress = ((dis - disparities).compressed() ** 2).sum() / 2
115+
else:
116+
stress = ((dis.ravel() - disparities.ravel()) ** 2).sum() / 2
112117

113118
# Update X using the Guttman transform
114119
dis[dis == 0] = 1e-5
115120
ratio = disparities / dis
121+
if masked:
122+
ratio[ratio.mask] = 1
116123
B = - ratio
117124
B[np.arange(len(B)), np.arange(len(B))] += ratio.sum(axis=1)
118-
X = 1. / n_samples * np.dot(B, X)
125+
X = 1. / n_samples * np.ma.dot(B, X)
126+
if hasattr(X,'mask'):
127+
X = np.ma.getdata(X)
119128

120129
dis = np.sqrt((X ** 2).sum(axis=1)).sum()
121130
if verbose >= 2:
@@ -231,7 +240,7 @@ def smacof(dissimilarities, metric=True, n_components=2, init=None, n_init=8,
231240
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
232241
"""
233242

234-
dissimilarities = check_array(dissimilarities)
243+
dissimilarities = check_array(dissimilarities, accept_masked=metric)
235244
random_state = check_random_state(random_state)
236245

237246
if hasattr(init, '__array__'):
@@ -402,7 +411,8 @@ def fit_transform(self, X, y=None, init=None):
402411
algorithm. By default, the algorithm is initialized with a randomly
403412
chosen array.
404413
"""
405-
X = check_array(X)
414+
X = check_array(X,
415+
accept_masked=(self.dissimilarity == "precomputed" and self.metric))
406416
if X.shape[0] == X.shape[1] and self.dissimilarity != "precomputed":
407417
warnings.warn("The MDS API has changed. ``fit`` now constructs an"
408418
" dissimilarity matrix from data. To use a custom "

sklearn/utils/validation.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,10 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
276276
return spmatrix
277277

278278

279-
def check_array(array, accept_sparse=False, dtype="numeric", order=None,
280-
copy=False, force_all_finite=True, ensure_2d=True,
281-
allow_nd=False, ensure_min_samples=1, ensure_min_features=1,
282-
warn_on_dtype=False, estimator=None):
279+
def check_array(array, accept_sparse=False, accept_masked=False,
280+
dtype="numeric", order=None, copy=False, force_all_finite=True,
281+
ensure_2d=True, allow_nd=False, ensure_min_samples=1,
282+
ensure_min_features=1, warn_on_dtype=False, estimator=None):
283283
"""Input validation on an array, list, sparse matrix or similar.
284284
285285
By default, the input is converted to an at least 2D numpy array.
@@ -353,6 +353,13 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
353353
The converted and validated X.
354354
355355
"""
356+
357+
# accept masked check
358+
masked = hasattr(array,'mask')
359+
if not accept_masked and masked:
360+
raise TypeError('Masked arrays are not supported.')
361+
mask = False if not masked else array.mask
362+
356363
# accept_sparse 'None' deprecation check
357364
if accept_sparse is None:
358365
warnings.warn(
@@ -399,7 +406,8 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
399406
array = _ensure_sparse_format(array, accept_sparse, dtype, copy,
400407
force_all_finite)
401408
else:
402-
array = np.array(array, dtype=dtype, order=order, copy=copy)
409+
array = np.ma.array(array, dtype=dtype, order=order,
410+
copy=copy, mask=mask)
403411

404412
if ensure_2d:
405413
if array.ndim == 1:
@@ -408,9 +416,10 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
408416
"Reshape your data either using array.reshape(-1, 1) if "
409417
"your data has a single feature or array.reshape(1, -1) "
410418
"if it contains a single sample.".format(array))
411-
array = np.atleast_2d(array)
419+
array = np.ma.atleast_2d(array)
412420
# To ensure that array flags are maintained
413-
array = np.array(array, dtype=dtype, order=order, copy=copy)
421+
array = np.ma.array(array, dtype=dtype, order=order,
422+
copy=copy, mask=mask)
414423

415424
# make sure we actually converted to numeric:
416425
if dtype_numeric and array.dtype.kind == "O":
@@ -442,10 +451,12 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
442451
msg = ("Data with input dtype %s was converted to %s%s."
443452
% (dtype_orig, array.dtype, context))
444453
warnings.warn(msg, DataConversionWarning)
454+
if not masked:
455+
array = np.ma.getdata(array)
445456
return array
446457

447458

448-
def check_X_y(X, y, accept_sparse=False, dtype="numeric", order=None,
459+
def check_X_y(X, y, accept_sparse=False, accept_masked=True, dtype="numeric", order=None,
449460
copy=False, force_all_finite=True, ensure_2d=True,
450461
allow_nd=False, multi_output=False, ensure_min_samples=1,
451462
ensure_min_features=1, y_numeric=False,
@@ -537,7 +548,7 @@ def check_X_y(X, y, accept_sparse=False, dtype="numeric", order=None,
537548
y_converted : object
538549
The converted and validated y.
539550
"""
540-
X = check_array(X, accept_sparse, dtype, order, copy, force_all_finite,
551+
X = check_array(X, accept_sparse, accept_masked, dtype, order, copy, force_all_finite,
541552
ensure_2d, allow_nd, ensure_min_samples,
542553
ensure_min_features, warn_on_dtype, estimator)
543554
if multi_output:

0 commit comments

Comments
 (0)