@@ -276,10 +276,10 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
276276 return spmatrix
277277
278278
279- def check_array (array , accept_sparse = False , dtype = "numeric" , order = None ,
280- copy = False , force_all_finite = True , ensure_2d = True ,
281- allow_nd = False , ensure_min_samples = 1 , ensure_min_features = 1 ,
282- warn_on_dtype = False , estimator = None ):
279+ def check_array (array , accept_sparse = False , accept_masked = False ,
280+ dtype = "numeric" , order = None , copy = False , force_all_finite = True ,
281+ ensure_2d = True , allow_nd = False , ensure_min_samples = 1 ,
282+ ensure_min_features = 1 , warn_on_dtype = False , estimator = None ):
283283 """Input validation on an array, list, sparse matrix or similar.
284284
285285 By default, the input is converted to an at least 2D numpy array.
@@ -353,6 +353,13 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
353353 The converted and validated X.
354354
355355 """
356+
357+ # accept masked check
358+ masked = hasattr (array ,'mask' )
359+ if not accept_masked and masked :
360+ raise TypeError ('Masked arrays are not supported.' )
361+ mask = False if not masked else array .mask
362+
356363 # accept_sparse 'None' deprecation check
357364 if accept_sparse is None :
358365 warnings .warn (
@@ -399,7 +406,8 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
399406 array = _ensure_sparse_format (array , accept_sparse , dtype , copy ,
400407 force_all_finite )
401408 else :
402- array = np .array (array , dtype = dtype , order = order , copy = copy )
409+ array = np .ma .array (array , dtype = dtype , order = order ,
410+ copy = copy , mask = mask )
403411
404412 if ensure_2d :
405413 if array .ndim == 1 :
@@ -408,9 +416,10 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
408416 "Reshape your data either using array.reshape(-1, 1) if "
409417 "your data has a single feature or array.reshape(1, -1) "
410418 "if it contains a single sample." .format (array ))
411- array = np .atleast_2d (array )
419+ array = np .ma . atleast_2d (array )
412420 # To ensure that array flags are maintained
413- array = np .array (array , dtype = dtype , order = order , copy = copy )
421+ array = np .ma .array (array , dtype = dtype , order = order ,
422+ copy = copy , mask = mask )
414423
415424 # make sure we actually converted to numeric:
416425 if dtype_numeric and array .dtype .kind == "O" :
@@ -442,10 +451,12 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
442451 msg = ("Data with input dtype %s was converted to %s%s."
443452 % (dtype_orig , array .dtype , context ))
444453 warnings .warn (msg , DataConversionWarning )
454+ if not masked :
455+ array = np .ma .getdata (array )
445456 return array
446457
447458
448- def check_X_y (X , y , accept_sparse = False , dtype = "numeric" , order = None ,
459+ def check_X_y (X , y , accept_sparse = False , accept_masked = True , dtype = "numeric" , order = None ,
449460 copy = False , force_all_finite = True , ensure_2d = True ,
450461 allow_nd = False , multi_output = False , ensure_min_samples = 1 ,
451462 ensure_min_features = 1 , y_numeric = False ,
@@ -537,7 +548,7 @@ def check_X_y(X, y, accept_sparse=False, dtype="numeric", order=None,
537548 y_converted : object
538549 The converted and validated y.
539550 """
540- X = check_array (X , accept_sparse , dtype , order , copy , force_all_finite ,
551+ X = check_array (X , accept_sparse , accept_masked , dtype , order , copy , force_all_finite ,
541552 ensure_2d , allow_nd , ensure_min_samples ,
542553 ensure_min_features , warn_on_dtype , estimator )
543554 if multi_output :
0 commit comments