Skip to content

Commit 4fa74c7

Browse files
committed
TEST: Add DtypeOverrideMixin to test __init__ and to_filename dtype args
1 parent 696840f commit 4fa74c7

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

nibabel/cifti2/tests/test_cifti2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA
15-
from nibabel.tests.test_image_api import SerializeMixin
15+
from nibabel.tests.test_image_api import SerializeMixin, DtypeOverrideMixin
1616

1717

1818
def compare_xml_leaf(str1, str2):
@@ -415,7 +415,7 @@ def test_underscoring():
415415
assert ci.cifti2._underscore(camel) == underscored
416416

417417

418-
class TestCifti2ImageAPI(_TDA, SerializeMixin):
418+
class TestCifti2ImageAPI(_TDA, SerializeMixin, DtypeOverrideMixin):
419419
""" Basic validation for Cifti2Image instances
420420
"""
421421
# A callable returning an image from ``image_maker(data, header)``
@@ -426,6 +426,8 @@ class TestCifti2ImageAPI(_TDA, SerializeMixin):
426426
ni_header_maker = Nifti2Header
427427
example_shapes = ((2,), (2, 3), (2, 3, 4))
428428
standard_extension = '.nii'
429+
storable_dtypes = (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32,
430+
np.int64, np.uint64, np.float32, np.float64)
429431

430432
def make_imaker(self, arr, header=None, ni_header=None):
431433
for idx, sz in enumerate(arr.shape):

nibabel/tests/test_image_api.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
from .test_parrec import EXAMPLE_IMAGES as PARREC_EXAMPLE_IMAGES
5656
from .test_brikhead import EXAMPLE_IMAGES as AFNI_EXAMPLE_IMAGES
5757

58+
from nibabel.arraywriters import WriterError
59+
5860

5961
def maybe_deprecated(meth_name):
6062
return pytest.deprecated_call() if meth_name == 'get_data' else nullcontext()
@@ -181,7 +183,7 @@ def validate_get_data_deprecated(self, imaker, params):
181183
assert_array_equal(np.asanyarray(img.dataobj), data)
182184

183185

184-
class GetSetDtypeMixin(object):
186+
class GetSetDtypeMixin:
185187
""" Adds dtype tests
186188
187189
Add this one if your image has ``get_data_dtype`` and ``set_data_dtype``.
@@ -666,6 +668,46 @@ def prox_imaker():
666668
yield make_prox_imaker(arr.copy(), aff, hdr), params
667669

668670

671+
class DtypeOverrideMixin(GetSetDtypeMixin):
672+
""" Test images that can accept ``dtype`` arguments to ``__init__`` and
673+
``to_file_map``
674+
"""
675+
676+
def validate_init_dtype_override(self, imaker, params):
677+
img = imaker()
678+
klass = img.__class__
679+
for dtype in self.storable_dtypes:
680+
if hasattr(img, 'affine'):
681+
new_img = klass(img.dataobj, img.affine, header=img.header, dtype=dtype)
682+
else: # XXX This is for CIFTI-2, these validators might need refactoring
683+
new_img = klass(img.dataobj, header=img.header, dtype=dtype)
684+
assert new_img.get_data_dtype() == dtype
685+
686+
if self.has_scaling and self.can_save:
687+
with np.errstate(invalid='ignore'):
688+
rt_img = bytesio_round_trip(new_img)
689+
assert rt_img.get_data_dtype() == dtype
690+
691+
def validate_to_file_dtype_override(self, imaker, params):
692+
if not self.can_save:
693+
raise unittest.SkipTest
694+
img = imaker()
695+
orig_dtype = img.get_data_dtype()
696+
fname = 'image' + self.standard_extension
697+
with InTemporaryDirectory():
698+
for dtype in self.storable_dtypes:
699+
try:
700+
img.to_filename(fname, dtype=dtype)
701+
except WriterError:
702+
# It's possible to try to save to a dtype that requires
703+
# scaling, and images without scale factors will fail.
704+
# We're not testing that here.
705+
continue
706+
rt_img = img.__class__.from_filename(fname)
707+
assert rt_img.get_data_dtype() == dtype
708+
assert img.get_data_dtype() == orig_dtype
709+
710+
669711
class ImageHeaderAPI(MakeImageAPI):
670712
""" When ``self.image_maker`` is an image class, make header from class
671713
"""
@@ -674,7 +716,12 @@ def header_maker(self):
674716
return self.image_maker.header_class()
675717

676718

677-
class TestAnalyzeAPI(ImageHeaderAPI):
719+
class TestSpatialImageAPI(ImageHeaderAPI):
720+
klass = image_maker = SpatialImage
721+
can_save = False
722+
723+
724+
class TestAnalyzeAPI(TestSpatialImageAPI, DtypeOverrideMixin):
678725
""" General image validation API instantiated for Analyze images
679726
"""
680727
klass = image_maker = AnalyzeImage
@@ -685,11 +732,6 @@ class TestAnalyzeAPI(ImageHeaderAPI):
685732
storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64)
686733

687734

688-
class TestSpatialImageAPI(TestAnalyzeAPI):
689-
klass = image_maker = SpatialImage
690-
can_save = False
691-
692-
693735
class TestSpm99AnalyzeAPI(TestAnalyzeAPI):
694736
# SPM-type analyze need scipy for mat file IO
695737
klass = image_maker = Spm99AnalyzeImage

0 commit comments

Comments
 (0)