Skip to content

Commit f2c1084

Browse files
committed
ENH: Enforce GIFTI compatibility at write
1 parent 89d20b2 commit f2c1084

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

nibabel/gifti/gifti.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import base64
1717
import sys
1818
import warnings
19-
from typing import Type
19+
from copy import copy
20+
from typing import Type, cast
2021

2122
import numpy as np
2223

@@ -27,6 +28,12 @@
2728
from ..nifti1 import data_type_codes, intent_codes, xform_codes
2829
from .util import KIND2FMT, array_index_order_codes, gifti_encoding_codes, gifti_endian_codes
2930

31+
GIFTI_DTYPES = (
32+
data_type_codes['NIFTI_TYPE_UINT8'],
33+
data_type_codes['NIFTI_TYPE_INT32'],
34+
data_type_codes['NIFTI_TYPE_FLOAT32'],
35+
)
36+
3037

3138
class _GiftiMDList(list):
3239
"""List view of GiftiMetaData object that will translate most operations"""
@@ -462,11 +469,7 @@ def __init__(
462469
if datatype is None:
463470
if self.data is None:
464471
datatype = 'none'
465-
elif self.data.dtype in (
466-
np.dtype('uint8'),
467-
np.dtype('int32'),
468-
np.dtype('float32'),
469-
):
472+
elif data_type_codes[self.data.dtype] in GIFTI_DTYPES:
470473
datatype = self.data.dtype
471474
else:
472475
raise ValueError(
@@ -848,20 +851,45 @@ def _to_xml_element(self):
848851
GIFTI.append(dar._to_xml_element())
849852
return GIFTI
850853

851-
def to_xml(self, enc='utf-8') -> bytes:
854+
def to_xml(self, enc='utf-8', *, mode='strict') -> bytes:
852855
"""Return XML corresponding to image content"""
856+
if mode == 'strict':
857+
if any(arr.datatype not in GIFTI_DTYPES for arr in self.darrays):
858+
raise ValueError(
859+
'GiftiImage contains data arrays with invalid data types; '
860+
'use mode="compat" to automatically cast to conforming types'
861+
)
862+
elif mode == 'compat':
863+
darrays = []
864+
for arr in self.darrays:
865+
if arr.datatype not in GIFTI_DTYPES:
866+
arr = copy(arr)
867+
# TODO: Better typing for recoders
868+
dtype = cast(np.dtype, data_type_codes.dtype[arr.datatype])
869+
if np.issubdtype(dtype, np.floating):
870+
arr.datatype = data_type_codes['float32']
871+
elif np.issubdtype(dtype, np.integer):
872+
arr.datatype = data_type_codes['int32']
873+
else:
874+
raise ValueError(f'Cannot convert {dtype} to float32/int32')
875+
darrays.append(arr)
876+
gii = copy(self)
877+
gii.darrays = darrays
878+
return gii.to_xml(enc=enc, mode='strict')
879+
elif mode != 'force':
880+
raise TypeError(f'Unknown mode {mode}')
853881
header = b"""<?xml version="1.0" encoding="UTF-8"?>
854882
<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
855883
"""
856884
return header + super().to_xml(enc)
857885

858886
# Avoid the indirection of going through to_file_map
859-
def to_bytes(self, enc='utf-8'):
860-
return self.to_xml(enc=enc)
887+
def to_bytes(self, enc='utf-8', *, mode='strict'):
888+
return self.to_xml(enc=enc, mode=mode)
861889

862890
to_bytes.__doc__ = SerializableImage.to_bytes.__doc__
863891

864-
def to_file_map(self, file_map=None, enc='utf-8'):
892+
def to_file_map(self, file_map=None, enc='utf-8', *, mode='strict'):
865893
"""Save the current image to the specified file_map
866894
867895
Parameters
@@ -877,7 +905,7 @@ def to_file_map(self, file_map=None, enc='utf-8'):
877905
if file_map is None:
878906
file_map = self.file_map
879907
with file_map['image'].get_prepare_fileobj('wb') as f:
880-
f.write(self.to_xml(enc=enc))
908+
f.write(self.to_xml(enc=enc, mode=mode))
881909

882910
@classmethod
883911
def from_file_map(klass, file_map, buffer_size=35000000, mmap=True):

nibabel/gifti/tests/test_gifti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def test_darray_dtype_coercion_failures():
505505
datatype=darray_dtype,
506506
)
507507
gii = GiftiImage(darrays=[da])
508-
gii_copy = GiftiImage.from_bytes(gii.to_bytes())
508+
gii_copy = GiftiImage.from_bytes(gii.to_bytes(mode='force'))
509509
da_copy = gii_copy.darrays[0]
510510
assert np.dtype(da_copy.data.dtype) == np.dtype(darray_dtype)
511511
assert_array_equal(da_copy.data, da.data)

0 commit comments

Comments
 (0)