16
16
import base64
17
17
import sys
18
18
import warnings
19
- from typing import Type
19
+ from copy import copy
20
+ from typing import Type , cast
20
21
21
22
import numpy as np
22
23
27
28
from ..nifti1 import data_type_codes , intent_codes , xform_codes
28
29
from .util import KIND2FMT , array_index_order_codes , gifti_encoding_codes , gifti_endian_codes
29
30
31
+ GIFTI_DTYPES = (
32
+ data_type_codes ['NIFTI_TYPE_UINT8' ],
33
+ data_type_codes ['NIFTI_TYPE_INT32' ],
34
+ data_type_codes ['NIFTI_TYPE_FLOAT32' ],
35
+ )
36
+
30
37
31
38
class _GiftiMDList (list ):
32
39
"""List view of GiftiMetaData object that will translate most operations"""
@@ -462,11 +469,7 @@ def __init__(
462
469
if datatype is None :
463
470
if self .data is None :
464
471
datatype = 'none'
465
- elif self .data .dtype in (
466
- np .dtype ('uint8' ),
467
- np .dtype ('int32' ),
468
- np .dtype ('float32' ),
469
- ):
472
+ elif data_type_codes [self .data .dtype ] in GIFTI_DTYPES :
470
473
datatype = self .data .dtype
471
474
else :
472
475
raise ValueError (
@@ -848,20 +851,45 @@ def _to_xml_element(self):
848
851
GIFTI .append (dar ._to_xml_element ())
849
852
return GIFTI
850
853
851
- def to_xml (self , enc = 'utf-8' ) -> bytes :
854
+ def to_xml (self , enc = 'utf-8' , * , mode = 'strict' ) -> bytes :
852
855
"""Return XML corresponding to image content"""
856
+ if mode == 'strict' :
857
+ if any (arr .datatype not in GIFTI_DTYPES for arr in self .darrays ):
858
+ raise ValueError (
859
+ 'GiftiImage contains data arrays with invalid data types; '
860
+ 'use mode="compat" to automatically cast to conforming types'
861
+ )
862
+ elif mode == 'compat' :
863
+ darrays = []
864
+ for arr in self .darrays :
865
+ if arr .datatype not in GIFTI_DTYPES :
866
+ arr = copy (arr )
867
+ # TODO: Better typing for recoders
868
+ dtype = cast (np .dtype , data_type_codes .dtype [arr .datatype ])
869
+ if np .issubdtype (dtype , np .floating ):
870
+ arr .datatype = data_type_codes ['float32' ]
871
+ elif np .issubdtype (dtype , np .integer ):
872
+ arr .datatype = data_type_codes ['int32' ]
873
+ else :
874
+ raise ValueError (f'Cannot convert { dtype } to float32/int32' )
875
+ darrays .append (arr )
876
+ gii = copy (self )
877
+ gii .darrays = darrays
878
+ return gii .to_xml (enc = enc , mode = 'strict' )
879
+ elif mode != 'force' :
880
+ raise TypeError (f'Unknown mode { mode } ' )
853
881
header = b"""<?xml version="1.0" encoding="UTF-8"?>
854
882
<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
855
883
"""
856
884
return header + super ().to_xml (enc )
857
885
858
886
# Avoid the indirection of going through to_file_map
859
- def to_bytes (self , enc = 'utf-8' ):
860
- return self .to_xml (enc = enc )
887
+ def to_bytes (self , enc = 'utf-8' , * , mode = 'strict' ):
888
+ return self .to_xml (enc = enc , mode = mode )
861
889
862
890
to_bytes .__doc__ = SerializableImage .to_bytes .__doc__
863
891
864
- def to_file_map (self , file_map = None , enc = 'utf-8' ):
892
+ def to_file_map (self , file_map = None , enc = 'utf-8' , * , mode = 'strict' ):
865
893
"""Save the current image to the specified file_map
866
894
867
895
Parameters
@@ -877,7 +905,7 @@ def to_file_map(self, file_map=None, enc='utf-8'):
877
905
if file_map is None :
878
906
file_map = self .file_map
879
907
with file_map ['image' ].get_prepare_fileobj ('wb' ) as f :
880
- f .write (self .to_xml (enc = enc ))
908
+ f .write (self .to_xml (enc = enc , mode = mode ))
881
909
882
910
@classmethod
883
911
def from_file_map (klass , file_map , buffer_size = 35000000 , mmap = True ):
0 commit comments