From c4c9347e3ec6ee5b52f3470120fee68867277222 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sat, 25 Mar 2023 11:49:28 -0600 Subject: [PATCH 0001/1125] MAINT: Avoid specific PyVista versions (#11595) --- environment.yml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index bbd8deb3c1a..960cbf8e775 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - imageio-ffmpeg>=0.4.1 - vtk>=9.2 - traitlets -- pyvista>=0.32,!=0.35.2,<0.38.0 +- pyvista>=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5 - pyvistaqt>=0.4 - qdarkstyle - darkdetect diff --git a/requirements.txt b/requirements.txt index 7aba1653451..1c2eab9d672 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ xlrd imageio>=2.6.1 imageio-ffmpeg>=0.4.1 traitlets -pyvista>=0.32,!=0.35.2,<0.38.0 +pyvista>=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5 pyvistaqt>=0.4 mffpy>=0.5.7 ipywidgets From f319107095d0ca09cc453ddc57815d9f7fd0ccfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Barth=C3=A9lemy?= Date: Sat, 25 Mar 2023 22:44:00 +0100 Subject: [PATCH 0002/1125] [ENH] Add forward IIR filtering (#11078) Co-authored-by: Eric Larson --- doc/changes/latest.inc | 1 + doc/changes/names.inc | 2 + mne/decoding/transformer.py | 3 +- mne/filter.py | 80 +++++++++++++------ mne/tests/test_filter.py | 53 +++++++++--- mne/utils/docs.py | 15 ++-- mne/viz/_figure.py | 4 +- .../preprocessing/25_background_filtering.py | 9 +++ 8 files changed, 122 insertions(+), 45 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 2a4860065b4..c74f97f3eca 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -28,6 +28,7 @@ Enhancements - Changed suggested type for ``ch_groups``` in `mne.viz.plot_sensors` from array to list of list(s) (arrays are still supported). (:gh:`11465` by `Hyonyoung Shin`_) - Add support for UCL/FIL OPM data using :func:`mne.io.read_raw_fil` (:gh:`11366` by :newcontrib:`George O'Neill` and `Robert Seymour`_) - Forward argument ``axes`` from `mne.viz.plot_sensors` to `mne.channels.DigMontage.plot` (:gh:`11470` by :newcontrib:`Jan Ebert` and `Mathieu Scheltienne`_) +- Add forward IIR filtering, using parameters ``method='iir', phase='forward'`` (:gh:`11078` by :newcontrib:`Quentin Barthélemy`) - Added ability to read stimulus durations from SNIRF files when using :func:`mne.io.read_raw_snirf` (:gh:`11397` by `Robert Luke`_) - Add :meth:`mne.Info.save` to save an :class:`mne.Info` object to a fif file (:gh:`11401` by `Alex Rockhill`_) - Improved error message when downloads are corrupted for :func:`mne.datasets.sample.data_path` and related functions (:gh:`11407` by `Eric Larson`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 72dbdfd4e36..29d4697582c 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -398,6 +398,8 @@ .. _Qianliang Li: https://www.dtu.dk/english/service/phonebook/person?id=126774 +.. _Quentin Barthélemy: https://github.com/qbarthelemy + .. _Quentin Bertrand: https://github.com/QB3 .. _Qunxi Dong: https://github.com/dongqunxi diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 4bf73cc5f5d..63999640afd 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -450,8 +450,7 @@ class FilterEstimator(TransformerMixin): Number of jobs to run in parallel. Can be 'cuda' if ``cupy`` is installed properly and method='fir'. method : str - 'fir' will use overlap-add FIR filtering, 'iir' will use IIR - forward-backward filtering (via filtfilt). + 'fir' will use overlap-add FIR filtering, 'iir' will use IIR filtering. iir_params : dict | None Dictionary of parameters to use for IIR filtering. See mne.filter.construct_iir_filter for details. If iir_params diff --git a/mne/filter.py b/mne/filter.py index e70af37fdfc..5c934ae52ec 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -420,20 +420,29 @@ def _check_coefficients(system): 'coefficients.') -def _filtfilt(x, iir_params, picks, n_jobs, copy): - """Call filtfilt.""" +def _iir_filter(x, iir_params, picks, n_jobs, copy, phase='zero'): + """Call filtfilt or lfilter.""" # set up array for filtering, reshape to 2D, operate on last axis - from scipy.signal import filtfilt, sosfiltfilt - padlen = min(iir_params['padlen'], x.shape[-1] - 1) + from scipy.signal import filtfilt, sosfiltfilt, lfilter, sosfilt x, orig_shape, picks = _prep_for_filtering(x, copy, picks) - if 'sos' in iir_params: - fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen, - axis=-1) - _check_coefficients(iir_params['sos']) + if phase in ('zero', 'zero-double'): + padlen = min(iir_params['padlen'], x.shape[-1] - 1) + if 'sos' in iir_params: + fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen, + axis=-1) + _check_coefficients(iir_params['sos']) + else: + fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'], + padlen=padlen, axis=-1) + _check_coefficients((iir_params['b'], iir_params['a'])) else: - fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'], - padlen=padlen, axis=-1) - _check_coefficients((iir_params['b'], iir_params['a'])) + if 'sos' in iir_params: + fun = partial(sosfilt, sos=iir_params['sos'], axis=-1) + _check_coefficients(iir_params['sos']) + else: + fun = partial(lfilter, b=iir_params['b'], a=iir_params['a'], + axis=-1) + _check_coefficients((iir_params['b'], iir_params['a'])) parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) if n_jobs == 1: for p in picks: @@ -508,7 +517,8 @@ def estimate_ringing_samples(system, max_try=100000): @verbose def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, - btype=None, return_copy=True, verbose=None): + btype=None, return_copy=True, *, phase='zero', + verbose=None): """Use IIR parameters to get filtering coefficients. This function works like a wrapper for iirdesign and iirfilter in @@ -563,6 +573,7 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, ``iir_params`` will be set inplace (if they weren't already). Otherwise, a new ``iir_params`` instance will be created and returned with these entries. + %(phase)s %(verbose)s Returns @@ -659,12 +670,17 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, Wp = f_pass / (float(sfreq) / 2) # IT will de designed ftype_nice = _ftype_dict.get(ftype, ftype) + _validate_type(phase, str, 'phase') + _check_option('phase', phase, ('zero', 'zero-double', 'forward')) + if phase in ('zero-double', 'zero'): + ptype = 'zero-phase (two-pass forward and reverse) non-causal' + else: + ptype = 'non-linear phase (one-pass forward) causal' logger.info('') logger.info('IIR filter parameters') logger.info('---------------------') - logger.info('%s %s zero-phase (two-pass forward and reverse) ' - 'non-causal filter:' % (ftype_nice, btype)) - # SciPy designs for -3dB but we do forward-backward, so this is -6dB + logger.info(f'{ftype_nice} {btype} {ptype} filter:') + # SciPy designs forward for -3dB, so forward-backward is -6dB if 'order' in iir_params: kwargs = dict(N=iir_params['order'], Wn=Wp, btype=btype, ftype=ftype, output=output) @@ -672,8 +688,12 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, if key in iir_params: kwargs[key] = iir_params[key] system = iirfilter(**kwargs) - logger.info('- Filter order %d (effective, after forward-backward)' - % (2 * iir_params['order'] * len(Wp),)) + if phase in ('zero', 'zero-double'): + ptype, pmul = '(effective, after forward-backward)', 2 + else: + ptype, pmul = '(forward)', 1 + logger.info('- Filter order %d %s' + % (pmul * iir_params['order'] * len(Wp), ptype)) else: # use gpass / gstop design Ws = np.asanyarray(f_stop) / (float(sfreq) / 2) @@ -694,8 +714,10 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, cutoffs = sosfreqz(system, worN=Wp * np.pi)[1] else: cutoffs = freqz(system[0], system[1], worN=Wp * np.pi)[1] + cutoffs = 20 * np.log10(np.abs(cutoffs)) # 2 * 20 here because we do forward-backward filtering - cutoffs = 40 * np.log10(np.abs(cutoffs)) + if phase in ('zero', 'zero-double'): + cutoffs *= 2 cutoffs = ', '.join(['%0.2f' % (c,) for c in cutoffs]) logger.info('- Cutoff%s at %s Hz: %s dB' % (_pl(f_pass), edge_freqs, cutoffs)) @@ -816,7 +838,7 @@ def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto', data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, copy, pad) else: - data = _filtfilt(data, filt, picks, n_jobs, copy) + data = _iir_filter(data, filt, picks, n_jobs, copy, phase) return data @@ -977,7 +999,8 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', data, sfreq, None, h_freq, None, h_trans_bandwidth, filter_length, method, phase, fir_window, fir_design) if method == 'iir': - out = construct_iir_filter(iir_params, f_p, f_s, sfreq, 'lowpass') + out = construct_iir_filter(iir_params, f_p, f_s, sfreq, 'lowpass', + phase=phase) else: # 'fir' freq = [0, f_p, f_s] gain = [1, 1, 0] @@ -992,7 +1015,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', filter_length, method, phase, fir_window, fir_design) if method == 'iir': out = construct_iir_filter(iir_params, pass_, stop, sfreq, - 'highpass') + 'highpass', phase=phase) else: # 'fir' freq = [stop, pass_, sfreq / 2.] gain = [0, 1, 1] @@ -1010,7 +1033,8 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', fir_window, fir_design) if method == 'iir': out = construct_iir_filter(iir_params, [f_p1, f_p2], - [f_s1, f_s2], sfreq, 'bandpass') + [f_s1, f_s2], sfreq, 'bandpass', + phase=phase) else: # 'fir' freq = [f_s1, f_p1, f_p2, f_s2] gain = [0, 1, 1, 0] @@ -1041,7 +1065,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', 'with FIR filtering') out = construct_iir_filter(iir_params, [f_p1[0], f_p2[0]], [f_s1[0], f_s2[0]], sfreq, - 'bandstop') + 'bandstop', phase=phase) else: # 'fir' freq = np.r_[f_p1, f_s1, f_s2, f_p2] gain = np.r_[np.ones_like(f_p1), np.zeros_like(f_s1), @@ -1634,7 +1658,8 @@ def detrend(x, order=1, axis=-1): 'blackman': dict(name='Blackman', ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) -_known_phases = ('linear', 'zero', 'zero-double', 'minimum') +_known_phases_fir = ('linear', 'zero', 'zero-double', 'minimum') +_known_phases_iir = ('zero', 'zero-double', 'forward') _known_fir_designs = ('firwin', 'firwin2') _fir_design_dict = { 'firwin': 'Windowed time-domain', @@ -1676,7 +1701,12 @@ def _triage_filter_params(x, sfreq, l_freq, h_freq, fir_design, bands='scalar', reverse=False): """Validate and automate filter parameter selection.""" _validate_type(phase, 'str', 'phase') - _check_option('phase', phase, _known_phases) + if method == 'fir': + _check_option('phase', phase, _known_phases_fir, + extra='when FIR filtering') + else: + _check_option('phase', phase, _known_phases_iir, + extra='when IIR filtering') _validate_type(fir_window, 'str', 'fir_window') _check_option('fir_window', fir_window, _known_fir_windows) _validate_type(fir_design, 'str', 'fir_design') diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index eb821f8ff4f..406f85c5c39 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -146,9 +146,11 @@ def test_iir_stability(): pytest.raises(RuntimeError, filter_data, sig, sfreq, 0.6, None, method='iir', iir_params=dict(ftype='butter', order=8, output='ba')) - # This one should work just fine + # These ones should work just fine filter_data(sig, sfreq, 0.6, None, method='iir', iir_params=dict(ftype='butter', order=8, output='sos')) + filter_data(sig, sfreq, 0.6, None, method='iir', phase='forward', + iir_params=dict(ftype='butter', order=8, output='sos')) # bad system type pytest.raises(ValueError, filter_data, sig, sfreq, 0.6, None, method='iir', iir_params=dict(ftype='butter', order=8, output='foo')) @@ -194,6 +196,31 @@ def test_iir_stability(): assert_allclose(x_sos[100:-100], x_ba[100:-100]) +def test_iir_phase(): + """Test IIR filter phase.""" + sig, sfreq, ind_one = np.zeros(101), 10, 50 + sig[ind_one] = 1 + iir_params = dict(ftype='butter', order=2, output='sos') + + # forward IIR + sig_f = filter_data(sig, sfreq, 0.6, None, method='iir', phase='forward', + iir_params=iir_params) + # test if output is zero before peak + assert_allclose(sig_f[:ind_one], np.zeros(ind_one)) + # test if power is lower after filtering + assert np.linalg.norm(sig) > np.linalg.norm(sig_f) + + # forward-backward IIR + sig_fb = filter_data(sig, sfreq, 0.6, None, method='iir', phase='zero', + iir_params=iir_params) + # test if filtered signal is symmetric + assert_allclose(sig_fb, sig_fb[::-1], rtol=1e-5, atol=1e-6) + # test if peak is not shifted + assert np.argmax(sig_fb) == ind_one + # test if power is lower after bilateral filtering + assert np.linalg.norm(sig_f) > np.linalg.norm(sig_fb) + + line_freqs = tuple(range(60, 241, 60)) @@ -615,11 +642,12 @@ def test_detrend(): assert_array_almost_equal(detrend(x, 0), np.zeros_like(x)) +@pytest.mark.parametrize('phase', ('zero', 'zero-double', 'forward')) @pytest.mark.parametrize('output', ('ba', 'sos')) @pytest.mark.parametrize('ftype', ('butter', 'bessel', 'ellip')) @pytest.mark.parametrize('btype', ('lowpass', 'bandpass')) @pytest.mark.parametrize('order', (1, 4)) -def test_reporting_iir(ftype, btype, order, output): +def test_reporting_iir(phase, ftype, btype, order, output): """Test IIR filter reporting.""" fs = 1000. l_freq = 1. if btype == 'bandpass' else None @@ -632,31 +660,34 @@ def test_reporting_iir(ftype, btype, order, output): else: pass_tol = 0.2 with catch_logging() as log: - x = create_filter(None, fs, l_freq, 40., method='iir', + x = create_filter(None, fs, l_freq, 40., method='iir', phase=phase, iir_params=iir_params, verbose=True) order_eff = order * (1 + (btype == 'bandpass')) if output == 'ba': assert len(x['b']) == order_eff + 1 + order_mult = 1. if phase == 'forward' else 2. log = log.getvalue() keys = [ 'IIR', - 'zero-phase', - 'two-pass forward and reverse', - 'non-causal', btype, ftype, - 'Filter order %d' % (order_eff * 2,), + 'Filter order %d' % (order_eff * order_mult,), 'Cutoff ' if btype == 'lowpass' else 'Cutoffs ', ] + if phase == 'forward': + keys += ['non-linear phase', 'one-pass forward', 'causal'] + else: + keys += ['zero-phase', 'two-pass forward and reverse', 'non-causal'] dB_decade = -27.74 if ftype == 'ellip': - dB_cutoff = -6.0 + dB_cutoff = -3.0 elif order == 1 or ftype == 'butter': - dB_cutoff = -6.02 + dB_cutoff = -3.01 else: assert ftype == 'bessel' assert order == 4 - dB_cutoff = -15.16 + dB_cutoff = -7.58 + dB_cutoff *= order_mult if btype == 'lowpass': keys += ['%0.2f dB' % (dB_cutoff,)] for key in keys: @@ -684,7 +715,7 @@ def test_reporting_iir(ftype, btype, order, output): else: passes += [idx_0p1, idx_1] - edge_val = 10 ** (dB_cutoff / 40.) + edge_val = 10 ** (dB_cutoff / (order_mult * 20.)) assert_allclose(h[edges], edge_val, atol=0.01) assert_allclose(h[passes], 1., atol=pass_tol) if ftype == 'butter' and btype == 'lowpass': diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 5007ce180e8..226afda56cc 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2540,14 +2540,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict['phase'] = """ phase : str - Phase of the filter, only used if ``method='fir'``. - Symmetric linear-phase FIR filters are constructed, and if ``phase='zero'`` - (default), the delay of this filter is compensated for, making it - non-causal. If ``phase='zero-double'``, + Phase of the filter. + When ``method='fir'``, symmetric linear-phase FIR filters are constructed, + and if ``phase='zero'`` (default), the delay of this filter is compensated + for, making it non-causal. If ``phase='zero-double'``, then this filter is applied twice, once forward, and once backward (also making it non-causal). If ``'minimum'``, then a minimum-phase filter - will be constricted and applied, which is causal but has weaker stop-band + will be constructed and applied, which is causal but has weaker stop-band suppression. + When ``method='iir'``, ``phase='zero'`` (default) or + ``phase='zero-double'`` constructs and applies IIR filter twice, once + forward, and once backward (making it non-causal) using filtfilt. + If ``phase='forward'``, it constructs and applies forward IIR filter using + lfilter. .. versionadded:: 0.13 """ diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index fd0f7a10a2a..87474d65b0b 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -314,7 +314,7 @@ def _load_data(self, start=None, stop=None): def _apply_filter(self, data, start, stop, picks): """Filter (with same defaults as raw.filter()).""" - from ..filter import _overlap_add_filter, _filtfilt + from ..filter import _overlap_add_filter, _iir_filter starts, stops = self.mne.filter_bounds mask = (starts < stop) & (stops > start) starts = np.maximum(starts[mask], start) - start @@ -328,7 +328,7 @@ def _apply_filter(self, data, start, stop, picks): this_data = _overlap_add_filter( this_data, self.mne.filter_coefs, copy=False) else: # IIR - this_data = _filtfilt( + this_data = _iir_filter( this_data, self.mne.filter_coefs, None, 1, False) data[_picks, _start:_stop] = this_data diff --git a/tutorials/preprocessing/25_background_filtering.py b/tutorials/preprocessing/25_background_filtering.py index 314857e5651..998e8919d7a 100644 --- a/tutorials/preprocessing/25_background_filtering.py +++ b/tutorials/preprocessing/25_background_filtering.py @@ -593,6 +593,15 @@ def plot_signal(x, offset): plot_filter(filt, sfreq, freq, gain, 'Chebychev-1 order=8, ripple=6 dB', compensate=True, **kwargs) +# %% +# Similarly to FIR filters, we can define causal IIR filters. + +filt = mne.filter.create_filter(x, sfreq, l_freq=None, h_freq=f_p, + method='iir', phase='forward', + iir_params=iir_params, verbose=True) +plot_filter(filt, sfreq, freq, gain, 'Chebychev-1 order=8, ripple=6 dB', + compensate=False, **kwargs) + # %% # Applying IIR filters # -------------------- From 4b419e03e36bae756af00f90c2c1a3cd6d7c20e1 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Sun, 26 Mar 2023 13:20:56 -0500 Subject: [PATCH 0003/1125] add contralateral referencing example (#11596) --- .../contralateral_referencing.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/preprocessing/contralateral_referencing.py diff --git a/examples/preprocessing/contralateral_referencing.py b/examples/preprocessing/contralateral_referencing.py new file mode 100644 index 00000000000..ad31d94f742 --- /dev/null +++ b/examples/preprocessing/contralateral_referencing.py @@ -0,0 +1,71 @@ +""" +.. _ex-contralateral-referencing: + +======================================= +Using contralateral referencing for EEG +======================================= + +Instead of using a single reference electrode for all channels, some +researchers reference the EEG electrodes in each hemisphere to an electrode in +the contralateral hemisphere (often an electrode over the mastoid bone; this is +common in sleep research for example). Here we demonstrate how to set a +contralateral EEG reference. +""" + +import numpy as np +import mne + +ssvep_folder = mne.datasets.ssvep.data_path() +ssvep_data_raw_path = (ssvep_folder / 'sub-02' / 'ses-01' / 'eeg' / + 'sub-02_ses-01_task-ssvep_eeg.vhdr') +raw = mne.io.read_raw(ssvep_data_raw_path, preload=True) +_ = raw.set_montage('easycap-M1') + +# %% +# The electrodes TP9 and TP10 are near the mastoids so we'll use them as our +# contralateral reference channels. Then we'll create our hemisphere groups. + +raw.rename_channels({ + 'TP9': 'M1', + 'TP10': 'M2' +}) + +# this splits electrodes into 3 groups; left, midline, and right +ch_indices = mne.channels.make_1020_channel_selections(raw.info) + +# convert indices to names +orig_names = np.array(raw.ch_names) +ch_names = {key: orig_names[idxs].tolist() for key, idxs in ch_indices.items()} + +# remove the ref channels from the lists of to-be-rereferenced channels +ch_names['Left'].remove('M1') +ch_names['Right'].remove('M2') + +# %% +# Finally we do the referencing. For the midline channels we'll reference them +# to the mean of the two mastoid channels; the left and right hemispheres we'll +# reference to the single contralateral mastoid channel. + +# midline referencing to mean of mastoids: +mastoids = ['M1', 'M2'] +rereferenced_midline_chs = (raw.copy() + .pick(mastoids + ch_names['Midline']) + .set_eeg_reference(mastoids) + .drop_channels(mastoids) + ) + +# contralateral referencing (alters channels in `raw` in-place): +for ref, hemi in dict(M2=ch_names['Left'], M1=ch_names['Right']).items(): + mne.set_bipolar_reference( + raw, anode=hemi, cathode=[ref] * len(hemi), copy=False + ) +# strip off '-M1' and '-M2' suffixes added to each bipolar-referenced channel +raw.rename_channels(lambda ch_name: ch_name.split('-')[0]) + +# replace unreferenced midline with rereferenced midline +_ = (raw.drop_channels(ch_names['Midline']) + .add_channels([rereferenced_midline_chs])) + +# %% +# Make sure the channel locations still look right: +fig = raw.plot_sensors(show_names=True, sphere='eeglab') From 84259a673caafc3dfb37e07e9d73a3593c5c6acf Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Mon, 27 Mar 2023 10:05:31 +0200 Subject: [PATCH 0004/1125] MAINT: IOError is an alias of OSError (#11600) --- mne/_freesurfer.py | 8 ++++---- mne/annotations.py | 4 ++-- mne/bem.py | 2 +- mne/commands/tests/test_commands.py | 2 +- mne/coreg.py | 20 ++++++++++---------- mne/dipole.py | 4 ++-- mne/fixes.py | 2 +- mne/forward/forward.py | 4 ++-- mne/gui/_coreg.py | 4 ++-- mne/io/artemis123/artemis123.py | 2 +- mne/io/artemis123/utils.py | 2 +- mne/io/brainvision/brainvision.py | 2 +- mne/io/brainvision/tests/test_brainvision.py | 2 +- mne/io/ctf/ctf.py | 2 +- mne/io/ctf/res4.py | 2 +- mne/io/ctf/tests/test_ctf.py | 2 +- mne/io/curry/tests/test_curry.py | 2 +- mne/io/eeglab/eeglab.py | 2 +- mne/io/fiff/raw.py | 2 +- mne/io/fiff/tests/test_raw_fiff.py | 4 ++-- mne/io/kit/kit.py | 6 +++--- mne/io/tests/test_raw.py | 2 +- mne/io/write.py | 2 +- mne/label.py | 6 +++--- mne/morph.py | 2 +- mne/report/report.py | 2 +- mne/source_estimate.py | 10 +++++----- mne/source_space.py | 6 +++--- mne/surface.py | 6 +++--- mne/tests/test_annotations.py | 4 ++-- mne/tests/test_bem.py | 2 +- mne/tests/test_chpi.py | 2 +- mne/tests/test_coreg.py | 4 ++-- mne/tests/test_epochs.py | 4 ++-- mne/tests/test_label.py | 8 ++++---- mne/tests/test_morph.py | 4 ++-- mne/tests/test_source_space.py | 2 +- mne/tests/test_transforms.py | 2 +- mne/time_frequency/tests/test_tfr.py | 2 +- mne/transforms.py | 4 ++-- mne/utils/check.py | 6 +++--- mne/utils/config.py | 2 +- mne/utils/tests/test_check.py | 2 +- mne/viz/backends/_utils.py | 2 +- mne/viz/misc.py | 8 ++++---- mne/viz/tests/test_misc.py | 2 +- 46 files changed, 88 insertions(+), 88 deletions(-) diff --git a/mne/_freesurfer.py b/mne/_freesurfer.py index 64442166d5c..cf0211ec3e6 100644 --- a/mne/_freesurfer.py +++ b/mne/_freesurfer.py @@ -134,7 +134,7 @@ def _get_mgz_header(fname): fname = _check_fname(fname, overwrite='read', must_exist=True, name='MRI image') if fname.suffix != ".mgz": - raise IOError('Filename must end with .mgz') + raise OSError('Filename must end with .mgz') header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)), ('type', '>i4'), ('dof', '>i4'), ('goodRASFlag', '>i2'), ('delta', '>f4', (3,)), ('Mdc', '>f4', (3, 3)), @@ -569,7 +569,7 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): if not path.is_file(): path = subjects_dir / subject / "mri" / "T1.mgz" if not path.is_file(): - raise IOError('mri not found: %s' % path) + raise OSError('mri not found: %s' % path) _, _, mri_ras_t, _, _ = _read_mri_info(path) mri_mni_t = combine_transforms(mri_ras_t, ras_mni_t, 'mri', 'mni_tal') return mri_mni_t @@ -591,7 +591,7 @@ def _check_mri(mri, subject, subjects_dir): if op.basename(mri) == mri: err = (f'Ambiguous filename - found {mri!r} in current folder.\n' 'If this is correct prefix name with relative or absolute path') - raise IOError(err) + raise OSError(err) return mri @@ -746,7 +746,7 @@ def _get_head_surface(surf, subject, subjects_dir, bem=None, verbose=None): return read_bem_surfaces(fname, on_defects='warn')[0] else: return _read_mri_surface(fname) - raise IOError('No head surface found for subject ' + raise OSError('No head surface found for subject ' f'{subject} after trying:\n' + '\n'.join(try_fnames)) diff --git a/mne/annotations.py b/mne/annotations.py index cf45d951c04..1d6ee2c5768 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1134,10 +1134,10 @@ def read_annotations(fname, sfreq='auto', uint16_codec=None): elif name.startswith('events_') and fname.endswith('mat'): annotations = _read_brainstorm_annotations(fname) else: - raise IOError('Unknown annotation file format "%s"' % fname) + raise OSError('Unknown annotation file format "%s"' % fname) if annotations is None: - raise IOError('No annotation data found in file "%s"' % fname) + raise OSError('No annotation data found in file "%s"' % fname) return annotations diff --git a/mne/bem.py b/mne/bem.py index 6a60ce065f5..66c0800a4d9 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -2247,7 +2247,7 @@ def _ensure_bem_surfaces(bem, extra_allow=(), name='bem'): def _check_file(fname, overwrite): """Prevent overwrites.""" if op.isfile(fname) and not overwrite: - raise IOError(f'File {fname} exists, use --overwrite to overwrite it') + raise OSError(f'File {fname} exists, use --overwrite to overwrite it') _tri_levels = dict( diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index e405778b32b..dd786d3773a 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -165,7 +165,7 @@ def test_make_scalp_surfaces(tmp_path, monkeypatch): mne_make_scalp_surfaces.run() assert op.isfile(dense_fname) assert op.isfile(medium_fname) - with pytest.raises(IOError, match='overwrite'): + with pytest.raises(OSError, match='overwrite'): mne_make_scalp_surfaces.run() # actually check the outputs head_py = read_bem_surfaces(dense_fname) diff --git a/mne/coreg.py b/mne/coreg.py index 3419392a013..52616d3c113 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -182,24 +182,24 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, # make sure freesurfer files exist fs_src = os.path.join(fs_home, 'subjects', 'fsaverage') if not os.path.exists(fs_src): - raise IOError('fsaverage not found at %r. Is fs_home specified ' + raise OSError('fsaverage not found at %r. Is fs_home specified ' 'correctly?' % fs_src) for name in ('label', 'mri', 'surf'): dirname = os.path.join(fs_src, name) if not os.path.isdir(dirname): - raise IOError("Freesurfer fsaverage seems to be incomplete: No " + raise OSError("Freesurfer fsaverage seems to be incomplete: No " "directory named %s found in %s" % (name, fs_src)) # make sure destination does not already exist dest = os.path.join(subjects_dir, 'fsaverage') if dest == fs_src: - raise IOError( + raise OSError( "Your subjects_dir points to the freesurfer subjects_dir (%r). " "The default subject can not be created in the freesurfer " "installation directory; please specify a different " "subjects_dir." % subjects_dir) elif (not update) and os.path.exists(dest): - raise IOError( + raise OSError( "Can not create fsaverage because %r already exists in " "subjects_dir %r. Delete or rename the existing fsaverage " "subject folder." % ('fsaverage', subjects_dir)) @@ -527,7 +527,7 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): subject : str Name of the mri subject. skip_fiducials : bool - Do not scale the MRI fiducials. If False, an IOError will be raised + Do not scale the MRI fiducials. If False, an OSError will be raised if no fiducials file can be found. subjects_dir : None | path-like Override the SUBJECTS_DIR environment variable @@ -591,7 +591,7 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): paths['fid'] = _find_fiducials_files(subject, subjects_dir) # check that we found at least one if len(paths['fid']) == 0: - raise IOError("No fiducials file found for %s. The fiducials " + raise OSError("No fiducials file found for %s. The fiducials " "file should be named " "{subject}/bem/{subject}-fiducials.fif. In " "order to scale an MRI without fiducials set " @@ -740,7 +740,7 @@ def read_mri_cfg(subject, subjects_dir=None): fname = subjects_dir / subject / "MRI scaling parameters.cfg" if not fname.exists(): - raise IOError("%r does not seem to be a scaled mri subject: %r does " + raise OSError("%r does not seem to be a scaled mri subject: %r does " "not exist." % (subject, fname)) logger.info("Reading MRI cfg file %s" % fname) @@ -864,7 +864,7 @@ def scale_bem(subject_to, bem_name, subject_from=None, scale=None, name=bem_name) if os.path.exists(dst): - raise IOError("File already exists: %s" % dst) + raise OSError("File already exists: %s" % dst) surfs = read_bem_surfaces(src, on_defects=on_defects) for surf in surfs: @@ -949,7 +949,7 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False, subjects_dir : None | path-like Override the ``SUBJECTS_DIR`` environment variable. skip_fiducials : bool - Do not scale the MRI fiducials. If False (default), an IOError will be + Do not scale the MRI fiducials. If False (default), an OSError will be raised if no fiducials file can be found. labels : bool Also scale all labels (default True). @@ -987,7 +987,7 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False, subjects_dir=subjects_dir) if os.path.exists(dest): if not overwrite: - raise IOError("Subject directory for %s already exists: %r" + raise OSError("Subject directory for %s already exists: %r" % (subject_to, dest)) shutil.rmtree(dest) diff --git a/mne/dipole.py b/mne/dipole.py index 415261063fd..94e984ad617 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -542,7 +542,7 @@ def _read_dipole_text(fname): del line data = np.atleast_2d(np.array(data, float)) if def_line is None: - raise IOError('Dipole text file is missing field definition ' + raise OSError('Dipole text file is missing field definition ' 'comment, cannot parse %s' % (fname,)) # actually parse the fields def_line = def_line.lstrip('%').lstrip('#').strip() @@ -575,7 +575,7 @@ def _read_dipole_text(fname): if len(ignored_fields) > 0: warn('Ignoring extra fields in dipole file: %s' % (ignored_fields,)) if len(fields) != data.shape[1]: - raise IOError('More data fields (%s) found than data columns (%s): %s' + raise OSError('More data fields (%s) found than data columns (%s): %s' % (len(fields), data.shape[1], fields)) logger.info("%d dipole(s) found" % len(data)) diff --git a/mne/fixes.py b/mne/fixes.py index 09bc6a58947..72d7dd1fbb1 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -145,7 +145,7 @@ def _read_volume_info(fobj): 'zras', 'cras']: pair = fobj.readline().decode('utf-8').split('=') if pair[0].strip() != key or len(pair) != 2: - raise IOError('Error parsing volume info.') + raise OSError('Error parsing volume info.') if key in ('valid', 'filename'): volume_info[key] = pair[1].strip() elif key == 'volume': diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 866b319c81b..7f78eeaf08d 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -1880,7 +1880,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(trans, trans_data) except Exception: - raise IOError('trans was a dict, but could not be ' + raise OSError('trans was a dict, but could not be ' 'written to disk as a transform file') elif isinstance(trans, (str, Path, PathLike)): _check_fname(trans, "read", must_exist=True, name="trans") @@ -1894,7 +1894,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(mri, mri_data) except Exception: - raise IOError('mri was a dict, but could not be ' + raise OSError('mri was a dict, but could not be ' 'written to disk as a transform file') elif isinstance(mri, (str, Path, PathLike)): _check_fname(mri, "read", must_exist=True, name="mri") diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 61232f5ed7d..d5126ba2fee 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -406,7 +406,7 @@ def _set_info_file(self, fname): info_file = _check_fname( fname, overwrite='read', must_exist=True, need_dir=False) valid = True - except IOError: + except OSError: valid = False if valid: style = dict(border="initial") @@ -1151,7 +1151,7 @@ def _add_head_surface(self): self._renderer, surface, self._subject, self._subjects_dir, bem, self._coord_frame, self._to_cf_t, alpha=self._head_opacity) - except IOError: + except OSError: head_actor, head_surf, _ = _plot_head_surface( self._renderer, "head", self._subject, self._subjects_dir, bem, self._coord_frame, self._to_cf_t, diff --git a/mne/io/artemis123/artemis123.py b/mne/io/artemis123/artemis123.py index 2a84588772e..e959158d9c4 100644 --- a/mne/io/artemis123/artemis123.py +++ b/mne/io/artemis123/artemis123.py @@ -107,7 +107,7 @@ def _get_artemis123_info(fname, pos_fname=None): elif sectionFlag == 2: values = line.strip().split('\t') if len(values) != 7: - raise IOError('Error parsing line \n\t:%s\n' % line + + raise OSError('Error parsing line \n\t:%s\n' % line + 'from file %s' % header) tmp = dict() for k, v in zip(chan_keys, values): diff --git a/mne/io/artemis123/utils.py b/mne/io/artemis123/utils.py index de7e98c9113..a2448b5fdfc 100644 --- a/mne/io/artemis123/utils.py +++ b/mne/io/artemis123/utils.py @@ -13,7 +13,7 @@ def _load_mne_locs(fname=None): fname = op.join(resource_dir, 'Artemis123_mneLoc.csv') if not op.exists(fname): - raise IOError('MNE locs file "%s" does not exist' % (fname)) + raise OSError('MNE locs file "%s" does not exist' % (fname)) logger.info('Loading mne loc file {}'.format(fname)) locs = dict() diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index 99a3cbcce29..d536153e5ac 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -469,7 +469,7 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): ext = op.splitext(hdr_fname)[-1] ahdr_format = (ext == '.ahdr') if ext not in ('.vhdr', '.ahdr'): - raise IOError("The header file must be given to read the data, " + raise OSError("The header file must be given to read the data, " "not a file with extension '%s'." % ext) settings, cfg, cinfostr, info = _aux_hdr_info(hdr_fname) diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py index 64477bbea89..b43287eabe3 100644 --- a/mne/io/brainvision/tests/test_brainvision.py +++ b/mne/io/brainvision/tests/test_brainvision.py @@ -498,7 +498,7 @@ def test_brainvision_data_software_filters_latin1_global_units(): def test_brainvision_data(): """Test reading raw Brain Vision files.""" - pytest.raises(IOError, read_raw_brainvision, vmrk_path) + pytest.raises(OSError, read_raw_brainvision, vmrk_path) pytest.raises(ValueError, read_raw_brainvision, vhdr_path, preload=True, scale="foo") diff --git a/mne/io/ctf/ctf.py b/mne/io/ctf/ctf.py index 4802af31c69..d06e48d8bb5 100644 --- a/mne/io/ctf/ctf.py +++ b/mne/io/ctf/ctf.py @@ -147,7 +147,7 @@ def __init__(self, directory, system_clock='truncate', preload=False, raw_extras.append(sample_info) first_samps = [0] * len(last_samps) if len(fnames) == 0: - raise IOError( + raise OSError( f'Could not find any data, could not find the following ' f'file(s): {missing_names}, and the following file(s) had no ' f'valid samples: {no_samps}') diff --git a/mne/io/ctf/res4.py b/mne/io/ctf/res4.py index be70a54b86d..8da208e18a5 100644 --- a/mne/io/ctf/res4.py +++ b/mne/io/ctf/res4.py @@ -19,7 +19,7 @@ def _make_ctf_name(directory, extra, raise_error=True): found = True if not op.isfile(fname): if raise_error: - raise IOError('Standard file %s not found' % fname) + raise OSError('Standard file %s not found' % fname) found = False return fname, found diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index 42de95134d4..1e699007714 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -416,7 +416,7 @@ def test_missing_res4(tmp_path): tmp_path / ctf_fname_continuous) read_raw_ctf(use_ds) os.remove(use_ds / (ctf_fname_continuous[:-2] + 'meg4')) - with pytest.raises(IOError, match='could not find the following'): + with pytest.raises(OSError, match='could not find the following'): read_raw_ctf(use_ds) diff --git a/mne/io/curry/tests/test_curry.py b/mne/io/curry/tests/test_curry.py index de759e14a5d..2f9c8c4d141 100644 --- a/mne/io/curry/tests/test_curry.py +++ b/mne/io/curry/tests/test_curry.py @@ -296,7 +296,7 @@ def test_check_missing_files(): """Test checking for missing curry files (smoke test).""" invalid_fname = "/invalid/path/name.xy" - with pytest.raises(IOError, match="file type .*? must end with"): + with pytest.raises(OSError, match="file type .*? must end with"): _read_events_curry(invalid_fname) with pytest.raises(FileNotFoundError, match='does not exist'): diff --git a/mne/io/eeglab/eeglab.py b/mne/io/eeglab/eeglab.py index f3cc8606c8b..821cef8fae1 100644 --- a/mne/io/eeglab/eeglab.py +++ b/mne/io/eeglab/eeglab.py @@ -42,7 +42,7 @@ def _check_eeglab_fname(fname, dataname): 'Old data format .dat detected. Please update your EEGLAB ' 'version and resave the data in .fdt format') elif fmt != '.fdt': - raise IOError('Expected .fdt file format. Found %s format' % fmt) + raise OSError('Expected .fdt file format. Found %s format' % fmt) basedir = op.dirname(fname) data_fname = op.join(basedir, dataname) diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py index b7e1a02b631..1a45af0c38e 100644 --- a/mne/io/fiff/raw.py +++ b/mne/io/fiff/raw.py @@ -436,7 +436,7 @@ def _get_fname_rep(fname): def _check_entry(first, nent): """Sanity check entries.""" if first >= nent: - raise IOError('Could not read data, perhaps this is a corrupt file') + raise OSError('Could not read data, perhaps this is a corrupt file') @fill_doc diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 39b8cc61c58..b274602cf23 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -220,7 +220,7 @@ def test_output_formats(tmp_path): for ii, (fmt, tol) in enumerate(zip(formats, tols)): # Let's test the overwriting error throwing while we're at it if ii > 0: - pytest.raises(IOError, raw.save, temp_file, fmt=fmt) + pytest.raises(OSError, raw.save, temp_file, fmt=fmt) raw.save(temp_file, fmt=fmt, overwrite=True) raw2 = read_raw_fif(temp_file) raw2_data = raw2[:, :][0] @@ -1468,7 +1468,7 @@ def test_save(tmp_path): raw.save(temp_fname) raw.load_data() # can't overwrite file without overwrite=True - with pytest.raises(IOError, match='file exists'): + with pytest.raises(OSError, match='file exists'): raw.save(fif_fname) # test abspath support and annotations diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py index dd978b38b85..0a433f6203c 100644 --- a/mne/io/kit/kit.py +++ b/mne/io/kit/kit.py @@ -573,7 +573,7 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, elif channel_type == KIT.CHANNEL_NULL: channels.append({'type': channel_type}) else: - raise IOError("Unknown KIT channel type: %i" % channel_type) + raise OSError("Unknown KIT channel type: %i" % channel_type) exg_gains = np.array(exg_gains) # @@ -639,7 +639,7 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, else: sqd['n_samples'] = sqd['frame_length'] * sqd['n_epochs'] else: - raise IOError("Invalid acquisition type: %i. Your file is neither " + raise OSError("Invalid acquisition type: %i. Your file is neither " "continuous nor epoched data." % (acq_type,)) # @@ -707,7 +707,7 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # precompute conversion factor for reading data if unsupported_format: if sysid not in LEGACY_AMP_PARAMS: - raise IOError("Legacy parameters for system ID %i unavailable" % + raise OSError("Legacy parameters for system ID %i unavailable" % (sysid,)) adc_range, adc_stored = LEGACY_AMP_PARAMS[sysid] is_meg = np.array([ch['type'] in KIT.CHANNELS_MEG for ch in channels]) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 76329ce74ef..586dbdcdd56 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -291,7 +291,7 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, # Test saving with not correct extension out_fname_h5 = op.join(tempdir, 'test_raw.h5') - with pytest.raises(IOError, match='raw must end with .fif or .fif.gz'): + with pytest.raises(OSError, match='raw must end with .fif or .fif.gz'): raw.save(out_fname_h5) raw3 = read_raw_fif(out_fname) diff --git a/mne/io/write.py b/mne/io/write.py index edd240b8570..fb832cec9ca 100644 --- a/mne/io/write.py +++ b/mne/io/write.py @@ -368,7 +368,7 @@ def check_fiff_length(fid, close=True): if fid.tell() > 2147483648: # 2 ** 31, FIFF uses signed 32-bit locations if close: fid.close() - raise IOError('FIFF file exceeded 2GB limit, please split file, reduce' + raise OSError('FIFF file exceeded 2GB limit, please split file, reduce' ' split_size (if possible), or save to a different ' 'format') diff --git a/mne/label.py b/mne/label.py index 69e9af93f52..4ce7ae9b056 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1954,7 +1954,7 @@ def _read_annot_cands(dir_name, raise_error=True): if not op.isdir(dir_name): if not raise_error: return list() - raise IOError('Directory for annotation does not exist: %s', + raise OSError('Directory for annotation does not exist: %s', dir_name) cands = os.listdir(dir_name) cands = sorted(set(c.replace('lh.', '').replace('rh.', '').replace( @@ -1990,10 +1990,10 @@ def _read_annot(fname): dir_name = op.split(fname)[0] cands = _read_annot_cands(dir_name) if len(cands) == 0: - raise IOError('No such file %s, no candidate parcellations ' + raise OSError('No such file %s, no candidate parcellations ' 'found in directory' % fname) else: - raise IOError('No such file %s, candidate parcellations in ' + raise OSError('No such file %s, candidate parcellations in ' 'that directory:\n%s' % (fname, '\n'.join(cands))) with open(fname, "rb") as fid: n_verts = np.fromfile(fid, '>i4', 1)[0] diff --git a/mne/morph.py b/mne/morph.py index aa4185ad45f..a3628d91652 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -198,7 +198,7 @@ def compute_source_morph(src, subject_from=None, subject_to='fsaverage', # let's KISS and use `brain.mgz`, too mri_path_to = op.join(subjects_dir, subject_to, mri_subpath) if not op.isfile(mri_path_to): - raise IOError('cannot read file: %s' % mri_path_to) + raise OSError('cannot read file: %s' % mri_path_to) logger.info(' Loading %s as "to" volume' % mri_path_to) with warnings.catch_warnings(): mri_to = nib.load(mri_path_to) diff --git a/mne/report/report.py b/mne/report/report.py index 574688e63a2..36bc7717161 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -460,7 +460,7 @@ def _iterate_trans_views(function, alpha, **kwargs): return _itv( function, fig, surfaces={'head-dense': alpha}, **kwargs ) - except IOError: + except OSError: return _itv(function, fig, surfaces={'head': alpha}, **kwargs) finally: backend._close_3d_figure(fig) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 75ba2ac54f9..1e3ed65bee2 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -261,7 +261,7 @@ def read_source_estimate(fname, subject=None): err = ("Invalid .stc filename: %r; needs to end with " "hemisphere tag ('...-lh.stc' or '...-rh.stc')" % fname) - raise IOError(err) + raise OSError(err) elif fname.endswith('.w'): ftype = 'w' if fname.endswith(('-lh.w', '-rh.w')): @@ -270,7 +270,7 @@ def read_source_estimate(fname, subject=None): err = ("Invalid .w filename: %r; needs to end with " "hemisphere tag ('...-lh.w' or '...-rh.w')" % fname) - raise IOError(err) + raise OSError(err) elif fname.endswith('.h5'): ftype = 'h5' fname = fname[:-3] @@ -292,9 +292,9 @@ def read_source_estimate(fname, subject=None): ftype = 'h5' fname += '-stc' elif any(stc_exist) or any(w_exist): - raise IOError("Hemisphere missing for %r" % fname_arg) + raise OSError("Hemisphere missing for %r" % fname_arg) else: - raise IOError("SourceEstimate File(s) not found for: %r" + raise OSError("SourceEstimate File(s) not found for: %r" % fname_arg) # read the files @@ -307,7 +307,7 @@ def read_source_estimate(fname, subject=None): kwargs['tmin'] = 0.0 kwargs['tstep'] = 0.0 else: - raise IOError('Volume source estimate must end with .stc or .w') + raise OSError('Volume source estimate must end with .stc or .w') kwargs['vertices'] = [kwargs['vertices']] elif ftype == 'surface': # stc file with surface source spaces lh = _read_stc(fname + '-lh.stc') diff --git a/mne/source_space.py b/mne/source_space.py index 40a7d98bdbb..37598d8c42a 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -1410,7 +1410,7 @@ def setup_source_space(subject, spacing='oct6', surface='white', ] for surf, hemi in zip(surfs, ['LH', 'RH']): if surf is not None and not op.isfile(surf): - raise IOError('Could not find the %s surface %s' + raise OSError('Could not find the %s surface %s' % (hemi, surf)) logger.info('Setting up the source space with the following parameters:\n') @@ -1675,7 +1675,7 @@ def setup_volume_source_space(subject=None, pos=5.0, mri=None, surf_extra = 'dict()' else: if not op.isfile(surface): - raise IOError('surface file "%s" not found' % surface) + raise OSError('surface file "%s" not found' % surface) surf_extra = surface logger.info('Boundary surface file : %s', surf_extra) else: @@ -2381,7 +2381,7 @@ def _ensure_src(src, kind=None, extra='', verbose=None): if _path_like(src): src = str(src) if not op.isfile(src): - raise IOError('Source space file "%s" not found' % src) + raise OSError('Source space file "%s" not found' % src) logger.info('Reading %s...' % src) src = read_source_spaces(src, verbose=False) if not isinstance(src, SourceSpaces): diff --git a/mne/surface.py b/mne/surface.py index ea106eae7f0..8d6ad3ba436 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -95,7 +95,7 @@ def _get_head_surface(subject, source, subjects_dir, on_defects, # let's do a more sophisticated search path = op.join(subjects_dir, subject, 'bem') if not op.isdir(path): - raise IOError('Subject bem directory "%s" does not exist.' + raise OSError('Subject bem directory "%s" does not exist.' % path) files = sorted(glob(op.join(path, '%s*%s.fif' % (subject, this_source)))) @@ -114,7 +114,7 @@ def _get_head_surface(subject, source, subjects_dir, on_defects, if surf is None: if raise_error: - raise IOError('No file matching "%s*%s" and containing a head ' + raise OSError('No file matching "%s*%s" and containing a head ' 'surface found.' % (subject, this_source)) else: return surf @@ -1621,7 +1621,7 @@ def read_tri(fname_in, swap=False, verbose=None): elif n_items in [4, 7]: inds = range(1, 4) else: - raise IOError('Unrecognized format of data.') + raise OSError('Unrecognized format of data.') rr = np.array([np.array([float(v) for v in line.split()])[inds] for line in lines[1:n_nodes + 1]]) tris = np.array([np.array([int(v) for v in line.split()])[inds] diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index c8d6175a427..09f645204dd 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -66,7 +66,7 @@ def test_basics(): raw = read_raw_fif(fif_fname) assert raw.annotations is not None assert len(raw.annotations.onset) == 0 - pytest.raises(IOError, read_annotations, fif_fname) + pytest.raises(OSError, read_annotations, fif_fname) onset = np.array(range(10)) duration = np.ones(10) description = np.repeat('test', 10) @@ -234,7 +234,7 @@ def test_crop(tmp_path): assert_array_equal(annot_read.description, raw.annotations.description) annot = Annotations((), (), ()) annot.save(fname, overwrite=True) - pytest.raises(IOError, read_annotations, fif_fname) # none in old raw + pytest.raises(OSError, read_annotations, fif_fname) # none in old raw annot = read_annotations(fname) assert isinstance(annot, Annotations) assert len(annot) == 0 diff --git a/mne/tests/test_bem.py b/mne/tests/test_bem.py index f42e48f6afd..7cfe18d1bc3 100644 --- a/mne/tests/test_bem.py +++ b/mne/tests/test_bem.py @@ -88,7 +88,7 @@ def test_io_bem(tmp_path, ext): surf = read_bem_surfaces(fname_bem_3, patch_stats=True) surf = read_bem_surfaces(fname_bem_3, patch_stats=False) write_bem_surfaces(temp_bem, surf[0]) - with pytest.raises(IOError, match='exists'): + with pytest.raises(OSError, match='exists'): write_bem_surfaces(temp_bem, surf[0]) write_bem_surfaces(temp_bem, surf[0], overwrite=True) if ext == 'h5': diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 558bc098fe5..e70c8c11d19 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -120,7 +120,7 @@ def test_read_write_head_pos(tmp_path): pytest.raises(ValueError, write_head_pos, temp_name, 'foo') # not array pytest.raises(ValueError, write_head_pos, temp_name, head_pos_read[:, :9]) pytest.raises(TypeError, read_head_pos, 0) - pytest.raises(IOError, read_head_pos, "101") + pytest.raises(OSError, read_head_pos, "101") @testing.requires_testing_data diff --git a/mne/tests/test_coreg.py b/mne/tests/test_coreg.py index be99c6172e1..c8f0ca72c62 100644 --- a/mne/tests/test_coreg.py +++ b/mne/tests/test_coreg.py @@ -211,7 +211,7 @@ def test_scale_mri_xfm(tmp_path, few_surfaces, subjects_dir_tmp_few): if subject_from == 'fsaverage': overwrite = skip_fiducials = False else: - with pytest.raises(IOError, match='No fiducials file'): + with pytest.raises(OSError, match='No fiducials file'): scale_mri( subject_from, subject_to, @@ -219,7 +219,7 @@ def test_scale_mri_xfm(tmp_path, few_surfaces, subjects_dir_tmp_few): subjects_dir=subjects_dir_tmp_few, ) skip_fiducials = True - with pytest.raises(IOError, match='already exists'): + with pytest.raises(OSError, match='already exists'): scale_mri( subject_from, subject_to, scale, diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index cbbbe435684..52445dc64ec 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -3250,7 +3250,7 @@ def test_save_overwrite(tmp_path): # scenario 2: overwrite=False and there is a file to overwrite # fname1 exists because of scenario 1 above - with pytest.raises(IOError, match='Destination file exists.'): + with pytest.raises(OSError, match='Destination file exists.'): epochs.save(fname1, overwrite=False) # scenario 3: overwrite=True and there isn't a file to overwrite @@ -3260,7 +3260,7 @@ def test_save_overwrite(tmp_path): epochs.save(fname2, overwrite=True) # check that the file got written assert fname2.is_file() - with pytest.raises(IOError, match='exists'): + with pytest.raises(OSError, match='exists'): epochs.save(fname2) # scenario 4: overwrite=True and there is a file to overwrite diff --git a/mne/tests/test_label.py b/mne/tests/test_label.py index 0059ef3b80f..92f90c69504 100644 --- a/mne/tests/test_label.py +++ b/mne/tests/test_label.py @@ -359,7 +359,7 @@ def test_annot_io(tmp_path): shutil.copy(surf_src / "rh.white", surf_dir) # read original labels - with pytest.raises(IOError, match='\nPALS_B12_Lobes$'): + with pytest.raises(OSError, match='\nPALS_B12_Lobes$'): read_labels_from_annot(subject, 'PALS_B12_Lobesey', subjects_dir=tmp_path) labels = read_labels_from_annot(subject, 'PALS_B12_Lobes', @@ -455,9 +455,9 @@ def test_read_labels_from_annot(tmp_path): subjects_dir=subjects_dir) pytest.raises(ValueError, read_labels_from_annot, 'sample', annot_fname='bla.annot', subjects_dir=subjects_dir) - with pytest.raises(IOError, match='does not exist'): + with pytest.raises(OSError, match='does not exist'): _read_annot_cands('foo') - with pytest.raises(IOError, match='no candidate'): + with pytest.raises(OSError, match='no candidate'): _read_annot(str(tmp_path)) # read labels using hemi specification @@ -977,7 +977,7 @@ def test_label_center_of_mass(): restrict_vertices='foo') pytest.raises(TypeError, label.center_of_mass, subjects_dir=subjects_dir, surf=1) - pytest.raises(IOError, label.center_of_mass, subjects_dir=subjects_dir, + pytest.raises(OSError, label.center_of_mass, subjects_dir=subjects_dir, surf='foo') diff --git a/mne/tests/test_morph.py b/mne/tests/test_morph.py index 9e1bc13735d..00f1bd7d1fb 100644 --- a/mne/tests/test_morph.py +++ b/mne/tests/test_morph.py @@ -353,7 +353,7 @@ def test_volume_source_morph_basic(tmp_path): **kwargs) # check wrong subject_to - with pytest.raises(IOError, match='cannot read file'): + with pytest.raises(OSError, match='cannot read file'): compute_source_morph(fwd['src'], 'sample', '42', subjects_dir=subjects_dir) @@ -364,7 +364,7 @@ def test_volume_source_morph_basic(tmp_path): source_morph_vol_r = read_source_morph(tmp_path / 'vol-morph.h5') # check for invalid file name handling () - with pytest.raises(IOError, match='not found'): + with pytest.raises(OSError, match='not found'): read_source_morph(tmp_path / '42') # check morph diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index 96d7da009d7..030474fbf51 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -333,7 +333,7 @@ def test_volume_source_space(tmp_path): del src_new src_new = read_source_spaces(temp_name) _compare_source_spaces(src, src_new, mode='approx') - with pytest.raises(IOError, match='surface file.*not exist'): + with pytest.raises(OSError, match='surface file.*not exist'): setup_volume_source_space( 'sample', surface='foo', mri=fname_mri, subjects_dir=subjects_dir) bem['surfs'][-1]['coord_frame'] = FIFF.FIFFV_COORD_HEAD diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py index 05a427305fd..c01894a51e0 100644 --- a/mne/tests/test_transforms.py +++ b/mne/tests/test_transforms.py @@ -87,7 +87,7 @@ def test_io_trans(tmp_path): assert trans0 == trans1 # check reading non -trans.fif files - pytest.raises(IOError, read_trans, fname_eve) + pytest.raises(OSError, read_trans, fname_eve) # check warning on bad filenames fname2 = tmp_path / 'trans-test-bad-name.fif' diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index aa287079950..ef2d6c587ee 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -493,7 +493,7 @@ def test_io(tmp_path): assert_equal(tfr.comment, tfr2.comment) assert_equal(tfr.nave, tfr2.nave) - pytest.raises(IOError, tfr.save, fname) + pytest.raises(OSError, tfr.save, fname) tfr.comment = None # test old meas_date diff --git a/mne/transforms.py b/mne/transforms.py index d6c4c63ce5f..0874b0d7854 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -462,7 +462,7 @@ def _get_trans(trans, fro='mri', to='head', allow_none=True): ) trans = Path(trans) if not trans.is_file(): - raise IOError(f'trans file "{trans}" not found') + raise OSError(f'trans file "{trans}" not found') if trans.suffix in ['.fif', '.gz']: fro_to_t = read_trans(trans) else: @@ -561,7 +561,7 @@ def read_trans(fname, return_all=False, verbose=None): if not return_all: break if len(trans) == 0: - raise IOError('This does not seem to be a -trans.fif file.') + raise OSError('This does not seem to be a -trans.fif file.') return trans if return_all else trans[0] diff --git a/mne/utils/check.py b/mne/utils/check.py index 52f096450c7..ffe0651dc81 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -56,7 +56,7 @@ def check_fname(fname, filetype, endings, endings_err=()): if len(endings_err) > 0 and not fname.endswith(endings_err): print_endings = ' or '.join([', '.join(endings_err[:-1]), endings_err[-1]]) - raise IOError('The filename (%s) for file type %s must end with %s' + raise OSError('The filename (%s) for file type %s must end with %s' % (fname, filetype, print_endings)) print_endings = ' or '.join([', '.join(endings[:-1]), endings[-1]]) if not fname.endswith(endings): @@ -234,13 +234,13 @@ def _check_fname( if must_exist: if need_dir: if not fname.is_dir(): - raise IOError( + raise OSError( f"Need a directory for {name} but found a file " f"at {fname}" ) else: if not fname.is_file(): - raise IOError( + raise OSError( f"Need a file for {name} but found a directory " f"at {fname}" ) diff --git a/mne/utils/config.py b/mne/utils/config.py index 099f8333b94..1ee8417a1f3 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -43,7 +43,7 @@ def set_cache_dir(cache_dir): temporary file storage. """ if cache_dir is not None and not op.exists(cache_dir): - raise IOError('Directory %s does not exist' % cache_dir) + raise OSError('Directory %s does not exist' % cache_dir) set_config('MNE_CACHE_DIR', cache_dir, set_env=False) diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 3ddb7f8495f..8f28ee7799a 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -48,7 +48,7 @@ def test_check(tmp_path): os.chmod(fname, orig_perms) os.remove(fname) assert not fname.is_file() - pytest.raises(IOError, check_fname, 'foo', 'tets-dip.x', (), ('.fif',)) + pytest.raises(OSError, check_fname, 'foo', 'tets-dip.x', (), ('.fif',)) pytest.raises(ValueError, _check_subject, None, None) pytest.raises(TypeError, _check_subject, None, 1) pytest.raises(TypeError, _check_subject, 1, None) diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index 1b3ec6c1270..c9521690fa0 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -329,7 +329,7 @@ def _qt_get_stylesheet(theme): else: try: file = open(theme, 'r') - except IOError: + except OSError: warn('Requested theme file not found, will use light instead: ' f'{repr(theme)}') else: diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 266d575d972..36482a2a9b0 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -570,7 +570,7 @@ def plot_bem(subject, subjects_dir=None, orientation='coronal', bem_path = subjects_dir / subject / "bem" if not bem_path.is_dir(): - raise IOError(f'Subject bem directory "{bem_path}" does not exist') + raise OSError(f'Subject bem directory "{bem_path}" does not exist') surfaces = _get_bem_plotting_surfaces(bem_path) if brain_surfaces is not None: @@ -584,7 +584,7 @@ def plot_bem(subject, subjects_dir=None, orientation='coronal', if surf_fname.exists(): surfaces.append((surf_fname, '#00DD00')) else: - raise IOError("Surface %s does not exist." % surf_fname) + raise OSError("Surface %s does not exist." % surf_fname) if isinstance(src, (str, Path, os.PathLike)): src = Path(src) @@ -592,7 +592,7 @@ def plot_bem(subject, subjects_dir=None, orientation='coronal', # convert to Path until get_subjects_dir returns a Path object src_ = Path(subjects_dir) / subject / "bem" / src if not src_.exists(): - raise IOError(f"{src} does not exist") + raise OSError(f"{src} does not exist") src = src_ src = read_source_spaces(src) elif src is not None and not isinstance(src, SourceSpaces): @@ -602,7 +602,7 @@ def plot_bem(subject, subjects_dir=None, orientation='coronal', ) if len(surfaces) == 0: - raise IOError('No surface files found. Surface files must end with ' + raise OSError('No surface files found. Surface files must end with ' 'inner_skull.surf, outer_skull.surf or outer_skin.surf') # Plot the contours diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 2ea3b6f8e99..156e3153277 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -136,7 +136,7 @@ def test_plot_cov(): def test_plot_bem(): """Test plotting of BEM contours.""" pytest.importorskip('nibabel') - with pytest.raises(IOError, match='MRI file .* not found'): + with pytest.raises(OSError, match='MRI file .* not found'): plot_bem(subject='bad-subject', subjects_dir=subjects_dir) with pytest.raises(ValueError, match="Invalid value for the 'orientation"): plot_bem(subject='sample', subjects_dir=subjects_dir, From bde717c01d9d4eaa21233685b9ea4017b54a5a0c Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Mon, 27 Mar 2023 18:27:35 +0200 Subject: [PATCH 0005/1125] MAINT: object is an implicit base for all classes (#11601) --- mne/_ola.py | 4 ++-- mne/annotations.py | 2 +- mne/channels/channels.py | 4 ++-- mne/channels/layout.py | 2 +- mne/channels/montage.py | 2 +- mne/coreg.py | 2 +- mne/decoding/mixin.py | 4 ++-- mne/epochs.py | 2 +- mne/event.py | 2 +- mne/filter.py | 2 +- mne/fixes.py | 2 +- mne/gui/__init__.py | 2 +- mne/gui/tests/test_coreg.py | 2 +- mne/inverse_sparse/mxne_optim.py | 4 ++-- mne/io/base.py | 4 ++-- mne/io/meas_info.py | 4 ++-- mne/io/open.py | 2 +- mne/io/proj.py | 2 +- mne/io/tag.py | 2 +- mne/label.py | 2 +- mne/report/report.py | 2 +- mne/simulation/raw.py | 4 ++-- mne/simulation/source.py | 2 +- mne/surface.py | 6 +++--- mne/time_frequency/csd.py | 2 +- mne/transforms.py | 4 ++-- mne/utils/_bunch.py | 2 +- mne/utils/_logging.py | 4 ++-- mne/utils/_testing.py | 4 ++-- mne/utils/check.py | 4 ++-- mne/utils/mixin.py | 8 ++++---- mne/utils/numerics.py | 4 ++-- mne/utils/progressbar.py | 4 ++-- mne/viz/_3d_overlay.py | 4 ++-- mne/viz/_brain/_brain.py | 2 +- mne/viz/_brain/_linkviewer.py | 2 +- mne/viz/_brain/_scraper.py | 2 +- mne/viz/_brain/callback.py | 10 +++++----- mne/viz/_brain/surface.py | 2 +- mne/viz/_brain/tests/test_brain.py | 4 ++-- mne/viz/backends/_notebook.py | 4 ++-- mne/viz/backends/_pyvista.py | 2 +- mne/viz/topomap.py | 2 +- mne/viz/utils.py | 6 +++--- 44 files changed, 71 insertions(+), 71 deletions(-) diff --git a/mne/_ola.py b/mne/_ola.py index d1ab34c235f..68c24a79278 100644 --- a/mne/_ola.py +++ b/mne/_ola.py @@ -11,7 +11,7 @@ ############################################################################### # Class for interpolation between adjacent points -class _Interp2(object): +class _Interp2: r"""Interpolate between two points. Parameters @@ -416,7 +416,7 @@ def _check_cola(win, nperseg, step, window_name, tol=1e-10): return const -class _Storer(object): +class _Storer: """Store data in chunks.""" def __init__(self, *outs, picks=None): diff --git a/mne/annotations.py b/mne/annotations.py index 1d6ee2c5768..848a4ba4dc7 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -87,7 +87,7 @@ def _ndarray_ch_names(ch_names): @fill_doc -class Annotations(object): +class Annotations: """Annotation object for annotating segments of raw data. .. note:: diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 12e797a5c10..7b2afb2ffbb 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -586,7 +586,7 @@ def set_meas_date(self, meas_date): return self -class UpdateChannelsMixin(object): +class UpdateChannelsMixin: """Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR.""" @verbose @@ -980,7 +980,7 @@ def add_reference_channels(self, ref_channels): return add_reference_channels(self, ref_channels, copy=False) -class InterpolationMixin(object): +class InterpolationMixin: """Mixin class for Raw, Evoked, Epochs.""" @verbose diff --git a/mne/channels/layout.py b/mne/channels/layout.py index 87149f458a4..e59bb80a2a1 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -25,7 +25,7 @@ from .channels import _get_ch_info -class Layout(object): +class Layout: """Sensor layouts. Layouts are typically loaded from a file using diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 3bc153a0d14..e28ded3f3d7 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -276,7 +276,7 @@ def make_dig_montage(ch_pos=None, nasion=None, lpa=None, rpa=None, return DigMontage(dig=dig, ch_names=ch_names) -class DigMontage(object): +class DigMontage: """Montage for digitized electrode and headshape position data. .. warning:: Montages are typically created using one of the helper diff --git a/mne/coreg.py b/mne/coreg.py index 52616d3c113..3ea3576e1d1 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -1278,7 +1278,7 @@ def _read_surface(filename, *, on_defects): @fill_doc -class Coregistration(object): +class Coregistration: """Class for MRI<->head coregistration. Parameters diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py index b2c491b9118..d38e9e4aff4 100644 --- a/mne/decoding/mixin.py +++ b/mne/decoding/mixin.py @@ -1,6 +1,6 @@ -class TransformerMixin(object): +class TransformerMixin: """Mixin class for all transformers in scikit-learn.""" def fit_transform(self, X, y=None, **fit_params): @@ -33,7 +33,7 @@ def fit_transform(self, X, y=None, **fit_params): return self.fit(X, y, **fit_params).transform(X) -class EstimatorMixin(object): +class EstimatorMixin: """Mixin class for estimators.""" def get_params(self, deep=True): diff --git a/mne/epochs.py b/mne/epochs.py index b2448bd95ba..319bb4ecdb4 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3211,7 +3211,7 @@ def read_epochs(fname, proj=True, preload=True, verbose=None): return EpochsFIF(fname, proj, preload, verbose) -class _RawContainer(object): +class _RawContainer: """Helper for a raw data container.""" def __init__(self, fid, data_tag, event_samps, epoch_shape, diff --git a/mne/event.py b/mne/event.py index 104897014d9..1478b4ae105 100644 --- a/mne/event.py +++ b/mne/event.py @@ -954,7 +954,7 @@ def concatenate_events(events, first_samps, last_samps): @fill_doc -class AcqParserFIF(object): +class AcqParserFIF: """Parser for Elekta data acquisition settings. This class parses parameters (e.g. events and averaging categories) that diff --git a/mne/filter.py b/mne/filter.py index 5c934ae52ec..5a3a25b5bec 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -1880,7 +1880,7 @@ def float_array(c): fir_window, fir_design) -class FilterMixin(object): +class FilterMixin: """Object for Epoch/Evoked filtering.""" @verbose diff --git a/mne/fixes.py b/mne/fixes.py index 72d7dd1fbb1..5ff7c07a66f 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -214,7 +214,7 @@ def is_regressor(estimator): } -class BaseEstimator(object): +class BaseEstimator: """Base class for all estimators in scikit-learn. Notes diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index f92817bf637..c86b413b634 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -249,7 +249,7 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, return gui -class _GUIScraper(object): +class _GUIScraper: """Scrape GUI outputs.""" def __repr__(self): diff --git a/mne/gui/tests/test_coreg.py b/mne/gui/tests/test_coreg.py index 44374d3c048..5c705c6dcb5 100644 --- a/mne/gui/tests/test_coreg.py +++ b/mne/gui/tests/test_coreg.py @@ -55,7 +55,7 @@ pytest.importorskip('nibabel') -class TstVTKPicker(object): +class TstVTKPicker: """Class to test cell picking.""" def __init__(self, mesh, cell_id, event_pos): diff --git a/mne/inverse_sparse/mxne_optim.py b/mne/inverse_sparse/mxne_optim.py index 587f1744400..bff7a909781 100644 --- a/mne/inverse_sparse/mxne_optim.py +++ b/mne/inverse_sparse/mxne_optim.py @@ -662,7 +662,7 @@ def safe_max_abs_diff(A, ia, B, ib): return np.max(np.abs(A - B)) -class _Phi(object): +class _Phi: """Have phi stft as callable w/o using a lambda that does not pickle.""" def __init__(self, wsize, tstep, n_coefs, n_times): # noqa: D102 @@ -705,7 +705,7 @@ def norm(self, z, ord=2): return norm -class _PhiT(object): +class _PhiT: """Have phi.T istft as callable w/o using a lambda that does not pickle.""" def __init__(self, tstep, n_freqs, n_steps, n_times): # noqa: D102 diff --git a/mne/io/base.py b/mne/io/base.py index b97290a1a7e..0dcc1b9eee4 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -2146,7 +2146,7 @@ def _get_scaling(ch_type, target_unit): return scaling -class _ReadSegmentFileProtector(object): +class _ReadSegmentFileProtector: """Ensure only _filenames, _raw_extras, and _read_segment_file are used.""" def __init__(self, raw): @@ -2160,7 +2160,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): self, data, idx, fi, start, stop, cals, mult) -class _RawShell(object): +class _RawShell: """Create a temporary raw object.""" def __init__(self): # noqa: D102 diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 20b795ab478..493542cce4d 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -147,7 +147,7 @@ def _unique_channel_names(ch_names, max_length=None, verbose=None): return ch_names -class MontageMixin(object): +class MontageMixin: """Mixin for Montage getting and setting.""" @fill_doc @@ -240,7 +240,7 @@ def set_montage(self, montage, match_case=True, match_alias=False, return self -class ContainsMixin(object): +class ContainsMixin: """Mixin class for Raw, Evoked, Epochs and Info.""" def __contains__(self, ch_type): diff --git a/mne/io/open.py b/mne/io/open.py index 7680c1344b5..d9aa7f3d53f 100644 --- a/mne/io/open.py +++ b/mne/io/open.py @@ -16,7 +16,7 @@ from ..utils import logger, verbose, _file_like, warn -class _NoCloseRead(object): +class _NoCloseRead: """Create a wrapper that will not close when used as a context manager.""" def __init__(self, fid): diff --git a/mne/io/proj.py b/mne/io/proj.py index 12db504f372..7774a8edb02 100644 --- a/mne/io/proj.py +++ b/mne/io/proj.py @@ -141,7 +141,7 @@ def plot_topomap( show=show) -class ProjMixin(object): +class ProjMixin: """Mixin class for Raw, Evoked, Epochs. Notes diff --git a/mne/io/tag.py b/mne/io/tag.py index 69504a5e49a..6d4b5df2ee4 100644 --- a/mne/io/tag.py +++ b/mne/io/tag.py @@ -17,7 +17,7 @@ ############################################################################## # HELPERS -class Tag(object): +class Tag: """Tag in FIF tree structure. Parameters diff --git a/mne/label.py b/mne/label.py index 4ce7ae9b056..0a1fa9710a0 100644 --- a/mne/label.py +++ b/mne/label.py @@ -872,7 +872,7 @@ def _get_label_src(label, src): return hemi_src -class BiHemiLabel(object): +class BiHemiLabel: """A freesurfer/MNE label with vertices in both hemispheres. Parameters diff --git a/mne/report/report.py b/mne/report/report.py index 36bc7717161..ae62ec10d87 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -3877,7 +3877,7 @@ def _recursive_search(path, pattern): _FA_FILE_CODE = '' # noqa: E501 -class _ReportScraper(object): +class _ReportScraper: """Scrape Report outputs. Only works properly if conf.py is configured properly and the file diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py index 19875910d38..1d2b8bcce0d 100644 --- a/mne/simulation/raw.py +++ b/mne/simulation/raw.py @@ -586,7 +586,7 @@ def add_chpi(raw, head_pos=None, interp='cos2', n_jobs=None, verbose=None): return raw -class _HPIForwards(object): +class _HPIForwards: def __init__(self, offsets, dev_head_ts, megcoils, hpi_rrs, hpi_nns): self.offsets = offsets @@ -656,7 +656,7 @@ def _stc_data_event(stc_counted, head_idx, sfreq, src=None, verts=None): return stc_data, stim_data, verts_ -class _SimForwards(object): +class _SimForwards: def __init__(self, dev_head_ts, offsets, info, trans, src, bem, mindist, n_jobs, meeg_picks, forward=None, use_cps=True): diff --git a/mne/simulation/source.py b/mne/simulation/source.py index df7dc736a19..425985ebf10 100644 --- a/mne/simulation/source.py +++ b/mne/simulation/source.py @@ -316,7 +316,7 @@ def simulate_stc(src, labels, stc_data, tmin, tstep, value_fun=None, return stc -class SourceSimulator(object): +class SourceSimulator: """Class to generate simulated Source Estimates. Parameters diff --git a/mne/surface.py b/mne/surface.py index 8d6ad3ba436..8a448f69feb 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -456,7 +456,7 @@ def _normalize_vectors(rr): return size -class _CDist(object): +class _CDist: """Wrapper for cdist that uses a Tree-like pattern.""" def __init__(self, xhs): @@ -513,7 +513,7 @@ def _safe_query(rr, func, reduce=False, **kwargs): return out -class _DistanceQuery(object): +class _DistanceQuery: """Wrapper for fast distance queries.""" def __init__(self, xhs, method='BallTree', allow_kdtree=False): @@ -596,7 +596,7 @@ def _polydata_to_surface(pd, normals=True): return out -class _CheckInside(object): +class _CheckInside: """Efficiently check if points are inside a surface.""" @verbose diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index b7cc09c5e5e..d8680df0047 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -65,7 +65,7 @@ def pick_channels_csd(csd, include=[], exclude=[], ordered=False, copy=True): return csd -class CrossSpectralDensity(object): +class CrossSpectralDensity: """Cross-spectral density. Given a list of time series, the CSD matrix denotes for each pair of time diff --git a/mne/transforms.py b/mne/transforms.py index 0874b0d7854..357c299a04b 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -928,7 +928,7 @@ def _compute_sph_harm(order, az, pol): # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -class _TPSWarp(object): +class _TPSWarp: """Transform points using thin-plate spline (TPS) warping. Notes @@ -1005,7 +1005,7 @@ def _tps(distsq): ############################################################################### # Spherical harmonic approximation + TPS warp -class _SphericalSurfaceWarp(object): +class _SphericalSurfaceWarp: """Warp surfaces via spherical harmonic smoothing and thin-plate splines. Notes diff --git a/mne/utils/_bunch.py b/mne/utils/_bunch.py index 103a9574461..13c6c6f1e02 100644 --- a/mne/utils/_bunch.py +++ b/mne/utils/_bunch.py @@ -52,7 +52,7 @@ def __setattr__(self, attr, val): # noqa: D105 super().__setattr__(attr, val) -class _Named(object): +class _Named: """Provide shared methods for giving named-representation subclasses.""" def __new__(cls, name, val): # noqa: D102,D105 diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 80d07e0e285..33f33d72c8b 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -295,7 +295,7 @@ def getvalue(self, close=True): return out -class catch_logging(object): +class catch_logging: """Store logging. This will remove all other logging handlers, and return the handler to @@ -334,7 +334,7 @@ def _record_warnings(): yield w -class WrapStdOut(object): +class WrapStdOut: """Dynamically wrap to sys.stdout. This makes packages that monkey-patch sys.stdout (e.g.doctest, diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index acd6c1764f6..1417b9c5c13 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -158,7 +158,7 @@ def run_command_if_main(): local_vars['run']() -class ArgvSetter(object): +class ArgvSetter: """Temporarily set sys.argv.""" def __init__(self, args=(), disable_stdout=True, @@ -182,7 +182,7 @@ def __exit__(self, *args): # noqa: D105 sys.stderr = self.orig_stderr -class SilenceStdout(object): +class SilenceStdout: """Silence stdout.""" def __init__(self, close=True): diff --git a/mne/utils/check.py b/mne/utils/check.py index ffe0651dc81..4ed7ed38566 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -469,7 +469,7 @@ def _is_numeric(n): return isinstance(n, numbers.Number) -class _IntLike(object): +class _IntLike: @classmethod def __instancecheck__(cls, other): try: @@ -484,7 +484,7 @@ def __instancecheck__(cls, other): path_like = (str, Path, os.PathLike) -class _Callable(object): +class _Callable: @classmethod def __instancecheck__(cls, other): return callable(other) diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 8eb14085d38..4828129b64e 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -20,7 +20,7 @@ logger.propagate = False # don't propagate (in case of multiple imports) -class SizeMixin(object): +class SizeMixin: """Estimate MNE object sizes.""" def __eq__(self, other): @@ -72,7 +72,7 @@ def __hash__(self): raise RuntimeError('Hashing unknown object type: %s' % type(self)) -class GetEpochsMixin(object): +class GetEpochsMixin: """Class to add epoch selection and metadata to certain classes.""" def __getitem__(self, item): @@ -448,7 +448,7 @@ def _check_decim(info, decim, offset, check_filter=True): return decim, offset, new_sfreq -class TimeMixin(object): +class TimeMixin: """Class to handle operations on time for MNE objects.""" @property @@ -709,7 +709,7 @@ def _prepare_read_metadata(metadata): return metadata -class _FakeNoPandas(object): # noqa: D101 +class _FakeNoPandas: # noqa: D101 def __enter__(self): # noqa: D105 def _check(strict=True): diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index a702eba6b37..a6c4a7fa734 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -833,7 +833,7 @@ def object_diff(a, b, pre='', *, allclose=False): return out -class _PCA(object): +class _PCA: """Principal component analysis (PCA).""" # Adapted from sklearn and stripped down to just use linalg.svd @@ -1055,7 +1055,7 @@ def _stamp_to_dt(utc_stamp): timedelta(seconds=stamp[0], microseconds=stamp[1])) -class _ReuseCycle(object): +class _ReuseCycle: """Cycle over a variable, preferring to reuse earlier indices. Requires the values in ``x`` to be hashable and unique. This holds diff --git a/mne/utils/progressbar.py b/mne/utils/progressbar.py index dd6c73991be..20a14e3b169 100644 --- a/mne/utils/progressbar.py +++ b/mne/utils/progressbar.py @@ -19,7 +19,7 @@ from ._logging import logger -class ProgressBar(object): +class ProgressBar: """Generate a command-line progressbar. Parameters @@ -181,7 +181,7 @@ def run(self): time.sleep(1. / 30.) # 30 Hz refresh is plenty -class _PBSubsetUpdater(object): +class _PBSubsetUpdater: def __init__(self, pb, idx): self.mmap = pb._mmap diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 819d9a6a30b..12c36c4ec73 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -13,7 +13,7 @@ from ..utils import logger -class _Overlay(object): +class _Overlay: def __init__(self, scalars, colormap, rng, opacity, name): self._scalars = scalars self._colormap = colormap @@ -54,7 +54,7 @@ def _norm(self, rng): return (self._scalars - rng[0]) / factor -class _LayeredMesh(object): +class _LayeredMesh: def __init__(self, renderer, vertices, triangles, normals): self._renderer = renderer self._vertices = vertices diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index f72c9364b42..78d28470fc2 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -51,7 +51,7 @@ @fill_doc -class Brain(object): +class Brain: """Class for visualizing a brain. .. warning:: diff --git a/mne/viz/_brain/_linkviewer.py b/mne/viz/_brain/_linkviewer.py index a2d7c34b51d..3be8b118384 100644 --- a/mne/viz/_brain/_linkviewer.py +++ b/mne/viz/_brain/_linkviewer.py @@ -7,7 +7,7 @@ from ...utils import warn -class _LinkViewer(object): +class _LinkViewer: """Class to link multiple Brain objects.""" def __init__(self, brains, time=True, camera=False, colorbar=True, diff --git a/mne/viz/_brain/_scraper.py b/mne/viz/_brain/_scraper.py index feb0b726631..08defc1b894 100644 --- a/mne/viz/_brain/_scraper.py +++ b/mne/viz/_brain/_scraper.py @@ -4,7 +4,7 @@ from ._brain import Brain -class _BrainScraper(object): +class _BrainScraper: """Scrape Brain objects.""" def __repr__(self): diff --git a/mne/viz/_brain/callback.py b/mne/viz/_brain/callback.py index 3e180e91677..831569e4aaf 100644 --- a/mne/viz/_brain/callback.py +++ b/mne/viz/_brain/callback.py @@ -9,7 +9,7 @@ from ...utils import logger -class TimeCallBack(object): +class TimeCallBack: """Callback to update the time.""" def __init__(self, brain=None, callback=None): @@ -36,7 +36,7 @@ def __call__(self, value, update_widget=False, time_as_index=True): self.widget.set_value(int(value)) -class UpdateColorbarScale(object): +class UpdateColorbarScale: """Class to update the values of the colorbar sliders.""" def __init__(self, brain, factor): @@ -53,7 +53,7 @@ def __call__(self): self.widgets[key].set_value(self.brain._data[key]) -class UpdateLUT(object): +class UpdateLUT: """Update the LUT.""" def __init__(self, brain=None): @@ -74,7 +74,7 @@ def __call__(self, fmin=None, fmid=None, fmax=None): widget.set_value(value) -class ShowView(object): +class ShowView: """Class that selects the correct view.""" def __init__(self, brain=None, data=None): @@ -100,7 +100,7 @@ def __call__(self, value, update_widget=False): self.widget.set_value(value) -class SmartCallBack(object): +class SmartCallBack: """Class to manage smart slider. It stores it's own slider representation for efficiency diff --git a/mne/viz/_brain/surface.py b/mne/viz/_brain/surface.py index d3820da5a96..4362b64baed 100644 --- a/mne/viz/_brain/surface.py +++ b/mne/viz/_brain/surface.py @@ -15,7 +15,7 @@ _read_patch) -class _Surface(object): +class _Surface: """Container for a brain surface. It is used for storing vertices, faces and morphometric data diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 93cc72d644a..230b99a98d9 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -52,7 +52,7 @@ pytest.importorskip("nibabel") -class _Collection(object): +class _Collection: def __init__(self, actors): self._actors = actors @@ -63,7 +63,7 @@ def GetItemAsObject(self, ii): return self._actors[ii] -class TstVTKPicker(object): +class TstVTKPicker: """Class to test cell picking.""" def __init__(self, mesh, cell_id, hemi, brain): diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index f2d830d1b62..c239aa9e42c 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -304,7 +304,7 @@ def _set_enabled(self, value): # modified from: # https://gist.github.com/elkhadiy/284900b3ea8a13ed7b777ab93a691719 -class _FilePicker(object): +class _FilePicker: def __init__(self, rows=20, directory_only=False, ignore_dotfiles=True): self._callback = None self._directory_only = directory_only @@ -562,7 +562,7 @@ def _click(self, value): self._buttons[value].click() -class _BoxLayout(object): +class _BoxLayout: def _handle_scroll(self, scroll=None): kwargs = _BASE_KWARGS.copy() diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index aa4493430cc..b4fe3b69b7f 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -136,7 +136,7 @@ def _is_active(self): return hasattr(self.plotter, 'ren_win') -class _Projection(object): +class _Projection: """Class storing projection information. Attributes diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 3221ff03e1d..be10dbe9502 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -626,7 +626,7 @@ def _get_extra_points(pos, extrapolate, origin, radii): return new_pos, mask_pos, tri -class _GridData(object): +class _GridData: """Unstructured (x,y) data interpolator. This class allows optimized interpolation by computing parameters diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 57930e771ee..6a6c353b321 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -654,7 +654,7 @@ def _key_press(event): plt.close(event.canvas.figure) -class ClickableImage(object): +class ClickableImage: """Display an image so you can click on it and store x/y positions. Takes as input an image array (can be any array that works with imshow, @@ -1371,7 +1371,7 @@ def _prepare_joint_axes(n_maps, figsize=None): return fig, main_ax, map_ax, cbar_ax -class DraggableColorbar(object): +class DraggableColorbar: """Enable interactive colorbar. See http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html @@ -1689,7 +1689,7 @@ def _connection_line(x, fig, sourceax, targetax, y=1., clip_on=False) -class DraggableLine(object): +class DraggableLine: """Custom matplotlib line for moving around by drag and drop. Parameters From 267967920ac384cd8698f65670bfb9c2e8720a49 Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Mon, 27 Mar 2023 18:29:37 +0200 Subject: [PATCH 0006/1125] MAINT: In Python 3, do not prefix literals with `u` (#11604) --- mne/datasets/brainstorm/bst_phantom_elekta.py | 2 +- mne/datasets/hf_sef/hf_sef.py | 2 +- mne/decoding/transformer.py | 2 +- mne/dipole.py | 2 +- mne/io/meas_info.py | 6 +++--- mne/io/open.py | 2 +- mne/io/proc_history.py | 2 +- mne/io/tests/test_meas_info.py | 2 +- mne/morph.py | 4 ++-- mne/preprocessing/artifact_detection.py | 12 ++++++------ mne/preprocessing/ica.py | 2 +- mne/preprocessing/maxwell.py | 2 +- mne/simulation/raw.py | 2 +- mne/tests/test_chpi.py | 2 +- mne/viz/_brain/colormap.py | 2 +- mne/viz/misc.py | 6 +++--- tutorials/stats-sensor-space/10_background_stats.py | 2 +- 17 files changed, 27 insertions(+), 27 deletions(-) diff --git a/mne/datasets/brainstorm/bst_phantom_elekta.py b/mne/datasets/brainstorm/bst_phantom_elekta.py index 40f92661085..abfa5a68aca 100644 --- a/mne/datasets/brainstorm/bst_phantom_elekta.py +++ b/mne/datasets/brainstorm/bst_phantom_elekta.py @@ -5,7 +5,7 @@ from ..utils import (_get_version, _version_doc, _data_path_doc_accept, _download_mne_dataset) -_description = u""" +_description = """ URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomElekta """ diff --git a/mne/datasets/hf_sef/hf_sef.py b/mne/datasets/hf_sef/hf_sef.py index 63d97df4fdf..aa4cffa33d1 100644 --- a/mne/datasets/hf_sef/hf_sef.py +++ b/mne/datasets/hf_sef/hf_sef.py @@ -14,7 +14,7 @@ @verbose def data_path(dataset='evoked', path=None, force_update=False, update_path=True, *, verbose=None): - u"""Get path to local copy of the high frequency SEF dataset. + """Get path to local copy of the high frequency SEF dataset. Gets a local copy of the high frequency SEF MEG dataset :footcite:`NurminenEtAl2017`. diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 63999640afd..5e7734c292a 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -67,7 +67,7 @@ def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): @fill_doc class Scaler(TransformerMixin, BaseEstimator): - u"""Standardize channel data. + """Standardize channel data. This class scales data for each channel. It differs from scikit-learn classes (e.g., :class:`sklearn.preprocessing.StandardScaler`) in that diff --git a/mne/dipole.py b/mne/dipole.py index 94e984ad617..c374fb2dca2 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -44,7 +44,7 @@ @fill_doc class Dipole(TimeMixin): - u"""Dipole class for sequential dipole fits. + """Dipole class for sequential dipole fits. .. note:: This class should usually not be instantiated directly via diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 493542cce4d..b5e8f844c62 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -72,7 +72,7 @@ def _get_valid_units(): 'micro', 'milli', 'centi', 'deci', 'deca', 'hecto', 'kilo', 'mega', 'giga', 'tera', 'peta', 'exa', 'zetta', 'yotta'] - valid_prefix_symbols = ['y', 'z', 'a', 'f', 'p', 'n', u'µ', 'm', 'c', 'd', + valid_prefix_symbols = ['y', 'z', 'a', 'f', 'p', 'n', 'µ', 'm', 'c', 'd', 'da', 'h', 'k', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] valid_unit_names = ['metre', 'kilogram', 'second', 'ampere', 'kelvin', 'mole', 'candela', 'radian', 'steradian', 'hertz', @@ -81,8 +81,8 @@ def _get_valid_units(): 'degree Celsius', 'lumen', 'lux', 'becquerel', 'gray', 'sievert', 'katal'] valid_unit_symbols = ['m', 'kg', 's', 'A', 'K', 'mol', 'cd', 'rad', 'sr', - 'Hz', 'N', 'Pa', 'J', 'W', 'C', 'V', 'F', u'Ω', 'S', - 'Wb', 'T', 'H', u'°C', 'lm', 'lx', 'Bq', 'Gy', 'Sv', + 'Hz', 'N', 'Pa', 'J', 'W', 'C', 'V', 'F', 'Ω', 'S', + 'Wb', 'T', 'H', '°C', 'lm', 'lx', 'Bq', 'Gy', 'Sv', 'kat'] # Valid units are all possible combinations of either prefix name or prefix diff --git a/mne/io/open.py b/mne/io/open.py index d9aa7f3d53f..e95a3e957c4 100644 --- a/mne/io/open.py +++ b/mne/io/open.py @@ -310,7 +310,7 @@ def _show_tree(fid, tree, indent, level, read_limit, max_str, tag_id): '/'.join(this_type) + ' (' + str(size) + 'b %s)' % type_ + postpend] - out[-1] = out[-1].replace('\n', u'¶') + out[-1] = out[-1].replace('\n', '¶') counter = 0 good = True if tag_id in kinds: diff --git a/mne/io/proc_history.py b/mne/io/proc_history.py index 306b33ff13f..21b1018ff34 100644 --- a/mne/io/proc_history.py +++ b/mne/io/proc_history.py @@ -228,7 +228,7 @@ def _read_maxfilter_record(fid, tree): tag = read_tag(fid, pos) chs = _safe_name_list(tag.data, 'read', 'proj_items_chs') # This list can null chars in the last entry, e.g.: - # [..., u'MEG2642', u'MEG2643', u'MEG2641\x00 ... \x00'] + # [..., 'MEG2642', 'MEG2643', 'MEG2641\x00 ... \x00'] chs[-1] = chs[-1].split('\x00')[0] sss_ctc['proj_items_chs'] = chs diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index b4fc87783bb..67b81b82f7e 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -248,7 +248,7 @@ def test_read_write_info(tmp_path): assert (len(info['chs']) == len(info2['chs'])) assert_array_equal(t1, t2) # proc_history (e.g., GH#1875) - creator = u'é' + creator = 'é' info = read_info(chpi_fname) info['proc_history'][0]['creator'] = creator info['hpi_meas'][0]['creator'] = creator diff --git a/mne/morph.py b/mne/morph.py index a3628d91652..269134803e9 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -632,8 +632,8 @@ def _morph_vols(self, vols, mesg, subselect=True): return img_to def __repr__(self): # noqa: D105 - s = u"%s" % self.kind - s += u", %s -> %s" % (self.subject_from, self.subject_to) + s = "%s" % self.kind + s += ", %s -> %s" % (self.subject_from, self.subject_to) if self.kind == 'volume': s += ", zooms : {}".format(self.zooms) s += ", niter_affine : {}".format(self.niter_affine) diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 67f302fc381..3cdf09d84cd 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -181,8 +181,8 @@ def annotate_movement(raw, pos, rotation_velocity_limit=None, onsets, offsets = _mask_to_onsets_offsets(bad_mask) onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot - logger.info(u'Omitting %5.1f%% (%3d segments): ' - u'ω >= %5.1f°/s (max: %0.1f°/s)' + logger.info('Omitting %5.1f%% (%3d segments): ' + 'ω >= %5.1f°/s (max: %0.1f°/s)' % (bad_pct, len(onsets), rotation_velocity_limit, np.rad2deg(r.max()))) annot += _annotations_from_mask( @@ -197,8 +197,8 @@ def annotate_movement(raw, pos, rotation_velocity_limit=None, onsets, offsets = _mask_to_onsets_offsets(bad_mask) onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot - logger.info(u'Omitting %5.1f%% (%3d segments): ' - u'v >= %5.4fm/s (max: %5.4fm/s)' + logger.info('Omitting %5.1f%% (%3d segments): ' + 'v >= %5.4fm/s (max: %5.4fm/s)' % (bad_pct, len(onsets), translation_velocity_limit, v.max())) annot += _annotations_from_mask( @@ -242,8 +242,8 @@ def annotate_movement(raw, pos, rotation_velocity_limit=None, onsets, offsets = _mask_to_onsets_offsets(bad_mask) onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot - logger.info(u'Omitting %5.1f%% (%3d segments): ' - u'disp >= %5.4fm (max: %5.4fm)' + logger.info('Omitting %5.1f%% (%3d segments): ' + 'disp >= %5.4fm (max: %5.4fm)' % (bad_pct, len(onsets), mean_distance_limit, disp.max())) annot += _annotations_from_mask( hp_ts, bad_mask, 'BAD_mov_dist', orig_time=orig_time) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index c2242c5e5ce..06e16383c15 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -149,7 +149,7 @@ def _check_for_unsupported_ica_channels(picks, info, allow_ref_meg=False): @fill_doc class ICA(ContainsMixin): - u"""Data decomposition using Independent Component Analysis (ICA). + """Data decomposition using Independent Component Analysis (ICA). This object estimates independent components from :class:`mne.io.Raw`, :class:`mne.Epochs`, or :class:`mne.Evoked` objects. Components can diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 321b59220ac..0251a372a53 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -2000,7 +2000,7 @@ def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine, def _compute_sphere_activation_in(degrees): - u"""Compute the "in" power from random currents in a sphere. + """Compute the "in" power from random currents in a sphere. Parameters ---------- diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py index 1d2b8bcce0d..7abd33bd1ee 100644 --- a/mne/simulation/raw.py +++ b/mne/simulation/raw.py @@ -129,7 +129,7 @@ def _check_head_pos(head_pos, info, first_samp, times=None): def simulate_raw(info, stc=None, trans=None, src=None, bem=None, head_pos=None, mindist=1.0, interp='cos2', n_jobs=None, use_cps=True, forward=None, first_samp=0, max_iter=10000, verbose=None): - u"""Simulate raw data. + """Simulate raw data. Head movements can optionally be simulated using the ``head_pos`` parameter. diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index e70c8c11d19..851a5e19938 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -438,7 +438,7 @@ def test_simulate_calculate_head_pos_chpi(): # Read info dict from raw FIF file info = read_info(raw_fname) # Tune the info structure - chpi_channel = u'STI201' + chpi_channel = 'STI201' ncoil = len(info['hpi_results'][0]['order']) coil_freq = 10 + np.arange(ncoil) * 5 hpi_subsystem = {'event_channel': chpi_channel, diff --git a/mne/viz/_brain/colormap.py b/mne/viz/_brain/colormap.py index 93c1ad2a050..785748cfbfb 100644 --- a/mne/viz/_brain/colormap.py +++ b/mne/viz/_brain/colormap.py @@ -64,7 +64,7 @@ def get_fill_colors(cols, n_fill): def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=True): - u"""Transparent color map calculation. + """Transparent color map calculation. A colormap may be sequential or divergent. When the colormap is divergent indicate this by providing a value for 'center'. The diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 36482a2a9b0..2a14ba50f04 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -184,7 +184,7 @@ def plot_cov(cov, info, exclude=(), colorbar=True, proj=False, show_svd=True, axes[0, k].text(this_rank - 1, axes[0, k].get_ylim()[1], 'rank ≈ %d' % (this_rank,), ha='right', va='top', color='r', alpha=0.5, zorder=4) - axes[0, k].set(ylabel=u'Noise σ (%s)' % unit, yscale='log', + axes[0, k].set(ylabel='Noise σ (%s)' % unit, yscale='log', xlabel='Eigenvalue index', title=name, xlim=[0, len(s) - 1]) tight_layout(fig=fig_svd) @@ -1357,9 +1357,9 @@ def plot_csd(csd, info=None, mode='csd', colorbar=True, cmap=None, if colorbar: cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_]) if mode == 'csd': - label = u'CSD' + label = 'CSD' if ch_type in units: - label += u' (%s)' % units[ch_type] + label += ' (%s)' % units[ch_type] cb.set_label(label) elif mode == 'coh': cb.set_label('Coherence') diff --git a/tutorials/stats-sensor-space/10_background_stats.py b/tutorials/stats-sensor-space/10_background_stats.py index 76cf1c22060..ae790fc7eae 100644 --- a/tutorials/stats-sensor-space/10_background_stats.py +++ b/tutorials/stats-sensor-space/10_background_stats.py @@ -263,7 +263,7 @@ def plot_t_p(t, p, title, mcc, axes=None): fig, ax = plt.subplots(figsize=(4, 3)) ax.scatter(N, p_type_I, 3) ax.set(xlim=N[[0, -1]], ylim=[0, 1], xlabel=r'$N_{\mathrm{test}}$', - ylabel=u'Probability of at least\none type I error') + ylabel='Probability of at least\none type I error') ax.grid(True) fig.tight_layout() fig.show() From 62af4ac642b21bc0b996c8c86fdeab417806f6d7 Mon Sep 17 00:00:00 2001 From: Scott Huberty <52462026+scott-huberty@users.noreply.github.com> Date: Mon, 27 Mar 2023 13:31:36 -0400 Subject: [PATCH 0007/1125] ENH: Read eyetracking data (Eyelink) (Fork of #10855 ) (#11152) Co-authored-by: dominikwelke --- doc/_includes/data_formats.rst | 2 + doc/changes/latest.inc | 1 + doc/conf.py | 3 +- doc/datasets.rst | 3 +- doc/overview/datasets_index.rst | 13 + doc/preprocessing.rst | 13 + doc/reading_raw_data.rst | 1 + mne/channels/channels.py | 18 +- mne/datasets/__init__.py | 3 +- mne/datasets/config.py | 15 +- mne/datasets/eyelink/__init__.py | 3 + mne/datasets/eyelink/eyelink.py | 26 + mne/datasets/utils.py | 3 +- mne/defaults.py | 19 +- mne/io/__init__.py | 1 + mne/io/constants.py | 15 +- mne/io/eyelink/__init__.py | 7 + mne/io/eyelink/eyelink.py | 882 ++++++++++++++++++ mne/io/eyelink/tests/__init__.py | 0 mne/io/eyelink/tests/test_eyelink.py | 147 +++ mne/io/meas_info.py | 3 +- mne/io/pick.py | 64 +- mne/io/tests/test_constants.py | 6 +- mne/preprocessing/__init__.py | 1 + mne/preprocessing/eyetracking/__init__.py | 7 + mne/preprocessing/eyetracking/eyetracking.py | 146 +++ mne/preprocessing/tests/test_ica.py | 10 +- mne/simulation/tests/test_raw.py | 1 + mne/utils/docs.py | 5 + mne/viz/_mpl_figure.py | 5 +- mne/viz/tests/test_raw.py | 1 + tools/circleci_download.sh | 3 + tutorials/io/70_reading_eyetracking_data.py | 172 ++++ .../preprocessing/90_eyetracking_data.py | 107 +++ 34 files changed, 1665 insertions(+), 41 deletions(-) create mode 100644 mne/datasets/eyelink/__init__.py create mode 100644 mne/datasets/eyelink/eyelink.py create mode 100644 mne/io/eyelink/__init__.py create mode 100644 mne/io/eyelink/eyelink.py create mode 100644 mne/io/eyelink/tests/__init__.py create mode 100644 mne/io/eyelink/tests/test_eyelink.py create mode 100644 mne/preprocessing/eyetracking/__init__.py create mode 100644 mne/preprocessing/eyetracking/eyetracking.py create mode 100644 tutorials/io/70_reading_eyetracking_data.py create mode 100644 tutorials/preprocessing/90_eyetracking_data.py diff --git a/doc/_includes/data_formats.rst b/doc/_includes/data_formats.rst index 641810c6b63..63dbfcdc98b 100644 --- a/doc/_includes/data_formats.rst +++ b/doc/_includes/data_formats.rst @@ -75,6 +75,8 @@ EEG :ref:`Persyst ` .lay :func:`mn NIRS :ref:`NIRx ` directory :func:`mne.io.read_raw_nirx` NIRS :ref:`BOXY ` directory :func:`mne.io.read_raw_boxy` + +EYETRACK SR eyelink ASCII files .asc :func:`mne.io.read_raw_eyelink` ============ ============================================ ========= =================================== More details are provided in the tutorials in the :ref:`tut-data-formats` diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index c74f97f3eca..2d731c83364 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -40,6 +40,7 @@ Enhancements - Add automatic projection of sEEG contact onto the inflated surface for :meth:`mne.viz.Brain.add_sensors` (:gh:`11436` by `Alex Rockhill`_) - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) - Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) +- Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) Bugs ~~~~ diff --git a/doc/conf.py b/doc/conf.py index cd11edf8789..1b4f7ad2ed0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -267,7 +267,7 @@ # Undocumented (on purpose) 'RawKIT', 'RawEximia', 'RawEGI', 'RawEEGLAB', 'RawEDF', 'RawCTF', 'RawBTi', 'RawBrainVision', 'RawCurry', 'RawNIRX', 'RawGDF', 'RawSNIRF', 'RawBOXY', - 'RawPersyst', 'RawNihon', 'RawNedf', 'RawHitachi', 'RawFIL', + 'RawPersyst', 'RawNihon', 'RawNedf', 'RawHitachi', 'RawFIL', 'RawEyelink', # sklearn subclasses 'mapping', 'to', 'any', # unlinkable @@ -1231,6 +1231,7 @@ def reset_warnings(gallery_conf, fname): f'{tu}/{si}/plot_creating_data_structures.html': f'{tu}/{si}/10_array_objs.html', # noqa E501 f'{tu}/{si}/plot_point_spread.html': f'{tu}/{si}/70_point_spread.html', f'{tu}/{si}/plot_dics.html': f'{tu}/{si}/80_dics.html', + f'{tu}/{tf}/plot_eyetracking.html': f'{tu}/preprocessing/90_eyetracking_data.html', # noqa E501 f'{ex}/{co}/mne_inverse_label_connectivity.html': f'{mne_conn}/{ex}/mne_inverse_label_connectivity.html', # noqa E501 f'{ex}/{co}/cwt_sensor_connectivity.html': f'{mne_conn}/{ex}/cwt_sensor_connectivity.html', # noqa E501 f'{ex}/{co}/mixed_source_space_connectivity.html': f'{mne_conn}/{ex}/mixed_source_space_connectivity.html', # noqa E501 diff --git a/doc/datasets.rst b/doc/datasets.rst index 8f8e98d4d82..c3d94c49006 100644 --- a/doc/datasets.rst +++ b/doc/datasets.rst @@ -44,4 +44,5 @@ Datasets refmeg_noise.data_path ssvep.data_path erp_core.data_path - epilepsy_ecog.data_path \ No newline at end of file + epilepsy_ecog.data_path + eyelink.data_path \ No newline at end of file diff --git a/doc/overview/datasets_index.rst b/doc/overview/datasets_index.rst index 23827d978d7..b2d0715e8e9 100644 --- a/doc/overview/datasets_index.rst +++ b/doc/overview/datasets_index.rst @@ -475,6 +475,19 @@ standard. * :ref:`tut-ssvep` +EYELINK +======= +:func:`mne.datasets.eyelink.data_path` + +A small example dataset in SR research's proprietary .asc format. +1 participant fixated on the screen while short light flashes appeared. +Monocular recording of gaze position and pupil size, 1000 Hz sampling +frequency. + +.. topic:: Examples + + * :ref:`tut-eyetrack` + References ========== diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 98403661a6d..c92167a04fe 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -141,6 +141,19 @@ Projections: make_montage_volume warp_montage +:py:mod:`mne.preprocessing.eyetracking`: + +.. currentmodule:: mne.preprocessing.eyetracking + +.. automodule:: mne.preprocessing.eyetracking + :no-members: + :no-inherited-members: + +.. autosummary:: + :toctree: generated/ + + set_channel_types_eyetrack + EEG referencing: .. currentmodule:: mne diff --git a/doc/reading_raw_data.rst b/doc/reading_raw_data.rst index ad04c0ca91a..c9316ffa9b0 100644 --- a/doc/reading_raw_data.rst +++ b/doc/reading_raw_data.rst @@ -20,6 +20,7 @@ Reading raw data read_raw_ctf read_raw_curry read_raw_edf + read_raw_eyelink read_raw_bdf read_raw_gdf read_raw_kit diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 7b2afb2ffbb..c3c86d20a34 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -203,7 +203,8 @@ def equalize_channels(instances, copy=True, verbose=None): FIFF.FIFF_UNIT_MOL: 'M', FIFF.FIFF_UNIT_NONE: 'NA', FIFF.FIFF_UNIT_CEL: 'C', - FIFF.FIFF_UNIT_S: 'S'} + FIFF.FIFF_UNIT_S: 'S', + FIFF.FIFF_UNIT_PX: 'px'} def _check_set(ch, projs, ch_type): @@ -331,7 +332,8 @@ def set_channel_types(self, mapping, verbose=None): ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, dbs, stim, syst, ecog, hbo, hbr, fnirs_cw_amplitude, fnirs_fd_ac_amplitude, - fnirs_fd_phase, fnirs_od, temperature, gsr + fnirs_fd_phase, fnirs_od, eyetrack_pos, eyetrack_pupil, + temperature, gsr .. versionadded:: 0.9.0 """ @@ -379,6 +381,10 @@ def set_channel_types(self, mapping, verbose=None): coil_type = FIFF.FIFFV_COIL_FNIRS_FD_PHASE elif ch_type == 'fnirs_od': coil_type = FIFF.FIFFV_COIL_FNIRS_OD + elif ch_type == 'eyetrack_pos': + coil_type = FIFF.FIFFV_COIL_EYETRACK_POS + elif ch_type == 'eyetrack_pupil': + coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL else: coil_type = FIFF.FIFFV_COIL_NONE self.info['chs'][c_ind]['coil_type'] = coil_type @@ -595,7 +601,7 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, resp=False, chpi=False, exci=False, ias=False, syst=False, seeg=False, dipole=False, gof=False, bio=False, ecog=False, fnirs=False, csd=False, dbs=False, - temperature=False, gsr=False, + temperature=False, gsr=False, eyetrack=False, include=(), exclude='bads', selection=None, verbose=None): """Pick some channels by type and names. @@ -621,9 +627,9 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, self.info, meg=meg, eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, ref_meg=ref_meg, misc=misc, resp=resp, chpi=chpi, exci=exci, ias=ias, syst=syst, seeg=seeg, dipole=dipole, gof=gof, bio=bio, - ecog=ecog, fnirs=fnirs, csd=csd, dbs=dbs, include=include, - exclude=exclude, selection=selection, temperature=temperature, - gsr=gsr) + ecog=ecog, fnirs=fnirs, csd=csd, dbs=dbs, temperature=temperature, + gsr=gsr, eyetrack=eyetrack, include=include, exclude=exclude, + selection=selection) self._pick_drop_channels(idx) diff --git a/mne/datasets/__init__.py b/mne/datasets/__init__.py index 96219aaf621..ec24f450fd0 100644 --- a/mne/datasets/__init__.py +++ b/mne/datasets/__init__.py @@ -26,6 +26,7 @@ from . import ssvep from . import erp_core from . import epilepsy_ecog +from . import eyelink from . import ucl_opm_auditory from ._fetch import fetch_dataset from .utils import (_download_all_example_data, fetch_hcp_mmp_parcellation, @@ -42,5 +43,5 @@ 'sleep_physionet', 'somato', 'spm_face', 'ssvep', 'testing', 'visual_92_categories', 'limo', 'erp_core', 'epilepsy_ecog', 'fetch_dataset', 'fetch_phantom', 'has_dataset', 'refmeg_noise', - 'fnirs_motor' + 'fnirs_motor', 'eyelink' ] diff --git a/mne/datasets/config.py b/mne/datasets/config.py index c9431a9837e..dc851e9bd2f 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing='0.142', misc='0.24') +RELEASES = dict(testing='0.144', misc='0.26') TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -111,7 +111,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS['testing'] = dict( archive_name=f'{TESTING_VERSIONED}.tar.gz', - hash='md5:44b857ddb34aefd752e4f5b19d625dee', + hash='md5:fb546f44dba3310945225ed8fdab4a91', url=('https://codeload.github.com/mne-tools/mne-testing-data/' f'tar.gz/{RELEASES["testing"]}'), # In case we ever have to resort to osf.io again... @@ -123,7 +123,7 @@ ) MNE_DATASETS['misc'] = dict( archive_name=f'{MISC_VERSIONED}.tar.gz', # 'mne-misc-data', - hash='md5:eb017a919939511932bd683f26f97490', + hash='md5:868b484fadd73b1d1a3535b7194a0d03', url=('https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/' f'{RELEASES["misc"]}'), folder_name='MNE-misc-data', @@ -335,3 +335,12 @@ folder_name='foo', config_key='MNE_DATASETS_FAKE_PATH' ) + +# eyelink dataset +MNE_DATASETS['eyelink'] = dict( + archive_name='eyelink_example_data.zip', + hash='md5:081950c05f35267458d9c751e178f161', + url=('https://osf.io/r5ndq/download?version=1'), + folder_name='eyelink-example-data', + config_key='MNE_DATASETS_EYELINK_PATH' +) diff --git a/mne/datasets/eyelink/__init__.py b/mne/datasets/eyelink/__init__.py new file mode 100644 index 00000000000..85931aba72d --- /dev/null +++ b/mne/datasets/eyelink/__init__.py @@ -0,0 +1,3 @@ +"""Eyelink test dataset.""" + +from .eyelink import data_path, get_version diff --git a/mne/datasets/eyelink/eyelink.py b/mne/datasets/eyelink/eyelink.py new file mode 100644 index 00000000000..a08e338ab33 --- /dev/null +++ b/mne/datasets/eyelink/eyelink.py @@ -0,0 +1,26 @@ +# Authors: Dominik Welke +# License: BSD Style. + +from ...utils import verbose +from ..utils import (_data_path_doc, _get_version, _version_doc, + _download_mne_dataset) + + +@verbose +def data_path(path=None, force_update=False, update_path=True, + download=True, *, verbose=None): # noqa: D103 + return _download_mne_dataset( + name='eyelink', processor='unzip', path=path, + force_update=force_update, update_path=update_path, + download=download) + + +data_path.__doc__ = _data_path_doc.format(name='eyelink', + conf='MNE_DATASETS_EYELINK_PATH') + + +def get_version(): # noqa: D103 + return _get_version('eyelink') + + +get_version.__doc__ = _version_doc.format(name='eyelink') diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index e03d179cfc6..50a894bfd7b 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -304,7 +304,7 @@ def _download_all_example_data(verbose=True): kiloword, phantom_4dbti, sleep_physionet, limo, fnirs_motor, refmeg_noise, fetch_infant_template, fetch_fsaverage, ssvep, erp_core, epilepsy_ecog, - fetch_phantom, ucl_opm_auditory) + fetch_phantom, eyelink, ucl_opm_auditory) sample_path = sample.data_path() testing.data_path() misc.data_path() @@ -327,6 +327,7 @@ def _download_all_example_data(verbose=True): brainstorm.bst_resting.data_path(accept=True) phantom_path = brainstorm.bst_phantom_elekta.data_path(accept=True) fetch_phantom('otaniemi', subjects_dir=phantom_path) + eyelink.data_path() brainstorm.bst_phantom_ctf.data_path(accept=True) eegbci.load_data(1, [6, 10, 14], update_path=True) for subj in range(4): diff --git a/mne/defaults.py b/mne/defaults.py index 362eba0d67f..16b3b843406 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -13,26 +13,29 @@ dipole='k', gof='k', bio='k', ecog='k', hbo='#AA3377', hbr='b', fnirs_cw_amplitude='k', fnirs_fd_ac_amplitude='k', fnirs_fd_phase='k', fnirs_od='k', csd='k', whitened='k', - gsr='#666633', temperature='#663333'), + gsr='#666633', temperature='#663333', + eyegaze='k', pupil='k'), si_units=dict(mag='T', grad='T/m', eeg='V', eog='V', ecg='V', emg='V', misc='AU', seeg='V', dbs='V', dipole='Am', gof='GOF', bio='V', ecog='V', hbo='M', hbr='M', ref_meg='T', fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', fnirs_fd_phase='rad', fnirs_od='V', csd='V/m²', - whitened='Z', gsr='S', temperature='C'), + whitened='Z', gsr='S', temperature='C', + eyegaze='AU', pupil='AU'), units=dict(mag='fT', grad='fT/cm', eeg='µV', eog='µV', ecg='µV', emg='µV', misc='AU', seeg='mV', dbs='µV', dipole='nAm', gof='GOF', bio='µV', ecog='µV', hbo='µM', hbr='µM', ref_meg='fT', fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', fnirs_fd_phase='rad', fnirs_od='V', csd='mV/m²', - whitened='Z', gsr='S', temperature='C'), + whitened='Z', gsr='S', temperature='C', + eyegaze='AU', pupil='AU'), # scalings for the units scalings=dict(mag=1e15, grad=1e13, eeg=1e6, eog=1e6, emg=1e6, ecg=1e6, misc=1.0, seeg=1e3, dbs=1e6, ecog=1e6, dipole=1e9, gof=1.0, bio=1e6, hbo=1e6, hbr=1e6, ref_meg=1e15, fnirs_cw_amplitude=1.0, fnirs_fd_ac_amplitude=1.0, fnirs_fd_phase=1., fnirs_od=1.0, csd=1e3, whitened=1., - gsr=1., temperature=1.), + gsr=1., temperature=1., eyegaze=1., pupil=1.), # rough guess for a good plot scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc='auto', @@ -42,13 +45,15 @@ fnirs_fd_ac_amplitude=2e-2, fnirs_fd_phase=2e-1, fnirs_od=2e-2, csd=200e-4, dipole=1e-7, gof=1e2, - gsr=1., temperature=0.1), + gsr=1., temperature=0.1, + eyegaze=3e-1, pupil=1e3), scalings_cov_rank=dict(mag=1e12, grad=1e11, eeg=1e5, # ~100x scalings seeg=1e1, dbs=1e4, ecog=1e4, hbo=1e4, hbr=1e4), ylim=dict(mag=(-600., 600.), grad=(-200., 200.), eeg=(-200., 200.), misc=(-5., 5.), seeg=(-20., 20.), dbs=(-200., 200.), dipole=(-100., 100.), gof=(0., 1.), bio=(-500., 500.), - ecog=(-200., 200.), hbo=(0, 20), hbr=(0, 20), csd=(-50., 50.)), + ecog=(-200., 200.), hbo=(0, 20), hbr=(0, 20), csd=(-50., 50.), + eyegaze=(0., 5000.), pupil=(0., 5000.)), titles=dict(mag='Magnetometers', grad='Gradiometers', eeg='EEG', eog='EOG', ecg='ECG', emg='EMG', misc='misc', seeg='sEEG', dbs='DBS', bio='BIO', dipole='Dipole', ecog='ECoG', hbo='Oxyhemoglobin', @@ -60,6 +65,8 @@ gof='Goodness of fit', csd='Current source density', stim='Stimulus', gsr='Galvanic skin response', temperature='Temperature', + eyegaze='Eye-tracking (Gaze position)', + pupil='Eye-tracking (Pupil size)', ), mask_params=dict(marker='o', markerfacecolor='w', diff --git a/mne/io/__init__.py b/mne/io/__init__.py index 6ed6b898566..0abb704873b 100644 --- a/mne/io/__init__.py +++ b/mne/io/__init__.py @@ -60,6 +60,7 @@ read_evoked_fieldtrip) from .nihon import read_raw_nihon from ._read_raw import read_raw +from .eyelink import read_raw_eyelink # for backward compatibility diff --git a/mne/io/constants.py b/mne/io/constants.py index 1159c85283d..f2847644f07 100644 --- a/mne/io/constants.py +++ b/mne/io/constants.py @@ -204,6 +204,8 @@ FIFF.FIFFV_FNIRS_CH = 1100 # Functional near-infrared spectroscopy FIFF.FIFFV_TEMPERATURE_CH = 1200 # Functional near-infrared spectroscopy FIFF.FIFFV_GALVANIC_CH = 1300 # Galvanic skin response +FIFF.FIFFV_EYETRACK_CH = 1400 # Eye-tracking + _ch_kind_named = {key: key for key in ( FIFF.FIFFV_BIO_CH, FIFF.FIFFV_MEG_CH, @@ -227,6 +229,7 @@ FIFF.FIFFV_FNIRS_CH, FIFF.FIFFV_GALVANIC_CH, FIFF.FIFFV_TEMPERATURE_CH, + FIFF.FIFFV_EYETRACK_CH )} # @@ -854,6 +857,8 @@ FIFF.FIFF_UNIT_AM = 202 # Am FIFF.FIFF_UNIT_AM_M2 = 203 # Am/m^2 FIFF.FIFF_UNIT_AM_M3 = 204 # Am/m^3 + +FIFF.FIFF_UNIT_PX = 210 # Pixel _ch_unit_named = {key: key for key in( FIFF.FIFF_UNIT_NONE, FIFF.FIFF_UNIT_UNITLESS, FIFF.FIFF_UNIT_M, FIFF.FIFF_UNIT_KG, FIFF.FIFF_UNIT_SEC, FIFF.FIFF_UNIT_A, FIFF.FIFF_UNIT_K, @@ -865,6 +870,7 @@ FIFF.FIFF_UNIT_CEL, FIFF.FIFF_UNIT_LM, FIFF.FIFF_UNIT_LX, FIFF.FIFF_UNIT_V_M2, FIFF.FIFF_UNIT_T_M, FIFF.FIFF_UNIT_AM, FIFF.FIFF_UNIT_AM_M2, FIFF.FIFF_UNIT_AM_M3, + FIFF.FIFF_UNIT_PX, )} # # Multipliers @@ -916,6 +922,11 @@ FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE = 304 # fNIRS frequency domain AC amplitude FIFF.FIFFV_COIL_FNIRS_FD_PHASE = 305 # fNIRS frequency domain phase FIFF.FIFFV_COIL_FNIRS_RAW = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE # old alias +FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE = 306 # fNIRS time-domain gated amplitude +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE = 307 # fNIRS time-domain moments amplitude + +FIFF.FIFFV_COIL_EYETRACK_POS = 400 # Eye-tracking gaze position +FIFF.FIFFV_COIL_EYETRACK_PUPIL = 401 # Eye-tracking pupil size FIFF.FIFFV_COIL_MCG_42 = 1000 # For testing the MCG software @@ -1002,7 +1013,9 @@ FIFF.FIFFV_COIL_DIPOLE, FIFF.FIFFV_COIL_FNIRS_HBO, FIFF.FIFFV_COIL_FNIRS_HBR, FIFF.FIFFV_COIL_FNIRS_RAW, FIFF.FIFFV_COIL_FNIRS_OD, FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, - FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_MCG_42, + FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE, FIFF.FIFFV_COIL_MCG_42, + FIFF.FIFFV_COIL_EYETRACK_POS, FIFF.FIFFV_COIL_EYETRACK_PUPIL, FIFF.FIFFV_COIL_POINT_MAGNETOMETER, FIFF.FIFFV_COIL_AXIAL_GRAD_5CM, FIFF.FIFFV_COIL_VV_PLANAR_W, FIFF.FIFFV_COIL_VV_PLANAR_T1, FIFF.FIFFV_COIL_VV_PLANAR_T2, FIFF.FIFFV_COIL_VV_PLANAR_T3, diff --git a/mne/io/eyelink/__init__.py b/mne/io/eyelink/__init__.py new file mode 100644 index 00000000000..77ee7ebc9ef --- /dev/null +++ b/mne/io/eyelink/__init__.py @@ -0,0 +1,7 @@ +"""Module for loading Eye-Tracker data.""" + +# Author: Dominik Welke +# +# License: BSD-3-Clause + +from .eyelink import read_raw_eyelink diff --git a/mne/io/eyelink/eyelink.py b/mne/io/eyelink/eyelink.py new file mode 100644 index 00000000000..a85796d5b77 --- /dev/null +++ b/mne/io/eyelink/eyelink.py @@ -0,0 +1,882 @@ +# Authors: Dominik Welke +# Scott Huberty +# Christian O'Reilly +# +# License: BSD-3-Clause + +from datetime import datetime, timezone, timedelta +from pathlib import Path + +import numpy as np +from ..constants import FIFF +from ..base import BaseRaw +from ..meas_info import create_info +from ...annotations import Annotations +from ...utils import logger, verbose, fill_doc, _check_pandas_installed + +EYELINK_COLS = {'timestamp': ('time',), + 'pos': {'left': ('xpos_left', 'ypos_left', 'pupil_left'), + 'right': ('xpos_right', 'ypos_right', 'pupil_right')}, + 'velocity': {'left': ('xvel_left', 'yvel_left'), + 'right': ('xvel_right', 'yvel_right')}, + 'resolution': ('xres', 'yres'), + 'input': ('DIN',), + 'flags': ('flags',), + 'remote': ('x_head', 'y_head', + 'distance'), + 'remote_flags': ('head_flags',), + 'block_num': ('block',), + 'eye_event': ('eye', 'time', 'end_time', 'duration'), + 'fixation': ('fix_avg_x', 'fix_avg_y', + 'fix_avg_pupil_size'), + 'saccade': ('sacc_start_x', 'sacc_start_y', + 'sacc_end_x', 'sacc_end_y', + 'sacc_visual_angle', 'peak_velocity')} + + +def _isfloat(token): + """Boolean test for whether string can be of type float. + + Parameters + ---------- + token : str + Single element from tokens list. + """ + if isinstance(token, str): + try: + float(token) + return True + except ValueError: + return False + else: + raise ValueError('input should be a string,' + f' but {token} is of type {type(token)}') + + +def _convert_types(tokens): + """Convert the type of each token in list. + + The tokens input is a list of string elements. + Posix timestamp strings can be integers, eye gaze position and + pupil size can be floats. flags token ("...") remains as string. + Missing eye/head-target data (indicated by '.' or 'MISSING_DATA') + are replaced by np.nan. + + Parameters + ---------- + Tokens : list + List of string elements. + + Returns + ------- + Tokens list with elements of various types. + """ + return [int(token) if token.isdigit() # execute this before _isfloat() + else float(token) if _isfloat(token) + else np.nan if token in ('.', 'MISSING_DATA') + else token # remains as string + for token in tokens] + + +def _parse_line(line): + """Parse tab delminited string from eyelink ASCII file. + + Takes a tab deliminited string from eyelink file, + splits it into a list of tokens, and converts the type + for each token in the list. + """ + if len(line): + tokens = line.split() + return _convert_types(tokens) + else: + raise ValueError('line is empty, nothing to parse') + + +def _is_sys_msg(line): + """Flag lines from eyelink ASCII file that contain a known system message. + + Some lines in eyelink files are system outputs usually + only meant for Eyelinks DataViewer application to read. + These shouldn't need to be parsed. + + Parameters + ---------- + line : string + single line from Eyelink asc file + + Returns + ------- + bool : + True if any of the following strings that are + known to indicate a system message are in the line + + Notes + ----- + Examples of eyelink system messages: + - ;Sess:22Aug22;Tria:1;Tri2:False;ESNT:182BFE4C2F4; + - ;NTPT:182BFE55C96;SMSG:__NTP_CLOCK_SYNC__;DIFF:-1; + - !V APLAYSTART 0 1 library/audio + - !MODE RECORD CR 500 2 1 R + """ + return any(['!V' in line, + '!MODE' in line, + ';' in line]) + + +def _get_sfreq(rec_info): + """Get sampling frequency from Eyelink ASCII file. + + Parameters + ---------- + rec_info : list + the first list in self._event_lines['SAMPLES']. + The sfreq occurs after RATE: i.e. [..., RATE, 1000, ...]. + + Returns + ------- + sfreq : int | float + """ + for i, token in enumerate(rec_info): + if token == 'RATE': + # sfreq is the first token after RATE + return rec_info[i + 1] + + +def _sort_by_time(df, col='time'): + df.sort_values(col, ascending=True, inplace=True) + df.reset_index(drop=True, inplace=True) + + +def _convert_times(df, first_samp, col='time'): + """Set initial time to 0, converts from ms to seconds in place. + + Parameters + ---------- + df pandas.DataFrame: + One of the dataframes in the self.dataframes dict. + + first_samp int: + timestamp of the first sample of the recording. This should + be the first sample of the first recording block. + col str (default 'time'): + column name to sort pandas.DataFrame by + + Notes + ----- + Each sample in an Eyelink file has a posix timestamp string. + Subtracts the "first" sample's timestamp from each timestamp. + The "first" sample is inferred to be the first sample of + the first recording block, i.e. the first "START" line. + """ + _sort_by_time(df, col) + for col in df.columns: + if col.endswith('time'): # 'time' and 'end_time' cols + df[col] -= first_samp + df[col] /= 1000 + if col in ['duration', 'offset']: + df[col] /= 1000 + + +def _fill_times(df, sfreq, time_col='time',): + """Fill missing timestamps if there are multiple recording blocks. + + Parameters + ---------- + df : pandas.DataFrame: + dataframe of the eyetracking data samples, BEFORE + _convert_times() is applied to the dataframe + + sfreq : int | float: + sampling frequency of the data + + time_col : str (default 'time'): + name of column with the timestamps (e.g. 9511881, 9511882, ...) + + Returns + ------- + %(df_return)s + + Notes + ----- + After _parse_recording_blocks, Files with multiple recording blocks will + have missing timestamps for the duration of the period between the blocks. + This would cause the occular annotations (i.e. blinks) to not line up with + the signal. + """ + pd = _check_pandas_installed() + + first, last = df[time_col].iloc[[0, -1]] + step = 1000 / sfreq + df[time_col] = df[time_col].astype(float) + new_times = pd.DataFrame(np.arange(first, last + step / 2, step), + columns=[time_col]) + return pd.merge_asof(new_times, df, on=time_col, direction='nearest', + tolerance=step / 10) + + +def _find_overlaps(df, max_time=0.05): + """Merge left/right eye events with onset/offset diffs less than max_time. + + df : pandas.DataFrame + Pandas DataFrame with occular events (fixations, saccades, blinks) + max_time : float (default 0.05) + Time in seconds. Defaults to .05 (50 ms) + + Returns + ------- + DataFrame: %(df_return)s + :class:`pandas.DataFrame` specifying overlapped eye events, if any + Notes + ----- + The idea is to cumulative sum the boolean values for rows with onset and + offset differences (against the previous row) that are greater than the + max_time. If onset and offset diffs are less than max_time then no_overlap + will become False. Alternatively, if either the onset or offset diff is + greater than max_time, no_overlap becomes True. Cumulatively summing over + these boolean values will leave rows with no_overlap == False unchanged + and hence with the same group number. + """ + pd = _check_pandas_installed() + + df = df.copy() + df["overlap_start"] = df.sort_values("time")["time"]\ + .diff()\ + .lt(max_time) + + df["overlap_end"] = (df["end_time"] + .diff().abs() + .lt(max_time)) + + df["no_overlap"] = ~(df["overlap_end"] + & df["overlap_start"]) + df["group"] = df["no_overlap"].cumsum() + + # now use groupby on 'group'. If one left and one right eye in group + # the new start/end times are the mean of the two eyes + ovrlp = pd.concat([pd.DataFrame(g[1].drop(columns="eye").mean()).T + if (len(g[1]) == 2) and (len(g[1].eye.unique()) == 2) + else g[1] # not an overlap, return group unchanged + for g in df.groupby("group")] + ) + # overlapped events get a "both" value in the "eye" col + if "eye" in ovrlp.columns: + ovrlp["eye"] = ovrlp["eye"].fillna("both") + else: + ovrlp["eye"] = "both" + tmp_cols = ["overlap_start", "overlap_end", "no_overlap", "group"] + return ovrlp.drop(columns=tmp_cols).reset_index(drop=True) + + +@fill_doc +def read_raw_eyelink(fname, preload=False, verbose=None, + create_annotations=True, apply_offsets=False, + find_overlaps=False, overlap_threshold=0.05, + gap_description='bad_rec_gap'): + """Reader for an Eyelink .asc file. + + Parameters + ---------- + fname : str + Path to the eyelink file (.asc). + %(preload)s + %(verbose)s + create_annotations : bool | list (default True) + Whether to create mne.Annotations from occular events + (blinks, fixations, saccades) and experiment messages. If a list, must + contain one or more of ['fixations', 'saccades',' blinks', messages']. + If True, creates mne.Annotations for both occular events and experiment + messages. + apply_offsets : bool (default False) + Adjusts the onset time of the mne.Annotations created from Eyelink + experiment messages, if offset values exist in + self.dataframes['messages']. + find_overlaps : bool (default False) + Combine left and right eye :class:`mne.Annotations` (blinks, fixations, + saccades) if their start times and their stop times are both not + separated by more than overlap_threshold. + overlap_threshold : float (default 0.05) + Time in seconds. Threshold of allowable time-gap between the start and + stop times of the left and right eyes. If gap is larger than threshold, + the :class:`mne.Annotations` will be kept separate (i.e. "blink_L", + "blink_R"). If the gap is smaller than the threshold, the + :class:`mne.Annotations` will be merged (i.e. "blink_both"). + gap_description : str (default 'bad_rec_gap') + If there are multiple recording blocks in the file, the description of + the annotation that will span across the gap period between the + blocks. Uses 'bad_rec_gap' by default so that these time periods will + be considered bad by MNE and excluded from operations like epoching. + + Returns + ------- + raw : instance of RawEyelink + A Raw object containing eyetracker data. + + See Also + -------- + mne.io.Raw : Documentation of attribute and methods. + """ + extension = Path(fname).suffix + if extension not in '.asc': + raise ValueError('This reader can only read eyelink .asc files.' + f' Got extension {extension} instead. consult eyelink' + ' manual for converting eyelink data format (.edf)' + ' files to .asc format.') + + return RawEyelink(fname, preload=preload, verbose=verbose, + create_annotations=create_annotations, + apply_offsets=apply_offsets, + find_overlaps=find_overlaps, + overlap_threshold=overlap_threshold, + gap_desc=gap_description) + + +@fill_doc +class RawEyelink(BaseRaw): + """Raw object from an XXX file. + + Parameters + ---------- + fname : str + Path to the data file (.XXX). + create_annotations : bool | list (default True) + Whether to create mne.Annotations from occular events + (blinks, fixations, saccades) and experiment messages. If a list, must + contain one or more of ['fixations', 'saccades',' blinks', messages']. + If True, creates mne.Annotations for both occular events and experiment + messages. + apply_offsets : bool (default False) + Adjusts the onset time of the mne.Annotations created from Eyelink + experiment messages, if offset values exist in + raw.dataframes['messages']. + find_overlaps : boolean (default False) + Combine left and right eye :class:`mne.Annotations` (blinks, fixations, + saccades) if their start times and their stop times are both not + separated by more than overlap_threshold. + overlap_threshold : float (default 0.05) + Time in seconds. Threshold of allowable time-gap between the start and + stop times of the left and right eyes. If gap is larger than threshold, + the :class:`mne.Annotations` will be kept separate (i.e. "blink_L", + "blink_R"). If the gap is smaller than the threshold, the + :class:`mne.Annotations` will be merged (i.e. "blink_both"). + gap_desc : str (default 'bad_rec_gap') + If there are multiple recording blocks in the file, the description of + the annotation that will span across the gap period between the + blocks. Uses 'bad_rec_gap' by default so that these time periods will + be considered bad by MNE and excluded from operations like epoching. + %(preload)s + %(verbose)s + + Attributes + ---------- + fname : pathlib.Path + Eyelink filename + dataframes : dict + Dictionary of pandas DataFrames. One for eyetracking samples, + and one for each type of eyelink event (blinks, messages, etc) + _sample_lines : list + List of lists, each list is one sample containing eyetracking + X/Y and pupil channel data (+ other channels, if they exist) + _event_lines : dict + Each key contains a list of lists, for an event-type that occurred + during the recording period. Events can vary, from occular events + (blinks, saccades, fixations), to messages from the stimulus + presentation software, or info from a response controller. + _system_lines : list + List of tab delimited strings. Each string is a system message, + that in most cases aren't needed. System messages occur for + Eyelinks DataViewer application. + _tracking_mode : str + Whether whether a single eye was tracked ('monocular'), or both + ('binocular'). + _gap_desc : str + The description to be used for annotations returned by _make_gap_annots + + See Also + -------- + mne.io.Raw : Documentation of attribute and methods. + """ + + @verbose + def __init__(self, fname, preload=False, verbose=None, + create_annotations=True, + apply_offsets=False, find_overlaps=False, + overlap_threshold=0.05, + gap_desc='bad_rec_gap'): + + logger.info('Loading {}'.format(fname)) + + self.fname = Path(fname) + self._sample_lines = None + self._event_lines = None + self._system_lines = None + self._tracking_mode = None # assigned in self._infer_col_names + self._meas_date = None + self._rec_info = None + self._gap_desc = gap_desc + self.dataframes = {} + + self._get_recording_datetime() # sets self._meas_date + self._parse_recording_blocks() # sets sample, event, & system lines + + sfreq = _get_sfreq(self._event_lines['SAMPLES'][0]) + col_names, ch_names = self._infer_col_names() + self._create_dataframes(col_names, sfreq, find_overlaps=find_overlaps, + threshold=overlap_threshold) + info = self._create_info(ch_names, sfreq) + eye_ch_data = self.dataframes['samples'][ch_names] + eye_ch_data = eye_ch_data.to_numpy().T + + # create mne object + super(RawEyelink, self).__init__(info, preload=eye_ch_data, + filenames=[self.fname], + verbose=verbose) + # set meas_date + self.set_meas_date(self._meas_date) + + # Make Annotations + gap_annots = None + if len(self.dataframes['recording_blocks']) > 1: + gap_annots = self._make_gap_annots() + eye_annots = None + if create_annotations: + eye_annots = self._make_eyelink_annots(self.dataframes, + create_annotations, + apply_offsets) + if gap_annots and eye_annots: # set both + self.set_annotations(gap_annots + eye_annots) + elif gap_annots: + self.set_annotations(gap_annots) + elif eye_annots: + self.set_annotations(eye_annots) + else: + logger.info('Not creating any annotations') + + def _parse_recording_blocks(self): + """Parse Eyelink ASCII file. + + Eyelink samples occur within START and END blocks. + samples lines start with a posix-like string, + and contain eyetracking sample info. Event Lines + start with an upper case string and contain info + about occular events (i.e. blink/saccade), or experiment + messages sent by the stimulus presentation software. + """ + with self.fname.open() as file: + block_num = 1 + self._sample_lines = [] + self._event_lines = {'START': [], 'END': [], 'SAMPLES': [], + 'EVENTS': [], 'ESACC': [], 'EBLINK': [], + 'EFIX': [], 'MSG': [], 'INPUT': [], + 'BUTTON': [], 'PUPIL': []} + self._system_lines = [] + + is_recording_block = False + for line in file: + if line.startswith('START'): # start of recording block + is_recording_block = True + if is_recording_block: + if _is_sys_msg(line): + self._system_lines.append(line) + continue # system messages don't need to be parsed. + tokens = _parse_line(line) + tokens.append(block_num) # add current block number + if isinstance(tokens[0], (int, float)): # Samples + self._sample_lines.append(tokens) + elif tokens[0] in self._event_lines.keys(): + event_key, event_info = tokens[0], tokens[1:] + self._event_lines[event_key].append(event_info) + if tokens[0] == 'END': # end of recording block + is_recording_block = False + block_num += 1 + if not self._event_lines['START']: + raise ValueError('Could not determine the start of the' + ' recording. When converting to ASCII, START' + ' events should not be suppressed.') + if not self._sample_lines: # no samples parsed + raise ValueError(f"Couldn't find any samples in {self.fname}") + self._validate_data() + + def _validate_data(self): + """Check the incoming data for some known problems that can occur.""" + self._rec_info = self._event_lines['SAMPLES'][0] + pupil_info = self._event_lines['PUPIL'][0] + n_blocks = len(self._event_lines['START']) + sfreq = int(_get_sfreq(self._rec_info)) + first_samp = self._event_lines['START'][0][0] + if ('LEFT' in self._rec_info) and ('RIGHT' in self._rec_info): + self._tracking_mode = 'binocular' + else: + self._tracking_mode = 'monocular' + # Detect the datatypes that are in file. + if 'GAZE' in self._rec_info: + logger.info('Pixel coordinate data detected.') + logger.warning('Pass `scalings=dict(eyegaze=1e3)` when using plot' + ' method to make traces more legible.') + elif 'HREF' in self._rec_info: + logger.info('Head-referenced eye angle data detected.') + elif 'PUPIL' in self._rec_info: + logger.warning('Raw eyegaze coordinates detected. Analyze with' + ' caution.') + if 'AREA' in pupil_info: + logger.info('Pupil-size area reported.') + elif 'DIAMETER' in pupil_info: + logger.info('Pupil-size diameter reported.') + # Check sampling frequency. + if sfreq == 2000 and isinstance(first_samp, int): + raise ValueError(f'The sampling rate is {sfreq}Hz but the' + ' timestamps were not output as float values.' + ' Check the settings in the EDF2ASC application.') + elif sfreq != 2000 and isinstance(first_samp, float): + raise ValueError('For recordings with a sampling rate less than' + ' 2000Hz, timestamps should not be output to the' + ' ASCII file as float values. Check the' + ' settings in the EDF2ASC application. Got a' + f' sampling rate of {sfreq}Hz.') + # If more than 1 recording period, make sure sfreq didn't change. + if n_blocks > 1: + err_msg = 'The sampling frequency changed during the recording.'\ + ' This file cannot be read into MNE.' + for block_info in self._event_lines['SAMPLES'][1:]: + block_sfreq = int(_get_sfreq(block_info)) + if block_sfreq != sfreq: + raise ValueError(err_msg + + f' Got both {sfreq} and {block_sfreq} Hz.' + ) + if self._tracking_mode == 'monocular': + assert self._rec_info[1] in ['LEFT', 'RIGHT'] + eye = self._rec_info[1] + blocks_list = self._event_lines['SAMPLES'] + eye_per_block = [block_info[1] for block_info in blocks_list] + if not all([this_eye == eye for this_eye in eye_per_block]): + logger.warning('The eye being tracked changed during the' + ' recording. The channel names will reflect' + ' the eye that was tracked at the start of' + ' the recording.') + + def _get_recording_datetime(self): + """Create a datetime object from the datetime in ASCII file.""" + # create a timezone object for UTC + tz = timezone(timedelta(hours=0)) + in_header = False + with self.fname.open() as file: + for line in file: + # header lines are at top of file and start with ** + if line.startswith('**'): + in_header = True + if in_header: + if line.startswith('** DATE:'): + dt_str = line.replace('** DATE:', '').strip() + fmt = "%a %b %d %H:%M:%S %Y" + try: + # Eyelink measdate timestamps are timezone naive. + # Force datetime to be in UTC. + # Even though dt is probably in local time zone. + dt_naive = datetime.strptime(dt_str, fmt) + dt_aware = dt_naive.replace(tzinfo=tz) + self._meas_date = dt_aware + except Exception: + msg = ('Extraction of measurement date failed.' + ' Please report this as a github issue.' + ' The date is being set to None') + logger.warning(msg) + break + + def _href_to_radian(self, opposite, f=15_000): + """Convert HREF eyegaze samples to radians. + + Parameters + ---------- + opposite : int + The x or y coordinate in an HREF gaze sample. + f : int (default 15_000) + distance of plane from the eye. + + Returns + ------- + x or y coordinate in radians + + Notes + ----- + See section 4.4.2.2 in the Eyelink 1000 Plus User Manual + (version 1.0.19) for a detailed description of HREF data. + """ + return np.arcsin(opposite / f) + + def _infer_col_names(self): + """Build column and channel names for data from Eyelink ASCII file. + + Returns the expected column names for the sample lines and event + lines, to be passed into pd.DataFrame. Sample and event lines in + eyelink files have a fixed order of columns, but the columns that + are present can vary. The order that col_names is built below should + NOT change. + """ + col_names = {} + # initiate the column names for the sample lines + col_names['sample'] = list(EYELINK_COLS['timestamp']) + + # and for the eye message lines + col_names['blink'] = list(EYELINK_COLS['eye_event']) + col_names['fixation'] = list(EYELINK_COLS['eye_event'] + + EYELINK_COLS['fixation']) + col_names['saccade'] = list(EYELINK_COLS['eye_event'] + + EYELINK_COLS['saccade']) + + # Recording was either binocular or monocular + # If monocular, find out which eye was tracked and append to ch_name + if self._tracking_mode == 'monocular': + assert self._rec_info[1] in ['LEFT', 'RIGHT'] + eye = self._rec_info[1].lower() + ch_names = list(EYELINK_COLS['pos'][eye]) + elif self._tracking_mode == 'binocular': + ch_names = list(EYELINK_COLS['pos']['left'] + + EYELINK_COLS['pos']['right']) + col_names['sample'].extend(ch_names) + + # The order of these if statements should not be changed. + if 'VEL' in self._rec_info: # If velocity data are reported + if self._tracking_mode == 'monocular': + ch_names.extend(EYELINK_COLS['velocity'][eye]) + col_names['sample'].extend(EYELINK_COLS['velocity'][eye]) + elif self._tracking_mode == 'binocular': + ch_names.extend(EYELINK_COLS['velocity']['left'] + + EYELINK_COLS['velocity']['right']) + col_names['sample'].extend(EYELINK_COLS['velocity']['left'] + + EYELINK_COLS['velocity']['right']) + # if resolution data are reported + if 'RES' in self._rec_info: + ch_names.extend(EYELINK_COLS['resolution']) + col_names['sample'].extend(EYELINK_COLS['resolution']) + col_names['fixation'].extend(EYELINK_COLS['resolution']) + col_names['saccade'].extend(EYELINK_COLS['resolution']) + # if digital input port values are reported + if 'INPUT' in self._rec_info: + ch_names.extend(EYELINK_COLS['input']) + col_names['sample'].extend(EYELINK_COLS['input']) + + # add flags column + col_names['sample'].extend(EYELINK_COLS['flags']) + + # if head target info was reported, add its cols after flags col. + if 'HTARGET' in self._rec_info: + ch_names.extend(EYELINK_COLS['remote']) + col_names['sample'].extend(EYELINK_COLS['remote'] + + EYELINK_COLS['remote_flags']) + + # finally add a column for recording block number + # FYI this column does not exist in the asc file.. + # but it is added during _parse_recording_blocks + for col in col_names.values(): + col.extend(EYELINK_COLS['block_num']) + + return col_names, ch_names + + def _create_dataframes(self, col_names, sfreq, find_overlaps=False, + threshold=0.05): + """Create pandas.DataFrame for Eyelink samples and events. + + Creates a pandas DataFrame for self._sample_lines and for each + non-empty key in self._event_lines. + """ + pd = _check_pandas_installed() + + # First sample should be the first line of the first recording block + first_samp = self._event_lines['START'][0][0] + + # dataframe for samples + self.dataframes['samples'] = pd.DataFrame(self._sample_lines, + columns=col_names['sample']) + if 'HREF' in self._rec_info: + pos_names = (EYELINK_COLS['pos']['left'][:-1] + + EYELINK_COLS['pos']['right'][:-1]) + for col in self.dataframes['samples'].columns: + if col not in pos_names: # 'xpos_left' ... 'ypos_right' + continue + series = self._href_to_radian(self.dataframes['samples'][col]) + self.dataframes['samples'][col] = series + + n_block = len(self._event_lines['START']) + if n_block > 1: + logger.info(f'There are {n_block} recording blocks in this' + ' file. Times between blocks will be annotated with' + f' {self._gap_desc}.') + # if there is more than 1 recording block we must account for + # the missing timestamps and samples bt the blocks + self.dataframes['samples'] = _fill_times(self.dataframes + ['samples'], + sfreq=sfreq) + _convert_times(self.dataframes['samples'], first_samp) + + # dataframe for each type of occular event + for event, columns, label in zip(['EFIX', 'ESACC', 'EBLINK'], + [col_names['fixation'], + col_names['saccade'], + col_names['blink']], + ['fixations', + 'saccades', + 'blinks'] + ): + if self._event_lines[event]: # an empty list returns False + self.dataframes[label] = pd.DataFrame(self._event_lines[event], + columns=columns) + _convert_times(self.dataframes[label], first_samp) + + if find_overlaps is True: + if self._tracking_mode == 'monocular': + raise ValueError('find_overlaps is only valid with' + ' binocular recordings, this file is' + f' {self._tracking_mode}') + df = _find_overlaps(self.dataframes[label], + max_time=threshold) + self.dataframes[label] = df + + else: + logger.info(f'No {label} were found in this file. ' + f'Not returning any info on {label}.') + + # make dataframe for experiment messages + if self._event_lines['MSG']: + msgs = [] + for tokens in self._event_lines['MSG']: + timestamp = tokens[0] + block = tokens[-1] + # if offset token exists, it will be the 1st index + # and is an int or float + if isinstance(tokens[1], (int, float)): + offset = tokens[1] + msg = ' '.join(str(x) for x in tokens[2:-1]) + else: + # there is no offset token + offset = np.nan + msg = ' '.join(str(x) for x in tokens[1:-1]) + msgs.append([timestamp, offset, msg, block]) + + cols = ['time', 'offset', 'event_msg', 'block'] + self.dataframes['messages'] = (pd.DataFrame(msgs, + columns=cols)) + _convert_times(self.dataframes['messages'], first_samp) + + # make dataframe for recording block start, end times + assert (len(self._event_lines['START']) + == len(self._event_lines['END']) + ) + blocks = [[bgn[0], end[0], bgn[-1]] # start, end, block_num + for bgn, end in zip(self._event_lines['START'], + self._event_lines['END']) + ] + cols = ['time', 'end_time', 'block'] + self.dataframes['recording_blocks'] = pd.DataFrame(blocks, + columns=cols) + _convert_times(self.dataframes['recording_blocks'], first_samp) + + # make dataframe for digital input port + if self._event_lines['INPUT']: + cols = ['time', 'DIN', 'block'] + self.dataframes['DINS'] = pd.DataFrame(self._event_lines['INPUT'], + columns=cols) + _convert_times(self.dataframes['DINS'], first_samp) + + # TODO: Make dataframes for other eyelink events (Buttons) + + def _create_info(self, ch_names, sfreq): + """Create info object for RawEyelink.""" + # assign channel type from ch_name + pos_names = (EYELINK_COLS['pos']['left'][:-1] + + EYELINK_COLS['pos']['right'][:-1]) + pupil_names = (EYELINK_COLS['pos']['left'][-1] + + EYELINK_COLS['pos']['right'][-1]) + ch_types = ['eyegaze' if ch in pos_names + else 'pupil' if ch in pupil_names + else 'stim' if ch == 'DIN' + else 'misc' + for ch in ch_names] + info = create_info(ch_names, + sfreq, + ch_types) + # set correct loc for eyepos and pupil channels + for ch_dict in info['chs']: + # loc index 3 can indicate left or right eye + if ch_dict['ch_name'].endswith('left'): # [x,y,pupil]_left + ch_dict['loc'][3] = -1 # left eye + elif ch_dict['ch_name'].endswith('right'): # [x,y,pupil]_right + ch_dict['loc'][3] = 1 # right eye + else: + logger.debug(f"leaving index 3 of loc array as" + f" {ch_dict['loc'][3]} for {ch_dict['ch_name']}") + # loc index 4 can indicate x/y coord + if ch_dict['ch_name'].startswith('x'): + ch_dict['loc'][4] = -1 # x-coord + elif ch_dict['ch_name'].startswith('y'): + ch_dict['loc'][4] = 1 # y-coord + else: + logger.debug(f"leaving index 4 of loc array as" + f" {ch_dict['loc'][4]} for {ch_dict['ch_name']}") + if 'HREF' in self._rec_info: + if ch_dict['ch_name'].startswith(('xpos', 'ypos')): + ch_dict['unit'] = FIFF.FIFF_UNIT_RAD + return info + + def _make_gap_annots(self, key='recording_blocks'): + """Create Annotations for gap periods between recording blocks.""" + df = self.dataframes[key] + gap_desc = self._gap_desc + onsets = df['end_time'].iloc[:-1] + diffs = df['time'].shift(-1) - df['end_time'] + durations = diffs.iloc[:-1] + descriptions = [gap_desc] * len(onsets) + return Annotations(onset=onsets, + duration=durations, + description=descriptions) + + def _make_eyelink_annots(self, df_dict, create_annots, apply_offsets): + """Create Annotations for each df in self.dataframes.""" + valid_descs = ['blinks', 'saccades', 'fixations', 'messages'] + msg = ("create_annotations must be True or a list containing one or" + f" more of {valid_descs}.") + wrong_type = (msg + f' Got a {type(create_annots)} instead.') + if create_annots is True: + descs = valid_descs + else: + assert isinstance(create_annots, list), wrong_type + for desc in create_annots: + assert desc in valid_descs, msg + f" Got '{desc}' instead" + descs = create_annots + + annots = None + for key, df in df_dict.items(): + eye_annot_cond = ((key in ['blinks', 'fixations', 'saccades']) + and (key in descs)) + if eye_annot_cond: + onsets = df['time'] + durations = df['duration'] + # Create annotations for both eyes + descriptions = f'{key[:-1]}_' + df['eye'] # i.e "blink_r" + this_annot = Annotations(onset=onsets, + duration=durations, + description=descriptions) + elif (key in ['messages']) and (key in descs): + if apply_offsets: + if df['offset'].isnull().all(): + logger.warning('There are no offsets for the messages' + f' in {self.fname}. Not applying any' + ' offset') + # If df['offset] is all NaNs, time is not changed + onsets = df['time'] + df['offset'].fillna(0) + else: + onsets = df['time'] + durations = [0] * onsets + descriptions = df['event_msg'] + this_annot = Annotations(onset=onsets, + duration=durations, + description=descriptions) + else: + continue # TODO make df and annotations for Buttons + if not annots: + annots = this_annot + elif annots: + annots += this_annot + if not annots: + logger.warning(f'Annotations for {descs} were requested but' + ' none could be made.') + return + return annots diff --git a/mne/io/eyelink/tests/__init__.py b/mne/io/eyelink/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mne/io/eyelink/tests/test_eyelink.py b/mne/io/eyelink/tests/test_eyelink.py new file mode 100644 index 00000000000..0aa5e4d4e0b --- /dev/null +++ b/mne/io/eyelink/tests/test_eyelink.py @@ -0,0 +1,147 @@ +import pytest + +import numpy as np + +from mne.datasets.testing import data_path, requires_testing_data +from mne.io import read_raw_eyelink +from mne.io.constants import FIFF +from mne.io.pick import _DATA_CH_TYPES_SPLIT +from mne.utils import _check_pandas_installed, requires_pandas + +testing_path = data_path(download=False) +fname = testing_path / 'eyetrack' / 'test_eyelink.asc' +fname_href = testing_path / 'eyetrack' / 'test_eyelink_HREF.asc' + + +def test_eyetrack_not_data_ch(): + """Eyetrack channels are not data channels.""" + msg = 'eyetrack channels are not data channels. Refer to MNE definition'\ + ' of data channels in the glossary section of the documentation.' + assert 'eyegaze' not in _DATA_CH_TYPES_SPLIT, msg + assert 'pupil' not in _DATA_CH_TYPES_SPLIT, msg + + +@requires_testing_data +@requires_pandas +@pytest.mark.parametrize('fname, create_annotations, find_overlaps', + [(fname, False, False), + (fname, True, False), + (fname, True, True), + (fname, ['fixations', 'saccades', 'blinks'], True)]) +def test_eyelink(fname, create_annotations, find_overlaps): + """Test reading eyelink asc files.""" + raw = read_raw_eyelink(fname, create_annotations=create_annotations, + find_overlaps=find_overlaps) + + # First, tests that shouldn't change based on function arguments + assert raw.info['sfreq'] == 500 # True for this file + assert raw.info['meas_date'].month == 3 + assert raw.info['meas_date'].day == 10 + assert raw.info['meas_date'].year == 2022 + + assert len(raw.info['ch_names']) == 6 + assert raw.info['chs'][0]['kind'] == FIFF.FIFFV_EYETRACK_CH + assert raw.info['chs'][0]['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_POS + raw.info['chs'][2]['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_PUPIL + + # x_left + assert all(raw.info['chs'][0]['loc'][3:5] == [-1, -1]) + # pupil_left + assert raw.info['chs'][2]['loc'][3] == -1 + assert np.isnan(raw.info['chs'][2]['loc'][4]) + # y_right + assert all(raw.info['chs'][4]['loc'][3:5] == [1, 1]) + assert 'RawEyelink' in repr(raw) + + # Test some annotation values for accuracy. + if create_annotations is True and find_overlaps: + orig = raw.info['meas_date'] + df = raw.annotations.to_data_frame() + # Convert annot onset datetimes to seconds, relative to orig_time + df['time_in_sec'] = df['onset'].apply(lambda x: x.timestamp() + - orig.timestamp()) + # There is a blink in this data at 8.9 seconds + cond = (df['time_in_sec'] > 8.899) & (df['time_in_sec'] < 8.95) + assert df[cond]['description'].values[0].startswith('blink') + if find_overlaps is True: + df = raw.annotations.to_data_frame() + # these should both be True so long as _find_overlaps is not + # majorly refactored. + assert 'blink_L' in df['description'].unique() + assert 'blink_both' in df['description'].unique() + if isinstance(create_annotations, list) and find_overlaps: + # the last pytest parametrize condition should hit this + df = raw.annotations.to_data_frame() + # Rows 0, 1, 2 should be 'fixation_both', 'saccade_both', 'blink_both' + for i, label in zip([0, 1, 2], ['fixation', 'saccade', 'blink']): + assert df['description'].iloc[i] == f'{label}_both' + + +@requires_testing_data +@requires_pandas +@pytest.mark.parametrize('fname_href', + [(fname_href)]) +def test_radian(fname_href): + """Test converting HREF position data to radians.""" + raw = read_raw_eyelink(fname_href, create_annotations=['blinks']) + # Test channel types + assert raw.get_channel_types() == ['eyegaze', 'eyegaze', 'pupil'] + + # Test that eyegaze channels have a radian unit + assert raw.info['chs'][0]['unit'] == FIFF.FIFF_UNIT_RAD + assert raw.info['chs'][1]['unit'] == FIFF.FIFF_UNIT_RAD + + # Data in radians should range between -1 and 1 + # Test first channel (xpos_right) + assert raw.get_data()[0].min() > -1 + assert raw.get_data()[0].max() < 1 + + +@requires_testing_data +@requires_pandas +@pytest.mark.parametrize('fname', [(fname)]) +def test_fill_times(fname): + """Test use of pd.merge_asof in _fill_times. + + We are merging on floating + point values. pd.merge_asof is used so that any differences in floating + point precision between df['samples']['times'] and the times generated + with np.arange don't result in the time columns not merging + correctly - i.e. 1560687.0 and 1560687.000001 should merge. + """ + from ..eyelink import _fill_times + + raw = read_raw_eyelink(fname, create_annotations=False) + sfreq = raw.info['sfreq'] + # just take first 1000 points for testing + df = raw.dataframes['samples'].iloc[:1000].reset_index(drop=True) + # even during blinks, pupil val is 0, so there should be no nans + # in this column + assert not df['pupil_left'].isna().sum() + nan_count = df['pupil_left'].isna().sum() # i.e 0 + df_merged = _fill_times(df, sfreq) + # If times dont merge correctly, there will be additional rows in + # in df_merged with all nan values + assert df_merged['pupil_left'].isna().sum() == nan_count # i.e. 0 + + +@requires_pandas +def test_find_overlaps(): + """Test finding overlapping occular events between the left and right eyes. + + In the simulated blink df below, the first two rows + will be considered an overlap because the diff() of both the 'time' and + 'end_time' values is <.05 (50ms). the 3rd and 4th rows will not be + considered an overlap because the diff() of the 'time' values is > .05 + (4.20 - 4.14 = .06). The 5th and 6th rows will not be considered an + overlap because they are both left eye events. + """ + from ..eyelink import _find_overlaps + pd = _check_pandas_installed() + blink_df = pd.DataFrame({'eye': ['L', 'R', 'L', 'R', 'L', 'L'], + 'time': [.01, .04, 4.14, 4.20, 6.50, 6.504], + 'end_time': [.05, .08, 4.18, 4.22, 6.60, 6.604]}) + overlap_df = _find_overlaps(blink_df) + assert len(overlap_df['eye'].unique()) == 3 # ['both', 'left', 'right'] + assert len(overlap_df) == 5 # ['both', 'L', 'R', 'L', 'L'] + assert overlap_df['eye'].iloc[0] == 'both' diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index b5e8f844c62..f8c9eba13cc 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -2431,7 +2431,8 @@ def create_info(ch_names, sfreq, ch_types='misc', verbose=None): :term:`data channel `. Currently supported fields are 'ecg', 'bio', 'stim', 'eog', 'misc', 'seeg', 'dbs', 'ecog', 'mag', 'eeg', 'ref_meg', 'grad', 'emg', 'hbr' - or 'hbo'. If str, then all channels are assumed to be of the same type. + 'eyetrack' or 'hbo'. + If str, then all channels are assumed to be of the same type. %(verbose)s Returns diff --git a/mne/io/pick.py b/mne/io/pick.py index 87511710143..d71971155d1 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -99,6 +99,10 @@ def get_channel_type_constants(include_defaults=False): unit=FIFF.FIFF_UNIT_CEL), gsr=dict(kind=FIFF.FIFFV_GALVANIC_CH, unit=FIFF.FIFF_UNIT_S), + eyegaze=dict(kind=FIFF.FIFFV_EYETRACK_CH, + coil_type=FIFF.FIFFV_COIL_EYETRACK_POS), + pupil=dict(kind=FIFF.FIFFV_EYETRACK_CH, + coil_type=FIFF.FIFFV_COIL_EYETRACK_PUPIL) ) if include_defaults: coil_none = dict(coil_type=FIFF.FIFFV_COIL_NONE) @@ -115,6 +119,8 @@ def get_channel_type_constants(include_defaults=False): emg=coil_none, bio=coil_none, fnirs_od=unit_none, + pupil=unit_none, + eyegaze=dict(unit=FIFF.FIFF_UNIT_PX), ) for key, value in defaults.items(): base[key].update(value) @@ -153,6 +159,7 @@ def get_channel_type_constants(include_defaults=False): FIFF.FIFFV_FNIRS_CH: 'fnirs', FIFF.FIFFV_TEMPERATURE_CH: 'temperature', FIFF.FIFFV_GALVANIC_CH: 'gsr', + FIFF.FIFFV_EYETRACK_CH: 'eyetrack', } # How to reduce our categories in channel_type (originally) _second_rules = { @@ -172,7 +179,10 @@ def get_channel_type_constants(include_defaults=False): FIFF.FIFFV_COIL_EEG_BIPOLAR: 'eeg', FIFF.FIFFV_COIL_NONE: 'eeg', # MNE-C backward compat FIFF.FIFFV_COIL_EEG_CSD: 'csd', - }) + }), + 'eyetrack': ('coil_type', {FIFF.FIFFV_COIL_EYETRACK_POS: 'eyegaze', + FIFF.FIFFV_COIL_EYETRACK_PUPIL: 'pupil' + }) } @@ -194,7 +204,7 @@ def channel_type(info, idx): {'grad', 'mag', 'eeg', 'csd', 'stim', 'eog', 'emg', 'ecg', 'ref_meg', 'resp', 'exci', 'ias', 'syst', 'misc', 'seeg', 'dbs', 'bio', 'chpi', 'dipole', 'gof', 'ecog', 'hbo', 'hbr', - 'temperature', 'gsr'} + 'temperature', 'gsr', 'eyetrack'} """ # This is faster than the original _channel_type_old now in test_pick.py # because it uses (at most!) two dict lookups plus one conditional @@ -350,6 +360,21 @@ def _triage_fnirs_pick(ch, fnirs, warned): return False +def _triage_eyetrack_pick(ch, eyetrack): + """Triage an eyetrack pick type.""" + if eyetrack is False: + return False + elif eyetrack is True: + return True + elif ch['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_PUPIL and \ + 'pupil' in eyetrack: + return True + elif ch['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_POS and \ + 'eyegaze' in eyetrack: + return True + return False + + def _check_meg_type(meg, allow_auto=False): """Ensure a valid meg type.""" if isinstance(meg, str): @@ -380,7 +405,7 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, chpi=False, exci=False, ias=False, syst=False, seeg=False, dipole=False, gof=False, bio=False, ecog=False, fnirs=False, csd=False, dbs=False, temperature=False, gsr=False, - include=(), exclude='bads', selection=None): + eyetrack=False, include=(), exclude='bads', selection=None): """Pick channels by type and names. Parameters @@ -412,14 +437,16 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, temperature, gsr): if not isinstance(param, bool): w = ('Parameters for all channel types (with the exception of ' - '"meg", "ref_meg" and "fnirs") must be of type bool, not {}.') + '"meg", "ref_meg", "fnirs", and "eyetrack") must be of type ' + 'bool, not {}.') raise ValueError(w.format(type(param))) param_dict = dict(eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, misc=misc, resp=resp, chpi=chpi, exci=exci, ias=ias, syst=syst, seeg=seeg, dbs=dbs, dipole=dipole, gof=gof, bio=bio, ecog=ecog, csd=csd, - temperature=temperature, gsr=gsr) + temperature=temperature, gsr=gsr, eyetrack=eyetrack) + # avoid triage if possible if isinstance(meg, bool): for key in ('grad', 'mag'): @@ -433,12 +460,14 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, try: pick[k] = param_dict[ch_type] except KeyError: # not so simple - assert ch_type in ( - 'grad', 'mag', 'ref_meg') + _FNIRS_CH_TYPES_SPLIT + assert ch_type in ('grad', 'mag', 'ref_meg') + \ + _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT if ch_type in ('grad', 'mag'): pick[k] = _triage_meg_pick(info['chs'][k], meg) elif ch_type == 'ref_meg': pick[k] = _triage_meg_pick(info['chs'][k], ref_meg) + elif ch_type in ('eyegaze', 'pupil'): + pick[k] = _triage_eyetrack_pick(info['chs'][k], eyetrack) else: # ch_type in ('hbo', 'hbr') pick[k] = _triage_fnirs_pick(info['chs'][k], fnirs, warned) @@ -730,10 +759,11 @@ def channel_indices_by_type(info, picks=None): channel indices. """ idx_by_type = {key: list() for key in _PICK_TYPES_KEYS if - key not in ('meg', 'fnirs')} + key not in ('meg', 'fnirs', 'eyetrack')} idx_by_type.update(mag=list(), grad=list(), hbo=list(), hbr=list(), fnirs_cw_amplitude=list(), fnirs_fd_ac_amplitude=list(), - fnirs_fd_phase=list(), fnirs_od=list()) + fnirs_fd_phase=list(), fnirs_od=list(), + eyegaze=list(), pupil=list()) picks = _picks_to_idx(info, picks, none='all', exclude=(), allow_empty=True) for k in picks: @@ -823,8 +853,10 @@ def _contains_ch_type(info, ch_type): meg_extras = list(_MEG_CH_TYPES_SPLIT) fnirs_extras = list(_FNIRS_CH_TYPES_SPLIT) + et_extras = list(_EYETRACK_CH_TYPES_SPLIT) valid_channel_types = sorted([key for key in _PICK_TYPES_KEYS - if key != 'meg'] + meg_extras + fnirs_extras) + if key != 'meg'] + + meg_extras + fnirs_extras + et_extras) _check_option('ch_type', ch_type, valid_channel_types) if info is None: raise ValueError('Cannot check for channels of type "%s" because info ' @@ -925,22 +957,26 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): meg=True, eeg=True, csd=True, stim=False, eog=False, ecg=False, emg=False, misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False, seeg=True, dipole=False, gof=False, bio=False, ecog=True, fnirs=True, - dbs=True, temperature=False, gsr=False) + dbs=True, temperature=False, gsr=False, eyetrack=True) _PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT) + ['ref_meg']) _MEG_CH_TYPES_SPLIT = ('mag', 'grad', 'planar1', 'planar2') _FNIRS_CH_TYPES_SPLIT = ('hbo', 'hbr', 'fnirs_cw_amplitude', 'fnirs_fd_ac_amplitude', 'fnirs_fd_phase', 'fnirs_od') +_EYETRACK_CH_TYPES_SPLIT = ('eyegaze', 'pupil') _DATA_CH_TYPES_ORDER_DEFAULT = ( 'mag', 'grad', 'eeg', 'csd', 'eog', 'ecg', 'resp', 'emg', 'ref_meg', 'misc', 'stim', 'chpi', 'exci', 'ias', 'syst', 'seeg', 'bio', 'ecog', 'dbs', 'temperature', 'gsr', 'gof', 'dipole', -) + _FNIRS_CH_TYPES_SPLIT + ('whitened',) +) + _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT + ('whitened',) + # Valid data types, ordered for consistency, used in viz/evoked. _VALID_CHANNEL_TYPES = ( 'eeg', 'grad', 'mag', 'seeg', 'eog', 'ecg', 'resp', 'emg', 'dipole', 'gof', - 'bio', 'ecog', 'dbs') + _FNIRS_CH_TYPES_SPLIT + ('misc', 'csd') + 'bio', 'ecog', 'dbs' +) + _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT + ('misc', 'csd') _DATA_CH_TYPES_SPLIT = ( - 'mag', 'grad', 'eeg', 'csd', 'seeg', 'ecog', 'dbs') + _FNIRS_CH_TYPES_SPLIT + 'mag', 'grad', 'eeg', 'csd', 'seeg', 'ecog', 'dbs' +) + _FNIRS_CH_TYPES_SPLIT # Electrode types (e.g., can be average-referenced together or separately) _ELECTRODE_CH_TYPES = ('eeg', 'ecog', 'seeg', 'dbs') diff --git a/mne/io/tests/test_constants.py b/mne/io/tests/test_constants.py index 1f0cd473992..2f05d73b19a 100644 --- a/mne/io/tests/test_constants.py +++ b/mne/io/tests/test_constants.py @@ -21,7 +21,7 @@ # https://github.com/mne-tools/fiff-constants/commits/master REPO = 'mne-tools' -COMMIT = '6d9ca9ce7fb44c63d429c2986a953500743dfb22' +COMMIT = 'e27f68cbf74dbfc5193ad429cc77900a59475181' # These are oddities that we won't address: iod_dups = (355, 359) # these are in both MEGIN and MNE files @@ -55,6 +55,10 @@ 303, # fNIRS optical density 304, # fNIRS frequency domain AC amplitude 305, # fNIRS frequency domain phase + 306, # fNIRS time domain gated amplitude + 307, # fNIRS time domain moments amplitude + 400, # Eye-tracking gaze position + 401, # Eye-tracking pupil size 1000, # For testing the MCG software 2001, # Generic axial gradiometer 3011, # VV prototype wirewound planar sensor diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py index 6bb8111efaa..b9c308dddaa 100644 --- a/mne/preprocessing/__init__.py +++ b/mne/preprocessing/__init__.py @@ -33,3 +33,4 @@ from .interpolate import equalize_bads, interpolate_bridged_electrodes from . import ieeg from ._css import cortical_signal_suppression +from . import eyetracking diff --git a/mne/preprocessing/eyetracking/__init__.py b/mne/preprocessing/eyetracking/__init__.py new file mode 100644 index 00000000000..7c7f5f42765 --- /dev/null +++ b/mne/preprocessing/eyetracking/__init__.py @@ -0,0 +1,7 @@ +"""Eye tracking specific preprocessing functions.""" + +# Authors: Dominik Welke +# +# License: BSD-3-Clause + +from .eyetracking import set_channel_types_eyetrack diff --git a/mne/preprocessing/eyetracking/eyetracking.py b/mne/preprocessing/eyetracking/eyetracking.py new file mode 100644 index 00000000000..346a130564f --- /dev/null +++ b/mne/preprocessing/eyetracking/eyetracking.py @@ -0,0 +1,146 @@ +# Authors: Dominik Welke +# +# License: BSD-3-Clause + + +import numpy as np + +from ...io.constants import FIFF + + +# specific function to set eyetrack channels +def set_channel_types_eyetrack(inst, mapping): + """Define sensor type for eyetrack channels. + + This function can set all eye tracking specific information: + channel type, unit, eye (and x/y component; only for gaze channels) + + Supported channel types: + ``'eyegaze'`` and ``'pupil'`` + + Supported units: + ``'au'``, ``'px'``, ``'deg'``, ``'rad'`` (for eyegaze) + ``'au'``, ``'mm'``, ``'m'`` (for pupil) + + Parameters + ---------- + inst : instance of Raw, Epochs, or Evoked + The data instance. + mapping : dict + A dictionary mapping a channel to a list/tuple including + channel type, unit, eye, [and x/y component] (all as str), e.g., + ``{'l_x': ('eyegaze', 'deg', 'left', 'x')}`` or + ``{'r_pupil': ('pupil', 'au', 'right')}``. + + Returns + ------- + inst : instance of Raw | Epochs | Evoked + The instance, modified in place. + + Notes + ----- + ``inst.set_channel_types()`` to ``'eyegaze'`` or ``'pupil'`` + works as well, but cannot correctly set unit, eye and x/y component. + + Data will be stored in SI units: + if your data comes in ``deg`` (visual angle) it will be converted to + ``rad``, if it is in ``mm`` it will be converted to ``m``. + """ + ch_names = inst.info['ch_names'] + + # allowed + valid_types = ['eyegaze', 'pupil'] # ch_type + valid_units = {'px': ['px', 'pixel'], + 'rad': ['rad', 'radian', 'radians'], + 'deg': ['deg', 'degree', 'degrees'], + 'm': ['m', 'meter', 'meters'], + 'mm': ['mm', 'millimeter', 'millimeters'], + 'au': [None, 'none', 'au', 'arbitrary']} + valid_units['all'] = [item for sublist in valid_units.values() + for item in sublist] + valid_eye = {'l': ['left', 'l'], + 'r': ['right', 'r']} + valid_eye['all'] = [item for sublist in valid_eye.values() + for item in sublist] + valid_xy = {'x': ['x', 'h', 'horizontal'], + 'y': ['y', 'v', 'vertical']} + valid_xy['all'] = [item for sublist in valid_xy.values() + for item in sublist] + + # loop over channels + for ch_name, ch_desc in mapping.items(): + if ch_name not in ch_names: + raise ValueError("This channel name (%s) doesn't exist in " + "info." % ch_name) + c_ind = ch_names.index(ch_name) + + # set ch_type and unit + ch_type = ch_desc[0].lower() + if ch_type not in valid_types: + raise ValueError( + "ch_type must be one of {}. " + "Got '{}' instead.".format(valid_types, ch_type)) + if ch_type == 'eyegaze': + coil_type = FIFF.FIFFV_COIL_EYETRACK_POS + elif ch_type == 'pupil': + coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL + inst.info['chs'][c_ind]['coil_type'] = coil_type + inst.info['chs'][c_ind]['kind'] = FIFF.FIFFV_EYETRACK_CH + + ch_unit = None if (ch_desc[1] is None) else ch_desc[1].lower() + if ch_unit not in valid_units['all']: + raise ValueError( + "unit must be one of {}. Got '{}' instead.".format( + valid_units['all'], ch_unit)) + if ch_unit in valid_units['px']: + unit_new = FIFF.FIFF_UNIT_PX + elif ch_unit in valid_units['rad']: + unit_new = FIFF.FIFF_UNIT_RAD + elif ch_unit in valid_units['deg']: # convert deg to rad (SI) + inst = inst.apply_function(_convert_deg_to_rad, picks=ch_name) + unit_new = FIFF.FIFF_UNIT_RAD + elif ch_unit in valid_units['m']: + unit_new = FIFF.FIFF_UNIT_M + elif ch_unit in valid_units['mm']: # convert mm to m (SI) + inst = inst.apply_function(_convert_mm_to_m, picks=ch_name) + unit_new = FIFF.FIFF_UNIT_M + elif ch_unit in valid_units['au']: + unit_new = FIFF.FIFF_UNIT_NONE + inst.info['chs'][c_ind]['unit'] = unit_new + + # set eye (and x/y-component) + loc = np.array([np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, + np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]) + + ch_eye = ch_desc[2].lower() + if ch_eye not in valid_eye['all']: + raise ValueError( + "eye must be one of {}. Got '{}' instead.".format( + valid_eye['all'], ch_eye)) + if ch_eye in valid_eye['l']: + loc[3] = -1 + elif ch_eye in valid_eye['r']: + loc[3] = 1 + + if ch_type == 'eyegaze': + ch_xy = ch_desc[3].lower() + if ch_xy not in valid_xy['all']: + raise ValueError( + "x/y must be one of {}. Got '{}' instead.".format( + valid_xy['all'], ch_xy)) + if ch_xy in valid_xy['x']: + loc[4] = -1 + elif ch_xy in valid_xy['y']: + loc[4] = 1 + + inst.info['chs'][c_ind]['loc'] = loc + + return inst + + +def _convert_mm_to_m(array): + return array * .001 + + +def _convert_deg_to_rad(array): + return array * np.pi / 180. diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index e144eac0566..63b657f2132 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -1227,22 +1227,24 @@ def test_bad_channels(method, allow_ref_meg): allow_ref_meg=allow_ref_meg) for inst in [raw, epochs]: for ch in chs_bad: + picks_dict = {('eyetrack' if ch in ('eyegaze', 'pupil') + else str(ch)): True} if allow_ref_meg: # Test case for only bad channels picks_bad1 = pick_types(inst.info, meg=False, ref_meg=False, - **{str(ch): True}) + **picks_dict) # Test case for good and bad channels picks_bad2 = pick_types(inst.info, meg=True, ref_meg=True, - **{str(ch): True}) + **picks_dict) else: # Test case for only bad channels picks_bad1 = pick_types(inst.info, meg=False, - **{str(ch): True}) + **picks_dict) # Test case for good and bad channels picks_bad2 = pick_types(inst.info, meg=True, - **{str(ch): True}) + **picks_dict) with pytest.raises(ValueError, match='Invalid channel type'): ica.fit(inst, picks=picks_bad1) diff --git a/mne/simulation/tests/test_raw.py b/mne/simulation/tests/test_raw.py index 31169a403bf..49fdc82f881 100644 --- a/mne/simulation/tests/test_raw.py +++ b/mne/simulation/tests/test_raw.py @@ -325,6 +325,7 @@ def test_degenerate(raw_data): @pytest.mark.slowtest def test_simulate_raw_bem(raw_data): """Test simulation of raw data with BEM.""" + pytest.importorskip('nibabel') raw, src_ss, stc, trans, sphere = raw_data src = setup_source_space('sample', 'oct1', subjects_dir=subjects_dir) for s in src: diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 226afda56cc..96398425ad4 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2661,6 +2661,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Temperature channels. gsr : bool Galvanic skin response channels. +eyetrack : bool | str + Eyetracking channels. If True include all eyetracking channels. If False + (default) include none. If string it can be 'eyegaze' (to include + eye position channels) or 'pupil' (to include pupil-size + channels). include : list of str List of additional channels to include. If empty do not include any. diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 7238d4490a6..ab5f4e76c67 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -52,7 +52,8 @@ from ..fixes import _close_event from ..annotations import _sync_onset from ..io.pick import (_DATA_CH_TYPES_ORDER_DEFAULT, _DATA_CH_TYPES_SPLIT, - _FNIRS_CH_TYPES_SPLIT, _VALID_CHANNEL_TYPES) + _FNIRS_CH_TYPES_SPLIT, _EYETRACK_CH_TYPES_SPLIT, + _VALID_CHANNEL_TYPES) from ..utils import Bunch, _click_ch_name, logger from . import plot_sensors from ._figure import BrowserBase @@ -2191,6 +2192,8 @@ def _split_picks_by_type(inst, picks, units, scalings, titles): pick_kwargs['meg'] = ch_type elif ch_type in _FNIRS_CH_TYPES_SPLIT: pick_kwargs['fnirs'] = ch_type + elif ch_type in _EYETRACK_CH_TYPES_SPLIT: + pick_kwargs['eyetrack'] = ch_type else: pick_kwargs[ch_type] = True these_picks = pick_types(inst.info, **pick_kwargs) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 21ed14dd7fa..859db3ce646 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -961,6 +961,7 @@ def test_plotting_order_consistency(): pick_data_set = set(_PICK_TYPES_DATA_DICT) pick_data_set.remove('meg') pick_data_set.remove('fnirs') + pick_data_set.remove('eyetrack') missing = pick_data_set.difference(set(_DATA_CH_TYPES_ORDER_DEFAULT)) assert missing == set() diff --git a/tools/circleci_download.sh b/tools/circleci_download.sh index a8eef47fd42..421f6f63ec1 100755 --- a/tools/circleci_download.sh +++ b/tools/circleci_download.sh @@ -107,6 +107,9 @@ else if [[ $(cat $FNAME | grep -x ".*datasets.*erp_core.*" | wc -l) -gt 0 ]]; then python -c "import mne; print(mne.datasets.erp_core.data_path(update_path=True))"; fi; + if [[ $(cat $FNAME | grep -x ".*datasets.*eyelink.*" | wc -l) -gt 0 ]]; then + python -c "import mne; print(mne.datasets.eyelink.data_path(update_path=True))"; + fi; if [[ $(cat $FNAME | grep -x ".*datasets.*ucl_opm_auditory.*" | wc -l) -gt 0 ]]; then python -c "import mne; print(mne.datasets.ucl_opm_auditory.data_path(update_path=True))"; fi; diff --git a/tutorials/io/70_reading_eyetracking_data.py b/tutorials/io/70_reading_eyetracking_data.py new file mode 100644 index 00000000000..f84477255b6 --- /dev/null +++ b/tutorials/io/70_reading_eyetracking_data.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +r""" +.. _tut-importing-eyetracking-data: + +======================================= +Importing Data from Eyetracking devices +======================================= + +Eyetracking devices record a persons point of gaze, usually in relation to a +screen. Typically, gaze position (also referred to as eye or pupil position) +and pupil size are recorded as separate channels. This section describes how to +read data from supported eyetracking manufacturers. + +MNE-Python provides functions for reading eyetracking data. When possible, +MNE-Python will internally convert and store eyetracking data according to an +SI unit (for example radians for position data, and meters for pupil size). + +.. note:: If you have eye tracking data in a format that MNE does not support + yet, you can try reading it using other tools and create an MNE + object from a numpy array. Then you can use + :func:`mne.preprocessing.eyetracking.set_channel_types_eyetrack` + to assign the correct eyetrack channel types. + +.. seealso:: Some MNE functions may not be available to eyetracking and other + physiological data, because MNE does not consider them to be data + channels. See the :doc:`glossary ` for more + information. + +.. _import-eyelink_asc: + +SR Research (Eyelink) (.asc) +============================ + +.. note:: MNE-Python currently only supports reading Eyelink eyetracking data + stored in the ASCII (.asc) format. + +Eyelink recordings are stored in the Eyelink Data Format (EDF; .edf), which are +binary files and thus relatively complex to support. To make the data in EDF +files accessible, Eyelink provides the application EDF2ASC, which converts EDF +files to a plain text ASCII format (.asc). These files can be imported +into MNE using :func:`mne.io.read_raw_eyelink`. + +.. note:: The Eyelink Data Format (EDF), should not be confused + with the European Data Format, the common EEG data format that also + uses the .edf extension. + +Supported measurement types from Eyelink files include eye position, pupil +size, saccadic velocity, resolution, and head position (for recordings +collected in remote mode). Eyelink files often report occular events (blinks, +saccades, and fixations), MNE will store these events as `mne.Annotations`. +For More information on the various measurement types that can be present in +Eyelink files, read below. + +Eye Position Data +----------------- + +Eyelink samples can report eye position data in pixels, units of visual +degrees, or as raw pupil coordinates. Samples are written as (x, y) coordinate +pairs (or two pairs for binocular data). The type of position data present in +an ASCII file will be detected automatically by MNE. The three types of +position data are explained below. + +Gaze +^^^^ +Gaze position data report the estimated (x, y) pixel coordinates of the +participants's gaze on the stimulus screen, compensating for head position +changes and distance from the screen. This datatype may be preferable if you +are interested in knowing where the participant was looking at on the stimulus +screen. The default (0, 0) location for Eyelink systems is at the top left of +the screen. + +This may be best demonstrated with an example. In the file plotted below, +eyetracking data was recorded while the participant read text on a display. +In this file, as the participant read the each line from left to right, the +x-coordinate increased. When the participant moved their gaze down to read a +new line, the y-coordinate *increased*, which is why the ``ypos_right`` channel +in the plot below increases over time (for example, at about 4-seconds, and +at about 8-seconds). +""" + +# %% +from mne.io import read_raw_eyelink +from mne.datasets import misc + +# %% +fpath = misc.data_path() / 'eyetracking' / 'eyelink' +raw = read_raw_eyelink(fpath / 'px_textpage_ws.asc', + create_annotations=['blinks']) +custom_scalings = dict(eyegaze=1e3) +raw.pick_types(eyetrack=True).plot(scalings=custom_scalings) + + +# %% +# .. important:: The (0, 0) pixel coordinates are at the top-left of the +# trackable area of the screen. Gaze towards lower areas of the +# the screen will yield a relatively higher y-coordinate. +# +# Note that we passed a custom `dict` to the ``'scalings'`` argument of +# `mne.io.Raw.plot`. This is because MNE's default plot scalings for eye +# position data are calibrated for HREF data, which are stored in radians +# (read below). + + +# %% +# Head-Referenced Eye Angle (HREF) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# HREF position data measures eye rotation angles relative to the head. It does +# not take into account changes in subject head position and angle, or distance +# from the stimulus screen. This datatype might be preferable for analyses +# that are interested in eye movement velocities and amplitudes, or for +# simultaneous and EEG/MEG eyetracking recordings where eye position data are +# used to identify EOG artifacts. +# +# HREF coordinates are stored in the ASCII file as integer values, with 260 or +# more units per visual degree, however MNE will convert and store these +# coordinates in radians. The (0, 0) point of HREF data is arbitrary, as the +# relationship between the screen position and the coordinates changes as the +# subject's head moves. +# +# Below is the same text reading recording that we plotted above, except a new +# ASCII file was generated, this time using HREF eye position data. + + +# %% +fpath = misc.data_path() / 'eyetracking' / 'eyelink' +raw = read_raw_eyelink(fpath / 'HREF_textpage_ws.asc', + create_annotations=['blinks']) +raw.pick_types(eyetrack=True).plot() + +# %% +# Pupil Position +# ^^^^^^^^^^^^^^ +# +# Pupil position data contains (x, y) coordinate pairs from the eye camera. +# It has not been converted to pixels (gaze) or eye angles (HREF). Most use +# cases do not require this data type, and caution should be taken when +# analyzing raw pupil positions. Note that when plotting data from a +# ``Raw`` object containing raw pupil position data, the plot scalings +# will likely be incorrect. You can pass custom scalings into the ``scalings`` +# parameter of `mne.io.Raw.plot` so that the signals are legible when plotting. + +# %% +# .. warning:: If a calibration was not performed prior to data collection, the +# EyeLink system cannot convert raw pupil position data to pixels +# (gaze) or eye angle (HREF). + +# %% +# Pupil Size Data +# --------------- +# Pupil size is measured by the EyeLink system at up to 500 samples per second. +# It may be reported as pupil *area*, or pupil *diameter* (i.e. the diameter +# of a circle/ellipse model fit to the pupil area). +# Which of these datatypes you get is specified by your recording- and/or your +# EDF2ASC settings. The pupil size data is not calibrated and reported in +# arbitrary units. Typical pupil *area* data range between 800 to 2000 units, +# with a precision of 1 unit, while pupil *diameter* data range between +# 1800-3000 units. +# +# Velocity, resolution, and head position data +# -------------------------------------------- +# Eyelink files can produce data on saccadic velocity, resolution, and head +# position for each sample in the file. MNE will read in these data if they are +# present in the file, but will label their channel types as ``'misc'``. +# +# .. warning:: Eyelink's EDF2ASC API allows for modification of the data +# and format that is converted to ASCII. However, MNE-Python +# assumes a specific structure, which the default parameters of +# EDF2ASC follow. ASCII files should be tab-deliminted, and both +# Samples and Events should be output. If the data were recorded +# at 2000Hz, timestamps should be floating point numbers. Manual +# modification of ASCII conversion via EDF2ASC is not recommended. diff --git a/tutorials/preprocessing/90_eyetracking_data.py b/tutorials/preprocessing/90_eyetracking_data.py new file mode 100644 index 00000000000..8c8a2e3b755 --- /dev/null +++ b/tutorials/preprocessing/90_eyetracking_data.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +""" +.. _tut-eyetrack: + +=========================================== +Working with eye tracker data in MNE-Python +=========================================== + +In this tutorial we will load some eye tracker data and plot the average +pupil response to light flashes (i.e. the pupillary light reflex). + +""" # noqa: E501 +# Authors: Dominik Welke +# Scott Huberty +# +# License: BSD-3-Clause + +# %% +# Data loading +# ------------ +# +# First we will load an eye tracker recording from SR research's proprietary +# ``'.asc'`` file format. +# +# By default, Eyelink files will output events for occular events (blinks, +# saccades, fixations), and experiment messages. MNE will store these events +# as `mne.Annotations`. If we are only interested in certain event types from +# the Eyelink file, we can select for these using the ``'create_annotations'`` +# argument of `mne.io.read_raw_eyelink`. Here, we will only create annotations +# for blinks, and experiment messages. +# +# The info structure tells us we loaded a monocular recording with 2 +# ``'eyegaze'``, channels (X/Y), 1 ``'pupil'`` channel, and 1 ``'stim'`` +# channel. + +from mne import Epochs, find_events +from mne.io import read_raw_eyelink +from mne.datasets.eyelink import data_path + +eyelink_fname = data_path() / 'mono_multi-block_multi-DINS.asc' + +raw = read_raw_eyelink(eyelink_fname, + create_annotations=['blinks', 'messages']) +raw.crop(tmin=0, tmax=146) + +# %% +# Get stimulus events from DIN channel +# ------------------------------------ +# +# Eyelink eye trackers have a DIN port that can be used to feed in stimulus +# or response timings. :func:`mne.io.read_raw_eyelink` loads this data as a +# ``'stim'`` channel. Alternatively, the onset of stimulus events could be sent +# to the eyetracker as ``messages`` - these can be read in as +# `mne.Annotations`. +# +# In the example data, the DIN channel contains the onset of light flashes on +# the screen. We now extract these events to visualize the pupil response. + +events = find_events(raw, 'DIN', + shortest_event=1, + min_duration=.02, + uint_cast=True) +event_dict = {'flash': 3} + + +# %% +# Plot raw data +# ------------- +# +# As the following plot shows, we now have a raw object with the eye tracker +# data, eyeblink annotations and stimulus events (from the DIN channel). +# +# The plot also shows us that there is some noise in the data (not always +# categorized as blinks). Also, notice that we have passed a custom `dict` into +# the scalings argument of ``raw.plot``. This is necessary to make the eyegaze +# channel traces legible when plotting, since the file contains pixel position +# data (as opposed to eye angles, which are reported in radians). + +raw.plot(events=events, event_id={'Flash': 3}, event_color='g', + start=25, duration=45, scalings=dict(eyegaze=1e3)) + + +# %% +# Plot average pupil response +# --------------------------- +# +# We now visualize the pupillary light reflex. +# Therefore, we select only the pupil channel and plot the evoked response to +# the light flashes. +# +# As we see, there is a prominent decrease in pupil size following the +# stimulation. The noise starting about 2.5 s after stimulus onset stems from +# eyeblinks and artifacts in some of the 16 trials. + +epochs = Epochs(raw, events, tmin=-0.3, tmax=5, + event_id=event_dict, preload=True) +epochs.pick_types(eyetrack='pupil') +epochs.average().plot() + +# %% +# It is important to note that pupil size data are reported by Eyelink (and +# stored internally by MNE) as arbitrary units (AU). While it often can be +# preferable to convert pupil size data to millimeters, this requires +# information that is not always present in the file. MNE does not currently +# provide methods to convert pupil size data. +# See :ref:`tut-importing-eyetracking-data` for more information on pupil size +# data. From 623895d0fd8894196ca638ad10ddccc98d488b4c Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Mon, 27 Mar 2023 20:50:00 +0200 Subject: [PATCH 0008/1125] MAINT: `coding: utf-8` is implicit in Python 3 (#11599) --- doc/conf.py | 2 -- doc/sphinxext/flow_diagram.py | 2 -- doc/sphinxext/gen_commands.py | 2 -- doc/sphinxext/gen_names.py | 2 -- doc/sphinxext/gh_substitutions.py | 2 -- doc/sphinxext/newcontrib_substitutions.py | 2 -- examples/datasets/brainstorm_data.py | 1 - examples/datasets/hf_sef_data.py | 1 - examples/datasets/limo_data.py | 1 - examples/datasets/opm_data.py | 1 - examples/datasets/spm_faces_dataset_sgskip.py | 1 - examples/decoding/decoding_csp_eeg.py | 1 - examples/decoding/decoding_csp_timefreq.py | 1 - examples/decoding/decoding_rsa_sgskip.py | 1 - examples/decoding/decoding_spatio_temporal_source.py | 1 - examples/decoding/decoding_spoc_CMC.py | 1 - examples/decoding/decoding_time_generalization_conditions.py | 1 - examples/decoding/decoding_unsupervised_spatial_filter.py | 1 - examples/decoding/decoding_xdawn_eeg.py | 1 - examples/decoding/ems_filtering.py | 1 - examples/decoding/linear_model_patterns.py | 1 - examples/decoding/receptive_field_mtrf.py | 1 - examples/decoding/ssd_spatial_filters.py | 1 - examples/forward/forward_sensitivity_maps.py | 1 - examples/forward/left_cerebellum_volume_source.py | 1 - examples/forward/source_space_morphing.py | 1 - examples/inverse/compute_mne_inverse_epochs_in_label.py | 1 - examples/inverse/compute_mne_inverse_raw_in_label.py | 1 - examples/inverse/compute_mne_inverse_volume.py | 1 - examples/inverse/custom_inverse_solver.py | 1 - examples/inverse/dics_epochs.py | 1 - examples/inverse/dics_source_power.py | 1 - examples/inverse/evoked_ers_source_power.py | 1 - examples/inverse/gamma_map_inverse.py | 1 - examples/inverse/label_activation_from_stc.py | 1 - examples/inverse/label_from_stc.py | 1 - examples/inverse/label_source_activations.py | 1 - examples/inverse/mixed_norm_inverse.py | 1 - examples/inverse/mixed_source_space_inverse.py | 1 - examples/inverse/mne_cov_power.py | 1 - examples/inverse/morph_surface_stc.py | 1 - examples/inverse/morph_volume_stc.py | 1 - examples/inverse/multi_dipole_model.py | 1 - examples/inverse/multidict_reweighted_tfmxne.py | 1 - examples/inverse/psf_ctf_label_leakage.py | 1 - examples/inverse/psf_ctf_vertices.py | 1 - examples/inverse/psf_ctf_vertices_lcmv.py | 1 - examples/inverse/psf_volume.py | 1 - examples/inverse/rap_music.py | 1 - examples/inverse/read_inverse.py | 1 - examples/inverse/read_stc.py | 1 - examples/inverse/resolution_metrics.py | 1 - examples/inverse/resolution_metrics_eegmeg.py | 1 - examples/inverse/snr_estimate.py | 1 - examples/inverse/source_space_snr.py | 1 - examples/inverse/time_frequency_mixed_norm_inverse.py | 1 - examples/inverse/vector_mne_solution.py | 1 - examples/io/elekta_epochs.py | 1 - examples/io/read_neo_format.py | 1 - examples/io/read_noise_covariance_matrix.py | 1 - examples/io/read_xdf.py | 1 - examples/preprocessing/css.py | 1 - examples/preprocessing/define_target_events.py | 1 - examples/preprocessing/eeg_bridging.py | 1 - examples/preprocessing/eeg_csd.py | 1 - examples/preprocessing/eog_artifact_histogram.py | 1 - examples/preprocessing/eog_regression.py | 1 - examples/preprocessing/find_ref_artifacts.py | 1 - examples/preprocessing/fnirs_artifact_removal.py | 1 - examples/preprocessing/ica_comparison.py | 1 - examples/preprocessing/interpolate_bad_channels.py | 1 - examples/preprocessing/locate_ieeg_micro.py | 1 - examples/preprocessing/movement_compensation.py | 1 - examples/preprocessing/movement_detection.py | 1 - examples/preprocessing/muscle_detection.py | 1 - examples/preprocessing/muscle_ica.py | 1 - examples/preprocessing/otp.py | 1 - examples/preprocessing/shift_evoked.py | 1 - examples/preprocessing/virtual_evoked.py | 1 - examples/preprocessing/xdawn_denoising.py | 1 - examples/simulation/simulate_evoked_data.py | 1 - examples/simulation/simulate_raw_data.py | 1 - examples/simulation/simulated_raw_data_using_subject_anatomy.py | 1 - examples/simulation/source_simulator.py | 1 - examples/stats/cluster_stats_evoked.py | 1 - examples/stats/fdr_stats_evoked.py | 1 - examples/stats/linear_regression_raw.py | 1 - examples/stats/sensor_permutation_test.py | 1 - examples/stats/sensor_regression.py | 1 - examples/time_frequency/compute_csd.py | 1 - examples/time_frequency/compute_source_psd_epochs.py | 1 - examples/time_frequency/source_label_time_frequency.py | 1 - examples/time_frequency/source_power_spectrum.py | 1 - examples/time_frequency/source_power_spectrum_opm.py | 1 - examples/time_frequency/source_space_time_frequency.py | 1 - examples/time_frequency/temporal_whitening.py | 1 - examples/time_frequency/time_frequency_erds.py | 1 - examples/time_frequency/time_frequency_global_field_power.py | 1 - examples/time_frequency/time_frequency_simulated.py | 1 - examples/visualization/3d_to_2d.py | 1 - examples/visualization/brain.py | 1 - examples/visualization/channel_epochs_image.py | 1 - examples/visualization/eeg_on_scalp.py | 1 - examples/visualization/evoked_arrowmap.py | 1 - examples/visualization/evoked_topomap.py | 1 - examples/visualization/evoked_whitening.py | 1 - examples/visualization/meg_sensors.py | 1 - examples/visualization/mne_helmet.py | 1 - examples/visualization/montage_sgskip.py | 1 - examples/visualization/parcellation.py | 1 - examples/visualization/publication_figure.py | 1 - examples/visualization/roi_erpimage_by_rt.py | 1 - examples/visualization/sensor_noise_level.py | 1 - examples/visualization/ssp_projs_sensitivity_map.py | 1 - examples/visualization/topo_compare_conditions.py | 1 - examples/visualization/topo_customized.py | 1 - examples/visualization/xhemi.py | 1 - logo/generate_mne_logos.py | 1 - mne/_freesurfer.py | 1 - mne/_ola.py | 1 - mne/beamformer/resolution_matrix.py | 1 - mne/beamformer/tests/test_resolution_matrix.py | 1 - mne/chpi.py | 1 - mne/commands/tests/test_commands.py | 1 - mne/conftest.py | 1 - mne/coreg.py | 1 - mne/datasets/_fsaverage/base.py | 1 - mne/datasets/_infant/base.py | 1 - mne/datasets/_phantom/base.py | 1 - mne/datasets/hf_sef/hf_sef.py | 1 - mne/datasets/sleep_physionet/_utils.py | 1 - mne/datasets/sleep_physionet/age.py | 1 - mne/datasets/sleep_physionet/temazepam.py | 1 - mne/decoding/csp.py | 1 - mne/decoding/mixin.py | 2 -- mne/decoding/receptive_field.py | 1 - mne/decoding/time_delaying_ridge.py | 1 - mne/decoding/transformer.py | 1 - mne/dipole.py | 1 - mne/epochs.py | 2 -- mne/evoked.py | 1 - mne/export/_brainvision.py | 1 - mne/export/_edf.py | 1 - mne/export/_eeglab.py | 1 - mne/export/_egimff.py | 1 - mne/export/_export.py | 1 - mne/export/tests/test_export.py | 1 - mne/forward/_compute_forward.py | 1 - mne/forward/_field_interpolation.py | 1 - mne/gui/_core.py | 1 - mne/gui/_ieeg_locate.py | 1 - mne/gui/tests/test_core.py | 1 - mne/gui/tests/test_ieeg_locate.py | 1 - mne/io/_digitization.py | 1 - mne/io/artemis123/tests/test_artemis123.py | 1 - mne/io/base.py | 1 - mne/io/brainvision/brainvision.py | 1 - mne/io/brainvision/tests/test_brainvision.py | 1 - mne/io/cnt/tests/test_cnt.py | 1 - mne/io/curry/tests/test_curry.py | 1 - mne/io/edf/edf.py | 1 - mne/io/edf/tests/test_edf.py | 1 - mne/io/egi/events.py | 1 - mne/io/egi/general.py | 1 - mne/io/egi/tests/test_egi.py | 1 - mne/io/eximia/tests/test_eximia.py | 1 - mne/io/fiff/tests/test_raw_fiff.py | 1 - mne/io/hitachi/tests/test_hitachi.py | 1 - mne/io/meas_info.py | 1 - mne/io/nedf/nedf.py | 1 - mne/io/nedf/tests/test_nedf.py | 1 - mne/io/nicolet/tests/test_nicolet.py | 1 - mne/io/nihon/tests/test_nihon.py | 1 - mne/io/nirx/tests/test_nirx.py | 1 - mne/io/open.py | 1 - mne/io/persyst/tests/test_persyst.py | 1 - mne/io/pick.py | 1 - mne/io/proc_history.py | 1 - mne/io/snirf/tests/test_snirf.py | 1 - mne/io/tests/test_meas_info.py | 1 - mne/io/tests/test_raw.py | 1 - mne/io/tests/test_show_fiff.py | 1 - mne/io/tests/test_utils.py | 1 - mne/io/tests/test_write.py | 1 - mne/io/utils.py | 1 - mne/io/what.py | 1 - mne/minimum_norm/inverse.py | 1 - mne/minimum_norm/resolution_matrix.py | 1 - mne/minimum_norm/spatial_resolution.py | 1 - mne/minimum_norm/tests/test_resolution_matrix.py | 1 - mne/minimum_norm/tests/test_resolution_metrics.py | 1 - mne/minimum_norm/tests/test_snr.py | 1 - mne/preprocessing/_fine_cal.py | 1 - mne/preprocessing/ica.py | 1 - mne/preprocessing/ieeg/tests/test_projection.py | 1 - mne/preprocessing/ieeg/tests/test_volume.py | 1 - mne/preprocessing/maxwell.py | 1 - mne/preprocessing/otp.py | 1 - mne/preprocessing/realign.py | 1 - mne/preprocessing/tests/test_csd.py | 1 - mne/rank.py | 1 - mne/report/tests/test_report.py | 1 - mne/simulation/raw.py | 1 - mne/stats/_adjacency.py | 2 -- mne/stats/cluster_level.py | 1 - mne/stats/tests/test_adjacency.py | 2 -- mne/tests/test_docstring_parameters.py | 1 - mne/tests/test_epochs.py | 1 - mne/tests/test_event.py | 1 - mne/tests/test_morph.py | 1 - mne/tests/test_parallel.py | 1 - mne/tests/test_source_estimate.py | 1 - mne/tests/test_source_space.py | 2 -- mne/tests/test_transforms.py | 1 - mne/time_frequency/csd.py | 1 - mne/time_frequency/spectrum.py | 1 - mne/time_frequency/tests/test_multitaper.py | 1 - mne/transforms.py | 1 - mne/utils/_bunch.py | 1 - mne/utils/_logging.py | 1 - mne/utils/_testing.py | 1 - mne/utils/check.py | 1 - mne/utils/config.py | 1 - mne/utils/dataframe.py | 1 - mne/utils/docs.py | 1 - mne/utils/fetching.py | 1 - mne/utils/linalg.py | 1 - mne/utils/misc.py | 1 - mne/utils/mixin.py | 1 - mne/utils/numerics.py | 1 - mne/utils/progressbar.py | 1 - mne/utils/tests/test_bunch.py | 1 - mne/utils/tests/test_progressbar.py | 1 - mne/utils/tests/test_testing.py | 1 - mne/viz/_3d.py | 1 - mne/viz/_3d_overlay.py | 1 - mne/viz/_brain/tests/test_brain.py | 1 - mne/viz/_brain/tests/test_notebook.py | 1 - mne/viz/_dipole.py | 1 - mne/viz/_figure.py | 1 - mne/viz/_mpl_figure.py | 1 - mne/viz/backends/_utils.py | 1 - mne/viz/backends/tests/test_abstract.py | 1 - mne/viz/evoked.py | 1 - mne/viz/misc.py | 1 - mne/viz/utils.py | 1 - tutorials/clinical/10_ieeg_localize.py | 1 - tutorials/clinical/20_seeg.py | 1 - tutorials/clinical/30_ecog.py | 1 - tutorials/clinical/60_sleep.py | 1 - tutorials/epochs/10_epochs_overview.py | 1 - tutorials/epochs/15_baseline_regression.py | 1 - tutorials/epochs/20_visualize_epochs.py | 1 - tutorials/epochs/30_epochs_metadata.py | 1 - tutorials/epochs/40_autogenerate_metadata.py | 1 - tutorials/epochs/50_epochs_to_data_frame.py | 1 - tutorials/epochs/60_make_fixed_length_epochs.py | 1 - tutorials/evoked/10_evoked_overview.py | 1 - tutorials/evoked/20_visualize_evoked.py | 1 - tutorials/evoked/30_eeg_erp.py | 1 - tutorials/evoked/40_whitened.py | 1 - tutorials/forward/10_background_freesurfer.py | 1 - tutorials/forward/20_source_alignment.py | 1 - tutorials/forward/25_automated_coreg.py | 1 - tutorials/forward/30_forward.py | 1 - tutorials/forward/35_eeg_no_mri.py | 1 - tutorials/forward/50_background_freesurfer_mne.py | 1 - tutorials/forward/80_fix_bem_in_blender.py | 1 - tutorials/forward/90_compute_covariance.py | 1 - tutorials/intro/10_overview.py | 1 - tutorials/intro/15_inplace.py | 1 - tutorials/intro/20_events_from_raw.py | 1 - tutorials/intro/30_info.py | 1 - tutorials/intro/40_sensor_locations.py | 1 - tutorials/intro/50_configure_mne.py | 1 - tutorials/intro/70_report.py | 1 - tutorials/inverse/10_stc_class.py | 1 - tutorials/inverse/20_dipole_fit.py | 1 - tutorials/inverse/30_mne_dspm_loreta.py | 1 - tutorials/inverse/35_dipole_orientations.py | 1 - tutorials/inverse/40_mne_fixed_free.py | 1 - tutorials/inverse/50_beamformer_lcmv.py | 1 - tutorials/inverse/60_visualize_stc.py | 1 - tutorials/inverse/70_eeg_mri_coords.py | 1 - tutorials/inverse/80_brainstorm_phantom_elekta.py | 1 - tutorials/inverse/85_brainstorm_phantom_ctf.py | 1 - tutorials/inverse/90_phantom_4DBTi.py | 1 - tutorials/io/10_reading_meg_data.py | 1 - tutorials/io/20_reading_eeg_data.py | 1 - tutorials/io/30_reading_fnirs_data.py | 1 - tutorials/io/60_ctf_bst_auditory.py | 1 - tutorials/machine-learning/30_strf.py | 1 - tutorials/machine-learning/50_decoding.py | 1 - tutorials/preprocessing/10_preprocessing_overview.py | 1 - tutorials/preprocessing/15_handling_bad_channels.py | 1 - tutorials/preprocessing/20_rejecting_bad_data.py | 1 - tutorials/preprocessing/25_background_filtering.py | 1 - tutorials/preprocessing/30_filtering_resampling.py | 1 - tutorials/preprocessing/35_artifact_correction_regression.py | 1 - tutorials/preprocessing/40_artifact_correction_ica.py | 1 - tutorials/preprocessing/45_projectors_background.py | 1 - tutorials/preprocessing/50_artifact_correction_ssp.py | 1 - tutorials/preprocessing/55_setting_eeg_reference.py | 1 - tutorials/preprocessing/59_head_positions.py | 1 - tutorials/preprocessing/60_maxwell_filtering_sss.py | 1 - tutorials/preprocessing/70_fnirs_processing.py | 1 - tutorials/preprocessing/80_opm_processing.py | 1 - tutorials/raw/10_raw_overview.py | 1 - tutorials/raw/20_event_arrays.py | 1 - tutorials/raw/30_annotate_raw.py | 1 - tutorials/raw/40_visualize_raw.py | 1 - tutorials/simulation/10_array_objs.py | 1 - tutorials/simulation/70_point_spread.py | 1 - tutorials/simulation/80_dics.py | 1 - tutorials/stats-sensor-space/10_background_stats.py | 1 - tutorials/stats-sensor-space/20_erp_stats.py | 1 - tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py | 1 - tutorials/stats-sensor-space/50_cluster_between_time_freq.py | 1 - tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py | 1 - tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py | 1 - tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py | 1 - tutorials/stats-source-space/30_cluster_ftest_spatiotemporal.py | 1 - .../stats-source-space/60_cluster_rmANOVA_spatiotemporal.py | 1 - tutorials/time-freq/10_spectrum_class.py | 1 - tutorials/time-freq/20_sensors_time_frequency.py | 1 - tutorials/time-freq/50_ssvep.py | 1 - 326 files changed, 337 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 1b4f7ad2ed0..2e6680a6b05 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full diff --git a/doc/sphinxext/flow_diagram.py b/doc/sphinxext/flow_diagram.py index 55bbf4aa8d0..9adb8636e2f 100644 --- a/doc/sphinxext/flow_diagram.py +++ b/doc/sphinxext/flow_diagram.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import os from os import path as op diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py index e0169a44b77..e5b2ed391b6 100644 --- a/doc/sphinxext/gen_commands.py +++ b/doc/sphinxext/gen_commands.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import glob from importlib import import_module import os diff --git a/doc/sphinxext/gen_names.py b/doc/sphinxext/gen_names.py index 01785598430..92c155b8f52 100644 --- a/doc/sphinxext/gen_names.py +++ b/doc/sphinxext/gen_names.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import os from os import path as op diff --git a/doc/sphinxext/gh_substitutions.py b/doc/sphinxext/gh_substitutions.py index 2c1cbf1f76c..f0c6a05c5ba 100644 --- a/doc/sphinxext/gh_substitutions.py +++ b/doc/sphinxext/gh_substitutions.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from docutils.nodes import reference from docutils.parsers.rst.roles import set_classes diff --git a/doc/sphinxext/newcontrib_substitutions.py b/doc/sphinxext/newcontrib_substitutions.py index 559a14bafa4..68595e74bdb 100644 --- a/doc/sphinxext/newcontrib_substitutions.py +++ b/doc/sphinxext/newcontrib_substitutions.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from docutils.nodes import reference, strong, target diff --git a/examples/datasets/brainstorm_data.py b/examples/datasets/brainstorm_data.py index 61bf9087476..949f2511a88 100644 --- a/examples/datasets/brainstorm_data.py +++ b/examples/datasets/brainstorm_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-brainstorm-raw: diff --git a/examples/datasets/hf_sef_data.py b/examples/datasets/hf_sef_data.py index fee1630277c..9857d22d09d 100644 --- a/examples/datasets/hf_sef_data.py +++ b/examples/datasets/hf_sef_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-hf-sef-data: diff --git a/examples/datasets/limo_data.py b/examples/datasets/limo_data.py index c897bc69cf8..d5670f62ffe 100644 --- a/examples/datasets/limo_data.py +++ b/examples/datasets/limo_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-limo-data: diff --git a/examples/datasets/opm_data.py b/examples/datasets/opm_data.py index b849c5b098c..ec6daab1037 100644 --- a/examples/datasets/opm_data.py +++ b/examples/datasets/opm_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-opm-somatosensory: diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset_sgskip.py index ce538d4b382..875cc2eb5d5 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset_sgskip.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-spm-faces: diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index dcb91e66ad4..beef85bbdc0 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-decoding-csp-eeg: diff --git a/examples/decoding/decoding_csp_timefreq.py b/examples/decoding/decoding_csp_timefreq.py index a5cb84d6f4f..6407646910b 100644 --- a/examples/decoding/decoding_csp_timefreq.py +++ b/examples/decoding/decoding_csp_timefreq.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-decoding-csp-eeg-timefreq: diff --git a/examples/decoding/decoding_rsa_sgskip.py b/examples/decoding/decoding_rsa_sgskip.py index cd53cf5382e..ba1be187372 100644 --- a/examples/decoding/decoding_rsa_sgskip.py +++ b/examples/decoding/decoding_rsa_sgskip.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-rsa-noplot: diff --git a/examples/decoding/decoding_spatio_temporal_source.py b/examples/decoding/decoding_spatio_temporal_source.py index 01187ea1e47..476b4d170c6 100644 --- a/examples/decoding/decoding_spatio_temporal_source.py +++ b/examples/decoding/decoding_spatio_temporal_source.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-dec-st-source: diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index 78a6918f22a..f1fb8c86400 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-spoc-cmc: diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index 01f6d1da00d..d39797e6561 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-linear-sensor-decoding: diff --git a/examples/decoding/decoding_unsupervised_spatial_filter.py b/examples/decoding/decoding_unsupervised_spatial_filter.py index a7514625842..a3fab432ace 100644 --- a/examples/decoding/decoding_unsupervised_spatial_filter.py +++ b/examples/decoding/decoding_unsupervised_spatial_filter.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ica-pca-decoding: diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index 484082b6085..9ec65f54976 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-xdawn-decoding: diff --git a/examples/decoding/ems_filtering.py b/examples/decoding/ems_filtering.py index 8f40837b9d4..8807bf57079 100644 --- a/examples/decoding/ems_filtering.py +++ b/examples/decoding/ems_filtering.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ems-filtering: diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index 3bf30e11161..f708503214b 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-linear-patterns: diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index d4aa6a9b4df..4e948613dbb 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-receptive-field-mtrf: diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py index eb780192456..723667c1864 100644 --- a/examples/decoding/ssd_spatial_filters.py +++ b/examples/decoding/ssd_spatial_filters.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ssd-spatial-filters: diff --git a/examples/forward/forward_sensitivity_maps.py b/examples/forward/forward_sensitivity_maps.py index e2ad5acae4b..e17e8e38c12 100644 --- a/examples/forward/forward_sensitivity_maps.py +++ b/examples/forward/forward_sensitivity_maps.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sensitivity-maps: diff --git a/examples/forward/left_cerebellum_volume_source.py b/examples/forward/left_cerebellum_volume_source.py index f2b9353b755..c8327100f10 100644 --- a/examples/forward/left_cerebellum_volume_source.py +++ b/examples/forward/left_cerebellum_volume_source.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-cerebellum-source-space: diff --git a/examples/forward/source_space_morphing.py b/examples/forward/source_space_morphing.py index 712562ee358..77688705e97 100644 --- a/examples/forward/source_space_morphing.py +++ b/examples/forward/source_space_morphing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-space-morphing: diff --git a/examples/inverse/compute_mne_inverse_epochs_in_label.py b/examples/inverse/compute_mne_inverse_epochs_in_label.py index 4a5129d9ca8..e78b37c17fe 100644 --- a/examples/inverse/compute_mne_inverse_epochs_in_label.py +++ b/examples/inverse/compute_mne_inverse_epochs_in_label.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-dSPM-epochs: diff --git a/examples/inverse/compute_mne_inverse_raw_in_label.py b/examples/inverse/compute_mne_inverse_raw_in_label.py index c38199035fb..1d473f2db1f 100644 --- a/examples/inverse/compute_mne_inverse_raw_in_label.py +++ b/examples/inverse/compute_mne_inverse_raw_in_label.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _example-sLORETA: diff --git a/examples/inverse/compute_mne_inverse_volume.py b/examples/inverse/compute_mne_inverse_volume.py index 02c5af5744c..215977ca393 100644 --- a/examples/inverse/compute_mne_inverse_volume.py +++ b/examples/inverse/compute_mne_inverse_volume.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-inverse-volume: diff --git a/examples/inverse/custom_inverse_solver.py b/examples/inverse/custom_inverse_solver.py index f5ad92b3daf..760ef4408e5 100644 --- a/examples/inverse/custom_inverse_solver.py +++ b/examples/inverse/custom_inverse_solver.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-custom-inverse: diff --git a/examples/inverse/dics_epochs.py b/examples/inverse/dics_epochs.py index 30039c86ad3..8aba68b9e44 100644 --- a/examples/inverse/dics_epochs.py +++ b/examples/inverse/dics_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-inverse-dics-epochs: diff --git a/examples/inverse/dics_source_power.py b/examples/inverse/dics_source_power.py index d9ae604b5ba..8a3ee2c1cf6 100644 --- a/examples/inverse/dics_source_power.py +++ b/examples/inverse/dics_source_power.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-inverse-source-power: diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index 5fa07d3ab52..0ded1fc7aff 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-loc-methods: diff --git a/examples/inverse/gamma_map_inverse.py b/examples/inverse/gamma_map_inverse.py index dd5b8343fb2..20a205c3322 100644 --- a/examples/inverse/gamma_map_inverse.py +++ b/examples/inverse/gamma_map_inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-gamma-map: diff --git a/examples/inverse/label_activation_from_stc.py b/examples/inverse/label_activation_from_stc.py index f67821e81f6..20368b68183 100644 --- a/examples/inverse/label_activation_from_stc.py +++ b/examples/inverse/label_activation_from_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-label-time-course: diff --git a/examples/inverse/label_from_stc.py b/examples/inverse/label_from_stc.py index 88b45faabd7..3d3abae2a16 100644 --- a/examples/inverse/label_from_stc.py +++ b/examples/inverse/label_from_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-functional-label: diff --git a/examples/inverse/label_source_activations.py b/examples/inverse/label_source_activations.py index 9818cb07d3b..30a55970d81 100644 --- a/examples/inverse/label_source_activations.py +++ b/examples/inverse/label_source_activations.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-label-time-series: diff --git a/examples/inverse/mixed_norm_inverse.py b/examples/inverse/mixed_norm_inverse.py index ed2a425025f..56b64e744a1 100644 --- a/examples/inverse/mixed_norm_inverse.py +++ b/examples/inverse/mixed_norm_inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-mixed-norm-inverse: diff --git a/examples/inverse/mixed_source_space_inverse.py b/examples/inverse/mixed_source_space_inverse.py index 485fac0d26d..f732178ea9f 100644 --- a/examples/inverse/mixed_source_space_inverse.py +++ b/examples/inverse/mixed_source_space_inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-mixed-source-space-inverse: diff --git a/examples/inverse/mne_cov_power.py b/examples/inverse/mne_cov_power.py index b7c5137e3b7..91fc47bc577 100644 --- a/examples/inverse/mne_cov_power.py +++ b/examples/inverse/mne_cov_power.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-cov-power: diff --git a/examples/inverse/morph_surface_stc.py b/examples/inverse/morph_surface_stc.py index c918a6e6f3b..80a35c87ed8 100644 --- a/examples/inverse/morph_surface_stc.py +++ b/examples/inverse/morph_surface_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-morph-surface: diff --git a/examples/inverse/morph_volume_stc.py b/examples/inverse/morph_volume_stc.py index 5be18ba6218..1494b7b30c8 100644 --- a/examples/inverse/morph_volume_stc.py +++ b/examples/inverse/morph_volume_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-morph-volume: diff --git a/examples/inverse/multi_dipole_model.py b/examples/inverse/multi_dipole_model.py index fe8f6acecf3..2dbe6362157 100644 --- a/examples/inverse/multi_dipole_model.py +++ b/examples/inverse/multi_dipole_model.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-multi-dipole: diff --git a/examples/inverse/multidict_reweighted_tfmxne.py b/examples/inverse/multidict_reweighted_tfmxne.py index 1cb111834b6..58aa0fefb09 100644 --- a/examples/inverse/multidict_reweighted_tfmxne.py +++ b/examples/inverse/multidict_reweighted_tfmxne.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-iterative-reweighted-tf-mxne: diff --git a/examples/inverse/psf_ctf_label_leakage.py b/examples/inverse/psf_ctf_label_leakage.py index ce174faee29..5975584c391 100644 --- a/examples/inverse/psf_ctf_label_leakage.py +++ b/examples/inverse/psf_ctf_label_leakage.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-leakage: diff --git a/examples/inverse/psf_ctf_vertices.py b/examples/inverse/psf_ctf_vertices.py index f6616106e6b..a365991ffa1 100644 --- a/examples/inverse/psf_ctf_vertices.py +++ b/examples/inverse/psf_ctf_vertices.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-psd-ctf: diff --git a/examples/inverse/psf_ctf_vertices_lcmv.py b/examples/inverse/psf_ctf_vertices_lcmv.py index fc8f740029e..de774c2149e 100644 --- a/examples/inverse/psf_ctf_vertices_lcmv.py +++ b/examples/inverse/psf_ctf_vertices_lcmv.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-psf-ctf-lcmv: diff --git a/examples/inverse/psf_volume.py b/examples/inverse/psf_volume.py index 042efb7be4a..7cfd0675cd8 100644 --- a/examples/inverse/psf_volume.py +++ b/examples/inverse/psf_volume.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-psd-vol: diff --git a/examples/inverse/rap_music.py b/examples/inverse/rap_music.py index fc5751afe09..937351b96dd 100644 --- a/examples/inverse/rap_music.py +++ b/examples/inverse/rap_music.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-rap-music: diff --git a/examples/inverse/read_inverse.py b/examples/inverse/read_inverse.py index 730fec92668..fd604b08f35 100644 --- a/examples/inverse/read_inverse.py +++ b/examples/inverse/read_inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-read-inverse: diff --git a/examples/inverse/read_stc.py b/examples/inverse/read_stc.py index b421671ce73..3ae91bfc799 100644 --- a/examples/inverse/read_stc.py +++ b/examples/inverse/read_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-read-stc: diff --git a/examples/inverse/resolution_metrics.py b/examples/inverse/resolution_metrics.py index aa795ccc972..10d3e03944c 100644 --- a/examples/inverse/resolution_metrics.py +++ b/examples/inverse/resolution_metrics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-res-metrics: diff --git a/examples/inverse/resolution_metrics_eegmeg.py b/examples/inverse/resolution_metrics_eegmeg.py index 3c3b67ca926..06268178058 100644 --- a/examples/inverse/resolution_metrics_eegmeg.py +++ b/examples/inverse/resolution_metrics_eegmeg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-res-metrics-meeg: diff --git a/examples/inverse/snr_estimate.py b/examples/inverse/snr_estimate.py index ccf6385e14e..956f3cbe643 100644 --- a/examples/inverse/snr_estimate.py +++ b/examples/inverse/snr_estimate.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-snr-estimate: diff --git a/examples/inverse/source_space_snr.py b/examples/inverse/source_space_snr.py index cdcdbb6351c..0dd14e71722 100644 --- a/examples/inverse/source_space_snr.py +++ b/examples/inverse/source_space_snr.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-space-snr: diff --git a/examples/inverse/time_frequency_mixed_norm_inverse.py b/examples/inverse/time_frequency_mixed_norm_inverse.py index db55fa60b8f..27968bd8971 100644 --- a/examples/inverse/time_frequency_mixed_norm_inverse.py +++ b/examples/inverse/time_frequency_mixed_norm_inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-tfr-mixed-norm: diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index baa0f65e335..caba3a46201 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-vector-mne-solution: diff --git a/examples/io/elekta_epochs.py b/examples/io/elekta_epochs.py index 6a3fff183ad..8c24902d209 100644 --- a/examples/io/elekta_epochs.py +++ b/examples/io/elekta_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-io-ave-fiff: diff --git a/examples/io/read_neo_format.py b/examples/io/read_neo_format.py index d000a49a7b1..43b8a98f876 100644 --- a/examples/io/read_neo_format.py +++ b/examples/io/read_neo_format.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-read-neo: diff --git a/examples/io/read_noise_covariance_matrix.py b/examples/io/read_noise_covariance_matrix.py index 1cf604f521f..57b0d314e25 100644 --- a/examples/io/read_noise_covariance_matrix.py +++ b/examples/io/read_noise_covariance_matrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-read-noise-cov: diff --git a/examples/io/read_xdf.py b/examples/io/read_xdf.py index 475dd3c1ee0..d65784d85ad 100644 --- a/examples/io/read_xdf.py +++ b/examples/io/read_xdf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-read-xdf: diff --git a/examples/preprocessing/css.py b/examples/preprocessing/css.py index 47447b06736..2631dc54d23 100644 --- a/examples/preprocessing/css.py +++ b/examples/preprocessing/css.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-css: diff --git a/examples/preprocessing/define_target_events.py b/examples/preprocessing/define_target_events.py index 356bc9e89ff..f35b16743d9 100644 --- a/examples/preprocessing/define_target_events.py +++ b/examples/preprocessing/define_target_events.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-tag-events: diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index f25d526e42d..31e8b06b08d 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-eeg-bridging: diff --git a/examples/preprocessing/eeg_csd.py b/examples/preprocessing/eeg_csd.py index d4176de1ceb..24f33b91e53 100644 --- a/examples/preprocessing/eeg_csd.py +++ b/examples/preprocessing/eeg_csd.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-eeg-csd: diff --git a/examples/preprocessing/eog_artifact_histogram.py b/examples/preprocessing/eog_artifact_histogram.py index d14eee57802..a6a3e895b3c 100644 --- a/examples/preprocessing/eog_artifact_histogram.py +++ b/examples/preprocessing/eog_artifact_histogram.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-eog: diff --git a/examples/preprocessing/eog_regression.py b/examples/preprocessing/eog_regression.py index 62c52c396ac..1d7f6879b9a 100644 --- a/examples/preprocessing/eog_regression.py +++ b/examples/preprocessing/eog_regression.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ======================================= Reduce EOG artifacts through regression diff --git a/examples/preprocessing/find_ref_artifacts.py b/examples/preprocessing/find_ref_artifacts.py index 969e714f684..f3781a0c1cc 100644 --- a/examples/preprocessing/find_ref_artifacts.py +++ b/examples/preprocessing/find_ref_artifacts.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-megnoise_processing: diff --git a/examples/preprocessing/fnirs_artifact_removal.py b/examples/preprocessing/fnirs_artifact_removal.py index c6882d38a09..b7236b76636 100644 --- a/examples/preprocessing/fnirs_artifact_removal.py +++ b/examples/preprocessing/fnirs_artifact_removal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-fnirs-artifacts: diff --git a/examples/preprocessing/ica_comparison.py b/examples/preprocessing/ica_comparison.py index f9fb7d75764..7c4a8aa733c 100644 --- a/examples/preprocessing/ica_comparison.py +++ b/examples/preprocessing/ica_comparison.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ica-comp: diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index b2893152e1f..635dffcbfba 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-interpolate-bad-channels: diff --git a/examples/preprocessing/locate_ieeg_micro.py b/examples/preprocessing/locate_ieeg_micro.py index 6433d2f6829..0ab653e93f6 100644 --- a/examples/preprocessing/locate_ieeg_micro.py +++ b/examples/preprocessing/locate_ieeg_micro.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ieeg-micro: diff --git a/examples/preprocessing/movement_compensation.py b/examples/preprocessing/movement_compensation.py index 51913e30ec0..3a31648c4a5 100644 --- a/examples/preprocessing/movement_compensation.py +++ b/examples/preprocessing/movement_compensation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-movement-comp: diff --git a/examples/preprocessing/movement_detection.py b/examples/preprocessing/movement_detection.py index d84548792df..ac90f45f587 100644 --- a/examples/preprocessing/movement_detection.py +++ b/examples/preprocessing/movement_detection.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-movement-detect: diff --git a/examples/preprocessing/muscle_detection.py b/examples/preprocessing/muscle_detection.py index d2369b87cd1..223f93743d9 100644 --- a/examples/preprocessing/muscle_detection.py +++ b/examples/preprocessing/muscle_detection.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-muscle-artifacts: diff --git a/examples/preprocessing/muscle_ica.py b/examples/preprocessing/muscle_ica.py index 14960a761b9..8abc96f5d6a 100644 --- a/examples/preprocessing/muscle_ica.py +++ b/examples/preprocessing/muscle_ica.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-muscle-ica: diff --git a/examples/preprocessing/otp.py b/examples/preprocessing/otp.py index e2bf81f7480..520d66166ac 100644 --- a/examples/preprocessing/otp.py +++ b/examples/preprocessing/otp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-otp: diff --git a/examples/preprocessing/shift_evoked.py b/examples/preprocessing/shift_evoked.py index ba938d88993..3bbe0386416 100644 --- a/examples/preprocessing/shift_evoked.py +++ b/examples/preprocessing/shift_evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-shift-evoked: diff --git a/examples/preprocessing/virtual_evoked.py b/examples/preprocessing/virtual_evoked.py index 32c93a4929a..b947226b40b 100644 --- a/examples/preprocessing/virtual_evoked.py +++ b/examples/preprocessing/virtual_evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-virtual-evoked: diff --git a/examples/preprocessing/xdawn_denoising.py b/examples/preprocessing/xdawn_denoising.py index 10699c41998..aa7c0f48e08 100644 --- a/examples/preprocessing/xdawn_denoising.py +++ b/examples/preprocessing/xdawn_denoising.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-xdawn-denoising: diff --git a/examples/simulation/simulate_evoked_data.py b/examples/simulation/simulate_evoked_data.py index 037b1dcbbc7..0d4cff6a6c3 100644 --- a/examples/simulation/simulate_evoked_data.py +++ b/examples/simulation/simulate_evoked_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sim-evoked: diff --git a/examples/simulation/simulate_raw_data.py b/examples/simulation/simulate_raw_data.py index 641f0171707..6c308792c97 100644 --- a/examples/simulation/simulate_raw_data.py +++ b/examples/simulation/simulate_raw_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sim-raw: diff --git a/examples/simulation/simulated_raw_data_using_subject_anatomy.py b/examples/simulation/simulated_raw_data_using_subject_anatomy.py index 393fb66d0b1..b78db66c965 100644 --- a/examples/simulation/simulated_raw_data_using_subject_anatomy.py +++ b/examples/simulation/simulated_raw_data_using_subject_anatomy.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sim-raw-sub: diff --git a/examples/simulation/source_simulator.py b/examples/simulation/source_simulator.py index 9a1ced9d7ff..93a348e46ca 100644 --- a/examples/simulation/source_simulator.py +++ b/examples/simulation/source_simulator.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sim-source: diff --git a/examples/stats/cluster_stats_evoked.py b/examples/stats/cluster_stats_evoked.py index e50a5018058..cf2f9d59c18 100644 --- a/examples/stats/cluster_stats_evoked.py +++ b/examples/stats/cluster_stats_evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-cluster-evoked: diff --git a/examples/stats/fdr_stats_evoked.py b/examples/stats/fdr_stats_evoked.py index 5007042fccc..b90ab6f9ccd 100644 --- a/examples/stats/fdr_stats_evoked.py +++ b/examples/stats/fdr_stats_evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-fdr-evoked: diff --git a/examples/stats/linear_regression_raw.py b/examples/stats/linear_regression_raw.py index 53b288fc5ee..54aef70c8e2 100644 --- a/examples/stats/linear_regression_raw.py +++ b/examples/stats/linear_regression_raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-linear-regression-raw: diff --git a/examples/stats/sensor_permutation_test.py b/examples/stats/sensor_permutation_test.py index 0cbb0e855e7..654c9b7153c 100644 --- a/examples/stats/sensor_permutation_test.py +++ b/examples/stats/sensor_permutation_test.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-perm-test: diff --git a/examples/stats/sensor_regression.py b/examples/stats/sensor_regression.py index 5bb68a61f77..9a1e42ae7f8 100644 --- a/examples/stats/sensor_regression.py +++ b/examples/stats/sensor_regression.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-sensor-regression: diff --git a/examples/time_frequency/compute_csd.py b/examples/time_frequency/compute_csd.py index 6c1b0fc2d48..e9a962bb733 100644 --- a/examples/time_frequency/compute_csd.py +++ b/examples/time_frequency/compute_csd.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-csd-matrix: diff --git a/examples/time_frequency/compute_source_psd_epochs.py b/examples/time_frequency/compute_source_psd_epochs.py index d93b5e77bd0..1ca42643f49 100644 --- a/examples/time_frequency/compute_source_psd_epochs.py +++ b/examples/time_frequency/compute_source_psd_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-psd-inverse: diff --git a/examples/time_frequency/source_label_time_frequency.py b/examples/time_frequency/source_label_time_frequency.py index 736abed8499..721c2fc4d2d 100644 --- a/examples/time_frequency/source_label_time_frequency.py +++ b/examples/time_frequency/source_label_time_frequency.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-space-power-phase-locking: diff --git a/examples/time_frequency/source_power_spectrum.py b/examples/time_frequency/source_power_spectrum.py index 8c32e8e78df..4b6d582d50b 100644 --- a/examples/time_frequency/source_power_spectrum.py +++ b/examples/time_frequency/source_power_spectrum.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-label-psd: diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index 813a08244e1..462f79c8eb9 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-opm-resting-state: diff --git a/examples/time_frequency/source_space_time_frequency.py b/examples/time_frequency/source_space_time_frequency.py index 7573dabbb2f..a0a5f944439 100644 --- a/examples/time_frequency/source_space_time_frequency.py +++ b/examples/time_frequency/source_space_time_frequency.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-source-space-tfr: diff --git a/examples/time_frequency/temporal_whitening.py b/examples/time_frequency/temporal_whitening.py index 0b85e1695fa..068abad7337 100644 --- a/examples/time_frequency/temporal_whitening.py +++ b/examples/time_frequency/temporal_whitening.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-temporal-whitening: diff --git a/examples/time_frequency/time_frequency_erds.py b/examples/time_frequency/time_frequency_erds.py index 69bd0bba98f..d55122c232b 100644 --- a/examples/time_frequency/time_frequency_erds.py +++ b/examples/time_frequency/time_frequency_erds.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-tfr-erds: diff --git a/examples/time_frequency/time_frequency_global_field_power.py b/examples/time_frequency/time_frequency_global_field_power.py index ff3031d4623..a9af92cdde9 100644 --- a/examples/time_frequency/time_frequency_global_field_power.py +++ b/examples/time_frequency/time_frequency_global_field_power.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-time-freq-global-field-power: diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index 48c74e7607c..c84803d7d2f 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-tfr-comparison: diff --git a/examples/visualization/3d_to_2d.py b/examples/visualization/3d_to_2d.py index ebe592e9dd4..bb692533baa 100644 --- a/examples/visualization/3d_to_2d.py +++ b/examples/visualization/3d_to_2d.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-electrode-pos-2d: diff --git a/examples/visualization/brain.py b/examples/visualization/brain.py index 71e53e1f8c8..5b31bc7b106 100644 --- a/examples/visualization/brain.py +++ b/examples/visualization/brain.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-brain: diff --git a/examples/visualization/channel_epochs_image.py b/examples/visualization/channel_epochs_image.py index ecbe0a789dc..bb52c11c44b 100644 --- a/examples/visualization/channel_epochs_image.py +++ b/examples/visualization/channel_epochs_image.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-epochs-image: diff --git a/examples/visualization/eeg_on_scalp.py b/examples/visualization/eeg_on_scalp.py index 96fff94a523..7ad5438b9dc 100644 --- a/examples/visualization/eeg_on_scalp.py +++ b/examples/visualization/eeg_on_scalp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-eeg-on-scalp: diff --git a/examples/visualization/evoked_arrowmap.py b/examples/visualization/evoked_arrowmap.py index df0c1123481..7ce3f1df093 100644 --- a/examples/visualization/evoked_arrowmap.py +++ b/examples/visualization/evoked_arrowmap.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-arrowmap: diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 7b576ab2376..abeb527757e 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-evoked-topomap: diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index 73d88013de4..7a5f7552cc1 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-evoked-whitening: diff --git a/examples/visualization/meg_sensors.py b/examples/visualization/meg_sensors.py index 28cea87b215..9d5ccd6411c 100644 --- a/examples/visualization/meg_sensors.py +++ b/examples/visualization/meg_sensors.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-plot-meg-sensors: diff --git a/examples/visualization/mne_helmet.py b/examples/visualization/mne_helmet.py index 9653a705981..c6c155bcfd3 100644 --- a/examples/visualization/mne_helmet.py +++ b/examples/visualization/mne_helmet.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-mne-helmet: diff --git a/examples/visualization/montage_sgskip.py b/examples/visualization/montage_sgskip.py index 17bfc686e36..96ab574499e 100644 --- a/examples/visualization/montage_sgskip.py +++ b/examples/visualization/montage_sgskip.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _plot_montage: diff --git a/examples/visualization/parcellation.py b/examples/visualization/parcellation.py index 63f1dd9d177..7118a2594b5 100644 --- a/examples/visualization/parcellation.py +++ b/examples/visualization/parcellation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-parcellation: diff --git a/examples/visualization/publication_figure.py b/examples/visualization/publication_figure.py index fc9385c385f..f86cc44075d 100644 --- a/examples/visualization/publication_figure.py +++ b/examples/visualization/publication_figure.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-publication-figure: diff --git a/examples/visualization/roi_erpimage_by_rt.py b/examples/visualization/roi_erpimage_by_rt.py index 26a07ff07a8..e803b3cb14b 100644 --- a/examples/visualization/roi_erpimage_by_rt.py +++ b/examples/visualization/roi_erpimage_by_rt.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _roi-erp: diff --git a/examples/visualization/sensor_noise_level.py b/examples/visualization/sensor_noise_level.py index 17b4dbd354a..55b220ba1c0 100644 --- a/examples/visualization/sensor_noise_level.py +++ b/examples/visualization/sensor_noise_level.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-noise-level: diff --git a/examples/visualization/ssp_projs_sensitivity_map.py b/examples/visualization/ssp_projs_sensitivity_map.py index a4b3e4bc341..2c8259d7a24 100644 --- a/examples/visualization/ssp_projs_sensitivity_map.py +++ b/examples/visualization/ssp_projs_sensitivity_map.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-ssp-proj: diff --git a/examples/visualization/topo_compare_conditions.py b/examples/visualization/topo_compare_conditions.py index eb5c9d13b0b..6687ba37576 100644 --- a/examples/visualization/topo_compare_conditions.py +++ b/examples/visualization/topo_compare_conditions.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-topo-compare: diff --git a/examples/visualization/topo_customized.py b/examples/visualization/topo_customized.py index 6e864802ba5..cc284431246 100644 --- a/examples/visualization/topo_customized.py +++ b/examples/visualization/topo_customized.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-topo-custom: diff --git a/examples/visualization/xhemi.py b/examples/visualization/xhemi.py index c6bf6e5a961..bb5a4971d4d 100644 --- a/examples/visualization/xhemi.py +++ b/examples/visualization/xhemi.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-xhemi: diff --git a/logo/generate_mne_logos.py b/logo/generate_mne_logos.py index f9d62ddc581..072710182be 100644 --- a/logo/generate_mne_logos.py +++ b/logo/generate_mne_logos.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ =============================================================================== Script 'mne logo' diff --git a/mne/_freesurfer.py b/mne/_freesurfer.py index cf0211ec3e6..d92ac40e807 100644 --- a/mne/_freesurfer.py +++ b/mne/_freesurfer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Freesurfer handling functions.""" # Authors: Alex Rockhill # Eric Larson diff --git a/mne/_ola.py b/mne/_ola.py index 68c24a79278..a4ecad26a66 100644 --- a/mne/_ola.py +++ b/mne/_ola.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD-3-Clause diff --git a/mne/beamformer/resolution_matrix.py b/mne/beamformer/resolution_matrix.py index b0e7c450ed6..5294de5a621 100644 --- a/mne/beamformer/resolution_matrix.py +++ b/mne/beamformer/resolution_matrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Compute resolution matrix for beamformers.""" # Authors: olaf.hauk@mrc-cbu.cam.ac.uk # diff --git a/mne/beamformer/tests/test_resolution_matrix.py b/mne/beamformer/tests/test_resolution_matrix.py index 09ef7fe1118..d033eaf6b67 100755 --- a/mne/beamformer/tests/test_resolution_matrix.py +++ b/mne/beamformer/tests/test_resolution_matrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Olaf Hauk # # License: BSD-3-Clause diff --git a/mne/chpi.py b/mne/chpi.py index 096c4e6f2bc..cdfc9b558ae 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Functions for fitting head positions with (c)HPI coils.""" # Next, ``compute_head_pos`` can be used to: diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index dd786d3773a..995edae59b9 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import glob import os from os import path as op diff --git a/mne/conftest.py b/mne/conftest.py index 0c08dab9a03..5da6406bfc3 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Eric Larson # # License: BSD-3-Clause diff --git a/mne/coreg.py b/mne/coreg.py index 3ea3576e1d1..db0b3645633 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Coregistration between different coordinate frames.""" # Authors: Christian Brodbeck diff --git a/mne/datasets/_fsaverage/base.py b/mne/datasets/_fsaverage/base.py index b22d50ae7ab..d4a8f3d82c0 100644 --- a/mne/datasets/_fsaverage/base.py +++ b/mne/datasets/_fsaverage/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD Style. diff --git a/mne/datasets/_infant/base.py b/mne/datasets/_infant/base.py index fe9032dc44f..c327c4835e0 100644 --- a/mne/datasets/_infant/base.py +++ b/mne/datasets/_infant/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD Style. diff --git a/mne/datasets/_phantom/base.py b/mne/datasets/_phantom/base.py index d420ca777e3..8785e3018ec 100644 --- a/mne/datasets/_phantom/base.py +++ b/mne/datasets/_phantom/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD Style. diff --git a/mne/datasets/hf_sef/hf_sef.py b/mne/datasets/hf_sef/hf_sef.py index aa4cffa33d1..401c3636017 100644 --- a/mne/datasets/hf_sef/hf_sef.py +++ b/mne/datasets/hf_sef/hf_sef.py @@ -1,5 +1,4 @@ #!/usr/bin/env python2 -# -*- coding: utf-8 -*- # Authors: Jussi Nurminen # License: BSD Style. diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index 85e6088bfcb..0c2c0632857 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Joan Massich # diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index 2f544035a24..4a0d8456639 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Joan Massich # diff --git a/mne/datasets/sleep_physionet/temazepam.py b/mne/datasets/sleep_physionet/temazepam.py index f4981a7cc25..a18f126ab5f 100644 --- a/mne/datasets/sleep_physionet/temazepam.py +++ b/mne/datasets/sleep_physionet/temazepam.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Joan Massich # diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index c7362a0b5bf..6e3ed67c163 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Romain Trachel # Alexandre Gramfort # Alexandre Barachant diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py index d38e9e4aff4..c000ae4b74d 100644 --- a/mne/decoding/mixin.py +++ b/mne/decoding/mixin.py @@ -1,5 +1,3 @@ - - class TransformerMixin: """Mixin class for all transformers in scikit-learn.""" diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index d344e93e668..cf6e6dd35bc 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Chris Holdgraf # Eric Larson diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index a1e2f426b18..2d3d13f1300 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """TimeDelayingRidge class.""" # Authors: Eric Larson # Ross Maddox diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 5e7734c292a..b6faf66cf97 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Mainak Jas # Alexandre Gramfort # Romain Trachel diff --git a/mne/dipole.py b/mne/dipole.py index c374fb2dca2..65fe90a39a3 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Single-dipole functions and classes.""" # Authors: Alexandre Gramfort diff --git a/mne/epochs.py b/mne/epochs.py index 319bb4ecdb4..ae0f6736564 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """Tools for working with epoched data.""" # Authors: Alexandre Gramfort diff --git a/mne/evoked.py b/mne/evoked.py index 1db3de6bcd7..07101f081d7 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Denis Engemann diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index 319b2314864..91e0c08b94d 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: MNE Developers # # License: BSD-3-Clause diff --git a/mne/export/_edf.py b/mne/export/_edf.py index 752e8c81caa..8a6b1370470 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: MNE Developers # # License: BSD-3-Clause diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 93556b056c5..00d566c13fe 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: MNE Developers # # License: BSD-3-Clause diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 6ad4ead73e7..65418d35d6c 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: MNE Developers # # License: BSD-3-Clause diff --git a/mne/export/_export.py b/mne/export/_export.py index 1d2cc44a141..c26927d1755 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: MNE Developers # # License: BSD-3-Clause diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index aaf92c92cb2..69679e5a7cd 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test exporting functions.""" # Authors: MNE Developers # diff --git a/mne/forward/_compute_forward.py b/mne/forward/_compute_forward.py index 7ceec0b22ad..9b4ee7dba1c 100644 --- a/mne/forward/_compute_forward.py +++ b/mne/forward/_compute_forward.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Matti Hämäläinen # Alexandre Gramfort # Martin Luessi diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index 2d2c0a6e615..fdc21ab8e9c 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Matti Hämäläinen # Alexandre Gramfort # Eric Larson diff --git a/mne/gui/_core.py b/mne/gui/_core.py index 7929c838737..b40f16621b3 100644 --- a/mne/gui/_core.py +++ b/mne/gui/_core.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Shared GUI classes and functions.""" # Authors: Alex Rockhill diff --git a/mne/gui/_ieeg_locate.py b/mne/gui/_ieeg_locate.py index 59f3ea715de..a23590d7317 100644 --- a/mne/gui/_ieeg_locate.py +++ b/mne/gui/_ieeg_locate.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Intracranial elecrode localization GUI for finding contact locations.""" # Authors: Alex Rockhill diff --git a/mne/gui/tests/test_core.py b/mne/gui/tests/test_core.py index 7a5040903bc..013bca7eed5 100644 --- a/mne/gui/tests/test_core.py +++ b/mne/gui/tests/test_core.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alex Rockhill # # License: BSD-3-clause diff --git a/mne/gui/tests/test_ieeg_locate.py b/mne/gui/tests/test_ieeg_locate.py index 2463d171d54..7fb2c544066 100644 --- a/mne/gui/tests/test_ieeg_locate.py +++ b/mne/gui/tests/test_ieeg_locate.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alex Rockhill # # License: BSD-3-clause diff --git a/mne/io/_digitization.py b/mne/io/_digitization.py index 37fb471ac68..30a07c19b46 100644 --- a/mne/io/_digitization.py +++ b/mne/io/_digitization.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Teon Brooks diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index aa38fac5b53..21b7204b775 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -1,4 +1,3 @@ - # Author: Luke Bloy # # License: BSD-3-Clause diff --git a/mne/io/base.py b/mne/io/base.py index 0dcc1b9eee4..05645de8cf3 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Martin Luessi diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index d536153e5ac..892f189fca2 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Conversion tool from BrainVision EEG to FIF.""" # Authors: Teon Brooks # Christian Brodbeck diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py index b43287eabe3..c9c375d086d 100644 --- a/mne/io/brainvision/tests/test_brainvision.py +++ b/mne/io/brainvision/tests/test_brainvision.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test reading of BrainVision format.""" # Author: Teon Brooks # Stefan Appelhoff diff --git a/mne/io/cnt/tests/test_cnt.py b/mne/io/cnt/tests/test_cnt.py index db0d1d695aa..ac37c7fe38e 100644 --- a/mne/io/cnt/tests/test_cnt.py +++ b/mne/io/cnt/tests/test_cnt.py @@ -1,4 +1,3 @@ - # Author: Jaakko Leppakangas # Joan Massich # diff --git a/mne/io/curry/tests/test_curry.py b/mne/io/curry/tests/test_curry.py index 2f9c8c4d141..2c53a0032a1 100644 --- a/mne/io/curry/tests/test_curry.py +++ b/mne/io/curry/tests/test_curry.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: Dirk Gütlin # diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index b3a7b38da36..c307216bdcc 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Reading tools from EDF, EDF+, BDF, and GDF.""" # Authors: Teon Brooks diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index bae740ba7ef..597cae5eee1 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Teon Brooks # Martin Billinger # Alan Leggitt diff --git a/mne/io/egi/events.py b/mne/io/egi/events.py index 9e4967d115a..196a6ea717a 100644 --- a/mne/io/egi/events.py +++ b/mne/io/egi/events.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # License: BSD-3-Clause diff --git a/mne/io/egi/general.py b/mne/io/egi/general.py index b8212ed3d56..a1de880efc6 100644 --- a/mne/io/egi/general.py +++ b/mne/io/egi/general.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # License: BSD-3-Clause diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 2f543321ce5..45b0ca1109e 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Denis A. Engemann # simplified BSD-3 license diff --git a/mne/io/eximia/tests/test_eximia.py b/mne/io/eximia/tests/test_eximia.py index ea627d10eb0..73f06cd9106 100644 --- a/mne/io/eximia/tests/test_eximia.py +++ b/mne/io/eximia/tests/test_eximia.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Federico Raimondo # simplified BSD-3 license diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index b274602cf23..4e913cd587b 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Alexandre Gramfort # Denis Engemann # diff --git a/mne/io/hitachi/tests/test_hitachi.py b/mne/io/hitachi/tests/test_hitachi.py index 802b46063bc..d04218b1eb0 100644 --- a/mne/io/hitachi/tests/test_hitachi.py +++ b/mne/io/hitachi/tests/test_hitachi.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # # License: BSD-3-Clause diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index f8c9eba13cc..bc22187b9a3 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Teon Brooks diff --git a/mne/io/nedf/nedf.py b/mne/io/nedf/nedf.py index 55c11c2c244..78ae0106b4e 100644 --- a/mne/io/nedf/nedf.py +++ b/mne/io/nedf/nedf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Import NeuroElectrics DataFormat (NEDF) files.""" from copy import deepcopy diff --git a/mne/io/nedf/tests/test_nedf.py b/mne/io/nedf/tests/test_nedf.py index 9f00cda7a04..404dd7af342 100644 --- a/mne/io/nedf/tests/test_nedf.py +++ b/mne/io/nedf/tests/test_nedf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test reading of NEDF format.""" # Author: Tristan Stenner # diff --git a/mne/io/nicolet/tests/test_nicolet.py b/mne/io/nicolet/tests/test_nicolet.py index 7e13bc5a497..670597e7b75 100644 --- a/mne/io/nicolet/tests/test_nicolet.py +++ b/mne/io/nicolet/tests/test_nicolet.py @@ -1,4 +1,3 @@ - # Author: Jaakko Leppakangas # # License: BSD-3-Clause diff --git a/mne/io/nihon/tests/test_nihon.py b/mne/io/nihon/tests/test_nihon.py index 1f461fa639f..9a497b7cfd2 100644 --- a/mne/io/nihon/tests/test_nihon.py +++ b/mne/io/nihon/tests/test_nihon.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Federico Raimondo # simplified BSD-3 license diff --git a/mne/io/nirx/tests/test_nirx.py b/mne/io/nirx/tests/test_nirx.py index 79b348b1ab5..de7a78b0cca 100644 --- a/mne/io/nirx/tests/test_nirx.py +++ b/mne/io/nirx/tests/test_nirx.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Robert Luke # Eric Larson # simplified BSD-3 license diff --git a/mne/io/open.py b/mne/io/open.py index e95a3e957c4..d2c94accd53 100644 --- a/mne/io/open.py +++ b/mne/io/open.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # diff --git a/mne/io/persyst/tests/test_persyst.py b/mne/io/persyst/tests/test_persyst.py index c907fae42df..4d11c728398 100644 --- a/mne/io/persyst/tests/test_persyst.py +++ b/mne/io/persyst/tests/test_persyst.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Adam Li # # License: BSD-3-Clause diff --git a/mne/io/pick.py b/mne/io/pick.py index d71971155d1..dc914c53ffd 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Martin Luessi diff --git a/mne/io/proc_history.py b/mne/io/proc_history.py index 21b1018ff34..7209f9f7ca7 100644 --- a/mne/io/proc_history.py +++ b/mne/io/proc_history.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Denis A. Engemann # Eric Larson # License: Simplified BSD diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index fc4572f3e3f..b9475d69583 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Robert Luke # simplified BSD-3 license diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index 67b81b82f7e..e16c8b5f41b 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: MNE Developers # Stefan Appelhoff # diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 586dbdcdd56..4b16c18a6a0 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Generic tests that all raw classes should run.""" # Authors: MNE Developers # Stefan Appelhoff diff --git a/mne/io/tests/test_show_fiff.py b/mne/io/tests/test_show_fiff.py index f25c6c04cac..52beb9cdbed 100644 --- a/mne/io/tests/test_show_fiff.py +++ b/mne/io/tests/test_show_fiff.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Eric Larson # # License: BSD-3-Clause diff --git a/mne/io/tests/test_utils.py b/mne/io/tests/test_utils.py index 6bfe5ae933d..601a9df4e9c 100644 --- a/mne/io/tests/test_utils.py +++ b/mne/io/tests/test_utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Run tests for the utilities.""" # Author: Stefan Appelhoff # diff --git a/mne/io/tests/test_write.py b/mne/io/tests/test_write.py index 2a67566c61f..a86e47d175c 100644 --- a/mne/io/tests/test_write.py +++ b/mne/io/tests/test_write.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Run tests for writing.""" # Author: Eric Larson # diff --git a/mne/io/utils.py b/mne/io/utils.py index 8520490ef9e..f9d01d9bae4 100644 --- a/mne/io/utils.py +++ b/mne/io/utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Martin Luessi diff --git a/mne/io/what.py b/mne/io/what.py index 0d4f5d2297a..fda10db46b0 100644 --- a/mne/io/what.py +++ b/mne/io/what.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # # License: BSD-3-Clause diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index 468878fab3b..42d58d0173a 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Matti Hämäläinen # Teon Brooks diff --git a/mne/minimum_norm/resolution_matrix.py b/mne/minimum_norm/resolution_matrix.py index 5457164613b..2b013f39e96 100644 --- a/mne/minimum_norm/resolution_matrix.py +++ b/mne/minimum_norm/resolution_matrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Compute resolution matrix for linear estimators.""" # Authors: olaf.hauk@mrc-cbu.cam.ac.uk # diff --git a/mne/minimum_norm/spatial_resolution.py b/mne/minimum_norm/spatial_resolution.py index 9d8752c3ef5..bf6e4173797 100644 --- a/mne/minimum_norm/spatial_resolution.py +++ b/mne/minimum_norm/spatial_resolution.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Olaf Hauk # # License: BSD-3-Clause diff --git a/mne/minimum_norm/tests/test_resolution_matrix.py b/mne/minimum_norm/tests/test_resolution_matrix.py index e8792161b4f..a1f39d493a6 100644 --- a/mne/minimum_norm/tests/test_resolution_matrix.py +++ b/mne/minimum_norm/tests/test_resolution_matrix.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Olaf Hauk # Alexandre Gramfort # Eric Larson diff --git a/mne/minimum_norm/tests/test_resolution_metrics.py b/mne/minimum_norm/tests/test_resolution_metrics.py index bfa437ea1ba..3d198daab57 100644 --- a/mne/minimum_norm/tests/test_resolution_metrics.py +++ b/mne/minimum_norm/tests/test_resolution_metrics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Olaf Hauk # Daniel McCloy # diff --git a/mne/minimum_norm/tests/test_snr.py b/mne/minimum_norm/tests/test_snr.py index 060d608eeea..ad9d83d2cf6 100644 --- a/mne/minimum_norm/tests/test_snr.py +++ b/mne/minimum_norm/tests/test_snr.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # Matti Hämäläinen # diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index 109be5f854f..fb0ed474938 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD-3-Clause diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 06e16383c15..d6432755b97 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: Denis A. Engemann # Alexandre Gramfort diff --git a/mne/preprocessing/ieeg/tests/test_projection.py b/mne/preprocessing/ieeg/tests/test_projection.py index 7dd7fd21d03..fbc8570782f 100644 --- a/mne/preprocessing/ieeg/tests/test_projection.py +++ b/mne/preprocessing/ieeg/tests/test_projection.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test the ieeg projection functions.""" # Authors: Alex Rockhill # diff --git a/mne/preprocessing/ieeg/tests/test_volume.py b/mne/preprocessing/ieeg/tests/test_volume.py index e20c372742a..d4f50610ea4 100644 --- a/mne/preprocessing/ieeg/tests/test_volume.py +++ b/mne/preprocessing/ieeg/tests/test_volume.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test ieeg volume functions.""" # Authors: Alex Rockhill # diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 0251a372a53..d270d716e5c 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Mark Wronkiewicz # Eric Larson # Jussi Nurminen diff --git a/mne/preprocessing/otp.py b/mne/preprocessing/otp.py index b1fa43a16ce..6f7be39a387 100644 --- a/mne/preprocessing/otp.py +++ b/mne/preprocessing/otp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Samu Taulu # Eric Larson diff --git a/mne/preprocessing/realign.py b/mne/preprocessing/realign.py index be510bb5521..a9faa763d71 100644 --- a/mne/preprocessing/realign.py +++ b/mne/preprocessing/realign.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # License: BSD-3-Clause diff --git a/mne/preprocessing/tests/test_csd.py b/mne/preprocessing/tests/test_csd.py index 400df01c905..d5699c8b9fa 100644 --- a/mne/preprocessing/tests/test_csd.py +++ b/mne/preprocessing/tests/test_csd.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test the current source density and related functions. For each supported file format, implement a test. diff --git a/mne/rank.py b/mne/rank.py index cdb07e280bc..20fb43ea90a 100644 --- a/mne/rank.py +++ b/mne/rank.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some utility functions for rank estimation.""" # Authors: Alexandre Gramfort # diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index 4228715a7dc..a71c8fdb9f7 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Mainak Jas # Teon Brooks # diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py index 7abd33bd1ee..99a5f87e71f 100644 --- a/mne/simulation/raw.py +++ b/mne/simulation/raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Mark Wronkiewicz # Yousra Bekhti # Eric Larson diff --git a/mne/stats/_adjacency.py b/mne/stats/_adjacency.py index a59450505d6..d3f54525e4b 100644 --- a/mne/stats/_adjacency.py +++ b/mne/stats/_adjacency.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Authors: Eric Larson # Stefan Appelhoff # diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index 1b90b15ee79..1a6c12c092c 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # Authors: Thorsten Kranz # Alexandre Gramfort diff --git a/mne/stats/tests/test_adjacency.py b/mne/stats/tests/test_adjacency.py index 5cf0b04e6b5..837c1a60bbf 100644 --- a/mne/stats/tests/test_adjacency.py +++ b/mne/stats/tests/test_adjacency.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Authors: Eric Larson # # License: Simplified BSD diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index c04b6ed862e..f92752559b2 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Eric Larson # # License: BSD-3-Clause diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 52445dc64ec..3f2fce63bb3 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Alexandre Gramfort # Denis Engemann # Stefan Appelhoff diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py index 63500362c1b..cb40c670849 100644 --- a/mne/tests/test_event.py +++ b/mne/tests/test_event.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Alexandre Gramfort # Eric Larson # diff --git a/mne/tests/test_morph.py b/mne/tests/test_morph.py index 00f1bd7d1fb..011936c7997 100644 --- a/mne/tests/test_morph.py +++ b/mne/tests/test_morph.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Tommy Clausner # # License: BSD-3-Clause diff --git a/mne/tests/test_parallel.py b/mne/tests/test_parallel.py index aabaa32f223..8a5e1b35f56 100644 --- a/mne/tests/test_parallel.py +++ b/mne/tests/test_parallel.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Eric Larson # # License: BSD-3-Clause diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 7b26c416c94..ec6da53cf56 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # License: BSD-3-Clause diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index 030474fbf51..364e250284a 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alexandre Gramfort # Eric Larson # @@ -940,7 +939,6 @@ def test_get_decimated_surfaces(src, n, nv): # Unfortunately the C code bombs when trying to add source space distances, # possibly due to incomplete "faking" of a smaller surface on our part here. """ -# -*- coding: utf-8 -*- import os import numpy as np diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py index c01894a51e0..ef2760b6728 100644 --- a/mne/tests/test_transforms.py +++ b/mne/tests/test_transforms.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Alexandre Gramfort # # License: BSD-3-Clause diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index d8680df0047..54b162e6eaf 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Marijn van Vliet # Susanna Aro # Roman Goj diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index b2172e67b79..d4bbf98be3b 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Container classes for spectral data.""" # Authors: Dan McCloy diff --git a/mne/time_frequency/tests/test_multitaper.py b/mne/time_frequency/tests/test_multitaper.py index 592c07adb7c..ec0941b1826 100644 --- a/mne/time_frequency/tests/test_multitaper.py +++ b/mne/time_frequency/tests/test_multitaper.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import numpy as np import pytest from numpy.testing import assert_array_almost_equal diff --git a/mne/transforms.py b/mne/transforms.py index 357c299a04b..c896270797c 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Helpers for various transformations.""" # Authors: Alexandre Gramfort diff --git a/mne/utils/_bunch.py b/mne/utils/_bunch.py index 13c6c6f1e02..2f7dad6c1b0 100644 --- a/mne/utils/_bunch.py +++ b/mne/utils/_bunch.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Bunch-related classes.""" # Authors: Alexandre Gramfort # Eric Larson diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 33f33d72c8b..62546e48842 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some utility functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 1417b9c5c13..365983debf8 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Testing functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/check.py b/mne/utils/check.py index 4ed7ed38566..cb4459e9e26 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The check functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/config.py b/mne/utils/config.py index 1ee8417a1f3..8dd03dc22ac 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The config functions.""" # Authors: Eric Larson # diff --git a/mne/utils/dataframe.py b/mne/utils/dataframe.py index 69e05a5f451..a2bd1ea814f 100644 --- a/mne/utils/dataframe.py +++ b/mne/utils/dataframe.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """inst.to_data_frame() helper functions.""" # Authors: Daniel McCloy # diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 96398425ad4..a447eda863e 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The documentation functions.""" # Authors: Eric Larson # diff --git a/mne/utils/fetching.py b/mne/utils/fetching.py index 0cc59551f51..db8be08a806 100644 --- a/mne/utils/fetching.py +++ b/mne/utils/fetching.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """File downloading functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/linalg.py b/mne/utils/linalg.py index 137774e77b9..78d4c8a68c9 100644 --- a/mne/utils/linalg.py +++ b/mne/utils/linalg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Utility functions to speed up linear algebraic operations. In general, things like np.dot and linalg.svd should be used directly diff --git a/mne/utils/misc.py b/mne/utils/misc.py index e8b88f91e68..723a8564eeb 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some miscellaneous utility functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 4828129b64e..9cc0a735b10 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some utility functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index a6c4a7fa734..6e18b487bdd 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some utility functions.""" # Authors: Alexandre Gramfort # Clemens Brunner diff --git a/mne/utils/progressbar.py b/mne/utils/progressbar.py index 20a14e3b169..8c6be37c70a 100644 --- a/mne/utils/progressbar.py +++ b/mne/utils/progressbar.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Some utility functions.""" # Authors: Alexandre Gramfort # diff --git a/mne/utils/tests/test_bunch.py b/mne/utils/tests/test_bunch.py index 3a6e71d6325..e69bd3d1a67 100644 --- a/mne/utils/tests/test_bunch.py +++ b/mne/utils/tests/test_bunch.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Clemens Brunner # Eric Larson # diff --git a/mne/utils/tests/test_progressbar.py b/mne/utils/tests/test_progressbar.py index 06f78b11a91..64f039d4725 100644 --- a/mne/utils/tests/test_progressbar.py +++ b/mne/utils/tests/test_progressbar.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # # License: BSD-3-Clause diff --git a/mne/utils/tests/test_testing.py b/mne/utils/tests/test_testing.py index 720f2c7efe4..9e65037208a 100644 --- a/mne/utils/tests/test_testing.py +++ b/mne/utils/tests/test_testing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Eric Larson # # License: BSD-3-Clause diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 72d583f0995..3ed77c1c5e7 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Functions to make 3D plots with M/EEG data.""" # Authors: Alexandre Gramfort diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 12c36c4ec73..03511eafb33 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Classes to handle overlapping surfaces.""" # Authors: Guillaume Favelier diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 230b99a98d9..2864d7ecbe3 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: Alexandre Gramfort # Eric Larson diff --git a/mne/viz/_brain/tests/test_notebook.py b/mne/viz/_brain/tests/test_notebook.py index 7560f4fc8b7..95029cc0459 100644 --- a/mne/viz/_brain/tests/test_notebook.py +++ b/mne/viz/_brain/tests/test_notebook.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: Guillaume Favelier # Eric Larson diff --git a/mne/viz/_dipole.py b/mne/viz/_dipole.py index 40d90cade5b..cb5bdb4622b 100644 --- a/mne/viz/_dipole.py +++ b/mne/viz/_dipole.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Dipole viz specific functions.""" # Authors: Eric Larson diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 87474d65b0b..c4d1bd3e0fc 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Base classes and functions for 2D browser backends.""" # Authors: Daniel McCloy diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index ab5f4e76c67..eb4a22d7671 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Figure classes for MNE-Python's 2D plots. Class Hierarchy diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index c9521690fa0..be4a16e50ab 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Authors: Alexandre Gramfort # Eric Larson diff --git a/mne/viz/backends/tests/test_abstract.py b/mne/viz/backends/tests/test_abstract.py index 967084f2f0c..987f794e7ae 100644 --- a/mne/viz/backends/tests/test_abstract.py +++ b/mne/viz/backends/tests/test_abstract.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Authors: Alex Rockhill # # License: Simplified BSD diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 8c9cc4564b4..bf67802626b 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Functions to plot evoked M/EEG data (besides topographies).""" # Authors: Alexandre Gramfort diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 2a14ba50f04..08f7403df9a 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Functions to make simple plots with M/EEG data.""" # Authors: Alexandre Gramfort diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 6a6c353b321..aff4e8a7343 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Utility functions for plotting M/EEG data.""" # Authors: Alexandre Gramfort diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py index 2bd21834574..16f793a5e4b 100644 --- a/tutorials/clinical/10_ieeg_localize.py +++ b/tutorials/clinical/10_ieeg_localize.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-ieeg-localize: diff --git a/tutorials/clinical/20_seeg.py b/tutorials/clinical/20_seeg.py index 14f803d83bc..dcaafed58f8 100644 --- a/tutorials/clinical/20_seeg.py +++ b/tutorials/clinical/20_seeg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-working-with-seeg: diff --git a/tutorials/clinical/30_ecog.py b/tutorials/clinical/30_ecog.py index 714474a81c0..bebe7b6e67d 100644 --- a/tutorials/clinical/30_ecog.py +++ b/tutorials/clinical/30_ecog.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-working-with-ecog: diff --git a/tutorials/clinical/60_sleep.py b/tutorials/clinical/60_sleep.py index 2877ab8546a..05f01645200 100644 --- a/tutorials/clinical/60_sleep.py +++ b/tutorials/clinical/60_sleep.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-sleep-stage-classif: diff --git a/tutorials/epochs/10_epochs_overview.py b/tutorials/epochs/10_epochs_overview.py index eefd617be0a..92e8f03ebac 100644 --- a/tutorials/epochs/10_epochs_overview.py +++ b/tutorials/epochs/10_epochs_overview.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-epochs-class: diff --git a/tutorials/epochs/15_baseline_regression.py b/tutorials/epochs/15_baseline_regression.py index cc2aaa5e45f..afb3d286c7a 100644 --- a/tutorials/epochs/15_baseline_regression.py +++ b/tutorials/epochs/15_baseline_regression.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _ex-baseline-regression: diff --git a/tutorials/epochs/20_visualize_epochs.py b/tutorials/epochs/20_visualize_epochs.py index 32046990f7c..1fe1459b294 100644 --- a/tutorials/epochs/20_visualize_epochs.py +++ b/tutorials/epochs/20_visualize_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-visualize-epochs: diff --git a/tutorials/epochs/30_epochs_metadata.py b/tutorials/epochs/30_epochs_metadata.py index d3ef76a0f6e..4f1c0c55638 100644 --- a/tutorials/epochs/30_epochs_metadata.py +++ b/tutorials/epochs/30_epochs_metadata.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-epochs-metadata: diff --git a/tutorials/epochs/40_autogenerate_metadata.py b/tutorials/epochs/40_autogenerate_metadata.py index 9e07769b6b1..d37ab5ba997 100644 --- a/tutorials/epochs/40_autogenerate_metadata.py +++ b/tutorials/epochs/40_autogenerate_metadata.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-autogenerate-metadata: diff --git a/tutorials/epochs/50_epochs_to_data_frame.py b/tutorials/epochs/50_epochs_to_data_frame.py index 0562e58bc83..9253c9ed9fd 100644 --- a/tutorials/epochs/50_epochs_to_data_frame.py +++ b/tutorials/epochs/50_epochs_to_data_frame.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-epochs-dataframe: diff --git a/tutorials/epochs/60_make_fixed_length_epochs.py b/tutorials/epochs/60_make_fixed_length_epochs.py index 34a90a8d75c..fdffdfaa3ca 100644 --- a/tutorials/epochs/60_make_fixed_length_epochs.py +++ b/tutorials/epochs/60_make_fixed_length_epochs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-fixed-length-epochs: diff --git a/tutorials/evoked/10_evoked_overview.py b/tutorials/evoked/10_evoked_overview.py index 50e29d849be..35c2bdb7f66 100644 --- a/tutorials/evoked/10_evoked_overview.py +++ b/tutorials/evoked/10_evoked_overview.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-evoked-class: diff --git a/tutorials/evoked/20_visualize_evoked.py b/tutorials/evoked/20_visualize_evoked.py index da08320e975..14864713c4e 100644 --- a/tutorials/evoked/20_visualize_evoked.py +++ b/tutorials/evoked/20_visualize_evoked.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-visualize-evoked: diff --git a/tutorials/evoked/30_eeg_erp.py b/tutorials/evoked/30_eeg_erp.py index 397908a2e81..0271ebe0037 100644 --- a/tutorials/evoked/30_eeg_erp.py +++ b/tutorials/evoked/30_eeg_erp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-erp: diff --git a/tutorials/evoked/40_whitened.py b/tutorials/evoked/40_whitened.py index a214ace7dd1..eb701528934 100644 --- a/tutorials/evoked/40_whitened.py +++ b/tutorials/evoked/40_whitened.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-whitened: diff --git a/tutorials/forward/10_background_freesurfer.py b/tutorials/forward/10_background_freesurfer.py index bd002a27082..d21a0e3e7a3 100644 --- a/tutorials/forward/10_background_freesurfer.py +++ b/tutorials/forward/10_background_freesurfer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-freesurfer-reconstruction: diff --git a/tutorials/forward/20_source_alignment.py b/tutorials/forward/20_source_alignment.py index 908312fe953..582e5a3c084 100644 --- a/tutorials/forward/20_source_alignment.py +++ b/tutorials/forward/20_source_alignment.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-source-alignment: diff --git a/tutorials/forward/25_automated_coreg.py b/tutorials/forward/25_automated_coreg.py index 4d726adba13..e2653cfcd8e 100644 --- a/tutorials/forward/25_automated_coreg.py +++ b/tutorials/forward/25_automated_coreg.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-auto-coreg: diff --git a/tutorials/forward/30_forward.py b/tutorials/forward/30_forward.py index 6e22dd575d0..3fb5310f55f 100644 --- a/tutorials/forward/30_forward.py +++ b/tutorials/forward/30_forward.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-forward: diff --git a/tutorials/forward/35_eeg_no_mri.py b/tutorials/forward/35_eeg_no_mri.py index 5719f1e47db..be8163b5cf8 100644 --- a/tutorials/forward/35_eeg_no_mri.py +++ b/tutorials/forward/35_eeg_no_mri.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-eeg-fsaverage-source-modeling: diff --git a/tutorials/forward/50_background_freesurfer_mne.py b/tutorials/forward/50_background_freesurfer_mne.py index b4423a5fde0..a204272b57f 100644 --- a/tutorials/forward/50_background_freesurfer_mne.py +++ b/tutorials/forward/50_background_freesurfer_mne.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-freesurfer-mne: diff --git a/tutorials/forward/80_fix_bem_in_blender.py b/tutorials/forward/80_fix_bem_in_blender.py index 65fa89de72b..3e57bd12bbf 100644 --- a/tutorials/forward/80_fix_bem_in_blender.py +++ b/tutorials/forward/80_fix_bem_in_blender.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-fix-meshes: diff --git a/tutorials/forward/90_compute_covariance.py b/tutorials/forward/90_compute_covariance.py index b113a97a787..2a538c9f14b 100644 --- a/tutorials/forward/90_compute_covariance.py +++ b/tutorials/forward/90_compute_covariance.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-compute-covariance: diff --git a/tutorials/intro/10_overview.py b/tutorials/intro/10_overview.py index 11f4bfc04ca..3a7be707d78 100644 --- a/tutorials/intro/10_overview.py +++ b/tutorials/intro/10_overview.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-overview: diff --git a/tutorials/intro/15_inplace.py b/tutorials/intro/15_inplace.py index 22c52574dbb..52d79d95f11 100644 --- a/tutorials/intro/15_inplace.py +++ b/tutorials/intro/15_inplace.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-inplace: diff --git a/tutorials/intro/20_events_from_raw.py b/tutorials/intro/20_events_from_raw.py index 31e94ecb085..7b43830917f 100644 --- a/tutorials/intro/20_events_from_raw.py +++ b/tutorials/intro/20_events_from_raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-events-vs-annotations: diff --git a/tutorials/intro/30_info.py b/tutorials/intro/30_info.py index 683c8c228ed..0fa4326a692 100644 --- a/tutorials/intro/30_info.py +++ b/tutorials/intro/30_info.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-info-class: diff --git a/tutorials/intro/40_sensor_locations.py b/tutorials/intro/40_sensor_locations.py index c3812efec7e..55d60e8a06c 100644 --- a/tutorials/intro/40_sensor_locations.py +++ b/tutorials/intro/40_sensor_locations.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-sensor-locations: diff --git a/tutorials/intro/50_configure_mne.py b/tutorials/intro/50_configure_mne.py index 31e6947888d..e12a6a97c69 100644 --- a/tutorials/intro/50_configure_mne.py +++ b/tutorials/intro/50_configure_mne.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-configure-mne: diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index c399fcaca3e..7b8ee2fc1f6 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-report: diff --git a/tutorials/inverse/10_stc_class.py b/tutorials/inverse/10_stc_class.py index 672efa2c539..3f659427f65 100644 --- a/tutorials/inverse/10_stc_class.py +++ b/tutorials/inverse/10_stc_class.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-source-estimate-class: diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index 6a00a3a68ba..9d64cada306 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-ecd-dipole: diff --git a/tutorials/inverse/30_mne_dspm_loreta.py b/tutorials/inverse/30_mne_dspm_loreta.py index c9438b2e756..7e70dcf9497 100644 --- a/tutorials/inverse/30_mne_dspm_loreta.py +++ b/tutorials/inverse/30_mne_dspm_loreta.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-inverse-methods: diff --git a/tutorials/inverse/35_dipole_orientations.py b/tutorials/inverse/35_dipole_orientations.py index 86c3f142e17..efe229bb5b8 100644 --- a/tutorials/inverse/35_dipole_orientations.py +++ b/tutorials/inverse/35_dipole_orientations.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-dipole-orientations: diff --git a/tutorials/inverse/40_mne_fixed_free.py b/tutorials/inverse/40_mne_fixed_free.py index 40877002f83..5a7b0f113e3 100644 --- a/tutorials/inverse/40_mne_fixed_free.py +++ b/tutorials/inverse/40_mne_fixed_free.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-mne-fixed-free: diff --git a/tutorials/inverse/50_beamformer_lcmv.py b/tutorials/inverse/50_beamformer_lcmv.py index 6b2dd21abfd..26886848117 100644 --- a/tutorials/inverse/50_beamformer_lcmv.py +++ b/tutorials/inverse/50_beamformer_lcmv.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-lcmv-beamformer: diff --git a/tutorials/inverse/60_visualize_stc.py b/tutorials/inverse/60_visualize_stc.py index 478a03d7343..e30e48b9ed5 100644 --- a/tutorials/inverse/60_visualize_stc.py +++ b/tutorials/inverse/60_visualize_stc.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-viz-stcs: diff --git a/tutorials/inverse/70_eeg_mri_coords.py b/tutorials/inverse/70_eeg_mri_coords.py index 6a091c08162..3cd4a3dd924 100644 --- a/tutorials/inverse/70_eeg_mri_coords.py +++ b/tutorials/inverse/70_eeg_mri_coords.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-eeg-mri-coords: diff --git a/tutorials/inverse/80_brainstorm_phantom_elekta.py b/tutorials/inverse/80_brainstorm_phantom_elekta.py index 1ca67ed2e00..4cad68431eb 100644 --- a/tutorials/inverse/80_brainstorm_phantom_elekta.py +++ b/tutorials/inverse/80_brainstorm_phantom_elekta.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-brainstorm-elekta-phantom: diff --git a/tutorials/inverse/85_brainstorm_phantom_ctf.py b/tutorials/inverse/85_brainstorm_phantom_ctf.py index fbb1fde6c80..362878bc9a4 100644 --- a/tutorials/inverse/85_brainstorm_phantom_ctf.py +++ b/tutorials/inverse/85_brainstorm_phantom_ctf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _plot_brainstorm_phantom_ctf: diff --git a/tutorials/inverse/90_phantom_4DBTi.py b/tutorials/inverse/90_phantom_4DBTi.py index c379ffe99fe..1fa9db12331 100644 --- a/tutorials/inverse/90_phantom_4DBTi.py +++ b/tutorials/inverse/90_phantom_4DBTi.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-phantom-4Dbti: diff --git a/tutorials/io/10_reading_meg_data.py b/tutorials/io/10_reading_meg_data.py index 10f306949c0..18fd458a45e 100644 --- a/tutorials/io/10_reading_meg_data.py +++ b/tutorials/io/10_reading_meg_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-imorting-meg-data: diff --git a/tutorials/io/20_reading_eeg_data.py b/tutorials/io/20_reading_eeg_data.py index dbb2a69010b..dd373424667 100644 --- a/tutorials/io/20_reading_eeg_data.py +++ b/tutorials/io/20_reading_eeg_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-imorting-eeg-data: diff --git a/tutorials/io/30_reading_fnirs_data.py b/tutorials/io/30_reading_fnirs_data.py index 0362fb6c8e3..31036726ed7 100644 --- a/tutorials/io/30_reading_fnirs_data.py +++ b/tutorials/io/30_reading_fnirs_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-importing-fnirs-data: diff --git a/tutorials/io/60_ctf_bst_auditory.py b/tutorials/io/60_ctf_bst_auditory.py index 62b9619581e..abef8b4394d 100644 --- a/tutorials/io/60_ctf_bst_auditory.py +++ b/tutorials/io/60_ctf_bst_auditory.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-brainstorm-auditory: diff --git a/tutorials/machine-learning/30_strf.py b/tutorials/machine-learning/30_strf.py index 6620ede3278..e5ca218ffa3 100644 --- a/tutorials/machine-learning/30_strf.py +++ b/tutorials/machine-learning/30_strf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-strf: diff --git a/tutorials/machine-learning/50_decoding.py b/tutorials/machine-learning/50_decoding.py index 3fb0036a21f..1b61045e7d1 100644 --- a/tutorials/machine-learning/50_decoding.py +++ b/tutorials/machine-learning/50_decoding.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-mvpa: diff --git a/tutorials/preprocessing/10_preprocessing_overview.py b/tutorials/preprocessing/10_preprocessing_overview.py index 84dcf0e95eb..c07a05cc3e9 100644 --- a/tutorials/preprocessing/10_preprocessing_overview.py +++ b/tutorials/preprocessing/10_preprocessing_overview.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-artifact-overview: diff --git a/tutorials/preprocessing/15_handling_bad_channels.py b/tutorials/preprocessing/15_handling_bad_channels.py index db1ff0dda4a..3d907aba39d 100644 --- a/tutorials/preprocessing/15_handling_bad_channels.py +++ b/tutorials/preprocessing/15_handling_bad_channels.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-bad-channels: diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index aec2532ff7f..a8080536992 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-reject-data-spans: diff --git a/tutorials/preprocessing/25_background_filtering.py b/tutorials/preprocessing/25_background_filtering.py index 998e8919d7a..a72423ab061 100644 --- a/tutorials/preprocessing/25_background_filtering.py +++ b/tutorials/preprocessing/25_background_filtering.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _disc-filtering: diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index 4adff1129bf..758bdad1cdf 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-filter-resample: diff --git a/tutorials/preprocessing/35_artifact_correction_regression.py b/tutorials/preprocessing/35_artifact_correction_regression.py index 9a81f74c53b..b9761c8c104 100644 --- a/tutorials/preprocessing/35_artifact_correction_regression.py +++ b/tutorials/preprocessing/35_artifact_correction_regression.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-artifact-regression: diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index d6511baba9c..51e353dcae7 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-artifact-ica: diff --git a/tutorials/preprocessing/45_projectors_background.py b/tutorials/preprocessing/45_projectors_background.py index 100734828f4..cfbde2d4ed7 100644 --- a/tutorials/preprocessing/45_projectors_background.py +++ b/tutorials/preprocessing/45_projectors_background.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-projectors-background: diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index b54e37b261b..19e4c3f9d91 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-artifact-ssp: diff --git a/tutorials/preprocessing/55_setting_eeg_reference.py b/tutorials/preprocessing/55_setting_eeg_reference.py index db728b72052..2a19a66a63b 100644 --- a/tutorials/preprocessing/55_setting_eeg_reference.py +++ b/tutorials/preprocessing/55_setting_eeg_reference.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-set-eeg-ref: diff --git a/tutorials/preprocessing/59_head_positions.py b/tutorials/preprocessing/59_head_positions.py index fa66228ccd6..7f4531e4f39 100644 --- a/tutorials/preprocessing/59_head_positions.py +++ b/tutorials/preprocessing/59_head_positions.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-head-pos: diff --git a/tutorials/preprocessing/60_maxwell_filtering_sss.py b/tutorials/preprocessing/60_maxwell_filtering_sss.py index 53586f90c26..0d17a0c78fc 100644 --- a/tutorials/preprocessing/60_maxwell_filtering_sss.py +++ b/tutorials/preprocessing/60_maxwell_filtering_sss.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-artifact-sss: diff --git a/tutorials/preprocessing/70_fnirs_processing.py b/tutorials/preprocessing/70_fnirs_processing.py index 3cecd7647f8..922c76f7086 100644 --- a/tutorials/preprocessing/70_fnirs_processing.py +++ b/tutorials/preprocessing/70_fnirs_processing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-fnirs-processing: diff --git a/tutorials/preprocessing/80_opm_processing.py b/tutorials/preprocessing/80_opm_processing.py index 93c5d1eaadf..ac9d80bc692 100644 --- a/tutorials/preprocessing/80_opm_processing.py +++ b/tutorials/preprocessing/80_opm_processing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-opm-processing: diff --git a/tutorials/raw/10_raw_overview.py b/tutorials/raw/10_raw_overview.py index 728217a3128..512ca999b3f 100644 --- a/tutorials/raw/10_raw_overview.py +++ b/tutorials/raw/10_raw_overview.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-raw-class: diff --git a/tutorials/raw/20_event_arrays.py b/tutorials/raw/20_event_arrays.py index 16ac2e3a59b..e6b483bc7d6 100644 --- a/tutorials/raw/20_event_arrays.py +++ b/tutorials/raw/20_event_arrays.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-event-arrays: diff --git a/tutorials/raw/30_annotate_raw.py b/tutorials/raw/30_annotate_raw.py index b790c03b95f..8ca0da314b5 100644 --- a/tutorials/raw/30_annotate_raw.py +++ b/tutorials/raw/30_annotate_raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-annotate-raw: diff --git a/tutorials/raw/40_visualize_raw.py b/tutorials/raw/40_visualize_raw.py index 6fb07017fae..7b42b2629d2 100644 --- a/tutorials/raw/40_visualize_raw.py +++ b/tutorials/raw/40_visualize_raw.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-visualize-raw: diff --git a/tutorials/simulation/10_array_objs.py b/tutorials/simulation/10_array_objs.py index 739a90fac42..fc1ef80f121 100644 --- a/tutorials/simulation/10_array_objs.py +++ b/tutorials/simulation/10_array_objs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-creating-data-structures: diff --git a/tutorials/simulation/70_point_spread.py b/tutorials/simulation/70_point_spread.py index b0ed4372cda..e31652bec9c 100644 --- a/tutorials/simulation/70_point_spread.py +++ b/tutorials/simulation/70_point_spread.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-point-spread: diff --git a/tutorials/simulation/80_dics.py b/tutorials/simulation/80_dics.py index c4c90418b30..71fd32a0210 100644 --- a/tutorials/simulation/80_dics.py +++ b/tutorials/simulation/80_dics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-dics: diff --git a/tutorials/stats-sensor-space/10_background_stats.py b/tutorials/stats-sensor-space/10_background_stats.py index ae790fc7eae..9360f915a9e 100644 --- a/tutorials/stats-sensor-space/10_background_stats.py +++ b/tutorials/stats-sensor-space/10_background_stats.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _disc-stats: diff --git a/tutorials/stats-sensor-space/20_erp_stats.py b/tutorials/stats-sensor-space/20_erp_stats.py index 9fdd95a1bbe..7cfc7fd14ce 100644 --- a/tutorials/stats-sensor-space/20_erp_stats.py +++ b/tutorials/stats-sensor-space/20_erp_stats.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-erp-stats: diff --git a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py index 63ad85b9d3e..08fda2a59b4 100644 --- a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py +++ b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-one-samp-tfr: diff --git a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py index dc610f719b0..7c911270171 100644 --- a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py +++ b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-tfr: diff --git a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py index 20985c99577..098f30e9ca6 100644 --- a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py +++ b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-timefreq-twoway-anova: diff --git a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py index 79d7341b153..2bee00db86c 100644 --- a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py +++ b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-spatiotemporal-sensor: diff --git a/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py b/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py index 59d07bc1a62..a21b16d3aef 100644 --- a/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py +++ b/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-one-sample-spatiotemporal: diff --git a/tutorials/stats-source-space/30_cluster_ftest_spatiotemporal.py b/tutorials/stats-source-space/30_cluster_ftest_spatiotemporal.py index d64f73da454..d3c5fb9c78c 100644 --- a/tutorials/stats-source-space/30_cluster_ftest_spatiotemporal.py +++ b/tutorials/stats-source-space/30_cluster_ftest_spatiotemporal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-spatiotemporal-source: diff --git a/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py b/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py index a4c8d7dc8cd..5ee29687439 100644 --- a/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py +++ b/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-cluster-rm-anova-spatiotemporal: diff --git a/tutorials/time-freq/10_spectrum_class.py b/tutorials/time-freq/10_spectrum_class.py index 52c56b45668..bb234fc5be8 100644 --- a/tutorials/time-freq/10_spectrum_class.py +++ b/tutorials/time-freq/10_spectrum_class.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # noqa: E501 """ .. _tut-spectrum-class: diff --git a/tutorials/time-freq/20_sensors_time_frequency.py b/tutorials/time-freq/20_sensors_time_frequency.py index de92a9fb4e0..47d6598e114 100644 --- a/tutorials/time-freq/20_sensors_time_frequency.py +++ b/tutorials/time-freq/20_sensors_time_frequency.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-sensors-time-freq: diff --git a/tutorials/time-freq/50_ssvep.py b/tutorials/time-freq/50_ssvep.py index 10d09b89d6b..8ce663a01b3 100644 --- a/tutorials/time-freq/50_ssvep.py +++ b/tutorials/time-freq/50_ssvep.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ .. _tut-ssvep: From 6de2197a8c59f5b132c4769ff66706f1720c9d56 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Mar 2023 09:53:28 -0600 Subject: [PATCH 0009/1125] ENH: Add mne-bids-pipeline to mne sys_info (#11606) --- mne/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/utils/config.py b/mne/utils/config.py index 8dd03dc22ac..5056fcfd18a 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -567,7 +567,7 @@ def sys_info(fid=None, show_paths=False, *, dependencies='user', unicode=True): '', '# Ecosystem (optional)', 'mne-bids', 'mne-nirs', 'mne-features', 'mne-connectivity', - 'mne-icalabel', + 'mne-icalabel', 'mne-bids-pipeline', '' ) if dependencies == 'developer': From 3b4930abd32c00dc71da6c4a3e56a850a34ba185 Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Tue, 28 Mar 2023 15:03:55 -0700 Subject: [PATCH 0010/1125] [BUG, MRG] Fix topomap extra plot generated, add util to check a range (#11607) --- doc/changes/latest.inc | 1 + mne/utils/__init__.py | 2 +- mne/utils/check.py | 31 +++++++++++++++++++++++++++++++ mne/utils/tests/test_check.py | 12 +++++++++++- mne/viz/topomap.py | 3 +-- 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 2d731c83364..dba8d159b15 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -64,6 +64,7 @@ Bugs - Fix :func:`mne.io.read_raw` for file names containing multiple dots (:gh:`11521` by `Clemens Brunner`_) - Fix bug in :func:`mne.export.export_raw` when exporting to EDF with a physical range set smaller than the data range (:gh:`11569` by `Mathieu Scheltienne`_) - Fix bug in :func:`mne.concatenate_raws` where two raws could not be merged if the order of the bad channel lists did not match (:gh:`11502` by `Moritz Gerster`_) +- Fix bug where :meth:`mne.Evoked.plot_topomap` opened an extra figure (:gh:`11607` by `Alex Rockhill`_) API changes diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index 0d04c882783..2ceef298cae 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -11,7 +11,7 @@ _check_pandas_index_arguments, _check_event_id, _check_ch_locs, _check_compensation_grade, _check_if_nan, _is_numeric, _ensure_int, _check_preload, - _validate_type, _check_info_inv, + _validate_type, _check_range, _check_info_inv, _check_channels_spatial_filter, _check_one_ch_type, _check_rank, _check_option, _check_depth, _check_combine, _path_like, _check_src_normal, _check_stc_units, diff --git a/mne/utils/check.py b/mne/utils/check.py index cb4459e9e26..780458264c9 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -551,6 +551,37 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, f"got {type(item)} instead.") +def _check_range(val, min_val, max_val, name, min_inclusive=True, + max_inclusive=True): + """Check that item is within range. + + Parameters + ---------- + val : int | float + The value to be checked. + min_val : int | float + The minimum value allowed. + max_val : int | float + The maximum value allowed. + name : str + The name of the value. + min_inclusive : bool + Whether ``val`` is allowed to be ``min_val``. + max_inclusive : bool + Whether ``val`` is allowed to be ``max_val``. + """ + below_min = val < min_val if min_inclusive else val <= min_val + above_max = val > max_val if max_inclusive else val >= max_val + if below_min or above_max: + error_str = f'The value of {name} must be between {min_val} ' + if min_inclusive: + error_str += 'inclusive ' + error_str += f'and {max_val}' + if max_inclusive: + error_str += 'inclusive ' + raise ValueError(error_str) + + def _path_like(item): """Validate that `item` is `path-like`. diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 8f28ee7799a..44caa61ba10 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -18,7 +18,8 @@ from mne.utils import (check_random_state, _check_fname, check_fname, _suggest, _check_subject, _check_info_inv, _check_option, Bunch, check_version, _path_like, _validate_type, _on_missing, - _safe_input, _check_ch_locs, _check_sphere) + _safe_input, _check_ch_locs, _check_sphere, + _check_range) data_path = testing.data_path(download=False) base_dir = data_path / "MEG" / "sample" @@ -184,6 +185,15 @@ def test_validate_type(): _validate_type(False, 'int-like') +def test_check_range(): + """Test _check_range.""" + _check_range(10, 1, 100, 'value') + with pytest.raises(ValueError, match='must be between'): + _check_range(0, 1, 10, 'value') + with pytest.raises(ValueError, match='must be between'): + _check_range(1, 1, 10, 'value', False, False) + + @testing.requires_testing_data def test_suggest(): """Test suggestions.""" diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index be10dbe9502..0db592fe14c 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -908,7 +908,6 @@ def _plot_topomap( border=_BORDER_DEFAULT, res=64, cmap=None, vmin=None, vmax=None, cnorm=None, show=True, onselect=None): from matplotlib.colors import Normalize - import matplotlib.pyplot as plt from matplotlib.widgets import RectangleSelector data = np.asarray(data) logger.debug(f'Plotting topomap for {ch_type} data shape {data.shape}') @@ -1050,7 +1049,7 @@ def _plot_topomap( verticalalignment='center', size='x-small') if not axes.figure.get_constrained_layout(): - plt.subplots_adjust(top=.95) + axes.figure.subplots_adjust(top=.95) if onselect is not None: lim = axes.dataLim From 686857c55288ed91cc13cdc43bf100ab1f6c475d Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 29 Mar 2023 23:27:47 -0700 Subject: [PATCH 0011/1125] [BUG, MRG] Don't modify info in place for transform points (#11612) --- doc/changes/latest.inc | 1 + mne/transforms.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index dba8d159b15..ed0734fe08d 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -65,6 +65,7 @@ Bugs - Fix bug in :func:`mne.export.export_raw` when exporting to EDF with a physical range set smaller than the data range (:gh:`11569` by `Mathieu Scheltienne`_) - Fix bug in :func:`mne.concatenate_raws` where two raws could not be merged if the order of the bad channel lists did not match (:gh:`11502` by `Moritz Gerster`_) - Fix bug where :meth:`mne.Evoked.plot_topomap` opened an extra figure (:gh:`11607` by `Alex Rockhill`_) +- Fix bug where :func:`mne.transforms.apply_volume_registration_points` modified info in place (:gh:`11612` by `Alex Rockhill`_) API changes diff --git a/mne/transforms.py b/mne/transforms.py index c896270797c..39ab647f479 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1911,6 +1911,7 @@ def apply_volume_registration_points(info, trans, moving, static, reg_affine, montage2 = make_dig_montage(**montage_kwargs) trans2 = compute_native_head_t(montage2) - info.set_montage(montage2) # converts to head coordinates + info2 = info.copy() + info2.set_montage(montage2) # converts to head coordinates - return info, trans2 + return info2, trans2 From bfc8a3c471d271e3d31b7696fc0f9465a262223d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 4 Apr 2023 10:19:46 -0600 Subject: [PATCH 0012/1125] ENH: Speed up code a bit (#11614) --- doc/changes/latest.inc | 1 + doc/changes/names.inc | 2 +- mne/annotations.py | 10 +++++++--- mne/time_frequency/spectrum.py | 7 +++---- mne/utils/docs.py | 5 ++--- mne/viz/_brain/tests/test_brain.py | 14 ++++++++++++-- mne/viz/backends/_utils.py | 9 +++++++-- 7 files changed, 33 insertions(+), 15 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index ed0734fe08d..e6c34f0c4ce 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -40,6 +40,7 @@ Enhancements - Add automatic projection of sEEG contact onto the inflated surface for :meth:`mne.viz.Brain.add_sensors` (:gh:`11436` by `Alex Rockhill`_) - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) - Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) +- Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) Bugs diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 29d4697582c..e1e7acafdf0 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -172,7 +172,7 @@ .. _George O'Neill: https://georgeoneill.github.io -.. _Guillaume Dumas: http://www.extrospection.eu +.. _Guillaume Dumas: https://mila.quebec/en/person/guillaume-dumas .. _Guillaume Favelier: https://github.com/GuillaumeFavelier diff --git a/mne/annotations.py b/mne/annotations.py index 848a4ba4dc7..24cc4069971 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -312,16 +312,20 @@ def __iadd__(self, other): def __iter__(self): """Iterate over the annotations.""" + # Figure this out once ahead of time for consistency and speed (for + # thousands of annotations) + with_ch_names = self._any_ch_names() for idx in range(len(self.onset)): - yield self.__getitem__(idx) + yield self.__getitem__(idx, with_ch_names=with_ch_names) - def __getitem__(self, key): + def __getitem__(self, key, *, with_ch_names=None): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): out_keys = ('onset', 'duration', 'description', 'orig_time') out_vals = (self.onset[key], self.duration[key], self.description[key], self.orig_time) - if self._any_ch_names(): + if with_ch_names or (with_ch_names is None and + self._any_ch_names()): out_keys += ('ch_names',) out_vals += (self.ch_names[key],) return OrderedDict(zip(out_keys, out_vals)) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index d4bbf98be3b..0a4e05da6df 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -739,10 +739,9 @@ def to_data_frame(self, picks=None, index=None, copy=True, index : str | list of str | None Kind of index to use for the DataFrame. If ``None``, a sequential integer index (:class:`pandas.RangeIndex`) will be used. If a - :class:`str`, a :class:`pandas.Index`, :class:`pandas.Int64Index`, - or :class:`pandas.Float64Index` will be used (see Notes). If a list - of two or more string values, a :class:`pandas.MultiIndex` will be - used. Defaults to ``None``. + :class:`str`, a :class:`pandas.Index` will be used (see Notes). If + a list of two or more string values, a :class:`pandas.MultiIndex` + will be used. Defaults to ``None``. %(copy_df)s %(long_format_df_spe)s %(verbose)s diff --git a/mne/utils/docs.py b/mne/utils/docs.py index a447eda863e..f5817ae3b74 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1658,12 +1658,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): index : {} | None Kind of index to use for the DataFrame. If ``None``, a sequential integer index (:class:`pandas.RangeIndex`) will be used. If ``'time'``, a - :class:`pandas.Float64Index`, :class:`pandas.Int64Index`, {}or - :class:`pandas.TimedeltaIndex` will be used + ``pandas.Index``{} or :class:`pandas.TimedeltaIndex` will be used (depending on the value of ``time_format``). {} """ -datetime = ':class:`pandas.DatetimeIndex`, ' +datetime = ', :class:`pandas.DatetimeIndex`,' multiindex = ('If a list of two or more string values, a ' ':class:`pandas.MultiIndex` will be created. ') raw = ("'time'", datetime, '') diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 2864d7ecbe3..9c30437e649 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -886,7 +886,12 @@ def test_brain_traces(renderer_interactive_pyvistaqt, hemi, src, tmp_path, # interpolation='linear', time_viewer=True) # """, 1) - gallery_conf = dict(src_dir=str(tmp_path), compress_images=[]) + gallery_conf = dict( + src_dir=str(tmp_path), + compress_images=[], + image_srcset=[], + matplotlib_animations=False, + ) scraper = _BrainScraper() rst = scraper(block, block_vars, gallery_conf) assert brain.plotter is None # closed @@ -918,7 +923,12 @@ def test_brain_scraper(renderer_interactive_pyvistaqt, brain_gc, tmp_path): block_vars = dict(image_path_iterator=iter(fnames), example_globals=dict(brain=brain)) block = ('code', '', 1) - gallery_conf = dict(src_dir=str(tmp_path), compress_images=[]) + gallery_conf = dict( + src_dir=str(tmp_path), + compress_images=[], + image_srcset=[], + matplotlib_animations=False, + ) scraper = _BrainScraper() rst = scraper(block, block_vars, gallery_conf) assert brain.plotter is None # closed diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index be4a16e50ab..4fec863e424 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -95,6 +95,9 @@ def _qt_disable_paint(widget): widget.paintEvent = paintEvent +_QT_ICON_KEYS = dict(app=None) + + def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): """Get QApplication-instance for MNE-Python. @@ -159,10 +162,12 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): if enable_icon or splash: icons_path = _qt_init_icons() - if enable_icon: + if enable_icon and app.windowIcon().cacheKey() != _QT_ICON_KEYS['app']: # Set icon kind = 'bigsur_' if platform.mac_ver()[0] >= '10.16' else 'default_' - app.setWindowIcon(QIcon(f"{icons_path}/mne_{kind}icon.png")) + icon = QIcon(f"{icons_path}/mne_{kind}icon.png") + app.setWindowIcon(icon) + _QT_ICON_KEYS['app'] = app.windowIcon().cacheKey() out = app if splash: From a458622f4b1b7daa7685ad1d79fe132b22082ed9 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 5 Apr 2023 21:59:48 -0400 Subject: [PATCH 0013/1125] MAINT: Use git rather than zipball (#11620) --- azure-pipelines.yml | 2 +- mne/viz/backends/_utils.py | 16 +++++++++++----- requirements.txt | 2 +- tools/azure_dependencies.sh | 6 +++--- tools/circleci_dependencies.sh | 2 +- tools/github_actions_dependencies.sh | 8 ++++---- tools/setup_xvfb.sh | 2 +- 7 files changed, 22 insertions(+), 16 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index bedd0082506..ad0c3b6b627 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -116,7 +116,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip setuptools wheel codecov - python -m pip install --progress-bar off mne-qt-browser[opengl] pyvista scikit-learn pytest-error-for-skips python-picard "PySide6!=6.3.0,!=6.4.0,!=6.4.0.1" qtpy + python -m pip install --progress-bar off mne-qt-browser[opengl] pyvista scikit-learn pytest-error-for-skips python-picard "PySide6!=6.3.0,!=6.4.0,!=6.4.0.1,!=6.5.0" qtpy python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index 4fec863e424..3f75a5b7f37 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -144,15 +144,21 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): # First we need to check to make sure the display is valid, otherwise # Qt might segfault on us - if not _display_is_valid(): + app = QApplication.instance() + if not (app or _display_is_valid()): raise RuntimeError('Cannot connect to a valid display') if pg_app: from pyqtgraph import mkQApp - app = mkQApp(app_name) - else: - app = QApplication.instance() or QApplication(sys.argv or [app_name]) - app.setApplicationName(app_name) + old_argv = sys.argv + try: + sys.argv = [] + app = mkQApp(app_name) + finally: + sys.argv = old_argv + elif not app: + app = QApplication([app_name]) + app.setApplicationName(app_name) app.setOrganizationName(organization_name) try: app.setAttribute(Qt.AA_UseHighDpiPixmaps) # works on PyQt5 and PySide2 diff --git a/requirements.txt b/requirements.txt index 1c2eab9d672..eeaa71cb9cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ h5io packaging pymatreader qtpy -PySide6!=6.3.0,!=6.4.0,!=6.4.0.1 # incompat with Matplotlib 3.6.1 and qtpy +PySide6!=6.3.0,!=6.4.0,!=6.4.0.1,!=6.5.0 # incompat with Matplotlib 3.6.1 and qtpy pyobjc-framework-Cocoa>=5.2.0; platform_system=="Darwin" sip scikit-learn diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 42c836645e2..869dccabd09 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -6,7 +6,7 @@ if [ "${TEST_MODE}" == "pip" ]; then python -m pip install --upgrade --only-binary ":all:" numpy scipy vtk python -m pip install --upgrade --only-binary="numba,llvmlite" -r requirements.txt # This can be removed once PyVistaQt 0.6 is out (including https://github.com/pyvista/pyvistaqt/pull/127) - python -m pip install --upgrade https://github.com/pyvista/pyvistaqt/zipball/main + python -m pip install --upgrade git+https://github.com/pyvista/pyvistaqt elif [ "${TEST_MODE}" == "pip-pre" ]; then python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" python-dateutil pytz joblib threadpoolctl six cycler kiwisolver pyparsing patsy @@ -19,8 +19,8 @@ elif [ "${TEST_MODE}" == "pip-pre" ]; then python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://test.pypi.org/simple" openmeeg python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps vtk - python -m pip install --progress-bar off https://github.com/pyvista/pyvista/zipball/main - python -m pip install --progress-bar off https://github.com/pyvista/pyvistaqt/zipball/main + python -m pip install --progress-bar off git+https://github.com/pyvista/pyvista + python -m pip install --progress-bar off git+https://github.com/pyvista/pyvistaqt python -m pip install --progress-bar off --upgrade --pre imageio-ffmpeg xlrd mffpy python-picard patsy pillow EXTRA_ARGS="--pre" ./tools/check_qt_import.sh PyQt6 diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 11d10273d1e..760066eb20e 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -11,7 +11,7 @@ else # standard doc build echo "Installing doc build dependencies" python -m pip install --upgrade --progress-bar off PyQt6 python -m pip install --upgrade --progress-bar off --only-binary "numpy,scipy,matplotlib,pandas,statsmodels" -r requirements.txt -r requirements_testing.txt -r requirements_doc.txt - python -m pip install --upgrade --progress-bar off https://github.com/mne-tools/mne-qt-browser/zipball/main https://github.com/sphinx-gallery/sphinx-gallery/zipball/master + python -m pip install --upgrade --progress-bar off git+https://github.com/mne-tools/mne-qt-browser git+https://github.com/sphinx-gallery/sphinx-gallery # deal with comparisons and escapes (https://app.circleci.com/pipelines/github/mne-tools/mne-python/9686/workflows/3fd32b47-3254-4812-8b9a-8bab0d646d18/jobs/32934) python -m pip install --upgrade --progress-bar off quantities fi diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 29d758e6b57..02df34f11cd 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -26,15 +26,15 @@ else pip install $STD_ARGS --pre --only-binary ":all:" pillow # We don't install Numba here because it forces an old NumPy version echo "nilearn and openmeeg" - pip install $STD_ARGS --pre https://github.com/nilearn/nilearn/zipball/main + pip install $STD_ARGS --pre git+https://github.com/nilearn/nilearn pip install $STD_ARGS --pre --only-binary ":all:" -i "https://test.pypi.org/simple" openmeeg echo "VTK" pip install $STD_ARGS --pre --only-binary ":all:" vtk python -c "import vtk" echo "PyVista" - pip install --progress-bar off https://github.com/pyvista/pyvista/zipball/main + pip install --progress-bar off git+https://github.com/pyvista/pyvista echo "pyvistaqt" - pip install --progress-bar off https://github.com/pyvista/pyvistaqt/zipball/main + pip install --progress-bar off git+https://github.com/pyvista/pyvistaqt echo "imageio-ffmpeg, xlrd, mffpy, python-picard" pip install --progress-bar off --pre imageio-ffmpeg xlrd mffpy python-picard patsy if [ "$OSTYPE" == "darwin"* ]; then @@ -42,7 +42,7 @@ else pip install --progress-bar off pyobjc-framework-Cocoa>=5.2.0 fi echo "mne-qt-browser" - pip install --progress-bar off https://github.com/mne-tools/mne-qt-browser/zipball/main + pip install --progress-bar off git+https://github.com/mne-tools/mne-qt-browser EXTRA_ARGS="--pre" fi # for compat_minimal and compat_old, we don't want to --upgrade diff --git a/tools/setup_xvfb.sh b/tools/setup_xvfb.sh index 040541eaee2..a5c55d0819b 100755 --- a/tools/setup_xvfb.sh +++ b/tools/setup_xvfb.sh @@ -11,5 +11,5 @@ done # This also includes the libraries necessary for PyQt5/PyQt6 sudo apt update -sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 +sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0 /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset From b25f2505727eb7a79cd4955f044f7b027c57ab43 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Apr 2023 11:18:20 -0600 Subject: [PATCH 0014/1125] API: One cycle of backward compat (#11621) --- mne/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index 2ceef298cae..b67e4ba2fd3 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -23,7 +23,8 @@ _check_edflib_installed, _to_rgb, _soft_import, _check_dict_keys, _check_pymatreader_installed, _import_h5py, _import_h5io_funcs, _import_nibabel, - _import_pymatreader_funcs, _check_head_radius) + _import_pymatreader_funcs, _check_head_radius, + has_nibabel) from .config import (set_config, get_config, get_config_path, set_cache_dir, set_memmap_min_size, get_subjects_dir, _get_stim_channel, sys_info, _get_extra_data_path, _get_root_dir, From 5d1aa110fa0d889a98171131609e59747fc52d65 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 11 Apr 2023 11:43:59 -0600 Subject: [PATCH 0015/1125] MAINT: Add token [ci skip] (#11622) --- .github/workflows/circle_artifacts.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index c444a1e9bae..96a4264627c 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -9,6 +9,7 @@ jobs: uses: larsoner/circleci-artifacts-redirector-action@master with: repo-token: ${{ secrets.GITHUB_TOKEN }} + api-token: ${{ secrets.CIRCLECI_TOKEN }} artifact-path: 0/dev/index.html circleci-jobs: build_docs,build_docs_main job-title: Check the rendered docs here! From d0cf4db4ba1f79cf31009b0c060ed64cc6b6de57 Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Thu, 13 Apr 2023 07:11:54 -0700 Subject: [PATCH 0016/1125] API: Port ieeg gui over to mne-gui-addons and add tfr gui example (#11616) Co-authored-by: Eric Larson --- azure-pipelines.yml | 6 +-- doc/changes/1.0.inc | 14 +++--- doc/changes/1.1.inc | 2 +- doc/changes/latest.inc | 7 +-- doc/conf.py | 1 + examples/inverse/evoked_ers_source_power.py | 7 ++- mne/conftest.py | 31 +++++++------ mne/gui/__init__.py | 50 ++++++++++++++------- mne/gui/tests/test_ieeg_locate.py | 4 ++ mne/time_frequency/tests/test_spectrum.py | 9 ++-- mne/utils/check.py | 1 + requirements.txt | 1 + tools/azure_dependencies.sh | 2 +- tutorials/clinical/10_ieeg_localize.py | 5 ++- 14 files changed, 88 insertions(+), 52 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ad0c3b6b627..186c081f036 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -115,7 +115,7 @@ stages: displayName: 'Get Python' - bash: | set -e - python -m pip install --progress-bar off --upgrade pip setuptools wheel codecov + python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off mne-qt-browser[opengl] pyvista scikit-learn pytest-error-for-skips python-picard "PySide6!=6.3.0,!=6.4.0,!=6.4.0.1,!=6.5.0" qtpy python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] @@ -177,7 +177,7 @@ stages: python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off --upgrade --pre --only-binary=\"numpy,scipy,matplotlib,vtk\" numpy scipy matplotlib vtk python -c "import vtk" - python -m pip install --progress-bar off --upgrade -r requirements.txt -r requirements_testing.txt -r requirements_testing_extra.txt codecov + python -m pip install --progress-bar off --upgrade -r requirements.txt -r requirements_testing.txt -r requirements_testing_extra.txt python -m pip install -e . displayName: 'Install dependencies with pip' - bash: | @@ -312,7 +312,7 @@ stages: displayName: Remove old MNE - script: pip install -e . displayName: 'Install MNE-Python dev' - - script: pip install --progress-bar off -e .[test] codecov + - script: pip install --progress-bar off -e .[test] condition: eq(variables['TEST_MODE'], 'conda') displayName: Install testing requirements - script: mne sys_info -pd diff --git a/doc/changes/1.0.inc b/doc/changes/1.0.inc index 0bfca8f5d97..cd4cb91eafd 100644 --- a/doc/changes/1.0.inc +++ b/doc/changes/1.0.inc @@ -38,9 +38,9 @@ Enhancements - :func:`mne.time_frequency.tfr_array_multitaper` can now return results for ``output='phase'`` instead of an error (:gh:`10281` by `Mikołaj Magnuski`_) -- Add show local maxima toggling button to :func:`mne.gui.locate_ieeg` (:gh:`9952` by `Alex Rockhill`_) +- Add show local maxima toggling button to ``mne.gui.locate_ieeg`` (:gh:`9952` by `Alex Rockhill`_) -- Show boundaries in :func:`mne.gui.locate_ieeg` (:gh:`10379` by `Eric Larson`_) +- Show boundaries in ``mne.gui.locate_ieeg`` (:gh:`10379` by `Eric Larson`_) - Add argument ``cval`` to :func:`mne.transforms.apply_volume_registration` to set interpolation values outside the image domain (:gh:`10379` by `Eric Larson`_) @@ -84,7 +84,7 @@ Enhancements - :meth:`mne.Epochs.plot_drop_log` now also includes the absolute number of epochs dropped in the title (:gh:`10186` by `Richard Höchenberger`_) -- Add a button to show the maximum intensity projection in :func:`mne.gui.locate_ieeg` (:gh:`10185` by `Alex Rockhill`_) +- Add a button to show the maximum intensity projection in ``mne.gui.locate_ieeg`` (:gh:`10185` by `Alex Rockhill`_) - Annotations from a :class:`~mne.io.Raw` object are now preserved by the :class:`~mne.Epochs` constructor and are supported when saving Epochs (:gh:`9969` and :gh:`10019` by `Adam Li`_) @@ -102,9 +102,9 @@ Enhancements - :class:`mne.coreg.Coregistration` gained a new attribute, ``fiducials``, allowing for convenient retrieval of the MRI fiducial points (:gh:`10243`, by `Richard Höchenberger`_) -- Added plotting points to represent contacts on the max intensity projection plot for :func:`mne.gui.locate_ieeg` (:gh:`10212` by `Alex Rockhill`_) +- Added plotting points to represent contacts on the max intensity projection plot for ``mne.gui.locate_ieeg`` (:gh:`10212` by `Alex Rockhill`_) -- Add lines in 3D and on the maximum intensity projection when more than two electrode contacts are selected to aid in identifying that contact for :func:`mne.gui.locate_ieeg` (:gh:`10212` by `Alex Rockhill`_) +- Add lines in 3D and on the maximum intensity projection when more than two electrode contacts are selected to aid in identifying that contact for ``mne.gui.locate_ieeg`` (:gh:`10212` by `Alex Rockhill`_) - Add a ``block`` parameter to :class:`mne.viz.Brain` and the UI of :class:`mne.coreg.Coregistration` to prevent the windows from closing immediately when running in a non-interactive Python session (:gh:`10222` by `Guillaume Favelier`_) @@ -122,7 +122,7 @@ Enhancements - Added :meth:`mne.viz.Brain.add_dipole` and :meth:`mne.viz.Brain.add_forward` to plot dipoles on a brain as well as :meth:`mne.viz.Brain.remove_dipole` and :meth:`mne.viz.Brain.remove_forward` (:gh:`10373` by `Alex Rockhill`_) -- Made anterior/posterior slice scrolling in :func:`mne.gui.locate_ieeg` possible for users without page up and page down buttons by allowing angle bracket buttons to be used (:gh:`10384` by `Alex Rockhill`_) +- Made anterior/posterior slice scrolling in ``mne.gui.locate_ieeg`` possible for users without page up and page down buttons by allowing angle bracket buttons to be used (:gh:`10384` by `Alex Rockhill`_) - Add support for ``theme='auto'`` for automatic dark-mode support in :meth:`raw.plot() ` and related functions and methods when using the ``'qt'`` backend (:gh:`10417` by `Eric Larson`_) @@ -246,7 +246,7 @@ Bugs - Fix bug with blank 3D rendering with MESA software rendering (:gh:`10400` by `Eric Larson`_) -- Fix a bug in :func:`mne.gui.locate_ieeg` where 2D lines on slice plots failed to update and were shown when not in maximum projection mode (:gh:`10335`, by `Alex Rockhill`_) +- Fix a bug in ``mne.gui.locate_ieeg`` where 2D lines on slice plots failed to update and were shown when not in maximum projection mode (:gh:`10335`, by `Alex Rockhill`_) - Fix misleading color scale in :ref:`tut-cluster-tfr` for the plotting of cluster F-statistics (:gh:`10393` by `Alex Rockhill`_) diff --git a/doc/changes/1.1.inc b/doc/changes/1.1.inc index fcb2be19823..2616a66c637 100644 --- a/doc/changes/1.1.inc +++ b/doc/changes/1.1.inc @@ -201,7 +201,7 @@ Bugs - Add a video on how to operate the coregistration GUI in :ref:`tut-source-alignment` (:gh:`10802` by `Alex Rockhill`_) -- Add ``show`` and ``block`` arguments to :func:`mne.gui.coregistration` and :func:`mne.gui.locate_ieeg` to pop up the GUIs and halt execution of subsequent code respectively (:gh:`10802` by `Alex Rockhill`_) +- Add ``show`` and ``block`` arguments to :func:`mne.gui.coregistration` and ``mne.gui.locate_ieeg`` to pop up the GUIs and halt execution of subsequent code respectively (:gh:`10802` by `Alex Rockhill`_) - Correctly report the number of available projections when printing measurement info in a Jupyter notebook (:gh:`10471` by `Clemens Brunner`_) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index e6c34f0c4ce..5392610e775 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -51,12 +51,12 @@ Bugs - Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_) - Fix bug where splash screen would not always disappear (:gh:`11398` by `Eric Larson`_) - Fix bug where having a different combination of volumes loaded into ``freeview`` caused different affines to be returned by :func:`mne.read_lta` for the same Linear Transform Array (LTA) (:gh:`11402` by `Alex Rockhill`_) -- Fix how :class:`mne.channels.DigMontage` is set when using :func:`mne.gui.locate_ieeg` so that :func:`mne.Info.get_montage` works and does not return ``None`` (:gh:`11421` by `Alex Rockhill`_) +- Fix how :class:`mne.channels.DigMontage` is set when using ``mne.gui.locate_ieeg`` so that :func:`mne.Info.get_montage` works and does not return ``None`` (:gh:`11421` by `Alex Rockhill`_) - Fix :func:`mne.io.read_raw_edf` when reading EDF data with different sampling rates and a mix of data channels when using ``infer_types=True`` (:gh:`11427` by `Alex Gramfort`_) - Fix how :class:`mne.channels.DigMontage` is set when using :func:`mne.preprocessing.ieeg.project_sensors_onto_brain` so that :func:`mne.Info.get_montage` works and does not return ``None`` (:gh:`11436` by `Alex Rockhill`_) - Fix configuration folder discovery on Windows, which would fail in certain edge cases; and produce a helpful error message if discovery still fails (:gh:`11441` by `Richard Höchenberger`_) - Make :class:`~mne.decoding.SlidingEstimator` and :class:`~mne.decoding.GeneralizingEstimator` respect the ``verbose`` argument. Now with ``verbose=False``, the progress bar is not shown during fitting, scoring, etc. (:gh:`11450` by `Mikołaj Magnuski`_) -- Fix bug with :func:`mne.gui.locate_ieeg` where Freesurfer ``?h.pial.T1`` was not recognized and suppress excess logging (:gh:`11489` by `Alex Rockhill`_) +- Fix bug with ``mne.gui.locate_ieeg`` where Freesurfer ``?h.pial.T1`` was not recognized and suppress excess logging (:gh:`11489` by `Alex Rockhill`_) - All functions accepting paths can now correctly handle :class:`~pathlib.Path` as input. Historically, we expected strings (instead of "proper" path objects), and only added :class:`~pathlib.Path` support in a few select places, leading to inconsistent behavior. (:gh:`11473` and :gh:`11499` by `Mathieu Scheltienne`_) - Fix visualization dialog compatibility with matplotlib 3.7 (:gh:`11409` by `Daniel McCloy`_ and `Eric Larson`_) - Expand tilde (user directory) in config keys (:gh:`11537` by `Clemens Brunner`_) @@ -71,6 +71,7 @@ Bugs API changes ~~~~~~~~~~~ +- Deprecate ``mne.gui.locate_ieeg`` in favor of :func:`mne-gui-addons:mne_gui_addons.locate_ieeg` (:gh:`11616` by `Alex Rockhill`_) - Deprecate arguments ``kind`` and ``path`` from :func:`mne.channels.read_layout` in favor of a common argument ``fname`` (:gh:`11500` by `Mathieu Scheltienne`_) -- Change ``aligned_ct`` positional argument in :func:`mne.gui.locate_ieeg` to ``base_image`` to reflect that this can now be used with unaligned images (:gh:`11567` by `Alex Rockhill`_) +- Change ``aligned_ct`` positional argument in ``mne.gui.locate_ieeg`` to ``base_image`` to reflect that this can now be used with unaligned images (:gh:`11567` by `Alex Rockhill`_) - ``mne.warp_montage_volume`` was deprecated in favor of :func:`mne.preprocessing.ieeg.warp_montage` (acts directly on points instead of using an intermediate volume) and :func:`mne.preprocessing.ieeg.make_montage_volume` (which makes a volume of ieeg contact locations which can still be useful) (:gh:`11572` by `Alex Rockhill`_) diff --git a/doc/conf.py b/doc/conf.py index 2e6680a6b05..d4fc98611cf 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -146,6 +146,7 @@ 'surfer': ('https://pysurfer.github.io/', None), 'mne_bids': ('https://mne.tools/mne-bids/stable', None), 'mne-connectivity': ('https://mne.tools/mne-connectivity/stable', None), + 'mne-gui-addons': ('https://mne.tools/mne-gui-addons', None), 'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None), 'seaborn': ('https://seaborn.pydata.org/', None), 'statsmodels': ('https://www.statsmodels.org/dev', None), diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index 0ded1fc7aff..b3ccaab5e04 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -79,9 +79,10 @@ # Weighted averaging is already in the addition of covariance objects. common_cov = baseline_cov + active_cov -mne.viz.plot_cov(baseline_cov, epochs.info) +baseline_cov.plot(epochs.info) # %% + # Compute some source estimates # ----------------------------- # Here we will use DICS, LCMV beamformer, and dSPM. @@ -152,3 +153,7 @@ def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): brain_dspm = stc_dspm.plot( hemi='rh', subjects_dir=subjects_dir, subject=subject, time_label='dSPM source power in the 12-30 Hz frequency band') + +# %% +# For more advanced usage, see +# :ref:`mne-gui-addons:sphx_glr_auto_examples_evoked_ers_source_power.py`. diff --git a/mne/conftest.py b/mne/conftest.py index 5da6406bfc3..05d433cacfc 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -13,6 +13,7 @@ import sys import warnings import pytest +from pytest import StashKey from unittest import mock import numpy as np @@ -228,13 +229,6 @@ def __init__(self, exception_handler=None, signals=None): cbook.CallbackRegistry = CallbackRegistryReraise -@pytest.fixture(scope='session') -def ci_macos(): - """Determine if running on MacOS CI.""" - return (os.getenv('CI', 'false').lower() == 'true' and - sys.platform == 'darwin') - - @pytest.fixture(scope='session') def azure_windows(): """Determine if running on Azure Windows.""" @@ -242,13 +236,6 @@ def azure_windows(): sys.platform.startswith('win')) -@pytest.fixture() -def check_gui_ci(ci_macos, azure_windows): - """Skip tests that are not reliable on CIs.""" - if azure_windows or ci_macos: - pytest.skip('Skipping GUI tests on MacOS CIs and Azure Windows') - - @pytest.fixture(scope='function') def raw_orig(): """Get raw data without any change to it from mne.io.tests.data.""" @@ -981,6 +968,22 @@ def qt_windows_closed(request): return if 'allow_unclosed_pyside2' in marks and API_NAME.lower() == 'pyside2': return + # Don't check when the test fails + report = request.node.stash[_phase_report_key] + if ("call" not in report) or report["call"].failed: + return widgets = app.topLevelWidgets() n_after = len(widgets) assert n_before == n_after, widgets[-4:] + + +# https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures # noqa: E501 +_phase_report_key = StashKey() + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """Stash the status of each item.""" + outcome = yield + rep = outcome.get_result() + item.stash.setdefault(_phase_report_key, {})[rep.when] = rep diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index c86b413b634..7dffe749732 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -4,7 +4,7 @@ # # License: BSD-3-Clause -from ..utils import verbose, get_config, warn +from ..utils import verbose, get_config, warn, deprecated @verbose @@ -200,6 +200,8 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, ) +@deprecated('Use the :mod:`mne-gui-addons:mne_gui_addons` package instead, ' + 'will be removed in version 1.5.0') @verbose def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, groups=None, show=True, block=False, verbose=None): @@ -234,19 +236,26 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, gui : instance of IntracranialElectrodeLocator The graphical user interface (GUI) window. """ - from ..viz.backends._utils import _qt_app_exec - from ._ieeg_locate import IntracranialElectrodeLocator - from qtpy.QtWidgets import QApplication - # get application - app = QApplication.instance() - if app is None: - app = QApplication(["Intracranial Electrode Locator"]) - gui = IntracranialElectrodeLocator( - info, trans, base_image, subject=subject, subjects_dir=subjects_dir, - groups=groups, show=show, verbose=verbose) - if block: - _qt_app_exec(app) - return gui + try: + import mne_gui_addons as mne_gui + except ImportError: + from ..viz.backends._utils import _qt_app_exec + from ._ieeg_locate import IntracranialElectrodeLocator + from qtpy.QtWidgets import QApplication + mne_gui = None + # get application + app = QApplication.instance() + if app is None: + app = QApplication(["Intracranial Electrode Locator"]) + gui = IntracranialElectrodeLocator( + info, trans, base_image, subject=subject, subjects_dir=subjects_dir, + groups=groups, show=show, verbose=verbose) + if block: + _qt_app_exec(app) + return mne_gui.locate_ieeg( + info=info, trans=trans, base_image=base_image, + subject=subject, subjects_dir=subjects_dir, + groups=groups, show=show, block=block) if mne_gui else gui class _GUIScraper: @@ -258,11 +267,20 @@ def __repr__(self): def __call__(self, block, block_vars, gallery_conf): from ._ieeg_locate import IntracranialElectrodeLocator from ._coreg import CoregistrationUI + gui_classes = ( + IntracranialElectrodeLocator, + CoregistrationUI, + ) + try: + from mne_gui_addons._ieeg_locate import IntracranialElectrodeLocator # noqa: E501 + except Exception: + pass + else: + gui_classes = gui_classes + (IntracranialElectrodeLocator,) from sphinx_gallery.scrapers import figure_rst from qtpy import QtGui for gui in block_vars['example_globals'].values(): - if (isinstance(gui, (IntracranialElectrodeLocator, - CoregistrationUI)) and + if (isinstance(gui, gui_classes) and not getattr(gui, '_scraped', False) and gallery_conf['builder_name'] == 'html'): gui._scraped = True # monkey-patch but it's easy enough diff --git a/mne/gui/tests/test_ieeg_locate.py b/mne/gui/tests/test_ieeg_locate.py index 7fb2c544066..6c086e73260 100644 --- a/mne/gui/tests/test_ieeg_locate.py +++ b/mne/gui/tests/test_ieeg_locate.py @@ -20,6 +20,10 @@ raw_path = sample_dir / "sample_audvis_trunc_raw.fif" fname_trans = sample_dir / "sample_audvis_trunc-trans.fif" +# Module-level ignore +pytestmark = pytest.mark.filterwarnings( + 'ignore:.*locate_ieeg.*deprecated.*:FutureWarning') + @pytest.fixture def _fake_CT_coords(skull_size=5, contact_size=2): diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 2478182a869..28a4244b789 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -10,7 +10,6 @@ from mne import Annotations from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt -from mne.utils import requires_h5py, requires_pandas def test_spectrum_errors(raw): @@ -63,10 +62,10 @@ def _get_inst(inst, request, evoked): return evoked if inst == 'evoked' else request.getfixturevalue(inst) -@requires_h5py @pytest.mark.parametrize('inst', ('raw', 'epochs', 'evoked')) def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" + pytest.importorskip('h5py') fname = tmp_path / f'{inst}-spectrum.h5' inst = _get_inst(inst, request, evoked) orig = inst.compute_psd() @@ -138,7 +137,8 @@ def _agg_helper(df, weights, group_cols): return Series(_df) -@requires_pandas +# TODO: Fix this warning +@pytest.mark.filterwarnings("ignore:.*columns to operate on.*:FutureWarning") @pytest.mark.parametrize('long_format', (False, True)) @pytest.mark.parametrize('method, output', [ ('welch', 'complex'), @@ -147,6 +147,7 @@ def _agg_helper(df, weights, group_cols): ]) def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): """Test converting complex multitaper spectra to data frame.""" + pytest.importorskip('pandas') from pandas.testing import assert_frame_equal from mne.utils.dataframe import _inplace @@ -192,10 +193,10 @@ def _fun(x): assert_frame_equal(agg_df, orig_df, check_categorical=False) -@requires_pandas @pytest.mark.parametrize('inst', ('raw', 'epochs', 'evoked')) def test_spectrum_to_data_frame(inst, request, evoked): """Test the to_data_frame method for Spectrum.""" + pytest.importorskip('pandas') from pandas.testing import assert_frame_equal # setup diff --git a/mne/utils/check.py b/mne/utils/check.py index 780458264c9..6a66fb20edc 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -352,6 +352,7 @@ def indent(x): mne_features='mne-features', mne_qt_browser='mne-qt-browser', mne_connectivity='mne-connectivity', + mne_gui_addons='mne-gui-addons', pyvista='pyvistaqt').get(name, name) try: diff --git a/requirements.txt b/requirements.txt index eeaa71cb9cc..de69e9e4574 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,3 +43,4 @@ mne-qt-browser darkdetect qdarkstyle threadpoolctl +mne-gui-addons diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 869dccabd09..0e704b5f60c 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -28,4 +28,4 @@ else echo "Unknown run type ${TEST_MODE}" exit 1 fi -python -m pip install $EXTRA_ARGS .[test,hdf5] codecov +python -m pip install $EXTRA_ARGS .[test,hdf5] diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py index 16f793a5e4b..17b39284b5b 100644 --- a/tutorials/clinical/10_ieeg_localize.py +++ b/tutorials/clinical/10_ieeg_localize.py @@ -41,6 +41,7 @@ from dipy.align import resample import mne +import mne_gui_addons as mne_gui from mne.datasets import fetch_fsaverage # paths to mne datasets: sample sEEG and FreeSurfer's fsaverage subject, @@ -345,7 +346,7 @@ def plot_overlay(image, compare, title, thresh=None): # you may want to add `block=True` to halt execution until you have interacted # with the GUI to find the channel positions, that way the raw object can # be used later in the script (e.g. saved with channel positions) -mne.gui.locate_ieeg(raw.info, subj_trans, CT_aligned, +mne_gui.locate_ieeg(raw.info, subj_trans, CT_aligned, subject='sample_seeg', subjects_dir=misc_path / 'seeg') # The `raw` object is modified to contain the channel locations @@ -370,7 +371,7 @@ def plot_overlay(image, compare, title, thresh=None): # use estimated `trans` which was used when the locations were found previously subj_trans_ecog = mne.coreg.estimate_head_mri_t( 'sample_ecog', misc_path / 'ecog') -mne.gui.locate_ieeg(raw_ecog.info, subj_trans_ecog, CT_aligned_ecog, +mne_gui.locate_ieeg(raw_ecog.info, subj_trans_ecog, CT_aligned_ecog, subject='sample_ecog', subjects_dir=misc_path / 'ecog') From 4a467838cdd3b013547a813c1772f3a481546714 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Thu, 13 Apr 2023 18:16:30 +0200 Subject: [PATCH 0017/1125] Display SVG figures correctly in Report (#11623) --- doc/changes/latest.inc | 2 +- mne/html_templates/report/image.html.jinja | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 5392610e775..67b9e61f0ff 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -67,7 +67,7 @@ Bugs - Fix bug in :func:`mne.concatenate_raws` where two raws could not be merged if the order of the bad channel lists did not match (:gh:`11502` by `Moritz Gerster`_) - Fix bug where :meth:`mne.Evoked.plot_topomap` opened an extra figure (:gh:`11607` by `Alex Rockhill`_) - Fix bug where :func:`mne.transforms.apply_volume_registration_points` modified info in place (:gh:`11612` by `Alex Rockhill`_) - +- In :class:`~mne.Report`, custom figures now show up correctly when ``image_format='svg'`` is requested (:gh:`11623` by `Richard Höchenberger`_) API changes ~~~~~~~~~~~ diff --git a/mne/html_templates/report/image.html.jinja b/mne/html_templates/report/image.html.jinja index 6f80960ccc3..06a6855ace5 100644 --- a/mne/html_templates/report/image.html.jinja +++ b/mne/html_templates/report/image.html.jinja @@ -3,7 +3,7 @@
{% if image_format == 'svg' %}
- {{ img }} + {{ img|safe }}
{% else %} {{ title }} Date: Thu, 13 Apr 2023 16:37:52 -0500 Subject: [PATCH 0018/1125] make test compatible with future pandas (#11625) --- mne/time_frequency/tests/test_spectrum.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 28a4244b789..14b68890753 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -137,8 +137,6 @@ def _agg_helper(df, weights, group_cols): return Series(_df) -# TODO: Fix this warning -@pytest.mark.filterwarnings("ignore:.*columns to operate on.*:FutureWarning") @pytest.mark.parametrize('long_format', (False, True)) @pytest.mark.parametrize('method, output', [ ('welch', 'complex'), @@ -176,8 +174,8 @@ def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): # sorting at the agg step *sigh* _inplace(orig_df, 'sort_values', by=grouping_cols, ignore_index=True) # aggregate - gb = df.drop(columns=drop_cols).groupby( - grouping_cols, as_index=False, observed=False) + df = df.drop(columns=drop_cols) + gb = df.groupby(grouping_cols, as_index=False, observed=False) if method == 'welch': if output == 'complex': def _fun(x): @@ -186,6 +184,7 @@ def _fun(x): _fun = np.nanmean agg_df = gb.aggregate(_fun) else: + gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 agg_df = gb.apply(_agg_helper, spectrum._mt_weights, grouping_cols) # even with check_categorical=False, we know that the *data* matches; # what may differ is the order of the "levels" in the *metadata* for the From fc981bd93b7a485480d0d074fe29fbfb23443609 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 14 Apr 2023 10:59:52 -0400 Subject: [PATCH 0019/1125] ENH: Allow gradient compensated data in maxwell_filter (#10554) --- doc/changes/latest.inc | 1 + mne/epochs.py | 11 ++++-- mne/io/base.py | 37 ++++++++---------- mne/preprocessing/maxwell.py | 51 ++++++++++++++++++------- mne/preprocessing/tests/test_maxwell.py | 48 ++++++++++++++++------- 5 files changed, 97 insertions(+), 51 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 67b9e61f0ff..b1b3f840a7b 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -41,6 +41,7 @@ Enhancements - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) - Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) - Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_) +- Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) Bugs diff --git a/mne/epochs.py b/mne/epochs.py index ae0f6736564..099d4651009 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3671,7 +3671,8 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, _check_usable, _col_norm_pinv, _get_n_moments, _get_mf_picks_fix_mags, _prep_mf_coils, _check_destination, - _remove_meg_projs, _get_coil_scale) + _remove_meg_projs_comps, + _get_coil_scale, _get_sensor_operator) if head_pos is None: raise TypeError('head_pos must be provided and cannot be None') from .chpi import head_pos_to_trans_rot_t @@ -3684,7 +3685,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, head_pos = head_pos_to_trans_rot_t(head_pos) trn, rot, t = head_pos del head_pos - _check_usable(epochs) + _check_usable(epochs, ignore_ref) origin = _check_origin(origin, epochs.info, 'head') recon_trans = _check_destination(destination, epochs.info, True) @@ -3697,6 +3698,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, _get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref) coil_scale, mag_scale = _get_coil_scale( meg_picks, mag_picks, grad_picks, mag_scale, info_to) + mult = _get_sensor_operator(epochs, meg_picks) n_channels, n_times = len(epochs.ch_names), len(epochs.times) other_picks = np.setdiff1d(np.arange(n_channels), meg_picks) data = np.zeros((n_channels, n_times)) @@ -3761,6 +3763,9 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, # (We would need to include external here for regularization to work) exp['ext_order'] = 0 S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans) + if mult is not None: + S_decomp = mult @ S_decomp + S_recon = mult @ S_recon exp['ext_order'] = ext_order # We could determine regularization on basis of destination basis # matrix, restricted to good channels, as regularizing individual @@ -3779,7 +3784,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, evoked = epochs._evoked_from_epoch_data(data, info_to, picks, n_events=count, kind='average', comment=epochs._name) - _remove_meg_projs(evoked) # remove MEG projectors, they won't apply now + _remove_meg_projs_comps(evoked, ignore_ref) logger.info('Created Evoked dataset from %s epochs' % (count,)) return (evoked, mapping) if return_mapping else evoked diff --git a/mne/io/base.py b/mne/io/base.py index 05645de8cf3..96e4e0f2549 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -281,8 +281,8 @@ def _dtype(self): return self._dtype_ @verbose - def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, - projector=None, verbose=None): + def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, + verbose=None): """Read a chunk of raw data. Parameters @@ -344,26 +344,22 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, # set up cals and mult (cals, compensation, and projector) n_out = len(np.arange(len(self.ch_names))[idx]) - cals = self._cals.ravel()[np.newaxis, :] - if projector is not None: - assert projector.shape[0] == projector.shape[1] == cals.shape[1] - if self._comp is not None: + cals = self._cals.ravel() + projector, comp = self._projector, self._comp + if comp is not None: + mult = comp if projector is not None: - mult = self._comp * cals - mult = np.dot(projector[idx], mult) - else: - mult = self._comp[idx] * cals - elif projector is not None: - mult = projector[idx] * cals + mult = projector @ mult else: - mult = None - del projector + mult = projector + del projector, comp if mult is None: - cals = cals.T[idx] + cals = cals[idx, np.newaxis] assert cals.shape == (n_out, 1) need_idx = idx # sufficient just to read the given channels else: + mult = mult[idx] * cals cals = None # shouldn't be used assert mult.shape == (n_out, len(self.ch_names)) # read all necessary for proj @@ -504,8 +500,7 @@ def _preload_data(self, preload): data_buffer = None logger.info('Reading %d ... %d = %9.3f ... %9.3f secs...' % (0, len(self.times) - 1, 0., self.times[-1])) - self._data = self._read_segment( - data_buffer=data_buffer, projector=self._projector) + self._data = self._read_segment(data_buffer=data_buffer) assert len(self._data) == self.info['nchan'] self.preload = True self._comp = None # no longer needed @@ -752,8 +747,7 @@ def _getitem(self, item, return_times=True): if self.preload: data = self._data[sel, start:stop] else: - data = self._read_segment(start=start, stop=stop, sel=sel, - projector=self._projector) + data = self._read_segment(start=start, stop=stop, sel=sel) if return_times: # Rather than compute the entire thing just compute the subset @@ -1669,7 +1663,7 @@ def append(self, raws, preload=None): nsamp = c_ns[-1] if not self.preload: - this_data = self._read_segment(projector=self._projector) + this_data = self._read_segment() else: this_data = self._data @@ -1681,8 +1675,7 @@ def append(self, raws, preload=None): if not raws[ri].preload: # read the data directly into the buffer data_buffer = _data[:, c_ns[ri]:c_ns[ri + 1]] - raws[ri]._read_segment(data_buffer=data_buffer, - projector=self._projector) + raws[ri]._read_segment(data_buffer=data_buffer) else: _data[:, c_ns[ri]:c_ns[ri + 1]] = raws[ri]._data self._data = _data diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index d270d716e5c..b6dba1fc21a 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -23,6 +23,7 @@ quat_to_rot, rot_to_quat) from ..forward import _concatenate_coils, _prep_meg_channels, _create_meg_coils from ..surface import _normalize_vectors +from ..io.compensator import make_compensator from ..io.constants import FIFF, FWD from ..io.meas_info import _simplify_info, Info from ..io.proc_history import _read_ctc @@ -376,7 +377,7 @@ def _prep_maxwell_filter( # triage inputs ASAP to avoid late-thrown errors _validate_type(raw, BaseRaw, 'raw') - _check_usable(raw) + _check_usable(raw, ignore_ref) _check_regularize(regularize) st_correlation = float(st_correlation) if st_correlation <= 0. or st_correlation > 1.: @@ -478,7 +479,6 @@ def _prep_maxwell_filter( exp['extended_proj'] = extended_proj del extended_proj # Reconstruct data from internal space only (Eq. 38), and rescale S_recon - S_recon /= coil_scale if recon_trans is not None: # warn if we have translated too far diff = 1000 * (info['dev_head_t']['trans'][:3, 3] - @@ -520,13 +520,20 @@ def _prep_maxwell_filter( np.zeros(3)]) else: this_pos_quat = None + + # Figure out our linear operator + mult = _get_sensor_operator(raw, meg_picks) + if mult is not None: + S_recon = mult @ S_recon + S_recon /= coil_scale + _get_this_decomp_trans = partial( _get_decomp, all_coils=all_coils, cal=calibration, regularize=regularize, exp=exp, ignore_ref=ignore_ref, coil_scale=coil_scale, grad_picks=grad_picks, mag_picks=mag_picks, good_mask=good_mask, mag_or_fine=mag_or_fine, bad_condition=bad_condition, - mag_scale=mag_scale) + mag_scale=mag_scale, mult=mult) update_kwargs.update( nchan=good_mask.sum(), st_only=st_only, recon_trans=recon_trans) params = dict( @@ -536,7 +543,7 @@ def _prep_maxwell_filter( this_pos_quat=this_pos_quat, meg_picks=meg_picks, good_mask=good_mask, grad_picks=grad_picks, head_pos=head_pos, info=info, _get_this_decomp_trans=_get_this_decomp_trans, - S_recon=S_recon, update_kwargs=update_kwargs) + S_recon=S_recon, update_kwargs=update_kwargs, ignore_ref=ignore_ref) return params @@ -544,7 +551,7 @@ def _run_maxwell_filter( raw, skip_by_annotation, st_duration, st_correlation, st_only, st_when, ctc, coil_scale, this_pos_quat, meg_picks, good_mask, grad_picks, head_pos, info, _get_this_decomp_trans, S_recon, - update_kwargs, + update_kwargs, *, ignore_ref=False, reconstruct='in', copy=True): # Eventually find_bad_channels_maxwell could be sped up by moving this # outside the loop (e.g., in the prep function) but regularization depends @@ -564,7 +571,7 @@ def _run_maxwell_filter( del raw if not st_only: # remove MEG projectors, they won't apply now - _remove_meg_projs(raw_sss) + _remove_meg_projs_comps(raw_sss, ignore_ref) # Figure out which segments of data we can use onsets, ends = _annotations_starts_stops( raw_sss, skip_by_annotation, invert=True) @@ -745,7 +752,19 @@ def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info): return coil_scale, mag_scale -def _remove_meg_projs(inst): +def _get_sensor_operator(raw, meg_picks): + comp = raw.compensation_grade + if comp not in (0, None): + mult = make_compensator(raw.info, 0, comp) + logger.info(f' Accounting for compensation grade {comp}') + assert mult.shape[0] == mult.shape[1] == len(raw.ch_names) + mult = mult[np.ix_(meg_picks, meg_picks)] + else: + mult = None + return mult + + +def _remove_meg_projs_comps(inst, ignore_ref): """Remove inplace existing MEG projectors (assumes inactive).""" meg_picks = pick_types(inst.info, meg=True, exclude=[]) meg_channels = [inst.ch_names[pi] for pi in meg_picks] @@ -754,6 +773,10 @@ def _remove_meg_projs(inst): if not any(c in meg_channels for c in proj['data']['col_names']): non_meg_proj.append(proj) inst.add_proj(non_meg_proj, remove_existing=True, verbose=False) + if ignore_ref and inst.info['comps']: + assert inst.compensation_grade in (None, 0) + with inst.info._unlock(): + inst.info['comps'] = [] def _check_destination(destination, info, head_frame): @@ -959,9 +982,9 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq): return pos -def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref, +def _get_decomp(trans, *, all_coils, cal, regularize, exp, ignore_ref, coil_scale, grad_picks, mag_picks, good_mask, mag_or_fine, - bad_condition, t, mag_scale): + bad_condition, t, mag_scale, mult): """Get a decomposition matrix and pseudoinverse matrices.""" from scipy import linalg # @@ -970,6 +993,8 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref, S_decomp_full = _get_s_decomp( exp, all_coils, trans, coil_scale, cal, ignore_ref, grad_picks, mag_picks, mag_scale) + if mult is not None: + S_decomp_full = mult @ S_decomp_full S_decomp = S_decomp_full[good_mask] # # Extended SSS basis (eSSS) @@ -1143,16 +1168,16 @@ def _check_regularize(regularize): raise ValueError('regularize must be None or "in"') -def _check_usable(inst): +def _check_usable(inst, ignore_ref): """Ensure our data are clean.""" if inst.proj: raise RuntimeError('Projectors cannot be applied to data during ' 'Maxwell filtering.') current_comp = inst.compensation_grade - if current_comp not in (0, None): + if current_comp not in (0, None) and ignore_ref: raise RuntimeError('Maxwell filter cannot be done on compensated ' - 'channels, but data have been compensated with ' - 'grade %s.' % current_comp) + 'channels (data have been compensated with ' + 'grade {current_comp}) when ignore_ref=True') def _col_norm_pinv(x): diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 06d4c193c70..233879e173a 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -303,27 +303,49 @@ def test_other_systems(): _assert_shielding(raw_sss_auto, power, 0.7) # CTF - raw_ctf = read_crop(fname_ctf_raw) - assert raw_ctf.compensation_grade == 3 - with pytest.raises(RuntimeError, match='compensated'): - maxwell_filter(raw_ctf) - raw_ctf.apply_gradient_compensation(0) + raw_ctf_3 = read_crop(fname_ctf_raw) + assert raw_ctf_3.compensation_grade == 3 + raw_ctf_0 = raw_ctf_3.copy().apply_gradient_compensation(0) + assert raw_ctf_0.compensation_grade == 0 + # 3rd-order gradient compensation works really well (better than MF here) + _assert_shielding(raw_ctf_3, raw_ctf_0, 20, 21) + origin = (0., 0., 0.04) + raw_sss_3 = maxwell_filter(raw_ctf_3, origin=origin, verbose=True) + _assert_n_free(raw_sss_3, 70) + _assert_shielding(raw_sss_3, raw_ctf_3, 0.12, 0.14) + _assert_shielding(raw_sss_3, raw_ctf_0, 2.63, 2.66) + assert raw_sss_3.compensation_grade == 3 + raw_sss_3.apply_gradient_compensation(0) + assert raw_sss_3.compensation_grade == 0 + _assert_shielding(raw_sss_3, raw_ctf_3, 0.15, 0.17) + _assert_shielding(raw_sss_3, raw_ctf_0, 3.18, 3.20) with pytest.raises(ValueError, match='digitization points'): - maxwell_filter(raw_ctf) - raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04)) - _assert_n_free(raw_sss, 68) - _assert_shielding(raw_sss, raw_ctf, 1.8) + maxwell_filter(raw_ctf_0) + raw_sss_0 = maxwell_filter(raw_ctf_0, origin=origin, verbose=True) + _assert_n_free(raw_sss_0, 68) + _assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09) + _assert_shielding(raw_sss_0, raw_ctf_0, 1.8, 1.9) + raw_sss_0.apply_gradient_compensation(3) + _assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09) + _assert_shielding(raw_sss_0, raw_ctf_0, 1.63, 1.67) + with pytest.raises(RuntimeError, match='ignore_ref'): + maxwell_filter(raw_ctf_3, ignore_ref=True) + # ignoring ref outperforms including it in maxwell filtering with catch_logging() as log: - raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04), + raw_sss = maxwell_filter(raw_ctf_0, origin=origin, ignore_ref=True, verbose=True) assert ', 12/15 out' in log.getvalue() # homogeneous fields removed _assert_n_free(raw_sss, 70) - _assert_shielding(raw_sss, raw_ctf, 12) - raw_sss_auto = maxwell_filter(raw_ctf, origin=(0., 0., 0.04), + _assert_shielding(raw_sss, raw_ctf_0, 12, 13) + # if ignore_ref=True, we remove compensators because they will not + # work the way people expect (it puts noise back in the data!) + with pytest.raises(ValueError, match='Desired compensation.*not found'): + raw_sss.copy().apply_gradient_compensation(3) + raw_sss_auto = maxwell_filter(raw_ctf_0, origin=origin, ignore_ref=True, mag_scale='auto') assert_allclose(raw_sss._data, raw_sss_auto._data) with catch_logging() as log: - maxwell_filter(raw_ctf, origin=(0., 0., 0.04), regularize=None, + maxwell_filter(raw_ctf_0, origin=origin, regularize=None, ignore_ref=True, verbose=True) assert '80/80 in, 12/15 out' in log.getvalue() # homogeneous fields From 1c301417dd3f80ea77289f2fd60a4fe585b9db8b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 14 Apr 2023 16:14:26 -0400 Subject: [PATCH 0020/1125] MAINT: Use VTK prerelease wheels in pre jobs (#11629) --- examples/preprocessing/locate_ieeg_micro.py | 3 ++- mne/tests/test_surface.py | 18 ++++++++++++------ requirements.txt | 1 - requirements_doc.txt | 9 +++++---- tools/azure_dependencies.sh | 2 +- tools/github_actions_dependencies.sh | 2 +- tutorials/clinical/10_ieeg_localize.py | 1 - 7 files changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/preprocessing/locate_ieeg_micro.py b/examples/preprocessing/locate_ieeg_micro.py index 0ab653e93f6..20142af7177 100644 --- a/examples/preprocessing/locate_ieeg_micro.py +++ b/examples/preprocessing/locate_ieeg_micro.py @@ -19,6 +19,7 @@ import numpy as np import nibabel as nib import mne +import mne_gui_addons # path to sample sEEG misc_path = mne.datasets.misc.data_path() @@ -55,7 +56,7 @@ # launch the viewer with only the CT (note, we won't be able to use # the MR in this case to help determine which brain area the contact is # in), and use the user interface to find the locations of the contacts -gui = mne.gui.locate_ieeg(raw.info, head_ct_t, CT_orig) +gui = mne_gui_addons.locate_ieeg(raw.info, head_ct_t, CT_orig) # we'll programmatically mark all the contacts on one electrode shaft for i, pos in enumerate([(-52.66, -40.84, -26.99), (-55.47, -38.03, -27.92), diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py index d20221bbd94..024e47bce34 100644 --- a/mne/tests/test_surface.py +++ b/mne/tests/test_surface.py @@ -164,7 +164,8 @@ def test_read_curv(): assert np.logical_or(bin_curv == 0, bin_curv == 1).all() -def test_decimate_surface_vtk(): +@pytest.mark.parametrize('n_tri', (4, 3, 2)) +def test_decimate_surface_vtk(n_tri): """Test triangular surface decimation.""" pytest.importorskip('pyvista') points = np.array([[-0.00686118, -0.10369860, 0.02615170], @@ -172,14 +173,17 @@ def test_decimate_surface_vtk(): [-0.00686208, -0.10368247, 0.02588313], [-0.00713987, -0.10368724, 0.02587745]]) tris = np.array([[0, 1, 2], [1, 2, 3], [0, 3, 1], [1, 2, 0]]) - for n_tri in [4, 3, 2]: # quadric decimation creates even numbered output. - _, this_tris = decimate_surface(points, tris, n_tri) - assert len(this_tris) == n_tri if not n_tri % 2 else 2 + _, this_tris = decimate_surface(points, tris, n_tri) + want = (n_tri, n_tri - 1) + if n_tri == 3: + want = want + (1,) + assert len(this_tris) in want with pytest.raises(ValueError, match='exceeds number of original'): decimate_surface(points, tris, len(tris) + 1) nirvana = 5 tris = np.array([[0, 1, 2], [1, 2, 3], [0, 3, 1], [1, 2, nirvana]]) - pytest.raises(ValueError, decimate_surface, points, tris, n_tri) + with pytest.raises(ValueError, match='undefined points'): + decimate_surface(points, tris, n_tri) @requires_freesurfer('mris_sphere') @@ -238,7 +242,9 @@ def test_marching_cubes(dtype, value, smooth): # verts and faces are rather large so use checksum rtol = 1e-2 if smooth else 1e-9 assert_allclose(verts.sum(axis=0), [14700, 14700, 14700], rtol=rtol) - assert_allclose(triangles.sum(axis=0), [363402, 360865, 350588]) + tri_sum = triangles.sum(axis=0).tolist() + # old VTK (9.2.6), new VTK + assert tri_sum in [[363402, 360865, 350588], [364089, 359867, 350408]] # test fill holes data[24:27, 24:27, 24:27] = 0 verts, triangles = _marching_cubes(data, level, smooth=smooth, diff --git a/requirements.txt b/requirements.txt index de69e9e4574..eeaa71cb9cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,4 +43,3 @@ mne-qt-browser darkdetect qdarkstyle threadpoolctl -mne-gui-addons diff --git a/requirements_doc.txt b/requirements_doc.txt index 140be6651a5..180d1699780 100644 --- a/requirements_doc.txt +++ b/requirements_doc.txt @@ -1,8 +1,8 @@ # requirements for building docs sphinx!=4.1.0,<6 -https://github.com/numpy/numpydoc/archive/main.zip +git+https://github.com/numpy/numpydoc.git@main pydata_sphinx_theme==0.13.1 -https://github.com/sphinx-gallery/sphinx-gallery/archive/master.zip +git+https://github.com/sphinx-gallery/sphinx-gallery@master sphinxcontrib-bibtex>=2.5 memory_profiler neo @@ -10,9 +10,10 @@ seaborn!=0.11.2 sphinx_copybutton sphinx-design sphinxcontrib-youtube -https://github.com/mne-tools/mne-bids/archive/main.zip +git+https://github.com/mne-tools/mne-bids@main pyxdf -https://github.com/mne-tools/mne-connectivity/archive/main.zip +git+https://github.com/mne-tools/mne-connectivity.git@main +git+https://github.com/mne-tools/mne-gui-addons.git@main pygments>=2.13 pytest graphviz diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 0e704b5f60c..4b8647552b8 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -18,7 +18,7 @@ elif [ "${TEST_MODE}" == "pip-pre" ]; then python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://pypi.anaconda.org/scipy-wheels-nightly/simple" scipy statsmodels pandas scikit-learn dipy matplotlib python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://test.pypi.org/simple" openmeeg - python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps vtk + python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://wheels.vtk.org/" vtk python -m pip install --progress-bar off git+https://github.com/pyvista/pyvista python -m pip install --progress-bar off git+https://github.com/pyvista/pyvistaqt python -m pip install --progress-bar off --upgrade --pre imageio-ffmpeg xlrd mffpy python-picard patsy pillow diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 02df34f11cd..3fa77f91e90 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -29,7 +29,7 @@ else pip install $STD_ARGS --pre git+https://github.com/nilearn/nilearn pip install $STD_ARGS --pre --only-binary ":all:" -i "https://test.pypi.org/simple" openmeeg echo "VTK" - pip install $STD_ARGS --pre --only-binary ":all:" vtk + pip install $STD_ARGS --pre --only-binary ":all:" -i "https://wheels.vtk.org/" vtk python -c "import vtk" echo "PyVista" pip install --progress-bar off git+https://github.com/pyvista/pyvista diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py index 17b39284b5b..b256c903553 100644 --- a/tutorials/clinical/10_ieeg_localize.py +++ b/tutorials/clinical/10_ieeg_localize.py @@ -26,7 +26,6 @@ use this module in your analysis to support the addition of new projects to MNE. """ - # Authors: Alex Rockhill # Eric Larson # From 5024b32782f7bf8e4dd944e57a3269c1b92d07d5 Mon Sep 17 00:00:00 2001 From: Zvi Baratz Date: Sat, 15 Apr 2023 01:35:22 +0300 Subject: [PATCH 0021/1125] Fix index name in to_data_frame()'s docstring (#11457) Co-authored-by: Daniel McCloy --- doc/changes/latest.inc | 1 + doc/changes/names.inc | 2 ++ mne/time_frequency/spectrum.py | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index b1b3f840a7b..84fb0e35d6b 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -56,6 +56,7 @@ Bugs - Fix :func:`mne.io.read_raw_edf` when reading EDF data with different sampling rates and a mix of data channels when using ``infer_types=True`` (:gh:`11427` by `Alex Gramfort`_) - Fix how :class:`mne.channels.DigMontage` is set when using :func:`mne.preprocessing.ieeg.project_sensors_onto_brain` so that :func:`mne.Info.get_montage` works and does not return ``None`` (:gh:`11436` by `Alex Rockhill`_) - Fix configuration folder discovery on Windows, which would fail in certain edge cases; and produce a helpful error message if discovery still fails (:gh:`11441` by `Richard Höchenberger`_) +- Fix :meth:`mne.time_frequency.Spectrum.to_data_frame`'s docstring to reflect the correct name for the appended frequencies column (:gh:`11457` by :newcontrib:`Zvi Baratz`) - Make :class:`~mne.decoding.SlidingEstimator` and :class:`~mne.decoding.GeneralizingEstimator` respect the ``verbose`` argument. Now with ``verbose=False``, the progress bar is not shown during fitting, scoring, etc. (:gh:`11450` by `Mikołaj Magnuski`_) - Fix bug with ``mne.gui.locate_ieeg`` where Freesurfer ``?h.pial.T1`` was not recognized and suppress excess logging (:gh:`11489` by `Alex Rockhill`_) - All functions accepting paths can now correctly handle :class:`~pathlib.Path` as input. Historically, we expected strings (instead of "proper" path objects), and only added :class:`~pathlib.Path` support in a few select places, leading to inconsistent behavior. (:gh:`11473` and :gh:`11499` by `Mathieu Scheltienne`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index e1e7acafdf0..7508543cc39 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -543,3 +543,5 @@ .. _Yu-Han Luo: https://github.com/yh-luo .. _Zhi Zhang: https://github.com/tczhangzhi/ + +.. _Zvi Baratz: https://github.com/ZviBaratz diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 0a4e05da6df..0daf25690c1 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -730,7 +730,7 @@ def to_data_frame(self, picks=None, index=None, copy=True, """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, - an additional column "frequency" is added, unless ``index='freq'`` + an additional column "freq" is added, unless ``index='freq'`` (in which case frequency values form the DataFrame's index). Parameters From f88f22d52b3b5e2cba3f790965ac14d2fe4a13e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 17 Apr 2023 12:24:20 +0200 Subject: [PATCH 0022/1125] MRG: Allow retrieval of channel names via make_1020_channel_selections() (#11632) --- doc/changes/latest.inc | 1 + .../contralateral_referencing.py | 9 ++-- mne/channels/channels.py | 45 ++++++++++++------- mne/channels/tests/test_channels.py | 5 +++ 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 84fb0e35d6b..9dc28d470f4 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -43,6 +43,7 @@ Enhancements - Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_) - Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) +- :func:`mne.channels.make_1020_channel_selections` gained a new parameter, ``return_ch_names``, to allow for easy retrieval of EEG channel names corresponding to the left, right, and midline portions of the montage (:gh:`11632` by `Richard Höchenberger`_) Bugs ~~~~ diff --git a/examples/preprocessing/contralateral_referencing.py b/examples/preprocessing/contralateral_referencing.py index ad31d94f742..2c04ccc7c8f 100644 --- a/examples/preprocessing/contralateral_referencing.py +++ b/examples/preprocessing/contralateral_referencing.py @@ -12,7 +12,6 @@ contralateral EEG reference. """ -import numpy as np import mne ssvep_folder = mne.datasets.ssvep.data_path() @@ -31,11 +30,9 @@ }) # this splits electrodes into 3 groups; left, midline, and right -ch_indices = mne.channels.make_1020_channel_selections(raw.info) - -# convert indices to names -orig_names = np.array(raw.ch_names) -ch_names = {key: orig_names[idxs].tolist() for key, idxs in ch_indices.items()} +ch_names = mne.channels.make_1020_channel_selections( + raw.info, return_ch_names=True +) # remove the ref channels from the lists of to-be-rereferenced channels ch_names['Left'].remove('M1') diff --git a/mne/channels/channels.py b/mne/channels/channels.py index c3c86d20a34..d0ae5f01673 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1837,31 +1837,39 @@ def _get_ch_info(info): @fill_doc -def make_1020_channel_selections(info, midline="z"): - """Return dict mapping from ROI names to lists of picks for 10/20 setups. - - This passes through all channel names, and uses a simple heuristic to - separate channel names into three Region of Interest-based selections: - Left, Midline and Right. The heuristic is that channels ending on any of - the characters in ``midline`` are filed under that heading, otherwise those - ending in odd numbers under "Left", those in even numbers under "Right". - Other channels are ignored. This is appropriate for 10/20 files, but not - for other channel naming conventions. - If an info object is provided, lists are sorted from posterior to anterior. +def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): + """Map hemisphere names to corresponding EEG channel names or indices. + + This function uses a simple heuristic to separate channel names into three + Region of Interest-based selections: ``Left``, ``Midline`` and ``Right``. + + The heuristic is that any of the channel names ending + with odd numbers are filed under ``Left``; those ending with even numbers + are filed under ``Right``; and those ending with the character(s) specified + in ``midline`` are filed under ``Midline``. Other channels are ignored. + + This is appropriate for 10/20, 10/10, 10/05, …, sensor arrangements, but + not for other naming conventions. Parameters ---------- - %(info_not_none)s If possible, the channel lists will be sorted - posterior-to-anterior; otherwise they default to the order specified in - ``info["ch_names"]``. + %(info_not_none)s If channel locations are present, the channel lists will + be sorted from posterior to anterior; otherwise, the order specified in + ``info["ch_names"]`` will be kept. midline : str Names ending in any of these characters are stored under the - ``Midline`` key. Defaults to 'z'. Note that capitalization is ignored. + ``Midline`` key. Defaults to ``'z'``. Capitalization is ignored. + return_ch_names : bool + Whether to return channel names instead of channel indices. + + .. versionadded:: 1.4.0 Returns ------- selections : dict - A dictionary mapping from ROI names to lists of picks (integers). + A dictionary mapping from region of interest name to a list of channel + indices (if ``return_ch_names=False``) or to a list of channel names + (if ``return_ch_names=True``). """ _validate_type(info, "info") @@ -1891,6 +1899,11 @@ def make_1020_channel_selections(info, midline="z"): selections = {selection: np.array(picks)[pos[picks, 1].argsort()] for selection, picks in selections.items()} + # convert channel indices to names if requested + if return_ch_names: + for selection, ch_indices in selections.items(): + selections[selection] = [info.ch_names[idx] for idx in ch_indices] + return selections diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index 7634100afbe..585ce9b43cc 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -384,6 +384,11 @@ def test_1020_selection(): for channel, roi in zip(fz_c3_c4, ("Midline", "Left", "Right")): assert channel in sels[roi] + # ensure returning channel names works as expected + sels_names = make_1020_channel_selections(raw.info, return_ch_names=True) + for selection, ch_names in sels_names.items(): + assert ch_names == [raw.ch_names[idx] for idx in sels[selection]] + @testing.requires_testing_data def test_find_ch_adjacency(): From 8fc3d07c8ab32101cba41723f3f8d825a08166df Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Apr 2023 19:38:06 -0400 Subject: [PATCH 0023/1125] MAINT: Report download time and size (#11635) --- .circleci/config.yml | 4 + mne/datasets/_fetch.py | 21 ++-- mne/datasets/config.py | 6 +- mne/datasets/eegbci/eegbci.py | 27 +++-- mne/datasets/limo/limo.py | 29 +++-- mne/datasets/sleep_physionet/_utils.py | 37 +++--- mne/datasets/sleep_physionet/age.py | 27 +++-- mne/datasets/sleep_physionet/temazepam.py | 27 +++-- .../sleep_physionet/tests/test_physionet.py | 48 ++++---- mne/datasets/tests/test_datasets.py | 2 +- mne/datasets/utils.py | 106 ++++++++++++------ mne/utils/config.py | 1 + tools/circleci_download.sh | 1 + 13 files changed, 218 insertions(+), 118 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index ab70c684e4d..9cbb54338d4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -230,6 +230,10 @@ jobs: - data-cache-ucl-opm-auditory - run: name: Get data + # This limit could be increased, but this is helpful for finding slow ones + # (even ~2GB datasets should be downloadable in this time from good + # providers) + no_output_timeout: 10m command: | ./tools/circleci_download.sh - run: diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 87cd1664534..578c1cf82ed 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -2,12 +2,12 @@ # # License: BSD Style. -import logging import sys import os import os.path as op from pathlib import Path from shutil import rmtree +import time from .. import __version__ as mne_version from ..utils import logger, warn, _safe_input @@ -17,7 +17,8 @@ TESTING_VERSIONED, MISC_VERSIONED, ) -from .utils import _dataset_version, _do_path_update, _get_path +from .utils import (_dataset_version, _do_path_update, _get_path, + _log_time_size, _downloader_params) from ..fixes import _compare_version @@ -130,6 +131,7 @@ def fetch_dataset( pass a list of dicts. """ # noqa E501 import pooch + t0 = time.time() if auth is not None: if len(auth) != 2: @@ -220,13 +222,9 @@ def fetch_dataset( "You must agree to the license to use this " "dataset" ) # downloader & processors - download_params = dict(progressbar=logger.level <= logging.INFO) + download_params = _downloader_params(auth=auth, token=token) if name == "fake": download_params["progressbar"] = False - if auth is not None: - download_params["auth"] = auth - if token is not None: - download_params["headers"] = {"Authorization": f"token {token}"} downloader = pooch.HTTPDownloader(**download_params) # make mappings from archive names to urls and to checksums @@ -241,8 +239,9 @@ def fetch_dataset( registry[archive_name] = dataset_hash # create the download manager + use_path = final_path if processor is None else Path(path) fetcher = pooch.create( - path=str(final_path) if processor is None else path, + path=str(use_path), base_url="", # Full URLs are given in the `urls` dict. version=None, # Data versioning is decoupled from MNE-Python version. urls=urls, @@ -252,6 +251,7 @@ def fetch_dataset( # use our logger level for pooch's logger too pooch.get_logger().setLevel(logger.getEffectiveLevel()) + sz = 0 for idx in range(len(names)): # fetch and unpack the data @@ -268,9 +268,11 @@ def fetch_dataset( 'the dataset to be downloaded again.') from None else: raise + fname = use_path / archive_name + sz += fname.stat().st_size # after unpacking, remove the archive file if processor is not None: - os.remove(op.join(path, archive_name)) + fname.unlink() # remove version number from "misc" and "testing" datasets folder names if name == "misc": @@ -299,4 +301,5 @@ def fetch_dataset( name=name, current=data_version, newest=mne_version ) ) + _log_time_size(t0, sz) return (final_path, data_version) if return_version else final_path diff --git a/mne/datasets/config.py b/mne/datasets/config.py index dc851e9bd2f..e84d63b41c4 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -320,8 +320,10 @@ MNE_DATASETS['hf_sef_evoked'] = dict( archive_name='hf_sef_evoked.tar.gz', hash='md5:13d34cb5db584e00868677d8fb0aab2b', - url=('https://zenodo.org/record/3523071/files/' - 'hf_sef_evoked.tar.gz'), + # Zenodo can be slow, so we use the OSF mirror + # url=('https://zenodo.org/record/3523071/files/' + # 'hf_sef_evoked.tar.gz'), + url='https://osf.io/25f8d/download?version=2', folder_name='hf_sef', config_key='MNE_DATASETS_HF_SEF_PATH', ) diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index d976425dd7a..e89ae089fcc 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -7,9 +7,11 @@ import re from os import path as op from pathlib import Path +import time -from ...utils import _url_to_local_path, verbose -from ..utils import _do_path_update, _get_path +from ...utils import _url_to_local_path, verbose, logger +from ..utils import (_do_path_update, _get_path, _log_time_size, + _downloader_params) # TODO: remove try/except when our min version is py 3.9 try: @@ -79,16 +81,17 @@ def data_path(url, path=None, force_update=False, update_path=None, *, destinations = [destination] # Fetch the file + downloader = pooch.HTTPDownloader(**_downloader_params()) if not op.isfile(destination) or force_update: if op.isfile(destination): os.remove(destination) if not op.isdir(op.dirname(destination)): os.makedirs(op.dirname(destination)) pooch.retrieve( - # URL to one of Pooch's test files url=url, path=destination, - fname=fname + downloader=downloader, + fname=fname, ) # Offer to update the path @@ -162,6 +165,7 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() if not hasattr(runs, '__iter__'): runs = [runs] @@ -195,14 +199,23 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, # fetch the file(s) data_paths = [] + sz = 0 for run in runs: file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf' - destination = op.join(base_path, file_part) - if force_update and op.isfile(destination): - os.remove(destination) + destination = Path(base_path, file_part) + if destination.exists(): + if force_update: + destination.unlink() + else: + continue + if sz == 0: # log once + logger.info('Downloading EEGBCI data') data_paths.append(fetcher.fetch(file_part)) # update path in config if desired _do_path_update(path, update_path, config_key, name) + sz += destination.stat().st_size + if sz > 0: + _log_time_size(t0, sz) return data_paths diff --git a/mne/datasets/limo/limo.py b/mne/datasets/limo/limo.py index 143a9dd1162..e0f1d0f9fa9 100644 --- a/mne/datasets/limo/limo.py +++ b/mne/datasets/limo/limo.py @@ -2,16 +2,18 @@ # # License: BSD-3-Clause -import os import os.path as op +from pathlib import Path +import time import numpy as np from ...channels import make_standard_montage from ...epochs import EpochsArray from ...io.meas_info import create_info -from ...utils import _check_pandas_installed, verbose -from ..utils import _get_path, _do_path_update, logger +from ...utils import _check_pandas_installed, verbose, logger +from ..utils import (_get_path, _do_path_update, _log_time_size, + _downloader_params) # root url for LIMO files @@ -67,8 +69,9 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() - downloader = pooch.HTTPDownloader(progressbar=True) # use tqdm + downloader = pooch.HTTPDownloader(**_downloader_params()) # local storage patch config_key = 'MNE_DATASETS_LIMO_PATH' @@ -168,14 +171,23 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, # use our logger level for pooch's logger too pooch.get_logger().setLevel(logger.getEffectiveLevel()) # fetch the data + sz = 0 for fname in ('LIMO.mat', 'Yr.mat'): - destination = op.join(subject_path, fname) - if force_update and op.isfile(destination): - os.remove(destination) + destination = Path(subject_path, fname) + if destination.exists(): + if force_update: + destination.unlink() + else: + continue + if sz == 0: # log once + logger.info('Downloading LIMO data') # fetch the remote file (if local file missing or has hash mismatch) fetcher.fetch(fname=fname, downloader=downloader) + sz += destination.stat().st_size # update path in config if desired _do_path_update(path, update_path, config_key, name) + if sz > 0: + _log_time_size(t0, sz) return base_path @@ -282,7 +294,8 @@ def load_data(subject, path=None, force_update=False, update_path=None, metadata = pd.DataFrame(metadata) # -- 6) Create custom epochs array - epochs = EpochsArray(data, info, events, tmin, event_id, metadata=metadata) + epochs = EpochsArray(data, info, events, tmin, event_id, metadata=metadata, + verbose=False) epochs.info['bads'] = missing_chans # missing channels are marked as bad. return epochs diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index 0c2c0632857..50f992e7803 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -10,7 +10,7 @@ from ...utils import (verbose, _TempDir, _check_pandas_installed, _on_missing) -from ..utils import _get_path +from ..utils import _get_path, _downloader_params AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), 'age_records.csv') TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__), @@ -30,18 +30,21 @@ def _fetch_one(fname, hashsum, path, force_update, base_url): # Fetch the file url = base_url + '/' + fname destination = op.join(path, fname) - if not op.isfile(destination) or force_update: - if op.isfile(destination): - os.remove(destination) - if not op.isdir(op.dirname(destination)): - os.makedirs(op.dirname(destination)) - pooch.retrieve( - url=url, - known_hash=f"sha1:{hashsum}", - path=path, - fname=fname - ) - return destination + if op.isfile(destination) and not force_update: + return destination, False + if op.isfile(destination): + os.remove(destination) + if not op.isdir(op.dirname(destination)): + os.makedirs(op.dirname(destination)) + downloader = pooch.HTTPDownloader(**_downloader_params()) + pooch.retrieve( + url=url, + known_hash=f"sha1:{hashsum}", + path=path, + downloader=downloader, + fname=fname + ) + return destination, True @verbose @@ -87,11 +90,13 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): # Download subjects info. subjects_fname = op.join(tmp, 'ST-subjects.xls') + downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=TEMAZEPAM_RECORDS_URL, known_hash=f"sha1:{TEMAZEPAM_RECORDS_URL_SHA1}", path=tmp, - fname=op.basename(subjects_fname) + downloader=downloader, + fname=op.basename(subjects_fname), ) # Load and Massage the checksums. @@ -146,11 +151,13 @@ def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): # Download subjects info. subjects_fname = op.join(tmp, 'SC-subjects.xls') + downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=AGE_RECORDS_URL, known_hash=f"sha1:{AGE_RECORDS_URL_SHA1}", path=tmp, - fname=op.basename(subjects_fname) + downloader=downloader, + fname=op.basename(subjects_fname), ) # Load and Massage the checksums. diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index 4a0d8456639..106d39d4e32 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -3,9 +3,13 @@ # # License: BSD Style. +import os +import time + import numpy as np from ...utils import verbose +from ..utils import _log_time_size from ._utils import _fetch_one, _data_path, _on_missing, AGE_SLEEP_RECORDS from ._utils import _check_subjects @@ -79,6 +83,7 @@ def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, ---------- .. footbibliography:: """ # noqa: E501 + t0 = time.time() records = np.loadtxt(AGE_SLEEP_RECORDS, skiprows=1, delimiter=',', @@ -107,15 +112,23 @@ def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, _on_missing(on_missing, msg) fnames = [] + sz = 0 for subject in subjects: for idx in np.where(psg_records['subject'] == subject)[0]: if psg_records['record'][idx] in recording: - psg_fname = _fetch_one(psg_records['fname'][idx].decode(), - psg_records['sha'][idx].decode(), - *params) - hyp_fname = _fetch_one(hyp_records['fname'][idx].decode(), - hyp_records['sha'][idx].decode(), - *params) + psg_fname, pdl = _fetch_one( + psg_records['fname'][idx].decode(), + psg_records['sha'][idx].decode(), + *params) + hyp_fname, hdl = _fetch_one( + hyp_records['fname'][idx].decode(), + hyp_records['sha'][idx].decode(), + *params) fnames.append([psg_fname, hyp_fname]) - + if pdl: + sz += os.path.getsize(psg_fname) + if hdl: + sz += os.path.getsize(hyp_fname) + if sz > 0: + _log_time_size(t0, sz) return fnames diff --git a/mne/datasets/sleep_physionet/temazepam.py b/mne/datasets/sleep_physionet/temazepam.py index a18f126ab5f..841dbe67a7f 100644 --- a/mne/datasets/sleep_physionet/temazepam.py +++ b/mne/datasets/sleep_physionet/temazepam.py @@ -3,9 +3,13 @@ # # License: BSD Style. +import os +import time + import numpy as np from ...utils import verbose +from ..utils import _log_time_size from ._utils import _fetch_one, _data_path, TEMAZEPAM_SLEEP_RECORDS from ._utils import _check_subjects @@ -67,6 +71,7 @@ def fetch_data(subjects, path=None, force_update=False, base_url=BASE_URL, *, ---------- .. footbibliography:: """ + t0 = time.time() records = np.loadtxt(TEMAZEPAM_SLEEP_RECORDS, skiprows=1, delimiter=',', @@ -83,15 +88,23 @@ def fetch_data(subjects, path=None, force_update=False, base_url=BASE_URL, *, params = [path, force_update, base_url] fnames = [] + sz = 0 for subject in subjects: # all the subjects are present at this point for idx in np.where(records['subject'] == subject)[0]: if records['record'][idx] == b'Placebo': - psg_fname = _fetch_one(records['psg fname'][idx].decode(), - records['psg sha'][idx].decode(), - *params) - hyp_fname = _fetch_one(records['hyp fname'][idx].decode(), - records['hyp sha'][idx].decode(), - *params) + psg_fname, pdl = _fetch_one( + records['psg fname'][idx].decode(), + records['psg sha'][idx].decode(), + *params) + hyp_fname, hdl = _fetch_one( + records['hyp fname'][idx].decode(), + records['hyp sha'][idx].decode(), + *params) fnames.append([psg_fname, hyp_fname]) - + if pdl: + sz += os.path.getsize(psg_fname) + if hdl: + sz += os.path.getsize(hyp_fname) + if sz > 0: + _log_time_size(t0, sz) return fnames diff --git a/mne/datasets/sleep_physionet/tests/test_physionet.py b/mne/datasets/sleep_physionet/tests/test_physionet.py index 3f754b863ac..549963cb73f 100644 --- a/mne/datasets/sleep_physionet/tests/test_physionet.py +++ b/mne/datasets/sleep_physionet/tests/test_physionet.py @@ -3,11 +3,9 @@ # # License: BSD Style. -import os.path as op -import numpy as np +from pathlib import Path import pytest -from numpy.testing import assert_array_equal import pooch from mne.utils import requires_good_network @@ -32,14 +30,15 @@ def __init__(self): def __call__(self, *args, **kwargs): self.call_args_list.append((args, kwargs)) + Path(kwargs['path'], kwargs['fname']).write_text('test') @property def call_count(self): return len(self.call_args_list) -def _keep_basename_only(path_structure): - return np.vectorize(op.basename)(np.array(path_structure)) +def _keep_basename_only(paths): + return [Path(p).name for p in paths] def _get_expected_url(http://github.com/name): @@ -49,7 +48,7 @@ def _get_expected_url(http://github.com/name): def _get_expected_path(base, name): - return op.join(base, name) + return Path(base, name) def _check_mocked_function_calls(mocked_func, call_fname_hash_pairs, @@ -62,8 +61,8 @@ def _check_mocked_function_calls(mocked_func, call_fname_hash_pairs, for idx, current in enumerate(call_fname_hash_pairs): _, call_kwargs = mocked_func.call_args_list[idx] hash_type, hash = call_kwargs['known_hash'].split(':') - assert call_kwargs['url'] == _get_expected_url(http://github.com/current['name']) - assert op.join(call_kwargs['path'], call_kwargs['fname']) == \ + assert call_kwargs['url'] == _get_expected_url(http://github.com/current['name']), idx + assert Path(call_kwargs['path'], call_kwargs['fname']) == \ _get_expected_path(base_path, current['name']) assert hash == current['hash'] assert hash_type == 'sha1' @@ -130,26 +129,24 @@ def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): monkeypatch.setattr(pooch, 'retrieve', my_func) paths = age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir) - assert_array_equal(_keep_basename_only(paths), - [['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf']]) + assert _keep_basename_only(paths[0]) == \ + ['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'] paths = age.fetch_data(subjects=[0, 1], recording=[1], path=physionet_tmpdir) - assert_array_equal(_keep_basename_only(paths), - [['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'], - ['SC4011E0-PSG.edf', 'SC4011EH-Hypnogram.edf']]) + assert _keep_basename_only(paths[0]) == \ + ['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'] + assert _keep_basename_only(paths[1]) == \ + ['SC4011E0-PSG.edf', 'SC4011EH-Hypnogram.edf'] paths = age.fetch_data(subjects=[0], recording=[1, 2], path=physionet_tmpdir) - assert_array_equal(_keep_basename_only(paths), - [['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'], - ['SC4002E0-PSG.edf', 'SC4002EC-Hypnogram.edf']]) + assert _keep_basename_only(paths[0]) == \ + ['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'] + assert _keep_basename_only(paths[1]) == \ + ['SC4002E0-PSG.edf', 'SC4002EC-Hypnogram.edf'] EXPECTED_CALLS = ( - {'name': 'SC4001E0-PSG.edf', - 'hash': 'adabd3b01fc7bb75c523a974f38ee3ae4e57b40f'}, - {'name': 'SC4001EC-Hypnogram.edf', - 'hash': '21c998eadc8b1e3ea6727d3585186b8f76e7e70b'}, {'name': 'SC4001E0-PSG.edf', 'hash': 'adabd3b01fc7bb75c523a974f38ee3ae4e57b40f'}, {'name': 'SC4001EC-Hypnogram.edf', @@ -158,14 +155,11 @@ def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): 'hash': '4d17451f7847355bcab17584de05e7e1df58c660'}, {'name': 'SC4011EH-Hypnogram.edf', 'hash': 'd582a3cbe2db481a362af890bc5a2f5ca7c878dc'}, - {'name': 'SC4001E0-PSG.edf', - 'hash': 'adabd3b01fc7bb75c523a974f38ee3ae4e57b40f'}, - {'name': 'SC4001EC-Hypnogram.edf', - 'hash': '21c998eadc8b1e3ea6727d3585186b8f76e7e70b'}, {'name': 'SC4002E0-PSG.edf', 'hash': 'c6b6d7a8605cc7e7602b6028ee77f6fbf5f7581d'}, {'name': 'SC4002EC-Hypnogram.edf', - 'hash': '386230188a3552b1fc90bba0fb7476ceaca174b6'}) + 'hash': '386230188a3552b1fc90bba0fb7476ceaca174b6'}, + ) base_path = age.data_path(path=physionet_tmpdir) _check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path) @@ -192,8 +186,8 @@ def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch): monkeypatch.setattr(pooch, 'retrieve', my_func) paths = temazepam.fetch_data(subjects=[0], path=physionet_tmpdir) - assert_array_equal(_keep_basename_only(paths), - [['ST7011J0-PSG.edf', 'ST7011JP-Hypnogram.edf']]) + assert _keep_basename_only(paths[0]) == \ + ['ST7011J0-PSG.edf', 'ST7011JP-Hypnogram.edf'] EXPECTED_CALLS = ( {'name': 'ST7011J0-PSG.edf', diff --git a/mne/datasets/tests/test_datasets.py b/mne/datasets/tests/test_datasets.py index 8709b934326..46c1ecd229f 100644 --- a/mne/datasets/tests/test_datasets.py +++ b/mne/datasets/tests/test_datasets.py @@ -189,7 +189,7 @@ def test_fetch_parcellations(tmp_path): _zip_fnames = ['foo/foo.txt', 'foo/bar.txt', 'foo/baz.txt'] -def _fake_zip_fetch(url, path, fname, known_hash): +def _fake_zip_fetch(url, path, fname, *args, **kwargs): fname = op.join(path, fname) with zipfile.ZipFile(fname, 'w') as zipf: with zipf.open('foo/', 'w'): diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 50a894bfd7b..1fba832abb0 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -11,12 +11,14 @@ from collections import OrderedDict import importlib import inspect +import logging import os import os.path as op from pathlib import Path import sys -import zipfile +import time import tempfile +import zipfile import numpy as np @@ -299,49 +301,52 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build - from . import (sample, testing, misc, spm_face, somato, brainstorm, - eegbci, multimodal, opm, hf_sef, mtrf, fieldtrip_cmc, - kiloword, phantom_4dbti, sleep_physionet, limo, - fnirs_motor, refmeg_noise, fetch_infant_template, - fetch_fsaverage, ssvep, erp_core, epilepsy_ecog, - fetch_phantom, eyelink, ucl_opm_auditory) - sample_path = sample.data_path() - testing.data_path() - misc.data_path() - spm_face.data_path() - somato.data_path() - hf_sef.data_path() - multimodal.data_path() - fnirs_motor.data_path() - opm.data_path() - mtrf.data_path() - fieldtrip_cmc.data_path() - kiloword.data_path() - phantom_4dbti.data_path() - refmeg_noise.data_path() - ssvep.data_path() - epilepsy_ecog.data_path() - ucl_opm_auditory.data_path() - brainstorm.bst_raw.data_path(accept=True) - brainstorm.bst_auditory.data_path(accept=True) - brainstorm.bst_resting.data_path(accept=True) - phantom_path = brainstorm.bst_phantom_elekta.data_path(accept=True) - fetch_phantom('otaniemi', subjects_dir=phantom_path) - eyelink.data_path() - brainstorm.bst_phantom_ctf.data_path(accept=True) + paths = dict() + for kind in ('sample testing misc spm_face somato hf_sef multimodal ' + 'fnirs_motor opm mtrf fieldtrip_cmc kiloword phantom_4dbti ' + 'refmeg_noise ssvep epilepsy_ecog ucl_opm_auditory eyelink ' + 'erp_core brainstorm.bst_raw brainstorm.bst_auditory ' + 'brainstorm.bst_resting brainstorm.bst_phantom_ctf ' + 'brainstorm.bst_phantom_elekta' + ).split(): + mod = importlib.import_module(f'mne.datasets.{kind}') + data_path_func = getattr(mod, 'data_path') + kwargs = dict() + if 'accept' in inspect.getfullargspec(data_path_func).args: + kwargs['accept'] = True + paths[kind] = data_path_func(**kwargs) + logger.info(f'[done {kind}]') + + # Now for the exceptions: + from . import ( + eegbci, sleep_physionet, limo, fetch_fsaverage, fetch_infant_template, + fetch_hcp_mmp_parcellation, fetch_phantom) eegbci.load_data(1, [6, 10, 14], update_path=True) for subj in range(4): eegbci.load_data(subj + 1, runs=[3], update_path=True) + logger.info('[done eegbci]') + sleep_physionet.age.fetch_data(subjects=[0, 1], recording=[1]) + logger.info('[done sleep_physionet]') + # If the user has SUBJECTS_DIR, respect it, if not, set it to the EEG one # (probably on CircleCI, or otherwise advanced user) fetch_fsaverage(None) + logger.info('[done fsaverage]') + fetch_infant_template('6mo') + logger.info('[done infant_template]') + fetch_hcp_mmp_parcellation( - subjects_dir=sample_path / 'subjects', accept=True) - limo.load_data(subject=1, update_path=True) + subjects_dir=paths['sample'] / 'subjects', accept=True) + logger.info('[done hcp_mmp_parcellation]') + + fetch_phantom( + 'otaniemi', subjects_dir=paths['brainstorm.bst_phantom_elekta']) + logger.info('[done phantom]') - erp_core.data_path() + limo.load_data(subject=1, update_path=True) + logger.info('[done limo]') @verbose @@ -371,6 +376,7 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): rh='https://osf.io/4kxny/download') hashes = dict(lh='9e4d8d6b90242b7e4b0145353436ef77', rh='dd6464db8e7762d969fc1d8087cd211b') + downloader = pooch.HTTPDownloader(**_downloader_params()) for hemi in ('lh', 'rh'): fname = f'{hemi}.aparc_sub.annot' fpath = destination / fname @@ -379,6 +385,7 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): url=urls[hemi], known_hash=f"md5:{hashes[hemi]}", path=destination, + downloader=downloader, fname=fname, ) @@ -430,6 +437,7 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, if answer.lower() != 'y': raise RuntimeError('You must agree to the license to use this ' 'dataset') + downloader = pooch.HTTPDownloader(**_downloader_params()) for hemi, fpath in zip(('lh', 'rh'), fnames): if not op.isfile(fpath): fname = fpath.name @@ -437,6 +445,7 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, url=urls[hemi], known_hash=f"md5:{hashes[hemi]}", path=destination, + downloader=downloader, fname=fname, ) @@ -553,6 +562,7 @@ def _manifest_check_download(manifest_path, destination, url, hash_): logger.info('%d file%s missing from %s in %s' % (len(need), _pl(need), manifest_path, destination)) if len(need) > 0: + downloader = pooch.HTTPDownloader(**_downloader_params()) with tempfile.TemporaryDirectory() as path: logger.info('Downloading missing files remotely') @@ -561,7 +571,8 @@ def _manifest_check_download(manifest_path, destination, url, hash_): url=url, known_hash=f"md5:{hash_}", path=path, - fname=op.basename(fname_path) + downloader=downloader, + fname=op.basename(fname_path), ) logger.info('Extracting missing file%s' % (_pl(need),)) @@ -575,3 +586,28 @@ def _manifest_check_download(manifest_path, destination, url, hash_): ff.extract(name, path=destination) logger.info('Successfully extracted %d file%s' % (len(need), _pl(need))) + + +def _log_time_size(t0, sz): + t = time.time() - t0 + fmt = '%Ss' + if t > 60: + fmt = f'%Mm{fmt}' + if t > 3600: + fmt = f'%Hh{fmt}' + sz = sz / 1048576 # 1024 ** 2 + t = time.strftime(fmt, time.gmtime(t)) + logger.info(f'Download complete in {t} ({sz:.1f} MB)') + + +def _downloader_params(*, auth=None, token=None): + params = dict() + params['progressbar'] = ( + logger.level <= logging.INFO and + get_config('MNE_TQDM', 'tqdm.auto') != 'off' + ) + if auth is not None: + params["auth"] = auth + if token is not None: + params["headers"] = {"Authorization": f"token {token}"} + return params diff --git a/mne/utils/config.py b/mne/utils/config.py index 5056fcfd18a..09a89fe9a0f 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -104,6 +104,7 @@ def set_memmap_min_size(memmap_min_size): 'MNE_DATASETS_BRAINSTORM_PATH', 'MNE_DATASETS_EEGBCI_PATH', 'MNE_DATASETS_EPILEPSY_ECOG_PATH', + 'MNE_DATASETS_EYELINK_PATH', 'MNE_DATASETS_HF_SEF_PATH', 'MNE_DATASETS_MEGSIM_PATH', 'MNE_DATASETS_MISC_PATH', diff --git a/tools/circleci_download.sh b/tools/circleci_download.sh index 421f6f63ec1..cb622cb1860 100755 --- a/tools/circleci_download.sh +++ b/tools/circleci_download.sh @@ -1,6 +1,7 @@ #!/bin/bash -e set -o pipefail +export MNE_TQDM=off if [ "$CIRCLE_BRANCH" == "main" ] || [[ $(cat gitlog.txt) == *"[circle full]"* ]]; then echo "Doing a full dev build"; From 4eb9ef04c50710dba840ceaac6dc4f69ffefdd31 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 19 Apr 2023 15:03:56 -0400 Subject: [PATCH 0024/1125] BUG: Fix bug with paths (#11639) --- mne/conftest.py | 41 +++++++++++++++++++ mne/datasets/eegbci/eegbci.py | 5 ++- mne/datasets/eegbci/tests/test_eegbci.py | 14 +++++++ .../sleep_physionet/tests/test_physionet.py | 33 ++------------- .../40_artifact_correction_ica.py | 1 + 5 files changed, 63 insertions(+), 31 deletions(-) create mode 100644 mne/datasets/eegbci/tests/test_eegbci.py diff --git a/mne/conftest.py b/mne/conftest.py index 05d433cacfc..a4b261db704 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -713,6 +713,47 @@ def download_is_error(monkeypatch): """Prevent downloading by raising an error when it's attempted.""" import pooch monkeypatch.setattr(pooch, 'retrieve', _fail) + yield + + +@pytest.fixture() +def fake_retrieve(monkeypatch, download_is_error): + """Monkeypatch pooch.retrieve to avoid downloading (just touch files).""" + import pooch + my_func = _FakeFetch() + monkeypatch.setattr(pooch, 'retrieve', my_func) + monkeypatch.setattr(pooch, 'create', my_func) + yield my_func + + +class _FakeFetch: + + def __init__(self): + self.call_args_list = list() + + @property + def call_count(self): + return len(self.call_args_list) + + # Wrapper for pooch.retrieve(...) and pooch.create(...) + def __call__(self, *args, **kwargs): + assert 'path' in kwargs + if 'fname' in kwargs: # pooch.retrieve(...) + self.call_args_list.append((args, kwargs)) + path = Path(kwargs['path'], kwargs['fname']) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text('test') + return path + else: # pooch.create(...) has been called + self.path = kwargs['path'] + return self + + # Wrappers for Pooch instances (e.g., in eegbci we pooch.create) + def fetch(self, fname): + self(path=self.path, fname=fname) + + def load_registry(self, registry): + assert Path(registry).exists(), registry # We can't use monkeypatch because its scope (function-level) conflicts with diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index e89ae089fcc..fd2b0a71e24 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -203,6 +203,7 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, for run in runs: file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf' destination = Path(base_path, file_part) + data_paths.append(destination) if destination.exists(): if force_update: destination.unlink() @@ -210,10 +211,10 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, continue if sz == 0: # log once logger.info('Downloading EEGBCI data') - data_paths.append(fetcher.fetch(file_part)) + fetcher.fetch(file_part) # update path in config if desired - _do_path_update(path, update_path, config_key, name) sz += destination.stat().st_size + _do_path_update(path, update_path, config_key, name) if sz > 0: _log_time_size(t0, sz) return data_paths diff --git a/mne/datasets/eegbci/tests/test_eegbci.py b/mne/datasets/eegbci/tests/test_eegbci.py new file mode 100644 index 00000000000..e60988ff36c --- /dev/null +++ b/mne/datasets/eegbci/tests/test_eegbci.py @@ -0,0 +1,14 @@ +# Authors: Eric Larson +# +# License: BSD Style. + +from mne.datasets import eegbci + + +def test_eegbci_download(tmp_path, fake_retrieve): + """Test Sleep Physionet URL handling.""" + for subj in range(4): + fnames = eegbci.load_data( + subj + 1, runs=[3], path=tmp_path, update_path=False) + assert len(fnames) == 1, subj + assert fake_retrieve.call_count == 4 diff --git a/mne/datasets/sleep_physionet/tests/test_physionet.py b/mne/datasets/sleep_physionet/tests/test_physionet.py index 549963cb73f..ad400505d73 100644 --- a/mne/datasets/sleep_physionet/tests/test_physionet.py +++ b/mne/datasets/sleep_physionet/tests/test_physionet.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -import pooch from mne.utils import requires_good_network from mne.utils import requires_pandas, requires_version @@ -23,20 +22,6 @@ def physionet_tmpdir(tmp_path_factory): return str(tmp_path_factory.mktemp('physionet_files')) -class _FakeFetch: - - def __init__(self): - self.call_args_list = list() - - def __call__(self, *args, **kwargs): - self.call_args_list.append((args, kwargs)) - Path(kwargs['path'], kwargs['fname']).write_text('test') - - @property - def call_count(self): - return len(self.call_args_list) - - def _keep_basename_only(paths): return [Path(p).name for p in paths] @@ -119,15 +104,8 @@ def test_sleep_physionet_age_missing_recordings(physionet_tmpdir, subject, assert paths == [] -def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): +def test_sleep_physionet_age(physionet_tmpdir, fake_retrieve): """Test Sleep Physionet URL handling.""" - # check download_is_error patching - with pytest.raises(AssertionError, match='Test should not download'): - age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir) - # then patch - my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - paths = age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir) assert _keep_basename_only(paths[0]) == \ ['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'] @@ -161,7 +139,7 @@ def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): 'hash': '386230188a3552b1fc90bba0fb7476ceaca174b6'}, ) base_path = age.data_path(path=physionet_tmpdir) - _check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path) + _check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path) @pytest.mark.xfail(strict=False) @@ -180,11 +158,8 @@ def test_run_update_temazepam_records(tmp_path): data, pd.read_csv(TEMAZEPAM_SLEEP_RECORDS)) -def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch): +def test_sleep_physionet_temazepam(physionet_tmpdir, fake_retrieve): """Test Sleep Physionet URL handling.""" - my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - paths = temazepam.fetch_data(subjects=[0], path=physionet_tmpdir) assert _keep_basename_only(paths[0]) == \ ['ST7011J0-PSG.edf', 'ST7011JP-Hypnogram.edf'] @@ -195,7 +170,7 @@ def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch): {'name': 'ST7011JP-Hypnogram.edf', 'hash': 'ff28e5e01296cefed49ae0c27cfb3ebc42e710bf'}) base_path = temazepam.data_path(path=physionet_tmpdir) - _check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path) + _check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path) with pytest.raises( ValueError, match='This dataset contains subjects 0 to 21'): diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 51e353dcae7..b4aae956300 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -29,6 +29,7 @@ sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample', 'sample_audvis_filt-0-40_raw.fif') raw = mne.io.read_raw_fif(sample_data_raw_file) + # Here we'll crop to 60 seconds and drop gradiometer channels for speed raw.crop(tmax=60.).pick_types(meg='mag', eeg=True, stim=True, eog=True) raw.load_data() From 8325f8082afd1d0d836585e77e01fcf600fd348c Mon Sep 17 00:00:00 2001 From: Clemens Brunner Date: Sat, 22 Apr 2023 02:00:37 +0200 Subject: [PATCH 0025/1125] Add pre-commit (#11541) Co-authored-by: Eric Larson Co-authored-by: Daniel McCloy --- .github/workflows/codespell_and_flake.yml | 45 --------- .github/workflows/precommit.yml | 14 +++ .gitignore | 2 +- .pre-commit-config.yaml | 36 +++++++ MANIFEST.in | 1 + Makefile | 95 ++----------------- azure-pipelines.yml | 19 +--- doc/install/contributing.rst | 16 ++-- examples/simulation/simulate_evoked_data.py | 2 +- ...imulated_raw_data_using_subject_anatomy.py | 3 +- examples/visualization/topo_customized.py | 3 +- .../tests/test_resolution_matrix.py | 4 +- mne/chpi.py | 4 +- mne/io/base.py | 5 +- mne/utils/check.py | 1 + pyproject.toml | 50 ++++++++++ requirements_testing.txt | 5 +- setup.cfg | 42 -------- 18 files changed, 138 insertions(+), 209 deletions(-) delete mode 100644 .github/workflows/codespell_and_flake.yml create mode 100644 .github/workflows/precommit.yml create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml delete mode 100644 setup.cfg diff --git a/.github/workflows/codespell_and_flake.yml b/.github/workflows/codespell_and_flake.yml deleted file mode 100644 index e191caa25d1..00000000000 --- a/.github/workflows/codespell_and_flake.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: 'codespell_and_flake' -# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#concurrency -# https://docs.github.com/en/developers/webhooks-and-events/events/github-event-types#pullrequestevent -# workflow name, PR number (empty on push), push ref (empty on PR) -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - style: - name: 'codespell and flake' - runs-on: ubuntu-20.04 - env: - CODESPELL_DIRS: 'mne/ doc/ tutorials/ examples/' - CODESPELL_SKIPS: 'doc/_build,doc/auto_*,*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle,*.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume,*.defect_borders,*.mgh,lh.*,rh.*,COR-*,FreeSurferColorLUT.txt,*.examples,.xdebug_mris_calc,bad.segments,BadChannels,*.hist,empty_file,*.orig,*.js,*.map,*.ipynb,searchindex.dat,install_mne_c.rst,plot_*.rst,*.rst.txt,c_EULA.rst*,*.html,gdf_encodes.txt,*.svg,references.bib,*.css,*.edf,*.bdf,*.vhdr' - - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - architecture: 'x64' - - run: | - python -m pip install --upgrade pip setuptools wheel - python -m pip install flake8 - name: 'Install dependencies' - - uses: rbialon/flake8-annotations@v1 - name: 'Setup flake8 annotations' - - run: make flake - name: 'Run flake8' - - uses: codespell-project/actions-codespell@v1.0 - with: - path: ${{ env.CODESPELL_DIRS }} - skip: ${{ env.CODESPELL_SKIPS }} - builtin: 'clear,rare,informal,names' - ignore_words_file: 'ignore_words.txt' - uri_ignore_words_list: 'bu' - name: 'Run codespell' diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml new file mode 100644 index 00000000000..4638064b646 --- /dev/null +++ b/.github/workflows/precommit.yml @@ -0,0 +1,14 @@ +name: Pre-commit + +on: [push, pull_request] + +jobs: + style: + name: Pre-commit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - uses: pre-commit/action@v3.0.0 diff --git a/.gitignore b/.gitignore index 40c64c7bb65..c73ee6d5257 100644 --- a/.gitignore +++ b/.gitignore @@ -92,5 +92,5 @@ cover venv/ *.json .hypothesis/ - +.ruff_cache/ .ipynb_checkpoints/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..4814c23d8eb --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +default_language_version: + python: python3.11 + +repos: +# - repo: https://github.com/psf/black +# rev: 23.1.0 +# hooks: +# - id: black +# args: [--quiet] + +# Ruff mne +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.262 + hooks: + - id: ruff + name: ruff mne + files: ^mne/ + +# Ruff tutorials and examples +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.262 + hooks: + - id: ruff + name: ruff tutorials and examples + # D103: missing docstring in public function + # D400: docstring first line must end with period + args: ["--ignore=D103,D400"] + files: ^tutorials/|^examples/ + +# Codespell +- repo: https://github.com/codespell-project/codespell + rev: v2.2.3 + hooks: + - id: codespell + files: ^mne/|^doc/|^examples/|^tutorials/ + types_or: [python, bib, rst, inc] diff --git a/MANIFEST.in b/MANIFEST.in index 6c1aa9ff47f..e00b86e3e79 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -57,6 +57,7 @@ exclude tools exclude Makefile exclude .coveragerc exclude *.yml +exclude *.yaml exclude ignore_words.txt exclude .mailmap exclude codemeta.json diff --git a/Makefile b/Makefile index a162617cd0a..c0e47ada7fb 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,6 @@ PYTHON ?= python PYTESTS ?= py.test -CTAGS ?= ctags CODESPELL_SKIPS ?= "doc/_build,doc/auto_*,*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle,*.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume,*.defect_borders,*.mgh,lh.*,rh.*,COR-*,FreeSurferColorLUT.txt,*.examples,.xdebug_mris_calc,bad.segments,BadChannels,*.hist,empty_file,*.orig,*.js,*.map,*.ipynb,searchindex.dat,install_mne_c.rst,plot_*.rst,*.rst.txt,c_EULA.rst*,*.html,gdf_encodes.txt,*.svg,references.bib,*.css,*.edf,*.bdf,*.vhdr" CODESPELL_DIRS ?= mne/ doc/ tutorials/ examples/ all: clean inplace test test-doc @@ -25,13 +24,6 @@ clean-cache: clean: clean-build clean-pyc clean-so clean-ctags clean-cache -in: inplace # just a shortcut -inplace: - $(PYTHON) setup.py build_ext -i - -wheel: - $(PYTHON) setup.py sdist bdist_wheel - wheel_quiet: $(PYTHON) setup.py -q sdist bdist_wheel @@ -43,22 +35,6 @@ testing_data: pytest: test -test: in - rm -f .coverage - $(PYTESTS) -m 'not ultraslowtest' mne - -test-verbose: in - rm -f .coverage - $(PYTESTS) -m 'not ultraslowtest' mne --verbose - -test-fast: in - rm -f .coverage - $(PYTESTS) -m 'not slowtest' mne - -test-full: in - rm -f .coverage - $(PYTESTS) mne - test-no-network: in sudo unshare -n -- sh -c 'MNE_SKIP_NETWORK_TESTS=1 py.test mne' @@ -66,56 +42,20 @@ test-no-testing-data: in @MNE_SKIP_TESTING_DATASET_TESTS=true \ $(PYTESTS) mne -test-no-sample-with-coverage: in testing_data - rm -rf coverage .coverage - $(PYTESTS) --cov=mne --cov-report html:coverage - test-doc: sample_data testing_data $(PYTESTS) --doctest-modules --doctest-ignore-import-errors --doctest-glob='*.rst' ./doc/ --ignore=./doc/auto_examples --ignore=./doc/auto_tutorials --ignore=./doc/_build --ignore=./doc/conf.py --ignore=doc/sphinxext --fulltrace -test-coverage: testing_data - rm -rf coverage .coverage - $(PYTESTS) --cov=mne --cov-report html:coverage -# what's the difference with test-no-sample-with-coverage? - -test-mem: in testing_data - ulimit -v 1097152 && $(PYTESTS) mne - -trailing-spaces: - find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' +pre-commit: + @pre-commit run -a -ctags: - # make tags for symbol based navigation in emacs and vim - # Install with: sudo apt-get install exuberant-ctags - $(CTAGS) -R * - -upload-pipy: - python setup.py sdist bdist_egg register upload - -flake: - @if command -v flake8 > /dev/null; then \ - echo "Running flake8"; \ - flake8 --count; \ - else \ - echo "flake8 not found, please install it!"; \ - exit 1; \ - fi; - @echo "flake8 passed" +# Aliases for stuff we used to support or users might think of +ruff: pre-commit +flake: pre-commit +pep: pre-commit codespell: # running manually @codespell --builtin clear,rare,informal,names,usage -w -i 3 -q 3 -S $(CODESPELL_SKIPS) --ignore-words=ignore_words.txt --uri-ignore-words-list=bu $(CODESPELL_DIRS) -codespell-error: # running on travis - @codespell --builtin clear,rare,informal,names,usage -i 0 -q 7 -S $(CODESPELL_SKIPS) --ignore-words=ignore_words.txt --uri-ignore-words-list=bu $(CODESPELL_DIRS) - -pydocstyle: - @echo "Running pydocstyle" - @pydocstyle mne - -docstring: - @echo "Running docstring tests" - @$(PYTESTS) --doctest-modules mne/tests/test_docstring_parameters.py - check-manifest: check-manifest -q --ignore .circleci/config.yml,doc,logo,mne/io/*/tests/data*,mne/io/tests/data,mne/preprocessing/tests/data,.DS_Store,mne/_version.py @@ -125,26 +65,3 @@ check-readme: clean wheel_quiet nesting: @echo "Running import nesting tests" @$(PYTESTS) mne/tests/test_import_nesting.py - -pep: - @$(MAKE) -k flake pydocstyle docstring codespell-error check-manifest nesting check-readme - -manpages: - @echo "I: generating manpages" - set -e; mkdir -p _build/manpages && \ - cd bin && for f in mne*; do \ - descr=$$(grep -h -e "^ *'''" -e 'DESCRIP =' $$f -h | sed -e "s,.*' *\([^'][^']*\)'.*,\1,g" | head -n 1); \ - PYTHONPATH=../ \ - help2man -n "$$descr" --no-discard-stderr --no-info --version-string "$(uver)" ./$$f \ - >| ../_build/manpages/$$f.1; \ - done - -build-doc-dev: - cd doc; make clean - cd doc; DISPLAY=:1.0 xvfb-run -n 1 -s "-screen 0 1280x1024x24 -noreset -ac +extension GLX +render" make html_dev - -build-doc-stable: - cd doc; make clean - cd doc; DISPLAY=:1.0 xvfb-run -n 1 -s "-screen 0 1280x1024x24 -noreset -ac +extension GLX +render" make html_stable - -docstyle: pydocstyle diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 186c081f036..e27e056aa3f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -39,7 +39,7 @@ stages: pool: vmImage: 'ubuntu-latest' variables: - PYTHON_VERSION: '3.9' + PYTHON_VERSION: '3.11' PYTHON_ARCH: 'x64' steps: - bash: echo $(COMMIT_MSG) @@ -50,23 +50,14 @@ stages: addToPath: true displayName: 'Get Python' - bash: | - set -e + set -eo pipefail python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off -r requirements_base.txt -r requirements_hdf5.txt -r requirements_testing.txt + pre-commit install --install-hooks displayName: Install dependencies - bash: | - make flake - displayName: make flake - - bash: | - make codespell-error - displayName: make codespell - - bash: | - make pydocstyle - displayName: make pydocstyle - condition: always() - - bash: | - make docstring - displayName: make docstring + make pre-commit + displayName: make ruff condition: always() - bash: | make nesting diff --git a/doc/install/contributing.rst b/doc/install/contributing.rst index d1419ef80f3..f7b278fe7d7 100644 --- a/doc/install/contributing.rst +++ b/doc/install/contributing.rst @@ -694,11 +694,16 @@ Adhere to standard Python style guidelines All contributions to MNE-Python are checked against style guidelines described in `PEP 8`_. We also check for common coding errors (such as variables that are defined but never used). We allow very few exceptions to these guidelines, and -use tools such as pep8_, pyflakes_, and flake8_ to check code style +use tools such as ruff_ to check code style automatically. From the :file:`mne-python` root directory, you can check for -style violations by running:: +style violations by first installing our pre-commit hook:: - $ make flake + $ pip install pre-commit + $ pre-commit install --install-hooks + +Then running:: + + $ make ruff # alias for `pre-commit run -a` in the shell. Several text editors or IDEs also have Python style checking, which can highlight style errors while you code (and train you to make those @@ -748,7 +753,7 @@ but complete docstrings are appropriate when private functions/methods are relatively complex. To run some basic tests on documentation, you can use:: $ pytest mne/tests/test_docstring_parameters.py - $ make docstyle + $ make ruff Cross-reference everywhere @@ -1097,8 +1102,7 @@ it can serve as a useful example of what to expect from the PR review process. .. linting .. _PEP 8: https://www.python.org/dev/peps/pep-0008/ -.. _pyflakes: https://pypi.org/project/pyflakes -.. _Flake8: http://flake8.pycqa.org/ +.. _ruff: https://beta.ruff.rs/docs .. misc diff --git a/examples/simulation/simulate_evoked_data.py b/examples/simulation/simulate_evoked_data.py index 0d4cff6a6c3..b906d2df265 100644 --- a/examples/simulation/simulate_evoked_data.py +++ b/examples/simulation/simulate_evoked_data.py @@ -55,7 +55,7 @@ def data_fun(times): - """Function to generate random source time courses""" + """Generate random source time courses.""" return (50e-9 * np.sin(30. * times) * np.exp(- (times - 0.15 + 0.05 * rng.randn(1)) ** 2 / 0.01)) diff --git a/examples/simulation/simulated_raw_data_using_subject_anatomy.py b/examples/simulation/simulated_raw_data_using_subject_anatomy.py index b78db66c965..0edb33e7d0f 100644 --- a/examples/simulation/simulated_raw_data_using_subject_anatomy.py +++ b/examples/simulation/simulated_raw_data_using_subject_anatomy.py @@ -124,8 +124,7 @@ def data_fun(times, latency, duration): - """Function to generate source time courses for evoked responses, - parametrized by latency and duration.""" + """Generate source time courses for evoked responses.""" f = 15 # oscillating frequency, beta band [Hz] sigma = 0.375 * duration sinusoid = np.sin(2 * np.pi * f * (times - latency)) diff --git a/examples/visualization/topo_customized.py b/examples/visualization/topo_customized.py index cc284431246..e9106a1e8d2 100644 --- a/examples/visualization/topo_customized.py +++ b/examples/visualization/topo_customized.py @@ -47,7 +47,8 @@ def my_callback(ax, ch_idx): - """ + """Handle axes callback. + This block of code is executed once you click on one of the channel axes in the plot. To work with the viz internals, this function should only take two parameters, the axis and the channel or data index. diff --git a/mne/beamformer/tests/test_resolution_matrix.py b/mne/beamformer/tests/test_resolution_matrix.py index d033eaf6b67..6d6730e3b9e 100755 --- a/mne/beamformer/tests/test_resolution_matrix.py +++ b/mne/beamformer/tests/test_resolution_matrix.py @@ -85,9 +85,9 @@ def test_resolution_matrix_lcmv(): # Some rows are off by about 0.1 - not yet clear why corr = [] - for (f, l) in zip(resmat_fwd, resmat_lcmv): + for (f, lf) in zip(resmat_fwd, resmat_lcmv): - corr.append(np.corrcoef(f, l)[0, 1]) + corr.append(np.corrcoef(f, lf)[0, 1]) # all row correlations should at least be above ~0.8 assert_allclose(corr, 1., atol=0.2) diff --git a/mne/chpi.py b/mne/chpi.py index cdfc9b558ae..b57477deb29 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -1212,8 +1212,8 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', # check if data has sufficiently changed if last['sin_fit'] is not None: # first iteration corrs = np.array( - [np.corrcoef(s, l)[0, 1] - for s, l in zip(sin_fit, last['sin_fit'])]) + [np.corrcoef(s, lst)[0, 1] + for s, lst in zip(sin_fit, last['sin_fit'])]) corrs *= corrs # check to see if we need to continue if fit_time - last['coil_fit_time'] <= t_step_max - 1e-7 and \ diff --git a/mne/io/base.py b/mne/io/base.py index 96e4e0f2549..df198af4033 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -570,7 +570,10 @@ def time_as_index(self, times, use_rounding=False, origin=None): @property def _raw_lengths(self): - return [l - f + 1 for f, l in zip(self._first_samps, self._last_samps)] + return [ + last - first + 1 + for first, last in zip(self._first_samps, self._last_samps) + ] @property def annotations(self): # noqa: D401 diff --git a/mne/utils/check.py b/mne/utils/check.py index 6a66fb20edc..e351184680f 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -1122,6 +1122,7 @@ def _to_rgb(*args, name='color', alpha=False): @deprecated('has_nibabel is deprecated and will be removed in 1.5') def has_nibabel(): + """Check if nibabel is installed.""" return check_version('nibabel') # pragma: no cover diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..b8b664ab193 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[tool.codespell] +ignore-words = "ignore_words.txt" +uri-ignore-words-list = "bu" +builtin = "clear,rare,informal,names,usage" +skip = "doc/references.bib" + +[tool.ruff] +select = ["E", "F", "W", "D"] +exclude = ["__init__.py", "constants.py", "fixes.py", "resources.py"] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D413", # Missing blank line after last section +] + +[tool.ruff.pydocstyle] +convention = "numpy" +ignore-decorators = [ + "property", + "setter", + "mne.utils.copy_function_doc_to_method_doc", + "mne.utils.copy_doc", + "mne.utils.deprecated" +] + +[tool.ruff.per-file-ignores] +"tutorials/time-freq/10_spectrum_class.py" = [ + "E501" # line too long +] +"mne/datasets/*/*.py" = [ + "D103", # Missing docstring in public function +] +"mne/utils/tests/test_docs.py" = [ + "D101", # Missing docstring in public class + "D410", # Missing blank line after section + "D411", # Missing blank line before section + "D414", # Section has no content +] +"examples/*/*.py" = [ + "D205", # 1 blank line required between summary line and description +] + +[tool.pytest.ini_options] +addopts = """--durations=20 --doctest-modules -ra --cov-report= --tb=short \ + --doctest-ignore-import-errors --junit-xml=junit-results.xml \ + --ignore=doc --ignore=logo --ignore=examples --ignore=tutorials \ + --ignore=mne/gui/_*.py --ignore=mne/icons --ignore=tools \ + --ignore=mne/report/js_and_css \ + --color=yes --capture=sys""" +junit_family = "xunit2" diff --git a/requirements_testing.txt b/requirements_testing.txt index 3344f4b409e..c8ff7b5c5fb 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -3,11 +3,10 @@ pytest pytest-cov pytest-timeout pytest-harvest -flake8 -flake8-array-spacing +ruff numpydoc codespell -pydocstyle check-manifest twine wheel +pre-commit diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 396a8906cc0..00000000000 --- a/setup.cfg +++ /dev/null @@ -1,42 +0,0 @@ -[aliases] -release = egg_info -RDb '' -# Make sure the sphinx docs are built each time we do a dist. -# bdist = build_sphinx bdist -# sdist = build_sphinx sdist -# Make sure a zip file is created each time we build the sphinx docs -# build_sphinx = generate_help build_sphinx zip_help -# Make sure the docs are uploaded when we do an upload -# upload = upload upload_help - -[egg_info] -# tag_build = .dev - -[bdist_rpm] -doc_files = doc - -[flake8] -exclude = __init__.py,constants.py,fixes.py,resources.py,*doc/auto_*,*doc/_build*,build/* -ignore = W503,W504,I100,I101,I201,N806,E201,E202,E221,E222,E241 -# We add A for the array-spacing plugin, and ignore the E ones it covers above -select = A,E,F,W,C -# 10_spectrum_class.py has a wide rST table -per-file-ignores = - tutorials/time-freq/10_spectrum_class.py:E501 - -[tool:pytest] -addopts = - --durations=20 --doctest-modules -ra --cov-report= --tb=short - --doctest-ignore-import-errors --junit-xml=junit-results.xml - --ignore=doc --ignore=logo --ignore=examples --ignore=tutorials - --ignore=mne/gui/_*.py --ignore=mne/icons --ignore=tools - --ignore=mne/report/js_and_css - --color=yes --capture=sys -junit_family = xunit2 - -[pydocstyle] -convention = pep257 -match_dir = ^(?!\.|doc|tutorials|examples|logo|icons).*$ -match = (?!tests/__init__\.py|fixes).*\.py -add-ignore = D100,D104,D107,D413 -add-select = D214,D215,D404,D405,D406,D407,D408,D409,D410,D411 -ignore-decorators = ^(copy_.*_doc_to_|on_trait_change|cached_property|deprecated|property|.*setter).* From 295b7c7ee90ea341dab2cc497f7ec6d5eceb2770 Mon Sep 17 00:00:00 2001 From: George O'Neill Date: Sat, 22 Apr 2023 23:26:08 +0100 Subject: [PATCH 0026/1125] ENH: Add support for Harmonic Field correction (#11536) Co-authored-by: Eric Larson --- doc/changes/latest.inc | 1 + doc/preprocessing.rst | 1 + doc/references.bib | 23 +++ mne/datasets/config.py | 4 +- mne/forward/_make_forward.py | 3 +- mne/io/pick.py | 6 +- mne/io/tests/test_pick.py | 16 ++ mne/preprocessing/__init__.py | 1 + mne/preprocessing/hfc.py | 100 +++++++++++++ mne/preprocessing/maxwell.py | 6 +- mne/preprocessing/ssp.py | 6 +- mne/preprocessing/tests/test_hfc.py | 149 +++++++++++++++++++ mne/utils/docs.py | 5 + tutorials/preprocessing/80_opm_processing.py | 78 ++++++++-- 14 files changed, 375 insertions(+), 24 deletions(-) create mode 100644 mne/preprocessing/hfc.py create mode 100644 mne/preprocessing/tests/test_hfc.py diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 9dc28d470f4..12c192e2b07 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -27,6 +27,7 @@ Enhancements - Adjusted the algorithm used in :class:`mne.decoding.SSD` to support non-full rank data (:gh:`11458` by :newcontrib:`Thomas Binns`) - Changed suggested type for ``ch_groups``` in `mne.viz.plot_sensors` from array to list of list(s) (arrays are still supported). (:gh:`11465` by `Hyonyoung Shin`_) - Add support for UCL/FIL OPM data using :func:`mne.io.read_raw_fil` (:gh:`11366` by :newcontrib:`George O'Neill` and `Robert Seymour`_) +- Add harmonic field correction (HFC) for OPM sensors in :func:`mne.preprocessing.compute_proj_hfc` (:gh:`11536` by :newcontrib:`George O'Neill` and `Eric Larson`_) - Forward argument ``axes`` from `mne.viz.plot_sensors` to `mne.channels.DigMontage.plot` (:gh:`11470` by :newcontrib:`Jan Ebert` and `Mathieu Scheltienne`_) - Add forward IIR filtering, using parameters ``method='iir', phase='forward'`` (:gh:`11078` by :newcontrib:`Quentin Barthélemy`) - Added ability to read stimulus durations from SNIRF files when using :func:`mne.io.read_raw_snirf` (:gh:`11397` by `Robert Luke`_) diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index c92167a04fe..0ed960be4b9 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -83,6 +83,7 @@ Projections: compute_maxwell_basis compute_proj_ecg compute_proj_eog + compute_proj_hfc cortical_signal_suppression create_ecg_epochs create_eog_epochs diff --git a/doc/references.bib b/doc/references.bib index 87a033a97f5..95d78aa1ce6 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2406,3 +2406,26 @@ @article{SeymourEtAl2022 year = {2022}, pages = {118834} } + +@article{TierneyEtAl2021, + title = {Modelling optically pumped magnetometer interference in MEG as a spatially homogeneous magnetic field}, + volume = {244}, + issn = {1053-8119}, + doi = {j.neuroimage.2021.118484}, + language = {en}, + journal = {NeuroImage}, + author = {Tierney, Tim M. and Alexander, Nicholas and Mellor, Stephanie and Holmes, Niall and Seymour, Robert and O'Neill, George C. and Maguire, Eleanor A. and Barnes, Gareth R.}, + year = {2021}, + pages = {118834} +} + +@article{TierneyEtAl2022, + title = {Spherical harmonic based noise rejection and neuronal sampling with multi-axis OPMs}, + journal = {NeuroImage}, + volume = {258}, + pages = {119338}, + year = {2022}, + issn = {1053-8119}, + doi = {j.neuroimage.2022.119338}, + author = {Tierney, Tim M. and Mellor, Stephanie nd O'Neill, George C. and Timms, Ryan C. and Barnes, Gareth R.}, +} \ No newline at end of file diff --git a/mne/datasets/config.py b/mne/datasets/config.py index e84d63b41c4..da2ed677566 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing='0.144', misc='0.26') +RELEASES = dict(testing='0.145', misc='0.26') TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -111,7 +111,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS['testing'] = dict( archive_name=f'{TESTING_VERSIONED}.tar.gz', - hash='md5:fb546f44dba3310945225ed8fdab4a91', + hash='md5:2036f7d7616129c624b757fbb019be24', url=('https://codeload.github.com/mne-tools/mne-testing-data/' f'tar.gz/{RELEASES["testing"]}'), # In case we ever have to resort to osf.io again... diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 3bd54fca55c..34be8d023cd 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -35,7 +35,8 @@ _FWD_ORDER) -_accuracy_dict = dict(normal=FWD.COIL_ACCURACY_NORMAL, +_accuracy_dict = dict(point=FWD.COIL_ACCURACY_POINT, + normal=FWD.COIL_ACCURACY_NORMAL, accurate=FWD.COIL_ACCURACY_ACCURATE) _extra_coil_def_fname = None diff --git a/mne/io/pick.py b/mne/io/pick.py index dc914c53ffd..292e9f8b772 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -1146,8 +1146,12 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, bad_type = pick break else: + # bad_type is None but this could still be empty + bad_type = list(picks) # triage MEG and FNIRS, which are complicated due to non-bool entries extra_picks = set() + if 'ref_meg' not in picks and not with_ref_meg: + kwargs['ref_meg'] = False if len(meg) > 0 and not kwargs.get('meg', False): # easiest just to iterate for use_meg in meg: @@ -1172,7 +1176,7 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, raise ValueError( 'picks (%s) could not be interpreted as ' 'channel names (no channel "%s"), channel types (no ' - 'type "%s"), or a generic type (just "all" or "data")' + 'type "%s" present), or a generic type (just "all" or "data")' % (repr(orig_picks) + extra_repr, str(bad_names), bad_type)) picks = np.array([], int) elif sum(any_found) > 1: diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py index 5ae95424b19..5508a263905 100644 --- a/mne/io/tests/test_pick.py +++ b/mne/io/tests/test_pick.py @@ -550,6 +550,22 @@ def test_picks_to_idx(): assert_array_equal(np.arange(len(info['ch_names'])), _picks_to_idx(info, 'all')) assert_array_equal([0], _picks_to_idx(info, 'data')) + # MEG reference sensors + info_ref = read_info(ctf_fname) + picks_meg = pick_types(info_ref, meg=True, ref_meg=False) + assert len(picks_meg) == 275 + picks_ref = pick_types(info_ref, meg=False, ref_meg=True) + assert len(picks_ref) == 29 + picks_meg_ref = np.sort(np.concatenate([picks_meg, picks_ref])) + assert len(picks_meg_ref) == 275 + 29 + assert_array_equal( + picks_meg_ref, pick_types(info_ref, meg=True, ref_meg=True)) + assert_array_equal( + picks_meg, _picks_to_idx(info_ref, 'meg', with_ref_meg=False)) + assert_array_equal( # explicit trumps implicit + picks_ref, _picks_to_idx(info_ref, 'ref_meg', with_ref_meg=False)) + assert_array_equal( + picks_meg_ref, _picks_to_idx(info_ref, 'meg', with_ref_meg=True)) def test_pick_channels_cov(): diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py index b9c308dddaa..fc7c5f15abc 100644 --- a/mne/preprocessing/__init__.py +++ b/mne/preprocessing/__init__.py @@ -33,4 +33,5 @@ from .interpolate import equalize_bads, interpolate_bridged_electrodes from . import ieeg from ._css import cortical_signal_suppression +from .hfc import compute_proj_hfc from . import eyetracking diff --git a/mne/preprocessing/hfc.py b/mne/preprocessing/hfc.py new file mode 100644 index 00000000000..87a5f2713f7 --- /dev/null +++ b/mne/preprocessing/hfc.py @@ -0,0 +1,100 @@ +# Authors: George O'Neill +# +# License: BSD-3-Clause + +import numpy as np + +from .maxwell import _prep_mf_coils, _sss_basis +from ..io.pick import _picks_to_idx, pick_info +from ..io.proj import Projection +from ..utils import verbose + + +@verbose +def compute_proj_hfc(info, order=1, picks='meg', exclude='bads', + *, accuracy='accurate', verbose=None): + """Generate projectors to perform homogeneous/harmonic correction to data. + + Remove evironmental fields from magentometer data by assuming it is + explained as a homogeneous :footcite:`TierneyEtAl2021` or harmonic field + :footcite:`TierneyEtAl2022`. Useful for arrays of OPMs. + + Parameters + ---------- + %(info)s + order : int + The order of the spherical harmonic basis set to use. Set to 1 to use + only the homogeneous field component (default), 2 to add gradients, 3 + to add quadrature terms etc. + picks : str | array_like | slice | None + Channels to include. Default of ``'meg'`` (same as None) will select + all non-reference MEG channels. Use ``('meg', 'ref_meg')`` to include + reference sensors as well. + exclude : list | 'bads' + List of channels to exclude from HFC, only used when picking + based on types (e.g., exclude="bads" when picks="meg"). + Specify ``'bads'`` (the default) to exclude all channels marked as bad. + accuracy : str + Can be ``"point"``, ``"normal"`` or ``"accurate"`` (default), defines + which level of coil definition accuracy is used to generate model. + %(verbose)s + + Returns + ------- + %(projs)s + + See Also + -------- + mne.io.Raw.add_proj + mne.io.Raw.apply_proj + + Notes + ----- + To apply the projectors to a dataset, use + ``inst.add_proj(projs).apply_proj()``. + + .. versionadded:: 1.4 + + References + ---------- + .. footbibliography:: + """ + picks = _picks_to_idx( + info, picks, none='meg', exclude=exclude, with_ref_meg=False) + info = pick_info(info, picks) + del picks + exp = dict(origin=(0., 0., 0.), int_order=0, ext_order=order) + coils = _prep_mf_coils(info, ignore_ref=False, accuracy=accuracy) + n_chs = len(coils[5]) + if n_chs != info['nchan']: + raise ValueError( + f'Only {n_chs}/{info["nchan"]} picks could be interpreted ' + 'as MEG channels.') + S = _sss_basis(exp, coils) + del coils + bad_chans = [ + info['ch_names'][pick] + for pick in np.where((~np.isfinite(S)).any(axis=1))[0] + ] + if bad_chans: + raise ValueError( + "The following channel(s) generate non-finite projectors:\n" + f" {bad_chans}\nPlease exclude from picks!") + S /= np.linalg.norm(S, axis=0) + labels = _label_basis(order) + assert len(labels) == S.shape[1] + projs = [] + for label, vec in zip(labels, S.T): + proj_data = dict(col_names=info['ch_names'], row_names=None, + data=vec[np.newaxis, :], ncol=info['nchan'], nrow=1) + projs.append(Projection(active=False, data=proj_data, desc=label)) + return projs + + +def _label_basis(order): + """Give basis vectors names for Projection() class.""" + return [ + f"HFC: l={L} m={m}" + for L in np.arange(1, order + 1) + for m in np.arange(-1 * L, L + 1) + ] diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index b6dba1fc21a..ff31cb4fc60 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -806,10 +806,12 @@ def _check_destination(destination, info, head_frame): @verbose -def _prep_mf_coils(info, ignore_ref=True, verbose=None): +def _prep_mf_coils(info, ignore_ref=True, *, accuracy='accurate', + verbose=None): """Get all coil integration information loaded and sorted.""" meg_sensors = _prep_meg_channels( - info, head_frame=False, ignore_ref=ignore_ref, verbose=False) + info, head_frame=False, ignore_ref=ignore_ref, accuracy=accuracy, + verbose=False) coils = meg_sensors['defs'] mag_mask = _get_mag_mask(coils) diff --git a/mne/preprocessing/ssp.py b/mne/preprocessing/ssp.py index 9059794bd6d..76c25cf8750 100644 --- a/mne/preprocessing/ssp.py +++ b/mne/preprocessing/ssp.py @@ -228,8 +228,7 @@ def compute_proj_ecg(raw, raw_event=None, tmin=-0.2, tmax=0.4, Returns ------- - proj : list - Computed SSP projectors. + %(projs)s ecg_events : ndarray Detected ECG events. drop_log : list @@ -339,8 +338,7 @@ def compute_proj_eog(raw, raw_event=None, tmin=-0.2, tmax=0.2, Returns ------- - proj: list - Computed SSP projectors. + %(projs)s eog_events: ndarray Detected EOG events. drop_log : list diff --git a/mne/preprocessing/tests/test_hfc.py b/mne/preprocessing/tests/test_hfc.py new file mode 100644 index 00000000000..04400d72332 --- /dev/null +++ b/mne/preprocessing/tests/test_hfc.py @@ -0,0 +1,149 @@ +# Authors: George O'Neill +# +# License: BSD-3-Clause + +from pathlib import Path + +import numpy as np +import pytest + +from numpy.testing import assert_allclose +from scipy.io import loadmat + +from mne.datasets import testing +from mne.io import read_raw_fil, read_info +from mne.preprocessing.hfc import compute_proj_hfc +from mne.io.pick import pick_types, pick_info, pick_channels + +fil_path = testing.data_path(download=False) / 'FIL' +fname_root = "sub-noise_ses-001_task-noise220622_run-001" + +io_dir = Path(__file__).parent.parent.parent / "io" +ctf_fname = io_dir / "tests" / "data" / "test_ctf_raw.fif" +fif_fname = io_dir / "tests" / "data" / "test_raw.fif" + +# The below channels in the test data do not have positions, set to bad +bads = ['G2-DS-Y', 'G2-DS-Z', 'G2-DT-Y', 'G2-DT-Z', 'G2-MW-Y', 'G2-MW-Z'] + +# TODO: Ignore this warning in all these tests until we deal with this properly +pytestmark = pytest.mark.filterwarnings( + 'ignore:No fiducials.*problems later!:RuntimeWarning', +) + + +def _unpack_mat(matin): + """Extract relevant entries from unstructred readmat.""" + data = matin['data'] + grad = data[0][0]['grad'] + label = list() + coil_label = list() + for ii in range(len(data[0][0]['label'])): + label.append(str(data[0][0]['label'][ii][0][0])) + for ii in range(len(grad[0][0]['label'])): + coil_label.append(str(grad[0][0]['label'][ii][0][0])) + + matout = {'label': label, + 'trial': data['trial'][0][0][0][0], + 'coil_label': coil_label, + 'coil_pos': grad[0][0]['coilpos'], + 'coil_ori': grad[0][0]['coilori']} + return matout + + +def _angle_between_each(A): + """Measure the angle between each row vector in a matrix.""" + assert A.ndim == 2 + A = A / np.linalg.norm(A, axis=1, keepdims=True) + d = (A @ A.T).ravel() + np.clip(d, -1, 1, out=d) + ang = np.abs(np.arccos(d)) + return ang + + +@testing.requires_testing_data +@pytest.mark.parametrize('order', [1, 2, 3]) +def test_correction(order): + """Apply HFC and compare to previous computed solutions.""" + binname = fil_path / "sub-noise_ses-001_task-noise220622_run-001_meg.bin" + raw = read_raw_fil(binname) + raw.load_data() + raw.info['bads'].extend([b for b in bads]) + projs = compute_proj_hfc(raw.info, order=order, accuracy="point") + raw.add_proj(projs).apply_proj() + + mat = _unpack_mat(loadmat(fil_path / f"{fname_root}_hfc_l{order}.mat")) + + proj_list = projs[0]['data']['col_names'] + picks = pick_channels(raw.ch_names, proj_list, ordered=True) + mat_list = mat["coil_label"] + mat_inds = pick_channels(mat_list, proj_list, ordered=True) + + want = mat['trial'][mat_inds] + got = raw.copy().add_proj(projs).apply_proj()[picks, 0:300][0] * 1e15 + assert_allclose(got, want, rtol=1e-7) + + # Now with default accuracy: not super close with tol but corr is good + projs = compute_proj_hfc(raw.info, order=order) + got = raw.copy().add_proj(projs).apply_proj()[picks, 0:300][0] * 1e15 + corr = np.corrcoef(got.ravel(), want.ravel())[0, 1] + assert 0.999999 < corr <= 1. + + +@testing.requires_testing_data +def test_l1_basis_orientations(): + """Test that angles between the basis components matches orientations.""" + binname = fil_path / "sub-noise_ses-001_task-noise220622_run-001_meg.bin" + raw = read_raw_fil(binname) + raw.info['bads'].extend([b for b in bads]) + projs = compute_proj_hfc(raw.info, accuracy='point') + basis = np.hstack([p['data']['data'].T for p in projs]) + picks = pick_types(raw.info, meg='mag') + assert len(picks) == 68 + assert basis.shape == (len(picks), 3) + ang_model = _angle_between_each(basis) + n_ang = len(picks) ** 2 + assert ang_model.shape == (n_ang,) + + chs = pick_info(raw.info, picks)['chs'] + ori_sens = np.array([ch['loc'][-3:] for ch in chs]) + # match the normalization that our projectors get + ori_sens /= np.linalg.norm(ori_sens, axis=0, keepdims=True) + assert ori_sens.shape == (len(picks), 3) + ang_sens = _angle_between_each(ori_sens) + assert ang_sens.shape == (n_ang,) + + assert_allclose(ang_sens, ang_model, atol=1e-7) + + +def test_ref_degenerate(): + """Test reference channel handling and degenerate conditions.""" + info = read_info(ctf_fname) + # exclude ref by default + projs = compute_proj_hfc(info) + meg_names = [ + info['ch_names'][pick] + for pick in pick_types(info, meg=True, ref_meg=False, exclude=[]) + ] + assert len(projs) == 3 + assert projs[0]['desc'] == 'HFC: l=1 m=-1' + assert projs[1]['desc'] == 'HFC: l=1 m=0' + assert projs[2]['desc'] == 'HFC: l=1 m=1' + assert projs[0]['data']['col_names'] == meg_names + meg_ref_names = [ + info['ch_names'][pick] + for pick in pick_types(info, meg=True, ref_meg=True, exclude=[]) + ] + projs = compute_proj_hfc(info, picks=('meg', 'ref_meg')) + assert projs[0]['data']['col_names'] == meg_ref_names + + # Degenerate + info = read_info(fif_fname) + compute_proj_hfc(info) # smoke test + with pytest.raises(ValueError, match='Only.*could be interpreted as MEG'): + compute_proj_hfc(info, picks=[0, 330]) # one MEG, one EEG + info['chs'][0]['loc'][:] = np.nan # first MEG proj + with pytest.raises(ValueError, match='non-finite projectors'): + compute_proj_hfc(info) + info_eeg = pick_info(info, pick_types(info, meg=False, eeg=True)) + with pytest.raises(ValueError, match=r'picks \(\'meg\'\) could not be'): + compute_proj_hfc(info_eeg) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f5817ae3b74..0e66fed2e2f 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2866,6 +2866,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): must be set to ``False`` (the default in this case). """ +docdict['projs'] = """ +projs : list of Projection + List of computed projection vectors. +""" + docdict['projs_report'] = """ projs : bool | None Whether to add SSP projector plots if projectors are present in diff --git a/tutorials/preprocessing/80_opm_processing.py b/tutorials/preprocessing/80_opm_processing.py index ac9d80bc692..aa8d3cb8278 100644 --- a/tutorials/preprocessing/80_opm_processing.py +++ b/tutorials/preprocessing/80_opm_processing.py @@ -29,12 +29,12 @@ import mne opm_data_folder = mne.datasets.ucl_opm_auditory.data_path() -opm_file = (opm_data_folder / 'sub-001' / 'ses-001' / 'meg' / - 'sub-001_ses-001_task-aef_run-001_meg.bin') +opm_file = (opm_data_folder / 'sub-002' / 'ses-001' / 'meg' / + 'sub-002_ses-001_task-aef_run-001_meg.bin') # For now we are going to assume the device and head coordinate frames are # identical (even though this is incorrect), so we pass verbose='error' for now raw = mne.io.read_raw_fil(opm_file, verbose='error') -raw.crop(120, 240).load_data() # crop for speed +raw.crop(120, 210).load_data() # crop for speed # %% # Examining raw data @@ -74,6 +74,7 @@ # To do this in our current dataset, we require a bit of housekeeping. # There are a set of channels beginning with the name "Flux" which do not # contain any evironmental data, these need to be set to as bad channels. +# Another channel -- G2-17-TAN -- will also be set to bad. # # For now we are only interested in removing artefacts seen below 5 Hz, so we # initially low-pass filter the good reference channels in this dataset prior @@ -88,6 +89,7 @@ # set flux channels to bad bad_picks = mne.pick_channels_regexp(raw.ch_names, regexp='Flux.') raw.info['bads'].extend([raw.ch_names[ii] for ii in bad_picks]) +raw.info['bads'].extend(['G2-17-TAN']) # compute the PSD for later using 1 Hz resolution psd_kwargs = dict(fmax=20, n_fft=int(round(raw.info['sfreq']))) @@ -108,6 +110,37 @@ ax.grid(True, ls=':') ax.set(title='After reference regression', **set_kwargs) +# compute the psd of the regressed data +psd_post_reg = raw.compute_psd(**psd_kwargs) + +# %% +# Denoising: Regressing via homogeneous field correction +# ------------------------------------------------------ +# +# Regression of a reference channel is a start, but in this instance assumes +# the relatiship between the references and a given sensor on the head as +# constant. However this becomes less accurate when the reference is not moving +# but the subject is. An alternative method, Homogeneous Field Correction (HFC) +# only requires that the sensors on the helmet stationary relative to each +# other. Which in a well-designed rigid helmet is the case. + + +# include gradients by setting order to 2, set to 1 for homgenous components +projs = mne.preprocessing.compute_proj_hfc(raw.info, order=2) +raw.add_proj(projs).apply_proj(verbose='error') + +# plot +data_ds, _ = raw[picks[::5], :stop] +data_ds = data_ds[:, ::step] * amp_scale + +fig, ax = plt.subplots(constrained_layout=True) +ax.plot(time_ds, data_ds.T - np.mean(data_ds, axis=1), **plot_kwargs) +ax.grid(True, ls=':') +ax.set(title='After HFC', **set_kwargs) + +# compute the psd of the regressed data +psd_post_hfc = raw.compute_psd(**psd_kwargs) + # %% # Comparing denoising methods # --------------------------- @@ -118,18 +151,34 @@ # after processing. We will use metric called the shielding factor to summarise # the values. Positive shielding factors indicate a reduction in power, whilst # negative means in increase. +# +# We see that reference regression does a good job in reducing low frequency +# drift up to ~2 Hz, with 20 dB of shielding. But rapidly drops off due to +# low pass filtering the reference signal at 5 Hz. We also can see that this +# method is also introducing additional interference at 3 Hz. +# +# HFC improves on the low frequency shielding (up to 32 dB). Also this method +# is not frequency-specific so we observe broadband interference reduction. -# psd_pre was computed above before regression -psd_post = raw.compute_psd(**psd_kwargs) -shielding = 10 * np.log10(psd_pre[:] / psd_post[:]) +shielding = 10 * np.log10(psd_pre[:] / psd_post_reg[:]) fig, ax = plt.subplots(constrained_layout=True) -ax.plot(psd_post.freqs, shielding.T, **plot_kwargs) +ax.plot(psd_post_reg.freqs, shielding.T, **plot_kwargs) ax.grid(True, ls=':') -ax.set(xticks=psd_post.freqs) +ax.set(xticks=psd_post_reg.freqs) ax.set(xlim=(0, 20), title='Reference regression shielding', xlabel='Frequency (Hz)', ylabel='Shielding (dB)') + +shielding = 10 * np.log10(psd_pre[:] / psd_post_hfc[:]) + +fig, ax = plt.subplots(constrained_layout=True) +ax.plot(psd_post_hfc.freqs, shielding.T, **plot_kwargs) +ax.grid(True, ls=':') +ax.set(xticks=psd_post_hfc.freqs) +ax.set(xlim=(0, 20), title='Reference regression & HFC shielding', + xlabel='Frequency (Hz)', ylabel='Shielding (dB)') + # %% # Filtering nuisance signals # -------------------------- @@ -141,14 +190,14 @@ # to the neural signals we are interested in). # # We are going to remove the 50 Hz mains signal with a notch filter, -# followed by a bandpass filter between 1 and 48 Hz. From here it becomes clear +# followed by a bandpass filter between 2 and 40 Hz. From here it becomes clear # that the variance in our signal has been reduced from 100s of pT to 10s of # pT instead. # notch -raw.notch_filter(np.arange(50, 251, 50)) +raw.notch_filter(np.arange(50, 251, 50), notch_widths=4) # bandpass -raw.filter(1, 48, picks='meg') +raw.filter(2, 40, picks='meg') # plot data_ds, _ = raw[picks[::5], :stop] data_ds = data_ds[:, ::step] * amp_scale @@ -158,7 +207,7 @@ ax.grid(True) set_kwargs = dict(ylim=(-500, 500), xlim=time_ds[[0, -1]], xlabel='Time (s)', ylabel='Amplitude (pT)') -ax.set(title='After regression and filtering', **set_kwargs) +ax.set(title='After regression, HFC and filtering', **set_kwargs) # %% # Generating an evoked response @@ -167,10 +216,11 @@ # With the data preprocessed, it is now possible to see an auditory evoked # response at the sensor level. -# sphinx_gallery_thumbnail_number = 5 +# sphinx_gallery_thumbnail_number = 7 events = mne.find_events(raw, min_duration=0.1) -epochs = mne.Epochs(raw, events, tmin=-0.1, tmax=0.4, baseline=(-0.1, 0.)) +epochs = mne.Epochs(raw, events, tmin=-0.1, tmax=0.4, + baseline=(-0.1, 0.), verbose='error') evoked = epochs.average() evoked.plot() From e4dd2286e397d3c40a92fefddaddc6273e683e1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 24 Apr 2023 20:25:58 +0200 Subject: [PATCH 0027/1125] MRG: Rename "Discourse" link in top navigation to "Forum" [ci skip] (#11649) --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index d4fc98611cf..41a0d3d6272 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -617,7 +617,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): dict(name='Twitter', url='https://twitter.com/mne_python', icon='fa-brands fa-square-twitter'), - dict(name='Discourse', + dict(name='Forum', url='https://mne.discourse.group/', icon='fa-brands fa-discourse'), dict(name='Discord', From d5556313874d1d68e50dcab685d80fb6f73a7913 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 25 Apr 2023 00:33:49 -0400 Subject: [PATCH 0028/1125] MAINT: Unify GH Actions pytest (#11644) --- .github/workflows/compat_minimal.yml | 77 ------------ .github/workflows/compat_old.yml | 68 ---------- .github/workflows/linux_conda.yml | 125 ------------------- .github/workflows/linux_pip.yml | 76 ----------- .github/workflows/macos_conda.yml | 76 ----------- .github/workflows/precommit.yml | 14 --- .github/workflows/tests.yml | 117 +++++++++++++++++ README.rst | 2 +- azure-pipelines.yml | 5 +- mne/bem.py | 107 ++-------------- mne/conftest.py | 2 + mne/cov.py | 6 +- mne/epochs.py | 8 +- mne/evoked.py | 22 ++-- mne/filter.py | 58 ++++++--- mne/fixes.py | 12 ++ mne/forward/forward.py | 6 +- mne/inverse_sparse/tests/test_mxne_optim.py | 10 +- mne/io/_digitization.py | 2 +- mne/io/brainvision/tests/test_brainvision.py | 119 ++++++++---------- mne/io/cnt/cnt.py | 29 ++--- mne/io/ctf/info.py | 3 +- mne/io/ctf/res4.py | 4 +- mne/io/ctf_comp.py | 2 +- mne/io/egi/egimff.py | 1 + mne/io/egi/general.py | 8 +- mne/io/fiff/raw.py | 6 +- mne/io/fiff/tests/test_raw_fiff.py | 2 +- mne/io/meas_info.py | 72 +++++------ mne/io/open.py | 6 +- mne/io/proc_history.py | 10 +- mne/io/proj.py | 12 +- mne/io/snirf/_snirf.py | 2 +- mne/io/tag.py | 48 ++++--- mne/io/tests/test_raw.py | 5 +- mne/io/tree.py | 2 +- mne/minimum_norm/inverse.py | 11 +- mne/preprocessing/_csd.py | 2 +- mne/preprocessing/_fine_cal.py | 2 +- mne/source_estimate.py | 20 +-- mne/source_space.py | 20 +-- mne/stats/cluster_level.py | 2 +- mne/stats/regression.py | 4 +- mne/surface.py | 4 +- mne/utils/_testing.py | 3 +- mne/utils/docs.py | 2 +- mne/utils/numerics.py | 2 +- mne/viz/_brain/tests/test_brain.py | 16 +-- mne/viz/_brain/tests/test_notebook.py | 17 ++- mne/viz/utils.py | 2 +- tools/get_minimal_commands.sh | 32 +++-- tools/github_actions_dependencies.sh | 26 ++-- tools/github_actions_env_vars.sh | 30 +++++ tools/github_actions_install.sh | 7 +- tools/github_actions_locale.sh | 5 - tools/github_actions_test.sh | 15 ++- 56 files changed, 520 insertions(+), 826 deletions(-) delete mode 100644 .github/workflows/compat_minimal.yml delete mode 100644 .github/workflows/compat_old.yml delete mode 100644 .github/workflows/linux_conda.yml delete mode 100644 .github/workflows/linux_pip.yml delete mode 100644 .github/workflows/macos_conda.yml delete mode 100644 .github/workflows/precommit.yml create mode 100644 .github/workflows/tests.yml create mode 100755 tools/github_actions_env_vars.sh delete mode 100755 tools/github_actions_locale.sh diff --git a/.github/workflows/compat_minimal.yml b/.github/workflows/compat_minimal.yml deleted file mode 100644 index a0027f46397..00000000000 --- a/.github/workflows/compat_minimal.yml +++ /dev/null @@ -1,77 +0,0 @@ -name: 'compat / minimal' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - # Minimal (runs with and without testing data) - job: - name: 'minimal 3.8' - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash - env: - CONDA_DEPENDENCIES: 'numpy scipy matplotlib' - DEPS: 'minimal' - DISPLAY: ':99.0' - MNE_DONTWRITE_HOME: true - MNE_FORCE_SERIAL: true - MNE_LOGGING_LEVEL: 'warning' - MNE_SKIP_NETWORK_TEST: 1 - OPENBLAS_NUM_THREADS: '1' - PYTHONUNBUFFERED: '1' - PYTHON_VERSION: '3.8' - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - run: ./tools/setup_xvfb.sh - name: 'Setup xvfb' - - uses: conda-incubator/setup-miniconda@v2 - with: - python-version: ${{ env.PYTHON_VERSION }} - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: true - - shell: bash -el {0} - run: | - ./tools/github_actions_dependencies.sh - source tools/get_minimal_commands.sh - name: 'Install dependencies' - - shell: bash -el {0} - run: ./tools/github_actions_install.sh - name: 'Install MNE' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - shell: bash -el {0} - run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/ - name: Run tests with no testing data - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: ./tools/github_actions_locale.sh - name: 'Print locale' - - shell: bash -el {0} - run: ./tools/github_actions_test.sh - name: 'Run tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/compat_old.yml b/.github/workflows/compat_old.yml deleted file mode 100644 index 36e47774231..00000000000 --- a/.github/workflows/compat_old.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: 'compat / old' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - job: - name: 'old 3.8' - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash - env: - CONDA_DEPENDENCIES: 'numpy=1.20.2 scipy=1.6.3 matplotlib=3.4 pandas=1.2.4 scikit-learn=0.24.2' - DISPLAY: ':99.0' - MNE_LOGGING_LEVEL: 'warning' - OPENBLAS_NUM_THREADS: '1' - PYTHONUNBUFFERED: '1' - PYTHON_VERSION: '3.8' - MNE_IGNORE_WARNINGS_IN_TESTS: 'true' - steps: - - uses: actions/checkout@v3 - - run: ./tools/setup_xvfb.sh - name: 'Setup xvfb' - - uses: conda-incubator/setup-miniconda@v2 - with: - python-version: ${{ env.PYTHON_VERSION }} - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: true - - shell: bash -el {0} - run: | - ./tools/github_actions_dependencies.sh - source tools/get_minimal_commands.sh - name: 'Install dependencies' - - shell: bash -el {0} - run: ./tools/github_actions_install.sh - name: 'Install MNE' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: ./tools/github_actions_locale.sh - name: 'Print locale' - - shell: bash -el {0} - run: ./tools/github_actions_test.sh - name: 'Run tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml deleted file mode 100644 index 9822254ecee..00000000000 --- a/.github/workflows/linux_conda.yml +++ /dev/null @@ -1,125 +0,0 @@ -name: 'linux / conda' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - py310: - runs-on: ubuntu-20.04 - name: 'linux conda 3.10' - defaults: - run: - shell: bash - env: - CONDA_ENV: 'environment.yml' - DISPLAY: ':99.0' - MNE_LOGGING_LEVEL: 'warning' - MKL_NUM_THREADS: '1' - PYTHONUNBUFFERED: '1' - PYTHON_VERSION: '3.10' - steps: - - uses: actions/checkout@v3 - - run: ./tools/setup_xvfb.sh - name: 'Setup xvfb' - - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: 'mne' - python-version: ${{ env.PYTHON_VERSION }} - environment-file: ${{ env.CONDA_ENV }} - # No mamba for this one job (use conda itself!) - - shell: bash -el {0} - run: | - ./tools/github_actions_dependencies.sh - source tools/get_minimal_commands.sh - name: 'Install dependencies' - - shell: bash -el {0} - run: mne_surf2bem --version - name: 'Check minimal commands' - - shell: bash -el {0} - run: ./tools/github_actions_install.sh - name: 'Install MNE' - - shell: bash -el {0} - run: | - QT_QPA_PLATFORM=xcb LIBGL_DEBUG=verbose LD_DEBUG=libs python -c "import pyvistaqt; pyvistaqt.BackgroundPlotter(show=True)" - name: 'Check Qt GL' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: ./tools/github_actions_locale.sh - name: 'Print locale' - - shell: bash -el {0} - run: ./tools/github_actions_test.sh - name: 'Run tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' - - notebook: - timeout-minutes: 90 - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash - env: - CONDA_ENV: 'environment.yml' - PYTHON_VERSION: '3.10' - steps: - - uses: actions/checkout@v3 - - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: 'mne' - python-version: ${{ env.PYTHON_VERSION }} - environment-file: ${{ env.CONDA_ENV }} - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: true - - shell: bash -el {0} - run: | - # TODO: As of 2023/02/28, notebook tests need a pinned mesalib - mamba install -c conda-forge "vtk>=9.2=*osmesa*" "mesalib=21.2.5" - mamba list - name: 'Install OSMesa VTK variant' - - shell: bash -el {0} - run: | - pip uninstall -yq mne - pip install --progress-bar off -ve .[test] - name: 'Install dependencies' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: pytest --tb=short -m "not pgtest" --cov=mne --cov-report=xml --cov-report=html -vv mne/viz - name: 'Run viz tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/linux_pip.yml b/.github/workflows/linux_pip.yml deleted file mode 100644 index fff47e31508..00000000000 --- a/.github/workflows/linux_pip.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: 'linux / pip-pre' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -permissions: - contents: read - -jobs: - # PIP-pre + non-default stim channel + log level info - job: - name: 'linux pip 3.10' - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash - env: - DISPLAY: ':99.0' - MNE_LOGGING_LEVEL: 'info' - MNE_STIM_CHANNEL: 'STI101' - OPENBLAS_NUM_THREADS: '1' - PYTHONUNBUFFERED: '1' - PYTHON_VERSION: '3.10' - steps: - - uses: actions/checkout@v3 - - run: ./tools/setup_xvfb.sh - name: 'Setup xvfb' - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - name: 'Setup python' - - shell: bash -el {0} - run: | - ./tools/github_actions_dependencies.sh - source tools/get_minimal_commands.sh - name: 'Install dependencies' - - shell: bash -el {0} - run: mne_surf2bem --version - name: 'Check minimal commands' - - shell: bash -el {0} - run: ./tools/github_actions_install.sh - name: 'Install MNE' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: | - ./tools/check_qt_import.sh PyQt6 - python -c "import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: ./tools/github_actions_locale.sh - name: 'Print locale' - - shell: bash -el {0} - run: ./tools/github_actions_test.sh - name: 'Run tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/macos_conda.yml b/.github/workflows/macos_conda.yml deleted file mode 100644 index 3befcc0b32b..00000000000 --- a/.github/workflows/macos_conda.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: 'macos / conda' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - job: - name: 'macos 3.8' - runs-on: macos-latest - defaults: - run: - shell: bash - env: - PYTHON_VERSION: '3.8' - MNE_LOGGING_LEVEL: 'warning' - MNE_3D_OPTION_SMOOTH_SHADING: 'true' - OPENBLAS_NUM_THREADS: '1' - PYTHONUNBUFFERED: '1' - CONDA_ENV: 'environment.yml' - CI_OS_NAME: 'osx' - steps: - - uses: actions/checkout@v3 - - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: 'mne' - python-version: ${{ env.PYTHON_VERSION }} - environment-file: ${{ env.CONDA_ENV }} - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: true - - shell: bash -el {0} - run: | - ./tools/github_actions_dependencies.sh - name: 'Install dependencies' - # https://github.com/mne-tools/mne-python/issues/10805 - # https://github.com/mne-tools/mne-python/runs/7042965701?check_suite_focus=true - #- shell: bash -el {0} - # run: | - # source tools/get_minimal_commands.sh - # name: 'Install minimal commands' - #- shell: bash -el {0} - # run: mne_surf2bem --version - # name: 'Check minimal commands' - - shell: bash -el {0} - run: ./tools/github_actions_install.sh - name: 'Install MNE' - - shell: bash -el {0} - run: ./tools/github_actions_infos.sh - name: 'Show infos' - - shell: bash -el {0} - run: ./tools/get_testing_version.sh - name: 'Get testing version' - - uses: actions/cache@v3 - with: - key: ${{ env.TESTING_VERSION }} - path: ~/mne_data - name: 'Cache testing data' - - shell: bash -el {0} - run: ./tools/github_actions_download.sh - name: 'Download testing data' - - shell: bash -el {0} - run: ./tools/github_actions_locale.sh - name: 'Print locale' - - shell: bash -el {0} - run: ./tools/github_actions_test.sh - name: 'Run tests' - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml deleted file mode 100644 index 4638064b646..00000000000 --- a/.github/workflows/precommit.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Pre-commit - -on: [push, pull_request] - -jobs: - style: - name: Pre-commit - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000000..3e8a3195c7a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,117 @@ +name: 'Tests' +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true +on: + push: + branches: + - '*' + pull_request: + branches: + - '*' + +permissions: + contents: read + +jobs: + style: + name: Style + runs-on: ubuntu-latest + timeout-minutes: 3 + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - uses: pre-commit/action@v3.0.0 + + pytest: + name: '${{ matrix.os }} / ${{ matrix.kind }} / ${{ matrix.python }}' + needs: style + timeout-minutes: 70 + continue-on-error: true + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + env: + PYTHON_VERSION: '${{ matrix.python }}' + MKL_NUM_THREADS: '1' + OPENBLAS_NUM_THREADS: '1' + PYTHONUNBUFFERED: '1' + MNE_CI_KIND: '${{ matrix.kind }}' + CI_OS_NAME: '${{ matrix.os }}' + strategy: + matrix: + include: + - os: ubuntu-latest + python: '3.10' + kind: conda + - os: ubuntu-latest + python: '3.10' + kind: notebook + - os: ubuntu-latest + python: '3.11' + kind: pip-pre + - os: macos-latest + python: '3.8' + kind: mamba + - os: ubuntu-latest + python: '3.8' + kind: minimal + - os: ubuntu-20.04 + python: '3.8' + kind: old + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - run: ./tools/github_actions_env_vars.sh + # Xvfb/OpenGL + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + pyvista: false + if: matrix.kind != 'notebook' + # Python (if pip) + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + if: startswith(matrix.kind, 'pip') + # Python (if conda) + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: ${{ env.CONDA_ACTIVATE_ENV }} + python-version: ${{ env.PYTHON_VERSION }} + environment-file: ${{ env.CONDA_ENV }} + miniforge-version: latest + miniforge-variant: Mambaforge + use-mamba: ${{ matrix.kind != 'conda' }} + if: ${{ !startswith(matrix.kind, 'pip') }} + - name: 'Install OSMesa VTK variant' + run: | + # TODO: As of 2023/02/28, notebook tests need a pinned mesalib + mamba install -c conda-forge "vtk>=9.2=*osmesa*" "mesalib=21.2.5" + mamba list + if: matrix.kind == 'notebook' + - run: ./tools/github_actions_dependencies.sh + # Minimal commands on Linux (macOS stalls) + - run: ./tools/get_minimal_commands.sh + if: ${{ startswith(matrix.os, 'ubuntu') }} + - run: ./tools/github_actions_install.sh + - run: ./tools/github_actions_infos.sh + # Check Qt on non-notebook + - run: ./tools/check_qt_import.sh $MNE_QT_BACKEND + if: ${{ env.MNE_QT_BACKEND != '' }} + - name: Run tests with no testing data + run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/ + if: matrix.kind == 'minimal' + - run: ./tools/get_testing_version.sh + - uses: actions/cache@v3 + with: + key: ${{ env.TESTING_VERSION }} + path: ~/mne_data + - run: ./tools/github_actions_download.sh + - run: ./tools/github_actions_test.sh + - uses: codecov/codecov-action@v3 + if: success() diff --git a/README.rst b/README.rst index d94f19ad8a0..ed37a718154 100644 --- a/README.rst +++ b/README.rst @@ -105,7 +105,7 @@ For full functionality, some functions require: - Numba >= 0.53.1 - NiBabel >= 3.2.1 -- OpenMEEG >= 2.5.5 +- OpenMEEG >= 2.5.6 - Pandas >= 1.2.4 - Picard >= 0.3 - CuPy >= 9.0.0 (for NVIDIA CUDA acceleration) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e27e056aa3f..3a68d3226af 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -366,5 +366,6 @@ stages: sed -i 's/.. graphviz::/.. graphviz/g' tutorials/preprocessing/40_artifact_correction_ica.py sed -i '/sphinx\.ext\.graphviz/d' doc/conf.py displayName: Skip graph that we cannot render - - bash: make -C doc html_dev-noplot - displayName: 'Build doc' + # TODO: Reenable this once we can get it to work! + # - bash: make -C doc html_dev-noplot + # displayName: 'Build doc' diff --git a/mne/bem.py b/mne/bem.py index 66c0800a4d9..505c0fca79d 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -17,10 +17,10 @@ import os.path as op from pathlib import Path import shutil -import tempfile import numpy as np +from .fixes import _compare_version from .io.constants import FIFF, FWD from .io._digitization import _dig_kind_dict, _dig_kind_rev, _dig_kind_ints from .io.write import (start_and_end_file, start_block, write_float, write_int, @@ -313,6 +313,8 @@ def _import_openmeeg(what='compute a BEM solution using OpenMEEG'): raise ImportError( f'The OpenMEEG module must be installed to {what}, but ' f'"import openmeeg" resulted in: {exc}') from None + if not _compare_version(om.__version__, '>=', '2.5.6'): + raise ImportError(f'OpenMEEG 2.5.6+ is required, got {om.__version__}') return om @@ -328,96 +330,7 @@ def _make_openmeeg_geometry(bem, mri_head_t=None): meshes.append((points, faces)) conductivity = bem['sigma'][::-1] - # We should be able to do this: - # - # geom = om.make_nested_geometry(meshes, conductivity) - # - # But OpenMEEG's NumPy support is iffy. So let's use file IO for now :( - - def _write_tris(fname, mesh): - from .surface import complete_surface_info - mesh = dict(rr=mesh[0], tris=mesh[1]) - complete_surface_info(mesh, copy=False, do_neighbor_tri=False) - with open(fname, 'w') as fid: - fid.write(f'- {len(mesh["rr"])}\n') - for r, n in zip(mesh['rr'], mesh['nn']): - fid.write(f'{r[0]:.8f} {r[1]:.8f} {r[2]:.8f} ' - f'{n[0]:.8f} {n[1]:.8f} {n[2]:.8f}\n') - n_tri = len(mesh['tris']) - fid.write(f'- {n_tri} {n_tri} {n_tri}\n') - for t in mesh['tris']: - fid.write(f'{t[0]} {t[1]} {t[2]}\n') - - assert len(conductivity) in (1, 3) - # on Windows, the dir can't be cleaned up, presumably because OpenMEEG - # does not let go of the file pointer (?). This is not great but hopefully - # writing files is temporary, and/or we can fix the file pointer bug - # in OpenMEEG soon. - tmp_dir = tempfile.TemporaryDirectory(prefix='openmeeg-io-') - tmp_path = Path(tmp_dir.name) - # In 3.10+ we could use this as a context manager as there is a - # ignore_cleanup_errors arg, but before this there is not. - # so let's just try/finally - try: - tmp_path = Path(tmp_path) - # write geom_file and three .tri files - geom_file = tmp_path / 'tmp.geom' - names = ['inner_skull', 'outer_skull', 'outer_skin'] - lines = [ - '# Domain Description 1.1', - '', - f'Interfaces {len(conductivity)}' - '', - f'Interface Cortex: "{names[0]}.tri"', - ] - if len(conductivity) == 3: - lines.extend([ - f'Interface Skull: "{names[1]}.tri"', - f'Interface Head: "{names[2]}.tri"', - ]) - lines.extend([ - '', - f'Domains {len(conductivity) + 1}', - '', - 'Domain Brain: -Cortex', - ]) - if len(conductivity) == 1: - lines.extend([ - 'Domain Air: Cortex', - ]) - else: - lines.extend([ - 'Domain Skull: Cortex -Skull', - 'Domain Scalp: Skull -Head', - 'Domain Air: Head', - ]) - with open(geom_file, 'w') as fid: - fid.write('\n'.join(lines)) - for mesh, name in zip(meshes, names): - _write_tris(tmp_path / f'{name}.tri', mesh) - # write cond_file - cond_file = tmp_path / 'tmp.cond' - lines = [ - '# Properties Description 1.0 (Conductivities)', - '', - f'Brain {conductivity[0]}', - ] - if len(conductivity) == 3: - lines.extend([ - f'Skull {conductivity[1]}', - f'Scalp {conductivity[2]}', - ]) - lines.append('Air 0.0') - with open(cond_file, 'w') as fid: - fid.write('\n'.join(lines)) - geom = om.Geometry(str(geom_file), str(cond_file)) - finally: - try: - tmp_dir.cleanup() - except Exception: - pass # ignore any cleanup errors (esp. on Windows) - - return geom + return om.make_nested_geometry(meshes, conductivity) def _fwd_bem_openmeeg_solution(bem): @@ -1463,24 +1376,24 @@ def _read_bem_surface(fid, this, def_coord_frame, s_id=None): if tag is None: res['id'] = FIFF.FIFFV_BEM_SURF_ID_UNKNOWN else: - res['id'] = int(tag.data) + res['id'] = int(tag.data.item()) if s_id is not None and res['id'] != s_id: return None tag = find_tag(fid, this, FIFF.FIFF_BEM_SIGMA) - res['sigma'] = 1.0 if tag is None else float(tag.data) + res['sigma'] = 1.0 if tag is None else float(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NNODE) if tag is None: raise ValueError('Number of vertices not found') - res['np'] = int(tag.data) + res['np'] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NTRI) if tag is None: raise ValueError('Number of triangles not found') - res['ntri'] = int(tag.data) + res['ntri'] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COORD_FRAME) if tag is None: @@ -1488,9 +1401,9 @@ def _read_bem_surface(fid, this, def_coord_frame, s_id=None): if tag is None: res['coord_frame'] = def_coord_frame else: - res['coord_frame'] = tag.data + res['coord_frame'] = int(tag.data.item()) else: - res['coord_frame'] = tag.data + res['coord_frame'] = int(tag.data.item()) # Vertices, normals, and triangles tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NODES) diff --git a/mne/conftest.py b/mne/conftest.py index a4b261db704..9a64066b852 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -140,6 +140,8 @@ def pytest_configure(config): ignore:Implementing implicit namespace packages.*:DeprecationWarning ignore:Deprecated call to `pkg_resources.*:DeprecationWarning ignore:pkg_resources is deprecated as an API.*:DeprecationWarning + # h5py + ignore:`product` is deprecated as of NumPy.*:DeprecationWarning """.format(first_kind) # noqa: E501 for warning_line in warning_lines.split('\n'): warning_line = warning_line.strip() diff --git a/mne/cov.py b/mne/cov.py index 59dfb01edaa..c612e4ae7f1 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1936,20 +1936,20 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): for p in range(len(covs)): tag = find_tag(fid, covs[p], FIFF.FIFF_MNE_COV_KIND) - if tag is not None and int(tag.data) == cov_kind: + if tag is not None and int(tag.data.item()) == cov_kind: this = covs[p] # Find all the necessary data tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_DIM) if tag is None: raise ValueError('Covariance matrix dimension not found') - dim = int(tag.data) + dim = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_NFREE) if tag is None: nfree = -1 else: - nfree = int(tag.data) + nfree = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_METHOD) if tag is None: diff --git a/mne/epochs.py b/mne/epochs.py index 099d4651009..d12e4cb167b 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3090,10 +3090,10 @@ def _read_one_epoch_file(f, tree, preload): pos = my_epochs['directory'][k].pos if kind == FIFF.FIFF_FIRST_SAMPLE: tag = read_tag(fid, pos) - first = int(tag.data) + first = int(tag.data.item()) elif kind == FIFF.FIFF_LAST_SAMPLE: tag = read_tag(fid, pos) - last = int(tag.data) + last = int(tag.data.item()) elif kind == FIFF.FIFF_EPOCH: # delay reading until later fid.seek(pos, 0) @@ -3103,11 +3103,11 @@ def _read_one_epoch_file(f, tree, preload): elif kind in [FIFF.FIFF_MNE_BASELINE_MIN, 304]: # Constant 304 was used before v0.11 tag = read_tag(fid, pos) - bmin = float(tag.data) + bmin = float(tag.data.item()) elif kind in [FIFF.FIFF_MNE_BASELINE_MAX, 305]: # Constant 305 was used before v0.11 tag = read_tag(fid, pos) - bmax = float(tag.data) + bmax = float(tag.data.item()) elif kind == FIFF.FIFF_MNE_EPOCHS_SELECTION: tag = read_tag(fid, pos) selection = np.array(tag.data) diff --git a/mne/evoked.py b/mne/evoked.py index 07101f081d7..29c6e5ca1cb 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -976,7 +976,7 @@ def _get_entries(fid, evoked_node, allow_maxshield=False): pos = my_aspect['directory'][k].pos if my_kind == FIFF.FIFF_ASPECT_KIND: tag = read_tag(fid, pos) - aspect_kinds.append(int(tag.data)) + aspect_kinds.append(int(tag.data.item())) comments = np.atleast_1d(comments) aspect_kinds = np.atleast_1d(aspect_kinds) if len(comments) != len(aspect_kinds) or len(comments) == 0: @@ -1281,31 +1281,31 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): comment = tag.data elif my_kind == FIFF.FIFF_FIRST_SAMPLE: tag = read_tag(fid, pos) - first = int(tag.data) + first = int(tag.data.item()) elif my_kind == FIFF.FIFF_LAST_SAMPLE: tag = read_tag(fid, pos) - last = int(tag.data) + last = int(tag.data.item()) elif my_kind == FIFF.FIFF_NCHAN: tag = read_tag(fid, pos) - nchan = int(tag.data) + nchan = int(tag.data.item()) elif my_kind == FIFF.FIFF_SFREQ: tag = read_tag(fid, pos) - sfreq = float(tag.data) + sfreq = float(tag.data.item()) elif my_kind == FIFF.FIFF_CH_INFO: tag = read_tag(fid, pos) chs.append(tag.data) elif my_kind == FIFF.FIFF_FIRST_TIME: tag = read_tag(fid, pos) - first_time = float(tag.data) + first_time = float(tag.data.item()) elif my_kind == FIFF.FIFF_NO_SAMPLES: tag = read_tag(fid, pos) - nsamp = int(tag.data) + nsamp = int(tag.data.item()) elif my_kind == FIFF.FIFF_MNE_BASELINE_MIN: tag = read_tag(fid, pos) - bmin = float(tag.data) + bmin = float(tag.data.item()) elif my_kind == FIFF.FIFF_MNE_BASELINE_MAX: tag = read_tag(fid, pos) - bmax = float(tag.data) + bmax = float(tag.data.item()) if comment is None: comment = 'No comment' @@ -1344,10 +1344,10 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): comment = tag.data elif kind == FIFF.FIFF_ASPECT_KIND: tag = read_tag(fid, pos) - aspect_kind = int(tag.data) + aspect_kind = int(tag.data.item()) elif kind == FIFF.FIFF_NAVE: tag = read_tag(fid, pos) - nave = int(tag.data) + nave = int(tag.data.item()) elif kind == FIFF.FIFF_EPOCH: tag = read_tag(fid, pos) epoch.append(tag) diff --git a/mne/filter.py b/mne/filter.py index 5a3a25b5bec..2fc0d10b2c4 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -682,7 +682,9 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, logger.info(f'{ftype_nice} {btype} {ptype} filter:') # SciPy designs forward for -3dB, so forward-backward is -6dB if 'order' in iir_params: - kwargs = dict(N=iir_params['order'], Wn=Wp, btype=btype, + singleton = btype in ('low', 'lowpass', 'high', 'highpass') + use_Wp = Wp.item() if singleton else Wp + kwargs = dict(N=iir_params['order'], Wn=use_Wp, btype=btype, ftype=ftype, output=output) for key in ('rp', 'rs'): if key in iir_params: @@ -993,6 +995,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', freq = [0, sfreq / 2.] gain = [1., 1.] if l_freq is None and h_freq is not None: + h_freq = h_freq.item() logger.info('Setting up low-pass filter at %0.2g Hz' % (h_freq,)) data, sfreq, _, f_p, _, f_s, filter_length, phase, fir_window, \ fir_design = _triage_filter_params( @@ -1008,6 +1011,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', freq += [sfreq / 2.] gain += [0] elif l_freq is not None and h_freq is None: + l_freq = l_freq.item() logger.info('Setting up high-pass filter at %0.2g Hz' % (l_freq,)) data, sfreq, pass_, _, stop, _, filter_length, phase, fir_window, \ fir_design = _triage_filter_params( @@ -1024,6 +1028,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', gain = [0] + gain elif l_freq is not None and h_freq is not None: if (l_freq < h_freq).any(): + l_freq, h_freq = l_freq.item(), h_freq.item() logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz' % (l_freq, h_freq)) data, sfreq, f_p1, f_p2, f_s1, f_s2, filter_length, phase, \ @@ -1051,6 +1056,7 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', raise ValueError('l_freq and h_freq must be the same length') msg = 'Setting up band-stop filter' if len(l_freq) == 1: + l_freq, h_freq = l_freq.item(), h_freq.item() msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq) logger.info(msg) # Note: order of outputs is intentionally switched here! @@ -1785,12 +1791,22 @@ def float_array(c): 'string, got "%s"' % l_trans_bandwidth) l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.), l_freq) - msg = ('- Lower transition bandwidth: %0.2f Hz' - % (l_trans_bandwidth)) - if dB_cutoff: - logger.info('- Lower passband edge: %0.2f' % (l_freq,)) - msg += ' (%s cutoff frequency: %0.2f Hz)' % ( - dB_cutoff, l_freq - l_trans_bandwidth / 2.) + l_trans_rep = np.array(l_trans_bandwidth, float) + if l_trans_rep.size == 1: + l_trans_rep = f'{l_trans_rep.item():0.2f}' + with np.printoptions(precision=2, floatmode='fixed'): + msg = f'- Lower transition bandwidth: {l_trans_rep} Hz' + if dB_cutoff: + l_freq_rep = np.array(l_freq, float) + if l_freq_rep.size == 1: + l_freq_rep = f'{l_freq_rep.item():0.2f}' + cutoff_rep = np.array( + l_freq - l_trans_bandwidth / 2., float) + if cutoff_rep.size == 1: + cutoff_rep = f'{cutoff_rep.item():0.2f}' + # Could be an array + logger.info(f'- Lower passband edge: {l_freq_rep}') + msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' logger.info(msg) l_trans_bandwidth = cast(l_trans_bandwidth) if np.any(l_trans_bandwidth <= 0): @@ -1812,12 +1828,21 @@ def float_array(c): 'string, got "%s"' % h_trans_bandwidth) h_trans_bandwidth = np.minimum(np.maximum(0.25 * h_freq, 2.), sfreq / 2. - h_freq) - msg = ('- Upper transition bandwidth: %0.2f Hz' - % (h_trans_bandwidth)) - if dB_cutoff: - logger.info('- Upper passband edge: %0.2f Hz' % (h_freq,)) - msg += ' (%s cutoff frequency: %0.2f Hz)' % ( - dB_cutoff, h_freq + h_trans_bandwidth / 2.) + h_trans_rep = np.array(h_trans_bandwidth, float) + if h_trans_rep.size == 1: + h_trans_rep = f'{h_trans_rep.item():0.2f}' + with np.printoptions(precision=2, floatmode='fixed'): + msg = f'- Upper transition bandwidth: {h_trans_rep} Hz' + if dB_cutoff: + h_freq_rep = np.array(h_freq, float) + if h_freq_rep.size == 1: + h_freq_rep = f'{h_freq_rep.item():0.2f}' + cutoff_rep = np.array( + h_freq + h_trans_bandwidth / 2., float) + if cutoff_rep.size == 1: + cutoff_rep = f'{cutoff_rep.item():0.2f}' + logger.info(f'- Upper passband edge: {h_freq_rep} Hz') + msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' logger.info(msg) h_trans_bandwidth = cast(h_trans_bandwidth) if np.any(h_trans_bandwidth <= 0): @@ -1834,8 +1859,11 @@ def float_array(c): if isinstance(filter_length, str) and filter_length.lower() == 'auto': filter_length = filter_length.lower() - h_check = h_trans_bandwidth if h_freq is not None else np.inf - l_check = l_trans_bandwidth if l_freq is not None else np.inf + h_check = l_check = np.inf + if h_freq is not None: + h_check = min(np.atleast_1d(h_trans_bandwidth)) + if l_freq is not None: + l_check = min(np.atleast_1d(l_trans_bandwidth)) mult_fact = 2. if fir_design == 'firwin2' else 1. filter_length = '%ss' % (_length_factors[fir_window] * mult_fact / float(min(h_check, l_check)),) diff --git a/mne/fixes.py b/mne/fixes.py index 5ff7c07a66f..b59439b3b88 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -12,6 +12,7 @@ # Lars Buitinck # License: BSD +from contextlib import contextmanager import inspect from math import log from pprint import pprint @@ -923,3 +924,14 @@ def pinv(a, rtol=None): u = u[:, :rank] u /= s[:rank] return (u @ vh[:rank]).conj().T + + +############################################################################### +# h5py uses np.product which is deprecated in NumPy 1.25 + +@contextmanager +def _numpy_h5py_dep(): + # h5io uses np.product + with warnings.catch_warnings(record=True): + warnings.filterwarnings('ignore', '`product` is deprecated.*', DeprecationWarning) + yield diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 7f78eeaf08d..0d111d107ed 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -280,7 +280,7 @@ def _get_tag_int(fid, node, name, id_): if tag is None: fid.close() raise ValueError(name + ' tag not found') - return int(tag.data) + return int(tag.data.item()) def _read_one(fid, node): @@ -418,7 +418,9 @@ def _read_forward_meas_info(tree, fid): if tag is None: tag = find_tag(fid, parent_mri, 236) # Constant 236 used before v0.11 - info['custom_ref_applied'] = int(tag.data) if tag is not None else False + info['custom_ref_applied'] = ( + int(tag.data.item()) if tag is not None else False + ) info._unlocked = False return info diff --git a/mne/inverse_sparse/tests/test_mxne_optim.py b/mne/inverse_sparse/tests/test_mxne_optim.py index 4081ee0e5f4..c3288528400 100644 --- a/mne/inverse_sparse/tests/test_mxne_optim.py +++ b/mne/inverse_sparse/tests/test_mxne_optim.py @@ -142,7 +142,7 @@ def test_norm_epsilon(): n_coefs = n_steps * n_freqs phi = _Phi(wsize, tstep, n_coefs, n_times) - Y = np.zeros(n_steps * n_freqs) + Y = np.zeros((n_steps * n_freqs).item()) l1_ratio = 0.03 assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.) @@ -152,7 +152,7 @@ def test_norm_epsilon(): l1_ratio = 1. assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y)) # dummy value without random: - Y = np.arange(n_steps * n_freqs).reshape(-1, ) + Y = np.arange((n_steps * n_freqs).item()) l1_ratio = 0.0 assert_allclose(norm_epsilon(Y, l1_ratio, phi) ** 2, stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0]))) @@ -166,13 +166,13 @@ def test_norm_epsilon(): # scaling w_time and w_space by the same amount should divide # epsilon norm by the same amount - Y = np.arange(n_coefs) + 1 + Y = np.arange(n_coefs.item()) + 1 mult = 2. assert_allclose( norm_epsilon(Y, l1_ratio, phi, w_space=1, - w_time=np.ones(n_coefs)) / mult, + w_time=np.ones(n_coefs.item())) / mult, norm_epsilon(Y, l1_ratio, phi, w_space=mult, - w_time=mult * np.ones(n_coefs))) + w_time=mult * np.ones(n_coefs.item()))) @pytest.mark.slowtest # slow-ish on Travis OSX diff --git a/mne/io/_digitization.py b/mne/io/_digitization.py index 30a07c19b46..a57e0eb78eb 100644 --- a/mne/io/_digitization.py +++ b/mne/io/_digitization.py @@ -186,7 +186,7 @@ def _read_dig_fif(fid, meas_info): dig.append(tag.data) elif kind == FIFF.FIFF_MNE_COORD_FRAME: tag = read_tag(fid, pos) - coord_frame = _coord_frame_named.get(int(tag.data)) + coord_frame = _coord_frame_named.get(int(tag.data.item())) for d in dig: d['coord_frame'] = coord_frame return _format_dig_points(dig) diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py index c9c375d086d..c78baa89027 100644 --- a/mne/io/brainvision/tests/test_brainvision.py +++ b/mne/io/brainvision/tests/test_brainvision.py @@ -9,8 +9,7 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_equal) +from numpy.testing import assert_array_equal, assert_allclose import pytest import datetime @@ -194,7 +193,7 @@ def test_vhdr_codepage_ansi(tmp_path): raw = read_raw_brainvision(ansi_vhdr_path) data_new, times_new = raw[:] - assert_equal(raw_init.ch_names, raw.ch_names) + assert raw_init.ch_names == raw.ch_names assert_allclose(data_new, data_expected, atol=1e-15) assert_allclose(times_new, times_expected, atol=1e-15) @@ -340,6 +339,7 @@ def test_ch_names_comma(tmp_path): assert "F4,foo" in raw.ch_names +@pytest.mark.filterwarnings('ignore:.*different.*:RuntimeWarning') def test_brainvision_data_highpass_filters(): """Test reading raw Brain Vision files with amplifier filter settings.""" # Homogeneous highpass in seconds (default measurement unit) @@ -347,8 +347,8 @@ def test_brainvision_data_highpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_highpass_path, eog=eog ) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 250. # Heterogeneous highpass in seconds (default measurement unit) with pytest.warns(RuntimeWarning, match='different .*pass filters') as w: @@ -356,25 +356,20 @@ def test_brainvision_data_highpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_path, eog=eog) - lowpass_warning = ['different lowpass filters' in str(ww.message) - for ww in w] - highpass_warning = ['different highpass filters' in str(ww.message) - for ww in w] + w = [str(ww.message) for ww in w] + assert not any('different lowpass filters' in ww for ww in w), w + assert all('different highpass filters' in ww for ww in w), w - expected_warnings = zip(lowpass_warning, highpass_warning) - - assert (all(any([lp, hp]) for lp, hp in expected_warnings)) - - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 250. # Homogeneous highpass in Hertz raw = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_highpass_hz_path, eog=eog) - assert_equal(raw.info['highpass'], 10.) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 10. + assert raw.info['lowpass'] == 250. # Heterogeneous highpass in Hertz with pytest.warns(RuntimeWarning, match='different .*pass filters') as w: @@ -382,19 +377,13 @@ def test_brainvision_data_highpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_hz_path, eog=eog) - trigger_warning = ['will be dropped' in str(ww.message) - for ww in w] - lowpass_warning = ['different lowpass filters' in str(ww.message) - for ww in w] - highpass_warning = ['different highpass filters' in str(ww.message) - for ww in w] - - expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning) - - assert (all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings)) + w = [str(ww.message) for ww in w] + assert not any('will be dropped' in ww for ww in w), w + assert not any('different lowpass filters' in ww for ww in w), w + assert all('different highpass filters' in ww for ww in w), w - assert_equal(raw.info['highpass'], 5.) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 5. + assert raw.info['lowpass'] == 250. def test_brainvision_data_lowpass_filters(): @@ -404,8 +393,8 @@ def test_brainvision_data_lowpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_lowpass_path, eog=eog ) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 250. # Heterogeneous lowpass in Hertz (default measurement unit) with pytest.warns(RuntimeWarning) as w: # event parsing @@ -422,16 +411,16 @@ def test_brainvision_data_lowpass_filters(): assert (all(any([lp, hp]) for lp, hp in expected_warnings)) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 250.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 250. # Homogeneous lowpass in seconds raw = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_lowpass_s_path, eog=eog ) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 1. / (2 * np.pi * 0.004)) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 1. / (2 * np.pi * 0.004) # Heterogeneous lowpass in seconds with pytest.warns(RuntimeWarning) as w: # filter settings @@ -448,8 +437,8 @@ def test_brainvision_data_lowpass_filters(): assert (all(any([lp, hp]) for lp, hp in expected_warnings)) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 10)) - assert_equal(raw.info['lowpass'], 1. / (2 * np.pi * 0.004)) + assert raw.info['highpass'] == 1. / (2 * np.pi * 10) + assert raw.info['lowpass'] == 1. / (2 * np.pi * 0.004) def test_brainvision_data_partially_disabled_hw_filters(): @@ -471,8 +460,8 @@ def test_brainvision_data_partially_disabled_hw_filters(): assert (all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings)) - assert_equal(raw.info['highpass'], 0.) - assert_equal(raw.info['lowpass'], 500.) + assert raw.info['highpass'] == 0. + assert raw.info['lowpass'] == 500. def test_brainvision_data_software_filters_latin1_global_units(): @@ -482,8 +471,8 @@ def test_brainvision_data_software_filters_latin1_global_units(): read_raw_brainvision, vhdr_fname=vhdr_old_path, eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), misc=("A2",)) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 0.9)) - assert_equal(raw.info['lowpass'], 50.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 0.9) + assert raw.info['lowpass'] == 50. # test sensor name with spaces (#9299) with pytest.warns(RuntimeWarning, match='software filter'): @@ -491,8 +480,8 @@ def test_brainvision_data_software_filters_latin1_global_units(): read_raw_brainvision, vhdr_fname=vhdr_old_longname_path, eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), misc=("A2",)) - assert_equal(raw.info['highpass'], 1. / (2 * np.pi * 0.9)) - assert_equal(raw.info['lowpass'], 50.) + assert raw.info['highpass'] == 1. / (2 * np.pi * 0.9) + assert raw.info['lowpass'] == 50. def test_brainvision_data(): @@ -507,8 +496,8 @@ def test_brainvision_data(): assert ('RawBrainVision' in repr(raw_py)) - assert_equal(raw_py.info['highpass'], 0.) - assert_equal(raw_py.info['lowpass'], 250.) + assert raw_py.info['highpass'] == 0. + assert raw_py.info['lowpass'] == 250. picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads') data_py, times_py = raw_py[picks] @@ -518,24 +507,24 @@ def test_brainvision_data(): picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads') data_bin, times_bin = raw_bin[picks] - assert_array_almost_equal(data_py, data_bin) - assert_array_almost_equal(times_py, times_bin) + assert_allclose(data_py, data_bin) + assert_allclose(times_py, times_bin) # Make sure EOG channels are marked correctly for ch in raw_py.info['chs']: if ch['ch_name'] in eog: - assert_equal(ch['kind'], FIFF.FIFFV_EOG_CH) + assert ch['kind'] == FIFF.FIFFV_EOG_CH elif ch['ch_name'] == 'STI 014': - assert_equal(ch['kind'], FIFF.FIFFV_STIM_CH) + assert ch['kind'] == FIFF.FIFFV_STIM_CH elif ch['ch_name'] in ('CP5', 'CP6'): - assert_equal(ch['kind'], FIFF.FIFFV_MISC_CH) - assert_equal(ch['unit'], FIFF.FIFF_UNIT_NONE) + assert ch['kind'] == FIFF.FIFFV_MISC_CH + assert ch['unit'] == FIFF.FIFF_UNIT_NONE elif ch['ch_name'] == 'ReRef': - assert_equal(ch['kind'], FIFF.FIFFV_MISC_CH) - assert_equal(ch['unit'], FIFF.FIFF_UNIT_CEL) + assert ch['kind'] == FIFF.FIFFV_MISC_CH + assert ch['unit'] == FIFF.FIFF_UNIT_CEL elif ch['ch_name'] in raw_py.info['ch_names']: - assert_equal(ch['kind'], FIFF.FIFFV_EEG_CH) - assert_equal(ch['unit'], FIFF.FIFF_UNIT_V) + assert ch['kind'] == FIFF.FIFFV_EEG_CH + assert ch['unit'] == FIFF.FIFF_UNIT_V else: raise RuntimeError("Unknown Channel: %s" % ch['ch_name']) @@ -546,20 +535,20 @@ def test_brainvision_data(): raw_units = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_units_path, eog=eog, misc='auto' ) - assert_equal(raw_units.info['chs'][0]['ch_name'], 'FP1') - assert_equal(raw_units.info['chs'][0]['kind'], FIFF.FIFFV_EEG_CH) + assert raw_units.info['chs'][0]['ch_name'] == 'FP1' + assert raw_units.info['chs'][0]['kind'] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[0] - assert_array_almost_equal(data_py[0, :], data_units.squeeze()) + assert_allclose(data_py[0, :], data_units.squeeze()) - assert_equal(raw_units.info['chs'][1]['ch_name'], 'FP2') - assert_equal(raw_units.info['chs'][1]['kind'], FIFF.FIFFV_EEG_CH) + assert raw_units.info['chs'][1]['ch_name'] == 'FP2' + assert raw_units.info['chs'][1]['kind'] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[1] - assert_array_almost_equal(data_py[1, :], data_units.squeeze()) + assert_allclose(data_py[1, :], data_units.squeeze()) - assert_equal(raw_units.info['chs'][2]['ch_name'], 'F3') - assert_equal(raw_units.info['chs'][2]['kind'], FIFF.FIFFV_EEG_CH) + assert raw_units.info['chs'][2]['ch_name'] == 'F3' + assert raw_units.info['chs'][2]['kind'] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[2] - assert_array_almost_equal(data_py[2, :], data_units.squeeze()) + assert_allclose(data_py[2, :], data_units.squeeze()) def test_brainvision_vectorized_data(): @@ -600,7 +589,7 @@ def test_brainvision_vectorized_data(): [-7.35999985e-06, -7.18000031e-06], ]) - assert_array_almost_equal(raw._data[:, :2], first_two_samples_all_chs) + assert_allclose(raw._data[:, :2], first_two_samples_all_chs) def test_coodinates_extraction(): diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index 5a4bdffc4a5..f7ee029184b 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -203,32 +203,32 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format): meas_date = _session_date_2_meas_date(session_date, date_format) fid.seek(370) - n_channels = np.fromfile(fid, dtype='= 0] fid.seek(438) - lowpass_toggle = np.fromfile(fid, 'i1', count=1)[0] - highpass_toggle = np.fromfile(fid, 'i1', count=1)[0] + lowpass_toggle = np.fromfile(fid, 'i1', count=1).item() + highpass_toggle = np.fromfile(fid, 'i1', count=1).item() # Header has a field for number of samples, but it does not seem to be # too reliable. That's why we have option for setting n_bytes manually. fid.seek(864) - n_samples = np.fromfile(fid, dtype=' 1: cnt_info['channel_offset'] //= n_bytes else: @@ -266,7 +267,7 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format): ch_name = read_str(fid, 10) ch_names.append(ch_name) fid.seek(data_offset + 75 * ch_idx + 4) - if np.fromfile(fid, dtype='u1', count=1)[0]: + if np.fromfile(fid, dtype='u1', count=1).item(): bads.append(ch_name) fid.seek(data_offset + 75 * ch_idx + 19) xy = np.fromfile(fid, dtype='f4', count=2) @@ -274,11 +275,11 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format): pos.append(xy) fid.seek(data_offset + 75 * ch_idx + 47) # Baselines are subtracted before scaling the data. - baselines.append(np.fromfile(fid, dtype='i2', count=1)[0]) + baselines.append(np.fromfile(fid, dtype='i2', count=1).item()) fid.seek(data_offset + 75 * ch_idx + 59) - sensitivity = np.fromfile(fid, dtype='f4', count=1)[0] + sensitivity = np.fromfile(fid, dtype='f4', count=1).item() fid.seek(data_offset + 75 * ch_idx + 71) - cal = np.fromfile(fid, dtype='f4', count=1)[0] + cal = np.fromfile(fid, dtype='f4', count=1).item() cals.append(cal * sensitivity * 1e-6 / 204.8) info = _empty_info(sfreq) diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 587ca8bd85f..0afbe3e2836 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -298,8 +298,7 @@ def _conv_comp(comp, first, last, chs): col_names = np.array(col_names)[mask].tolist() n_col = len(col_names) n_row = len(row_names) - ccomp = dict(ctfkind=np.array([comp[first]['coeff_type']]), - save_calibrated=False) + ccomp = dict(ctfkind=comp[first]['coeff_type'], save_calibrated=False) _add_kind(ccomp) data = np.empty((n_row, n_col)) diff --git a/mne/io/ctf/res4.py b/mne/io/ctf/res4.py index 8da208e18a5..7b0a4e2b9e6 100644 --- a/mne/io/ctf/res4.py +++ b/mne/io/ctf/res4.py @@ -88,8 +88,8 @@ def _read_comp_coeff(fid, d): d['comp'].append(comp) comp['sensor_name'] = \ comps['sensor_name'][k].split(b'\x00')[0].decode('utf-8') - comp['coeff_type'] = comps['coeff_type'][k] - comp['ncoeff'] = comps['ncoeff'][k] + comp['coeff_type'] = comps['coeff_type'][k].item() + comp['ncoeff'] = comps['ncoeff'][k].item() comp['sensors'] = [s.split(b'\x00')[0].decode('utf-8') for s in comps['sensors'][k][:comp['ncoeff']]] comp['coeffs'] = comps['coeffs'][k][:comp['ncoeff']] diff --git a/mne/io/ctf_comp.py b/mne/io/ctf_comp.py index 61fca9117f7..04198e45c58 100644 --- a/mne/io/ctf_comp.py +++ b/mne/io/ctf_comp.py @@ -116,7 +116,7 @@ def _read_ctf_comp(fid, node, chs, ch_names_mapping): raise Exception('Compensation type not found') # Get the compensation kind and map it to a simple number - one = dict(ctfkind=tag.data) + one = dict(ctfkind=tag.data.item()) del tag _add_kind(one) for p in range(node['nent']): diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index aee9310e86b..1c745d8b10e 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -519,6 +519,7 @@ def __init__(self, input_fname, eog=None, misc=None, ref_idx = np.flatnonzero(np.in1d(mon.ch_names, REFERENCE_NAMES)) if len(ref_idx): + ref_idx = ref_idx.item() ref_coords = info['chs'][int(ref_idx)]['loc'][:3] for chan in info['chs']: is_eeg = chan['kind'] == FIFF.FIFFV_EEG_CH diff --git a/mne/io/egi/general.py b/mne/io/egi/general.py index a1de880efc6..c364e0eb9c7 100644 --- a/mne/io/egi/general.py +++ b/mne/io/egi/general.py @@ -150,12 +150,12 @@ def _get_signalfname(filepath): def _block_r(fid): """Read meta data.""" - if np.fromfile(fid, dtype=np.dtype('i4'), count=1)[0] != 1: # not metadata + if np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() != 1: # not meta return None - header_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1)[0] - block_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1)[0] + header_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() + block_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() hl = int(block_size / 4) - nc = np.fromfile(fid, dtype=np.dtype('i4'), count=1)[0] + nc = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() nsamples = int(hl / nc) np.fromfile(fid, dtype=np.dtype('i4'), count=nc) # sigoffset sigfreq = np.fromfile(fid, dtype=np.dtype('i4'), count=nc) diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py index 1a45af0c38e..b71154a0208 100644 --- a/mne/io/fiff/raw.py +++ b/mne/io/fiff/raw.py @@ -187,7 +187,7 @@ def _read_raw_file(self, fname, allow_maxshield, preload, # Get first sample tag if it is there if directory[first].kind == FIFF.FIFF_FIRST_SAMPLE: tag = read_tag(fid, directory[first].pos) - first_samp = int(tag.data) + first_samp = int(tag.data.item()) first += 1 _check_entry(first, nent) @@ -195,7 +195,7 @@ def _read_raw_file(self, fname, allow_maxshield, preload, if directory[first].kind == FIFF.FIFF_DATA_SKIP: # This first skip can be applied only after we know the bufsize tag = read_tag(fid, directory[first].pos) - first_skip = int(tag.data) + first_skip = int(tag.data.item()) first += 1 _check_entry(first, nent) @@ -220,7 +220,7 @@ def _read_raw_file(self, fname, allow_maxshield, preload, # an re-clicked the button if ent.kind == FIFF.FIFF_DATA_SKIP: tag = read_tag(fid, ent.pos) - nskip = int(tag.data) + nskip = int(tag.data.item()) elif ent.kind == FIFF.FIFF_DATA_BUFFER: # Figure out the number of samples in this buffer if ent.type == FIFF.FIFFT_DAU_PACK16: diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 4e913cd587b..25e2ad1792a 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -1873,7 +1873,7 @@ def test_corrupted(tmp_path): with open(skip_fname, 'rb') as fid: tag = read_tag_info(fid) tag = read_tag(fid) - dirpos = int(tag.data) + dirpos = int(tag.data.item()) assert dirpos == 12641532 fid.seek(0) data = fid.read(dirpos) diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index bc22187b9a3..9bc63b38052 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -23,7 +23,7 @@ from .open import fiff_open from .tree import dir_tree_find from .tag import (read_tag, find_tag, _ch_coord_dict, _update_ch_info_named, - _rename_list) + _rename_list, _int_item, _float_item) from .proj import (_read_proj, _write_proj, _uniquify_projs, _normalize_proj, _proj_equal, Projection) from .ctf_comp import _read_ctf_comp, write_ctf_comp @@ -1449,21 +1449,21 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): pos = meas_info['directory'][k].pos if kind == FIFF.FIFF_NCHAN: tag = read_tag(fid, pos) - nchan = int(tag.data) + nchan = int(tag.data.item()) elif kind == FIFF.FIFF_SFREQ: tag = read_tag(fid, pos) - sfreq = float(tag.data) + sfreq = float(tag.data.item()) elif kind == FIFF.FIFF_CH_INFO: tag = read_tag(fid, pos) chs.append(tag.data) elif kind == FIFF.FIFF_LOWPASS: tag = read_tag(fid, pos) - if not np.isnan(tag.data): - lowpass = float(tag.data) + if not np.isnan(tag.data.item()): + lowpass = float(tag.data.item()) elif kind == FIFF.FIFF_HIGHPASS: tag = read_tag(fid, pos) if not np.isnan(tag.data): - highpass = float(tag.data) + highpass = float(tag.data.item()) elif kind == FIFF.FIFF_MEAS_DATE: tag = read_tag(fid, pos) meas_date = tuple(tag.data) @@ -1503,19 +1503,19 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): proj_name = tag.data elif kind == FIFF.FIFF_LINE_FREQ: tag = read_tag(fid, pos) - line_freq = float(tag.data) + line_freq = float(tag.data.item()) elif kind == FIFF.FIFF_GANTRY_ANGLE: tag = read_tag(fid, pos) - gantry_angle = float(tag.data) + gantry_angle = float(tag.data.item()) elif kind in [FIFF.FIFF_MNE_CUSTOM_REF, 236]: # 236 used before v0.11 tag = read_tag(fid, pos) - custom_ref_applied = int(tag.data) + custom_ref_applied = int(tag.data.item()) elif kind == FIFF.FIFF_XPLOTTER_LAYOUT: tag = read_tag(fid, pos) xplotter_layout = str(tag.data) elif kind == FIFF.FIFF_MNE_KIT_SYSTEM_ID: tag = read_tag(fid, pos) - kit_system_id = int(tag.data) + kit_system_id = int(tag.data.item()) ch_names_mapping = _read_extended_ch_info(chs, meas_info, fid) # Check that we have everything we need @@ -1622,11 +1622,11 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): elif kind == FIFF.FIFF_HPI_FIT_GOODNESS: hr['goodness'] = read_tag(fid, pos).data elif kind == FIFF.FIFF_HPI_FIT_GOOD_LIMIT: - hr['good_limit'] = float(read_tag(fid, pos).data) + hr['good_limit'] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_FIT_DIST_LIMIT: - hr['dist_limit'] = float(read_tag(fid, pos).data) + hr['dist_limit'] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_FIT_ACCEPT: - hr['accept'] = int(read_tag(fid, pos).data) + hr['accept'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_COORD_TRANS: hr['coord_trans'] = read_tag(fid, pos).data hrs.append(hr) @@ -1643,17 +1643,17 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): if kind == FIFF.FIFF_CREATOR: hm['creator'] = str(read_tag(fid, pos).data) elif kind == FIFF.FIFF_SFREQ: - hm['sfreq'] = float(read_tag(fid, pos).data) + hm['sfreq'] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_NCHAN: - hm['nchan'] = int(read_tag(fid, pos).data) + hm['nchan'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_NAVE: - hm['nave'] = int(read_tag(fid, pos).data) + hm['nave'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_NCOIL: - hm['ncoil'] = int(read_tag(fid, pos).data) + hm['ncoil'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_FIRST_SAMPLE: - hm['first_samp'] = int(read_tag(fid, pos).data) + hm['first_samp'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_LAST_SAMPLE: - hm['last_samp'] = int(read_tag(fid, pos).data) + hm['last_samp'] = int(read_tag(fid, pos).data.item()) hpi_coils = dir_tree_find(hpi_meas, FIFF.FIFFB_HPI_COIL) hcs = [] for hpi_coil in hpi_coils: @@ -1662,7 +1662,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): kind = hpi_coil['directory'][k].kind pos = hpi_coil['directory'][k].pos if kind == FIFF.FIFF_HPI_COIL_NO: - hc['number'] = int(read_tag(fid, pos).data) + hc['number'] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_EPOCH: hc['epoch'] = read_tag(fid, pos).data hc['epoch'].flags.writeable = False @@ -1673,7 +1673,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hc['corr_coeff'] = read_tag(fid, pos).data hc['corr_coeff'].flags.writeable = False elif kind == FIFF.FIFF_HPI_COIL_FREQ: - hc['coil_freq'] = float(read_tag(fid, pos).data) + hc['coil_freq'] = float(read_tag(fid, pos).data.item()) hcs.append(hc) hm['hpi_coils'] = hcs hms.append(hm) @@ -1690,7 +1690,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): pos = subject_info['directory'][k].pos if kind == FIFF.FIFF_SUBJ_ID: tag = read_tag(fid, pos) - si['id'] = int(tag.data) + si['id'] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_HIS_ID: tag = read_tag(fid, pos) si['his_id'] = str(tag.data) @@ -1715,10 +1715,10 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): si['birthday'] = tag.data elif kind == FIFF.FIFF_SUBJ_SEX: tag = read_tag(fid, pos) - si['sex'] = int(tag.data) + si['sex'] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_HAND: tag = read_tag(fid, pos) - si['hand'] = int(tag.data) + si['hand'] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_WEIGHT: tag = read_tag(fid, pos) si['weight'] = tag.data @@ -1761,10 +1761,10 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): pos = helium_info['directory'][k].pos if kind == FIFF.FIFF_HE_LEVEL_RAW: tag = read_tag(fid, pos) - hi['he_level_raw'] = float(tag.data) + hi['he_level_raw'] = float(tag.data.item()) elif kind == FIFF.FIFF_HELIUM_LEVEL: tag = read_tag(fid, pos) - hi['helium_level'] = float(tag.data) + hi['helium_level'] = float(tag.data.item()) elif kind == FIFF.FIFF_ORIG_FILE_GUID: tag = read_tag(fid, pos) hi['orig_file_guid'] = str(tag.data) @@ -1784,7 +1784,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): pos = hpi_subsystem['directory'][k].pos if kind == FIFF.FIFF_HPI_NCOIL: tag = read_tag(fid, pos) - hs['ncoil'] = int(tag.data) + hs['ncoil'] = int(tag.data.item()) elif kind == FIFF.FIFF_EVENT_CHANNEL: tag = read_tag(fid, pos) hs['event_channel'] = str(tag.data) @@ -2794,17 +2794,17 @@ def _bad_chans_comp(info, ch_names): kind=int, ident=int, r=lambda x: x, coord_frame=int) # key -> const, cast, write _CH_INFO_MAP = OrderedDict( - scanno=(FIFF.FIFF_CH_SCAN_NO, int, write_int), - logno=(FIFF.FIFF_CH_LOGICAL_NO, int, write_int), - kind=(FIFF.FIFF_CH_KIND, int, write_int), - range=(FIFF.FIFF_CH_RANGE, float, write_float), - cal=(FIFF.FIFF_CH_CAL, float, write_float), - coil_type=(FIFF.FIFF_CH_COIL_TYPE, int, write_int), + scanno=(FIFF.FIFF_CH_SCAN_NO, _int_item, write_int), + logno=(FIFF.FIFF_CH_LOGICAL_NO, _int_item, write_int), + kind=(FIFF.FIFF_CH_KIND, _int_item, write_int), + range=(FIFF.FIFF_CH_RANGE, _float_item, write_float), + cal=(FIFF.FIFF_CH_CAL, _float_item, write_float), + coil_type=(FIFF.FIFF_CH_COIL_TYPE, _int_item, write_int), loc=(FIFF.FIFF_CH_LOC, lambda x: x, write_float), - unit=(FIFF.FIFF_CH_UNIT, int, write_int), - unit_mul=(FIFF.FIFF_CH_UNIT_MUL, int, write_int), + unit=(FIFF.FIFF_CH_UNIT, _int_item, write_int), + unit_mul=(FIFF.FIFF_CH_UNIT_MUL, _int_item, write_int), ch_name=(FIFF.FIFF_CH_DACQ_NAME, str, write_string), - coord_frame=(FIFF.FIFF_CH_COORD_FRAME, int, write_int), + coord_frame=(FIFF.FIFF_CH_COORD_FRAME, _int_item, write_int), ) # key -> cast _CH_CAST = OrderedDict((key, val[1]) for key, val in _CH_INFO_MAP.items()) diff --git a/mne/io/open.py b/mne/io/open.py index d2c94accd53..e3b83c31fb1 100644 --- a/mne/io/open.py +++ b/mne/io/open.py @@ -62,7 +62,7 @@ def _get_next_fname(fid, fname, tree): for ent in nodes['directory']: if ent.kind == FIFF.FIFF_REF_ROLE: tag = read_tag(fid, ent.pos) - role = int(tag.data) + role = int(tag.data.item()) if role != FIFF.FIFFV_ROLE_NEXT_FILE: next_fname = None break @@ -74,7 +74,7 @@ def _get_next_fname(fid, fname, tree): # we construct the name from the current name. if next_fname is not None: continue - next_num = read_tag(fid, ent.pos).data + next_num = read_tag(fid, ent.pos).data.item() path, base = op.split(fname) idx = base.find('.') idx2 = base.rfind('-') @@ -157,7 +157,7 @@ def _fiff_open(fname, fid, preload): # Read or create the directory tree logger.debug(' Creating tag directory for %s...' % fname) - dirpos = int(tag.data) + dirpos = int(tag.data.item()) read_slow = True if dirpos > 0: dir_tag = read_tag(fid, dirpos) diff --git a/mne/io/proc_history.py b/mne/io/proc_history.py index 7209f9f7ca7..290730a2aeb 100644 --- a/mne/io/proc_history.py +++ b/mne/io/proc_history.py @@ -10,7 +10,7 @@ write_string, write_float_matrix, write_int_matrix, write_float_sparse, write_id, write_name_list_sanitized, _safe_name_list) -from .tag import find_tag +from .tag import find_tag, _int_item, _float_item from .constants import FIFF from ..fixes import _csc_matrix_cast from ..utils import warn, _check_fname @@ -137,15 +137,15 @@ def _write_proc_history(fid, info): _sss_info_writers = (write_int, write_int, write_float, write_int, write_int, write_int, write_int, write_int, write_float, write_float) -_sss_info_casters = (int, int, np.array, int, - int, int, np.array, int, - float, float) +_sss_info_casters = (_int_item, _int_item, np.array, _int_item, + _int_item, _int_item, np.array, _int_item, + _float_item, _float_item) _max_st_keys = ('job', 'subspcorr', 'buflen') _max_st_ids = (FIFF.FIFF_SSS_JOB, FIFF.FIFF_SSS_ST_CORR, FIFF.FIFF_SSS_ST_LENGTH) _max_st_writers = (write_int, write_float, write_float) -_max_st_casters = (int, float, float) +_max_st_casters = (_int_item, _float_item, _float_item) _sss_ctc_keys = ('block_id', 'date', 'creator', 'decoupler') _sss_ctc_ids = (FIFF.FIFF_BLOCK_ID, diff --git a/mne/io/proj.py b/mne/io/proj.py index 7774a8edb02..e874c47ebd5 100644 --- a/mne/io/proj.py +++ b/mne/io/proj.py @@ -461,7 +461,7 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): # global_nchan = None # tag = find_tag(fid, nodes[0], FIFF.FIFF_NCHAN) # if tag is not None: - # global_nchan = int(tag.data) + # global_nchan = int(tag.data.item()) items = dir_tree_find(nodes[0], FIFF.FIFFB_PROJ_ITEM) for item in items: @@ -471,7 +471,7 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): # sometimes # tag = find_tag(fid, item, FIFF.FIFF_NCHAN) # if tag is not None: - # nchan = int(tag.data) + # nchan = int(tag.data.item()) # else: # nchan = global_nchan @@ -487,13 +487,13 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_KIND) if tag is not None: - kind = int(tag.data) + kind = int(tag.data.item()) else: raise ValueError('Projection item kind missing') tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_NVEC) if tag is not None: - nvec = int(tag.data) + nvec = int(tag.data.item()) else: raise ValueError('Number of projection vectors not specified') @@ -511,13 +511,13 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): tag = find_tag(fid, item, FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE) if tag is not None: - active = bool(tag.data) + active = bool(tag.data.item()) else: active = False tag = find_tag(fid, item, FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR) if tag is not None: - explained_var = float(tag.data) + explained_var = float(tag.data.item()) else: explained_var = None diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index 93d024a4a75..1b1b20fd531 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -473,7 +473,7 @@ def _extract_sampling_rate(dat): fs_diff = np.around(np.diff(time_data), decimals=4) if len(np.unique(fs_diff)) == 1: # Uniformly sampled data - sampling_rate = 1. / np.unique(fs_diff) + sampling_rate = 1. / np.unique(fs_diff).item() else: warn("MNE does not currently support reading " "SNIRF files with non-uniform sampled data.") diff --git a/mne/io/tag.py b/mne/io/tag.py index 6d4b5df2ee4..21077701192 100644 --- a/mne/io/tag.py +++ b/mne/io/tag.py @@ -181,7 +181,7 @@ def _read_matrix(fid, tag, shape, rlims, matrix_coding): # Find dimensions and return to the beginning of tag data pos = fid.tell() fid.seek(tag.size - 4, 1) - ndim = int(np.frombuffer(fid.read(4), dtype='>i4')) + ndim = int(np.frombuffer(fid.read(4), dtype='>i4').item()) fid.seek(-(ndim + 1) * 4, 1) dims = np.frombuffer(fid.read(4 * ndim), dtype='>i4')[::-1] # @@ -211,7 +211,7 @@ def _read_matrix(fid, tag, shape, rlims, matrix_coding): # Find dimensions and return to the beginning of tag data pos = fid.tell() fid.seek(tag.size - 4, 1) - ndim = int(np.frombuffer(fid.read(4), dtype='>i4')) + ndim = int(np.frombuffer(fid.read(4), dtype='>i4').item()) fid.seek(-(ndim + 2) * 4, 1) dims = np.frombuffer(fid.read(4 * (ndim + 1)), dtype='>i4') if ndim != 2: @@ -296,17 +296,17 @@ def _read_complex_double(fid, tag, shape, rlims): def _read_id_struct(fid, tag, shape, rlims): """Read ID struct tag.""" return dict( - version=int(np.frombuffer(fid.read(4), dtype=">i4")), + version=int(np.frombuffer(fid.read(4), dtype=">i4").item()), machid=np.frombuffer(fid.read(8), dtype=">i4"), - secs=int(np.frombuffer(fid.read(4), dtype=">i4")), - usecs=int(np.frombuffer(fid.read(4), dtype=">i4"))) + secs=int(np.frombuffer(fid.read(4), dtype=">i4").item()), + usecs=int(np.frombuffer(fid.read(4), dtype=">i4").item())) def _read_dig_point_struct(fid, tag, shape, rlims): """Read dig point struct tag.""" - kind = int(np.frombuffer(fid.read(4), dtype=">i4")) + kind = int(np.frombuffer(fid.read(4), dtype=">i4").item()) kind = _dig_kind_named.get(kind, kind) - ident = int(np.frombuffer(fid.read(4), dtype=">i4")) + ident = int(np.frombuffer(fid.read(4), dtype=">i4").item()) if kind == FIFF.FIFFV_POINT_CARDINAL: ident = _dig_cardinal_named.get(ident, ident) return dict( @@ -318,8 +318,8 @@ def _read_dig_point_struct(fid, tag, shape, rlims): def _read_coord_trans_struct(fid, tag, shape, rlims): """Read coord trans struct tag.""" from ..transforms import Transform - fro = int(np.frombuffer(fid.read(4), dtype=">i4")) - to = int(np.frombuffer(fid.read(4), dtype=">i4")) + fro = int(np.frombuffer(fid.read(4), dtype=">i4").item()) + to = int(np.frombuffer(fid.read(4), dtype=">i4").item()) rot = np.frombuffer(fid.read(36), dtype=">f4").reshape(3, 3) move = np.frombuffer(fid.read(12), dtype=">f4") trans = np.r_[np.c_[rot, move], @@ -343,17 +343,17 @@ def _read_coord_trans_struct(fid, tag, shape, rlims): def _read_ch_info_struct(fid, tag, shape, rlims): """Read channel info struct tag.""" d = dict( - scanno=int(np.frombuffer(fid.read(4), dtype=">i4")), - logno=int(np.frombuffer(fid.read(4), dtype=">i4")), - kind=int(np.frombuffer(fid.read(4), dtype=">i4")), - range=float(np.frombuffer(fid.read(4), dtype=">f4")), - cal=float(np.frombuffer(fid.read(4), dtype=">f4")), - coil_type=int(np.frombuffer(fid.read(4), dtype=">i4")), + scanno=int(np.frombuffer(fid.read(4), dtype=">i4").item()), + logno=int(np.frombuffer(fid.read(4), dtype=">i4").item()), + kind=int(np.frombuffer(fid.read(4), dtype=">i4").item()), + range=float(np.frombuffer(fid.read(4), dtype=">f4").item()), + cal=float(np.frombuffer(fid.read(4), dtype=">f4").item()), + coil_type=int(np.frombuffer(fid.read(4), dtype=">i4").item()), # deal with really old OSX Anaconda bug by casting to float64 loc=np.frombuffer(fid.read(48), dtype=">f4").astype(np.float64), # unit and exponent - unit=int(np.frombuffer(fid.read(4), dtype=">i4")), - unit_mul=int(np.frombuffer(fid.read(4), dtype=">i4")), + unit=int(np.frombuffer(fid.read(4), dtype=">i4").item()), + unit_mul=int(np.frombuffer(fid.read(4), dtype=">i4").item()), ) # channel name ch_name = np.frombuffer(fid.read(16), dtype=">c") @@ -374,8 +374,8 @@ def _update_ch_info_named(d): def _read_old_pack(fid, tag, shape, rlims): """Read old pack tag.""" - offset = float(np.frombuffer(fid.read(4), dtype=">f4")) - scale = float(np.frombuffer(fid.read(4), dtype=">f4")) + offset = float(np.frombuffer(fid.read(4), dtype=">f4").item()) + scale = float(np.frombuffer(fid.read(4), dtype=">f4").item()) data = np.frombuffer(fid.read(tag.size - 8), dtype=">i2") data = data * scale # to float64 data += offset @@ -389,7 +389,7 @@ def _read_dir_entry_struct(fid, tag, shape, rlims): def _read_julian(fid, tag, shape, rlims): """Read julian tag.""" - return _julian_to_cal(int(np.frombuffer(fid.read(4), dtype=">i4"))) + return _julian_to_cal(int(np.frombuffer(fid.read(4), dtype=">i4").item())) # Read types call dict @@ -515,3 +515,11 @@ def has_tag(node, kind): def _rename_list(bads, ch_names_mapping): return [ch_names_mapping.get(bad, bad) for bad in bads] + + +def _int_item(x): + return int(x.item()) + + +def _float_item(x): + return float(x.item()) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 4b16c18a6a0..694cd46c941 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -20,6 +20,7 @@ import mne from mne import concatenate_raws, create_info, Annotations, pick_types from mne.datasets import testing +from mne.fixes import _numpy_h5py_dep from mne.io import read_raw_fif, RawArray, BaseRaw, Info, _writing_info_hdf5 from mne.io._digitization import _dig_kind_dict from mne.io.base import _get_scaling @@ -373,9 +374,9 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, if check_version('h5io'): read_hdf5, write_hdf5 = _import_h5io_funcs() fname_h5 = op.join(tempdir, 'info.h5') - with _writing_info_hdf5(raw.info): + with _writing_info_hdf5(raw.info), _numpy_h5py_dep(): write_hdf5(fname_h5, raw.info) - new_info = Info(read_hdf5(fname_h5)) + new_info = Info(read_hdf5(fname_h5)) assert object_diff(new_info, raw.info) == '' # Make sure that changing directory does not break anything diff --git a/mne/io/tree.py b/mne/io/tree.py index 16293df4152..b4ed4ee1c7b 100644 --- a/mne/io/tree.py +++ b/mne/io/tree.py @@ -54,7 +54,7 @@ def make_dir_tree(fid, directory, start=0, indent=0, verbose=None): if directory[start].kind == FIFF_BLOCK_START: tag = read_tag(fid, directory[start].pos) - block = tag.data + block = tag.data.item() else: block = 0 diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index 42d58d0173a..174e3c46f28 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -182,19 +182,19 @@ def read_inverse_operator(fname, *, verbose=None): raise Exception('Modalities not found') inv = dict() - inv['methods'] = int(tag.data) + inv['methods'] = int(tag.data.item()) tag = find_tag(fid, invs, FIFF.FIFF_MNE_SOURCE_ORIENTATION) if tag is None: raise Exception('Source orientation constraints not found') - inv['source_ori'] = int(tag.data) + inv['source_ori'] = int(tag.data.item()) tag = find_tag(fid, invs, FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS) if tag is None: raise Exception('Number of sources not found') - inv['nsource'] = int(tag.data) + inv['nsource'] = int(tag.data.item()) inv['nchan'] = 0 # # Coordinate frame @@ -212,7 +212,10 @@ def read_inverse_operator(fname, *, verbose=None): unit_dict = {FIFF.FIFF_UNIT_AM: 'Am', FIFF.FIFF_UNIT_AM_M2: 'Am/m^2', FIFF.FIFF_UNIT_AM_M3: 'Am/m^3'} - inv['units'] = unit_dict.get(int(getattr(tag, 'data', -1)), None) + inv['units'] = unit_dict.get( + int( + getattr(tag, 'data', np.array([-1])).item() + ), None) # # The actual source orientation vectors diff --git a/mne/preprocessing/_csd.py b/mne/preprocessing/_csd.py index 0649aee544e..3b5df493270 100644 --- a/mne/preprocessing/_csd.py +++ b/mne/preprocessing/_csd.py @@ -290,7 +290,7 @@ def compute_bridged_electrodes(inst, lm_cutoff=16, epoch_threshold=0.5, kde = gaussian_kde(ed_flat[ed_flat < lm_cutoff]) with np.errstate(invalid='ignore'): local_minimum = float(minimize_scalar( - lambda x: kde(x) if x < lm_cutoff and x > 0 else np.inf).x) + lambda x: kde(x) if x < lm_cutoff and x > 0 else np.inf).x.item()) logger.info(f'Local minimum {local_minimum} found') # find electrodes that are below the cutoff local minimum on diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index fb0ed474938..1292d536bc0 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -124,7 +124,7 @@ def compute_fine_calibration(raw, n_imbalance=3, t_window=10., ext_order=2, _, calibration, _ = _prep_fine_cal(info, calibration) for pi, pick in enumerate(mag_picks): idx = calibration['ch_names'].index(info['ch_names'][pick]) - cals[pick] = calibration['imb_cals'][idx] + cals[pick] = calibration['imb_cals'][idx].item() zs[pi] = calibration['locs'][idx][-3:] elif len(mag_picks) > 0: cal_list = list() diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 1e3ed65bee2..cb2cd5b8c76 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -50,19 +50,18 @@ def _read_stc(filename): num_bytes = 4 # read tmin in ms - stc['tmin'] = float(np.frombuffer(buf, dtype=">f4", count=1, - offset=offset)) - stc['tmin'] /= 1000.0 + stc['tmin'] = float( + np.frombuffer(buf, dtype=">f4", count=1, offset=offset).item()) / 1000. offset += num_bytes # read sampling rate in ms - stc['tstep'] = float(np.frombuffer(buf, dtype=">f4", count=1, - offset=offset)) - stc['tstep'] /= 1000.0 + stc['tstep'] = float( + np.frombuffer(buf, dtype=">f4", count=1, offset=offset).item()) / 1000. offset += num_bytes # read number of vertices/sources - vertices_n = int(np.frombuffer(buf, dtype=">u4", count=1, offset=offset)) + vertices_n = int( + np.frombuffer(buf, dtype=">u4", count=1, offset=offset).item()) offset += num_bytes # read the source vector @@ -71,7 +70,8 @@ def _read_stc(filename): offset += num_bytes * vertices_n # read the number of timepts - data_n = int(np.frombuffer(buf, dtype=">u4", count=1, offset=offset)) + data_n = int( + np.frombuffer(buf, dtype=">u4", count=1, offset=offset).item()) offset += num_bytes if (vertices_n and # vertices_n can be 0 (empty stc) @@ -157,7 +157,7 @@ def _read_w(filename): # read the vertices and data for i in range(vertices_n): vertices[i] = _read_3(fid) - data[i] = np.fromfile(fid, dtype='>f4', count=1)[0] + data[i] = np.fromfile(fid, dtype='>f4', count=1).item() w = dict() w['vertices'] = vertices @@ -2207,7 +2207,7 @@ def save(self, fname, ftype='stc', *, overwrite=False, verbose=None): if not fname.endswith(('-vl.w', '-vol.w')): fname += '-vl.w' fname = str(_check_fname(fname, overwrite=overwrite)) - _write_w(fname, vertices=self.vertices[0], data=self.data) + _write_w(fname, vertices=self.vertices[0], data=self.data[:, 0]) elif ftype == 'h5': super().save(fname, 'h5', overwrite=overwrite) logger.info('[done]') diff --git a/mne/source_space.py b/mne/source_space.py index 37598d8c42a..8c7e8899ea1 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -827,13 +827,13 @@ def _read_one_source_space(fid, this): if tag is None: res['id'] = int(FIFF.FIFFV_MNE_SURF_UNKNOWN) else: - res['id'] = int(tag.data) + res['id'] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_SOURCE_SPACE_TYPE) if tag is None: raise ValueError('Unknown source space type') else: - src_type = int(tag.data) + src_type = int(tag.data.item()) if src_type == FIFF.FIFFV_MNE_SPACE_SURFACE: res['type'] = 'surf' elif src_type == FIFF.FIFFV_MNE_SPACE_VOLUME: @@ -889,15 +889,15 @@ def _read_one_source_space(fid, this): tag = find_tag(fid, mri, FIFF.FIFF_MRI_WIDTH) if tag is not None: - res['mri_width'] = int(tag.data) + res['mri_width'] = int(tag.data.item()) tag = find_tag(fid, mri, FIFF.FIFF_MRI_HEIGHT) if tag is not None: - res['mri_height'] = int(tag.data) + res['mri_height'] = int(tag.data.item()) tag = find_tag(fid, mri, FIFF.FIFF_MRI_DEPTH) if tag is not None: - res['mri_depth'] = int(tag.data) + res['mri_depth'] = int(tag.data.item()) tag = find_tag(fid, mri, FIFF.FIFF_MNE_FILE_NAME) if tag is not None: @@ -922,7 +922,7 @@ def _read_one_source_space(fid, this): if tag is None: raise ValueError('Number of vertices not found') - res['np'] = int(tag.data) + res['np'] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NTRI) if tag is None: @@ -930,7 +930,7 @@ def _read_one_source_space(fid, this): if tag is None: res['ntri'] = 0 else: - res['ntri'] = int(tag.data) + res['ntri'] = int(tag.data.item()) else: res['ntri'] = tag.data @@ -980,7 +980,7 @@ def _read_one_source_space(fid, this): res['inuse'] = np.zeros(res['nuse'], dtype=np.int64) res['vertno'] = None else: - res['nuse'] = int(tag.data) + res['nuse'] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_SOURCE_SPACE_SELECTION) if tag is None: raise ValueError('Source selection information missing') @@ -1023,7 +1023,7 @@ def _read_one_source_space(fid, this): res['dist_limit'] = None else: res['dist'] = tag1.data - res['dist_limit'] = tag2.data + res['dist_limit'] = tag2.data.item() # Add the upper triangle res['dist'] = res['dist'] + res['dist'].T if (res['dist'] is not None): @@ -1299,7 +1299,7 @@ def _write_one_source_space(fid, this, verbose=None): dists = sparse.triu(dists, format=dists.format) write_float_sparse_rcs(fid, FIFF.FIFF_MNE_SOURCE_SPACE_DIST, dists) write_float_matrix(fid, FIFF.FIFF_MNE_SOURCE_SPACE_DIST_LIMIT, - this['dist_limit']) + np.array(this['dist_limit'], float)) # Segmentation data if this['type'] == 'vol' and ('seg_name' in this): diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index 1a6c12c092c..8495171e179 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -880,7 +880,7 @@ def _permutation_cluster_test(X, threshold, n_permutations, tail, stat_fun, t_obs_buffer[pos: pos + buffer_size] =\ stat_fun(*[x[:, pos: pos + buffer_size] for x in X]) - if not np.alltrue(t_obs == t_obs_buffer): + if not np.all(t_obs == t_obs_buffer): warn('Provided stat_fun does not treat variables independently. ' 'Setting buffer_size to None.') buffer_size = None diff --git a/mne/stats/regression.py b/mne/stats/regression.py index 9cd206d0cb6..1e2e7abac6a 100644 --- a/mne/stats/regression.py +++ b/mne/stats/regression.py @@ -82,7 +82,7 @@ def linear_regression(inst, design_matrix, names=None): raise ValueError('Input must be epochs or iterable of source ' 'estimates') logger.info(msg + ', (%s targets, %s regressors)' % - (np.product(data.shape[1:]), len(names))) + (np.prod(data.shape[1:]), len(names))) lm_params = _fit_lm(data, design_matrix, names) lm = namedtuple('lm', 'beta stderr t_val p_val mlog10_p_val') lm_fits = {} @@ -103,7 +103,7 @@ def _fit_lm(data, design_matrix, names): """Aux function.""" from scipy import stats, linalg n_samples = len(data) - n_features = np.product(data.shape[1:]) + n_features = np.prod(data.shape[1:]) if design_matrix.ndim != 2: raise ValueError('Design matrix must be a 2d array') n_rows, n_predictors = design_matrix.shape diff --git a/mne/surface.py b/mne/surface.py index 8a448f69feb..dbc9884dc08 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -897,10 +897,10 @@ def _read_patch(fname): # and PyCortex (BSD) patch = dict() with open(fname, 'r') as fid: - ver = np.fromfile(fid, dtype='>i4', count=1)[0] + ver = np.fromfile(fid, dtype='>i4', count=1).item() if ver != -1: raise RuntimeError(f'incorrect version # {ver} (not -1) found') - npts = np.fromfile(fid, dtype='>i4', count=1)[0] + npts = np.fromfile(fid, dtype='>i4', count=1).item() dtype = np.dtype( [('vertno', '>i4'), ('x', '>f'), ('y', '>f'), ('z', '>f')]) recs = np.fromfile(fid, dtype=dtype, count=npts) diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 365983debf8..64979f1fd62 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -113,7 +113,8 @@ def requires_openmeeg_mark(): """Mark pytest tests that require OpenMEEG.""" import pytest return pytest.mark.skipif( - not check_version('openmeeg', '2.5.5'), reason='Requires OpenMEEG') + not check_version( + 'openmeeg', '2.5.6'), reason='Requires OpenMEEG >= 2.5.6') def requires_freesurfer(arg): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 0e66fed2e2f..b07f06d68a6 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -4531,7 +4531,7 @@ def %(name)s(%(signature)s):\n def deprecated_alias(dep_name, func, removed_in=None): """Inject a deprecated alias into the namespace.""" if removed_in is None: - from .._version import __version__ + from .. import __version__ removed_in = __version__.split('.')[:2] removed_in[1] = str(int(removed_in[1]) + 1) removed_in = '.'.join(removed_in) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 6e18b487bdd..cab320eceac 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -824,7 +824,7 @@ def object_diff(a, b, pre='', *, allclose=False): if c.nnz > 0: out += pre + (' sparse matrix a and b differ on %s ' 'elements' % c.nnz) - elif hasattr(a, '__getstate__'): + elif hasattr(a, '__getstate__') and a.__getstate__() is not None: out += object_diff(a.__getstate__(), b.__getstate__(), pre, allclose=allclose) else: diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 9c30437e649..309a1ea465c 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -490,7 +490,7 @@ def __init__(self): @testing.requires_testing_data -@pytest.mark.skipif(os.getenv('CI_OS_NAME', '') == 'osx', +@pytest.mark.skipif(os.getenv('CI_OS_NAME', '').startswith('macos'), reason='Unreliable/segfault on macOS CI') @pytest.mark.parametrize('hemi', ('lh', 'rh')) def test_single_hemi(hemi, renderer_interactive_pyvistaqt, brain_gc): @@ -578,16 +578,17 @@ def tiny(tmp_path): sz = brain.plotter.size() sz = (sz.width(), sz.height()) sz_ren = brain.plotter.renderer.GetSize() - ratio = np.median(np.array(sz_ren) / np.array(sz)) + ratio = np.round(np.median(np.array(sz_ren) / np.array(sz))).astype(int) return brain, ratio @pytest.mark.filterwarnings('ignore:.*constrained_layout not applied.*:') def test_brain_screenshot(renderer_interactive_pyvistaqt, tmp_path, brain_gc): """Test time viewer screenshot.""" - # XXX disable for sprint because it's too unreliable - if sys.platform == 'darwin' and os.getenv('GITHUB_ACTIONS', '') == 'true': - pytest.skip('Test is unreliable on GitHub Actions macOS') + # This is broken on Conda + GHA for some reason + if os.getenv('CONDA_PREFIX', '') != '' and \ + os.getenv('GITHUB_ACTIONS', '') == 'true': + pytest.skip('Test is unreliable on GitHub Actions conda runs') tiny_brain, ratio = tiny(tmp_path) img_nv = tiny_brain.screenshot(time_viewer=False) want = (_TINY_SIZE[1] * ratio, _TINY_SIZE[0] * ratio, 3) @@ -901,9 +902,10 @@ def test_brain_traces(renderer_interactive_pyvistaqt, hemi, src, tmp_path, assert fname.stem in rst assert fname.is_file() img = image.imread(fname) - assert img.shape[1] == screenshot.shape[1] # same width + assert_allclose(img.shape[1], screenshot.shape[1], atol=1) # width assert img.shape[0] > screenshot.shape[0] # larger height - assert img.shape[:2] == screenshot_all.shape[:2] + assert_allclose(img.shape[1], screenshot_all.shape[1], atol=1) + assert_allclose(img.shape[0], screenshot_all.shape[0], atol=1) # TODO: don't skip on Windows, see diff --git a/mne/viz/_brain/tests/test_notebook.py b/mne/viz/_brain/tests/test_notebook.py index 95029cc0459..749f9f6075e 100644 --- a/mne/viz/_brain/tests/test_notebook.py +++ b/mne/viz/_brain/tests/test_notebook.py @@ -40,9 +40,11 @@ def test_notebook_alignment(renderer_notebook, brain_gc, nbexec): @testing.requires_testing_data def test_notebook_interactive(renderer_notebook, brain_gc, nbexec): """Test interactive modes.""" - import tempfile from contextlib import contextmanager + import os from pathlib import Path + import tempfile + import time import pytest from numpy.testing import assert_allclose from ipywidgets import Button @@ -85,20 +87,29 @@ def interactive(on): movie_path = tmp_path / "test.gif" screenshot_path = tmp_path / "test.png" brain._renderer.actions['movie_field'].value = str(movie_path) + assert not movie_path.is_file() brain._renderer.actions['screenshot_field'].value = \ str(screenshot_path) + assert not screenshot_path.is_file() total_number_of_buttons = sum( '_field' not in k for k in brain._renderer.actions.keys()) assert 'play' in brain._renderer.actions # play is not a button widget, it does not have a click() method number_of_buttons = 1 - for action in brain._renderer.actions.values(): + button_names = list() + for name, action in brain._renderer.actions.items(): widget = action._action if isinstance(widget, Button): widget.click() + button_names.append(name) number_of_buttons += 1 assert number_of_buttons == total_number_of_buttons - assert movie_path.is_file() + time.sleep(0.5) + assert 'movie' in button_names, button_names + # TODO: this fails on GHA for some reason, need to figure it out + if os.getenv('GITHUB_ACTIONS', '') != 'true': + assert movie_path.is_file() + assert 'screenshot' in button_names, button_names assert screenshot_path.is_file() img_nv = brain.screenshot() assert img_nv.shape == (300, 300, 3), img_nv.shape diff --git a/mne/viz/utils.py b/mne/viz/utils.py index aff4e8a7343..9bac4490619 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1778,7 +1778,7 @@ def log_fix(tval): closer_idx = np.argmin(np.abs(xlims - temp_ticks)) further_idx = np.argmax(np.abs(xlims - temp_ticks)) start_stop = [temp_ticks[closer_idx], xlims[further_idx]] - step = np.sign(np.diff(start_stop)) * np.max(np.abs(temp_ticks)) + step = np.sign(np.diff(start_stop)).item() * np.max(np.abs(temp_ticks)) tts = np.arange(*start_stop, step) xticks = np.array(sorted(xticks + [tts[0], tts[-1]])) axes.set_xticks(xticks) diff --git a/tools/get_minimal_commands.sh b/tools/get_minimal_commands.sh index 797fcff5568..2b96eaf8cf6 100755 --- a/tools/get_minimal_commands.sh +++ b/tools/get_minimal_commands.sh @@ -1,6 +1,6 @@ -#!/bin/bash -e +#!/bin/bash -set -o pipefail +set -eo pipefail if [ "${DEPS}" == "minimal" ]; then return 0 2>/dev/null || exit "0" @@ -10,18 +10,20 @@ pushd ~ > /dev/null export MNE_ROOT="${PWD}/minimal_cmds" export PATH=${MNE_ROOT}/bin:$PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then + echo "Setting MNE_ROOT for GHA" echo "MNE_ROOT=${MNE_ROOT}" >> $GITHUB_ENV; echo "${MNE_ROOT}/bin" >> $GITHUB_PATH; -fi; -if [ "${AZURE_CI}" == "true" ]; then +elif [ "${AZURE_CI}" == "true" ]; then + echo "Setting MNE_ROOT for Azure" echo "##vso[task.setvariable variable=MNE_ROOT]${MNE_ROOT}" echo "##vso[task.setvariable variable=PATH]${PATH}"; -fi; -if [ "${CIRCLECI}" == "true" ]; then +elif [ "${CIRCLECI}" == "true" ]; then + echo "Setting MNE_ROOT for CircleCI" echo "export MNE_ROOT=${MNE_ROOT}" >> "$BASH_ENV"; echo "export PATH=${MNE_ROOT}/bin:$PATH" >> "$BASH_ENV"; fi; -if [ "${CI_OS_NAME}" != "osx" ]; then +if [[ "${CI_OS_NAME}" != "macos"* ]]; then + echo "Getting files for Linux..." if [ ! -d "${PWD}/minimal_cmds" ]; then curl -L https://osf.io/g7dzs/download?version=5 | tar xz else @@ -46,6 +48,7 @@ if [ "${CI_OS_NAME}" != "osx" ]; then echo "export FREESURFER_HOME=${FREESURFER_HOME}" >> "$BASH_ENV"; fi; else + echo "Getting files for macOS Intel..." if [ ! -d "${PWD}/minimal_cmds" ]; then curl -L https://osf.io/rjcz4/download?version=2 | tar xz else @@ -53,20 +56,27 @@ else fi; export DYLD_LIBRARY_PATH=${MNE_ROOT}/lib:$DYLD_LIBRARY_PATH if [ "${GITHUB_ACTIONS}" == "true" ]; then + echo "Setting variables for GHA" echo "DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" >> "$GITHUB_ENV"; set -x wget https://github.com/XQuartz/XQuartz/releases/download/XQuartz-2.7.11/XQuartz-2.7.11.dmg sudo hdiutil attach XQuartz-2.7.11.dmg sudo installer -package /Volumes/XQuartz-2.7.11/XQuartz.pkg -target / sudo ln -s /opt/X11 /usr/X11 - fi; - if [ "${AZURE_CI}" == "true" ]; then + elif [ "${AZURE_CI}" == "true" ]; then + echo "Setting variables for Azure" echo "##vso[task.setvariable variable=DYLD_LIBRARY_PATH]${DYLD_LIBRARY_PATH}" - fi; - if [ "${CIRCLECI}" == "true" ]; then + elif [ "${CIRCLECI}" == "true" ]; then + echo "Setting variables for CircleCI" echo "export DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}" >> "$BASH_ENV"; fi; fi popd > /dev/null +set -x +which mne_process_raw mne_process_raw --version +which mne_surf2bem +mne_surf2bem --version +which mri_average mri_average --version +set +x diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 3fa77f91e90..0391ef59df0 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -3,11 +3,14 @@ STD_ARGS="--progress-bar off --upgrade" EXTRA_ARGS="" if [ ! -z "$CONDA_ENV" ]; then + echo "Uninstalling MNE for CONDA_ENV=${CONDA_ENV}" pip uninstall -yq mne elif [ ! -z "$CONDA_DEPENDENCIES" ]; then + echo "Using Mamba to install CONDA_DEPENDENCIES=${CONDA_DEPENDENCIES}" mamba install -y $CONDA_DEPENDENCIES else - # Changes here should also go in the interactive_test CircleCI job + echo "Install pip-pre dependencies" + test "${MNE_CI_KIND}" == "pip-pre" python -m pip install $STD_ARGS pip setuptools wheel echo "Numpy" pip uninstall -yq numpy @@ -16,20 +19,17 @@ else pip install $STD_ARGS --pre --only-binary ":all:" python-dateutil pytz joblib threadpoolctl six echo "PyQt6" # Broken as of 2022/09/20 - # pip install $STD_ARGS --pre --only-binary ":all:" --no-deps --extra-index-url https://www.riverbankcomputing.com/pypi/simple PyQt6 PyQt6-sip PyQt6-Qt6 - pip install $STD_ARGS --pre --only-binary ":all:" --no-deps PyQt6 PyQt6-sip PyQt6-Qt6 + # pip install $STD_ARGS --pre --only-binary ":all:" --no-deps --extra-index-url https://www.riverbankcomputing.com/pypi/simple PyQt6 + pip install $STD_ARGS --pre --only-binary ":all:" PyQt6 echo "NumPy/SciPy/pandas etc." - # Wait for https://github.com/scipy/scipy/issues/17811 - pip install $STD_ARGS --pre --only-binary ":all:" --no-deps --default-timeout=60 numpy - pip install $STD_ARGS --pre --only-binary ":all:" --no-deps --default-timeout=60 -i "https://pypi.anaconda.org/scipy-wheels-nightly/simple" scipy scikit-learn dipy pandas statsmodels matplotlib - pip install $STD_ARGS --pre --only-binary ":all:" --no-deps -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py - pip install $STD_ARGS --pre --only-binary ":all:" pillow - # We don't install Numba here because it forces an old NumPy version + pip install $STD_ARGS --pre --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" numpy scipy scikit-learn dipy pandas matplotlib pillow statsmodels + pip install $STD_ARGS --pre --only-binary ":all:" -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py + # No Numba because it forces an old NumPy version echo "nilearn and openmeeg" pip install $STD_ARGS --pre git+https://github.com/nilearn/nilearn - pip install $STD_ARGS --pre --only-binary ":all:" -i "https://test.pypi.org/simple" openmeeg + pip install $STD_ARGS --pre --only-binary ":all:" --extra-index-url "https://test.pypi.org/simple" openmeeg echo "VTK" - pip install $STD_ARGS --pre --only-binary ":all:" -i "https://wheels.vtk.org/" vtk + pip install $STD_ARGS --pre --only-binary ":all:" --extra-index-url "https://wheels.vtk.org/" vtk python -c "import vtk" echo "PyVista" pip install --progress-bar off git+https://github.com/pyvista/pyvista @@ -37,10 +37,6 @@ else pip install --progress-bar off git+https://github.com/pyvista/pyvistaqt echo "imageio-ffmpeg, xlrd, mffpy, python-picard" pip install --progress-bar off --pre imageio-ffmpeg xlrd mffpy python-picard patsy - if [ "$OSTYPE" == "darwin"* ]; then - echo "pyobjc-framework-Cocoa" - pip install --progress-bar off pyobjc-framework-Cocoa>=5.2.0 - fi echo "mne-qt-browser" pip install --progress-bar off git+https://github.com/mne-tools/mne-qt-browser EXTRA_ARGS="--pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh new file mode 100755 index 00000000000..cf1dbdb45a5 --- /dev/null +++ b/tools/github_actions_env_vars.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -eo pipefail -x + +# old and minimal use conda +if [[ "$MNE_CI_KIND" == "old" ]]; then + echo "Setting conda env vars for old" + echo "CONDA_ACTIVATE_ENV=true" >> $GITHUB_ENV + echo "CONDA_DEPENDENCIES=numpy=1.20.2 scipy=1.6.3 matplotlib=3.4 pandas=1.2.4 scikit-learn=0.24.2" >> $GITHUB_ENV + echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" >> $GITHUB_ENV + echo "MNE_SKIP_NETWORK_TESTS=1" >> $GITHUB_ENV +elif [[ "$MNE_CI_KIND" == "minimal" ]]; then + echo "Setting conda env vars for minimal" + echo "CONDA_ACTIVATE_ENV=true" >> $GITHUB_ENV + echo "CONDA_DEPENDENCIES=numpy scipy matplotlib" >> $GITHUB_ENV +elif [[ "$MNE_CI_KIND" == "notebook" ]]; then + echo "CONDA_ENV=environment.yml" >> $GITHUB_ENV + echo "CONDA_ACTIVATE_ENV=mne" >> $GITHUB_ENV + # TODO: This should work but breaks stuff... + # echo "MNE_3D_BACKEND=notebook" >> $GITHUB_ENV +elif [[ "$MNE_CI_KIND" != "pip"* ]]; then # conda, mamba (use warning level for completeness) + echo "Setting conda env vars for $MNE_CI_KIND" + echo "CONDA_ENV=environment.yml" >> $GITHUB_ENV + echo "CONDA_ACTIVATE_ENV=mne" >> $GITHUB_ENV + echo "MNE_QT_BACKEND=PyQt5" >> $GITHUB_ENV + echo "MNE_LOGGING_LEVEL=warning" >> $GITHUB_ENV +else # pip-like + echo "Setting pip env vars for $MNE_CI_KIND" + echo "MNE_QT_BACKEND=PyQt6" >> $GITHUB_ENV +fi +set +x diff --git a/tools/github_actions_install.sh b/tools/github_actions_install.sh index 899f5921591..f52c193d773 100755 --- a/tools/github_actions_install.sh +++ b/tools/github_actions_install.sh @@ -1,4 +1,5 @@ -#!/bin/bash -ef +#!/bin/bash -python setup.py build -python setup.py install +set -eo pipefail + +pip install -ve . diff --git a/tools/github_actions_locale.sh b/tools/github_actions_locale.sh deleted file mode 100755 index c81b229d663..00000000000 --- a/tools/github_actions_locale.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -ef - -echo "Print locale " -locale -echo "Other stuff" diff --git a/tools/github_actions_test.sh b/tools/github_actions_test.sh index a6b9a94e096..512af0f4047 100755 --- a/tools/github_actions_test.sh +++ b/tools/github_actions_test.sh @@ -1,10 +1,17 @@ -#!/bin/bash -ef +#!/bin/bash -USE_DIRS="mne/" -if [ "${CI_OS_NAME}" != "osx" ]; then +set -eo pipefail + +if [[ "${CI_OS_NAME}" != "macos"* ]]; then CONDITION="not (ultraslowtest or pgtest)" else CONDITION="not (slowtest or pgtest)" fi -echo 'pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml -vv ${USE_DIRS}' +if [ "${MNE_CI_KIND}" == "notebook" ]; then + USE_DIRS=mne/viz/ +else + USE_DIRS="mne/" +fi +set -x pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml -vv ${USE_DIRS} +set +x From cc4006845d3c5367f16705557d652d9bbb82e87e Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Tue, 25 Apr 2023 06:22:56 -0500 Subject: [PATCH 0029/1125] use py3.10 in precommit config (#11648) Co-authored-by: Clemens Brunner Co-authored-by: Eric Larson --- .pre-commit-config.yaml | 5 ++--- requirements_testing.txt | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4814c23d8eb..7ef76755eae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,3 @@ -default_language_version: - python: python3.11 - repos: # - repo: https://github.com/psf/black # rev: 23.1.0 @@ -32,5 +29,7 @@ repos: rev: v2.2.3 hooks: - id: codespell + additional_dependencies: + - tomli files: ^mne/|^doc/|^examples/|^tutorials/ types_or: [python, bib, rst, inc] diff --git a/requirements_testing.txt b/requirements_testing.txt index c8ff7b5c5fb..fa6c7b86b3f 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -7,6 +7,7 @@ ruff numpydoc codespell check-manifest +tomli; python_version<'3.11' twine wheel pre-commit From 57f5ce300f8e67ffcc0d9a19ab724693105b74b6 Mon Sep 17 00:00:00 2001 From: Mikolaj Magnuski Date: Tue, 25 Apr 2023 13:24:01 +0200 Subject: [PATCH 0030/1125] FIX: missing channels/fiducials can be np.nan (#11634) --- mne/channels/montage.py | 3 +- mne/channels/tests/test_montage.py | 46 ++++++++++++++++++++++++++++++ mne/io/_digitization.py | 3 +- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index e28ded3f3d7..6a2a84241ea 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1656,7 +1656,8 @@ def compute_native_head_t(montage, *, on_missing='warn', verbose=None): else: fid_keys = ('nasion', 'lpa', 'rpa') for key in fid_keys: - if fid_coords[key] is None: + this_coord = fid_coords[key] + if this_coord is None or np.any(np.isnan(this_coord)): msg = ( f'Fiducial point {key} not found, assuming identity ' f'{_verbose_frames[coord_frame]} to head transformation') diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 88297cae040..5fe16a2294d 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -848,6 +848,52 @@ def test_set_dig_montage(): [6., 7., 8., 42., 42., 42.]]) +def test_set_dig_montage_with_nan_positions(): + """Test that fiducials are not NaN. + + Test that setting a montage with some NaN positions does not produce + NaN fiducials. + """ + def _ensure_fid_not_nan(info, ch_pos): + montage_kwargs = dict(ch_pos=dict(), coord_frame='head') + for ch_idx, ch in enumerate(info.ch_names): + montage_kwargs['ch_pos'][ch] = ch_pos[ch_idx] + + new_montage = make_dig_montage(**montage_kwargs) + info = info.copy() + info.set_montage(new_montage) + + recovered_montage = info.get_montage() + fid_coords, coord_frame = _get_fid_coords( + recovered_montage.dig, raise_error=False) + + for fid_coord in fid_coords.values(): + if fid_coord is not None: + assert not np.isnan(fid_coord).any() + + return fid_coords, coord_frame + + channels = list('ABCDEF') + info = create_info(channels, 1000, ch_types='seeg') + + # if all positions are NaN, the fiducials should not be NaN, but None + ch_pos = [info['chs'][ch_idx]['loc'][:3] + for ch_idx in range(len(channels))] + fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) + for fid_coord in fid_coords.values(): + assert fid_coord is None + assert coord_frame is None + + # if some positions are not NaN, the fiducials should be a non-NaN array + ch_pos[0] = np.array([1., 1.5, 1.]) + ch_pos[1] = np.array([2., 1.5, 1.5]) + ch_pos[2] = np.array([1.25, 1., 1.25]) + fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) + for fid_coord in fid_coords.values(): + assert isinstance(fid_coord, np.ndarray) + assert coord_frame == FIFF.FIFFV_COORD_HEAD + + @testing.requires_testing_data def test_fif_dig_montage(tmp_path): """Test FIF dig montage support.""" diff --git a/mne/io/_digitization.py b/mne/io/_digitization.py index a57e0eb78eb..c6baf9f507b 100644 --- a/mne/io/_digitization.py +++ b/mne/io/_digitization.py @@ -254,7 +254,8 @@ def _ensure_fiducials_head(dig): if radius is None: radius = [ np.linalg.norm(d['r']) for d in dig - if d['coord_frame'] == FIFF.FIFFV_COORD_HEAD] + if d['coord_frame'] == FIFF.FIFFV_COORD_HEAD + and not np.isnan(d['r']).any()] if not radius: return # can't complete, no head points radius = np.mean(radius) From 909e45821bd004b6072a6426e4e8ea88e51b72ad Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Tue, 25 Apr 2023 20:58:42 +0200 Subject: [PATCH 0031/1125] Fig bug with ica.plot_components when axes is provided (#11654) --- doc/changes/latest.inc | 1 + mne/viz/tests/test_ica.py | 22 ++++++++++++++++++---- mne/viz/topomap.py | 24 ++++++++++++++++++++---- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 12c192e2b07..5cefd8aa8a6 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -72,6 +72,7 @@ Bugs - Fix bug where :meth:`mne.Evoked.plot_topomap` opened an extra figure (:gh:`11607` by `Alex Rockhill`_) - Fix bug where :func:`mne.transforms.apply_volume_registration_points` modified info in place (:gh:`11612` by `Alex Rockhill`_) - In :class:`~mne.Report`, custom figures now show up correctly when ``image_format='svg'`` is requested (:gh:`11623` by `Richard Höchenberger`_) +- Fix bug where providing ``axes`` in `mne.preprocessing.ICA.plot_components` would fail (:gh:`11654` by `Mathieu Scheltienne`_) API changes ~~~~~~~~~~~ diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 07f4bbbd5d3..a5724c2e40e 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -6,10 +6,10 @@ import sys from pathlib import Path +import matplotlib.pyplot as plt import numpy as np -from numpy.testing import assert_equal, assert_array_equal import pytest -import matplotlib.pyplot as plt +from numpy.testing import assert_equal, assert_array_equal from mne import (read_events, Epochs, read_cov, pick_types, Annotations, make_fixed_length_events) @@ -59,9 +59,11 @@ def test_plot_ica_components(): res = 8 fast_test = {"res": res, "contours": 0, "sensors": False} raw = _get_raw() - ica = ICA(noise_cov=read_cov(cov_fname), n_components=2) + ica = ICA(noise_cov=read_cov(cov_fname), n_components=8) ica_picks = _get_picks(raw) - with pytest.warns(RuntimeWarning, match='projection'): + with pytest.warns( + RuntimeWarning, match="(projection)|(unstable mixing matrix)" + ): ica.fit(raw, picks=ica_picks) for components in [0, [0], [0, 1], [0, 1] * 2, None]: @@ -106,6 +108,18 @@ def test_plot_ica_components(): title = topomap_ax.get_title() assert (lbl.split(' ')[0] == title.split(' ')[0]) + # test provided axes + _, ax = plt.subplots(1, 1) + ica.plot_components(axes=ax, picks=0, **fast_test) + _, ax = plt.subplots(2, 1) + ica.plot_components(axes=ax, picks=[0, 1], **fast_test) + _, ax = plt.subplots(2, 2) + ica.plot_components(axes=ax, picks=[0, 1, 2, 3], **fast_test) + _, ax = plt.subplots(3, 2) + ica.plot_components( + axes=ax, picks=[0, 1, 2, 3, 4, 5], nrows=2, ncols=2, **fast_test + ) + ica.info = None with pytest.raises(RuntimeError, match='fit the ICA'): ica.plot_components(1, ch_type='mag') diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 0db592fe14c..e5b7e001aa6 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -1136,7 +1136,7 @@ def plot_ica_components( Defaults to True, which plots one standard deviation above/below. If set to float allows to control how many standard deviations are plotted. For example 2.5 will plot 2.5 standard deviation above/below. - reject : 'auto' | dict | None + reject : ``'auto'`` | dict | None Allows to specify rejection parameters used to drop epochs (or segments if continuous signal is passed as inst). If None, no rejection is applied. The default is 'auto', @@ -1167,7 +1167,12 @@ def plot_ica_components( .. versionadded:: 1.3 %(colorbar_topomap)s %(cbar_fmt_topomap)s - %(axes_evoked_plot_topomap)s + axes : Axes | array of Axes | None + The subplot(s) to plot to. Either a single Axes or an iterable of Axes + if more than one subplot is needed. The number of subplots must match + the number of selected components. If None, new figures will be created + with the number of subplots per figure controlled by ``nrows`` and + ``ncols``. title : str | None The title of the generated figure. If ``None`` (default) and ``axes=None``, a default title of "ICA Components" will be used. @@ -1199,6 +1204,8 @@ def plot_ica_components( topomap (this option is only available when the ``inst`` argument is supplied). """ # noqa E501 + from matplotlib.pyplot import Axes + from ..io import BaseRaw from ..epochs import BaseEpochs @@ -1226,7 +1233,12 @@ def plot_ica_components( figs = [] cut_points = range(max_subplots, n_components, max_subplots) pick_groups = np.split(range(n_components), cut_points) - for _picks in pick_groups: + for k, _picks in enumerate(pick_groups): + _axes = axes.flatten() if isinstance(axes, np.ndarray) else axes + try: # either an iterable, 1D numpy array or others + _axes = _axes[k * max_subplots: (k + 1) * max_subplots] + except TypeError: # None or Axes + _axes = axes fig = plot_ica_components( ica, picks=_picks, ch_type=ch_type, inst=inst, plot_std=plot_std, reject=reject, sensors=sensors, @@ -1234,7 +1246,7 @@ def plot_ica_components( sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, - cbar_fmt=cbar_fmt, axes=axes, title=title, nrows=nrows, + cbar_fmt=cbar_fmt, axes=_axes, title=title, nrows=nrows, ncols=ncols, show=show, image_args=image_args, psd_args=psd_args, verbose=verbose) figs.append(fig) @@ -1260,6 +1272,10 @@ def plot_ica_components( if not user_passed_axes: fig, axes, _, _ = _prepare_trellis(len(data), ncols=ncols, nrows=nrows) fig.suptitle(title) + else: + axes = axes.flatten() if isinstance(axes, np.ndarray) else axes + axes = [axes] if isinstance(axes, Axes) else axes + fig = axes[0].get_figure() subplot_titles = list() for ii, data_, ax in zip(picks, data, axes): From 52506f4bed1af37e42644ed261735ef0a27b7556 Mon Sep 17 00:00:00 2001 From: Moritz Gerster <45031224+moritz-gerster@users.noreply.github.com> Date: Thu, 27 Apr 2023 18:05:55 +0200 Subject: [PATCH 0032/1125] FIX: mne.concatenate_raws(raws) wrongly concatenates raws [ci skip] (#11640) Co-authored-by: Daniel McCloy --- doc/sensor_space.rst | 1 + mne/__init__.py | 2 +- mne/io/__init__.py | 2 +- mne/io/base.py | 22 ++++++++++++++ mne/io/fiff/tests/test_raw_fiff.py | 48 +++++++++++++++++++++++++----- mne/io/meas_info.py | 4 +++ 6 files changed, 70 insertions(+), 9 deletions(-) diff --git a/doc/sensor_space.rst b/doc/sensor_space.rst index a1c72b3aa59..b4bbda60053 100644 --- a/doc/sensor_space.rst +++ b/doc/sensor_space.rst @@ -11,6 +11,7 @@ Sensor Space Data concatenate_raws equalize_channels grand_average + match_channel_orders pick_channels pick_channels_cov pick_channels_forward diff --git a/mne/__init__.py b/mne/__init__.py index fc799f9d59a..27a2846887e 100644 --- a/mne/__init__.py +++ b/mne/__init__.py @@ -35,7 +35,7 @@ pick_types_forward, pick_channels_cov, pick_channels_evoked, pick_info, channel_type, channel_indices_by_type) -from .io.base import concatenate_raws +from .io.base import concatenate_raws, match_channel_orders from .io.meas_info import create_info, Info from .io.proj import Projection from .io.kit import read_epochs_kit diff --git a/mne/io/__init__.py b/mne/io/__init__.py index 0abb704873b..8af9495499d 100644 --- a/mne/io/__init__.py +++ b/mne/io/__init__.py @@ -66,6 +66,6 @@ # for backward compatibility from .fiff import Raw from .fiff import Raw as RawFIF -from .base import concatenate_raws +from .base import concatenate_raws, match_channel_orders from .reference import (set_eeg_reference, set_bipolar_reference, add_reference_channels) diff --git a/mne/io/base.py b/mne/io/base.py index df198af4033..867c58656ce 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -2570,6 +2570,28 @@ def concatenate_raws(raws, preload=None, events_list=None, *, return raws[0], events +@fill_doc +def match_channel_orders(raws, copy=True): + """Ensure consistent channel order across raws. + + Parameters + ---------- + raws : list + List of :class:`~mne.io.Raw` instances to order. + %(copy_df)s + + Returns + ------- + list of Raw + List of Raws with matched channel orders. + """ + raws = deepcopy(raws) if copy else raws + ch_order = raws[0].ch_names + for raw in raws[1:]: + raw.reorder_channels(ch_order) + return raws + + def _check_maxshield(allow_maxshield): """Warn or error about MaxShield.""" msg = ('This file contains raw Internal Active ' diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 25e2ad1792a..febd699d2cd 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -21,7 +21,8 @@ from mne.datasets import testing from mne.filter import filter_data from mne.io.constants import FIFF -from mne.io import RawArray, concatenate_raws, read_raw_fif, base +from mne.io import (RawArray, concatenate_raws, read_raw_fif, + match_channel_orders, base) from mne.io.open import read_tag, read_tag_info from mne.io.tag import _read_tag_header from mne.io.tests.test_raw import _test_concat, _test_raw_reader @@ -389,7 +390,7 @@ def _create_toy_data(n_channels=3, sfreq=250, seed=None): def test_concatenate_raws_bads_order(): - """Test concatenation of raw instances.""" + """Test concatenation of raws when the order of *bad* channels varies.""" raw0 = _create_toy_data() raw1 = _create_toy_data() @@ -410,25 +411,58 @@ def test_concatenate_raws_bads_order(): # Bad channel mismatch raises raw2 = raw1.copy() raw2.info["bads"] = ["0", "2"] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="bads.*must match"): concatenate_raws([raw0, raw2]) # Type mismatch raises epochs1 = make_fixed_length_epochs(raw1) - with pytest.raises(ValueError): - concatenate_raws([raw0, epochs1]) + with pytest.raises(ValueError, match="type.*must match"): + concatenate_raws([raw0, epochs1.load_data()]) # Sample rate mismatch raw3 = _create_toy_data(sfreq=500) - with pytest.raises(ValueError): + raw3.info["bads"] = ["0", "1"] + with pytest.raises(ValueError, match="info.*must match"): concatenate_raws([raw0, raw3]) # Number of channels mismatch raw4 = _create_toy_data(n_channels=4) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="nchan.*must match"): concatenate_raws([raw0, raw4]) +def test_concatenate_raws_order(): + """Test concatenation of raws when the order of *good* channels varies.""" + raw0 = _create_toy_data(n_channels=2) + raw0._data[0] = np.zeros_like(raw0._data[0]) # set one channel zero + + # Create copy and concatenate raws + raw1 = raw0.copy() + raw_concat = concatenate_raws([raw0.copy(), raw1]) + assert raw0.ch_names == raw1.ch_names == raw_concat.ch_names == ["0", "1"] + ch0 = raw_concat.get_data(picks=["0"]) + assert np.all(ch0 == 0) + + # Change the order of the channels and concatenate again + raw1.reorder_channels(["1", "0"]) + assert raw1.ch_names == ["1", "0"] + raws = [raw0.copy(), raw1] + with pytest.raises(ValueError, match="Channel order must match."): + # Fails now due to wrong order of channels + raw_concat = concatenate_raws(raws) + + with pytest.raises(ValueError, match="Channel order must match."): + # still fails, because raws is copied and not changed in place + match_channel_orders(raws, copy=True) + raw_concat = concatenate_raws(raws) + + # Now passes because all raws have the same order + match_channel_orders(raws, copy=False) + raw_concat = concatenate_raws(raws) + ch0 = raw_concat.get_data(picks=["0"]) + assert np.all(ch0 == 0) + + @testing.requires_testing_data @pytest.mark.parametrize('mod', ( 'meg', diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 9bc63b38052..3e6eb62c4a6 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -2906,6 +2906,10 @@ def _ensure_infos_match(info1, info2, name, *, on_mismatch='raise'): raise ValueError(f'{name}.info[\'sfreq\'] must match') if set(info1['ch_names']) != set(info2['ch_names']): raise ValueError(f'{name}.info[\'ch_names\'] must match') + if info1['ch_names'] != info2['ch_names']: + msg = (f'{name}.info[\'ch_names\']: Channel order must match. Use ' + '"mne.match_channel_orders()" to sort channels.') + raise ValueError(msg) if len(info2['projs']) != len(info1['projs']): raise ValueError(f'SSP projectors in {name} must be the same') if any(not _proj_equal(p1, p2) for p1, p2 in From cecbf0fb1dc919254bb6095db491b2d7c5f68003 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 27 Apr 2023 13:01:21 -0400 Subject: [PATCH 0033/1125] MAINT: Simplify doc/conf.py (#11657) --- .circleci/config.yml | 14 ++- .github/workflows/circle_artifacts.yml | 2 +- azure-pipelines.yml | 34 ------- doc/Makefile | 120 ++++++------------------- doc/conf.py | 75 +++++++--------- doc/install/contributing.rst | 16 +--- doc/sphinxext/gen_commands.py | 29 +++--- environment.yml | 1 + requirements_testing_extra.txt | 3 +- tools/circleci_download.sh | 14 ++- tools/github_actions_dependencies.sh | 5 +- 11 files changed, 93 insertions(+), 220 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9cbb54338d4..554f27fc55b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -247,7 +247,7 @@ jobs: - run: name: make test-doc command: | - if [[ $(cat gitlog.txt) == *"[circle front]"* ]] || [[ $(cat build.txt) == "html_dev-memory" ]] || [[ $(cat build.txt) == "html_stable-memory" ]]; then + if [[ $(cat gitlog.txt) == *"[circle front]"* ]] || [[ $(cat build.txt) == "html-memory" ]] ; then make test-doc; mkdir -p doc/_build/test-results/test-doc; cp junit-results.xml doc/_build/test-results/test-doc/junit.xml; @@ -276,7 +276,7 @@ jobs: - run: name: Reduce artifact upload time command: | - if grep -q html_dev-pattern-memory build.txt || grep -q html_dev-noplot build.txt; then + if grep -q html-pattern-memory build.txt || grep -q html-noplot build.txt; then zip -rm doc/_build/html/_downloads.zip doc/_build/html/_downloads fi for NAME in generated auto_tutorials auto_examples; do @@ -299,15 +299,11 @@ jobs: # Save the HTML - store_artifacts: path: doc/_build/html/ - destination: dev - - store_artifacts: - path: doc/_build/html_stable/ - destination: stable + destination: html - persist_to_workspace: root: doc/_build paths: - html - - html_stable # Keep these separate, maybe better in terms of size limitations (?) - save_cache: @@ -465,7 +461,7 @@ jobs: - run: name: Check docs command: | - if [ ! -f /tmp/build/html/index.html ] && [ ! -f /tmp/build/html_stable/index.html ]; then + if [ ! -f /tmp/build/html/index.html ] ; then echo "No files found to upload (build: ${CIRCLE_BRANCH})."; circleci-agent step halt; fi; @@ -498,7 +494,7 @@ jobs: else echo "Deploying stable docs for ${CIRCLE_BRANCH}."; rm -Rf stable; - cp -a /tmp/build/html_stable stable; + cp -a /tmp/build/html stable; git add -A; git commit -m "CircleCI update of stable docs (${CIRCLE_BUILD_NUM})."; fi; diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index 96a4264627c..9a98c32d2f5 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -10,6 +10,6 @@ jobs: with: repo-token: ${{ secrets.GITHUB_TOKEN }} api-token: ${{ secrets.CIRCLECI_TOKEN }} - artifact-path: 0/dev/index.html + artifact-path: 0/html/index.html circleci-jobs: build_docs,build_docs_main job-title: Check the rendered docs here! diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3a68d3226af..b050cc191c1 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -335,37 +335,3 @@ stages: codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' - - - job: SphinxWindows - pool: - vmImage: 'windows-latest' - variables: - AZURE_CI_WINDOWS: 'true' - steps: - - bash: | - set -e - git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git - powershell gl-ci-helpers/appveyor/install_opengl.ps1 - displayName: Install OpenGL - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.10' - - bash: | - set -eo pipefail - PYTHONUTF8=1 pip install --progress-bar off -r requirements.txt -r requirements_doc.txt - displayName: Install documentation dependencies - - script: pip install -e . - displayName: Install dev MNE - - script: mne sys_info -pd - displayName: Print config and test access to commands - - script: python -c "import numpy; numpy.show_config()" - displayName: Print NumPy config - - bash: | - set -eo pipefail - sed -i 's/.. graphviz::/.. graphviz/g' doc/install/contributing.rst - sed -i 's/.. graphviz::/.. graphviz/g' tutorials/preprocessing/40_artifact_correction_ica.py - sed -i '/sphinx\.ext\.graphviz/d' doc/conf.py - displayName: Skip graph that we cannot render - # TODO: Reenable this once we can get it to work! - # - bash: make -C doc html_dev-noplot - # displayName: 'Build doc' diff --git a/doc/Makefile b/doc/Makefile index 452b392759c..a8b50b9908b 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -4,125 +4,59 @@ # You can set these variables from the command line. SPHINXOPTS = -nWT --keep-going SPHINXBUILD = sphinx-build -PAPER = MPROF = SG_STAMP_STARTS=true mprof run -E --python sphinx # Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d _build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +ALLSPHINXOPTS = -d _build/doctrees $(SPHINXOPTS) . -.PHONY: help clean html dirhtml pickle json htmlhelp qthelp latex changes linkcheck doctest +.PHONY: help clean html html-noplot html-pattern linkcheck linkcheck-grep doctest # make with no arguments will build the first target by default, i.e., build standalone HTML files -first_target: html_dev-noplot +first_target: html-noplot help: @echo "Please use \`make ' where is one of" - @echo " html_stable to make standalone HTML files (stable version)" - @echo " html_dev to make standalone HTML files (dev version)" - @echo " html_dev-pattern to make standalone HTML files for one example dir (dev version)" - @echo " *-noplot to make standalone HTML files without plotting" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " changes to make an overview of all changed/added/deprecated items" + @echo " html to make standalone HTML files" + @echo " html-memory to make standalone HTML files while monitoring memory usage" + @echo " html-pattern to make standalone HTML files for a specific filename pattern" + @echo " html-front to make standalone HTML files with only the frontpage examples" + @echo " html-noplot to make standalone HTML files without plotting" + @echo " clean to clean HTML files" @echo " linkcheck to check all external links for integrity" + @echo " linkcheck-grep to grep the linkcheck resut" @echo " doctest to run all doctests embedded in the documentation (if enabled)" + @echo " view to view the built HTML" clean: -rm -rf _build auto_examples auto_tutorials generated *.stc *.fif *.nii.gz -html_stable: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) _build/html_stable +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) _build/html @echo - @echo "Build finished. The HTML pages are in _build/html_stable." - -html_stable-memory: - $(MPROF) -b html $(ALLSPHINXOPTS) _build/html_stable - @echo - @echo "Build finished. The HTML pages are in _build/html_stable." - -html_dev: - BUILD_DEV_HTML=1 $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) _build/html - @echo - @echo "Build finished. The HTML pages are in _build/html" - -html_dev-memory: - BUILD_DEV_HTML=1 $(MPROF) -b html $(ALLSPHINXOPTS) _build/html - @echo - @echo "Build finished. The HTML pages are in _build/html" + @echo "Build finished. The HTML pages are in _build/html." -html_dev-pattern: - BUILD_DEV_HTML=1 $(SPHINXBUILD) -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -D sphinx_gallery_conf.run_stale_examples=True -b html $(ALLSPHINXOPTS) _build/html +html-memory: + $(MPROF) -b html $(ALLSPHINXOPTS) _build/html @echo - @echo "Build finished. The HTML pages are in _build/html" + @echo "Build finished. The HTML pages are in _build/html." -html_dev-pattern-memory: - BUILD_DEV_HTML=1 $(MPROF) -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -D sphinx_gallery_conf.run_stale_examples=True -b html $(ALLSPHINXOPTS) _build/html +html-pattern: + $(SPHINXBUILD) -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -D sphinx_gallery_conf.run_stale_examples=True -b html $(ALLSPHINXOPTS) _build/html @echo @echo "Build finished. The HTML pages are in _build/html" -html_dev-noplot: - BUILD_DEV_HTML=1 $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html - @echo - @echo "Build finished. The HTML pages are in _build/html." - -html_dev-debug: - BUILD_DEV_HTML=1 $(SPHINXBUILD) -PD plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html - html-noplot: - $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html_stable - @echo - @echo "Build finished. The HTML pages are in _build/html_stable." - -html_dev-front: - @PATTERN="\(30_mne_dspm_loreta.py\|50_decoding.py\|30_strf.py\|20_cluster_1samp_spatiotemporal.py\|20_visualize_evoked.py\)" make html_dev-pattern; - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) _build/dirhtml - @echo - @echo "Build finished. The HTML pages are in _build/dirhtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) _build/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) _build/json + $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html @echo - @echo "Build finished; now you can process the JSON files." + @echo "Build finished. The HTML pages are in _build/html." -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) _build/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in _build/htmlhelp." +html-front: + @PATTERN="\(30_mne_dspm_loreta.py\|50_decoding.py\|30_strf.py\|20_cluster_1samp_spatiotemporal.py\|20_visualize_evoked.py\)" make html-pattern -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) _build/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in _build/qthelp, like this:" - @echo "# qcollectiongenerator _build/qthelp/MNE.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile _build/qthelp/MNE.qhc" - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) _build/latex - @echo - @echo "Build finished; the LaTeX files are in _build/latex." - @echo "Run \`make all-pdf' or \`make all-ps' in that directory to" \ - "run these through (pdf)latex." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) _build/changes - @echo - @echo "The overview file is in _build/changes." +# Aliases for old methods +html_dev-pattern-memory: html-pattern +html_dev-noplot: html-noplot +html_dev-front: html-front linkcheck: @$(SPHINXBUILD) -b linkcheck -D nitpicky=0 -D plot_gallery=0 -D exclude_patterns="cited.rst,whats_new.rst,configure_git.rst" -d _build/doctrees . _build/linkcheck diff --git a/doc/conf.py b/doc/conf.py index 41a0d3d6272..d26ef8d269b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,6 +32,7 @@ os.environ['MNE_BROWSER_OVERVIEW_MODE'] = 'hidden' os.environ['MNE_BROWSER_THEME'] = 'light' os.environ['MNE_3D_OPTION_THEME'] = 'light' +sphinx_logger = sphinx.util.logging.getLogger('mne') # -- Path setup -------------------------------------------------------------- @@ -62,6 +63,8 @@ # # The full version, including alpha/beta/rc tags. release = mne.__version__ +sphinx_logger.info( + f'Building documentation for MNE {release} ({mne.__file__})') # The short X.Y version. version = '.'.join(release.split('.')[:2]) @@ -338,7 +341,7 @@ def __call__(self, gallery_conf, fname, when): gc.collect() when = f'mne/conf.py:Resetter.__call__:{when}:{fname}' # Support stuff like - # MNE_SKIP_INSTANCE_ASSERTIONS="Brain,Plotter,BackgroundPlotter,vtkPolyData,_Renderer" make html_dev-memory # noqa: E501 + # MNE_SKIP_INSTANCE_ASSERTIONS="Brain,Plotter,BackgroundPlotter,vtkPolyData,_Renderer" make html-memory # noqa: E501 # to just test MNEQtBrowser skips = os.getenv('MNE_SKIP_INSTANCE_ASSERTIONS', '').lower() prefix = '' @@ -378,33 +381,22 @@ def __call__(self, gallery_conf, fname, when): gallery_dirs = ['auto_tutorials', 'auto_examples'] os.environ['_MNE_BUILDING_DOC'] = 'true' scrapers = ('matplotlib',) -try: - mne.viz.set_3d_backend(mne.viz.get_3d_backend()) -except Exception: - report_scraper = None -else: - backend = mne.viz.get_3d_backend() - if backend in ('notebook', 'pyvistaqt'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - import pyvista - pyvista.OFF_SCREEN = False - pyvista.BUILDING_GALLERY = True - scrapers += ( - mne.gui._GUIScraper(), - mne.viz._brain._BrainScraper(), - 'pyvista', - ) - report_scraper = mne.report._ReportScraper() - scrapers += (report_scraper,) - del backend -try: - import mne_qt_browser - _min_ver = _compare_version(mne_qt_browser.__version__, '>=', '0.2') - if mne.viz.get_browser_backend() == 'qt' and _min_ver: - scrapers += (mne.viz._scraper._MNEQtBrowserScraper(),) -except ImportError: - pass +mne.viz.set_3d_backend('pyvistaqt') +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import pyvista +pyvista.OFF_SCREEN = False +pyvista.BUILDING_GALLERY = True + +report_scraper = mne.report._ReportScraper() +scrapers = ( + 'matplotlib', + mne.gui._GUIScraper(), + mne.viz._brain._BrainScraper(), + 'pyvista', + report_scraper, + mne.viz._scraper._MNEQtBrowserScraper(), +) compress_images = ('images', 'thumbnails') # let's make things easier on Windows users @@ -690,7 +682,6 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): xxl = '6' # variables to pass to HTML templating engine html_context = { - 'build_dev_html': bool(int(os.environ.get('BUILD_DEV_HTML', False))), 'default_mode': 'auto', 'pygment_light_style': 'tango', 'pygment_dark_style': 'native', @@ -1073,15 +1064,15 @@ def reset_warnings(gallery_conf, fname): icon_class[icon] = ('fa-brands',) if icon in brand_icons else ('fa-solid',) icon_class[icon] += ('fa-fw',) if icon in fixed_width_icons else () -prolog = '' +rst_prolog = '' for icon, classes in icon_class.items(): - prolog += f''' + rst_prolog += f''' .. |{icon}| raw:: html ''' -prolog += ''' +rst_prolog += ''' .. |ensp| unicode:: U+2002 .. EN SPACE ''' @@ -1097,7 +1088,7 @@ def reset_warnings(gallery_conf, fname): if line.strip().startswith('Requires-Python'): min_py = line.split(':')[1] min_py = min_py.lstrip(' =<>') -prolog += f'\n.. |min_python_version| replace:: {min_py}\n' +rst_prolog += f'\n.. |min_python_version| replace:: {min_py}\n' # -- website redirects -------------------------------------------------------- @@ -1285,7 +1276,7 @@ def make_redirects(app, exception): assert os.path.isfile(to_path), (fname, to_path) with open(fr_path, 'w') as fid: fid.write(TEMPLATE.format(to=to_fname)) - logger.info( + sphinx_logger.info( f'Added {len(fnames):3d} HTML plot_* redirects for {out_dir}') # custom redirects for fr, to in custom_redirects.items(): @@ -1313,7 +1304,7 @@ def make_redirects(app, exception): os.makedirs(os.path.dirname(fr_path), exist_ok=True) with open(fr_path, 'w') as fid: fid.write(TEMPLATE.format(to=to)) - logger.info( + sphinx_logger.info( f'Added {len(custom_redirects):3d} HTML custom redirects') @@ -1327,11 +1318,11 @@ def make_version(app, exception): try: stdout, _ = run_subprocess(['git', 'rev-parse', 'HEAD'], verbose=False) except Exception as exc: - logger.warning(f'Failed to write _version.txt: {exc}') + sphinx_logger.warning(f'Failed to write _version.txt: {exc}') return with open(os.path.join(app.outdir, '_version.txt'), 'w') as fid: fid.write(stdout) - logger.info(f'Added "{stdout.rstrip()}" > _version.txt') + sphinx_logger.info(f'Added "{stdout.rstrip()}" > _version.txt') # -- Connect our handlers to the main Sphinx app --------------------------- @@ -1339,13 +1330,7 @@ def make_version(app, exception): def setup(app): """Set up the Sphinx app.""" app.connect('autodoc-process-docstring', append_attr_meth_examples) - if report_scraper is not None: - report_scraper.app = app - app.config.rst_prolog = prolog - app.connect('builder-inited', report_scraper.copyfiles) - sphinx_logger = sphinx.util.logging.getLogger('mne') - sphinx_logger.info( - f'Building documentation for MNE {release} ({mne.__file__})') - sphinx_logger.info(f'Building with scrapers={scrapers}') + report_scraper.app = app + app.connect('builder-inited', report_scraper.copyfiles) app.connect('build-finished', make_redirects) app.connect('build-finished', make_version) diff --git a/doc/install/contributing.rst b/doc/install/contributing.rst index f7b278fe7d7..19e5a126874 100644 --- a/doc/install/contributing.rst +++ b/doc/install/contributing.rst @@ -909,28 +909,16 @@ You can build the documentation locally using `GNU Make`_ with :file:`doc/Makefile`. From within the :file:`doc` directory, you can test formatting and linking by running:: - $ make html_dev-noplot + $ make html-noplot This will build the documentation *except* it will format (but not execute) the tutorial and example files. If you have created or modified an example or tutorial, you should instead run -:samp:`PATTERN={} make html_dev-pattern` to render +:samp:`make html-pattern PATTERN={}` to render all the documentation and additionally execute just your example or tutorial (so you can make sure it runs successfully and generates the output / figures you expect). -.. note:: - If you are using a *Windows command shell*, to use the pattern approach, - use the following two lines: - - .. code-block:: doscon - - > set PATTERN= - > make html_dev-pattern - - If you are on Windows but using the `git BASH`_ shell, use the same two - commands but replace ``set`` with ``export``. - After either of these commands completes, ``make show`` will open the locally-rendered documentation site in your browser. If you see many warnings that seem unrelated to your contributions, it might be that your output folder diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py index e5b2ed391b6..0339160b2bb 100644 --- a/doc/sphinxext/gen_commands.py +++ b/doc/sphinxext/gen_commands.py @@ -1,7 +1,7 @@ import glob from importlib import import_module import os -from os import path as op +from pathlib import Path from mne.utils import _replace_md5, ArgvSetter @@ -47,17 +47,20 @@ def setup_module(): def generate_commands_rst(app=None): - from sphinx.util import status_iterator - out_dir = op.abspath(op.join(op.dirname(__file__), '..', 'generated')) - if not op.isdir(out_dir): - os.mkdir(out_dir) - out_fname = op.join(out_dir, 'commands.rst.new') - - command_path = op.abspath( - op.join(os.path.dirname(__file__), '..', '..', 'mne', 'commands')) - fnames = sorted([ - op.basename(fname) - for fname in glob.glob(op.join(command_path, 'mne_*.py'))]) + try: + from sphinx.util.display import status_iterator + except Exception: + from sphinx.util import status_iterator + root = Path(__file__).parent.parent.parent.absolute() + out_dir = (root / 'doc' / 'generated').absolute() + out_dir.mkdir(exist_ok=True) + out_fname =out_dir / 'commands.rst.new' + + command_path = root / 'mne' / 'commands' + fnames = sorted( + Path(fname).name + for fname in glob.glob(str(command_path / 'mne_*.py'))) + assert len(fnames) iterator = status_iterator( fnames, 'generating MNE command help ... ', length=len(fnames)) with open(out_fname, 'w', encoding='utf8') as f: @@ -97,7 +100,7 @@ def generate_commands_rst(app=None): cmd_name_space = cmd_name.replace('mne_', 'mne ') f.write(command_rst.format( cmd_name_space, '=' * len(cmd_name_space), output)) - _replace_md5(out_fname) + _replace_md5(str(out_fname)) # This is useful for testing/iterating to see what the result looks like diff --git a/environment.yml b/environment.yml index 960cbf8e775..37ae4e6582e 100644 --- a/environment.yml +++ b/environment.yml @@ -27,6 +27,7 @@ dependencies: - numexpr - imageio - spyder-kernels>=1.10.0 +- imageio>=2.6.1 - imageio-ffmpeg>=0.4.1 - vtk>=9.2 - traitlets diff --git a/requirements_testing_extra.txt b/requirements_testing_extra.txt index 6154bae3f11..09126e5f412 100644 --- a/requirements_testing_extra.txt +++ b/requirements_testing_extra.txt @@ -5,4 +5,5 @@ sphinx-gallery eeglabio EDFlib-Python pybv -imageio-ffmpeg +imageio>=2.6.1 +imageio-ffmpeg>=0.4.1 diff --git a/tools/circleci_download.sh b/tools/circleci_download.sh index cb622cb1860..6a411ef8ff1 100755 --- a/tools/circleci_download.sh +++ b/tools/circleci_download.sh @@ -3,13 +3,9 @@ set -o pipefail export MNE_TQDM=off -if [ "$CIRCLE_BRANCH" == "main" ] || [[ $(cat gitlog.txt) == *"[circle full]"* ]]; then - echo "Doing a full dev build"; - echo html_dev-memory > build.txt; - python -c "import mne; mne.datasets._download_all_example_data()"; -elif [ "$CIRCLE_BRANCH" == "maint/1.3" ]; then - echo "Doing a full stable build"; - echo html_stable-memory > build.txt; +if [ "$CIRCLE_BRANCH" == "main" ] || [[ $(cat gitlog.txt) == *"[circle full]"* ]] || [[ "$CIRCLE_BRANCH" == "maint/"* ]]; then + echo "Doing a full build"; + echo html-memory > build.txt; python -c "import mne; mne.datasets._download_all_example_data()"; else echo "Doing a partial build"; @@ -119,9 +115,9 @@ else echo PATTERN="$PATTERN"; if [[ $PATTERN ]]; then PATTERN="\(${PATTERN::-2}\)"; - echo html_dev-pattern-memory > build.txt; + echo html-pattern-memory > build.txt; else - echo html_dev-noplot > build.txt; + echo html-noplot > build.txt; fi; fi; echo "$PATTERN" > pattern.txt; diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 0391ef59df0..ab67af8794a 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -22,7 +22,10 @@ else # pip install $STD_ARGS --pre --only-binary ":all:" --no-deps --extra-index-url https://www.riverbankcomputing.com/pypi/simple PyQt6 pip install $STD_ARGS --pre --only-binary ":all:" PyQt6 echo "NumPy/SciPy/pandas etc." - pip install $STD_ARGS --pre --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" numpy scipy scikit-learn dipy pandas matplotlib pillow statsmodels + pip install $STD_ARGS --pre --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" numpy + # SciPy<->sklearn problematic, see https://github.com/scipy/scipy/issues/18377 + pip install $STD_ARGS --pre --only-binary ":all:" scipy + pip install $STD_ARGS --pre --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" scikit-learn dipy pandas matplotlib pillow statsmodels pip install $STD_ARGS --pre --only-binary ":all:" -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py # No Numba because it forces an old NumPy version echo "nilearn and openmeeg" From 8c7f0c61f37a150712251cc32200152b6b4bd37a Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Thu, 27 Apr 2023 12:47:51 -0700 Subject: [PATCH 0034/1125] [MAINT, MRG] Move over tutorials (#11646) Co-authored-by: Eric Larson --- doc/Makefile | 8 +- doc/_static/style.css | 12 + doc/conf.py | 3 +- examples/preprocessing/README.txt | 23 + examples/preprocessing/locate_ieeg_micro.py | 94 ---- tutorials/clinical/10_ieeg_localize.py | 544 -------------------- tutorials/clinical/README.txt | 35 +- 7 files changed, 78 insertions(+), 641 deletions(-) delete mode 100644 examples/preprocessing/locate_ieeg_micro.py delete mode 100644 tutorials/clinical/10_ieeg_localize.py diff --git a/doc/Makefile b/doc/Makefile index a8b50b9908b..f32527537d5 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -45,6 +45,11 @@ html-pattern: @echo @echo "Build finished. The HTML pages are in _build/html" +html-pattern-memory: + $(MPROF) -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -D sphinx_gallery_conf.run_stale_examples=True -b html $(ALLSPHINXOPTS) _build/html + @echo + @echo "Build finished. The HTML pages are in _build/html" + html-noplot: $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html @echo @@ -54,7 +59,8 @@ html-front: @PATTERN="\(30_mne_dspm_loreta.py\|50_decoding.py\|30_strf.py\|20_cluster_1samp_spatiotemporal.py\|20_visualize_evoked.py\)" make html-pattern # Aliases for old methods -html_dev-pattern-memory: html-pattern +html_dev-pattern-memory: html-pattern-memory +html_dev-pattern: html-pattern html_dev-noplot: html-noplot html_dev-front: html-front diff --git a/doc/_static/style.css b/doc/_static/style.css index 1d51176c383..81df2a5637e 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -151,6 +151,18 @@ iframe.sg_report { display: none; } +/* Make our external thumbnails (e.g., mne-gui-addons) act like standard SG ones */ +.sphx-glr-thumbcontainer a.external { + bottom: 0; + display: block; + left: 0; + box-sizing: border-box; + padding: 150px 10px 0; + position: absolute; + right: 0; + top: 0; +} + /* ***************************************************** sphinx-design fixes */ p.btn a { color: unset; diff --git a/doc/conf.py b/doc/conf.py index d26ef8d269b..86b8967634d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -493,7 +493,8 @@ def __call__(self, gallery_conf, fname, when): '.*fit_transform()|.*get_params()|.*predict()|' '.*predict_proba()|.*set_params()|.*transform()|' # I/O, also related to mixins - '.*.remove.*|.*.write.*)') + '.*.remove.*|.*.write.*)'), + 'copyfile_regex': r'.*index\.rst', # allow custom index.rst files } # Files were renamed from plot_* with: # find . -type f -name 'plot_*.py' -exec sh -c 'x="{}"; xn=`basename "${x}"`; git mv "$x" `dirname "${x}"`/${xn:5}' \; # noqa diff --git a/examples/preprocessing/README.txt b/examples/preprocessing/README.txt index ff3ee6ed6c2..bd197e073b5 100644 --- a/examples/preprocessing/README.txt +++ b/examples/preprocessing/README.txt @@ -4,3 +4,26 @@ Preprocessing Examples related to data preprocessing (artifact detection / rejection etc.) +.. raw:: html + +
+ +.. raw:: html + +
+ +.. only:: html + + .. image:: https://mne.tools/mne-gui-addons/_images/sphx_glr_locate_ieeg_micro_001.png + :alt: + + :ref:`ex-ieeg-micro` + +.. raw:: html + +
Locating micro-scale intracranial electrode contacts
+
+ +.. raw:: html + +
\ No newline at end of file diff --git a/examples/preprocessing/locate_ieeg_micro.py b/examples/preprocessing/locate_ieeg_micro.py deleted file mode 100644 index 20142af7177..00000000000 --- a/examples/preprocessing/locate_ieeg_micro.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -.. _ex-ieeg-micro: - -==================================================== -Locating micro-scale intracranial electrode contacts -==================================================== - -When intracranial electrode contacts are very small, sometimes -the computed tomography (CT) scan is higher resolution than the -magnetic resonance (MR) image and so you want to find the contacts -on the CT without downsampling to the MR resolution. This example -shows how to do this. -""" - -# Authors: Alex Rockhill -# -# License: BSD-3-Clause - -import numpy as np -import nibabel as nib -import mne -import mne_gui_addons - -# path to sample sEEG -misc_path = mne.datasets.misc.data_path() -subjects_dir = misc_path / 'seeg' - -# GUI requires pyvista backend -mne.viz.set_3d_backend('pyvistaqt') - -# we need three things: -# 1) The electrophysiology file which contains the channels names -# that we would like to associate with positions in the brain -# 2) The CT where the electrode contacts show up with high intensity -# 3) The MR where the brain is best visible (low contrast in CT) -raw = mne.io.read_raw(misc_path / 'seeg' / 'sample_seeg_ieeg.fif') -CT_orig = nib.load(misc_path / 'seeg' / 'sample_seeg_CT.mgz') -T1 = nib.load(misc_path / 'seeg' / 'sample_seeg' / 'mri' / 'T1.mgz') - -# we'll also need a head-CT surface RAS transform, this can be faked with an -# identify matrix but we'll find the fiducials on the CT in freeview (be sure -# to find them in surface RAS (TkReg RAS in freeview) and not scanner RAS -# (RAS in freeview)) (also be sure to note left is generally on the right in -# freeview) and reproduce them here: -montage = mne.channels.make_dig_montage( - nasion=[-28.97, -5.88, -76.40], lpa=[-96.35, -16.26, 17.63], - rpa=[31.28, -52.95, -0.69], coord_frame='mri') -raw.set_montage(montage, on_missing='ignore') # haven't located yet! -head_ct_t = mne.transforms.invert_transform( - mne.channels.compute_native_head_t(montage)) - -# note: coord_frame = 'mri' is a bit of a misnormer, it is a reference to -# the surface RAS coordinate frame, here it is of the CT - - -# launch the viewer with only the CT (note, we won't be able to use -# the MR in this case to help determine which brain area the contact is -# in), and use the user interface to find the locations of the contacts -gui = mne_gui_addons.locate_ieeg(raw.info, head_ct_t, CT_orig) - -# we'll programmatically mark all the contacts on one electrode shaft -for i, pos in enumerate([(-52.66, -40.84, -26.99), (-55.47, -38.03, -27.92), - (-57.68, -36.27, -28.85), (-59.89, -33.81, -29.32), - (-62.57, -31.35, -30.37), (-65.13, -29.07, -31.30), - (-67.57, -26.26, -31.88)]): - gui.set_RAS(pos) - gui.mark_channel(f'LENT {i + 1}') - -# finally, the coordinates will be in "head" (unless the trans was faked -# as the identity, in which case they will be in surface RAS of the CT already) -# so we need to convert them to scanner RAS of the CT, apply the alignment so -# that they are in scanner RAS of the MRI and from there to surface RAS -# of the MRI for viewing using freesurfer recon-all surfaces--fortunately -# that is done for us in `mne.transforms.apply_volume_registration_points` - -# note that since we didn't fake the head->CT surface RAS transform, we -# could apply the head->mri transform directly but that relies of the -# fiducial points being marked exactly the same on the CT as on the MRI-- -# the error from this is not precise enough for intracranial electrophysiology, -# better is to rely on the precision of the CT-MR image registration - -reg_affine = np.array([ # CT-MR registration - [0.99270756, -0.03243313, 0.11610254, -133.094156], - [0.04374389, 0.99439665, -0.09623816, -97.58320673], - [-0.11233068, 0.10061512, 0.98856381, -84.45551601], - [0., 0., 0., 1.]]) - -raw.info, head_mri_t = mne.transforms.apply_volume_registration_points( - raw.info, head_ct_t, CT_orig, T1, reg_affine) - -brain = mne.viz.Brain(subject='sample_seeg', subjects_dir=subjects_dir, - alpha=0.5) -brain.add_sensors(raw.info, head_mri_t) -brain.show_view(azimuth=120, elevation=100) diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py deleted file mode 100644 index b256c903553..00000000000 --- a/tutorials/clinical/10_ieeg_localize.py +++ /dev/null @@ -1,544 +0,0 @@ -""" -.. _tut-ieeg-localize: - -======================================== -Locating intracranial electrode contacts -======================================== - -Analysis of intracranial electrophysiology recordings typically involves -finding the position of each contact relative to brain structures. In a -typical setup, the brain and the electrode locations will be in two places -and will have to be aligned; the brain is best visualized by a -pre-implantation magnetic resonance (MR) image whereas the electrode contact -locations are best visualized in a post-implantation computed tomography (CT) -image. The CT image has greater intensity than the background at each of the -electrode contacts and for the skull. Using the skull, the CT can be aligned -to MR-space. This accomplishes our goal of obtaining contact locations in -MR-space (which is where the brain structures are best determined using the -:ref:`tut-freesurfer-reconstruction`). Contact locations in MR-space can also -be warped to a template space such as ``fsaverage`` for group comparisons. -Please note that this tutorial requires ``nibabel``, ``nilearn`` and ``dipy`` -which can be installed using ``pip`` as well as 3D plotting -(see :ref:`manual-install`). - -Support for intracranial electrophysiology analysis in MNE was added after -the original publication, so please cite :footcite:`RockhillEtAl2022` if you -use this module in your analysis to support the addition of new projects to -MNE. -""" -# Authors: Alex Rockhill -# Eric Larson -# -# License: BSD-3-Clause - -# %% -import numpy as np -import matplotlib.pyplot as plt - -import nibabel as nib -import nilearn.plotting -from dipy.align import resample - -import mne -import mne_gui_addons as mne_gui -from mne.datasets import fetch_fsaverage - -# paths to mne datasets: sample sEEG and FreeSurfer's fsaverage subject, -# which is in MNI space -misc_path = mne.datasets.misc.data_path() -sample_path = mne.datasets.sample.data_path() -subjects_dir = sample_path / 'subjects' - -# use mne-python's fsaverage data -fetch_fsaverage(subjects_dir=subjects_dir, verbose=True) # downloads if needed - -# GUI requires pyvista backend -mne.viz.set_3d_backend('pyvistaqt') - -############################################################################### -# Aligning the T1 to ACPC -# ======================= -# -# For intracranial electrophysiology recordings, the Brain Imaging Data -# Structure (BIDS) standard requires that coordinates be aligned to the -# anterior commissure and posterior commissure (ACPC-aligned). Therefore, it is -# recommended that you do this alignment before finding the positions of the -# channels in your recording. Doing this will make the "mri" (aka surface RAS) -# coordinate frame an ACPC coordinate frame. This can be done using -# Freesurfer's freeview: -# -# .. code-block:: console -# -# $ freeview $MISC_PATH/seeg/sample_seeg_T1.mgz -# -# And then interact with the graphical user interface: -# -# First, it is recommended to change the cursor style to long, this can be done -# through the menu options like so: -# -# :menuselection:`Freeview --> Preferences --> General --> Cursor style -# --> Long` -# -# Then, the image needs to be aligned to ACPC to look like the image below. -# This can be done by pulling up the transform popup from the menu like so: -# -# :menuselection:`Tools --> Transform Volume` -# -# .. note:: -# Be sure to set the text entry box labeled RAS (not TkReg RAS) to -# ``0 0 0`` before beginning the transform. -# -# Then translate the image until the crosshairs meet on the AC and -# run through the PC as shown in the plot. The eyes should be in -# the ACPC plane and the image should be rotated until they are symmetrical, -# and the crosshairs should transect the midline of the brain. -# Be sure to use both the rotate and the translate menus and save the volume -# after you're finished using ``Save Volume As`` in the transform popup -# :footcite:`HamiltonEtAl2017`. - -T1 = nib.load(misc_path / 'seeg' / 'sample_seeg' / 'mri' / 'T1.mgz') -viewer = T1.orthoview() -viewer.set_position(0, 9.9, 5.8) -viewer.figs[0].axes[0].annotate( - 'PC', (107, 108), xytext=(10, 75), color='white', - horizontalalignment='center', - arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) -viewer.figs[0].axes[0].annotate( - 'AC', (137, 108), xytext=(246, 75), color='white', - horizontalalignment='center', - arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) - -# %% -# Freesurfer recon-all -# ==================== -# -# The first step is the most time consuming; the freesurfer reconstruction. -# This process segments out the brain from the rest of the MR image and -# determines which voxels correspond to each brain area based on a template -# deformation. This process takes approximately 8 hours so plan accordingly. -# The example dataset contains the data from completed reconstruction so -# we will proceed using that. -# -# .. code-block:: console -# -# $ export SUBJECT=sample_seeg -# $ export SUBJECTS_DIR=$MY_DATA_DIRECTORY -# $ recon-all -subjid $SUBJECT -sd $SUBJECTS_DIR \ -# -i $MISC_PATH/seeg/sample_seeg_T1.mgz -all -deface -# -# .. note:: -# You may need to include an additional ``-cw256`` flag which can be added -# to the end of the recon-all command if your MR scan is not -# ``256 × 256 × 256`` voxels. -# -# .. note:: -# Using the ``-deface`` flag will create a defaced, anonymized T1 image -# located at ``$MY_DATA_DIRECTORY/$SUBJECT/mri/orig_defaced.mgz``, -# which is helpful for when you publish your data. You can also use -# :func:`mne_bids.write_anat` and pass ``deface=True``. - - -# %% -# Aligning the CT to the MR -# ========================= -# -# Let's load our T1 and CT images and visualize them. You can hardly -# see the CT, it's so misaligned that all you can see is part of the -# stereotactic frame that is anteriolateral to the skull in the middle plot. -# Clearly, we need to align the CT to the T1 image. - -def plot_overlay(image, compare, title, thresh=None): - """Define a helper function for comparing plots.""" - image = nib.orientations.apply_orientation( - np.asarray(image.dataobj), nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(image.affine))).astype(np.float32) - compare = nib.orientations.apply_orientation( - np.asarray(compare.dataobj), nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(compare.affine))).astype(np.float32) - if thresh is not None: - compare[compare < np.quantile(compare, thresh)] = np.nan - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle(title) - for i, ax in enumerate(axes): - ax.imshow(np.take(image, [image.shape[i] // 2], axis=i).squeeze().T, - cmap='gray') - ax.imshow(np.take(compare, [compare.shape[i] // 2], - axis=i).squeeze().T, cmap='gist_heat', alpha=0.5) - ax.invert_yaxis() - ax.axis('off') - fig.tight_layout() - - -CT_orig = nib.load(misc_path / 'seeg' / 'sample_seeg_CT.mgz') - -# resample to T1's definition of world coordinates -CT_resampled = resample(moving=np.asarray(CT_orig.dataobj), - static=np.asarray(T1.dataobj), - moving_affine=CT_orig.affine, - static_affine=T1.affine) -plot_overlay(T1, CT_resampled, 'Unaligned CT Overlaid on T1', thresh=0.95) -del CT_resampled - -# %% -# Now we need to align our CT image to the T1 image. -# -# We want this to be a rigid transformation (just rotation + translation), -# so we don't do a full affine registration (that includes shear) here. -# This takes a while (~10 minutes) to execute so we skip actually running it -# here:: -# -# reg_affine, _ = mne.transforms.compute_volume_registration( -# CT_orig, T1, pipeline='rigids') -# -# Instead we just hard-code the resulting 4x4 matrix: - -reg_affine = np.array([ - [0.99270756, -0.03243313, 0.11610254, -133.094156], - [0.04374389, 0.99439665, -0.09623816, -97.58320673], - [-0.11233068, 0.10061512, 0.98856381, -84.45551601], - [0., 0., 0., 1.]]) - -# use a cval='1%' here to make the values outside the domain of the CT -# the same as the background level during interpolation -CT_aligned = mne.transforms.apply_volume_registration( - CT_orig, T1, reg_affine, cval='1%') -plot_overlay(T1, CT_aligned, 'Aligned CT Overlaid on T1', thresh=0.95) -del CT_orig - -# %% -# .. note:: -# Alignment failures sometimes occur which requires manual pre-alignment. -# Freesurfer's ``freeview`` can be used to to align manually -# -# .. code-block:: console -# -# $ freeview $MISC_PATH/seeg/sample_seeg/mri/T1.mgz \ -# $MISC_PATH/seeg/sample_seeg_CT.mgz:colormap=heat:opacity=0.6 -# -# - Navigate to the upper toolbar, go to -# :menuselection:`Tools --> Transform Volume` -# - Use the rotation and translation slide bars to align the CT -# to the MR (be sure to have the CT selected in the upper left menu) -# - Save the linear transform array (lta) file using the ``Save Reg...`` -# button -# -# Since we really require as much precision as possible for the -# alignment, we should rerun the algorithm starting with the manual -# alignment. This time, we just want to skip to the most exact rigid -# alignment, without smoothing, since the manual alignment is already -# very close. -# -# .. code-block:: python -# -# # load transform -# manual_reg_affine_vox = mne.read_lta(op.join( # the path used above -# misc_path, 'seeg', 'sample_seeg_CT_aligned_manual.mgz.lta')) -# # convert from vox->vox to ras->ras -# manual_reg_affine = \ -# CT_orig.affine @ np.linalg.inv(manual_reg_affine_vox) \ -# @ np.linalg.inv(CT_orig.affine) -# reg_affine, _ = mne.transforms.compute_volume_registration( -# CT_orig, T1, pipeline=['rigid'], -# starting_affine=manual_reg_affine) -# CT_aligned = mne.transforms.apply_volume_registration( -# CT_orig, T1, reg_affine, cval='1%') -# -# The rest of the tutorial can then be completed using ``CT_aligned`` -# from this point on. - -# %% -# We can now see how the CT image looks properly aligned to the T1 image. -# -# .. note:: -# The hyperintense skull is actually aligned to the hypointensity between -# the brain and the scalp. The brighter area surrounding the skull in the -# MR is actually subcutaneous fat. - -# make low intensity parts of the CT transparent for easier visualization -CT_data = CT_aligned.get_fdata().copy() -CT_data[CT_data < np.quantile(CT_data, 0.95)] = np.nan -T1_data = np.asarray(T1.dataobj) - -fig, axes = plt.subplots(1, 3, figsize=(12, 6)) -for ax in axes: - ax.axis('off') -axes[0].imshow(T1_data[T1.shape[0] // 2], cmap='gray') -axes[0].set_title('MR') -axes[1].imshow(np.asarray(CT_aligned.dataobj)[CT_aligned.shape[0] // 2], - cmap='gray') -axes[1].set_title('CT') -axes[2].imshow(T1_data[T1.shape[0] // 2], cmap='gray') -axes[2].imshow(CT_data[CT_aligned.shape[0] // 2], cmap='gist_heat', alpha=0.5) -for ax in (axes[0], axes[2]): - ax.annotate('Subcutaneous fat', (110, 52), xytext=(100, 30), - color='white', horizontalalignment='center', - arrowprops=dict(facecolor='white')) -for ax in axes: - ax.annotate('Skull (dark in MR, bright in CT)', (40, 175), - xytext=(120, 246), horizontalalignment='center', - color='white', arrowprops=dict(facecolor='white')) -axes[2].set_title('CT aligned to MR') -fig.tight_layout() -del CT_data, T1 - -# %% -# Now we need to estimate the "head" coordinate transform. -# -# MNE stores digitization montages in a coordinate frame called "head" -# defined by fiducial points (origin is halfway between the LPA and RPA -# see :ref:`tut-source-alignment`). For sEEG, it is convenient to get an -# estimate of the location of the fiducial points for the subject -# using the Talairach transform (see :func:`mne.coreg.get_mni_fiducials`) -# to use to define the coordinate frame so that we don't have to manually -# identify their location. - -# estimate head->mri transform -subj_trans = mne.coreg.estimate_head_mri_t( - 'sample_seeg', misc_path / 'seeg') - -# %% -# Marking the Location of Each Electrode Contact -# ============================================== -# -# Now, the CT and the MR are in the same space, so when you are looking at a -# point in CT space, it is the same point in MR space. So now everything is -# ready to determine the location of each electrode contact in the -# individual subject's anatomical space (T1-space). To do this, we can use the -# MNE intracranial electrode location graphical user interface. -# -# .. note:: The most useful coordinate frame for intracranial electrodes is -# generally the ``surface RAS`` coordinate frame because that is -# the coordinate frame that all the surface and image files that -# Freesurfer outputs are in, see :ref:`tut-freesurfer-mne`. These are -# useful for finding the brain structures nearby each contact and -# plotting the results. -# -# See the following video on how to operate the GUI or follow the steps below: -# -# .. youtube:: 8JWDJhXq0VY -# -# - Click in each image to navigate to each electrode contact -# - Select the contact name in the right panel -# - Press the "Mark" button or the "m" key to associate that -# position with that contact -# - Repeat until each contact is marked, they will both appear as circles -# in the plots and be colored in the sidebar when marked -# -# .. note:: The channel locations are saved to the ``raw`` object every time -# a location is marked or removed so there is no "Save" button. -# -# .. note:: Using the scroll or +/- arrow keys you can zoom in and out, -# and the up/down, left/right and page up/page down keys allow -# you to move one slice in any direction. This information is -# available in the help menu, accessible by pressing the "h" key. -# -# .. note:: If "Snap to Center" is on, this will use the radius so be -# sure to set it properly. - -# sphinx_gallery_thumbnail_number = 5 - -# load electrophysiology data to find channel locations for -# (the channels are already located in the example) - -raw = mne.io.read_raw(misc_path / 'seeg' / 'sample_seeg_ieeg.fif') - -# you may want to add `block=True` to halt execution until you have interacted -# with the GUI to find the channel positions, that way the raw object can -# be used later in the script (e.g. saved with channel positions) -mne_gui.locate_ieeg(raw.info, subj_trans, CT_aligned, - subject='sample_seeg', - subjects_dir=misc_path / 'seeg') -# The `raw` object is modified to contain the channel locations - -# %% -# Let's do a quick sidebar and show what this looks like for ECoG as well. - -T1_ecog = nib.load(misc_path / 'ecog' / 'sample_ecog' / 'mri' / 'T1.mgz') -CT_orig_ecog = nib.load(misc_path / 'ecog' / 'sample_ecog_CT.mgz') - -# pre-computed affine from `mne.transforms.compute_volume_registration` -reg_affine = np.array([ - [0.99982382, -0.00414586, -0.01830679, 0.15413965], - [0.00549597, 0.99721885, 0.07432601, -1.54316131], - [0.01794773, -0.07441352, 0.99706595, -1.84162514], - [0., 0., 0., 1.]]) -# align CT -CT_aligned_ecog = mne.transforms.apply_volume_registration( - CT_orig_ecog, T1_ecog, reg_affine, cval='1%') - -raw_ecog = mne.io.read_raw(misc_path / 'ecog' / 'sample_ecog_ieeg.fif') -# use estimated `trans` which was used when the locations were found previously -subj_trans_ecog = mne.coreg.estimate_head_mri_t( - 'sample_ecog', misc_path / 'ecog') -mne_gui.locate_ieeg(raw_ecog.info, subj_trans_ecog, CT_aligned_ecog, - subject='sample_ecog', - subjects_dir=misc_path / 'ecog') - -# %% -# For ECoG, we typically want to account for "brain shift" or shrinking of the -# brain away from the skull/dura due to changes in pressure during the -# craniotomy -# Note: this requires the BEM surfaces to have been computed e.g. using -# :ref:`mne watershed_bem` or :ref:`mne flash_bem`. -# First, let's plot the localized sensor positions without modification. - -# plot projected sensors -brain_kwargs = dict(cortex='low_contrast', alpha=0.2, background='white') -brain = mne.viz.Brain('sample_ecog', subjects_dir=misc_path / 'ecog', - title='Before Projection', **brain_kwargs) -brain.add_sensors(raw_ecog.info, trans=subj_trans_ecog) -view_kwargs = dict(azimuth=60, elevation=100, distance=350, - focalpoint=(0, 0, -15)) -brain.show_view(**view_kwargs) - -# %% -# Now, let's project the sensors to the brain surface and re-plot them. - -# project sensors to the brain surface -raw_ecog.info = mne.preprocessing.ieeg.project_sensors_onto_brain( - raw_ecog.info, subj_trans_ecog, 'sample_ecog', - subjects_dir=misc_path / 'ecog') - -# plot projected sensors -brain = mne.viz.Brain('sample_ecog', subjects_dir=misc_path / 'ecog', - title='After Projection', **brain_kwargs) -brain.add_sensors(raw_ecog.info, trans=subj_trans_ecog) -brain.show_view(**view_kwargs) - -# %% -# Let's plot the electrode contact locations on the subject's brain. -# -# MNE stores digitization montages in a coordinate frame called "head" -# defined by fiducial points (origin is halfway between the LPA and RPA -# see :ref:`tut-source-alignment`). For sEEG, it is convenient to get an -# estimate of the location of the fiducial points for the subject -# using the Talairach transform (see :func:`mne.coreg.get_mni_fiducials`) -# to use to define the coordinate frame so that we don't have to manually -# identify their location. The estimated head->mri ``trans`` was used -# when the electrode contacts were localized so we need to use it again here. - -# plot the alignment -brain = mne.viz.Brain('sample_seeg', subjects_dir=misc_path / 'seeg', - **brain_kwargs) -brain.add_sensors(raw.info, trans=subj_trans) -brain.show_view(**view_kwargs) - -# %% -# Warping to a Common Atlas -# ========================= -# -# Electrode contact locations are often compared across subjects in a template -# space such as ``fsaverage`` or ``cvs_avg35_inMNI152``. To transform electrode -# contact locations to that space, we need to determine a function that maps -# from the subject's brain to the template brain. We will use the symmetric -# diffeomorphic registration (SDR) implemented by ``Dipy`` to do this. -# -# Before we can make a function to account for individual differences in the -# shape and size of brain areas, we need to fix the alignment of the brains. -# The plot below shows that they are not yet aligned. - -# load the subject's brain and the Freesurfer "fsaverage" template brain -subject_brain = nib.load( - misc_path / 'seeg' / 'sample_seeg' / 'mri' / 'brain.mgz') -template_brain = nib.load( - subjects_dir / 'fsaverage' / 'mri' / 'brain.mgz') - -plot_overlay(template_brain, subject_brain, - 'Alignment with fsaverage before Affine Registration') - -# %% -# Now, we'll register the affine of the subject's brain to the template brain. -# This aligns the two brains, preparing the subject's brain to be warped -# to the template. -# -# .. warning:: Here we use custom ``zooms`` just for speed (this downsamples -# the image resolution), in general we recommend using -# ``zooms=None`` (default) for highest accuracy! - -zooms = dict(translation=10, rigid=10, affine=10, sdr=5) -reg_affine, sdr_morph = mne.transforms.compute_volume_registration( - subject_brain, template_brain, zooms=zooms, verbose=True) -subject_brain_sdr = mne.transforms.apply_volume_registration( - subject_brain, template_brain, reg_affine, sdr_morph) - -# apply the transform to the subject brain to plot it -plot_overlay(template_brain, subject_brain_sdr, - 'Alignment with fsaverage after SDR Registration') - -# %% -# Finally, we'll apply the registrations to the electrode contact coordinates. -# The brain image is warped to the template but the goal was to warp the -# positions of the electrode contacts. To do that, we'll make an image that is -# a lookup table of the electrode contacts. In this image, the background will -# be ``0`` s all the bright voxels near the location of the first contact will -# be ``1`` s, the second ``2`` s and so on. This image can then be warped by -# the SDR transform. We can finally recover a position by averaging the -# positions of all the voxels that had the contact's lookup number in -# the warped image. - -# first we need our montage but it needs to be converted to "mri" coordinates -# using our ``subj_trans`` -montage = raw.get_montage() -montage.apply_trans(subj_trans) - -# warp the montage -montage_warped = mne.preprocessing.ieeg.warp_montage( - montage, subject_brain, template_brain, reg_affine, sdr_morph) - -# visualize using an image of the electrode contacts to see their sizes -elec_image = mne.preprocessing.ieeg.make_montage_volume( - montage, CT_aligned, thresh=0.25) - -# warp image using transforms -warped_elec_image = mne.transforms.apply_volume_registration( - elec_image, template_brain, reg_affine, sdr_morph, - interpolation='nearest') - -fig, axes = plt.subplots(2, 1, figsize=(8, 8)) -nilearn.plotting.plot_glass_brain(elec_image, axes=axes[0], cmap='Dark2') -fig.text(0.1, 0.65, 'Subject T1', rotation='vertical') -nilearn.plotting.plot_glass_brain(warped_elec_image, axes=axes[1], - cmap='Dark2') -fig.text(0.1, 0.25, 'fsaverage', rotation='vertical') -fig.suptitle('Electrodes warped to fsaverage') - -del CT_aligned, subject_brain, template_brain - -# %% -# We can now plot the result. You can compare this to the plot in -# :ref:`tut-working-with-seeg` to see the difference between this morph, which -# is more complex, and the less-complex, linear Talairach transformation. -# By accounting for the shape of this particular subject's brain using the -# SDR to warp the positions of the electrode contacts, the position in the -# template brain is able to be more accurately estimated. -# -# .. note:: The accuracy of warping to the template has been degraded by -# using ``zooms`` to downsample the image before registration -# which makes some of the contacts inaccurately appear outside -# the brain. - -# first we need to add fiducials so that we can define the "head" coordinate -# frame in terms of them (with the origin at the center between LPA and RPA) -montage_warped.add_estimated_fiducials('fsaverage', subjects_dir) - -# compute the head<->mri ``trans`` now using the fiducials -template_trans = mne.channels.compute_native_head_t(montage_warped) - -# now we can set the montage and, because there are fiducials in the montage, -# the montage will be properly transformed to "head" coordinates when we do -# (this step uses ``template_trans`` but it is recomputed behind the scenes) -raw.set_montage(montage_warped) - -# plot the resulting alignment -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, **brain_kwargs) -brain.add_sensors(raw.info, trans=template_trans) -brain.show_view(**view_kwargs) - -# %% -# This pipeline was developed based on previous work -# :footcite:`HamiltonEtAl2017`. -# -# References -# ========== -# -# .. footbibliography:: diff --git a/tutorials/clinical/README.txt b/tutorials/clinical/README.txt index 43e5f701fa2..b0c11e7b344 100644 --- a/tutorials/clinical/README.txt +++ b/tutorials/clinical/README.txt @@ -1,4 +1,37 @@ Clinical applications --------------------- -These tutorials illustrate clinical uses of MNE-Python. +These tutorials illustrate some clinical use cases. + +MNE-GUI-addons examples +^^^^^^^^^^^^^^^^^^^^^^^ +The :mod:`mne-gui-addons:mne_gui_addons` package supports some clinical use cases: + +.. raw:: html + +
+ +.. raw:: html + +
+ +.. only:: html + + .. image:: https://mne.tools/mne-gui-addons/_images/sphx_glr_ieeg_locate_005.png + :alt: + + :ref:`mne-gui-addons:tut-ieeg-localize` + +.. raw:: html + +
Locating intracranial electrode contacts
+
+ +.. raw:: html + +
+ +MNE-Python examples +^^^^^^^^^^^^^^^^^^^ + +MNE-Python also supports some clinical use cases directly: From 42b29db564564155a77352bae4303068f7d78935 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 28 Apr 2023 06:46:45 -0500 Subject: [PATCH 0035/1125] clarify relationships among FWHM, sigma, and n_cycles in Morlet function docstrings (#11658) --- mne/utils/docs.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index b07f06d68a6..4b200154c0b 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1435,16 +1435,34 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ docdict['fwhm_morlet_notes'] = r""" -In wavelet analysis, the oscillation that is defined by ``n_cycles`` is tapered -by a Gaussian taper, i.e., the edges of the wavelet are dampened. This means -that reporting the number of cycles is not necessarily helpful for -understanding the amount of temporal smoothing that has been applied (see -:footcite:`Cohen2019`). Instead, the full width at half-maximum (FWHM) of the -wavelet can be reported. - -The FWHM of the wavelet at a specific frequency is defined as: -:math:`\mathrm{FWHM} = \frac{\mathtt{n\_cycles} \times \sqrt{2 \ln{2}}}{\pi \times \mathtt{freq}}` -(cf. eq. 4 in :footcite:`Cohen2019`). +Convolution of a signal with a Morlet wavelet will impose temporal smoothing +that is determined by the duration of the wavelet. In MNE-Python, the duration +of the wavelet is determined by the ``sigma`` parameter, which gives the +standard deviation of the wavelet's Gaussian envelope (our wavelets extend to +±5 standard deviations to ensure values very close to zero at the endpoints). +Some authors (e.g., :footcite:`Cohen2019`) recommend specifying and reporting +wavelet duration in terms of the full-width half-maximum (FWHM) of the +wavelet's Gaussian envelope. The FWHM is related to ``sigma`` by the following +identity: :math:`\mathrm{FWHM} = \sigma \times 2 \sqrt{2 \ln{2}}` (or the +equivalent in Python code: ``fwhm = sigma * 2 * np.sqrt(2 * np.log(2))``). +If ``sigma`` is not provided, it is computed from ``n_cycles`` as +:math:`\frac{\mathtt{n\_cycles}}{2 \pi f}` where :math:`f` is the frequency of +the wavelet oscillation (given by ``freqs``). Thus when ``sigma=None`` the FWHM +will be given by + +.. math:: + + \mathrm{FWHM} = \frac{\mathtt{n\_cycles} \times \sqrt{2 \ln{2}}}{\pi \times f} + +(cf. eq. 4 in :footcite:`Cohen2019`). To create wavelets with a chosen FWHM, +one can compute:: + + n_cycles = desired_fwhm * np.pi * np.array(freqs) / np.sqrt(2 * np.log(2)) + +to get an array of values for ``n_cycles`` that yield the desired FWHM at each +frequency in ``freqs``. If you want different FWHM values at each frequency, +do the same computation with ``desired_fwhm`` as an array of the same shape as +``freqs``. """ # noqa E501 # %% From 32000d4983ba904a5660ec7c31818d881e1cfa61 Mon Sep 17 00:00:00 2001 From: Florin Pop Date: Fri, 28 Apr 2023 15:06:52 +0200 Subject: [PATCH 0036/1125] FIX: Set UseHighDpiPixmaps only for PyQt5 and PySide2 (#11662) --- doc/changes/latest.inc | 1 + doc/changes/names.inc | 2 ++ mne/utils/check.py | 11 ++++++----- mne/viz/backends/_utils.py | 10 ++++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 5cefd8aa8a6..bc442f5cd3d 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -48,6 +48,7 @@ Enhancements Bugs ~~~~ +- Improving compatibility with Qt6 by removing the use of deprecated ``AA_UseHighDpiPixmaps`` attribute for this Qt version (:gh:`11662` by :newcontrib:`Florin Pop`) - Fix :func:`mne.time_frequency.psd_array_multitaper` docstring where argument ``bandwidth`` incorrectly reported argument as half-bandwidth and gave wrong explanation of default value (:gh:`11479` by :newcontrib: `Tom Stone`_) - Fix bug where installation of a package depending on ``mne`` will error when done in an environment where ``setuptools`` is not present (:gh:`11454` by :newcontrib: `Arne Pelzer`_) - Fix bug where :func:`mne.preprocessing.regress_artifact` and :class:`mne.preprocessing.EOGRegression` incorrectly tracked ``picks`` (:gh:`11366` by `Eric Larson`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 7508543cc39..8130fefeed6 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -164,6 +164,8 @@ .. _Felix Raimundo: https://github.com/gamazeps +.. _Florin Pop: https://github.com/florin-pop + .. _Frederik Weber: https://github.com/Frederik-D-Weber .. _Fu-Te Wong: https://github.com/zuxfoucault diff --git a/mne/utils/check.py b/mne/utils/check.py index e351184680f..95c598cf908 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -859,7 +859,7 @@ def _check_stc_units(stc, threshold=1e-7): # 100 nAm threshold for warning % (1e9 * max_cur)) -def _check_qt_version(*, return_api=False): +def _check_qt_version(*, return_api=False, check_usable_display=True): """Check if Qt is installed.""" from ..viz.backends._utils import _init_mne_qtapp try: @@ -874,10 +874,11 @@ def _check_qt_version(*, return_api=False): # Having Qt installed is not enough -- sometimes the app is unusable # for example because there is no usable display (e.g., on a server), # so we have to try instantiating one to actually know. - try: - _init_mne_qtapp() - except Exception: - api = version = None + if check_usable_display: + try: + _init_mne_qtapp() + except Exception: + api = version = None if return_api: return version, api else: diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index 3f75a5b7f37..c47d75b26cc 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -124,6 +124,8 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): from qtpy.QtCore import Qt from qtpy.QtGui import QIcon, QPixmap, QGuiApplication from qtpy.QtWidgets import QApplication, QSplashScreen + from ...fixes import _compare_version + from ...utils import _check_qt_version app_name = 'MNE-Python' organization_name = 'MNE' @@ -160,10 +162,10 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): app = QApplication([app_name]) app.setApplicationName(app_name) app.setOrganizationName(organization_name) - try: - app.setAttribute(Qt.AA_UseHighDpiPixmaps) # works on PyQt5 and PySide2 - except AttributeError: - pass # not required on PyQt6 and PySide6 anyway + qt_version = _check_qt_version(check_usable_display=False) + # HiDPI is enabled by default in Qt6, requires to be explicitly set for Qt5 + if _compare_version(qt_version, '<', '6.0'): + app.setAttribute(Qt.AA_UseHighDpiPixmaps) if enable_icon or splash: icons_path = _qt_init_icons() From 18b80903da84483325584d8dffe7c9f847f12b81 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 28 Apr 2023 10:33:47 -0400 Subject: [PATCH 0037/1125] DOC: Fix table formatting (#11663) --- .circleci/config.yml | 6 +++--- doc/_static/style.css | 19 +++++++++++++++---- .../10_preprocessing_overview.py | 1 - 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 554f27fc55b..067b971dfc7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,9 +19,9 @@ jobs: default: "false" docker: - image: cimg/base:current-22.04 - # medium 2 vCPUs, 4GB mem; medium+ 3vCPUs 6GB mem; large 4 vCPUs 8GB mem - # https://circleci.com/docs/configuration-reference#resourceclass - resource_class: medium+ + # large 4 vCPUs 15GB mem + # https://discuss.circleci.com/t/changes-to-remote-docker-reporting-pricing/47759 + resource_class: large steps: - restore_cache: keys: diff --git a/doc/_static/style.css b/doc/_static/style.css index 81df2a5637e..b2648b80d6c 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -9,7 +9,6 @@ --mne-color-discord: #5865F2; --mne-color-twitter: #55acee; --mne-color-primary: #007bff; - --mne-color-primary-text: #fff; --mne-color-primary-highlight: #0063cc; /* font weight */ --mne-font-weight-semibold: 600; @@ -29,7 +28,6 @@ html[data-theme="light"] { --mne-color-heading: #003e80; /* pydata-sphinx-theme overrides */ --pst-color-primary: var(--mne-color-primary); - --pst-color-primary-text: var(--mne-color-primary-text); --pst-color-primary-highlight: var(--mne-color-primary-highlight); --pst-color-info: var(--pst-color-primary); --pst-color-border: #ccc; @@ -57,7 +55,6 @@ html[data-theme="dark"] { --mne-color-heading: #b8cbe0; /* pydata-sphinx-theme overrides */ --pst-color-primary: var(--mne-color-primary); - --pst-color-primary-text: var(--mne-color-primary-text); --pst-color-primary-highlight: var(--mne-color-primary-highlight); --pst-color-info: var(--pst-color-primary); --pst-color-border: #333; @@ -67,7 +64,7 @@ html[data-theme="dark"] { --sg-download-a-background-color: var(--pst-color-primary); --sg-download-a-background-image: unset; --sg-download-a-border-color: var(--pst-color-border); - --sg-download-a-color: #fff; + --sg-download-a-color: #000; --sg-download-a-hover-background-color: var(--pst-color-primary-highlight); --sg-download-a-hover-box-shadow-1: none; --sg-download-a-hover-box-shadow-2: none; @@ -163,6 +160,20 @@ iframe.sg_report { top: 0; } +/* TODO: Either pydata-sphinx-theme (for using Bootstrap) or sphinx-gallery (for adding table formatting) should fix this */ +.table-striped-columns>:not(caption)>tr>:nth-child(2n),.table-striped>tbody>tr:nth-of-type(odd)>* { + --bs-table-accent-bg: var(--bs-table-striped-bg); + color: var(--pst-color-text-base); +} +.table-hover>tbody>tr:hover>* { + --bs-table-accent-bg: var(--bs-table-hover-bg); + color: var(--pst-color-text-base); +} +.rendered_html table { + color: var(--pst-color-text-base); +} + + /* ***************************************************** sphinx-design fixes */ p.btn a { color: unset; diff --git a/tutorials/preprocessing/10_preprocessing_overview.py b/tutorials/preprocessing/10_preprocessing_overview.py index c07a05cc3e9..a679a6267f6 100644 --- a/tutorials/preprocessing/10_preprocessing_overview.py +++ b/tutorials/preprocessing/10_preprocessing_overview.py @@ -130,7 +130,6 @@ # around 20 seconds, so in this case a cutoff of 0.1 Hz would probably suppress # most of the drift). # -# # Power line noise # ~~~~~~~~~~~~~~~~ # From e9fd6e78283f4c0655dfe80e1e86053181543f8f Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sun, 30 Apr 2023 18:46:10 -0500 Subject: [PATCH 0038/1125] BUG: Fix bug with annotation rename (#11666) Co-authored-by: Timur Sokhin --- doc/changes/latest.inc | 1 + doc/changes/names.inc | 2 ++ mne/annotations.py | 7 ++----- mne/tests/test_annotations.py | 5 +++++ 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index bc442f5cd3d..a3129a443da 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -51,6 +51,7 @@ Bugs - Improving compatibility with Qt6 by removing the use of deprecated ``AA_UseHighDpiPixmaps`` attribute for this Qt version (:gh:`11662` by :newcontrib:`Florin Pop`) - Fix :func:`mne.time_frequency.psd_array_multitaper` docstring where argument ``bandwidth`` incorrectly reported argument as half-bandwidth and gave wrong explanation of default value (:gh:`11479` by :newcontrib: `Tom Stone`_) - Fix bug where installation of a package depending on ``mne`` will error when done in an environment where ``setuptools`` is not present (:gh:`11454` by :newcontrib: `Arne Pelzer`_) +- Fix bug in :meth:`mne.Annotations.rename` where replacements were not done correctly (:gh:`11666` by :newcontrib:`Timur Sokhin`_ and `Eric Larson`_) - Fix bug where :func:`mne.preprocessing.regress_artifact` and :class:`mne.preprocessing.EOGRegression` incorrectly tracked ``picks`` (:gh:`11366` by `Eric Larson`_) - Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_) - Fix bug where splash screen would not always disappear (:gh:`11398` by `Eric Larson`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 8130fefeed6..1894afa0db8 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -516,6 +516,8 @@ .. _Timothy Gates: https://au.linkedin.com/in/tim-gates-0528a4199 +.. _Timur Sokhin: https://github.com/Qwinpin + .. _Tod Flak: https://github.com/todflak .. _Tom Ma: https://github.com/myd7349 diff --git a/mne/annotations.py b/mne/annotations.py index 24cc4069971..00de96d32e4 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -674,11 +674,8 @@ def rename(self, mapping, verbose=None): _validate_type(mapping, dict) _check_dict_keys(mapping, self.description, valid_key_source="data", key_description="Annotation description(s)") - - for old, new in mapping.items(): - self.description = [d.replace(old, new) for d in self.description] - - self.description = np.array(self.description) + self.description = np.array( + [str(mapping.get(d, d)) for d in self.description]) return self diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 09f645204dd..d1a311bc9ae 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1416,6 +1416,11 @@ def test_annotation_rename(): with pytest.raises(TypeError, match="dict, got instead"): a.rename({"wrong"}) + a = np.array([[0, 0, 11], [1000, 0, 1], [1230, 0, 111]]) + a = mne.annotations_from_events(a, 256) + a.rename({'1': 'A', '11': 'B', '111': 'C'}) + assert_array_equal(a.description, ['B', 'A', 'C']) + def test_annotation_duration_setting(): """Test annotation duration setting works.""" From ccdbbdcc501042045c30f44454ed1d704c817550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 1 May 2023 18:36:34 +0200 Subject: [PATCH 0039/1125] MRG: Allow silencing of "The unit for channel(s) ... has changed" warnings (#11668) --- doc/changes/latest.inc | 1 + mne/channels/channels.py | 19 +++++++++++++++---- mne/channels/tests/test_channels.py | 12 +++++++++++- mne/channels/tests/test_interpolation.py | 3 +-- mne/export/tests/test_export.py | 3 +-- mne/preprocessing/tests/test_ecg.py | 4 +--- mne/tests/test_epochs.py | 5 +++-- mne/viz/tests/test_topo.py | 6 ++---- 8 files changed, 35 insertions(+), 18 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index a3129a443da..ee3cb5e1311 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -45,6 +45,7 @@ Enhancements - Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) - :func:`mne.channels.make_1020_channel_selections` gained a new parameter, ``return_ch_names``, to allow for easy retrieval of EEG channel names corresponding to the left, right, and midline portions of the montage (:gh:`11632` by `Richard Höchenberger`_) +- Methods for setting the sensor types of channels (e.g., for raw data, :meth:`mne.io.Raw.set_channel_types`) gained a new parameter, ``on_unit_change``, to control behavior (raise an exception, emit a warning, or do nothing) in case the measurement unit is adjusted automatically (:gh:`11668` by `Richard Höchenberger`_) Bugs ~~~~ diff --git a/mne/channels/channels.py b/mne/channels/channels.py index d0ae5f01673..433130c309b 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -308,14 +308,19 @@ def _set_channel_positions(self, pos, names): raise ValueError(msg) @verbose - def set_channel_types(self, mapping, verbose=None): - """Define the sensor type of channels. + def set_channel_types(self, mapping, *, on_unit_change='warn', verbose=None): + """Specify the sensor types of channels. Parameters ---------- mapping : dict - A dictionary mapping a channel to a sensor type (str), e.g., + A dictionary mapping channel names to sensor types, e.g., ``{'EEG061': 'eog'}``. + on_unit_change : 'raise' | 'warn' | 'ignore' + What to do if the measurement unit of a channel is changed + automatically to match the new sensor type. + + .. versionadded:: 1.4 %(verbose)s Returns @@ -388,9 +393,15 @@ def set_channel_types(self, mapping, verbose=None): else: coil_type = FIFF.FIFFV_COIL_NONE self.info['chs'][c_ind]['coil_type'] = coil_type + msg = "The unit for channel(s) {0} has changed from {1} to {2}." for this_change, names in unit_changes.items(): - warn(msg.format(", ".join(sorted(names)), *this_change)) + _on_missing( + on_missing=on_unit_change, + msg=msg.format(", ".join(sorted(names)), *this_change), + name='on_unit_change', + ) + return self @verbose diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index 585ce9b43cc..c811c682ea1 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -139,8 +139,18 @@ def test_set_channel_types(): with pytest.raises(RuntimeError, match='type .* in projector "PCA-v1"'): raw2.set_channel_types(mapping) # has prj raw2.add_proj([], remove_existing=True) + + # Should raise + with pytest.raises(ValueError, match='unit for channel.* has changed'): + raw2.copy().set_channel_types(mapping, on_unit_change='raise') + + # Should warn with pytest.warns(RuntimeWarning, match='unit for channel.* has changed'): - raw2 = raw2.set_channel_types(mapping) + raw2.copy().set_channel_types(mapping) + + # Shouldn't warn + raw2.set_channel_types(mapping, on_unit_change='ignore') + info = raw2.info assert info['chs'][371]['ch_name'] == 'EEG 057' assert info['chs'][371]['kind'] == FIFF.FIFFV_DBS_CH diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 092e7ae87c5..f6c71d1ff00 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -182,8 +182,7 @@ def test_interpolation_meg(): # before MEG channels raw.crop(0, 0.1).load_data().pick_channels(epochs_meg.ch_names) raw.info.normalize_proj() - with pytest.warns(RuntimeWarning, match='unit .* changed from .* to .*'): - raw.set_channel_types({raw.ch_names[0]: 'stim'}) + raw.set_channel_types({raw.ch_names[0]: 'stim'}, on_unit_change='ignore') raw.info['bads'] = [raw.ch_names[1]] raw.load_data() raw.interpolate_bads(mode='fast') diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 69679e5a7cd..27e29ab343f 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -260,8 +260,7 @@ def test_rawarray_edf(tmp_path): # test that warning is raised if there are non-voltage based channels raw = RawArray(data, info) - with pytest.warns(RuntimeWarning, match='The unit'): - raw.set_channel_types({'9': 'hbr'}) + raw.set_channel_types({'9': 'hbr'}, on_unit_change='ignore') with pytest.warns(RuntimeWarning, match='Non-voltage channels'): raw.export(temp_fname, overwrite=True) diff --git a/mne/preprocessing/tests/test_ecg.py b/mne/preprocessing/tests/test_ecg.py index 92f8e361514..35be10511a4 100644 --- a/mne/preprocessing/tests/test_ecg.py +++ b/mne/preprocessing/tests/test_ecg.py @@ -1,6 +1,5 @@ from pathlib import Path -import pytest import numpy as np from mne.io import read_raw_fif @@ -76,8 +75,7 @@ def test_find_ecg(): # test with user provided ecg channel raw.del_proj() assert 'MEG 2641' in raw.ch_names - with pytest.warns(RuntimeWarning, match='unit for channel'): - raw.set_channel_types({'MEG 2641': 'ecg'}) + raw.set_channel_types({'MEG 2641': 'ecg'}, on_unit_change='ignore') create_ecg_epochs(raw) raw.pick_types(meg=True) # remove ECG diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 3f2fce63bb3..65fa099d283 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -1505,8 +1505,9 @@ def test_evoked_io_from_epochs(tmp_path): picks = pick_types(raw.info, meg=True, eeg=True) epochs = Epochs(raw, events[:4], event_id, -0.2, tmax, picks=picks, baseline=(0.1, 0.2), decim=5) - with pytest.warns(RuntimeWarning, match='unit for.*changed from'): - epochs.set_channel_types({epochs.ch_names[0]: 'syst'}) + epochs.set_channel_types( + {epochs.ch_names[0]: 'syst'}, on_unit_change='ignore' + ) evokeds = list() for picks in (None, 'all'): evoked = epochs.average(picks) diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 5d39c494b67..997dfa002fc 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -111,15 +111,13 @@ def return_inds(d): # to test function kwarg to zorder arg of evoked.plot # test sEEG (gh:8733) evoked.del_proj().pick_types('mag') # avoid overlapping positions error mapping = {ch_name: 'seeg' for ch_name in evoked.ch_names} - with pytest.warns(RuntimeWarning, match='The unit .* has changed from .*'): - evoked.set_channel_types(mapping) + evoked.set_channel_types(mapping, on_unit_change='ignore') evoked.plot_joint() # test DBS (gh:8739) evoked = _get_epochs().average().pick_types('mag') mapping = {ch_name: 'dbs' for ch_name in evoked.ch_names} - with pytest.warns(RuntimeWarning, match='The unit for'): - evoked.set_channel_types(mapping) + evoked.set_channel_types(mapping, on_unit_change='ignore') evoked.plot_joint() plt.close('all') From daec3da7c7ecfc8a2db53f438d0c64f4703d78d0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 1 May 2023 13:01:03 -0500 Subject: [PATCH 0040/1125] API: Deprecate ordered=False and legacy some functions (#11665) --- .circleci/config.yml | 37 +----- doc/changes/latest.inc | 11 +- doc/conf.py | 3 +- examples/inverse/multi_dipole_model.py | 8 +- .../time_frequency_mixed_norm_inverse.py | 1 - examples/preprocessing/eeg_bridging.py | 4 +- mne/beamformer/_rap_music.py | 8 +- mne/beamformer/tests/test_lcmv.py | 9 +- mne/channels/channels.py | 12 +- mne/channels/tests/test_channels.py | 6 +- mne/chpi.py | 4 +- mne/cov.py | 17 +-- mne/epochs.py | 3 +- mne/event.py | 3 +- mne/forward/forward.py | 9 +- mne/inverse_sparse/mxne_inverse.py | 4 +- mne/io/base.py | 2 +- mne/io/pick.py | 108 ++++++++++-------- mne/io/tests/test_meas_info.py | 2 +- mne/io/tests/test_pick.py | 59 ++++++---- mne/preprocessing/tests/test_ica.py | 12 +- mne/rank.py | 3 +- mne/tests/test_chpi.py | 3 +- mne/tests/test_docstring_parameters.py | 1 - mne/tests/test_event.py | 2 +- mne/time_frequency/csd.py | 11 +- mne/utils/check.py | 2 +- mne/utils/docs.py | 10 ++ mne/viz/raw.py | 4 +- mne/viz/tests/test_3d.py | 4 +- mne/viz/tests/test_epochs.py | 2 +- mne/viz/tests/test_ica.py | 11 +- mne/viz/tests/test_topo.py | 6 +- mne/viz/tests/test_topomap.py | 2 +- mne/viz/topomap.py | 5 +- mne/viz/utils.py | 8 +- tools/azure_dependencies.sh | 7 +- tools/circleci_bash_env.sh | 27 +++++ tools/circleci_dependencies.sh | 18 +-- tutorials/simulation/80_dics.py | 7 +- .../40_cluster_1samp_time_freq.py | 3 +- 41 files changed, 238 insertions(+), 220 deletions(-) create mode 100755 tools/circleci_bash_env.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index 067b971dfc7..84731e7c6f0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -88,30 +88,8 @@ jobs: - run: name: Set BASH_ENV - command: | - set -e - set -o pipefail - ./tools/setup_xvfb.sh - sudo apt install -qq graphviz optipng python3.10-venv python3-venv libxft2 ffmpeg - python3.10 -m venv ~/python_env - echo "set -e" >> $BASH_ENV - echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV - echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV - echo "export MNE_FULL_DATE=true" >> $BASH_ENV - source tools/get_minimal_commands.sh - echo "export MNE_3D_BACKEND=pyvistaqt" >> $BASH_ENV - echo "export MNE_3D_OPTION_MULTI_SAMPLES=1" >> $BASH_ENV - echo "export MNE_BROWSER_BACKEND=qt" >> $BASH_ENV - echo "export MNE_BROWSER_PRECOMPUTE=false" >> $BASH_ENV - echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV - echo "export DISPLAY=:99" >> $BASH_ENV - echo "source ~/python_env/bin/activate" >> $BASH_ENV - mkdir -p ~/.local/bin - ln -s ~/python_env/bin/python ~/.local/bin/python - echo "BASH_ENV:" - cat $BASH_ENV - mkdir -p ~/mne_data - touch pattern.txt + command: ./tools/circleci_bash_env.sh + - run: name: check neuromag2ft command: | @@ -398,18 +376,12 @@ jobs: type: string default: "false" docker: - - image: circleci/python:3.9.2-buster + - image: cimg/base:current-22.04 steps: - restore_cache: keys: - source-cache - checkout - - run: - name: Set BASH_ENV - command: | - set -e - echo "set -e" >> $BASH_ENV - echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV - run: name: Check-skip command: | @@ -418,6 +390,9 @@ jobs: echo "Skip detected, exiting job ${CIRCLE_JOB}." circleci-agent step halt; fi + - run: + name: Set BASH_ENV + command: ./tools/circleci_bash_env.sh - restore_cache: keys: - pip-cache diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index ee3cb5e1311..851e3b04587 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -40,7 +40,7 @@ Enhancements - Add a video to :ref:`tut-freesurfer-mne` of a brain inflating from the pial surface to aid in understanding the inflated brain (:gh:`11440` by `Alex Rockhill`_) - Add automatic projection of sEEG contact onto the inflated surface for :meth:`mne.viz.Brain.add_sensors` (:gh:`11436` by `Alex Rockhill`_) - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) -- Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) +- Use new :meth:`dipy.align.imwarp.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) - Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_) - Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) @@ -50,9 +50,9 @@ Enhancements Bugs ~~~~ - Improving compatibility with Qt6 by removing the use of deprecated ``AA_UseHighDpiPixmaps`` attribute for this Qt version (:gh:`11662` by :newcontrib:`Florin Pop`) -- Fix :func:`mne.time_frequency.psd_array_multitaper` docstring where argument ``bandwidth`` incorrectly reported argument as half-bandwidth and gave wrong explanation of default value (:gh:`11479` by :newcontrib: `Tom Stone`_) -- Fix bug where installation of a package depending on ``mne`` will error when done in an environment where ``setuptools`` is not present (:gh:`11454` by :newcontrib: `Arne Pelzer`_) -- Fix bug in :meth:`mne.Annotations.rename` where replacements were not done correctly (:gh:`11666` by :newcontrib:`Timur Sokhin`_ and `Eric Larson`_) +- Fix :func:`mne.time_frequency.psd_array_multitaper` docstring where argument ``bandwidth`` incorrectly reported argument as half-bandwidth and gave wrong explanation of default value (:gh:`11479` by :newcontrib:`Tom Stone`) +- Fix bug where installation of a package depending on ``mne`` will error when done in an environment where ``setuptools`` is not present (:gh:`11454` by :newcontrib:`Arne Pelzer`) +- Fix bug in :meth:`mne.Annotations.rename` where replacements were not done correctly (:gh:`11666` by :newcontrib:`Timur Sokhin` and `Eric Larson`_) - Fix bug where :func:`mne.preprocessing.regress_artifact` and :class:`mne.preprocessing.EOGRegression` incorrectly tracked ``picks`` (:gh:`11366` by `Eric Larson`_) - Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_) - Fix bug where splash screen would not always disappear (:gh:`11398` by `Eric Larson`_) @@ -83,3 +83,6 @@ API changes - Deprecate arguments ``kind`` and ``path`` from :func:`mne.channels.read_layout` in favor of a common argument ``fname`` (:gh:`11500` by `Mathieu Scheltienne`_) - Change ``aligned_ct`` positional argument in ``mne.gui.locate_ieeg`` to ``base_image`` to reflect that this can now be used with unaligned images (:gh:`11567` by `Alex Rockhill`_) - ``mne.warp_montage_volume`` was deprecated in favor of :func:`mne.preprocessing.ieeg.warp_montage` (acts directly on points instead of using an intermediate volume) and :func:`mne.preprocessing.ieeg.make_montage_volume` (which makes a volume of ieeg contact locations which can still be useful) (:gh:`11572` by `Alex Rockhill`_) +- Deprecate ``mne.pick_channels_evoked`` in favor of ``evoked.copy().pick(...)`` (:gh:`11665` by `Eric Larson`_) +- Set instance methods ``inst.pick_types`` and ``inst.pick_channels`` as legacy in favor of ``inst.pick(...)`` (:gh:`11665` by `Eric Larson`_) +- The default of ``inst.pick_channels(..., ordered=False)`` will change to ``ordered=True`` in 1.5 to avoid silent bugs (:gh:`11665` by `Eric Larson`_) diff --git a/doc/conf.py b/doc/conf.py index 86b8967634d..4897d6f1527 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -160,8 +160,7 @@ 'picard': ('https://pierreablin.github.io/picard/', None), 'qdarkstyle': ('https://qdarkstylesheet.readthedocs.io/en/latest', None), 'eeglabio': ('https://eeglabio.readthedocs.io/en/latest', None), - 'dipy': ('https://dipy.org/documentation/latest/', - 'https://dipy.org/documentation/latest/objects.inv/'), + 'dipy': ('https://dipy.org/documentation/latest/', None), 'pooch': ('https://www.fatiando.org/pooch/latest/', None), 'pybv': ('https://pybv.readthedocs.io/en/latest/', None), 'pyqtgraph': ('https://pyqtgraph.readthedocs.io/en/latest/', None), diff --git a/examples/inverse/multi_dipole_model.py b/examples/inverse/multi_dipole_model.py index 2dbe6362157..afed2d738df 100644 --- a/examples/inverse/multi_dipole_model.py +++ b/examples/inverse/multi_dipole_model.py @@ -78,13 +78,13 @@ # sensors on the right side of the helmet. picks_left = read_vectorview_selection('Left', info=info) evoked_fit_left = evoked_left.copy().crop(0.08, 0.08) -evoked_fit_left.pick_channels(picks_left) -cov_fit_left = cov.copy().pick_channels(picks_left) +evoked_fit_left.pick_channels(picks_left, ordered=False) +cov_fit_left = cov.copy().pick_channels(picks_left, ordered=False) picks_right = read_vectorview_selection('Right', info=info) evoked_fit_right = evoked_right.copy().crop(0.08, 0.08) -evoked_fit_right.pick_channels(picks_right) -cov_fit_right = cov.copy().pick_channels(picks_right) +evoked_fit_right.pick_channels(picks_right, ordered=False) +cov_fit_right = cov.copy().pick_channels(picks_right, ordered=False) # Any SSS projections that are active on this data need to be re-normalized # after picking channels. diff --git a/examples/inverse/time_frequency_mixed_norm_inverse.py b/examples/inverse/time_frequency_mixed_norm_inverse.py index 27968bd8971..2271c58f24c 100644 --- a/examples/inverse/time_frequency_mixed_norm_inverse.py +++ b/examples/inverse/time_frequency_mixed_norm_inverse.py @@ -50,7 +50,6 @@ # Handling average file condition = 'Left visual' evoked = mne.read_evokeds(ave_fname, condition=condition, baseline=(None, 0)) -evoked = mne.pick_channels_evoked(evoked) # We make the window slightly larger than what you'll eventually be interested # in ([-0.05, 0.3]) to avoid edge effects. evoked.crop(tmin=-0.1, tmax=0.4) diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index 31e8b06b08d..fa94e752c71 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -160,14 +160,14 @@ # pairs, meaning that it is unlikely that all four of these electrodes are # bridged. -raw = raw_data[6].copy().pick_channels(['F2', 'F4', 'FC2', 'FC4']) +raw = raw_data[6].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) raw.add_channels([mne.io.RawArray( raw.get_data(ch1) - raw.get_data(ch2), mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), raw.first_samp) for ch1, ch2 in [('F2', 'F4'), ('FC2', 'FC4')]]) raw.plot(duration=20, scalings=dict(eeg=2e-4)) -raw = raw_data[1].copy().pick_channels(['F2', 'F4', 'FC2', 'FC4']) +raw = raw_data[1].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) raw.add_channels([mne.io.RawArray( raw.get_data(ch1) - raw.get_data(ch2), mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), diff --git a/mne/beamformer/_rap_music.py b/mne/beamformer/_rap_music.py index 827b085385f..3b59fa90c46 100644 --- a/mne/beamformer/_rap_music.py +++ b/mne/beamformer/_rap_music.py @@ -8,7 +8,7 @@ import numpy as np from ..forward import is_fixed_orient, convert_forward_solution -from ..io.pick import pick_channels_evoked, pick_info, pick_channels_forward +from ..io.pick import pick_info, pick_channels_forward from ..inverse_sparse.mxne_inverse import _make_dipoles_sparse from ..minimum_norm.inverse import _log_exp_var from ..utils import logger, verbose, _check_info_inv, fill_doc @@ -274,11 +274,7 @@ def rap_music(evoked, forward, noise_cov, n_dipoles=5, return_residual=False, picks) if return_residual: - residual = evoked.copy() - selection = [info['ch_names'][p] for p in picks] - - residual = pick_channels_evoked(residual, - include=selection) + residual = evoked.copy().pick([info['ch_names'][p] for p in picks]) residual.data -= explained_data active_projs = [p for p in residual.info['projs'] if p['active']] for p in active_projs: diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 9e564dbc432..7f8e654c9bf 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -81,7 +81,7 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, bads = [raw.ch_names[pick] for pick in bad_picks] assert not any(pick in picks for pick in bad_picks) picks = np.concatenate([picks, bad_picks]) - raw.pick_channels([raw.ch_names[ii] for ii in picks]) + raw.pick_channels([raw.ch_names[ii] for ii in picks], ordered=True) del picks raw.info['bads'] = bads # add more bads @@ -429,7 +429,7 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): # (avoid "grad data rank (13) did not match the noise rank (None)") data_cov_grad = pick_channels_cov( data_cov, [ch_name for ch_name in epochs.info['ch_names'] - if ch_name.endswith(('2', '3'))]) + if ch_name.endswith(('2', '3'))], ordered=False) assert len(data_cov_grad['names']) > 4 make_lcmv(epochs.info, forward_fixed, data_cov_grad, reg=0.01, noise_cov=noise_cov) @@ -499,8 +499,9 @@ def test_lcmv_cov(weight_norm, pick_ori): filters = make_lcmv(evoked.info, forward, data_cov, noise_cov=noise_cov, weight_norm=weight_norm, pick_ori=pick_ori) for cov in (data_cov, noise_cov): - this_cov = pick_channels_cov(cov, evoked.ch_names) - this_evoked = evoked.copy().pick_channels(this_cov['names']) + this_cov = pick_channels_cov(cov, evoked.ch_names, ordered=False) + this_evoked = evoked.copy().pick_channels( + this_cov['names'], ordered=True) this_cov['projs'] = this_evoked.info['projs'] assert this_evoked.ch_names == this_cov['names'] stc = apply_lcmv_cov(this_cov, filters) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 433130c309b..211b0275441 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -25,7 +25,7 @@ from ..utils import (verbose, logger, warn, _check_preload, _validate_type, fill_doc, _check_option, _get_stim_channel, _check_fname, _check_dict_keys, - _on_missing) + _on_missing, legacy) from ..io.constants import FIFF from ..io.meas_info import (anonymize_info, Info, MontageMixin, create_info, _rename_comps) @@ -607,6 +607,7 @@ class UpdateChannelsMixin: """Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR.""" @verbose + @legacy(alt='inst.pick(...)') def pick_types(self, meg=False, eeg=False, stim=False, eog=False, ecg=False, emg=False, ref_meg='auto', *, misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False, @@ -660,18 +661,15 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, return self @verbose - def pick_channels(self, ch_names, ordered=False, *, verbose=None): + @legacy(alt='inst.pick(...)') + def pick_channels(self, ch_names, ordered=None, *, verbose=None): """Pick some channels. Parameters ---------- ch_names : list The list of channels to select. - ordered : bool - If True (default False), ensure that the order of the channels in - the modified instance matches the order of ``ch_names``. - - .. versionadded:: 0.20.0 + %(ordered)s %(verbose)s .. versionadded:: 1.1 diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index c811c682ea1..04f07d84ec3 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -172,7 +172,8 @@ def test_set_channel_types(): assert info['chs'][375]['kind'] == FIFF.FIFFV_SEEG_CH assert info['chs'][375]['unit'] == FIFF.FIFF_UNIT_V assert info['chs'][375]['coil_type'] == FIFF.FIFFV_COIL_EEG - for idx in pick_channels(raw.ch_names, ['MEG 2441', 'MEG 2443']): + for idx in pick_channels(raw.ch_names, ['MEG 2441', 'MEG 2443'], + ordered=False): assert info['chs'][idx]['kind'] == FIFF.FIFFV_EEG_CH assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_EEG @@ -480,9 +481,8 @@ def test_pick_channels(): assert len(raw.ch_names) == 3 # selected correctly 3 channels and ignored 'meg', and emit warning - with pytest.warns(RuntimeWarning, match='not present in the info'): + with pytest.raises(ValueError, match='not present in the info'): raw.pick(['MEG 0113', "meg", 'MEG 0112', 'MEG 0111']) - assert len(raw.ch_names) == 3 names_len = len(raw.ch_names) raw.pick(['all']) # selected correctly all channels diff --git a/mne/chpi.py b/mne/chpi.py index b57477deb29..648ad6ca78a 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -346,8 +346,8 @@ def get_chpi_info(info, on_missing='raise', verbose=None): hpi_pick = None # there is no pick! if hpi_sub is not None: if 'event_channel' in hpi_sub: - hpi_pick = pick_channels(info['ch_names'], - [hpi_sub['event_channel']]) + hpi_pick = pick_channels( + info['ch_names'], [hpi_sub['event_channel']], ordered=False) hpi_pick = hpi_pick[0] if len(hpi_pick) > 0 else None # grab codes indicating a coil is active hpi_on = [coil['event_bits'][0] for coil in hpi_sub['hpi_coils']] diff --git a/mne/cov.py b/mne/cov.py index c612e4ae7f1..43c993c6c91 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -336,16 +336,16 @@ def plot_topomap( extrapolate=extrapolate, sphere=sphere, border=border, time_format='') - def pick_channels(self, ch_names, ordered=False): + @verbose + def pick_channels(self, ch_names, ordered=None, *, verbose=None): """Pick channels from this covariance matrix. Parameters ---------- ch_names : list of str List of channels to keep. All other channels are dropped. - ordered : bool - If True (default False), ensure that the order of the channels - matches the order of ``ch_names``. + %(ordered)s + %(verbose)s Returns ------- @@ -1472,7 +1472,8 @@ def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, raise RuntimeError('Not all channels present in noise covariance:\n%s' % missing) C = noise_cov._get_square()[np.ix_(noise_cov_idx, noise_cov_idx)] - info = pick_info(info, pick_channels(info['ch_names'], ch_names)) + info = pick_info( + info, pick_channels(info['ch_names'], ch_names, ordered=False)) projs = info['projs'] + noise_cov['projs'] noise_cov = Covariance( data=C, names=ch_names, bads=list(noise_cov['bads']), @@ -1665,7 +1666,8 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', # This actually removes bad channels from the cov, which is not backward # compatible, so let's leave all channels in - cov_good = pick_channels_cov(cov, include=info_ch_names, exclude=exclude) + cov_good = pick_channels_cov( + cov, include=info_ch_names, exclude=exclude, ordered=False) ch_names = cov_good.ch_names # Now get the indices for each channel type in the cov @@ -1723,7 +1725,8 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', C[np.ix_(idx, idx)] = this_C # Put data back in correct locations - idx = pick_channels(cov.ch_names, info_ch_names, exclude=exclude) + idx = pick_channels( + cov.ch_names, info_ch_names, exclude=exclude, ordered=False) cov['data'][np.ix_(idx, idx)] = C return cov diff --git a/mne/epochs.py b/mne/epochs.py index d12e4cb167b..8a9e83d22d9 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -889,7 +889,8 @@ def subtract_evoked(self, evoked=None): evoked = self.average(picks) # find the indices of the channels to use - picks = pick_channels(evoked.ch_names, include=self.ch_names) + picks = pick_channels( + evoked.ch_names, include=self.ch_names, ordered=False) # make sure the omitted channels are not data channels if len(picks) < len(self.ch_names): diff --git a/mne/event.py b/mne/event.py index 1478b4ae105..68f943c3b49 100644 --- a/mne/event.py +++ b/mne/event.py @@ -422,7 +422,8 @@ def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, # pull stim channel from config if necessary stim_channel = _get_stim_channel(stim_channel, raw.info) - picks = pick_channels(raw.info['ch_names'], include=stim_channel) + picks = pick_channels( + raw.info['ch_names'], include=stim_channel, ordered=False) if len(picks) == 0: raise ValueError('No stim channel found to extract event triggers.') data, _ = raw[picks, :] diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 0d111d107ed..17ed07f8ac4 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -475,7 +475,8 @@ def _merge_fwds(fwds, *, verbose=None): @verbose -def read_forward_solution(fname, include=(), exclude=(), verbose=None): +def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, + verbose=None): """Read a forward solution a.k.a. lead field. Parameters @@ -487,6 +488,7 @@ def read_forward_solution(fname, include=(), exclude=(), verbose=None): are included. exclude : list, optional List of names of channels to exclude. If empty include all channels. + %(ordered)s %(verbose)s Returns @@ -665,7 +667,7 @@ def read_forward_solution(fname, include=(), exclude=(), verbose=None): @verbose def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, - copy=True, use_cps=True, verbose=None): + copy=True, use_cps=True, *, verbose=None): """Convert forward solution between different source orientations. Parameters @@ -1401,7 +1403,8 @@ def _stc_src_sel(src, stc, on_missing='raise', def _fill_measurement_info(info, fwd, sfreq, data): """Fill the measurement info of a Raw or Evoked object.""" - sel = pick_channels(info['ch_names'], fwd['sol']['row_names']) + sel = pick_channels( + info['ch_names'], fwd['sol']['row_names'], ordered=False) info = pick_info(info, sel) info['bads'] = [] diff --git a/mne/inverse_sparse/mxne_inverse.py b/mne/inverse_sparse/mxne_inverse.py index ef17fc736b0..ac2cbc5f488 100644 --- a/mne/inverse_sparse/mxne_inverse.py +++ b/mne/inverse_sparse/mxne_inverse.py @@ -9,7 +9,6 @@ from ..minimum_norm.inverse import (combine_xyz, _prepare_forward, _check_reference, _log_exp_var) from ..forward import is_fixed_orient -from ..io.pick import pick_channels_evoked from ..io.proj import deactivate_proj from ..utils import (logger, verbose, _check_depth, _check_option, sum_squared, _validate_type, check_random_state, warn) @@ -82,8 +81,7 @@ def _reapply_source_weighting(X, source_weighting, active_set): def _compute_residual(forward, evoked, X, active_set, info): # OK, picking based on row_names is safe sel = [forward['sol']['row_names'].index(c) for c in info['ch_names']] - residual = evoked.copy() - residual = pick_channels_evoked(residual, include=info['ch_names']) + residual = evoked.copy().pick(info['ch_names']) r_tmp = residual.copy() r_tmp.data = np.dot(forward['sol']['data'][sel, :][:, active_set], X) diff --git a/mne/io/base.py b/mne/io/base.py index 867c58656ce..e85ebe005d6 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1795,7 +1795,7 @@ def add_events(self, events, stim_channel=None, replace=False): if events.ndim != 2 or events.shape[1] != 3: raise ValueError('events must be shape (n_events, 3)') stim_channel = _get_stim_channel(stim_channel, self.info) - pick = pick_channels(self.ch_names, stim_channel) + pick = pick_channels(self.ch_names, stim_channel, ordered=False) if len(pick) == 0: raise ValueError('Channel %s not found' % stim_channel) pick = pick[0] diff --git a/mne/io/pick.py b/mne/io/pick.py index 292e9f8b772..e5156974c2c 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -11,7 +11,7 @@ from .constants import FIFF from ..utils import (logger, verbose, _validate_type, fill_doc, _ensure_int, - _check_option, warn) + _check_option, warn, deprecated) def get_channel_type_constants(include_defaults=False): @@ -220,7 +220,8 @@ def channel_type(info, idx): return first_kind -def pick_channels(ch_names, include, exclude=[], ordered=False): +@verbose +def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): """Pick channels by names. Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``. @@ -238,12 +239,8 @@ def pick_channels(ch_names, include, exclude=[], ordered=False): exclude : list of str List of channels to exclude (if empty do not exclude any channel). Defaults to []. - ordered : bool - If true (default False), treat ``include`` as an ordered list - rather than a set, and any channels from ``include`` are missing - in ``ch_names`` an error will be raised. - - .. versionadded:: 0.18 + %(ordered)s + %(verbose)s Returns ------- @@ -256,34 +253,47 @@ def pick_channels(ch_names, include, exclude=[], ordered=False): """ if len(np.unique(ch_names)) != len(ch_names): raise RuntimeError('ch_names is not a unique list, picking is unsafe') + _validate_type(ordered, (bool, None), 'ordered') _check_excludes_includes(include) _check_excludes_includes(exclude) - if not ordered: - if not isinstance(include, set): - include = set(include) - if not isinstance(exclude, set): - exclude = set(exclude) - sel = [] - for k, name in enumerate(ch_names): - if (len(include) == 0 or name in include) and name not in exclude: - sel.append(k) - else: - if not isinstance(include, list): - include = list(include) - if len(include) == 0: - include = list(ch_names) - if not isinstance(exclude, list): - exclude = list(exclude) - sel, missing = list(), list() - for name in include: - if name in ch_names: - if name not in exclude: - sel.append(ch_names.index(name)) - else: - missing.append(name) - if len(missing): + if not isinstance(include, list): + include = list(include) + if len(include) == 0: + include = list(ch_names) + if not isinstance(exclude, list): + exclude = list(exclude) + sel, missing = list(), list() + for name in include: + if name in ch_names: + if name not in exclude: + sel.append(ch_names.index(name)) + else: + missing.append(name) + dep_msg = ( + 'The default for pick_channels will change from ordered=False to ' + 'ordered=True in 1.5' + ) + if len(missing): + if ordered is None: + warn( + f'{dep_msg} and this will result in an error because the ' + f'following channel names are missing:\n{missing}\n' + 'Either fix your included names or explicitly pass ' + 'ordered=False.', FutureWarning) + elif ordered: raise ValueError('Missing channels from ch_names required by ' - 'include:\n%s' % (missing,)) + 'include:\n%s' % (missing,)) + if not ordered: + out_sel = np.unique(sel) + if ordered is None and not np.array_equal(out_sel, sel): + warn( + f'{dep_msg} and this will result in a change of behavior ' + 'because the resulting channel order will not match. Either ' + 'use a channel order that matches your instance or ' + 'pass ordered=False.', + FutureWarning, + ) + sel = out_sel return np.array(sel, int) @@ -486,7 +496,8 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, if len(myinclude) == 0: sel = np.array([], int) else: - sel = pick_channels(info['ch_names'], myinclude, exclude) + sel = pick_channels( + info['ch_names'], myinclude, exclude, ordered=False) return sel @@ -571,6 +582,8 @@ def _has_kit_refs(info, picks): return False +@deprecated('pick_channels_evoked in deprecated and will be removed in 1.5, ' + 'use evoked.copy().pick(...) instead.') def pick_channels_evoked(orig, include=[], exclude='bads'): """Pick channels from evoked data. @@ -615,8 +628,8 @@ def pick_channels_evoked(orig, include=[], exclude='bads'): @verbose -def pick_channels_forward(orig, include=[], exclude=[], ordered=False, - copy=True, verbose=None): +def pick_channels_forward(orig, include=[], exclude=[], ordered=None, + copy=True, *, verbose=None): """Pick channels from forward operator. Parameters @@ -629,11 +642,7 @@ def pick_channels_forward(orig, include=[], exclude=[], ordered=False, exclude : list of str | 'bads' Channels to exclude (if empty, do not exclude any). Defaults to []. If 'bads', then exclude bad channels in orig. - ordered : bool - If true (default False), treat ``include`` as an ordered list - rather than a set. - - .. versionadded:: 0.18 + %(ordered)s copy : bool If True (default), make a copy. @@ -773,8 +782,9 @@ def channel_indices_by_type(info, picks=None): return idx_by_type -def pick_channels_cov(orig, include=[], exclude='bads', ordered=False, - copy=True): +@verbose +def pick_channels_cov(orig, include=[], exclude='bads', ordered=None, + copy=True, *, verbose=None): """Pick channels from covariance matrix. Parameters @@ -785,16 +795,13 @@ def pick_channels_cov(orig, include=[], exclude='bads', ordered=False, List of channels to include (if empty, include all available). exclude : list of str, (optional) | 'bads' Channels to exclude (if empty, do not exclude any). Defaults to 'bads'. - ordered : bool - If True (default False), ensure that the order of the channels in the - modified instance matches the order of ``include``. - - .. versionadded:: 0.20.0 + %(ordered)s copy : bool If True (the default), return a copy of the covariance matrix with the modified channels. If False, channels are modified in-place. .. versionadded:: 0.20.0 + %(verbose)s Returns ------- @@ -1188,8 +1195,9 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, picked_ch_type_or_generic = not len(picks_name) if len(bad_names) > 0 and not picked_ch_type_or_generic: - warn(f'Channel(s) {bad_names} could not be picked, because ' - 'they are not present in the info instance.') + raise ValueError( + f'Channel(s) {bad_names} could not be picked, because ' + 'they are not present in the info instance.') if return_kind: return picks, picked_ch_type_or_generic diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index e16c8b5f41b..4af6d0ebe78 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -903,7 +903,7 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname): data=proj, active=False, desc='test', kind=0, explained_var=0.) raw.add_proj(proj, remove_existing=True) raw.info.normalize_proj() - raw.pick_channels(data_names + ref_names).crop(0, 2) + raw.pick_channels(data_names + ref_names, ordered=False).crop(0, 2) long_names = ['123456789abcdefg' + name for name in raw.ch_names] fname = tmp_path / 'test-raw.fif' with catch_logging() as log: diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py index 5508a263905..1632455b50e 100644 --- a/mne/io/tests/test_pick.py +++ b/mne/io/tests/test_pick.py @@ -3,7 +3,7 @@ import pytest import numpy as np -from numpy.testing import assert_array_equal, assert_equal +from numpy.testing import assert_array_equal from mne import (pick_channels_regexp, pick_types, Epochs, read_forward_solution, rename_channels, @@ -232,7 +232,7 @@ def test_pick_seeg_ecog(): assert_indexing(info, picks_by_type) assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 7]) for i, t in enumerate(types): - assert_equal(channel_type(info, i), types[i]) + assert channel_type(info, i) == types[i] raw = RawArray(np.zeros((len(names), 10)), info) events = np.array([[1, 0, 0], [2, 0, 0]]) epochs = Epochs(raw, events=events, event_id={'event': 0}, @@ -244,7 +244,7 @@ def test_pick_seeg_ecog(): assert lt == rt # Deal with constant debacle raw = read_raw_fif(io_dir / "tests" / "data" / "test_chpi_raw_sss.fif") - assert_equal(len(pick_types(raw.info, meg=False, seeg=True, ecog=True)), 0) + assert len(pick_types(raw.info, meg=False, seeg=True, ecog=True)) == 0 def test_pick_dbs(): @@ -329,8 +329,8 @@ def test_pick_ref(): def _check_fwd_n_chan_consistent(fwd, n_expected): n_ok = len(fwd['info']['ch_names']) n_sol = fwd['sol']['data'].shape[0] - assert_equal(n_expected, n_sol) - assert_equal(n_expected, n_ok) + assert n_expected == n_sol + assert n_expected == n_ok @testing.requires_testing_data @@ -367,8 +367,8 @@ def test_pick_forward_seeg_ecog(): counts['ecog'] += 1 # repick & check fwd_seeg = pick_types_forward(fwd, meg=False, seeg=True) - assert_equal(fwd_seeg['sol']['row_names'], [seeg_name]) - assert_equal(fwd_seeg['info']['ch_names'], [seeg_name]) + assert fwd_seeg['sol']['row_names'] == [seeg_name] + assert fwd_seeg['info']['ch_names'] == [seeg_name] # should work fine fwd_ = pick_types_forward(fwd, meg=True) _check_fwd_n_chan_consistent(fwd_, counts['meg']) @@ -393,15 +393,15 @@ def test_picks_by_channels(): raw = RawArray(test_data, info) pick_list = _picks_by_type(raw.info) - assert_equal(len(pick_list), 3) - assert_equal(pick_list[0][0], 'mag') + assert len(pick_list) == 3 + assert pick_list[0][0] == 'mag' pick_list2 = _picks_by_type(raw.info, meg_combined=False) - assert_equal(len(pick_list), len(pick_list2)) - assert_equal(pick_list2[0][0], 'mag') + assert len(pick_list) == len(pick_list2) + assert pick_list2[0][0] == 'mag' pick_list2 = _picks_by_type(raw.info, meg_combined=True) - assert_equal(len(pick_list), len(pick_list2) + 1) - assert_equal(pick_list2[0][0], 'meg') + assert len(pick_list) == len(pick_list2) + 1 + assert pick_list2[0][0] == 'meg' test_data = rng.random_sample((4, 2000)) ch_names = ['MEG %03d' % i for i in [1, 2, 3, 4]] @@ -410,19 +410,20 @@ def test_picks_by_channels(): info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) raw = RawArray(test_data, info) # This acts as a set, not an order - assert_array_equal(pick_channels(info['ch_names'], ['MEG 002', 'MEG 001']), - [0, 1]) + assert_array_equal( + pick_channels(info['ch_names'], ['MEG 002', 'MEG 001'], ordered=False), + [0, 1]) # Make sure checks for list input work. pytest.raises(ValueError, pick_channels, ch_names, 'MEG 001') pytest.raises(ValueError, pick_channels, ch_names, ['MEG 001'], 'hi') pick_list = _picks_by_type(raw.info) - assert_equal(len(pick_list), 1) - assert_equal(pick_list[0][0], 'mag') + assert len(pick_list) == 1 + assert pick_list[0][0] == 'mag' pick_list2 = _picks_by_type(raw.info, meg_combined=True) - assert_equal(len(pick_list), len(pick_list2)) - assert_equal(pick_list2[0][0], 'mag') + assert len(pick_list) == len(pick_list2) + assert pick_list2[0][0] == 'mag' # pick_types type check with pytest.raises(ValueError, match='must be of type'): @@ -430,8 +431,18 @@ def test_picks_by_channels(): # duplicate check names = ['MEG 002', 'MEG 002'] - assert len(pick_channels(raw.info['ch_names'], names)) == 1 - assert len(raw.copy().pick_channels(names)[0][0]) == 1 + assert len(pick_channels(raw.info['ch_names'], names, ordered=False)) == 1 + with pytest.warns(FutureWarning, match='ordered=False'): + assert len(raw.copy().pick_channels(names)[0][0]) == 1 + + # missing ch_name + bad_names = names + ['BAD'] + with pytest.raises(ValueError, match='Missing channels'): + pick_channels(raw.info['ch_names'], bad_names, ordered=True) + with pytest.raises(ValueError, match='Missing channels'): + raw.copy().pick_channels(bad_names, ordered=True) + with pytest.raises(ValueError, match='could not be picked'): + raw.copy().pick(bad_names) def test_clean_info_bads(): @@ -463,8 +474,8 @@ def test_clean_info_bads(): # simulate the call to pick_info excluding the bad meg channels info_meg = pick_info(raw.info, picks_meg) - assert_equal(info_eeg['bads'], eeg_bad_ch) - assert_equal(info_meg['bads'], meg_bad_ch) + assert info_eeg['bads'] == eeg_bad_ch + assert info_meg['bads'] == meg_bad_ch info = pick_info(raw.info, picks_meg) info._check_consistency() @@ -584,7 +595,7 @@ def test_pick_channels_cov(): assert_array_equal(cov_copy['data'], [2., 1.]) # Test picking in-place - pick_channels_cov(cov, ['CH2', 'CH1'], copy=False) + pick_channels_cov(cov, ['CH2', 'CH1'], copy=False, ordered=False) assert cov.ch_names == ['CH1', 'CH2'] assert_array_equal(cov['data'], [1., 2.]) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 63b657f2132..668cc9bb813 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -576,7 +576,7 @@ def short_raw_epochs(): """Get small data.""" raw = read_raw_fif(raw_fname).crop(0, 5).load_data() raw.pick_channels(set(raw.ch_names[::10]) | set( - ['EOG 061', 'MEG 1531', 'MEG 1441', 'MEG 0121'])) + ['EOG 061', 'MEG 1531', 'MEG 1441', 'MEG 0121']), ordered=False) assert 'eog' in raw raw.del_proj() # avoid warnings raw.set_annotations(Annotations([0.5], [0.5], ['BAD'])) @@ -724,8 +724,7 @@ def test_ica_additional(method, tmp_path, short_raw_epochs): assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2])) # epochs extraction from raw fit - with pytest.warns(RuntimeWarning, match='could not be picked'), \ - pytest.raises(RuntimeError, match="match fitted data"): + with pytest.raises(ValueError, match='not present in the info'): ica.get_sources(epochs) # test filtering @@ -930,10 +929,9 @@ def f(x, y): assert_array_equal(raw_data, raw[:][0]) raw.drop_channels(raw.ch_names[:2]) - with pytest.raises(RuntimeError, match='match fitted'): - with pytest.warns(RuntimeWarning, match='longer'): - ica.find_bads_eog(raw) - with pytest.raises(RuntimeError, match='match fitted'): + with pytest.raises(ValueError, match='not present in the info'): + ica.find_bads_eog(raw) + with pytest.raises(ValueError, match='not present in the info'): with pytest.warns(RuntimeWarning, match='longer'): ica.find_bads_ecg(raw, threshold='auto') diff --git a/mne/rank.py b/mne/rank.py index 20fb43ea90a..104e28388b3 100644 --- a/mne/rank.py +++ b/mne/rank.py @@ -325,7 +325,8 @@ def compute_rank(inst, rank=None, scalings=None, info=None, tol='auto', info = info.copy() info['bads'] = [] inst = pick_channels_cov( - inst, set(inst['names']) & set(info['ch_names']), exclude=[]) + inst, set(inst['names']) & set(info['ch_names']), exclude=[], + ordered=False) if info['ch_names'] != inst['names']: info = pick_info(info, [info['ch_names'].index(name) for name in inst['names']]) diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 851a5e19938..46247fadd91 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -313,7 +313,8 @@ def test_calculate_chpi_positions_vv(): raw_bad.crop(0, 1.) picks = np.concatenate([np.arange(306, len(raw_bad.ch_names)), pick_types(raw_bad.info, meg=True)[::16]]) - raw_bad.pick_channels([raw_bad.ch_names[pick] for pick in picks]) + raw_bad.pick_channels( + [raw_bad.ch_names[pick] for pick in picks], ordered=False) with pytest.warns(RuntimeWarning, match='Discrepancy'): with catch_logging() as log_file: _calculate_chpi_positions(raw_bad, t_step_min=1., verbose=True) diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index f92752559b2..ddfec15686e 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -245,7 +245,6 @@ def test_tabs(): mesh_edges next_fast_len parallel_func -pick_channels_evoked plot_epochs_psd plot_epochs_psd_topomap plot_raw_psd_topo diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py index cb40c670849..9eb93f2bd40 100644 --- a/mne/tests/test_event.py +++ b/mne/tests/test_event.py @@ -570,7 +570,7 @@ def test_acqparser_averaging(): ev_ref_mag = ev_ref.copy() ev_ref_mag.pick_channels(['MEG0111']) ev_ref_grad = ev_ref.copy() - ev_ref_grad.pick_channels(['MEG2643', 'MEG1622']) + ev_ref_grad.pick_channels(['MEG2643', 'MEG1622'], ordered=False) assert_allclose(ev_mag.data, ev_ref_mag.data, rtol=0, atol=1e-15) # tol = 1 fT # Elekta put these in a different order diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index 54b162e6eaf..4d1e076645a 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -20,7 +20,9 @@ from ..parallel import parallel_func -def pick_channels_csd(csd, include=[], exclude=[], ordered=False, copy=True): +@verbose +def pick_channels_csd(csd, include=[], exclude=[], ordered=None, copy=True, *, + verbose=None): """Pick channels from cross-spectral density matrix. Parameters @@ -31,16 +33,13 @@ def pick_channels_csd(csd, include=[], exclude=[], ordered=False, copy=True): List of channels to include (if empty, include all available). exclude : list of str Channels to exclude (if empty, do not exclude any). - ordered : bool - If True (default False), ensure that the order of the channels in the - modified instance matches the order of ``include``. - - .. versionadded:: 0.20.0 + %(ordered)s copy : bool If True (the default), return a copy of the CSD matrix with the modified channels. If False, channels are modified in-place. .. versionadded:: 0.20.0 + %(verbose)s Returns ------- diff --git a/mne/utils/check.py b/mne/utils/check.py index 95c598cf908..b5308227547 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -313,7 +313,7 @@ def _check_compensation_grade(info1, info2, name1, if t_info['comps']: with t_info._unlock(): t_info['comps'] = [] - picks = pick_channels(t_info['ch_names'], ch_names) + picks = pick_channels(t_info['ch_names'], ch_names, ordered=False) pick_info(t_info, picks, copy=False) # "or 0" here aliases None -> 0, as they are equivalent grade1 = get_current_comp(info1) or 0 diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 4b200154c0b..7a0d2968a65 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2474,6 +2474,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.22 """ +docdict['ordered'] = """ +ordered : bool + If True (default False), ensure that the order of the channels in + the modified instance matches the order of ``ch_names``. + + .. versionadded:: 0.20.0 + .. versionchanged:: 1.5 + The default changed from False in 1.4 to True in 1.5. +""" + docdict['origin_maxwell'] = """ origin : array-like, shape (3,) | str Origin of internal and external multipolar moment space in meters. diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 75599bd718a..813c9e80500 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -497,12 +497,12 @@ def _setup_channel_selections(raw, kind, order): # get stim channel (if any) stim_ch = _get_stim_channel(None, raw.info, raise_error=False) stim_ch = stim_ch if len(stim_ch) else [''] - stim_ch = pick_channels(raw.ch_names, stim_ch) + stim_ch = pick_channels(raw.ch_names, stim_ch, ordered=False) # loop over regions keys = np.concatenate([_SELECTIONS, _EEG_SELECTIONS]) for key in keys: channels = read_vectorview_selection(key, info=raw.info) - picks = pick_channels(raw.ch_names, channels) + picks = pick_channels(raw.ch_names, channels, ordered=False) picks = np.intersect1d(picks, order) if not len(picks): continue # omit empty selections diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 27b00571a31..afc3b6d660c 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -16,7 +16,7 @@ from matplotlib.colors import Colormap from matplotlib.figure import Figure -from mne import (make_field_map, pick_channels_evoked, read_evokeds, +from mne import (make_field_map, read_evokeds, read_trans, read_dipole, SourceEstimate, make_sphere_model, use_coil_def, pick_types, setup_volume_source_space, read_forward_solution, @@ -151,7 +151,7 @@ def test_plot_evoked_field(renderer): """Test plotting evoked field.""" evoked = read_evokeds(evoked_fname, condition='Left Auditory', baseline=(-0.2, 0.0)) - evoked = pick_channels_evoked(evoked, evoked.ch_names[::10]) # speed + evoked.pick(evoked.ch_names[::10]) # speed for t, n_contours in zip(['meg', None], [21, 0]): with pytest.warns(RuntimeWarning, match='projection'): maps = make_field_map(evoked, trans_fname, subject='sample', diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 22674c88706..3a20d25d9c0 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -373,7 +373,7 @@ def test_plot_epochs_ctf(raw_ctf, browser_backend): """Test of basic CTF plotting.""" raw_ctf.pick_channels(['UDIO001', 'UPPT001', 'SCLK01-177', 'BG1-4304', 'MLC11-4304', 'MLC11-4304', - 'EEG058', 'UADC007-4302']) + 'EEG058', 'UADC007-4302'], ordered=False) evts = make_fixed_length_events(raw_ctf) epochs = Epochs(raw_ctf, evts, preload=True) epochs.plot() diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index a5724c2e40e..291c381a6d0 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -210,7 +210,8 @@ def test_plot_ica_properties(): # Test merging grads. pick_names = raw.ch_names[:15:2] + raw.ch_names[1:15:2] - raw = _get_raw(preload=True).pick_channels(pick_names).crop(0, 5) + raw = _get_raw(preload=True).pick_channels(pick_names, ordered=False) + raw.crop(0, 5) raw.info.normalize_proj() ica = ICA(random_state=0, max_iter=1) with pytest.warns(UserWarning, match='did not converge'): @@ -220,7 +221,7 @@ def test_plot_ica_properties(): # Test handling of zeros ica = ICA(random_state=0, max_iter=1) - epochs.pick_channels(pick_names) + epochs.pick_channels(pick_names, ordered=False) with pytest.warns(UserWarning, match='did not converge'): ica.fit(epochs) epochs._data[0] = 0 @@ -314,13 +315,11 @@ def test_plot_ica_sources(raw_orig, browser_backend, monkeypatch): # test error handling raw_ = raw.copy().load_data() raw_.drop_channels('MEG 0113') - with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"), \ - pytest.warns(RuntimeWarning, match='could not be picked'): + with pytest.raises(ValueError, match="could not be picked"): ica.plot_sources(inst=raw_) epochs_ = epochs.copy().load_data() epochs_.drop_channels('MEG 0113') - with pytest.raises(RuntimeError, match="Epochs don't match fitted data"), \ - pytest.warns(RuntimeWarning, match='could not be picked'): + with pytest.raises(ValueError, match="could not be picked"): ica.plot_sources(inst=epochs_) del raw_ del epochs_ diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 997dfa002fc..dce8f4b37da 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -14,8 +14,7 @@ import matplotlib import matplotlib.pyplot as plt -from mne import (read_events, Epochs, pick_channels_evoked, read_cov, - compute_proj_evoked) +from mne import read_events, Epochs, read_cov, compute_proj_evoked from mne.channels import read_layout from mne.io import read_raw_fif from mne.time_frequency.tfr import AverageTFR @@ -159,8 +158,7 @@ def test_plot_topo(): evoked_delayed_ssp = _get_epochs_delayed_ssp().average() ch_names = evoked_delayed_ssp.ch_names[:3] # make it faster - picked_evoked_delayed_ssp = pick_channels_evoked(evoked_delayed_ssp, - ch_names) + picked_evoked_delayed_ssp = evoked_delayed_ssp.pick(ch_names) fig = plot_evoked_topo(picked_evoked_delayed_ssp, layout, proj='interactive') func = _get_presser(fig) diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 8d79f70bc00..f142f49d4d2 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -142,7 +142,7 @@ def test_plot_projs_topomap(): eeg_proj = make_eeg_average_ref_proj(info) info_meg = pick_info(info, pick_types(info, meg=True, eeg=False)) - with pytest.raises(ValueError, match='No channel names in info match p'): + with pytest.raises(ValueError, match='Missing channels'): plot_projs_topomap([eeg_proj], info_meg) diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index e5b7e001aa6..d56b994e45b 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -369,10 +369,7 @@ def _plot_projs_topomap( types.append(list(these_ch_types)[0]) data = proj['data']['data'].ravel() info_names = _clean_names(info['ch_names'], remove_whitespace=True) - picks = pick_channels(info_names, ch_names) - if len(picks) == 0: - raise ValueError( - f'No channel names in info match projector {proj}') + picks = pick_channels(info_names, ch_names, ordered=True) use_info = pick_info(info, picks) data_picks, pos, merge_channels, names, ch_type, this_sphere, \ clip_origin = _prepare_topomap_plot( diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 9bac4490619..a34f24e69e5 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1066,7 +1066,8 @@ def plot_sensors(info, kind='topomap', ch_type=None, title=None, for selection in _SELECTIONS + _EEG_SELECTIONS: channels = pick_channels( info['ch_names'], - read_vectorview_selection(selection, info=info)) + read_vectorview_selection(selection, info=info), + ordered=False) ch_groups.append(channels) color_vals = np.ones((len(ch_groups), 4)) for idx, ch_group in enumerate(ch_groups): @@ -1862,7 +1863,8 @@ def _setup_plot_projector(info, noise_cov, proj=True, use_noise_cov=True, set(noise_cov['bads'])) # Actually compute the whitener only using the difference whiten_names = cov_names - bad_names - whiten_picks = pick_channels(info['ch_names'], whiten_names) + whiten_picks = pick_channels( + info['ch_names'], whiten_names, ordered=True) whiten_info = pick_info(info, whiten_picks) rank = _triage_rank_sss(whiten_info, [noise_cov])[1][0] whitener, whitened_ch_names = compute_whitener( @@ -1952,7 +1954,7 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None): break if rank.get(ch_type) is None: ch_names = [info['ch_names'][pick] for pick in this_picks] - this_C = pick_channels_cov(cov, ch_names) + this_C = pick_channels_cov(cov, ch_names, ordered=False) this_estimated_rank = compute_rank( this_C, scalings=scalings, info=info_proj)[ch_type] this_rank[ch_type] = this_estimated_rank diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index 4b8647552b8..c68247cfd01 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -13,9 +13,10 @@ elif [ "${TEST_MODE}" == "pip-pre" ]; then # Broken as of 2022/09/20 # python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps --extra-index-url https://www.riverbankcomputing.com/pypi/simple PyQt6 PyQt6-sip PyQt6-Qt6 python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps PyQt6 PyQt6-sip PyQt6-Qt6 - # Wait for https://github.com/scipy/scipy/issues/17811 - python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps numpy - python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://pypi.anaconda.org/scipy-wheels-nightly/simple" scipy statsmodels pandas scikit-learn dipy matplotlib + python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://pypi.anaconda.org/scipy-wheels-nightly/simple" numpy + # # SciPy<->sklearn problematic, see https://github.com/scipy/scipy/issues/18377 + python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps scipy + python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://pypi.anaconda.org/scipy-wheels-nightly/simple" statsmodels pandas scikit-learn dipy matplotlib python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com/" h5py python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://test.pypi.org/simple" openmeeg python -m pip install --progress-bar off --upgrade --pre --only-binary ":all:" --no-deps -i "https://wheels.vtk.org/" vtk diff --git a/tools/circleci_bash_env.sh b/tools/circleci_bash_env.sh new file mode 100755 index 00000000000..9ba25eee5cd --- /dev/null +++ b/tools/circleci_bash_env.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e +set -o pipefail + +./tools/setup_xvfb.sh +sudo apt install -qq graphviz optipng python3.10-venv python3-venv libxft2 ffmpeg +python3.10 -m venv ~/python_env +echo "set -e" >> $BASH_ENV +echo "set -o pipefail" >> $BASH_ENV +echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV +echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV +echo "export MNE_FULL_DATE=true" >> $BASH_ENV +source tools/get_minimal_commands.sh +echo "export MNE_3D_BACKEND=pyvistaqt" >> $BASH_ENV +echo "export MNE_3D_OPTION_MULTI_SAMPLES=1" >> $BASH_ENV +echo "export MNE_BROWSER_BACKEND=qt" >> $BASH_ENV +echo "export MNE_BROWSER_PRECOMPUTE=false" >> $BASH_ENV +echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV +echo "export DISPLAY=:99" >> $BASH_ENV +echo "source ~/python_env/bin/activate" >> $BASH_ENV +mkdir -p ~/.local/bin +ln -s ~/python_env/bin/python ~/.local/bin/python +echo "BASH_ENV:" +cat $BASH_ENV +mkdir -p ~/mne_data +touch pattern.txt diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 760066eb20e..677bd5ced22 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -1,18 +1,4 @@ #!/bin/bash -ef -echo "Installing setuptools and sphinx" -python -m pip install --upgrade "pip!=20.3.0" -python -m pip install --upgrade --progress-bar off setuptools wheel -python -m pip install --upgrade --progress-bar off --pre sphinx -if [[ "$CIRCLE_JOB" == "linkcheck"* ]]; then - echo "Installing minimal linkcheck dependencies" - python -m pip install --progress-bar off pillow pytest -r requirements_base.txt -r requirements_doc.txt -else # standard doc build - echo "Installing doc build dependencies" - python -m pip install --upgrade --progress-bar off PyQt6 - python -m pip install --upgrade --progress-bar off --only-binary "numpy,scipy,matplotlib,pandas,statsmodels" -r requirements.txt -r requirements_testing.txt -r requirements_doc.txt - python -m pip install --upgrade --progress-bar off git+https://github.com/mne-tools/mne-qt-browser git+https://github.com/sphinx-gallery/sphinx-gallery - # deal with comparisons and escapes (https://app.circleci.com/pipelines/github/mne-tools/mne-python/9686/workflows/3fd32b47-3254-4812-8b9a-8bab0d646d18/jobs/32934) - python -m pip install --upgrade --progress-bar off quantities -fi -python -m pip install -e . +python -m pip install --upgrade "pip!=20.3.0" setuptools wheel +python -m pip install --upgrade --progress-bar off --only-binary "numpy,scipy,matplotlib,pandas,statsmodels" -r requirements.txt -r requirements_testing.txt -r requirements_doc.txt PyQt6 git+https://github.com/mne-tools/mne-qt-browser -e . diff --git a/tutorials/simulation/80_dics.py b/tutorials/simulation/80_dics.py index 71fd32a0210..53ca16936d6 100644 --- a/tutorials/simulation/80_dics.py +++ b/tutorials/simulation/80_dics.py @@ -187,8 +187,11 @@ def coh_signal_gen(): # Plot some of the channels of the simulated data that are situated above one # of our simulated sources. -picks = mne.pick_channels(epochs.ch_names, - mne.read_vectorview_selection('Left-frontal')) +picks = mne.pick_channels( + epochs.ch_names, + mne.read_vectorview_selection('Left-frontal'), + ordered=False, +) epochs.plot(picks=picks) # %% diff --git a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py index 08fda2a59b4..6976632b00c 100644 --- a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py +++ b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py @@ -68,7 +68,8 @@ reject=dict(grad=4000e-13, eog=150e-6)) # just use right temporal sensors for speed -epochs.pick_channels(mne.read_vectorview_selection('Right-temporal')) +epochs.pick_channels( + mne.read_vectorview_selection('Right-temporal'), ordered=False) evoked = epochs.average() # Factor to down-sample the temporal dimension of the TFR computed by From 295a1bbcb8eec67773506464b56230e4ef7c8e41 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 1 May 2023 14:18:31 -0500 Subject: [PATCH 0041/1125] MAINT: Linkcheck (#11670) --- doc/_includes/institutional-partners.rst | 2 +- doc/changes/names.inc | 46 +++++++++--------- doc/conf.py | 61 +++++++++++++----------- doc/references.bib | 6 +-- mne/channels/montage.py | 7 ++- mne/io/cnt/cnt.py | 4 +- 6 files changed, 64 insertions(+), 62 deletions(-) diff --git a/doc/_includes/institutional-partners.rst b/doc/_includes/institutional-partners.rst index 083be66859d..fc2b0ee05df 100644 --- a/doc/_includes/institutional-partners.rst +++ b/doc/_includes/institutional-partners.rst @@ -32,7 +32,7 @@ Former partners - `Aarhus Universitet `_ - `Berkeley Institute for Data Science `_ - `Boston University `_ -- `Commissariat à l’énergie atomique et aux énergies alternatives `_ +- `Commissariat à l’énergie atomique et aux énergies alternatives `_ - `Forschungszentrum Jülich `_ - `Institut du Cerveau et de la Moelle épinière `_ - `Institut national de la santé et de la recherche médicale `_ diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 1894afa0db8..386d74eb809 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -1,6 +1,6 @@ -.. _Abram Hindle: http://softwareprocess.es +.. _Abram Hindle: https://softwareprocess.es -.. _Adam Li: http://github.com/adam2392 +.. _Adam Li: https://github.com/adam2392 .. _Adeline Fecker: https://github.com/adelinefecker @@ -16,13 +16,13 @@ .. _Alex Ciok: https://github.com/alexCiok -.. _Alex Gramfort: http://alexandre.gramfort.net +.. _Alex Gramfort: https://alexandre.gramfort.net .. _Alex Rockhill: https://github.com/alexrockhill/ .. _Alexander Rudiuk: https://github.com/ARudiuk -.. _Alexandre Barachant: http://alexandre.barachant.org +.. _Alexandre Barachant: https://alexandre.barachant.org .. _Andrea Brovelli: https://andrea-brovelli.net @@ -78,7 +78,7 @@ .. _Chris Holdgraf: https://chrisholdgraf.com -.. _Chris Mullins: http://crmullins.com +.. _Chris Mullins: https://crmullins.com .. _Christian Brodbeck: https://github.com/christianbrodbeck @@ -100,7 +100,7 @@ .. _Daniel Hasegan: https://daniel.hasegan.com -.. _Daniel McCloy: http://dan.mccloy.info +.. _Daniel McCloy: https://dan.mccloy.info .. _Daniel Strohmeier: https://github.com/joewalter @@ -114,7 +114,7 @@ .. _Demetres Kostas: https://github.com/kostasde -.. _Denis Engemann: http://denis-engemann.de +.. _Denis Engemann: https://denis-engemann.de .. _Dinara Issagaliyeva: https://github.com/dissagaliyeva @@ -132,13 +132,13 @@ .. _Eduard Ort: https://github.com/eort -.. _Emily Stephen: http://github.com/emilyps14 +.. _Emily Stephen: https://github.com/emilyps14 .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt -.. _Eric Larson: http://larsoner.com +.. _Eric Larson: https://larsoner.com .. _Erica Peterson: https://github.com/nordme @@ -184,7 +184,7 @@ .. _Hakimeh Pourakbari: https://github.com/Hpakbari -.. _Hari Bharadwaj: http://www.haribharadwaj.com +.. _Hari Bharadwaj: https://github.com/haribharadwaj .. _Henrich Kolkhorst: https://github.com/hekolk @@ -280,7 +280,7 @@ .. _Lau Møller Andersen: https://github.com/ualsbombe -.. _Laura Gwilliams: http://lauragwilliams.github.io +.. _Laura Gwilliams: https://lauragwilliams.github.io .. _Leonardo Barbosa: https://github.com/noreun @@ -348,7 +348,7 @@ .. _Mikołaj Magnuski: https://github.com/mmagnuski -.. _Milan Rybář: http://milanrybar.cz +.. _Milan Rybář: https://milanrybar.cz .. _Mingjian He: https://github.com/mh105 @@ -362,9 +362,9 @@ .. _Naveen Srinivasan: https://github.com/naveensrinivasan -.. _Nick Foti: http://nfoti.github.io +.. _Nick Foti: https://nfoti.github.io -.. _Nick Ward: http://www.ucl.ac.uk/ion/departments/sobell/Research/NWard +.. _Nick Ward: https://www.ucl.ac.uk/ion/departments/sobell/Research/NWard .. _Nicolas Barascud: https://github.com/nbara @@ -376,7 +376,7 @@ .. _Okba Bekhelifi: https://github.com/okbalefthanded -.. _Olaf Hauk: http://www.neuroscience.cam.ac.uk/directory/profile.php?olafhauk +.. _Olaf Hauk: https://www.neuroscience.cam.ac.uk/directory/profile.php?olafhauk .. _Oleh Kozynets: https://github.com/OlehKSS @@ -436,7 +436,7 @@ .. _Romain Trachel: https://fr.linkedin.com/in/trachelr -.. _Roman Goj: http://romanmne.blogspot.co.uk +.. _Roman Goj: https://romanmne.blogspot.co.uk .. _Ross Maddox: https://www.urmc.rochester.edu/labs/maddox-lab.aspx @@ -450,13 +450,13 @@ .. _Santeri Ruuskanen: https://github.com/ruuskas -.. _Sara Sommariva: http://www.dima.unige.it/~sommariva/ +.. _Sara Sommariva: https://www.dima.unige.it/~sommariva/ -.. _Sawradip Saha: http://sawradip.github.io/ +.. _Sawradip Saha: https://sawradip.github.io/ .. _Scott Huberty: https://orcid.org/0000-0003-2637-031X -.. _Sebastiaan Mathot: http://www.cogsci.nl/smathot +.. _Sebastiaan Mathot: https://www.cogsci.nl/smathot .. _Sebastian Castano: https://github.com/jscastanoc @@ -476,7 +476,7 @@ .. _Simon Kern: https://github.com/skjerns -.. _Simon Kornblith: http://simonster.com +.. _Simon Kornblith: https://simonster.com .. _Sondre Foslien: https://github.com/sondrfos @@ -484,7 +484,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista -.. _Stefan Appelhoff: http://stefanappelhoff.com +.. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr @@ -502,7 +502,7 @@ .. _T. Wang: https://github.com/twang5 -.. _Tal Linzen: http://tallinzen.net/ +.. _Tal Linzen: https://tallinzen.net/ .. _Teon Brooks: https://teonbrooks.com @@ -526,7 +526,7 @@ .. _Tommy Clausner: https://github.com/TommyClausner -.. _Toomas Erik Anijärv: http://www.toomaserikanijarv.com/ +.. _Toomas Erik Anijärv: https://www.toomaserikanijarv.com/ .. _Tristan Stenner: https://github.com/tstenner/ diff --git a/doc/conf.py b/doc/conf.py index 4897d6f1527..e0748d37053 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -160,7 +160,8 @@ 'picard': ('https://pierreablin.github.io/picard/', None), 'qdarkstyle': ('https://qdarkstylesheet.readthedocs.io/en/latest', None), 'eeglabio': ('https://eeglabio.readthedocs.io/en/latest', None), - 'dipy': ('https://dipy.org/documentation/latest/', None), + 'dipy': ('https://dipy.org/documentation/1.7.0/', + 'https://dipy.org/documentation/1.7.0/objects.inv/'), 'pooch': ('https://www.fatiando.org/pooch/latest/', None), 'pybv': ('https://pybv.readthedocs.io/en/latest/', None), 'pyqtgraph': ('https://pyqtgraph.readthedocs.io/en/latest/', None), @@ -521,37 +522,41 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # -- Other extension configuration ------------------------------------------- -linkcheck_request_headers = dict(user_agent='Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36') # noqa: E501 +user_agent = 'Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Mobile Safari/537.36' # noqa: E501 +# Can eventually add linkcheck_request_headers if needed linkcheck_ignore = [ # will be compiled to regex - r'https://datashare.is.ed.ac.uk/handle/10283/2189/?show=full', # noqa Max retries exceeded with url: /handle/10283/2189?show=full (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1123)'))) - 'https://doi.org/10.1002/mds.870120629', # Read timed out. - 'https://doi.org/10.1088/0031-9155/32/1/004', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1088/0031-9155/40/3/001', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1088/0031-9155/51/7/008', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1088/0031-9155/57/7/1937', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1088/0967-3334/22/4/305', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1088/1741-2552/aacfe4', # noqa Read timed out. (read timeout=15) - 'https://doi.org/10.1093/sleep/18.7.557', # noqa 403 Client Error: Forbidden for url: https://academic.oup.com/sleep/article-lookup/doi/10.1093/sleep/18.7.557 - 'https://doi.org/10.1162/089976699300016719', # noqa 403 Client Error: Forbidden for url: https://direct.mit.edu/neco/article/11/2/417-441/6242 - 'https://doi.org/10.1162/jocn.1993.5.2.162', # noqa 403 Client Error: Forbidden for url: https://direct.mit.edu/jocn/article/5/2/162-176/3095 - 'https://doi.org/10.1162/neco.1995.7.6.1129', # noqa 403 Client Error: Forbidden for url: https://direct.mit.edu/neco/article/7/6/1129-1159/5909 - 'https://doi.org/10.1162/jocn_a_00405', # noqa 403 Client Error: Forbidden for url: https://direct.mit.edu/jocn/article/25/9/1477-1492/27980 - 'https://doi.org/10.1167/15.6.4', # noqa 403 Client Error: Forbidden for url: https://jov.arvojournals.org/article.aspx?doi=10.1167/15.6.4 - 'https://doi.org/10.7488/ds/1556', # noqa Max retries exceeded with url: /handle/10283/2189 (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1122)'))) - 'https://imaging.mrc-cbu.cam.ac.uk/imaging/MniTalairach', # noqa Max retries exceeded with url: /imaging/MniTalairach (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1122)'))) - 'https://www.nyu.edu/', # noqa Max retries exceeded with url: / (Caused by SSLError(SSLError(1, '[SSL: DH_KEY_TOO_SMALL] dh key too small (_ssl.c:1122)'))) - 'https://docs.python.org/3/library/.*', # noqa ('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer')) - 'https://hal.archives-ouvertes.fr/hal-01848442.*', # noqa Sometimes: 503 Server Error: Service Unavailable for url: https://hal.archives-ouvertes.fr/hal-01848442/ - 'http://www.cs.ucl.ac.uk/staff/d.barber/brml.*', # noqa Sometimes: Read timed out - 'https://compumedicsneuroscan.com/scan-acquire-configuration-files.*', # noqa SSL certificate error as of 2021/09/28 - 'https://chrisholdgraf.com/', # noqa Max retries exceeded sometimes - 'https://www.dtu.dk/english/service/phonebook/person.*', # noqa Too slow - 'https://speakerdeck.com/dengemann/eeg-sensor-covariance-using-cross-validation', # noqa Too slow - 'https://doi.org/10.1002/hbm.10024', # noqa Too slow sometimes - 'https://www.researchgate.net/', # noqa As of 2022/05/31 we get "403 Forbidden" errors, might have to do with https://stackoverflow.com/questions/72347165 but not worth the effort to fix + # 403 Client Error: Forbidden + "https://doi.org/10.1002/", # onlinelibrary.wiley.com/doi/10.1002/hbm + "https://doi.org/10.1021/", # pubs.acs.org/doi/abs + "https://doi.org/10.1073/", # pnas.org + "https://doi.org/10.1093/", # academic.oup.com/sleep/ + "https://doi.org/10.1098/", # royalsocietypublishing.org + "https://doi.org/10.1111/", # onlinelibrary.wiley.com/doi/10.1111/psyp + "https://doi.org/10.1126/", # www.science.org + "https://doi.org/10.1137/", # epubs.siam.org + "https://doi.org/10.1161/", # www.ahajournals.org + "https://doi.org/10.1162/", # direct.mit.edu/neco/article/ + "https://doi.org/10.1167/", # jov.arvojournals.org + "https://doi.org/10.1177/", # journals.sagepub.com + "https://doi.org/10.1063/", # pubs.aip.org/aip/jap + "https://www.researchgate.net/profile/", + # 503 Server error + "https://hal.archives-ouvertes.fr/hal-01848442", + # Read timed out + "http://www.cs.ucl.ac.uk/staff/d.barber/brml", + "https://www.cea.fr/", + # Max retries exceeded + "https://doi.org/10.7488/ds/1556", + "https://datashare.is.ed.ac.uk/handle/10283", + "https://imaging.mrc-cbu.cam.ac.uk/imaging/MniTalairach", + "https://www.nyu.edu/", + # Too slow + "https://speakerdeck.com/dengemann/", + "https://www.dtu.dk/english/service/phonebook/person", ] linkcheck_anchors = False # saves a bit of time linkcheck_timeout = 15 # some can be quite slow +linkcheck_retries = 3 # autodoc / autosummary autosummary_generate = True diff --git a/doc/references.bib b/doc/references.bib index 95d78aa1ce6..9415b165484 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2411,21 +2411,19 @@ @article{TierneyEtAl2021 title = {Modelling optically pumped magnetometer interference in MEG as a spatially homogeneous magnetic field}, volume = {244}, issn = {1053-8119}, - doi = {j.neuroimage.2021.118484}, + doi = {10.1016/j.neuroimage.2021.118484}, language = {en}, journal = {NeuroImage}, author = {Tierney, Tim M. and Alexander, Nicholas and Mellor, Stephanie and Holmes, Niall and Seymour, Robert and O'Neill, George C. and Maguire, Eleanor A. and Barnes, Gareth R.}, year = {2021}, - pages = {118834} } @article{TierneyEtAl2022, title = {Spherical harmonic based noise rejection and neuronal sampling with multi-axis OPMs}, journal = {NeuroImage}, volume = {258}, - pages = {119338}, year = {2022}, issn = {1053-8119}, - doi = {j.neuroimage.2022.119338}, + doi = {10.1016/j.neuroimage.2022.119338}, author = {Tierney, Tim M. and Mellor, Stephanie nd O'Neill, George C. and Timms, Ryan C. and Barnes, Gareth R.}, } \ No newline at end of file diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 6a2a84241ea..178557bc520 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -691,10 +691,9 @@ def read_dig_dat(fname): .. Warning:: This function was implemented based on ``*.dat`` files available from - `Compumedics `__ and might not work as expected with novel - files. If it does not read your files correctly please contact the - mne-python developers. + `Compumedics `__ and might not work + as expected with novel files. If it does not read your files correctly + please contact the MNE-Python developers. Parameters ---------- diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index f7ee029184b..f7fdef8710c 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -113,8 +113,8 @@ def read_raw_cnt(input_fname, eog=(), misc=(), ecg=(), Montages can be created/imported with: - Standard montages with :func:`mne.channels.make_standard_montage` - - Montages for `Compumedics systems `_ with + - Montages for `Compumedics systems + `__ with :func:`mne.channels.read_dig_dat` - Other reader functions are listed under *See Also* at :class:`mne.channels.DigMontage` From 263114e32d4fe4951f8f622f7ca74c30bfba37ae Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Mon, 1 May 2023 15:26:00 -0500 Subject: [PATCH 0042/1125] ENH: Fix CNT annotations to include bad spans (#11631) --- doc/changes/latest.inc | 3 +- doc/changes/names.inc | 2 ++ mne/datasets/config.py | 4 +-- mne/io/cnt/cnt.py | 62 ++++++++++++++++++++++++++++++++++-- mne/io/cnt/tests/test_cnt.py | 9 ++++++ 5 files changed, 75 insertions(+), 5 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 851e3b04587..a0383e025f1 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -23,6 +23,7 @@ Current (1.4.dev0) Enhancements ~~~~~~~~~~~~ +- Add functionality for reading CNT spans/annotations marked bad to :func:`mne.io.read_raw_cnt` (:gh:`11631` by :newcontrib:`Jacob Woessner`) - Add ``:unit:`` Sphinx directive to enable use of uniform non-breaking spaces throughout the documentation (:gh:`11469` by :newcontrib:`Sawradip Saha`) - Adjusted the algorithm used in :class:`mne.decoding.SSD` to support non-full rank data (:gh:`11458` by :newcontrib:`Thomas Binns`) - Changed suggested type for ``ch_groups``` in `mne.viz.plot_sensors` from array to list of list(s) (arrays are still supported). (:gh:`11465` by `Hyonyoung Shin`_) @@ -53,6 +54,7 @@ Bugs - Fix :func:`mne.time_frequency.psd_array_multitaper` docstring where argument ``bandwidth`` incorrectly reported argument as half-bandwidth and gave wrong explanation of default value (:gh:`11479` by :newcontrib:`Tom Stone`) - Fix bug where installation of a package depending on ``mne`` will error when done in an environment where ``setuptools`` is not present (:gh:`11454` by :newcontrib:`Arne Pelzer`) - Fix bug in :meth:`mne.Annotations.rename` where replacements were not done correctly (:gh:`11666` by :newcontrib:`Timur Sokhin` and `Eric Larson`_) +- Fix :meth:`mne.time_frequency.Spectrum.to_data_frame`'s docstring to reflect the correct name for the appended frequencies column (:gh:`11457` by :newcontrib:`Zvi Baratz`) - Fix bug where :func:`mne.preprocessing.regress_artifact` and :class:`mne.preprocessing.EOGRegression` incorrectly tracked ``picks`` (:gh:`11366` by `Eric Larson`_) - Fix bug where channel names were not properly sanitized in :func:`mne.write_evokeds` and related functions (:gh:`11399` by `Eric Larson`_) - Fix bug where splash screen would not always disappear (:gh:`11398` by `Eric Larson`_) @@ -61,7 +63,6 @@ Bugs - Fix :func:`mne.io.read_raw_edf` when reading EDF data with different sampling rates and a mix of data channels when using ``infer_types=True`` (:gh:`11427` by `Alex Gramfort`_) - Fix how :class:`mne.channels.DigMontage` is set when using :func:`mne.preprocessing.ieeg.project_sensors_onto_brain` so that :func:`mne.Info.get_montage` works and does not return ``None`` (:gh:`11436` by `Alex Rockhill`_) - Fix configuration folder discovery on Windows, which would fail in certain edge cases; and produce a helpful error message if discovery still fails (:gh:`11441` by `Richard Höchenberger`_) -- Fix :meth:`mne.time_frequency.Spectrum.to_data_frame`'s docstring to reflect the correct name for the appended frequencies column (:gh:`11457` by :newcontrib:`Zvi Baratz`) - Make :class:`~mne.decoding.SlidingEstimator` and :class:`~mne.decoding.GeneralizingEstimator` respect the ``verbose`` argument. Now with ``verbose=False``, the progress bar is not shown during fitting, scoring, etc. (:gh:`11450` by `Mikołaj Magnuski`_) - Fix bug with ``mne.gui.locate_ieeg`` where Freesurfer ``?h.pial.T1`` was not recognized and suppress excess logging (:gh:`11489` by `Alex Rockhill`_) - All functions accepting paths can now correctly handle :class:`~pathlib.Path` as input. Historically, we expected strings (instead of "proper" path objects), and only added :class:`~pathlib.Path` support in a few select places, leading to inconsistent behavior. (:gh:`11473` and :gh:`11499` by `Mathieu Scheltienne`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 386d74eb809..164949340d1 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -204,6 +204,8 @@ .. _Jack Zhang: https://github.com/jackz314 +.. _Jacob Woessner: https://github.com/withmywoessner + .. _Jair Montoya Martinez: https://github.com/jmontoyam .. _Jan Ebert: https://www.jan-ebert.com/ diff --git a/mne/datasets/config.py b/mne/datasets/config.py index da2ed677566..ec45dbbf91b 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing='0.145', misc='0.26') +RELEASES = dict(testing='0.146', misc='0.26') TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -111,7 +111,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS['testing'] = dict( archive_name=f'{TESTING_VERSIONED}.tar.gz', - hash='md5:2036f7d7616129c624b757fbb019be24', + hash='md5:a2e86fe404f4321408b22f38711d11b7', url=('https://codeload.github.com/mne-tools/mne-testing-data/' f'tar.gz/{RELEASES["testing"]}'), # In case we ever have to resort to osf.io again... diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index f7fdef8710c..44876435678 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -46,6 +46,17 @@ def _read_annotations_cnt(fname, data_format='int16'): SETUP_NCHANNELS_OFFSET = 370 SETUP_RATE_OFFSET = 376 + def _accept_reject_function(keypad_accept): + accept_list = [] + for code in keypad_accept: + if 'xd0' in str(code): + accept_list.append('good') + elif 'xc0' in str(code): + accept_list.append('bad') + else: + accept_list.append('NA') + return np.array(accept_list) + def _translating_function(offset, n_channels, event_type, data_format=data_format): n_bytes = 2 if data_format == 'int16' else 4 @@ -53,7 +64,47 @@ def _translating_function(offset, n_channels, event_type, offset *= n_bytes * n_channels event_time = offset - 900 - (75 * n_channels) event_time //= n_channels * n_bytes - return event_time - 1 + event_time = event_time - 1 + # Prevent negative event times + np.clip(event_time, 0, None, out=event_time) + return event_time + + def _update_bad_span_onset(accept_reject, onset, duration, description): + accept_reject = accept_reject.tolist() + onset = onset.tolist() + duration = duration.tolist() + description = description.tolist() + # If there are no bad spans, return original parameters + if 'bad' not in accept_reject: + return np.array(onset), np.array(duration), np.array(description) + # Create lists of bad and good span markers and onset + bad_good_span_markers = [i for i in accept_reject + if i in ['bad', 'good']] + bad_good_onset = [onset[i] for i, value in enumerate(accept_reject) + if value in ['bad', 'good']] + # Calculate duration of bad span + first_bad_index = bad_good_span_markers.index('bad') + duration_list = [bad_good_onset[i + 1] - bad_good_onset[i] + for i in range(first_bad_index, + len(bad_good_span_markers), 2)] + # Add bad event marker duration and description + duration_list_index = 0 + for i in range(len(onset)): + if accept_reject[i] == 'bad': + duration[i] = duration_list[duration_list_index] + description[i] = 'BAD_' + description[i] + duration_list_index += 1 + # Remove good span markers + final_onset, final_duration, final_description = [], [], [] + for i in range(len(accept_reject)): + if accept_reject[i] != 'good': + final_onset.append(onset[i]) + final_duration.append(duration[i]) + final_description.append(description[i]) + return ( + np.array(final_onset), + np.array(final_duration), + np.array(final_description)) with open(fname, 'rb') as fid: fid.seek(SETUP_NCHANNELS_OFFSET) @@ -86,9 +137,16 @@ def _translating_function(offset, n_channels, event_type, data_format=data_format) duration = np.array([getattr(e, 'Latency', 0.) for e in my_events], dtype=float) + accept_reject = _accept_reject_function( + np.array([e.KeyPad_Accept for e in my_events])) description = np.array([str(e.StimType) for e in my_events]) - return Annotations(onset=onset / sfreq, + + onset, duration, description = _update_bad_span_onset(accept_reject, + onset / sfreq, + duration, + description) + return Annotations(onset=onset, duration=duration, description=description, orig_time=None) diff --git a/mne/io/cnt/tests/test_cnt.py b/mne/io/cnt/tests/test_cnt.py index ac37c7fe38e..f4af393f06d 100644 --- a/mne/io/cnt/tests/test_cnt.py +++ b/mne/io/cnt/tests/test_cnt.py @@ -15,6 +15,7 @@ data_path = testing.data_path(download=False) fname = data_path / "CNT" / "scan41_short.cnt" +fname_bad_spans = data_path / "CNT" / "test_CNT_events_mne_JWoess_clipped.cnt" @testing.requires_testing_data @@ -50,3 +51,11 @@ def test_compare_events_and_annotations(): assert len(annot) == 6 assert_array_equal(annot.onset[:-1], events[:, 0] / raw.info['sfreq']) assert 'STI 014' not in raw.info['ch_names'] + + +@testing.requires_testing_data +def test_bad_spans(): + """Test reading raw cnt files with bad spans.""" + annot = read_annotations(fname_bad_spans) + temp = '\t'.join(annot.description) + assert 'BAD' in temp From 97bbc37a7375e835f3db7d2d99f8168f7338abae Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 28 Apr 2023 11:12:13 -0400 Subject: [PATCH 0043/1125] MAINT: Use black --- .github/workflows/tests.yml | 1 + .pre-commit-config.yaml | 10 +++++----- azure-pipelines.yml | 2 +- ignore_words.txt | 2 ++ mne/channels/tests/test_montage.py | 12 ++++++------ mne/chpi.py | 11 ++++++----- mne/conftest.py | 6 +++--- mne/coreg.py | 14 ++++++++++---- mne/io/kit/tests/test_kit.py | 19 ++++++++++++------- mne/io/tests/test_raw.py | 12 +++++++++--- mne/io/tests/test_reference.py | 3 ++- mne/source_space.py | 9 +++++++-- mne/tests/test_annotations.py | 4 ++-- mne/tests/test_docstring_parameters.py | 5 ++++- mne/tests/test_source_estimate.py | 2 +- mne/tests/test_source_space.py | 2 +- mne/transforms.py | 2 +- mne/utils/tests/test_check.py | 4 +++- mne/viz/backends/_notebook.py | 10 ++++++++-- mne/viz/backends/_qt.py | 12 +++++++++--- pyproject.toml | 4 +++- requirements_testing.txt | 1 + .../forward/50_background_freesurfer_mne.py | 2 +- 23 files changed, 98 insertions(+), 51 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3e8a3195c7a..d535e037c0a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,7 @@ jobs: - uses: actions/setup-python@v4 with: python-version: '3.11' + - uses: psf/black@stable - uses: pre-commit/action@v3.0.0 pytest: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ef76755eae..59a60c19015 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,9 @@ repos: -# - repo: https://github.com/psf/black -# rev: 23.1.0 -# hooks: -# - id: black -# args: [--quiet] +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + args: [--quiet] # Ruff mne - repo: https://github.com/charliermarsh/ruff-pre-commit diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b050cc191c1..f0665efc164 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -57,7 +57,7 @@ stages: displayName: Install dependencies - bash: | make pre-commit - displayName: make ruff + displayName: make pre-commit condition: always() - bash: | make nesting diff --git a/ignore_words.txt b/ignore_words.txt index 8dde5403c07..c09662e1a1a 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -14,6 +14,7 @@ nd cas thes ba +bu ist od fo @@ -33,3 +34,4 @@ recuse ro nam shs +pres diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 5fe16a2294d..f78e6bb3f2d 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -223,12 +223,12 @@ def test_documented(): pytest.param( partial(read_custom_montage, head_size=None, coord_frame='mri'), - ('// MatLab Sphere coordinates [degrees] Cartesian coordinates\n' # noqa: E501 - '// Label Theta Phi Radius X Y Z off sphere surface\n' # noqa: E501 - 'E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n' # noqa: E501 - 'E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n' # noqa: E501 - 'E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n' # noqa: E501 - 'E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022'), # noqa: E501 + "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 + "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 + "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 + "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 + "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 + "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 make_dig_montage( ch_pos={ 'E1': [0.7677, 0.5934, -0.2419], diff --git a/mne/chpi.py b/mne/chpi.py index 648ad6ca78a..9d80fa6efde 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -847,11 +847,12 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, # 1. Check number of good ones # if len(use_idx) < 3: - msg = (_time_prefix(fit_time) + '%s/%s good HPI fits, cannot ' - 'determine the transformation (%s GOF)!' - % (len(use_idx), n_coils, - ', '.join('%0.2f' % g for g in g_coils))) - warn(msg) + gofs = ', '.join(f"{g:0.2f}" for g in g_coils) + warn( + f"{_time_prefix(fit_time)}{len(use_idx)}/{n_coils} " + "good HPI fits, cannot determine the transformation " + f"({gofs} GOF)!" + ) continue # diff --git a/mne/conftest.py b/mne/conftest.py index 9a64066b852..72e95b6e788 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -101,8 +101,8 @@ def pytest_configure(config): first_kind = 'error' else: first_kind = 'always' - warning_lines = r""" - {0}:: + warning_lines = f" {first_kind}::" + warning_lines += r""" # matplotlib->traitlets (notebook) ignore:Passing unrecognized arguments to super.*:DeprecationWarning # notebook tests @@ -142,7 +142,7 @@ def pytest_configure(config): ignore:pkg_resources is deprecated as an API.*:DeprecationWarning # h5py ignore:`product` is deprecated as of NumPy.*:DeprecationWarning - """.format(first_kind) # noqa: E501 + """ # noqa: E501 for warning_line in warning_lines.split('\n'): warning_line = warning_line.strip() if warning_line and not warning_line.startswith('#'): diff --git a/mne/coreg.py b/mne/coreg.py index db0b3645633..3e21f3ff917 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -23,11 +23,17 @@ from .io._digitization import _get_data_as_dict_from_dig # keep get_mni_fiducials for backward compat (no burden to keep in this # namespace, too) -from ._freesurfer import (_read_mri_info, get_mni_fiducials, # noqa: F401 - estimate_head_mri_t) # noqa: F401 +from ._freesurfer import ( + _read_mri_info, + get_mni_fiducials, + estimate_head_mri_t, # noqa: F401 +) from .label import read_label, Label -from .source_space import (add_source_space_distances, read_source_spaces, # noqa: E501,F401 - write_source_spaces) +from .source_space import ( + add_source_space_distances, + read_source_spaces, # noqa: F401 + write_source_spaces, +) from .surface import (read_surface, write_surface, _normalize_vectors, complete_surface_info, decimate_surface, _DistanceQuery) diff --git a/mne/io/kit/tests/test_kit.py b/mne/io/kit/tests/test_kit.py index d3746012328..696d10a83da 100644 --- a/mne/io/kit/tests/test_kit.py +++ b/mne/io/kit/tests/test_kit.py @@ -69,8 +69,9 @@ def test_data(tmp_path): # check functionality raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_txt_path, hsp_txt_path) - assert raw_mrk.info['description'] == \ - 'NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C' + assert ( + raw_mrk.info['description'] == 'NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C' # noqa: E501 + ) raw_py = _test_raw_reader(read_raw_kit, input_fname=sqd_path, mrk=mrk_path, elp=elp_txt_path, hsp=hsp_txt_path, stim=list(range(167, 159, -1)), slope='+', @@ -123,8 +124,9 @@ def test_data(tmp_path): # KIT-UMD data _test_raw_reader(read_raw_kit, input_fname=sqd_umd_path, test_rank='less') raw = read_raw_kit(sqd_umd_path) - assert raw.info['description'] == \ - 'University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R' # noqa: E501 + assert ( + raw.info['description'] == 'University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R' # noqa: E501 + ) assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_UMD_2014_12) # check number/kind of channels assert_equal(len(raw.info['chs']), 193) @@ -135,8 +137,9 @@ def test_data(tmp_path): # KIT Academia Sinica raw = read_raw_kit(sqd_as_path, slope='+') - assert raw.info['description'] == \ - 'Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2' # noqa: E501 + assert ( + raw.info['description'] == 'Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2' # noqa: E501 + ) assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_AS_2008) assert_equal(raw.info['chs'][100]['ch_name'], 'MEG 101') assert_equal(raw.info['chs'][100]['kind'], FIFF.FIFFV_MEG_CH) @@ -374,7 +377,9 @@ def test_berlin(): """Test data from Berlin.""" # gh-8535 raw = read_raw_kit(berlin_path) - assert raw.info['description'] == 'Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2' # noqa: E501 + assert ( + raw.info['description'] == 'Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2' # noqa: E501 + ) assert raw.info['kit_system_id'] == 124 assert raw.info['highpass'] == 0. assert raw.info['lowpass'] == 200. diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 694cd46c941..4c728df90ef 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -739,9 +739,15 @@ def test_describe_print(): assert re.match( r'', s[0]) is not None, s[0] - assert s[1] == " ch name type unit min Q1 median Q3 max" # noqa - assert s[2] == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa - assert s[-1] == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa + assert ( + s[1] == " ch name type unit min Q1 median Q3 max" # noqa: E501 + ) + assert ( + s[2] == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa: E501 + ) + assert ( + s[-1] == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa: E501 + ) @requires_pandas diff --git a/mne/io/tests/test_reference.py b/mne/io/tests/test_reference.py index 8ab37fb5879..0cfb2a5349e 100644 --- a/mne/io/tests/test_reference.py +++ b/mne/io/tests/test_reference.py @@ -329,7 +329,8 @@ def test_set_eeg_reference_rest(): # load('leadfield.mat', 'G'); # dat_ref = ft_preproc_rereference(dat, 'all', 'rest', true, G); # sprintf('%g ', dat_ref(:, 171)); - want = np.array('-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05'.split(' '), float) # noqa: E501 + data_array = "-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05" # noqa: E501 + want = np.array(data_array.split(" "), float) norm = np.linalg.norm(want) idx = np.argmin(np.abs(evoked.times - 0.083)) assert idx == 170 diff --git a/mne/source_space.py b/mne/source_space.py index 8c7e8899ea1..6eb6c000537 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -31,8 +31,13 @@ complete_surface_info, _compute_nearest, fast_cross_3d, _CheckInside) # keep get_mni_fiducials here just for easy backward compat -from ._freesurfer import (_get_mri_info_data, _get_atlas_values, # noqa: F401 - read_freesurfer_lut, get_mni_fiducials, _check_mri) +from ._freesurfer import ( + _get_mri_info_data, + _get_atlas_values, + read_freesurfer_lut, + get_mni_fiducials, # noqa: F401 + _check_mri, +) from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc, _ensure_int, _get_call_line, warn, object_size, sizeof_fmt, _check_fname, _path_like, _check_sphere, _import_nibabel, diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index d1a311bc9ae..d1e35cebc0d 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -987,7 +987,7 @@ def test_io_annotation_txt(dummy_annotation_txt_file, tmp_path_factory, pytest.param(None, None, id='None'), pytest.param(42, 42.0, id='Scalar'), pytest.param(3.14, 3.14, id='Float'), - pytest.param((3, 140000), 3.14, id='Scalar touple'), + pytest.param((3, 140000), 3.14, id="Scalar tuple"), pytest.param('2002-12-03 19:01:11.720100', 1038942071.7201, id='valid iso8601 string'), pytest.param('2002-12-03T19:01:11.720100', None, @@ -1355,7 +1355,7 @@ def test_annotation_ch_names(): assert raw_2.annotations.ch_names[1] == tuple(raw.ch_names[4:5]) for ch_drop in raw_2.annotations.ch_names: assert all(name in raw_2.ch_names for name in ch_drop) - with pytest.raises(ValueError, match='channel name in annotations missin'): + with pytest.raises(ValueError, match='channel name in annotations miss'): raw_2.set_annotations(annot) with pytest.warns(RuntimeWarning, match='channel name in annotations mis'): raw_2.set_annotations(annot, on_missing='warn') diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index ddfec15686e..7a3f59783bc 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -130,7 +130,10 @@ def check_parameters_match(func, cls=None): msg = str(exc) # E ValueError: no signature found for builtin type # - if inspect.isclass(callable_) and 'no signature found for buil' in msg: + if ( + inspect.isclass(callable_) and + "no signature found for builtin type" in msg + ): pass else: raise diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index ec6da53cf56..02e174556c1 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1898,5 +1898,5 @@ def test_label_extraction_subject(kind): with pytest.raises(ValueError, match=r'label\.sub.*not match.* stc\.'): extract_label_time_course(stc, labels_fs, src) stc.subject = None - with pytest.raises(ValueError, match=r'label\.sub.*not match.* sourc'): + with pytest.raises(ValueError, match=r"label\.sub.*not match.* sour"): extract_label_time_course(stc, labels_fs, src) diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index 364e250284a..83ad939d7ef 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -479,7 +479,7 @@ def test_setup_source_space(tmp_path): setup_source_space('sample', spacing='7emm', add_dist=False, subjects_dir=subjects_dir) with pytest.raises(ValueError, match='must be a string with values'): - setup_source_space('sample', spacing='alls', + setup_source_space("sample", spacing="ally", add_dist=False, subjects_dir=subjects_dir) # ico 5 (fsaverage) - write to temp file diff --git a/mne/transforms.py b/mne/transforms.py index 39ab647f479..1514b2ad2d3 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1822,7 +1822,7 @@ def apply_volume_registration(moving, static, reg_affine, sdr_morph=None, moving.shape, moving_affine) reg_data = affine_map.transform(moving, interpolation=interpolation) if sdr_morph is not None: - logger.info('Appling SDR warp ...') + logger.info("Applying SDR warp ...") reg_data = sdr_morph.transform( reg_data, interpolation=interpolation, image_world2grid=np.linalg.inv(static_affine), diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 44caa61ba10..5763649dd5d 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -204,7 +204,9 @@ def test_suggest(): sug = _suggest('Left-cerebellum', names) assert sug == " Did you mean 'Left-Cerebellum-Cortex'?" sug = _suggest('Cerebellum-Cortex', names) - assert sug == " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?" # noqa: E501 + assert ( + sug == " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?" # noqa: E501 + ) def test_on_missing(): diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index c239aa9e42c..187c02e23c9 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -34,8 +34,14 @@ _AbstractWidgetList, _AbstractAction, _AbstractDialog, _AbstractKeyPress) from ._pyvista import _PyVistaRenderer, Plotter -from ._pyvista import (_close_3d_figure, _check_3d_figure, _close_all, # noqa: F401,E501 analysis:ignore - _set_3d_view, _set_3d_title, _take_3d_screenshot) # noqa: F401,E501 analysis:ignore +from ._pyvista import ( + _close_3d_figure, # noqa: F401 + _check_3d_figure, # noqa: F401 + _close_all, # noqa: F401 + _set_3d_view, # noqa: F401 + _set_3d_title, # noqa: F401 + _take_3d_screenshot, # noqa: F401 +) from ._utils import _notebook_vtk_works diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index fa8b3b9b9be..d058a505c34 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -32,9 +32,15 @@ QSpinBox, QStyle, QStyleOptionSlider) from ._pyvista import _PyVistaRenderer -from ._pyvista import (_close_3d_figure, _check_3d_figure, _close_all, # noqa: F401,E501 analysis:ignore - _set_3d_view, _set_3d_title, _take_3d_screenshot, # noqa: F401,E501 analysis:ignore - _is_mesa) # noqa: F401,E501 analysis:ignore +from ._pyvista import ( + _close_3d_figure, # noqa: F401 + _check_3d_figure, # noqa: F401 + _close_all, # noqa: F401 + _set_3d_view, # noqa: F401 + _set_3d_title, # noqa: F401 + _take_3d_screenshot, # noqa: F401 + _is_mesa, # noqa: F401 +) from ._abstract import (_AbstractAppWindow, _AbstractHBoxLayout, _AbstractVBoxLayout, _AbstractGridLayout, _AbstractWidget, _AbstractCanvas, diff --git a/pyproject.toml b/pyproject.toml index b8b664ab193..8e7f703106d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [tool.codespell] ignore-words = "ignore_words.txt" -uri-ignore-words-list = "bu" builtin = "clear,rare,informal,names,usage" skip = "doc/references.bib" @@ -48,3 +47,6 @@ addopts = """--durations=20 --doctest-modules -ra --cov-report= --tb=short \ --ignore=mne/report/js_and_css \ --color=yes --capture=sys""" junit_family = "xunit2" + +[tool.black] +exclude = "(dist/)|(build/)|(.*\\.ipynb)" diff --git a/requirements_testing.txt b/requirements_testing.txt index fa6c7b86b3f..aad9e7ea206 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -11,3 +11,4 @@ tomli; python_version<'3.11' twine wheel pre-commit +black diff --git a/tutorials/forward/50_background_freesurfer_mne.py b/tutorials/forward/50_background_freesurfer_mne.py index a204272b57f..4d67e3e19b3 100644 --- a/tutorials/forward/50_background_freesurfer_mne.py +++ b/tutorials/forward/50_background_freesurfer_mne.py @@ -128,7 +128,7 @@ def imshow_mri(data, img, vox, xyz, suptitle): # Figure out the title based on the code of this axis ori_slice = dict(P='Coronal', A='Coronal', I='Axial', S='Axial', - L='Sagittal', R='Saggital') + L='Sagittal', R='Sagittal') ori_names = dict(P='posterior', A='anterior', I='inferior', S='superior', L='left', R='right') From e81ec528a42ac687f3d961ed5cf8e25f236925b0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 2 May 2023 13:22:53 -0400 Subject: [PATCH 0044/1125] MAINT: Run black on codebase --- doc/conf.py | 1958 ++++++----- doc/sphinxext/flow_diagram.py | 184 +- doc/sphinxext/gen_commands.py | 52 +- doc/sphinxext/gen_names.py | 21 +- doc/sphinxext/gh_substitutions.py | 8 +- doc/sphinxext/mne_substitutions.py | 51 +- doc/sphinxext/newcontrib_substitutions.py | 13 +- doc/sphinxext/unit_role.py | 10 +- examples/datasets/brainstorm_data.py | 34 +- examples/datasets/hf_sef_data.py | 7 +- examples/datasets/limo_data.py | 95 +- examples/datasets/opm_data.py | 95 +- examples/datasets/spm_faces_dataset_sgskip.py | 67 +- examples/decoding/decoding_csp_eeg.py | 57 +- examples/decoding/decoding_csp_timefreq.py | 103 +- examples/decoding/decoding_rsa_sgskip.py | 76 +- .../decoding_spatio_temporal_source.py | 95 +- examples/decoding/decoding_spoc_CMC.py | 27 +- ...decoding_time_generalization_conditions.py | 72 +- .../decoding_unsupervised_spatial_filter.py | 53 +- examples/decoding/decoding_xdawn_eeg.py | 68 +- examples/decoding/ems_filtering.py | 57 +- examples/decoding/linear_model_patterns.py | 33 +- examples/decoding/receptive_field_mtrf.py | 130 +- examples/decoding/ssd_spatial_filters.py | 90 +- examples/forward/forward_sensitivity_maps.py | 70 +- .../forward/left_cerebellum_volume_source.py | 32 +- examples/forward/source_space_morphing.py | 42 +- .../compute_mne_inverse_epochs_in_label.py | 83 +- .../compute_mne_inverse_raw_in_label.py | 24 +- .../inverse/compute_mne_inverse_volume.py | 23 +- examples/inverse/custom_inverse_solver.py | 51 +- examples/inverse/dics_epochs.py | 67 +- examples/inverse/dics_source_power.py | 36 +- examples/inverse/evoked_ers_source_power.py | 116 +- examples/inverse/gamma_map_inverse.py | 75 +- examples/inverse/label_activation_from_stc.py | 42 +- examples/inverse/label_from_stc.py | 65 +- examples/inverse/label_source_activations.py | 76 +- examples/inverse/mixed_norm_inverse.py | 118 +- .../inverse/mixed_source_space_inverse.py | 143 +- examples/inverse/mne_cov_power.py | 95 +- examples/inverse/morph_surface_stc.py | 55 +- examples/inverse/morph_volume_stc.py | 35 +- examples/inverse/multi_dipole_model.py | 81 +- .../inverse/multidict_reweighted_tfmxne.py | 62 +- examples/inverse/psf_ctf_label_leakage.py | 85 +- examples/inverse/psf_ctf_vertices.py | 69 +- examples/inverse/psf_ctf_vertices_lcmv.py | 147 +- examples/inverse/psf_volume.py | 62 +- examples/inverse/rap_music.py | 30 +- examples/inverse/read_inverse.py | 37 +- examples/inverse/read_stc.py | 13 +- examples/inverse/resolution_metrics.py | 127 +- examples/inverse/resolution_metrics_eegmeg.py | 137 +- examples/inverse/snr_estimate.py | 6 +- examples/inverse/source_space_snr.py | 38 +- .../time_frequency_mixed_norm_inverse.py | 120 +- examples/inverse/vector_mne_solution.py | 60 +- examples/io/elekta_epochs.py | 34 +- examples/io/read_neo_format.py | 10 +- examples/io/read_noise_covariance_matrix.py | 6 +- examples/io/read_xdf.py | 6 +- .../contralateral_referencing.py | 46 +- examples/preprocessing/css.py | 67 +- .../preprocessing/define_target_events.py | 49 +- examples/preprocessing/eeg_bridging.py | 191 +- examples/preprocessing/eeg_csd.py | 39 +- .../preprocessing/eog_artifact_histogram.py | 15 +- examples/preprocessing/eog_regression.py | 31 +- examples/preprocessing/find_ref_artifacts.py | 24 +- .../preprocessing/fnirs_artifact_removal.py | 21 +- examples/preprocessing/ica_comparison.py | 27 +- .../preprocessing/interpolate_bad_channels.py | 16 +- .../preprocessing/movement_compensation.py | 26 +- examples/preprocessing/movement_detection.py | 48 +- examples/preprocessing/muscle_detection.py | 14 +- examples/preprocessing/muscle_ica.py | 31 +- examples/preprocessing/otp.py | 43 +- examples/preprocessing/shift_evoked.py | 38 +- examples/preprocessing/virtual_evoked.py | 26 +- examples/preprocessing/xdawn_denoising.py | 34 +- examples/simulation/plot_stc_metrics.py | 165 +- examples/simulation/simulate_evoked_data.py | 58 +- examples/simulation/simulate_raw_data.py | 44 +- ...imulated_raw_data_using_subject_anatomy.py | 104 +- examples/simulation/source_simulator.py | 25 +- examples/stats/cluster_stats_evoked.py | 52 +- examples/stats/fdr_stats_evoked.py | 55 +- examples/stats/linear_regression_raw.py | 38 +- examples/stats/sensor_permutation_test.py | 45 +- examples/stats/sensor_regression.py | 14 +- examples/time_frequency/compute_csd.py | 32 +- .../compute_source_psd_epochs.py | 76 +- .../source_label_time_frequency.py | 91 +- .../time_frequency/source_power_spectrum.py | 43 +- .../source_power_spectrum_opm.py | 162 +- .../source_space_time_frequency.py | 47 +- examples/time_frequency/temporal_whitening.py | 26 +- .../time_frequency/time_frequency_erds.py | 107 +- .../time_frequency_global_field_power.py | 73 +- .../time_frequency_simulated.py | 170 +- examples/visualization/3d_to_2d.py | 27 +- examples/visualization/brain.py | 47 +- .../visualization/channel_epochs_image.py | 53 +- examples/visualization/eeg_on_scalp.py | 21 +- examples/visualization/evoked_arrowmap.py | 29 +- examples/visualization/evoked_topomap.py | 67 +- examples/visualization/evoked_whitening.py | 47 +- examples/visualization/meg_sensors.py | 57 +- examples/visualization/mne_helmet.py | 49 +- examples/visualization/montage_sgskip.py | 29 +- examples/visualization/parcellation.py | 55 +- examples/visualization/publication_figure.py | 130 +- examples/visualization/roi_erpimage_by_rt.py | 76 +- examples/visualization/sensor_noise_level.py | 7 +- .../ssp_projs_sensitivity_map.py | 20 +- .../visualization/topo_compare_conditions.py | 22 +- examples/visualization/topo_customized.py | 31 +- examples/visualization/xhemi.py | 29 +- logo/generate_mne_logos.py | 173 +- mne/__init__.py | 261 +- mne/__main__.py | 2 +- mne/_freesurfer.py | 375 +- mne/_ola.py | 286 +- mne/annotations.py | 623 ++-- mne/baseline.py | 89 +- mne/beamformer/__init__.py | 18 +- mne/beamformer/_compute_beamformer.py | 352 +- mne/beamformer/_dics.py | 268 +- mne/beamformer/_lcmv.py | 208 +- mne/beamformer/_rap_music.py | 59 +- mne/beamformer/resolution_matrix.py | 15 +- mne/beamformer/tests/test_dics.py | 765 ++-- mne/beamformer/tests/test_external.py | 72 +- mne/beamformer/tests/test_lcmv.py | 977 ++++-- mne/beamformer/tests/test_rap_music.py | 153 +- .../tests/test_resolution_matrix.py | 43 +- mne/bem.py | 1586 +++++---- mne/channels/__init__.py | 100 +- mne/channels/_dig_montage_utils.py | 66 +- mne/channels/_standard_montage_utils.py | 227 +- mne/channels/channels.py | 1373 +++++--- mne/channels/interpolation.py | 105 +- mne/channels/layout.py | 438 ++- mne/channels/montage.py | 764 ++-- mne/channels/tests/test_channels.py | 441 +-- mne/channels/tests/test_interpolation.py | 211 +- mne/channels/tests/test_layout.py | 233 +- mne/channels/tests/test_montage.py | 1793 ++++++---- mne/channels/tests/test_standard_montage.py | 266 +- mne/chpi.py | 973 +++--- mne/commands/mne_anonymize.py | 67 +- mne/commands/mne_browse_raw.py | 189 +- mne/commands/mne_bti2fiff.py | 97 +- mne/commands/mne_clean_eog_ecg.py | 183 +- mne/commands/mne_compare_fiff.py | 3 +- mne/commands/mne_compute_proj_ecg.py | 331 +- mne/commands/mne_compute_proj_eog.py | 306 +- mne/commands/mne_coreg.py | 142 +- mne/commands/mne_flash_bem.py | 161 +- mne/commands/mne_freeview_bem_surfaces.py | 70 +- mne/commands/mne_kit2fiff.py | 83 +- mne/commands/mne_make_scalp_surfaces.py | 72 +- mne/commands/mne_maxfilter.py | 254 +- mne/commands/mne_prepare_bem_model.py | 36 +- mne/commands/mne_report.py | 114 +- mne/commands/mne_setup_forward_model.py | 135 +- mne/commands/mne_setup_source_space.py | 175 +- mne/commands/mne_show_fiff.py | 12 +- mne/commands/mne_show_info.py | 7 +- mne/commands/mne_surf2bem.py | 22 +- mne/commands/mne_sys_info.py | 36 +- mne/commands/mne_watershed_bem.py | 117 +- mne/commands/mne_what.py | 3 +- mne/commands/tests/test_commands.py | 400 ++- mne/commands/utils.py | 35 +- mne/conftest.py | 581 +-- mne/coreg.py | 1201 ++++--- mne/cov.py | 1337 ++++--- mne/cuda.py | 150 +- mne/datasets/__init__.py | 48 +- mne/datasets/_fake/_fake.py | 25 +- mne/datasets/_fetch.py | 40 +- mne/datasets/_fsaverage/base.py | 22 +- mne/datasets/_infant/base.py | 25 +- mne/datasets/_phantom/base.py | 18 +- mne/datasets/brainstorm/__init__.py | 3 +- mne/datasets/brainstorm/bst_auditory.py | 42 +- mne/datasets/brainstorm/bst_phantom_ctf.py | 42 +- mne/datasets/brainstorm/bst_phantom_elekta.py | 43 +- mne/datasets/brainstorm/bst_raw.py | 54 +- mne/datasets/brainstorm/bst_resting.py | 42 +- mne/datasets/config.py | 355 +- mne/datasets/eegbci/eegbci.py | 62 +- mne/datasets/eegbci/tests/test_eegbci.py | 3 +- mne/datasets/epilepsy_ecog/_data.py | 25 +- mne/datasets/erp_core/erp_core.py | 27 +- mne/datasets/eyelink/eyelink.py | 27 +- mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py | 25 +- mne/datasets/fnirs_motor/fnirs_motor.py | 27 +- mne/datasets/hf_sef/hf_sef.py | 48 +- mne/datasets/kiloword/kiloword.py | 21 +- mne/datasets/limo/limo.py | 287 +- mne/datasets/misc/_misc.py | 22 +- mne/datasets/mtrf/mtrf.py | 23 +- mne/datasets/multimodal/multimodal.py | 27 +- mne/datasets/opm/opm.py | 25 +- mne/datasets/phantom_4dbti/phantom_4dbti.py | 25 +- mne/datasets/refmeg_noise/refmeg_noise.py | 25 +- mne/datasets/sample/sample.py | 27 +- mne/datasets/sleep_physionet/_utils.py | 168 +- mne/datasets/sleep_physionet/age.py | 71 +- mne/datasets/sleep_physionet/temazepam.py | 51 +- .../sleep_physionet/tests/test_physionet.py | 196 +- mne/datasets/somato/somato.py | 27 +- mne/datasets/spm_face/spm_data.py | 40 +- mne/datasets/ssvep/ssvep.py | 25 +- mne/datasets/testing/__init__.py | 9 +- mne/datasets/testing/_testing.py | 52 +- mne/datasets/tests/test_datasets.py | 255 +- .../ucl_opm_auditory/ucl_opm_auditory.py | 24 +- mne/datasets/utils.py | 623 ++-- .../visual_92_categories.py | 25 +- mne/decoding/__init__.py | 11 +- mne/decoding/base.py | 169 +- mne/decoding/csp.py | 349 +- mne/decoding/ems.py | 49 +- mne/decoding/mixin.py | 21 +- mne/decoding/receptive_field.py | 168 +- mne/decoding/search_light.py | 131 +- mne/decoding/ssd.py | 188 +- mne/decoding/tests/test_base.py | 175 +- mne/decoding/tests/test_csp.py | 209 +- mne/decoding/tests/test_ems.py | 39 +- mne/decoding/tests/test_receptive_field.py | 399 ++- mne/decoding/tests/test_search_light.py | 107 +- mne/decoding/tests/test_ssd.py | 335 +- mne/decoding/tests/test_time_frequency.py | 5 +- mne/decoding/tests/test_transformer.py | 184 +- mne/decoding/time_delaying_ridge.py | 136 +- mne/decoding/time_frequency.py | 37 +- mne/decoding/transformer.py | 261 +- mne/defaults.py | 394 ++- mne/dipole.py | 1022 ++++-- mne/epochs.py | 2158 ++++++++---- mne/event.py | 632 ++-- mne/evoked.py | 1103 ++++-- mne/export/_brainvision.py | 1 + mne/export/_edf.py | 192 +- mne/export/_eeglab.py | 55 +- mne/export/_egimff.py | 99 +- mne/export/_export.py | 89 +- mne/export/tests/test_export.py | 376 +- mne/filter.py | 1715 ++++++--- mne/fixes.py | 282 +- mne/forward/__init__.py | 64 +- mne/forward/_compute_forward.py | 293 +- mne/forward/_field_interpolation.py | 347 +- mne/forward/_lead_dots.py | 247 +- mne/forward/_make_forward.py | 661 ++-- mne/forward/forward.py | 1396 ++++---- mne/forward/tests/test_field_interpolation.py | 253 +- mne/forward/tests/test_forward.py | 393 ++- mne/forward/tests/test_make_forward.py | 711 ++-- mne/gui/__init__.py | 199 +- mne/gui/_core.py | 398 ++- mne/gui/_coreg.py | 744 ++-- mne/gui/_ieeg_locate.py | 501 ++- mne/gui/tests/test_core.py | 37 +- mne/gui/tests/test_coreg.py | 198 +- mne/gui/tests/test_gui_api.py | 254 +- mne/gui/tests/test_ieeg_locate.py | 186 +- mne/html_templates/_templates.py | 19 +- mne/inverse_sparse/__init__.py | 3 +- mne/inverse_sparse/_gamma_map.py | 133 +- mne/inverse_sparse/mxne_debiasing.py | 18 +- mne/inverse_sparse/mxne_inverse.py | 592 +++- mne/inverse_sparse/mxne_optim.py | 772 ++-- mne/inverse_sparse/tests/test_gamma_map.py | 171 +- mne/inverse_sparse/tests/test_mxne_inverse.py | 516 ++- mne/inverse_sparse/tests/test_mxne_optim.py | 366 +- mne/io/__init__.py | 21 +- mne/io/_digitization.py | 372 +- mne/io/_read_raw.py | 51 +- mne/io/array/array.py | 67 +- mne/io/array/tests/test_array.py | 107 +- mne/io/artemis123/artemis123.py | 455 +-- mne/io/artemis123/tests/test_artemis123.py | 88 +- mne/io/artemis123/utils.py | 70 +- mne/io/base.py | 1522 +++++--- mne/io/besa/besa.py | 190 +- mne/io/besa/tests/test_besa.py | 60 +- mne/io/boxy/boxy.py | 182 +- mne/io/boxy/tests/test_boxy.py | 138 +- mne/io/brainvision/brainvision.py | 619 ++-- mne/io/brainvision/tests/test_brainvision.py | 744 ++-- mne/io/bti/bti.py | 1466 ++++---- mne/io/bti/constants.py | 134 +- mne/io/bti/read.py | 39 +- mne/io/bti/tests/test_bti.py | 411 ++- mne/io/cnt/_utils.py | 70 +- mne/io/cnt/cnt.py | 357 +- mne/io/cnt/tests/test_cnt.py | 29 +- mne/io/compensator.py | 69 +- mne/io/constants.py | 1628 +++++---- mne/io/ctf/constants.py | 2 +- mne/io/ctf/ctf.py | 208 +- mne/io/ctf/eeg.py | 73 +- mne/io/ctf/hc.py | 58 +- mne/io/ctf/info.py | 518 +-- mne/io/ctf/markers.py | 67 +- mne/io/ctf/res4.py | 231 +- mne/io/ctf/tests/test_ctf.py | 667 ++-- mne/io/ctf/trans.py | 108 +- mne/io/ctf_comp.py | 92 +- mne/io/curry/curry.py | 387 +- mne/io/curry/tests/test_curry.py | 421 ++- mne/io/diff.py | 6 +- mne/io/edf/edf.py | 1139 +++--- mne/io/edf/tests/test_edf.py | 655 ++-- mne/io/edf/tests/test_gdf.py | 91 +- mne/io/eeglab/_eeglab.py | 8 +- mne/io/eeglab/eeglab.py | 398 ++- mne/io/eeglab/tests/test_eeglab.py | 600 ++-- mne/io/egi/egi.py | 279 +- mne/io/egi/egimff.py | 697 ++-- mne/io/egi/events.py | 57 +- mne/io/egi/general.py | 126 +- mne/io/egi/tests/test_egi.py | 429 ++- mne/io/eximia/eximia.py | 57 +- mne/io/eximia/tests/test_eximia.py | 38 +- mne/io/eyelink/eyelink.py | 669 ++-- mne/io/eyelink/tests/test_eyelink.py | 116 +- mne/io/fieldtrip/__init__.py | 3 +- mne/io/fieldtrip/fieldtrip.py | 53 +- mne/io/fieldtrip/tests/helpers.py | 166 +- mne/io/fieldtrip/tests/test_fieldtrip.py | 185 +- mne/io/fieldtrip/utils.py | 299 +- mne/io/fiff/raw.py | 274 +- mne/io/fiff/tests/test_raw_fiff.py | 1166 ++++--- mne/io/fil/__init__.py | 2 +- mne/io/fil/fil.py | 226 +- mne/io/fil/sensors.py | 5 +- mne/io/fil/tests/test_fil.py | 37 +- mne/io/hitachi/hitachi.py | 221 +- mne/io/hitachi/tests/test_hitachi.py | 288 +- mne/io/kit/constants.py | 98 +- mne/io/kit/coreg.py | 128 +- mne/io/kit/kit.py | 624 ++-- mne/io/kit/tests/test_coreg.py | 4 +- mne/io/kit/tests/test_kit.py | 357 +- mne/io/matrix.py | 95 +- mne/io/meas_info.py | 2101 ++++++----- mne/io/nedf/nedf.py | 108 +- mne/io/nedf/tests/test_nedf.py | 92 +- mne/io/nicolet/nicolet.py | 134 +- mne/io/nicolet/tests/test_nicolet.py | 18 +- mne/io/nihon/nihon.py | 367 +- mne/io/nihon/tests/test_nihon.py | 55 +- mne/io/nirx/_localized_abbr.py | 108 +- mne/io/nirx/nirx.py | 406 ++- mne/io/nirx/tests/test_nirx.py | 641 ++-- mne/io/open.py | 159 +- mne/io/persyst/persyst.py | 207 +- mne/io/persyst/tests/test_persyst.py | 97 +- mne/io/pick.py | 1102 +++--- mne/io/proc_history.py | 246 +- mne/io/proj.py | 608 ++-- mne/io/reference.py | 377 +- mne/io/snirf/_snirf.py | 470 +-- mne/io/snirf/tests/test_snirf.py | 369 +- mne/io/tag.py | 223 +- mne/io/tests/__init__.py | 2 +- mne/io/tests/test_apply_function.py | 19 +- mne/io/tests/test_compensator.py | 36 +- mne/io/tests/test_constants.py | 384 +- mne/io/tests/test_meas_info.py | 926 ++--- mne/io/tests/test_pick.py | 602 ++-- mne/io/tests/test_proc_history.py | 36 +- mne/io/tests/test_raw.py | 583 ++-- mne/io/tests/test_read_raw.py | 61 +- mne/io/tests/test_reference.py | 603 ++-- mne/io/tests/test_show_fiff.py | 21 +- mne/io/tests/test_utils.py | 13 +- mne/io/tests/test_what.py | 37 +- mne/io/tests/test_write.py | 8 +- mne/io/tree.py | 82 +- mne/io/utils.py | 102 +- mne/io/what.py | 43 +- mne/io/write.py | 318 +- mne/label.py | 1169 ++++--- mne/minimum_norm/__init__.py | 36 +- mne/minimum_norm/_eloreta.py | 106 +- mne/minimum_norm/inverse.py | 1456 +++++--- mne/minimum_norm/resolution_matrix.py | 213 +- mne/minimum_norm/spatial_resolution.py | 119 +- mne/minimum_norm/tests/test_inverse.py | 1405 +++++--- .../tests/test_resolution_matrix.py | 207 +- .../tests/test_resolution_metrics.py | 135 +- mne/minimum_norm/tests/test_snr.py | 24 +- mne/minimum_norm/tests/test_time_frequency.py | 277 +- mne/minimum_norm/time_frequency.py | 558 ++- mne/misc.py | 33 +- mne/morph.py | 845 +++-- mne/morph_map.py | 96 +- mne/parallel.py | 64 +- mne/preprocessing/__init__.py | 34 +- mne/preprocessing/_csd.py | 151 +- mne/preprocessing/_css.py | 18 +- mne/preprocessing/_fine_cal.py | 296 +- mne/preprocessing/_peak_finder.py | 21 +- mne/preprocessing/_regress.py | 192 +- mne/preprocessing/annotate_amplitude.py | 104 +- mne/preprocessing/annotate_nan.py | 7 +- mne/preprocessing/artifact_detection.py | 274 +- mne/preprocessing/bads.py | 1 + mne/preprocessing/ctps_.py | 18 +- mne/preprocessing/ecg.py | 295 +- mne/preprocessing/eog.py | 214 +- mne/preprocessing/eyetracking/eyetracking.py | 115 +- mne/preprocessing/hfc.py | 30 +- mne/preprocessing/ica.py | 1807 ++++++---- mne/preprocessing/ieeg/_projection.py | 127 +- mne/preprocessing/ieeg/_volume.py | 161 +- .../ieeg/tests/test_projection.py | 147 +- mne/preprocessing/ieeg/tests/test_volume.py | 104 +- mne/preprocessing/infomax_.py | 92 +- mne/preprocessing/interpolate.py | 70 +- mne/preprocessing/maxfilter.py | 124 +- mne/preprocessing/maxwell.py | 1781 ++++++---- mne/preprocessing/nirs/__init__.py | 17 +- mne/preprocessing/nirs/_beer_lambert_law.py | 64 +- mne/preprocessing/nirs/_optical_density.py | 6 +- .../nirs/_scalp_coupling_index.py | 26 +- mne/preprocessing/nirs/_tddr.py | 13 +- mne/preprocessing/nirs/nirs.py | 182 +- .../nirs/tests/test_beer_lambert_law.py | 75 +- mne/preprocessing/nirs/tests/test_nirs.py | 341 +- .../nirs/tests/test_optical_density.py | 25 +- .../nirs/tests/test_scalp_coupling_index.py | 33 +- ...temporal_derivative_distribution_repair.py | 18 +- mne/preprocessing/otp.py | 34 +- mne/preprocessing/realign.py | 49 +- mne/preprocessing/ssp.py | 431 ++- mne/preprocessing/stim.py | 56 +- .../tests/test_annotate_amplitude.py | 288 +- mne/preprocessing/tests/test_annotate_nan.py | 14 +- .../tests/test_artifact_detection.py | 138 +- mne/preprocessing/tests/test_csd.py | 131 +- mne/preprocessing/tests/test_css.py | 31 +- mne/preprocessing/tests/test_ctps.py | 54 +- mne/preprocessing/tests/test_ecg.py | 78 +- .../tests/test_eeglab_infomax.py | 67 +- mne/preprocessing/tests/test_eog.py | 6 +- mne/preprocessing/tests/test_fine_cal.py | 71 +- mne/preprocessing/tests/test_hfc.py | 78 +- mne/preprocessing/tests/test_ica.py | 1084 +++--- mne/preprocessing/tests/test_infomax.py | 17 +- mne/preprocessing/tests/test_interpolate.py | 103 +- mne/preprocessing/tests/test_maxwell.py | 1551 +++++---- mne/preprocessing/tests/test_otp.py | 59 +- mne/preprocessing/tests/test_peak_finder.py | 12 +- mne/preprocessing/tests/test_realign.py | 77 +- mne/preprocessing/tests/test_regress.py | 93 +- mne/preprocessing/tests/test_ssp.py | 250 +- mne/preprocessing/tests/test_stim.py | 82 +- mne/preprocessing/tests/test_xdawn.py | 189 +- mne/preprocessing/xdawn.py | 189 +- mne/proj.py | 264 +- mne/rank.py | 294 +- .../bootstrap-icons/gen_css_for_mne.py | 40 +- mne/report/report.py | 2308 +++++++----- mne/report/tests/test_report.py | 765 ++-- mne/simulation/_metrics.py | 4 +- mne/simulation/evoked.py | 64 +- mne/simulation/metrics/__init__.py | 22 +- mne/simulation/metrics/metrics.py | 97 +- mne/simulation/metrics/tests/test_metrics.py | 185 +- mne/simulation/raw.py | 498 +-- mne/simulation/source.py | 206 +- mne/simulation/tests/test_evoked.py | 141 +- mne/simulation/tests/test_metrics.py | 24 +- mne/simulation/tests/test_raw.py | 373 +- mne/simulation/tests/test_source.py | 319 +- mne/source_estimate.py | 1943 +++++++---- mne/source_space.py | 2008 ++++++----- mne/stats/__init__.py | 24 +- mne/stats/_adjacency.py | 38 +- mne/stats/cluster_level.py | 661 ++-- mne/stats/multi_comp.py | 10 +- mne/stats/parametric.py | 90 +- mne/stats/permutations.py | 42 +- mne/stats/regression.py | 183 +- mne/stats/tests/test_adjacency.py | 28 +- mne/stats/tests/test_cluster_level.py | 645 ++-- mne/stats/tests/test_multi_comp.py | 9 +- mne/stats/tests/test_parametric.py | 144 +- mne/stats/tests/test_permutations.py | 40 +- mne/stats/tests/test_regression.py | 82 +- mne/surface.py | 1111 +++--- mne/tests/test_annotations.py | 1131 +++--- mne/tests/test_bem.py | 566 +-- mne/tests/test_chpi.py | 573 +-- mne/tests/test_coreg.py | 378 +- mne/tests/test_cov.py | 749 ++-- mne/tests/test_defaults.py | 33 +- mne/tests/test_dipole.py | 361 +- mne/tests/test_docstring_parameters.py | 267 +- mne/tests/test_epochs.py | 3101 ++++++++++------- mne/tests/test_event.py | 537 +-- mne/tests/test_evoked.py | 421 +-- mne/tests/test_filter.py | 894 +++-- mne/tests/test_freesurfer.py | 186 +- mne/tests/test_import_nesting.py | 5 +- mne/tests/test_label.py | 884 +++-- mne/tests/test_line_endings.py | 77 +- mne/tests/test_morph.py | 855 +++-- mne/tests/test_morph_map.py | 27 +- mne/tests/test_ola.py | 63 +- mne/tests/test_parallel.py | 20 +- mne/tests/test_proj.py | 331 +- mne/tests/test_rank.py | 237 +- mne/tests/test_read_vectorview_selection.py | 43 +- mne/tests/test_source_estimate.py | 1303 +++---- mne/tests/test_source_space.py | 913 ++--- mne/tests/test_surface.py | 436 ++- mne/tests/test_transforms.py | 370 +- mne/time_frequency/__init__.py | 33 +- mne/time_frequency/_stft.py | 73 +- mne/time_frequency/_stockwell.py | 107 +- mne/time_frequency/ar.py | 8 +- mne/time_frequency/csd.py | 584 +++- mne/time_frequency/multitaper.py | 157 +- mne/time_frequency/psd.py | 89 +- mne/time_frequency/spectrum.py | 791 +++-- mne/time_frequency/tests/test_ar.py | 21 +- mne/time_frequency/tests/test_csd.py | 406 ++- mne/time_frequency/tests/test_multitaper.py | 20 +- mne/time_frequency/tests/test_psd.py | 103 +- mne/time_frequency/tests/test_spectrum.py | 239 +- mne/time_frequency/tests/test_stft.py | 29 +- mne/time_frequency/tests/test_stockwell.py | 82 +- mne/time_frequency/tests/test_tfr.py | 1083 +++--- mne/time_frequency/tfr.py | 1514 +++++--- mne/transforms.py | 937 +++-- mne/utils/__init__.py | 283 +- mne/utils/_bunch.py | 9 +- mne/utils/_logging.py | 142 +- mne/utils/_testing.py | 198 +- mne/utils/check.py | 660 ++-- mne/utils/config.py | 485 +-- mne/utils/dataframe.py | 64 +- mne/utils/docs.py | 2059 +++++++---- mne/utils/fetching.py | 5 +- mne/utils/linalg.py | 62 +- mne/utils/misc.py | 124 +- mne/utils/mixin.py | 223 +- mne/utils/numerics.py | 326 +- mne/utils/progressbar.py | 64 +- mne/utils/spectrum.py | 36 +- mne/utils/tests/test_bunch.py | 4 +- mne/utils/tests/test_check.py | 254 +- mne/utils/tests/test_config.py | 91 +- mne/utils/tests/test_docs.py | 117 +- mne/utils/tests/test_linalg.py | 40 +- mne/utils/tests/test_logging.py | 134 +- mne/utils/tests/test_misc.py | 88 +- mne/utils/tests/test_numerics.py | 394 ++- mne/utils/tests/test_progressbar.py | 71 +- mne/utils/tests/test_testing.py | 29 +- mne/viz/_3d.py | 3016 ++++++++++------ mne/viz/_3d_overlay.py | 24 +- mne/viz/__init__.py | 111 +- mne/viz/_brain/__init__.py | 2 +- mne/viz/_brain/_brain.py | 2057 ++++++----- mne/viz/_brain/_linkviewer.py | 18 +- mne/viz/_brain/_scraper.py | 64 +- mne/viz/_brain/callback.py | 16 +- mne/viz/_brain/colormap.py | 94 +- mne/viz/_brain/surface.py | 69 +- mne/viz/_brain/tests/test_brain.py | 860 +++-- mne/viz/_brain/tests/test_notebook.py | 75 +- mne/viz/_brain/view.py | 57 +- mne/viz/_dipole.py | 170 +- mne/viz/_figure.py | 350 +- mne/viz/_mpl_figure.py | 1348 ++++--- mne/viz/_proj.py | 152 +- mne/viz/_scraper.py | 31 +- mne/viz/backends/_abstract.py | 377 +- mne/viz/backends/_notebook.py | 684 ++-- mne/viz/backends/_pyvista.py | 699 ++-- mne/viz/backends/_qt.py | 573 +-- mne/viz/backends/_utils.py | 166 +- mne/viz/backends/renderer.py | 88 +- mne/viz/backends/tests/_utils.py | 4 +- mne/viz/backends/tests/test_abstract.py | 68 +- mne/viz/backends/tests/test_renderer.py | 145 +- mne/viz/backends/tests/test_utils.py | 59 +- mne/viz/circle.py | 195 +- mne/viz/conftest.py | 20 +- mne/viz/epochs.py | 684 ++-- mne/viz/evoked.py | 1975 +++++++---- mne/viz/ica.py | 885 +++-- mne/viz/misc.py | 866 +++-- mne/viz/montage.py | 35 +- mne/viz/raw.py | 364 +- mne/viz/tests/test_3d.py | 1058 +++--- mne/viz/tests/test_3d_mpl.py | 131 +- mne/viz/tests/test_circle.py | 25 +- mne/viz/tests/test_epochs.py | 308 +- mne/viz/tests/test_evoked.py | 507 +-- mne/viz/tests/test_figure.py | 4 +- mne/viz/tests/test_ica.py | 295 +- mne/viz/tests/test_misc.py | 295 +- mne/viz/tests/test_montage.py | 63 +- mne/viz/tests/test_proj.py | 44 +- mne/viz/tests/test_raw.py | 630 ++-- mne/viz/tests/test_scraper.py | 17 +- mne/viz/tests/test_topo.py | 311 +- mne/viz/tests/test_topomap.py | 601 ++-- mne/viz/tests/test_utils.py | 93 +- mne/viz/topo.py | 843 +++-- mne/viz/topomap.py | 2563 +++++++++----- mne/viz/utils.py | 1448 +++++--- setup.py | 200 +- tools/check_mne_location.py | 5 +- tools/generate_codemeta.py | 136 +- tutorials/clinical/20_seeg.py | 126 +- tutorials/clinical/30_ecog.py | 114 +- tutorials/clinical/60_sleep.py | 132 +- tutorials/epochs/10_epochs_overview.py | 86 +- tutorials/epochs/15_baseline_regression.py | 112 +- tutorials/epochs/20_visualize_epochs.py | 80 +- tutorials/epochs/30_epochs_metadata.py | 35 +- tutorials/epochs/40_autogenerate_metadata.py | 232 +- tutorials/epochs/50_epochs_to_data_frame.py | 91 +- .../epochs/60_make_fixed_length_epochs.py | 15 +- tutorials/evoked/10_evoked_overview.py | 55 +- tutorials/evoked/20_visualize_evoked.py | 97 +- tutorials/evoked/30_eeg_erp.py | 150 +- tutorials/evoked/40_whitened.py | 31 +- tutorials/forward/10_background_freesurfer.py | 9 +- tutorials/forward/20_source_alignment.py | 173 +- tutorials/forward/25_automated_coreg.py | 29 +- tutorials/forward/30_forward.py | 101 +- tutorials/forward/35_eeg_no_mri.py | 65 +- .../forward/50_background_freesurfer_mne.py | 268 +- tutorials/forward/80_fix_bem_in_blender.py | 38 +- tutorials/forward/90_compute_covariance.py | 50 +- tutorials/intro/10_overview.py | 120 +- tutorials/intro/15_inplace.py | 16 +- tutorials/intro/20_events_from_raw.py | 40 +- tutorials/intro/30_info.py | 22 +- tutorials/intro/40_sensor_locations.py | 52 +- tutorials/intro/50_configure_mne.py | 50 +- tutorials/intro/70_report.py | 289 +- tutorials/inverse/10_stc_class.py | 45 +- tutorials/inverse/20_dipole_fit.py | 93 +- tutorials/inverse/30_mne_dspm_loreta.py | 86 +- tutorials/inverse/35_dipole_orientations.py | 148 +- tutorials/inverse/40_mne_fixed_free.py | 69 +- tutorials/inverse/50_beamformer_lcmv.py | 118 +- tutorials/inverse/60_visualize_stc.py | 122 +- tutorials/inverse/70_eeg_mri_coords.py | 49 +- .../inverse/80_brainstorm_phantom_elekta.py | 92 +- .../inverse/85_brainstorm_phantom_ctf.py | 46 +- tutorials/inverse/90_phantom_4DBTi.py | 39 +- tutorials/io/30_reading_fnirs_data.py | 59 +- tutorials/io/60_ctf_bst_auditory.py | 195 +- tutorials/io/70_reading_eyetracking_data.py | 10 +- tutorials/machine-learning/30_strf.py | 135 +- tutorials/machine-learning/50_decoding.py | 147 +- .../10_preprocessing_overview.py | 28 +- .../preprocessing/15_handling_bad_channels.py | 39 +- .../preprocessing/20_rejecting_bad_data.py | 92 +- .../preprocessing/25_background_filtering.py | 338 +- .../preprocessing/30_filtering_resampling.py | 69 +- .../35_artifact_correction_regression.py | 50 +- .../40_artifact_correction_ica.py | 90 +- .../preprocessing/45_projectors_background.py | 96 +- .../50_artifact_correction_ssp.py | 116 +- .../preprocessing/55_setting_eeg_reference.py | 38 +- tutorials/preprocessing/59_head_positions.py | 12 +- .../preprocessing/60_maxwell_filtering_sss.py | 105 +- .../preprocessing/70_fnirs_processing.py | 191 +- tutorials/preprocessing/80_opm_processing.py | 74 +- .../preprocessing/90_eyetracking_data.py | 27 +- tutorials/raw/10_raw_overview.py | 64 +- tutorials/raw/20_event_arrays.py | 42 +- tutorials/raw/30_annotate_raw.py | 51 +- tutorials/raw/40_visualize_raw.py | 11 +- tutorials/simulation/10_array_objs.py | 89 +- tutorials/simulation/70_point_spread.py | 84 +- tutorials/simulation/80_dics.py | 146 +- .../stats-sensor-space/10_background_stats.py | 156 +- tutorials/stats-sensor-space/20_erp_stats.py | 55 +- .../40_cluster_1samp_time_freq.py | 124 +- .../50_cluster_between_time_freq.py | 129 +- .../70_cluster_rmANOVA_time_freq.py | 146 +- .../75_cluster_ftest_spatiotemporal.py | 167 +- .../20_cluster_1samp_spatiotemporal.py | 101 +- .../30_cluster_ftest_spatiotemporal.py | 66 +- .../60_cluster_rmANOVA_spatiotemporal.py | 137 +- tutorials/time-freq/10_spectrum_class.py | 33 +- .../time-freq/20_sensors_time_frequency.py | 122 +- tutorials/time-freq/50_ssvep.py | 383 +- 707 files changed, 107021 insertions(+), 70350 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index e0748d37053..8f904e45d85 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -22,17 +22,21 @@ import mne from mne.fixes import _compare_version from mne.tests.test_docstring_parameters import error_ignores -from mne.utils import (linkcode_resolve, # noqa, analysis:ignore - _assert_no_instances, sizeof_fmt, run_subprocess) +from mne.utils import ( + linkcode_resolve, # noqa, analysis:ignore + _assert_no_instances, + sizeof_fmt, + run_subprocess, +) from mne.viz import Brain # noqa -matplotlib.use('agg') +matplotlib.use("agg") faulthandler.enable() -os.environ['_MNE_BROWSER_NO_BLOCK'] = 'true' -os.environ['MNE_BROWSER_OVERVIEW_MODE'] = 'hidden' -os.environ['MNE_BROWSER_THEME'] = 'light' -os.environ['MNE_3D_OPTION_THEME'] = 'light' -sphinx_logger = sphinx.util.logging.getLogger('mne') +os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" +os.environ["MNE_BROWSER_OVERVIEW_MODE"] = "hidden" +os.environ["MNE_BROWSER_THEME"] = "light" +os.environ["MNE_3D_OPTION_THEME"] = "light" +sphinx_logger = sphinx.util.logging.getLogger("mne") # -- Path setup -------------------------------------------------------------- @@ -40,22 +44,23 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'mne'))) -sys.path.append(os.path.abspath(os.path.join(curdir, 'sphinxext'))) +sys.path.append(os.path.abspath(os.path.join(curdir, "..", "mne"))) +sys.path.append(os.path.abspath(os.path.join(curdir, "sphinxext"))) # -- Project information ----------------------------------------------------- -project = 'MNE' +project = "MNE" td = datetime.now(tz=timezone.utc) # We need to triage which date type we use so that incremental builds work # (Sphinx looks at variable changes and rewrites all files if some change) copyright = ( f'2012–{td.year}, MNE Developers. Last updated \n' # noqa: E501 - '') # noqa: E501 -if os.getenv('MNE_FULL_DATE', 'false').lower() != 'true': - copyright = f'2012–{td.year}, MNE Developers. Last updated locally.' + '' +) # noqa: E501 +if os.getenv("MNE_FULL_DATE", "false").lower() != "true": + copyright = f"2012–{td.year}, MNE Developers. Last updated locally." # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -63,71 +68,70 @@ # # The full version, including alpha/beta/rc tags. release = mne.__version__ -sphinx_logger.info( - f'Building documentation for MNE {release} ({mne.__file__})') +sphinx_logger.info(f"Building documentation for MNE {release} ({mne.__file__})") # The short X.Y version. -version = '.'.join(release.split('.')[:2]) +version = ".".join(release.split(".")[:2]) # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '2.0' +needs_sphinx = "2.0" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.linkcode', - 'sphinx.ext.mathjax', - 'sphinx.ext.todo', - 'sphinx.ext.graphviz', - 'numpydoc', - 'sphinx_gallery.gen_gallery', - 'gen_commands', - 'gh_substitutions', - 'mne_substitutions', - 'newcontrib_substitutions', - 'gen_names', - 'matplotlib.sphinxext.plot_directive', - 'sphinxcontrib.bibtex', - 'sphinx_copybutton', - 'sphinx_design', - 'sphinxcontrib.youtube', - 'unit_role', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "sphinx.ext.mathjax", + "sphinx.ext.todo", + "sphinx.ext.graphviz", + "numpydoc", + "sphinx_gallery.gen_gallery", + "gen_commands", + "gh_substitutions", + "mne_substitutions", + "newcontrib_substitutions", + "gen_names", + "matplotlib.sphinxext.plot_directive", + "sphinxcontrib.bibtex", + "sphinx_copybutton", + "sphinx_design", + "sphinxcontrib.youtube", + "unit_role", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_includes'] +exclude_patterns = ["_includes"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The main toctree document. -master_doc = 'index' +master_doc = "index" # List of documents that shouldn't be included in the build. unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. default_role = "py:obj" # A list of ignored prefixes for module index sorting. -modindex_common_prefix = ['mne.'] +modindex_common_prefix = ["mne."] # -- Sphinx-Copybutton configuration ----------------------------------------- copybutton_prompt_text = r">>> |\.\.\. |\$ " @@ -136,36 +140,38 @@ # -- Intersphinx configuration ----------------------------------------------- intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'scipy': ('https://docs.scipy.org/doc/scipy', None), - 'matplotlib': ('https://matplotlib.org/stable', None), - 'sklearn': ('https://scikit-learn.org/stable', None), - 'numba': ('https://numba.readthedocs.io/en/latest', None), - 'joblib': ('https://joblib.readthedocs.io/en/latest', None), - 'nibabel': ('https://nipy.org/nibabel', None), - 'nilearn': ('http://nilearn.github.io/stable', None), - 'nitime': ('https://nipy.org/nitime/', None), - 'surfer': ('https://pysurfer.github.io/', None), - 'mne_bids': ('https://mne.tools/mne-bids/stable', None), - 'mne-connectivity': ('https://mne.tools/mne-connectivity/stable', None), - 'mne-gui-addons': ('https://mne.tools/mne-gui-addons', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None), - 'seaborn': ('https://seaborn.pydata.org/', None), - 'statsmodels': ('https://www.statsmodels.org/dev', None), - 'patsy': ('https://patsy.readthedocs.io/en/latest', None), - 'pyvista': ('https://docs.pyvista.org/', None), - 'imageio': ('https://imageio.readthedocs.io/en/latest', None), - 'mne_realtime': ('https://mne.tools/mne-realtime', None), - 'picard': ('https://pierreablin.github.io/picard/', None), - 'qdarkstyle': ('https://qdarkstylesheet.readthedocs.io/en/latest', None), - 'eeglabio': ('https://eeglabio.readthedocs.io/en/latest', None), - 'dipy': ('https://dipy.org/documentation/1.7.0/', - 'https://dipy.org/documentation/1.7.0/objects.inv/'), - 'pooch': ('https://www.fatiando.org/pooch/latest/', None), - 'pybv': ('https://pybv.readthedocs.io/en/latest/', None), - 'pyqtgraph': ('https://pyqtgraph.readthedocs.io/en/latest/', None), - 'openmeeg': ('https://openmeeg.github.io/', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "sklearn": ("https://scikit-learn.org/stable", None), + "numba": ("https://numba.readthedocs.io/en/latest", None), + "joblib": ("https://joblib.readthedocs.io/en/latest", None), + "nibabel": ("https://nipy.org/nibabel", None), + "nilearn": ("http://nilearn.github.io/stable", None), + "nitime": ("https://nipy.org/nitime/", None), + "surfer": ("https://pysurfer.github.io/", None), + "mne_bids": ("https://mne.tools/mne-bids/stable", None), + "mne-connectivity": ("https://mne.tools/mne-connectivity/stable", None), + "mne-gui-addons": ("https://mne.tools/mne-gui-addons", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "seaborn": ("https://seaborn.pydata.org/", None), + "statsmodels": ("https://www.statsmodels.org/dev", None), + "patsy": ("https://patsy.readthedocs.io/en/latest", None), + "pyvista": ("https://docs.pyvista.org/", None), + "imageio": ("https://imageio.readthedocs.io/en/latest", None), + "mne_realtime": ("https://mne.tools/mne-realtime", None), + "picard": ("https://pierreablin.github.io/picard/", None), + "qdarkstyle": ("https://qdarkstylesheet.readthedocs.io/en/latest", None), + "eeglabio": ("https://eeglabio.readthedocs.io/en/latest", None), + "dipy": ( + "https://dipy.org/documentation/1.7.0/", + "https://dipy.org/documentation/1.7.0/objects.inv/", + ), + "pooch": ("https://www.fatiando.org/pooch/latest/", None), + "pybv": ("https://pybv.readthedocs.io/en/latest/", None), + "pyqtgraph": ("https://pyqtgraph.readthedocs.io/en/latest/", None), + "openmeeg": ("https://openmeeg.github.io/", None), } @@ -175,127 +181,251 @@ docscrape.ClassDoc.extra_public_methods = mne.utils._doc_special_members numpydoc_class_members_toctree = False numpydoc_show_inherited_class_members = { - 'mne.SourceSpaces': False, - 'mne.Forward': False, + "mne.SourceSpaces": False, + "mne.Forward": False, } numpydoc_attributes_as_param_list = True numpydoc_xref_param_type = True numpydoc_xref_aliases = { # Python - 'file-like': ':term:`file-like `', - 'iterator': ':term:`iterator `', - 'path-like': ':term:`path-like`', - 'array-like': ':term:`array_like `', - 'Path': ':class:`python:pathlib.Path`', - 'bool': ':class:`python:bool`', + "file-like": ":term:`file-like `", + "iterator": ":term:`iterator `", + "path-like": ":term:`path-like`", + "array-like": ":term:`array_like `", + "Path": ":class:`python:pathlib.Path`", + "bool": ":class:`python:bool`", # Matplotlib - 'colormap': ':doc:`colormap `', - 'color': ':doc:`color `', - 'Axes': 'matplotlib.axes.Axes', - 'Figure': 'matplotlib.figure.Figure', - 'Axes3D': 'mpl_toolkits.mplot3d.axes3d.Axes3D', - 'ColorbarBase': 'matplotlib.colorbar.ColorbarBase', + "colormap": ":doc:`colormap `", + "color": ":doc:`color `", + "Axes": "matplotlib.axes.Axes", + "Figure": "matplotlib.figure.Figure", + "Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D", + "ColorbarBase": "matplotlib.colorbar.ColorbarBase", # sklearn - 'LeaveOneOut': 'sklearn.model_selection.LeaveOneOut', + "LeaveOneOut": "sklearn.model_selection.LeaveOneOut", # joblib - 'joblib.Parallel': 'joblib.Parallel', + "joblib.Parallel": "joblib.Parallel", # nibabel - 'Nifti1Image': 'nibabel.nifti1.Nifti1Image', - 'Nifti2Image': 'nibabel.nifti2.Nifti2Image', - 'SpatialImage': 'nibabel.spatialimages.SpatialImage', + "Nifti1Image": "nibabel.nifti1.Nifti1Image", + "Nifti2Image": "nibabel.nifti2.Nifti2Image", + "SpatialImage": "nibabel.spatialimages.SpatialImage", # MNE - 'Label': 'mne.Label', 'Forward': 'mne.Forward', 'Evoked': 'mne.Evoked', - 'Info': 'mne.Info', 'SourceSpaces': 'mne.SourceSpaces', - 'Epochs': 'mne.Epochs', 'Layout': 'mne.channels.Layout', - 'EvokedArray': 'mne.EvokedArray', 'BiHemiLabel': 'mne.BiHemiLabel', - 'AverageTFR': 'mne.time_frequency.AverageTFR', - 'EpochsTFR': 'mne.time_frequency.EpochsTFR', - 'Raw': 'mne.io.Raw', 'ICA': 'mne.preprocessing.ICA', - 'Covariance': 'mne.Covariance', 'Annotations': 'mne.Annotations', - 'DigMontage': 'mne.channels.DigMontage', - 'VectorSourceEstimate': 'mne.VectorSourceEstimate', - 'VolSourceEstimate': 'mne.VolSourceEstimate', - 'VolVectorSourceEstimate': 'mne.VolVectorSourceEstimate', - 'MixedSourceEstimate': 'mne.MixedSourceEstimate', - 'MixedVectorSourceEstimate': 'mne.MixedVectorSourceEstimate', - 'SourceEstimate': 'mne.SourceEstimate', 'Projection': 'mne.Projection', - 'ConductorModel': 'mne.bem.ConductorModel', - 'Dipole': 'mne.Dipole', 'DipoleFixed': 'mne.DipoleFixed', - 'InverseOperator': 'mne.minimum_norm.InverseOperator', - 'CrossSpectralDensity': 'mne.time_frequency.CrossSpectralDensity', - 'SourceMorph': 'mne.SourceMorph', - 'Xdawn': 'mne.preprocessing.Xdawn', - 'Report': 'mne.Report', - 'TimeDelayingRidge': 'mne.decoding.TimeDelayingRidge', - 'Vectorizer': 'mne.decoding.Vectorizer', - 'UnsupervisedSpatialFilter': 'mne.decoding.UnsupervisedSpatialFilter', - 'TemporalFilter': 'mne.decoding.TemporalFilter', - 'SSD': 'mne.decoding.SSD', - 'Scaler': 'mne.decoding.Scaler', 'SPoC': 'mne.decoding.SPoC', - 'PSDEstimator': 'mne.decoding.PSDEstimator', - 'LinearModel': 'mne.decoding.LinearModel', - 'FilterEstimator': 'mne.decoding.FilterEstimator', - 'EMS': 'mne.decoding.EMS', 'CSP': 'mne.decoding.CSP', - 'Beamformer': 'mne.beamformer.Beamformer', - 'Transform': 'mne.transforms.Transform', - 'Coregistration': 'mne.coreg.Coregistration', - 'Figure3D': 'mne.viz.Figure3D', - 'EOGRegression': 'mne.preprocessing.EOGRegression', - 'Spectrum': 'mne.time_frequency.Spectrum', - 'EpochsSpectrum': 'mne.time_frequency.EpochsSpectrum', + "Label": "mne.Label", + "Forward": "mne.Forward", + "Evoked": "mne.Evoked", + "Info": "mne.Info", + "SourceSpaces": "mne.SourceSpaces", + "Epochs": "mne.Epochs", + "Layout": "mne.channels.Layout", + "EvokedArray": "mne.EvokedArray", + "BiHemiLabel": "mne.BiHemiLabel", + "AverageTFR": "mne.time_frequency.AverageTFR", + "EpochsTFR": "mne.time_frequency.EpochsTFR", + "Raw": "mne.io.Raw", + "ICA": "mne.preprocessing.ICA", + "Covariance": "mne.Covariance", + "Annotations": "mne.Annotations", + "DigMontage": "mne.channels.DigMontage", + "VectorSourceEstimate": "mne.VectorSourceEstimate", + "VolSourceEstimate": "mne.VolSourceEstimate", + "VolVectorSourceEstimate": "mne.VolVectorSourceEstimate", + "MixedSourceEstimate": "mne.MixedSourceEstimate", + "MixedVectorSourceEstimate": "mne.MixedVectorSourceEstimate", + "SourceEstimate": "mne.SourceEstimate", + "Projection": "mne.Projection", + "ConductorModel": "mne.bem.ConductorModel", + "Dipole": "mne.Dipole", + "DipoleFixed": "mne.DipoleFixed", + "InverseOperator": "mne.minimum_norm.InverseOperator", + "CrossSpectralDensity": "mne.time_frequency.CrossSpectralDensity", + "SourceMorph": "mne.SourceMorph", + "Xdawn": "mne.preprocessing.Xdawn", + "Report": "mne.Report", + "TimeDelayingRidge": "mne.decoding.TimeDelayingRidge", + "Vectorizer": "mne.decoding.Vectorizer", + "UnsupervisedSpatialFilter": "mne.decoding.UnsupervisedSpatialFilter", + "TemporalFilter": "mne.decoding.TemporalFilter", + "SSD": "mne.decoding.SSD", + "Scaler": "mne.decoding.Scaler", + "SPoC": "mne.decoding.SPoC", + "PSDEstimator": "mne.decoding.PSDEstimator", + "LinearModel": "mne.decoding.LinearModel", + "FilterEstimator": "mne.decoding.FilterEstimator", + "EMS": "mne.decoding.EMS", + "CSP": "mne.decoding.CSP", + "Beamformer": "mne.beamformer.Beamformer", + "Transform": "mne.transforms.Transform", + "Coregistration": "mne.coreg.Coregistration", + "Figure3D": "mne.viz.Figure3D", + "EOGRegression": "mne.preprocessing.EOGRegression", + "Spectrum": "mne.time_frequency.Spectrum", + "EpochsSpectrum": "mne.time_frequency.EpochsSpectrum", # dipy - 'dipy.align.AffineMap': 'dipy.align.imaffine.AffineMap', - 'dipy.align.DiffeomorphicMap': 'dipy.align.imwarp.DiffeomorphicMap', + "dipy.align.AffineMap": "dipy.align.imaffine.AffineMap", + "dipy.align.DiffeomorphicMap": "dipy.align.imwarp.DiffeomorphicMap", } numpydoc_xref_ignore = { # words - 'instance', 'instances', 'of', 'default', 'shape', 'or', - 'with', 'length', 'pair', 'matplotlib', 'optional', 'kwargs', 'in', - 'dtype', 'object', + "instance", + "instances", + "of", + "default", + "shape", + "or", + "with", + "length", + "pair", + "matplotlib", + "optional", + "kwargs", + "in", + "dtype", + "object", # shapes - 'n_vertices', 'n_faces', 'n_channels', 'm', 'n', 'n_events', 'n_colors', - 'n_times', 'obj', 'n_chan', 'n_epochs', 'n_picks', 'n_ch_groups', - 'n_dipoles', 'n_ica_components', 'n_pos', 'n_node_names', 'n_tapers', - 'n_signals', 'n_step', 'n_freqs', 'wsize', 'Tx', 'M', 'N', 'p', 'q', 'r', - 'n_observations', 'n_regressors', 'n_cols', 'n_frequencies', 'n_tests', - 'n_samples', 'n_permutations', 'nchan', 'n_points', 'n_features', - 'n_parts', 'n_features_new', 'n_components', 'n_labels', 'n_events_in', - 'n_splits', 'n_scores', 'n_outputs', 'n_trials', 'n_estimators', 'n_tasks', - 'nd_features', 'n_classes', 'n_targets', 'n_slices', 'n_hpi', 'n_fids', - 'n_elp', 'n_pts', 'n_tris', 'n_nodes', 'n_nonzero', 'n_events_out', - 'n_segments', 'n_orient_inv', 'n_orient_fwd', 'n_orient', 'n_dipoles_lcmv', - 'n_dipoles_fwd', 'n_picks_ref', 'n_coords', 'n_meg', 'n_good_meg', - 'n_moments', 'n_patterns', 'n_new_events', + "n_vertices", + "n_faces", + "n_channels", + "m", + "n", + "n_events", + "n_colors", + "n_times", + "obj", + "n_chan", + "n_epochs", + "n_picks", + "n_ch_groups", + "n_dipoles", + "n_ica_components", + "n_pos", + "n_node_names", + "n_tapers", + "n_signals", + "n_step", + "n_freqs", + "wsize", + "Tx", + "M", + "N", + "p", + "q", + "r", + "n_observations", + "n_regressors", + "n_cols", + "n_frequencies", + "n_tests", + "n_samples", + "n_permutations", + "nchan", + "n_points", + "n_features", + "n_parts", + "n_features_new", + "n_components", + "n_labels", + "n_events_in", + "n_splits", + "n_scores", + "n_outputs", + "n_trials", + "n_estimators", + "n_tasks", + "nd_features", + "n_classes", + "n_targets", + "n_slices", + "n_hpi", + "n_fids", + "n_elp", + "n_pts", + "n_tris", + "n_nodes", + "n_nonzero", + "n_events_out", + "n_segments", + "n_orient_inv", + "n_orient_fwd", + "n_orient", + "n_dipoles_lcmv", + "n_dipoles_fwd", + "n_picks_ref", + "n_coords", + "n_meg", + "n_good_meg", + "n_moments", + "n_patterns", + "n_new_events", # Undocumented (on purpose) - 'RawKIT', 'RawEximia', 'RawEGI', 'RawEEGLAB', 'RawEDF', 'RawCTF', 'RawBTi', - 'RawBrainVision', 'RawCurry', 'RawNIRX', 'RawGDF', 'RawSNIRF', 'RawBOXY', - 'RawPersyst', 'RawNihon', 'RawNedf', 'RawHitachi', 'RawFIL', 'RawEyelink', + "RawKIT", + "RawEximia", + "RawEGI", + "RawEEGLAB", + "RawEDF", + "RawCTF", + "RawBTi", + "RawBrainVision", + "RawCurry", + "RawNIRX", + "RawGDF", + "RawSNIRF", + "RawBOXY", + "RawPersyst", + "RawNihon", + "RawNedf", + "RawHitachi", + "RawFIL", + "RawEyelink", # sklearn subclasses - 'mapping', 'to', 'any', + "mapping", + "to", + "any", # unlinkable - 'CoregistrationUI', - 'IntracranialElectrodeLocator', - 'mne_qt_browser.figure.MNEQtBrowser', + "CoregistrationUI", + "IntracranialElectrodeLocator", + "mne_qt_browser.figure.MNEQtBrowser", } numpydoc_validate = True -numpydoc_validation_checks = {'all'} | set(error_ignores) +numpydoc_validation_checks = {"all"} | set(error_ignores) numpydoc_validation_exclude = { # set of regex # dict subclasses - r'\.clear', r'\.get$', r'\.copy$', r'\.fromkeys', r'\.items', r'\.keys', - r'\.pop', r'\.popitem', r'\.setdefault', r'\.update', r'\.values', + r"\.clear", + r"\.get$", + r"\.copy$", + r"\.fromkeys", + r"\.items", + r"\.keys", + r"\.pop", + r"\.popitem", + r"\.setdefault", + r"\.update", + r"\.values", # list subclasses - r'\.append', r'\.count', r'\.extend', r'\.index', r'\.insert', r'\.remove', - r'\.sort', + r"\.append", + r"\.count", + r"\.extend", + r"\.index", + r"\.insert", + r"\.remove", + r"\.sort", # we currently don't document these properly (probably okay) - r'\.__getitem__', r'\.__contains__', r'\.__hash__', r'\.__mul__', - r'\.__sub__', r'\.__add__', r'\.__iter__', r'\.__div__', r'\.__neg__', + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", # copied from sklearn - r'mne\.utils\.deprecated', + r"mne\.utils\.deprecated", } # -- Sphinx-gallery configuration -------------------------------------------- + class Resetter(object): """Simple class to make the str(obj) static for Sphinx build env hash.""" @@ -303,10 +433,11 @@ def __init__(self): self.t0 = time.time() def __repr__(self): - return f'<{self.__class__.__name__}>' + return f"<{self.__class__.__name__}>" def __call__(self, gallery_conf, fname, when): import matplotlib.pyplot as plt + try: from pyvista import Plotter # noqa except ImportError: @@ -324,45 +455,46 @@ def __call__(self, gallery_conf, fname, when): except ImportError: MNEQtBrowser = None from mne.viz.backends.renderer import backend + _Renderer = backend._Renderer if backend is not None else None reset_warnings(gallery_conf, fname) # in case users have interactive mode turned on in matplotlibrc, # turn it off here (otherwise the build can be very slow) plt.ioff() - plt.rcParams['animation.embed_limit'] = 30. - plt.rcParams['figure.raise_window'] = False + plt.rcParams["animation.embed_limit"] = 30.0 + plt.rcParams["figure.raise_window"] = False # neo holds on to an exception, which in turn holds a stack frame, # which will keep alive the global vars during SG execution try: import neo + neo.io.stimfitio.STFIO_ERR = None except Exception: pass gc.collect() - when = f'mne/conf.py:Resetter.__call__:{when}:{fname}' + when = f"mne/conf.py:Resetter.__call__:{when}:{fname}" # Support stuff like # MNE_SKIP_INSTANCE_ASSERTIONS="Brain,Plotter,BackgroundPlotter,vtkPolyData,_Renderer" make html-memory # noqa: E501 # to just test MNEQtBrowser - skips = os.getenv('MNE_SKIP_INSTANCE_ASSERTIONS', '').lower() - prefix = '' - if skips not in ('true', '1', 'all'): - prefix = 'Clean ' - skips = skips.split(',') - if 'brain' not in skips: + skips = os.getenv("MNE_SKIP_INSTANCE_ASSERTIONS", "").lower() + prefix = "" + if skips not in ("true", "1", "all"): + prefix = "Clean " + skips = skips.split(",") + if "brain" not in skips: _assert_no_instances(Brain, when) # calls gc.collect() - if Plotter is not None and 'plotter' not in skips: + if Plotter is not None and "plotter" not in skips: _assert_no_instances(Plotter, when) - if BackgroundPlotter is not None and \ - 'backgroundplotter' not in skips: + if BackgroundPlotter is not None and "backgroundplotter" not in skips: _assert_no_instances(BackgroundPlotter, when) - if vtkPolyData is not None and 'vtkpolydata' not in skips: + if vtkPolyData is not None and "vtkpolydata" not in skips: _assert_no_instances(vtkPolyData, when) - if '_renderer' not in skips: + if "_renderer" not in skips: _assert_no_instances(_Renderer, when) - if MNEQtBrowser is not None and \ - 'mneqtbrowser' not in skips: + if MNEQtBrowser is not None and "mneqtbrowser" not in skips: # Ensure any manual fig.close() events get properly handled from mne_qt_browser._pg_figure import QApplication + inst = QApplication.instance() if inst is not None: for _ in range(2): @@ -370,18 +502,19 @@ def __call__(self, gallery_conf, fname, when): _assert_no_instances(MNEQtBrowser, when) # This will overwrite some Sphinx printing but it's useful # for memory timestamps - if os.getenv('SG_STAMP_STARTS', '').lower() == 'true': + if os.getenv("SG_STAMP_STARTS", "").lower() == "true": import psutil + process = psutil.Process(os.getpid()) mem = sizeof_fmt(process.memory_info().rss) - print(f'{prefix}{time.time() - self.t0:6.1f} s : {mem}'.ljust(22)) + print(f"{prefix}{time.time() - self.t0:6.1f} s : {mem}".ljust(22)) -examples_dirs = ['../tutorials', '../examples'] -gallery_dirs = ['auto_tutorials', 'auto_examples'] -os.environ['_MNE_BUILDING_DOC'] = 'true' -scrapers = ('matplotlib',) -mne.viz.set_3d_backend('pyvistaqt') +examples_dirs = ["../tutorials", "../examples"] +gallery_dirs = ["auto_tutorials", "auto_examples"] +os.environ["_MNE_BUILDING_DOC"] = "true" +scrapers = ("matplotlib",) +mne.viz.set_3d_backend("pyvistaqt") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) import pyvista @@ -390,111 +523,118 @@ def __call__(self, gallery_conf, fname, when): report_scraper = mne.report._ReportScraper() scrapers = ( - 'matplotlib', + "matplotlib", mne.gui._GUIScraper(), mne.viz._brain._BrainScraper(), - 'pyvista', + "pyvista", report_scraper, mne.viz._scraper._MNEQtBrowserScraper(), ) -compress_images = ('images', 'thumbnails') +compress_images = ("images", "thumbnails") # let's make things easier on Windows users # (on Linux and macOS it's easy enough to require this) -if sys.platform.startswith('win'): +if sys.platform.startswith("win"): try: - subprocess.check_call(['optipng', '--version']) + subprocess.check_call(["optipng", "--version"]) except Exception: compress_images = () sphinx_gallery_conf = { - 'doc_module': ('mne',), - 'reference_url': dict(mne=None), - 'examples_dirs': examples_dirs, - 'subsection_order': ExplicitOrder(['../examples/io/', - '../examples/simulation/', - '../examples/preprocessing/', - '../examples/visualization/', - '../examples/time_frequency/', - '../examples/stats/', - '../examples/decoding/', - '../examples/connectivity/', - '../examples/forward/', - '../examples/inverse/', - '../examples/realtime/', - '../examples/datasets/', - '../tutorials/intro/', - '../tutorials/io/', - '../tutorials/raw/', - '../tutorials/preprocessing/', - '../tutorials/epochs/', - '../tutorials/evoked/', - '../tutorials/time-freq/', - '../tutorials/forward/', - '../tutorials/inverse/', - '../tutorials/stats-sensor-space/', - '../tutorials/stats-source-space/', - '../tutorials/machine-learning/', - '../tutorials/clinical/', - '../tutorials/simulation/', - '../tutorials/sample-datasets/', - '../tutorials/misc/']), - 'gallery_dirs': gallery_dirs, - 'default_thumb_file': os.path.join('_static', 'mne_helmet.png'), - 'backreferences_dir': 'generated', - 'plot_gallery': 'True', # Avoid annoying Unicode/bool default warning - 'thumbnail_size': (160, 112), - 'remove_config_comments': True, - 'min_reported_time': 1., - 'abort_on_example_error': False, - 'reset_modules': ('matplotlib', Resetter()), # called w/each script - 'reset_modules_order': 'both', - 'image_scrapers': scrapers, - 'show_memory': not sys.platform.startswith(('win', 'darwin')), - 'line_numbers': False, # messes with style - 'within_subsection_order': FileNameSortKey, - 'capture_repr': ('_repr_html_',), - 'junit': os.path.join('..', 'test-results', 'sphinx-gallery', 'junit.xml'), - 'matplotlib_animations': True, - 'compress_images': compress_images, - 'filename_pattern': '^((?!sgskip).)*$', - 'exclude_implicit_doc': { - r'mne\.io\.read_raw_fif', r'mne\.io\.Raw', r'mne\.Epochs', - r'mne.datasets.*', + "doc_module": ("mne",), + "reference_url": dict(mne=None), + "examples_dirs": examples_dirs, + "subsection_order": ExplicitOrder( + [ + "../examples/io/", + "../examples/simulation/", + "../examples/preprocessing/", + "../examples/visualization/", + "../examples/time_frequency/", + "../examples/stats/", + "../examples/decoding/", + "../examples/connectivity/", + "../examples/forward/", + "../examples/inverse/", + "../examples/realtime/", + "../examples/datasets/", + "../tutorials/intro/", + "../tutorials/io/", + "../tutorials/raw/", + "../tutorials/preprocessing/", + "../tutorials/epochs/", + "../tutorials/evoked/", + "../tutorials/time-freq/", + "../tutorials/forward/", + "../tutorials/inverse/", + "../tutorials/stats-sensor-space/", + "../tutorials/stats-source-space/", + "../tutorials/machine-learning/", + "../tutorials/clinical/", + "../tutorials/simulation/", + "../tutorials/sample-datasets/", + "../tutorials/misc/", + ] + ), + "gallery_dirs": gallery_dirs, + "default_thumb_file": os.path.join("_static", "mne_helmet.png"), + "backreferences_dir": "generated", + "plot_gallery": "True", # Avoid annoying Unicode/bool default warning + "thumbnail_size": (160, 112), + "remove_config_comments": True, + "min_reported_time": 1.0, + "abort_on_example_error": False, + "reset_modules": ("matplotlib", Resetter()), # called w/each script + "reset_modules_order": "both", + "image_scrapers": scrapers, + "show_memory": not sys.platform.startswith(("win", "darwin")), + "line_numbers": False, # messes with style + "within_subsection_order": FileNameSortKey, + "capture_repr": ("_repr_html_",), + "junit": os.path.join("..", "test-results", "sphinx-gallery", "junit.xml"), + "matplotlib_animations": True, + "compress_images": compress_images, + "filename_pattern": "^((?!sgskip).)*$", + "exclude_implicit_doc": { + r"mne\.io\.read_raw_fif", + r"mne\.io\.Raw", + r"mne\.Epochs", + r"mne.datasets.*", }, - 'show_api_usage': False, # disable for now until graph warning fixed - 'api_usage_ignore': ( - '(' - '.*__.*__|' # built-ins - '.*Base.*|.*Array.*|mne.Vector.*|mne.Mixed.*|mne.Vol.*|' # inherited - 'mne.coreg.Coregistration.*|' # GUI + "show_api_usage": False, # disable for now until graph warning fixed + "api_usage_ignore": ( + "(" + ".*__.*__|" # built-ins + ".*Base.*|.*Array.*|mne.Vector.*|mne.Mixed.*|mne.Vol.*|" # inherited + "mne.coreg.Coregistration.*|" # GUI # common - '.*utils.*|.*verbose()|.*copy()|.*update()|.*save()|' - '.*get_data()|' + ".*utils.*|.*verbose()|.*copy()|.*update()|.*save()|" + ".*get_data()|" # mixins - '.*add_channels()|.*add_reference_channels()|' - '.*anonymize()|.*apply_baseline()|.*apply_function()|' - '.*apply_hilbert()|.*as_type()|.*decimate()|' - '.*drop()|.*drop_channels()|.*drop_log_stats()|' - '.*export()|.*get_channel_types()|' - '.*get_montage()|.*interpolate_bads()|.*next()|' - '.*pick()|.*pick_channels()|.*pick_types()|' - '.*plot_sensors()|.*rename_channels()|' - '.*reorder_channels()|.*savgol_filter()|' - '.*set_eeg_reference()|.*set_channel_types()|' - '.*set_meas_date()|.*set_montage()|.*shift_time()|' - '.*time_as_index()|.*to_data_frame()|' + ".*add_channels()|.*add_reference_channels()|" + ".*anonymize()|.*apply_baseline()|.*apply_function()|" + ".*apply_hilbert()|.*as_type()|.*decimate()|" + ".*drop()|.*drop_channels()|.*drop_log_stats()|" + ".*export()|.*get_channel_types()|" + ".*get_montage()|.*interpolate_bads()|.*next()|" + ".*pick()|.*pick_channels()|.*pick_types()|" + ".*plot_sensors()|.*rename_channels()|" + ".*reorder_channels()|.*savgol_filter()|" + ".*set_eeg_reference()|.*set_channel_types()|" + ".*set_meas_date()|.*set_montage()|.*shift_time()|" + ".*time_as_index()|.*to_data_frame()|" # dictionary inherited - '.*clear()|.*fromkeys()|.*get()|.*items()|' - '.*keys()|.*pop()|.*popitem()|.*setdefault()|' - '.*values()|' + ".*clear()|.*fromkeys()|.*get()|.*items()|" + ".*keys()|.*pop()|.*popitem()|.*setdefault()|" + ".*values()|" # sklearn inherited - '.*apply()|.*decision_function()|.*fit()|' - '.*fit_transform()|.*get_params()|.*predict()|' - '.*predict_proba()|.*set_params()|.*transform()|' + ".*apply()|.*decision_function()|.*fit()|" + ".*fit_transform()|.*get_params()|.*predict()|" + ".*predict_proba()|.*set_params()|.*transform()|" # I/O, also related to mixins - '.*.remove.*|.*.write.*)'), - 'copyfile_regex': r'.*index\.rst', # allow custom index.rst files + ".*.remove.*|.*.write.*)" + ), + "copyfile_regex": r".*index\.rst", # allow custom index.rst files } # Files were renamed from plot_* with: # find . -type f -name 'plot_*.py' -exec sh -c 'x="{}"; xn=`basename "${x}"`; git mv "$x" `dirname "${x}"`/${xn:5}' \; # noqa @@ -506,9 +646,12 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # does not respect the autodoc templates that would otherwise insert # the .. include:: lines, so we need to do it. # Eventually this could perhaps live in SG. - if what in ('attribute', 'method'): - size = os.path.getsize(os.path.join( - os.path.dirname(__file__), 'generated', '%s.examples' % (name,))) + if what in ("attribute", "method"): + size = os.path.getsize( + os.path.join( + os.path.dirname(__file__), "generated", "%s.examples" % (name,) + ) + ) if size > 0: lines += """ .. _sphx_glr_backreferences_{1}: @@ -517,12 +660,16 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): .. minigallery:: {1} -""".format(name.split('.')[-1], name).split('\n') +""".format( + name.split(".")[-1], name + ).split( + "\n" + ) # -- Other extension configuration ------------------------------------------- -user_agent = 'Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Mobile Safari/537.36' # noqa: E501 +user_agent = "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Mobile Safari/537.36" # noqa: E501 # Can eventually add linkcheck_request_headers if needed linkcheck_ignore = [ # will be compiled to regex # 403 Client Error: Forbidden @@ -560,12 +707,12 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # autodoc / autosummary autosummary_generate = True -autodoc_default_options = {'inherited-members': None} +autodoc_default_options = {"inherited-members": None} # sphinxcontrib-bibtex -bibtex_bibfiles = ['./references.bib'] -bibtex_style = 'unsrt' -bibtex_footbibliography_header = '' +bibtex_bibfiles = ["./references.bib"] +bibtex_style = "unsrt" +bibtex_footbibliography_header = "" # -- Nitpicky ---------------------------------------------------------------- @@ -575,7 +722,10 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ("py:class", "None. Remove all items from D."), ("py:class", "a set-like object providing a view on D's items"), ("py:class", "a set-like object providing a view on D's keys"), - ("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501 + ( + "py:class", + "v, remove specified key and return the corresponding value.", + ), # noqa: E501 ("py:class", "None. Update D from dict/iterable E and F."), ("py:class", "an object providing a view on D's values"), ("py:class", "a shallow copy of D"), @@ -584,11 +734,14 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ("py:class", "mne.utils._logging._FuncT"), ] nitpick_ignore_regex = [ - ('py:.*', r"mne\.io\.BaseRaw.*"), - ('py:.*', r"mne\.BaseEpochs.*"), - ('py:obj', "(filename|metadata|proj|times|tmax|tmin|annotations|ch_names|compensation_grade|filenames|first_samp|first_time|last_samp|n_times|proj|times|tmax|tmin)"), # noqa: E501 + ("py:.*", r"mne\.io\.BaseRaw.*"), + ("py:.*", r"mne\.BaseEpochs.*"), + ( + "py:obj", + "(filename|metadata|proj|times|tmax|tmin|annotations|ch_names|compensation_grade|filenames|first_samp|first_time|last_samp|n_times|proj|times|tmax|tmin)", + ), # noqa: E501 ] -suppress_warnings = ['image.nonlocal_uri'] # we intentionally link outside +suppress_warnings = ["image.nonlocal_uri"] # we intentionally link outside # -- Options for HTML output ------------------------------------------------- @@ -596,46 +749,56 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -switcher_version_match = 'dev' if release.endswith('dev0') else version +switcher_version_match = "dev" if release.endswith("dev0") else version html_theme_options = { - 'icon_links': [ - dict(name='GitHub', - url='https://github.com/mne-tools/mne-python', - icon='fa-brands fa-square-github'), - dict(name='Mastodon', - url='https://fosstodon.org/@mne', - icon='fa-brands fa-mastodon', - attributes=dict(rel='me')), - dict(name='Twitter', - url='https://twitter.com/mne_python', - icon='fa-brands fa-square-twitter'), - dict(name='Forum', - url='https://mne.discourse.group/', - icon='fa-brands fa-discourse'), - dict(name='Discord', - url='https://discord.gg/rKfvxTuATa', - icon='fa-brands fa-discord') + "icon_links": [ + dict( + name="GitHub", + url="https://github.com/mne-tools/mne-python", + icon="fa-brands fa-square-github", + ), + dict( + name="Mastodon", + url="https://fosstodon.org/@mne", + icon="fa-brands fa-mastodon", + attributes=dict(rel="me"), + ), + dict( + name="Twitter", + url="https://twitter.com/mne_python", + icon="fa-brands fa-square-twitter", + ), + dict( + name="Forum", + url="https://mne.discourse.group/", + icon="fa-brands fa-discourse", + ), + dict( + name="Discord", + url="https://discord.gg/rKfvxTuATa", + icon="fa-brands fa-discord", + ), ], - 'icon_links_label': 'External Links', # for screen reader - 'use_edit_page_button': False, - 'navigation_with_keys': False, - 'show_toc_level': 1, - 'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'], - 'footer_start': ['copyright'], - 'footer_end': [], - 'secondary_sidebar_items': ['page-toc'], - 'analytics': dict(google_analytics_id='G-5TBCPCRB6X'), - 'switcher': { - 'json_url': 'https://mne.tools/dev/_static/versions.json', - 'version_match': switcher_version_match, + "icon_links_label": "External Links", # for screen reader + "use_edit_page_button": False, + "navigation_with_keys": False, + "show_toc_level": 1, + "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"], + "footer_start": ["copyright"], + "footer_end": [], + "secondary_sidebar_items": ["page-toc"], + "analytics": dict(google_analytics_id="G-5TBCPCRB6X"), + "switcher": { + "json_url": "https://mne.tools/dev/_static/versions.json", + "version_match": switcher_version_match, }, - 'pygment_light_style': 'default', - 'pygment_dark_style': 'github-dark', + "pygment_light_style": "default", + "pygment_dark_style": "github-dark", } # The name of an image file (relative to this directory) to place at the top @@ -651,24 +814,24 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'style.css', + "style.css", ] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. html_extra_path = [ - 'contributing.html', - 'documentation.html', - 'getting_started.html', - 'install_mne_python.html', + "contributing.html", + "documentation.html", + "getting_started.html", + "install_mne_python.html", ] # Custom sidebar templates, maps document names to template names. html_sidebars = { - 'index': ['sidebar-quicklinks.html'], + "index": ["sidebar-quicklinks.html"], } # If true, links to the reST sources are added to the pages. @@ -679,262 +842,346 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): html_show_sphinx = False # accommodate different logo shapes (width values in rem) -xs = '2' -sm = '2.5' -md = '3' -lg = '4.5' -xl = '5' -xxl = '6' +xs = "2" +sm = "2.5" +md = "3" +lg = "4.5" +xl = "5" +xxl = "6" # variables to pass to HTML templating engine html_context = { - 'default_mode': 'auto', - 'pygment_light_style': 'tango', - 'pygment_dark_style': 'native', - 'funders': [ - dict(img='nih.svg', size='3', title='National Institutes of Health'), - dict(img='nsf.png', size='3.5', - title='US National Science Foundation'), - dict(img='erc.svg', size='3.5', title='European Research Council', - klass='only-light'), - dict(img='erc-dark.svg', size='3.5', title='European Research Council', - klass='only-dark'), - dict(img='doe.svg', size='3', title='US Department of Energy'), - dict(img='anr.svg', size='3.5', - title='Agence Nationale de la Recherche'), - dict(img='cds.png', size='2.25', - title='Paris-Saclay Center for Data Science'), - dict(img='google.svg', size='2.25', title='Google'), - dict(img='amazon.svg', size='2.5', title='Amazon'), - dict(img='czi.svg', size='2.5', title='Chan Zuckerberg Initiative'), + "default_mode": "auto", + "pygment_light_style": "tango", + "pygment_dark_style": "native", + "funders": [ + dict(img="nih.svg", size="3", title="National Institutes of Health"), + dict(img="nsf.png", size="3.5", title="US National Science Foundation"), + dict( + img="erc.svg", + size="3.5", + title="European Research Council", + klass="only-light", + ), + dict( + img="erc-dark.svg", + size="3.5", + title="European Research Council", + klass="only-dark", + ), + dict(img="doe.svg", size="3", title="US Department of Energy"), + dict(img="anr.svg", size="3.5", title="Agence Nationale de la Recherche"), + dict(img="cds.png", size="2.25", title="Paris-Saclay Center for Data Science"), + dict(img="google.svg", size="2.25", title="Google"), + dict(img="amazon.svg", size="2.5", title="Amazon"), + dict(img="czi.svg", size="2.5", title="Chan Zuckerberg Initiative"), ], - 'institutions': [ - dict(name='Massachusetts General Hospital', - img='MGH.svg', - url='https://www.massgeneral.org/', - size=sm), - dict(name='Athinoula A. Martinos Center for Biomedical Imaging', - img='Martinos.png', - url='https://martinos.org/', - size=md), - dict(name='Harvard Medical School', - img='Harvard.png', - url='https://hms.harvard.edu/', - size=sm), - dict(name='Massachusetts Institute of Technology', - img='MIT.svg', - url='https://web.mit.edu/', - size=md), - dict(name='New York University', - img='NYU.svg', - url='https://www.nyu.edu/', - size=xs, - klass='only-light'), - dict(name='New York University', - img='NYU-dark.svg', - url='https://www.nyu.edu/', - size=xs, - klass='only-dark'), - dict(name='Commissariat à l´énergie atomique et aux énergies alternatives', # noqa E501 - img='CEA.png', - url='http://www.cea.fr/', - size=md), - dict(name='Aalto-yliopiston perustieteiden korkeakoulu', - img='Aalto.svg', - url='https://sci.aalto.fi/', - size=md, - klass='only-light'), - dict(name='Aalto-yliopiston perustieteiden korkeakoulu', - img='Aalto-dark.svg', - url='https://sci.aalto.fi/', - size=md, - klass='only-dark'), - dict(name='Télécom ParisTech', - img='Telecom_Paris_Tech.svg', - url='https://www.telecom-paris.fr/', - size=md), - dict(name='University of Washington', - img='Washington.svg', - url='https://www.washington.edu/', - size=md, - klass='only-light'), - dict(name='University of Washington', - img='Washington-dark.svg', - url='https://www.washington.edu/', - size=md, - klass='only-dark'), - dict(name='Institut du Cerveau et de la Moelle épinière', - img='ICM.jpg', - url='https://icm-institute.org/', - size=md), - dict(name='Boston University', - img='BU.svg', - url='https://www.bu.edu/', - size=lg), - dict(name='Institut national de la santé et de la recherche médicale', - img='Inserm.svg', - url='https://www.inserm.fr/', - size=xl, - klass='only-light'), - dict(name='Institut national de la santé et de la recherche médicale', - img='Inserm-dark.svg', - url='https://www.inserm.fr/', - size=xl, - klass='only-dark'), - dict(name='Forschungszentrum Jülich', - img='Julich.svg', - url='https://www.fz-juelich.de/', - size=xl, - klass='only-light'), - dict(name='Forschungszentrum Jülich', - img='Julich-dark.svg', - url='https://www.fz-juelich.de/', - size=xl, - klass='only-dark'), - dict(name='Technische Universität Ilmenau', - img='Ilmenau.svg', - url='https://www.tu-ilmenau.de/', - size=xxl, - klass='only-light'), - dict(name='Technische Universität Ilmenau', - img='Ilmenau-dark.svg', - url='https://www.tu-ilmenau.de/', - size=xxl, - klass='only-dark'), - dict(name='Berkeley Institute for Data Science', - img='BIDS.svg', - url='https://bids.berkeley.edu/', - size=lg, - klass='only-light'), - dict(name='Berkeley Institute for Data Science', - img='BIDS-dark.svg', - url='https://bids.berkeley.edu/', - size=lg, - klass='only-dark'), - dict(name='Institut national de recherche en informatique et en automatique', # noqa E501 - img='inria.png', - url='https://www.inria.fr/', - size=xl), - dict(name='Aarhus Universitet', - img='Aarhus.svg', - url='https://www.au.dk/', - size=xl, - klass='only-light'), - dict(name='Aarhus Universitet', - img='Aarhus-dark.svg', - url='https://www.au.dk/', - size=xl, - klass='only-dark'), - dict(name='Karl-Franzens-Universität Graz', - img='Graz.svg', - url='https://www.uni-graz.at/', - size=md), - dict(name='SWPS Uniwersytet Humanistycznospołeczny', - img='SWPS.svg', - url='https://www.swps.pl/', - size=xl, - klass='only-light'), - dict(name='SWPS Uniwersytet Humanistycznospołeczny', - img='SWPS-dark.svg', - url='https://www.swps.pl/', - size=xl, - klass='only-dark'), - dict(name='Max-Planck-Institut für Bildungsforschung', - img='MPIB.svg', - url='https://www.mpib-berlin.mpg.de/', - size=xxl, - klass='only-light'), - dict(name='Max-Planck-Institut für Bildungsforschung', - img='MPIB-dark.svg', - url='https://www.mpib-berlin.mpg.de/', - size=xxl, - klass='only-dark'), - dict(name='Macquarie University', - img='Macquarie.svg', - url='https://www.mq.edu.au/', - size=lg, - klass='only-light'), - dict(name='Macquarie University', - img='Macquarie-dark.svg', - url='https://www.mq.edu.au/', - size=lg, - klass='only-dark'), - dict(name='Children’s Hospital of Philadelphia Research Institute', - img='CHOP.svg', - url='https://www.research.chop.edu/imaging', - size=xxl, - klass='only-light'), - dict(name='Children’s Hospital of Philadelphia Research Institute', - img='CHOP-dark.svg', - url='https://www.research.chop.edu/imaging', - size=xxl, - klass='only-dark'), - dict(name='Donders Institute for Brain, Cognition and Behaviour at Radboud University', # noqa E501 - img='Donders.png', - url='https://www.ru.nl/donders/', - size=xl), + "institutions": [ + dict( + name="Massachusetts General Hospital", + img="MGH.svg", + url="https://www.massgeneral.org/", + size=sm, + ), + dict( + name="Athinoula A. Martinos Center for Biomedical Imaging", + img="Martinos.png", + url="https://martinos.org/", + size=md, + ), + dict( + name="Harvard Medical School", + img="Harvard.png", + url="https://hms.harvard.edu/", + size=sm, + ), + dict( + name="Massachusetts Institute of Technology", + img="MIT.svg", + url="https://web.mit.edu/", + size=md, + ), + dict( + name="New York University", + img="NYU.svg", + url="https://www.nyu.edu/", + size=xs, + klass="only-light", + ), + dict( + name="New York University", + img="NYU-dark.svg", + url="https://www.nyu.edu/", + size=xs, + klass="only-dark", + ), + dict( + name="Commissariat à l´énergie atomique et aux énergies alternatives", # noqa E501 + img="CEA.png", + url="http://www.cea.fr/", + size=md, + ), + dict( + name="Aalto-yliopiston perustieteiden korkeakoulu", + img="Aalto.svg", + url="https://sci.aalto.fi/", + size=md, + klass="only-light", + ), + dict( + name="Aalto-yliopiston perustieteiden korkeakoulu", + img="Aalto-dark.svg", + url="https://sci.aalto.fi/", + size=md, + klass="only-dark", + ), + dict( + name="Télécom ParisTech", + img="Telecom_Paris_Tech.svg", + url="https://www.telecom-paris.fr/", + size=md, + ), + dict( + name="University of Washington", + img="Washington.svg", + url="https://www.washington.edu/", + size=md, + klass="only-light", + ), + dict( + name="University of Washington", + img="Washington-dark.svg", + url="https://www.washington.edu/", + size=md, + klass="only-dark", + ), + dict( + name="Institut du Cerveau et de la Moelle épinière", + img="ICM.jpg", + url="https://icm-institute.org/", + size=md, + ), + dict( + name="Boston University", img="BU.svg", url="https://www.bu.edu/", size=lg + ), + dict( + name="Institut national de la santé et de la recherche médicale", + img="Inserm.svg", + url="https://www.inserm.fr/", + size=xl, + klass="only-light", + ), + dict( + name="Institut national de la santé et de la recherche médicale", + img="Inserm-dark.svg", + url="https://www.inserm.fr/", + size=xl, + klass="only-dark", + ), + dict( + name="Forschungszentrum Jülich", + img="Julich.svg", + url="https://www.fz-juelich.de/", + size=xl, + klass="only-light", + ), + dict( + name="Forschungszentrum Jülich", + img="Julich-dark.svg", + url="https://www.fz-juelich.de/", + size=xl, + klass="only-dark", + ), + dict( + name="Technische Universität Ilmenau", + img="Ilmenau.svg", + url="https://www.tu-ilmenau.de/", + size=xxl, + klass="only-light", + ), + dict( + name="Technische Universität Ilmenau", + img="Ilmenau-dark.svg", + url="https://www.tu-ilmenau.de/", + size=xxl, + klass="only-dark", + ), + dict( + name="Berkeley Institute for Data Science", + img="BIDS.svg", + url="https://bids.berkeley.edu/", + size=lg, + klass="only-light", + ), + dict( + name="Berkeley Institute for Data Science", + img="BIDS-dark.svg", + url="https://bids.berkeley.edu/", + size=lg, + klass="only-dark", + ), + dict( + name="Institut national de recherche en informatique et en automatique", # noqa E501 + img="inria.png", + url="https://www.inria.fr/", + size=xl, + ), + dict( + name="Aarhus Universitet", + img="Aarhus.svg", + url="https://www.au.dk/", + size=xl, + klass="only-light", + ), + dict( + name="Aarhus Universitet", + img="Aarhus-dark.svg", + url="https://www.au.dk/", + size=xl, + klass="only-dark", + ), + dict( + name="Karl-Franzens-Universität Graz", + img="Graz.svg", + url="https://www.uni-graz.at/", + size=md, + ), + dict( + name="SWPS Uniwersytet Humanistycznospołeczny", + img="SWPS.svg", + url="https://www.swps.pl/", + size=xl, + klass="only-light", + ), + dict( + name="SWPS Uniwersytet Humanistycznospołeczny", + img="SWPS-dark.svg", + url="https://www.swps.pl/", + size=xl, + klass="only-dark", + ), + dict( + name="Max-Planck-Institut für Bildungsforschung", + img="MPIB.svg", + url="https://www.mpib-berlin.mpg.de/", + size=xxl, + klass="only-light", + ), + dict( + name="Max-Planck-Institut für Bildungsforschung", + img="MPIB-dark.svg", + url="https://www.mpib-berlin.mpg.de/", + size=xxl, + klass="only-dark", + ), + dict( + name="Macquarie University", + img="Macquarie.svg", + url="https://www.mq.edu.au/", + size=lg, + klass="only-light", + ), + dict( + name="Macquarie University", + img="Macquarie-dark.svg", + url="https://www.mq.edu.au/", + size=lg, + klass="only-dark", + ), + dict( + name="Children’s Hospital of Philadelphia Research Institute", + img="CHOP.svg", + url="https://www.research.chop.edu/imaging", + size=xxl, + klass="only-light", + ), + dict( + name="Children’s Hospital of Philadelphia Research Institute", + img="CHOP-dark.svg", + url="https://www.research.chop.edu/imaging", + size=xxl, + klass="only-dark", + ), + dict( + name="Donders Institute for Brain, Cognition and Behaviour at Radboud University", # noqa E501 + img="Donders.png", + url="https://www.ru.nl/donders/", + size=xl, + ), ], # \u00AD is an optional hyphen (not rendered unless needed) # If these are changed, the Makefile should be updated, too - 'carousel': [ - dict(title='Source Estimation', - text='Distributed, sparse, mixed-norm, beam\u00ADformers, dipole fitting, and more.', # noqa E501 - url='auto_tutorials/inverse/index.html', - img='sphx_glr_30_mne_dspm_loreta_008.gif', - alt='dSPM'), - dict(title='Machine Learning', - text='Advanced decoding models including time general\u00ADiza\u00ADtion.', # noqa E501 - url='auto_tutorials/machine-learning/50_decoding.html', - img='sphx_glr_50_decoding_006.png', - alt='Decoding'), - dict(title='Encoding Models', - text='Receptive field estima\u00ADtion with optional smooth\u00ADness priors.', # noqa E501 - url='auto_tutorials/machine-learning/30_strf.html', - img='sphx_glr_30_strf_001.png', - alt='STRF'), - dict(title='Statistics', - text='Parametric and non-parametric, permutation tests and clustering.', # noqa E501 - url='auto_tutorials/stats-source-space/index.html', - img='sphx_glr_20_cluster_1samp_spatiotemporal_001.png', - alt='Clusters'), - dict(title='Connectivity', - text='All-to-all spectral and effective connec\u00ADtivity measures.', # noqa E501 - url='https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html', # noqa E501 - img='https://mne.tools/mne-connectivity/stable/_images/sphx_glr_mne_inverse_label_connectivity_001.png', # noqa E501 - alt='Connectivity'), - dict(title='Data Visualization', - text='Explore your data from multiple perspectives.', - url='auto_tutorials/evoked/20_visualize_evoked.html', - img='sphx_glr_20_visualize_evoked_010.png', - alt='Visualization'), - ] + "carousel": [ + dict( + title="Source Estimation", + text="Distributed, sparse, mixed-norm, beam\u00ADformers, dipole fitting, and more.", # noqa E501 + url="auto_tutorials/inverse/index.html", + img="sphx_glr_30_mne_dspm_loreta_008.gif", + alt="dSPM", + ), + dict( + title="Machine Learning", + text="Advanced decoding models including time general\u00ADiza\u00ADtion.", # noqa E501 + url="auto_tutorials/machine-learning/50_decoding.html", + img="sphx_glr_50_decoding_006.png", + alt="Decoding", + ), + dict( + title="Encoding Models", + text="Receptive field estima\u00ADtion with optional smooth\u00ADness priors.", # noqa E501 + url="auto_tutorials/machine-learning/30_strf.html", + img="sphx_glr_30_strf_001.png", + alt="STRF", + ), + dict( + title="Statistics", + text="Parametric and non-parametric, permutation tests and clustering.", # noqa E501 + url="auto_tutorials/stats-source-space/index.html", + img="sphx_glr_20_cluster_1samp_spatiotemporal_001.png", + alt="Clusters", + ), + dict( + title="Connectivity", + text="All-to-all spectral and effective connec\u00ADtivity measures.", # noqa E501 + url="https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html", # noqa E501 + img="https://mne.tools/mne-connectivity/stable/_images/sphx_glr_mne_inverse_label_connectivity_001.png", # noqa E501 + alt="Connectivity", + ), + dict( + title="Data Visualization", + text="Explore your data from multiple perspectives.", + url="auto_tutorials/evoked/20_visualize_evoked.html", + img="sphx_glr_20_visualize_evoked_010.png", + alt="Visualization", + ), + ], } # Output file base name for HTML help builder. -htmlhelp_basename = 'mne-doc' +htmlhelp_basename = "mne-doc" # -- Options for plot_directive ---------------------------------------------- # Adapted from SciPy plot_include_source = True -plot_formats = [('png', 96)] +plot_formats = [("png", 96)] plot_html_show_formats = False plot_html_show_source_link = False font_size = 13 * 72 / 96.0 # 13 px plot_rcparams = { - 'font.size': font_size, - 'axes.titlesize': font_size, - 'axes.labelsize': font_size, - 'xtick.labelsize': font_size, - 'ytick.labelsize': font_size, - 'legend.fontsize': font_size, - 'figure.figsize': (6, 5), - 'figure.subplot.bottom': 0.2, - 'figure.subplot.left': 0.2, - 'figure.subplot.right': 0.9, - 'figure.subplot.top': 0.85, - 'figure.subplot.wspace': 0.4, - 'text.usetex': False, + "font.size": font_size, + "axes.titlesize": font_size, + "axes.labelsize": font_size, + "xtick.labelsize": font_size, + "ytick.labelsize": font_size, + "legend.fontsize": font_size, + "figure.figsize": (6, 5), + "figure.subplot.bottom": 0.2, + "figure.subplot.left": 0.2, + "figure.subplot.right": 0.9, + "figure.subplot.top": 0.85, + "figure.subplot.wspace": 0.4, + "text.usetex": False, } @@ -951,13 +1198,14 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -latex_toplevel_sectioning = 'part' +latex_toplevel_sectioning = "part" _np_print_defaults = np.get_printoptions() # -- Warnings management ----------------------------------------------------- + def reset_warnings(gallery_conf, fname): """Ensure we are future compatible and ignore silly warnings.""" # In principle, our examples should produce no warnings. @@ -968,78 +1216,84 @@ def reset_warnings(gallery_conf, fname): # remove tweaks from other module imports or example runs warnings.resetwarnings() # restrict - warnings.filterwarnings('error') + warnings.filterwarnings("error") # allow these, but show them - warnings.filterwarnings('always', '.*non-standard config type: "foo".*') - warnings.filterwarnings('always', '.*config type: "MNEE_USE_CUUDAA".*') - warnings.filterwarnings('always', '.*cannot make axes width small.*') - warnings.filterwarnings('always', '.*Axes that are not compatible.*') - warnings.filterwarnings('always', '.*FastICA did not converge.*') + warnings.filterwarnings("always", '.*non-standard config type: "foo".*') + warnings.filterwarnings("always", '.*config type: "MNEE_USE_CUUDAA".*') + warnings.filterwarnings("always", ".*cannot make axes width small.*") + warnings.filterwarnings("always", ".*Axes that are not compatible.*") + warnings.filterwarnings("always", ".*FastICA did not converge.*") # ECoG BIDS spec violations: - warnings.filterwarnings('always', '.*Fiducial point nasion not found.*') - warnings.filterwarnings('always', '.*DigMontage is only a subset of.*') + warnings.filterwarnings("always", ".*Fiducial point nasion not found.*") + warnings.filterwarnings("always", ".*DigMontage is only a subset of.*") warnings.filterwarnings( # xhemi morph (should probably update sample) - 'always', '.*does not exist, creating it and saving it.*') + "always", ".*does not exist, creating it and saving it.*" + ) # internal warnings - warnings.filterwarnings('default', module='sphinx') + warnings.filterwarnings("default", module="sphinx") # allow these warnings, but don't show them for key in ( - 'The module matplotlib.tight_layout is deprecated', # nilearn - 'invalid version and will not be supported', # pyxdf - 'distutils Version classes are deprecated', # seaborn and neo - '`np.object` is a deprecated alias for the builtin `object`', # pyxdf + "The module matplotlib.tight_layout is deprecated", # nilearn + "invalid version and will not be supported", # pyxdf + "distutils Version classes are deprecated", # seaborn and neo + "`np.object` is a deprecated alias for the builtin `object`", # pyxdf # nilearn, should be fixed in > 0.9.1 - 'In future, it will be an error for \'np.bool_\' scalars to', + "In future, it will be an error for 'np.bool_' scalars to", # sklearn hasn't updated to SciPy's sym_pos dep - 'The \'sym_pos\' keyword is deprecated', + "The 'sym_pos' keyword is deprecated", # numba - '`np.MachAr` is deprecated', + "`np.MachAr` is deprecated", # joblib hasn't updated to avoid distutils - 'distutils package is deprecated', + "distutils package is deprecated", # jupyter - 'Jupyter is migrating its paths to use standard', - r'Widget\..* is deprecated\.', + "Jupyter is migrating its paths to use standard", + r"Widget\..* is deprecated\.", # PyQt6 - 'Enum value .* is marked as deprecated', + "Enum value .* is marked as deprecated", # matplotlib PDF output - 'The py23 module has been deprecated', + "The py23 module has been deprecated", # pkg_resources - 'Implementing implicit namespace packages', - 'Deprecated call to `pkg_resources', + "Implementing implicit namespace packages", + "Deprecated call to `pkg_resources", # nilearn - 'pkg_resources is deprecated as an API', - r'The .* was deprecated in Matplotlib 3\.7', + "pkg_resources is deprecated as an API", + r"The .* was deprecated in Matplotlib 3\.7", ): warnings.filterwarnings( # deal with other modules having bad imports - 'ignore', message=".*%s.*" % key, category=DeprecationWarning) - warnings.filterwarnings( - 'ignore', message=( - 'Matplotlib is currently using agg, which is a non-GUI backend.*' + "ignore", message=".*%s.*" % key, category=DeprecationWarning ) + warnings.filterwarnings( + "ignore", + message=("Matplotlib is currently using agg, which is a non-GUI backend.*"), ) # matplotlib 3.6 in nilearn and pyvista - warnings.filterwarnings( - 'ignore', message='.*cmap function will be deprecated.*') + warnings.filterwarnings("ignore", message=".*cmap function will be deprecated.*") # xarray/netcdf4 warnings.filterwarnings( - 'ignore', message=r'numpy\.ndarray size changed, may indicate.*', - category=RuntimeWarning) + "ignore", + message=r"numpy\.ndarray size changed, may indicate.*", + category=RuntimeWarning, + ) # qdarkstyle warnings.filterwarnings( - 'ignore', message=r'.*Setting theme=.*6 in qdarkstyle.*', - category=RuntimeWarning) + "ignore", + message=r".*Setting theme=.*6 in qdarkstyle.*", + category=RuntimeWarning, + ) # pandas, via seaborn (examples/time_frequency/time_frequency_erds.py) warnings.filterwarnings( - 'ignore', message=r'iteritems is deprecated.*Use \.items instead\.', - category=FutureWarning) + "ignore", + message=r"iteritems is deprecated.*Use \.items instead\.", + category=FutureWarning, + ) # pandas in 50_epochs_to_data_frame.py warnings.filterwarnings( - 'ignore', message=r'invalid value encountered in cast', - category=RuntimeWarning) + "ignore", message=r"invalid value encountered in cast", category=RuntimeWarning + ) # xarray _SixMetaPathImporter (?) warnings.filterwarnings( - 'ignore', message=r'falling back to find_module', - category=ImportWarning) + "ignore", message=r"falling back to find_module", category=ImportWarning + ) # In case we use np.set_printoptions in any tutorials, we only # want it to affect those: @@ -1051,49 +1305,70 @@ def reset_warnings(gallery_conf, fname): # -- Fontawesome support ----------------------------------------------------- -brand_icons = ('apple', 'linux', 'windows', 'discourse', 'python') +brand_icons = ("apple", "linux", "windows", "discourse", "python") fixed_width_icons = ( # homepage: - 'book', 'code-branch', 'newspaper', 'circle-question', 'quote-left', + "book", + "code-branch", + "newspaper", + "circle-question", + "quote-left", # contrib guide: - 'bug-slash', 'comment', 'computer-mouse', 'hand-sparkles', 'pencil', - 'text-slash', 'universal-access', 'wand-magic-sparkles', - 'discourse', 'python', + "bug-slash", + "comment", + "computer-mouse", + "hand-sparkles", + "pencil", + "text-slash", + "universal-access", + "wand-magic-sparkles", + "discourse", + "python", ) other_icons = ( - 'hand-paper', 'question', 'rocket', 'server', 'code', 'desktop', - 'terminal', 'cloud-arrow-down', 'wrench', 'hourglass-half' + "hand-paper", + "question", + "rocket", + "server", + "code", + "desktop", + "terminal", + "cloud-arrow-down", + "wrench", + "hourglass-half", ) icon_class = dict() for icon in brand_icons + fixed_width_icons + other_icons: - icon_class[icon] = ('fa-brands',) if icon in brand_icons else ('fa-solid',) - icon_class[icon] += ('fa-fw',) if icon in fixed_width_icons else () + icon_class[icon] = ("fa-brands",) if icon in brand_icons else ("fa-solid",) + icon_class[icon] += ("fa-fw",) if icon in fixed_width_icons else () -rst_prolog = '' +rst_prolog = "" for icon, classes in icon_class.items(): - rst_prolog += f''' + rst_prolog += f""" .. |{icon}| raw:: html -''' +""" -rst_prolog += ''' +rst_prolog += """ .. |ensp| unicode:: U+2002 .. EN SPACE -''' +""" # -- Dependency info ---------------------------------------------------------- try: from importlib.metadata import metadata # new in Python 3.8 - min_py = metadata('mne')['Requires-Python'] + + min_py = metadata("mne")["Requires-Python"] except ModuleNotFoundError: from pkg_resources import get_distribution - info = get_distribution('mne').get_metadata_lines('PKG-INFO') + + info = get_distribution("mne").get_metadata_lines("PKG-INFO") for line in info: - if line.strip().startswith('Requires-Python'): - min_py = line.split(':')[1] -min_py = min_py.lstrip(' =<>') -rst_prolog += f'\n.. |min_python_version| replace:: {min_py}\n' + if line.strip().startswith("Requires-Python"): + min_py = line.split(":")[1] +min_py = min_py.lstrip(" =<>") +rst_prolog += f"\n.. |min_python_version| replace:: {min_py}\n" # -- website redirects -------------------------------------------------------- @@ -1101,141 +1376,214 @@ def reset_warnings(gallery_conf, fname): # since we don't need to add redirects for examples added after this date. needed_plot_redirects = { # tutorials - '10_epochs_overview.py', '10_evoked_overview.py', '10_overview.py', - '10_preprocessing_overview.py', '10_raw_overview.py', - '10_reading_meg_data.py', '15_handling_bad_channels.py', - '20_event_arrays.py', '20_events_from_raw.py', '20_reading_eeg_data.py', - '20_rejecting_bad_data.py', '20_visualize_epochs.py', - '20_visualize_evoked.py', '30_annotate_raw.py', '30_epochs_metadata.py', - '30_filtering_resampling.py', '30_info.py', '30_reading_fnirs_data.py', - '35_artifact_correction_regression.py', '40_artifact_correction_ica.py', - '40_autogenerate_metadata.py', '40_sensor_locations.py', - '40_visualize_raw.py', '45_projectors_background.py', - '50_artifact_correction_ssp.py', '50_configure_mne.py', - '50_epochs_to_data_frame.py', '55_setting_eeg_reference.py', - '59_head_positions.py', '60_make_fixed_length_epochs.py', - '60_maxwell_filtering_sss.py', '70_fnirs_processing.py', + "10_epochs_overview.py", + "10_evoked_overview.py", + "10_overview.py", + "10_preprocessing_overview.py", + "10_raw_overview.py", + "10_reading_meg_data.py", + "15_handling_bad_channels.py", + "20_event_arrays.py", + "20_events_from_raw.py", + "20_reading_eeg_data.py", + "20_rejecting_bad_data.py", + "20_visualize_epochs.py", + "20_visualize_evoked.py", + "30_annotate_raw.py", + "30_epochs_metadata.py", + "30_filtering_resampling.py", + "30_info.py", + "30_reading_fnirs_data.py", + "35_artifact_correction_regression.py", + "40_artifact_correction_ica.py", + "40_autogenerate_metadata.py", + "40_sensor_locations.py", + "40_visualize_raw.py", + "45_projectors_background.py", + "50_artifact_correction_ssp.py", + "50_configure_mne.py", + "50_epochs_to_data_frame.py", + "55_setting_eeg_reference.py", + "59_head_positions.py", + "60_make_fixed_length_epochs.py", + "60_maxwell_filtering_sss.py", + "70_fnirs_processing.py", # examples - '3d_to_2d.py', 'brainstorm_data.py', 'channel_epochs_image.py', - 'cluster_stats_evoked.py', 'compute_csd.py', - 'compute_mne_inverse_epochs_in_label.py', - 'compute_mne_inverse_raw_in_label.py', 'compute_mne_inverse_volume.py', - 'compute_source_psd_epochs.py', 'covariance_whitening_dspm.py', - 'custom_inverse_solver.py', - 'decoding_csp_eeg.py', 'decoding_csp_timefreq.py', - 'decoding_spatio_temporal_source.py', 'decoding_spoc_CMC.py', - 'decoding_time_generalization_conditions.py', - 'decoding_unsupervised_spatial_filter.py', 'decoding_xdawn_eeg.py', - 'define_target_events.py', 'dics_source_power.py', 'eeg_csd.py', - 'eeg_on_scalp.py', 'eeglab_head_sphere.py', 'elekta_epochs.py', - 'ems_filtering.py', 'eog_artifact_histogram.py', 'evoked_arrowmap.py', - 'evoked_ers_source_power.py', 'evoked_topomap.py', 'evoked_whitening.py', - 'fdr_stats_evoked.py', 'find_ref_artifacts.py', - 'fnirs_artifact_removal.py', 'forward_sensitivity_maps.py', - 'gamma_map_inverse.py', 'hf_sef_data.py', 'ica_comparison.py', - 'interpolate_bad_channels.py', 'label_activation_from_stc.py', - 'label_from_stc.py', 'label_source_activations.py', - 'left_cerebellum_volume_source.py', 'limo_data.py', - 'linear_model_patterns.py', 'linear_regression_raw.py', - 'meg_sensors.py', 'mixed_norm_inverse.py', - 'mixed_source_space_inverse.py', - 'mne_cov_power.py', 'mne_helmet.py', 'mne_inverse_coherence_epochs.py', - 'mne_inverse_envelope_correlation.py', - 'mne_inverse_envelope_correlation_volume.py', - 'mne_inverse_psi_visual.py', - 'morph_surface_stc.py', 'morph_volume_stc.py', 'movement_compensation.py', - 'movement_detection.py', 'multidict_reweighted_tfmxne.py', - 'muscle_detection.py', 'opm_data.py', 'otp.py', 'parcellation.py', - 'psf_ctf_label_leakage.py', 'psf_ctf_vertices.py', - 'psf_ctf_vertices_lcmv.py', 'publication_figure.py', 'rap_music.py', - 'read_inverse.py', 'read_neo_format.py', 'read_noise_covariance_matrix.py', - 'read_stc.py', 'receptive_field_mtrf.py', 'resolution_metrics.py', - 'resolution_metrics_eegmeg.py', 'roi_erpimage_by_rt.py', - 'sensor_noise_level.py', - 'sensor_permutation_test.py', 'sensor_regression.py', - 'shift_evoked.py', 'simulate_evoked_data.py', 'simulate_raw_data.py', - 'simulated_raw_data_using_subject_anatomy.py', 'snr_estimate.py', - 'source_label_time_frequency.py', 'source_power_spectrum.py', - 'source_power_spectrum_opm.py', 'source_simulator.py', - 'source_space_morphing.py', 'source_space_snr.py', - 'source_space_time_frequency.py', 'ssd_spatial_filters.py', - 'ssp_projs_sensitivity_map.py', 'temporal_whitening.py', - 'time_frequency_erds.py', 'time_frequency_global_field_power.py', - 'time_frequency_mixed_norm_inverse.py', 'time_frequency_simulated.py', - 'topo_compare_conditions.py', 'topo_customized.py', - 'vector_mne_solution.py', 'virtual_evoked.py', 'xdawn_denoising.py', - 'xhemi.py', + "3d_to_2d.py", + "brainstorm_data.py", + "channel_epochs_image.py", + "cluster_stats_evoked.py", + "compute_csd.py", + "compute_mne_inverse_epochs_in_label.py", + "compute_mne_inverse_raw_in_label.py", + "compute_mne_inverse_volume.py", + "compute_source_psd_epochs.py", + "covariance_whitening_dspm.py", + "custom_inverse_solver.py", + "decoding_csp_eeg.py", + "decoding_csp_timefreq.py", + "decoding_spatio_temporal_source.py", + "decoding_spoc_CMC.py", + "decoding_time_generalization_conditions.py", + "decoding_unsupervised_spatial_filter.py", + "decoding_xdawn_eeg.py", + "define_target_events.py", + "dics_source_power.py", + "eeg_csd.py", + "eeg_on_scalp.py", + "eeglab_head_sphere.py", + "elekta_epochs.py", + "ems_filtering.py", + "eog_artifact_histogram.py", + "evoked_arrowmap.py", + "evoked_ers_source_power.py", + "evoked_topomap.py", + "evoked_whitening.py", + "fdr_stats_evoked.py", + "find_ref_artifacts.py", + "fnirs_artifact_removal.py", + "forward_sensitivity_maps.py", + "gamma_map_inverse.py", + "hf_sef_data.py", + "ica_comparison.py", + "interpolate_bad_channels.py", + "label_activation_from_stc.py", + "label_from_stc.py", + "label_source_activations.py", + "left_cerebellum_volume_source.py", + "limo_data.py", + "linear_model_patterns.py", + "linear_regression_raw.py", + "meg_sensors.py", + "mixed_norm_inverse.py", + "mixed_source_space_inverse.py", + "mne_cov_power.py", + "mne_helmet.py", + "mne_inverse_coherence_epochs.py", + "mne_inverse_envelope_correlation.py", + "mne_inverse_envelope_correlation_volume.py", + "mne_inverse_psi_visual.py", + "morph_surface_stc.py", + "morph_volume_stc.py", + "movement_compensation.py", + "movement_detection.py", + "multidict_reweighted_tfmxne.py", + "muscle_detection.py", + "opm_data.py", + "otp.py", + "parcellation.py", + "psf_ctf_label_leakage.py", + "psf_ctf_vertices.py", + "psf_ctf_vertices_lcmv.py", + "publication_figure.py", + "rap_music.py", + "read_inverse.py", + "read_neo_format.py", + "read_noise_covariance_matrix.py", + "read_stc.py", + "receptive_field_mtrf.py", + "resolution_metrics.py", + "resolution_metrics_eegmeg.py", + "roi_erpimage_by_rt.py", + "sensor_noise_level.py", + "sensor_permutation_test.py", + "sensor_regression.py", + "shift_evoked.py", + "simulate_evoked_data.py", + "simulate_raw_data.py", + "simulated_raw_data_using_subject_anatomy.py", + "snr_estimate.py", + "source_label_time_frequency.py", + "source_power_spectrum.py", + "source_power_spectrum_opm.py", + "source_simulator.py", + "source_space_morphing.py", + "source_space_snr.py", + "source_space_time_frequency.py", + "ssd_spatial_filters.py", + "ssp_projs_sensitivity_map.py", + "temporal_whitening.py", + "time_frequency_erds.py", + "time_frequency_global_field_power.py", + "time_frequency_mixed_norm_inverse.py", + "time_frequency_simulated.py", + "topo_compare_conditions.py", + "topo_customized.py", + "vector_mne_solution.py", + "virtual_evoked.py", + "xdawn_denoising.py", + "xhemi.py", } -ex = 'auto_examples' -co = 'connectivity' -mne_conn = 'https://mne.tools/mne-connectivity/stable' -tu = 'auto_tutorials' -di = 'discussions' -sm = 'source-modeling' -fw = 'forward' -nv = 'inverse' -sn = 'stats-sensor-space' -sr = 'stats-source-space' -sd = 'sample-datasets' -ml = 'machine-learning' -tf = 'time-freq' -si = 'simulation' +ex = "auto_examples" +co = "connectivity" +mne_conn = "https://mne.tools/mne-connectivity/stable" +tu = "auto_tutorials" +di = "discussions" +sm = "source-modeling" +fw = "forward" +nv = "inverse" +sn = "stats-sensor-space" +sr = "stats-source-space" +sd = "sample-datasets" +ml = "machine-learning" +tf = "time-freq" +si = "simulation" custom_redirects = { # Custom redirects (one HTML path to another, relative to outdir) # can be added here as fr->to key->value mappings - f'{tu}/evoked/plot_eeg_erp.html': f'{tu}/evoked/30_eeg_erp.html', - f'{tu}/evoked/plot_whitened.html': f'{tu}/evoked/40_whitened.html', - f'{tu}/misc/plot_modifying_data_inplace.html': f'{tu}/intro/15_inplace.html', # noqa E501 - f'{tu}/misc/plot_report.html': f'{tu}/intro/70_report.html', - f'{tu}/misc/plot_seeg.html': f'{tu}/clinical/20_seeg.html', - f'{tu}/misc/plot_ecog.html': f'{tu}/clinical/30_ecog.html', - f'{tu}/{ml}/plot_receptive_field.html': f'{tu}/{ml}/30_strf.html', - f'{tu}/{ml}/plot_sensors_decoding.html': f'{tu}/{ml}/50_decoding.html', - f'{tu}/{sm}/plot_background_freesurfer.html': f'{tu}/{fw}/10_background_freesurfer.html', # noqa E501 - f'{tu}/{sm}/plot_source_alignment.html': f'{tu}/{fw}/20_source_alignment.html', # noqa E501 - f'{tu}/{sm}/plot_forward.html': f'{tu}/{fw}/30_forward.html', - f'{tu}/{sm}/plot_eeg_no_mri.html': f'{tu}/{fw}/35_eeg_no_mri.html', - f'{tu}/{sm}/plot_background_freesurfer_mne.html': f'{tu}/{fw}/50_background_freesurfer_mne.html', # noqa E501 - f'{tu}/{sm}/plot_fix_bem_in_blender.html': f'{tu}/{fw}/80_fix_bem_in_blender.html', # noqa E501 - f'{tu}/{sm}/plot_compute_covariance.html': f'{tu}/{fw}/90_compute_covariance.html', # noqa E501 - f'{tu}/{sm}/plot_object_source_estimate.html': f'{tu}/{nv}/10_stc_class.html', # noqa E501 - f'{tu}/{sm}/plot_dipole_fit.html': f'{tu}/{nv}/20_dipole_fit.html', - f'{tu}/{sm}/plot_mne_dspm_source_localization.html': f'{tu}/{nv}/30_mne_dspm_loreta.html', # noqa E501 - f'{tu}/{sm}/plot_dipole_orientations.html': f'{tu}/{nv}/35_dipole_orientations.html', # noqa E501 - f'{tu}/{sm}/plot_mne_solutions.html': f'{tu}/{nv}/40_mne_fixed_free.html', - f'{tu}/{sm}/plot_beamformer_lcmv.html': f'{tu}/{nv}/50_beamformer_lcmv.html', # noqa E501 - f'{tu}/{sm}/plot_visualize_stc.html': f'{tu}/{nv}/60_visualize_stc.html', - f'{tu}/{sm}/plot_eeg_mri_coords.html': f'{tu}/{nv}/70_eeg_mri_coords.html', - f'{tu}/{sd}/plot_brainstorm_phantom_elekta.html': f'{tu}/{nv}/80_brainstorm_phantom_elekta.html', # noqa E501 - f'{tu}/{sd}/plot_brainstorm_phantom_ctf.html': f'{tu}/{nv}/85_brainstorm_phantom_ctf.html', # noqa E501 - f'{tu}/{sd}/plot_phantom_4DBTi.html': f'{tu}/{nv}/90_phantom_4DBTi.html', - f'{tu}/{sd}/plot_brainstorm_auditory.html': f'{tu}/io/60_ctf_bst_auditory.html', # noqa E501 - f'{tu}/{sd}/plot_sleep.html': f'{tu}/clinical/60_sleep.html', - f'{tu}/{di}/plot_background_filtering.html': f'{tu}/preprocessing/25_background_filtering.html', # noqa E501 - f'{tu}/{di}/plot_background_statistics.html': f'{tu}/{sn}/10_background_stats.html', # noqa E501 - f'{tu}/{sn}/plot_stats_cluster_erp.html': f'{tu}/{sn}/20_erp_stats.html', - f'{tu}/{sn}/plot_stats_cluster_1samp_test_time_frequency.html': f'{tu}/{sn}/40_cluster_1samp_time_freq.html', # noqa E501 - f'{tu}/{sn}/plot_stats_cluster_time_frequency.html': f'{tu}/{sn}/50_cluster_between_time_freq.html', # noqa E501 - f'{tu}/{sn}/plot_stats_spatio_temporal_cluster_sensors.html': f'{tu}/{sn}/75_cluster_ftest_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal.html': f'{tu}/{sr}/20_cluster_1samp_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal_2samp.html': f'{tu}/{sr}/30_cluster_ftest_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal_repeated_measures_anova.html': f'{tu}/{sr}/60_cluster_rmANOVA_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_time_frequency_repeated_measures_anova.html': f'{tu}/{sn}/70_cluster_rmANOVA_time_freq.html', # noqa E501 - f'{tu}/{tf}/plot_sensors_time_frequency.html': f'{tu}/{tf}/20_sensors_time_frequency.html', # noqa E501 - f'{tu}/{tf}/plot_ssvep.html': f'{tu}/{tf}/50_ssvep.html', - f'{tu}/{si}/plot_creating_data_structures.html': f'{tu}/{si}/10_array_objs.html', # noqa E501 - f'{tu}/{si}/plot_point_spread.html': f'{tu}/{si}/70_point_spread.html', - f'{tu}/{si}/plot_dics.html': f'{tu}/{si}/80_dics.html', - f'{tu}/{tf}/plot_eyetracking.html': f'{tu}/preprocessing/90_eyetracking_data.html', # noqa E501 - f'{ex}/{co}/mne_inverse_label_connectivity.html': f'{mne_conn}/{ex}/mne_inverse_label_connectivity.html', # noqa E501 - f'{ex}/{co}/cwt_sensor_connectivity.html': f'{mne_conn}/{ex}/cwt_sensor_connectivity.html', # noqa E501 - f'{ex}/{co}/mixed_source_space_connectivity.html': f'{mne_conn}/{ex}/mixed_source_space_connectivity.html', # noqa E501 - f'{ex}/{co}/mne_inverse_coherence_epochs.html': f'{mne_conn}/{ex}/mne_inverse_coherence_epochs.html', # noqa E501 - f'{ex}/{co}/mne_inverse_connectivity_spectrum.html': f'{mne_conn}/{ex}/mne_inverse_connectivity_spectrum.html', # noqa E501 - f'{ex}/{co}/mne_inverse_envelope_correlation_volume.html': f'{mne_conn}/{ex}/mne_inverse_envelope_correlation_volume.html', # noqa E501 - f'{ex}/{co}/mne_inverse_envelope_correlation.html': f'{mne_conn}/{ex}/mne_inverse_envelope_correlation.html', # noqa E501 - f'{ex}/{co}/mne_inverse_psi_visual.html': f'{mne_conn}/{ex}/mne_inverse_psi_visual.html', # noqa E501 - f'{ex}/{co}/sensor_connectivity.html': f'{mne_conn}/{ex}/sensor_connectivity.html', # noqa E501 + f"{tu}/evoked/plot_eeg_erp.html": f"{tu}/evoked/30_eeg_erp.html", + f"{tu}/evoked/plot_whitened.html": f"{tu}/evoked/40_whitened.html", + f"{tu}/misc/plot_modifying_data_inplace.html": f"{tu}/intro/15_inplace.html", # noqa E501 + f"{tu}/misc/plot_report.html": f"{tu}/intro/70_report.html", + f"{tu}/misc/plot_seeg.html": f"{tu}/clinical/20_seeg.html", + f"{tu}/misc/plot_ecog.html": f"{tu}/clinical/30_ecog.html", + f"{tu}/{ml}/plot_receptive_field.html": f"{tu}/{ml}/30_strf.html", + f"{tu}/{ml}/plot_sensors_decoding.html": f"{tu}/{ml}/50_decoding.html", + f"{tu}/{sm}/plot_background_freesurfer.html": f"{tu}/{fw}/10_background_freesurfer.html", # noqa E501 + f"{tu}/{sm}/plot_source_alignment.html": f"{tu}/{fw}/20_source_alignment.html", # noqa E501 + f"{tu}/{sm}/plot_forward.html": f"{tu}/{fw}/30_forward.html", + f"{tu}/{sm}/plot_eeg_no_mri.html": f"{tu}/{fw}/35_eeg_no_mri.html", + f"{tu}/{sm}/plot_background_freesurfer_mne.html": f"{tu}/{fw}/50_background_freesurfer_mne.html", # noqa E501 + f"{tu}/{sm}/plot_fix_bem_in_blender.html": f"{tu}/{fw}/80_fix_bem_in_blender.html", # noqa E501 + f"{tu}/{sm}/plot_compute_covariance.html": f"{tu}/{fw}/90_compute_covariance.html", # noqa E501 + f"{tu}/{sm}/plot_object_source_estimate.html": f"{tu}/{nv}/10_stc_class.html", # noqa E501 + f"{tu}/{sm}/plot_dipole_fit.html": f"{tu}/{nv}/20_dipole_fit.html", + f"{tu}/{sm}/plot_mne_dspm_source_localization.html": f"{tu}/{nv}/30_mne_dspm_loreta.html", # noqa E501 + f"{tu}/{sm}/plot_dipole_orientations.html": f"{tu}/{nv}/35_dipole_orientations.html", # noqa E501 + f"{tu}/{sm}/plot_mne_solutions.html": f"{tu}/{nv}/40_mne_fixed_free.html", + f"{tu}/{sm}/plot_beamformer_lcmv.html": f"{tu}/{nv}/50_beamformer_lcmv.html", # noqa E501 + f"{tu}/{sm}/plot_visualize_stc.html": f"{tu}/{nv}/60_visualize_stc.html", + f"{tu}/{sm}/plot_eeg_mri_coords.html": f"{tu}/{nv}/70_eeg_mri_coords.html", + f"{tu}/{sd}/plot_brainstorm_phantom_elekta.html": f"{tu}/{nv}/80_brainstorm_phantom_elekta.html", # noqa E501 + f"{tu}/{sd}/plot_brainstorm_phantom_ctf.html": f"{tu}/{nv}/85_brainstorm_phantom_ctf.html", # noqa E501 + f"{tu}/{sd}/plot_phantom_4DBTi.html": f"{tu}/{nv}/90_phantom_4DBTi.html", + f"{tu}/{sd}/plot_brainstorm_auditory.html": f"{tu}/io/60_ctf_bst_auditory.html", # noqa E501 + f"{tu}/{sd}/plot_sleep.html": f"{tu}/clinical/60_sleep.html", + f"{tu}/{di}/plot_background_filtering.html": f"{tu}/preprocessing/25_background_filtering.html", # noqa E501 + f"{tu}/{di}/plot_background_statistics.html": f"{tu}/{sn}/10_background_stats.html", # noqa E501 + f"{tu}/{sn}/plot_stats_cluster_erp.html": f"{tu}/{sn}/20_erp_stats.html", + f"{tu}/{sn}/plot_stats_cluster_1samp_test_time_frequency.html": f"{tu}/{sn}/40_cluster_1samp_time_freq.html", # noqa E501 + f"{tu}/{sn}/plot_stats_cluster_time_frequency.html": f"{tu}/{sn}/50_cluster_between_time_freq.html", # noqa E501 + f"{tu}/{sn}/plot_stats_spatio_temporal_cluster_sensors.html": f"{tu}/{sn}/75_cluster_ftest_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal.html": f"{tu}/{sr}/20_cluster_1samp_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal_2samp.html": f"{tu}/{sr}/30_cluster_ftest_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal_repeated_measures_anova.html": f"{tu}/{sr}/60_cluster_rmANOVA_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_time_frequency_repeated_measures_anova.html": f"{tu}/{sn}/70_cluster_rmANOVA_time_freq.html", # noqa E501 + f"{tu}/{tf}/plot_sensors_time_frequency.html": f"{tu}/{tf}/20_sensors_time_frequency.html", # noqa E501 + f"{tu}/{tf}/plot_ssvep.html": f"{tu}/{tf}/50_ssvep.html", + f"{tu}/{si}/plot_creating_data_structures.html": f"{tu}/{si}/10_array_objs.html", # noqa E501 + f"{tu}/{si}/plot_point_spread.html": f"{tu}/{si}/70_point_spread.html", + f"{tu}/{si}/plot_dics.html": f"{tu}/{si}/80_dics.html", + f"{tu}/{tf}/plot_eyetracking.html": f"{tu}/preprocessing/90_eyetracking_data.html", # noqa E501 + f"{ex}/{co}/mne_inverse_label_connectivity.html": f"{mne_conn}/{ex}/mne_inverse_label_connectivity.html", # noqa E501 + f"{ex}/{co}/cwt_sensor_connectivity.html": f"{mne_conn}/{ex}/cwt_sensor_connectivity.html", # noqa E501 + f"{ex}/{co}/mixed_source_space_connectivity.html": f"{mne_conn}/{ex}/mixed_source_space_connectivity.html", # noqa E501 + f"{ex}/{co}/mne_inverse_coherence_epochs.html": f"{mne_conn}/{ex}/mne_inverse_coherence_epochs.html", # noqa E501 + f"{ex}/{co}/mne_inverse_connectivity_spectrum.html": f"{mne_conn}/{ex}/mne_inverse_connectivity_spectrum.html", # noqa E501 + f"{ex}/{co}/mne_inverse_envelope_correlation_volume.html": f"{mne_conn}/{ex}/mne_inverse_envelope_correlation_volume.html", # noqa E501 + f"{ex}/{co}/mne_inverse_envelope_correlation.html": f"{mne_conn}/{ex}/mne_inverse_envelope_correlation.html", # noqa E501 + f"{ex}/{co}/mne_inverse_psi_visual.html": f"{mne_conn}/{ex}/mne_inverse_psi_visual.html", # noqa E501 + f"{ex}/{co}/sensor_connectivity.html": f"{mne_conn}/{ex}/sensor_connectivity.html", # noqa E501 } @@ -1243,11 +1591,12 @@ def make_redirects(app, exception): """Make HTML redirects.""" # https://www.sphinx-doc.org/en/master/extdev/appapi.html # Adapted from sphinxcontrib/redirects (BSD-2-Clause) - if not (isinstance(app.builder, - sphinx.builders.html.StandaloneHTMLBuilder) and - exception is None): + if not ( + isinstance(app.builder, sphinx.builders.html.StandaloneHTMLBuilder) + and exception is None + ): return - logger = sphinx.util.logging.getLogger('mne') + logger = sphinx.util.logging.getLogger("mne") TEMPLATE = """\ @@ -1263,79 +1612,88 @@ def make_redirects(app, exception): If you are not redirected automatically, follow this link. """ # noqa: E501 - sphinx_gallery_conf = app.config['sphinx_gallery_conf'] - for src_dir, out_dir in zip(sphinx_gallery_conf['examples_dirs'], - sphinx_gallery_conf['gallery_dirs']): + sphinx_gallery_conf = app.config["sphinx_gallery_conf"] + for src_dir, out_dir in zip( + sphinx_gallery_conf["examples_dirs"], sphinx_gallery_conf["gallery_dirs"] + ): root = os.path.abspath(os.path.join(app.srcdir, src_dir)) - fnames = [os.path.join(os.path.relpath(dirpath, root), fname) - for dirpath, _, fnames in os.walk(root) - for fname in fnames - if fname in needed_plot_redirects] + fnames = [ + os.path.join(os.path.relpath(dirpath, root), fname) + for dirpath, _, fnames in os.walk(root) + for fname in fnames + if fname in needed_plot_redirects + ] # plot_ redirects for fname in fnames: dirname = os.path.join(app.outdir, out_dir, os.path.dirname(fname)) - to_fname = os.path.splitext(os.path.basename(fname))[0] + '.html' - fr_fname = f'plot_{to_fname}' + to_fname = os.path.splitext(os.path.basename(fname))[0] + ".html" + fr_fname = f"plot_{to_fname}" to_path = os.path.join(dirname, to_fname) fr_path = os.path.join(dirname, fr_fname) assert os.path.isfile(to_path), (fname, to_path) - with open(fr_path, 'w') as fid: + with open(fr_path, "w") as fid: fid.write(TEMPLATE.format(to=to_fname)) sphinx_logger.info( - f'Added {len(fnames):3d} HTML plot_* redirects for {out_dir}') + f"Added {len(fnames):3d} HTML plot_* redirects for {out_dir}" + ) # custom redirects for fr, to in custom_redirects.items(): - if not to.startswith('http'): + if not to.startswith("http"): assert os.path.isfile(os.path.join(app.outdir, to)), to # handle links to sibling folders - path_parts = to.split('/') + path_parts = to.split("/") assert tu in path_parts, path_parts # need to refactor otherwise - path_parts = ['..'] + path_parts[(path_parts.index(tu) + 1):] + path_parts = [".."] + path_parts[(path_parts.index(tu) + 1) :] to = os.path.join(*path_parts) - assert to.endswith('html'), to + assert to.endswith("html"), to fr_path = os.path.join(app.outdir, fr) - assert fr_path.endswith('html'), fr_path + assert fr_path.endswith("html"), fr_path # allow overwrite if existing file is just a redirect if os.path.isfile(fr_path): - with open(fr_path, 'r') as fid: + with open(fr_path, "r") as fid: for _ in range(8): next(fid) line = fid.readline() - assert 'Page Redirection' in line, line + assert "Page Redirection" in line, line # handle folders that no longer exist - if fr_path.split('/')[-2] in ( - 'misc', 'discussions', 'source-modeling', 'sample-datasets', - 'connectivity'): + if fr_path.split("/")[-2] in ( + "misc", + "discussions", + "source-modeling", + "sample-datasets", + "connectivity", + ): os.makedirs(os.path.dirname(fr_path), exist_ok=True) - with open(fr_path, 'w') as fid: + with open(fr_path, "w") as fid: fid.write(TEMPLATE.format(to=to)) - sphinx_logger.info( - f'Added {len(custom_redirects):3d} HTML custom redirects') + sphinx_logger.info(f"Added {len(custom_redirects):3d} HTML custom redirects") def make_version(app, exception): """Make a text file with the git version.""" - if not (isinstance(app.builder, - sphinx.builders.html.StandaloneHTMLBuilder) and - exception is None): + if not ( + isinstance(app.builder, sphinx.builders.html.StandaloneHTMLBuilder) + and exception is None + ): return - logger = sphinx.util.logging.getLogger('mne') + logger = sphinx.util.logging.getLogger("mne") try: - stdout, _ = run_subprocess(['git', 'rev-parse', 'HEAD'], verbose=False) + stdout, _ = run_subprocess(["git", "rev-parse", "HEAD"], verbose=False) except Exception as exc: - sphinx_logger.warning(f'Failed to write _version.txt: {exc}') + sphinx_logger.warning(f"Failed to write _version.txt: {exc}") return - with open(os.path.join(app.outdir, '_version.txt'), 'w') as fid: + with open(os.path.join(app.outdir, "_version.txt"), "w") as fid: fid.write(stdout) sphinx_logger.info(f'Added "{stdout.rstrip()}" > _version.txt') # -- Connect our handlers to the main Sphinx app --------------------------- + def setup(app): """Set up the Sphinx app.""" - app.connect('autodoc-process-docstring', append_attr_meth_examples) + app.connect("autodoc-process-docstring", append_attr_meth_examples) report_scraper.app = app - app.connect('builder-inited', report_scraper.copyfiles) - app.connect('build-finished', make_redirects) - app.connect('build-finished', make_version) + app.connect("builder-inited", report_scraper.copyfiles) + app.connect("build-finished", make_redirects) + app.connect("build-finished", make_version) diff --git a/doc/sphinxext/flow_diagram.py b/doc/sphinxext/flow_diagram.py index 9adb8636e2f..d6a941d7869 100644 --- a/doc/sphinxext/flow_diagram.py +++ b/doc/sphinxext/flow_diagram.py @@ -1,14 +1,14 @@ import os from os import path as op -title = 'mne-python flow diagram' +title = "mne-python flow diagram" -font_face = 'Arial' +font_face = "Arial" node_size = 12 node_small_size = 9 edge_size = 9 -sensor_color = '#7bbeca' -source_color = '#ff6347' +sensor_color = "#7bbeca" +source_color = "#ff6347" legend = """ < @@ -17,62 +17,74 @@ Sensor (M/EEG) space Source (brain) space ->""" % (edge_size, sensor_color, source_color) -legend = ''.join(legend.split('\n')) +>""" % ( + edge_size, + sensor_color, + source_color, +) +legend = "".join(legend.split("\n")) nodes = dict( - T1='T1', - flashes='Flash5/30', - trans='Head-MRI trans', - recon='Freesurfer surfaces', - bem='BEM', - src='http://github.com/Source%20space/nmne.SourceSpaces', - cov='Noise covariance\nmne.Covariance', - fwd='Forward solution\nmne.forward.Forward', - inv='Inverse operator\nmne.minimum_norm.InverseOperator', - stc='Source estimate\nmne.SourceEstimate', - raw='Raw data\nmne.io.Raw', - epo='Epoched data\nmne.Epochs', - evo='Averaged data\nmne.Evoked', - pre='Preprocessed data\nmne.io.Raw', + T1="T1", + flashes="Flash5/30", + trans="Head-MRI trans", + recon="Freesurfer surfaces", + bem="BEM", + src="http://github.com/Source%20space/nmne.SourceSpaces", + cov="Noise covariance\nmne.Covariance", + fwd="Forward solution\nmne.forward.Forward", + inv="Inverse operator\nmne.minimum_norm.InverseOperator", + stc="Source estimate\nmne.SourceEstimate", + raw="Raw data\nmne.io.Raw", + epo="Epoched data\nmne.Epochs", + evo="Averaged data\nmne.Evoked", + pre="Preprocessed data\nmne.io.Raw", legend=legend, ) -sensor_space = ('raw', 'pre', 'epo', 'evo', 'cov') -source_space = ('src', 'stc', 'bem', 'flashes', 'recon', 'T1') +sensor_space = ("raw", "pre", "epo", "evo", "cov") +source_space = ("src", "stc", "bem", "flashes", "recon", "T1") edges = ( - ('T1', 'recon'), - ('flashes', 'bem'), - ('recon', 'bem'), - ('recon', 'src', 'mne.setup_source_space'), - ('src', 'fwd'), - ('bem', 'fwd'), - ('trans', 'fwd', 'mne.make_forward_solution'), - ('fwd', 'inv'), - ('cov', 'inv', 'mne.make_inverse_operator'), - ('inv', 'stc'), - ('evo', 'stc', 'mne.minimum_norm.apply_inverse'), - ('raw', 'pre', 'raw.filter\n' - 'mne.preprocessing.ICA\n' - 'mne.preprocessing.compute_proj_eog\n' - 'mne.preprocessing.compute_proj_ecg\n' - '...'), - ('pre', 'epo', 'mne.Epochs'), - ('epo', 'evo', 'epochs.average'), - ('epo', 'cov', 'mne.compute_covariance'), + ("T1", "recon"), + ("flashes", "bem"), + ("recon", "bem"), + ("recon", "src", "mne.setup_source_space"), + ("src", "fwd"), + ("bem", "fwd"), + ("trans", "fwd", "mne.make_forward_solution"), + ("fwd", "inv"), + ("cov", "inv", "mne.make_inverse_operator"), + ("inv", "stc"), + ("evo", "stc", "mne.minimum_norm.apply_inverse"), + ( + "raw", + "pre", + "raw.filter\n" + "mne.preprocessing.ICA\n" + "mne.preprocessing.compute_proj_eog\n" + "mne.preprocessing.compute_proj_ecg\n" + "...", + ), + ("pre", "epo", "mne.Epochs"), + ("epo", "evo", "epochs.average"), + ("epo", "cov", "mne.compute_covariance"), ) subgraphs = ( - [('T1', 'flashes', 'recon', 'bem', 'src'), - ('' - 'Freesurfer / MNE-C>' % node_small_size)], + [ + ("T1", "flashes", "recon", "bem", "src"), + ( + '' + "Freesurfer / MNE-C>" % node_small_size + ), + ], ) def setup(app): - app.connect('builder-inited', generate_flow_diagram) - app.add_config_value('make_flow_diagram', True, 'html') + app.connect("builder-inited", generate_flow_diagram) + app.add_config_value("make_flow_diagram", True, "html") def setup_module(): @@ -81,84 +93,88 @@ def setup_module(): def generate_flow_diagram(app): - out_dir = op.join(app.builder.outdir, '_static') + out_dir = op.join(app.builder.outdir, "_static") if not op.isdir(out_dir): os.makedirs(out_dir) - out_fname = op.join(out_dir, 'mne-python_flow.svg') - make_flow_diagram = app is None or \ - bool(app.builder.config.make_flow_diagram) + out_fname = op.join(out_dir, "mne-python_flow.svg") + make_flow_diagram = app is None or bool(app.builder.config.make_flow_diagram) if not make_flow_diagram: - print('Skipping flow diagram, webpage will have a missing image') + print("Skipping flow diagram, webpage will have a missing image") return import pygraphviz as pgv + g = pgv.AGraph(name=title, directed=True) for key, label in nodes.items(): - label = label.split('\n') + label = label.split("\n") if len(label) > 1: - label[0] = ('<' % node_size - + label[0] + '') + label[0] = '<' % node_size + label[0] + "" for li in range(1, len(label)): - label[li] = ('' % node_small_size - + label[li] + '') - label[-1] = label[-1] + '>' - label = '
'.join(label) + label[li] = ( + '' % node_small_size + + label[li] + + "" + ) + label[-1] = label[-1] + ">" + label = "
".join(label) else: label = label[0] - g.add_node(key, shape='plaintext', label=label) + g.add_node(key, shape="plaintext", label=label) # Create and customize nodes and edges for edge in edges: g.add_edge(*edge[:2]) e = g.get_edge(*edge[:2]) if len(edge) > 2: - e.attr['label'] = ('<' + - '
'.join(edge[2].split('\n')) + - '
>') - e.attr['fontsize'] = edge_size + e.attr["label"] = ( + "<" + + '
'.join(edge[2].split("\n")) + + '
>' + ) + e.attr["fontsize"] = edge_size # Change colors - for these_nodes, color in zip((sensor_space, source_space), - (sensor_color, source_color)): + for these_nodes, color in zip( + (sensor_space, source_space), (sensor_color, source_color) + ): for node in these_nodes: - g.get_node(node).attr['fillcolor'] = color - g.get_node(node).attr['style'] = 'filled' + g.get_node(node).attr["fillcolor"] = color + g.get_node(node).attr["style"] = "filled" # Create subgraphs for si, subgraph in enumerate(subgraphs): - g.add_subgraph(subgraph[0], 'cluster%s' % si, - label=subgraph[1], color='black') + g.add_subgraph(subgraph[0], "cluster%s" % si, label=subgraph[1], color="black") # Format (sub)graphs for gr in g.subgraphs() + [g]: for x in [gr.node_attr, gr.edge_attr]: - x['fontname'] = font_face - g.node_attr['shape'] = 'box' + x["fontname"] = font_face + g.node_attr["shape"] = "box" # A couple of special ones - for ni, node in enumerate(('fwd', 'inv', 'trans')): + for ni, node in enumerate(("fwd", "inv", "trans")): node = g.get_node(node) - node.attr['gradientangle'] = 270 + node.attr["gradientangle"] = 270 colors = (source_color, sensor_color) colors = colors if ni == 0 else colors[::-1] - node.attr['fillcolor'] = ':'.join(colors) - node.attr['style'] = 'filled' + node.attr["fillcolor"] = ":".join(colors) + node.attr["style"] = "filled" del node - g.get_node('legend').attr.update(shape='plaintext', margin=0, rank='sink') + g.get_node("legend").attr.update(shape="plaintext", margin=0, rank="sink") # put legend in same rank/level as inverse - leg = g.add_subgraph(['legend', 'inv'], name='legendy') - leg.graph_attr['rank'] = 'same' + leg = g.add_subgraph(["legend", "inv"], name="legendy") + leg.graph_attr["rank"] = "same" - g.layout('dot') - g.draw(out_fname, format='svg') + g.layout("dot") + g.draw(out_fname, format="svg") return g # This is useful for testing/iterating to see what the result looks like -if __name__ == '__main__': +if __name__ == "__main__": from mne.io.constants import Bunch - out_dir = op.abspath(op.join(op.dirname(__file__), '..', '_build', 'html')) - app = Bunch(builder=Bunch(outdir=out_dir, - config=Bunch(make_flow_diagram=True))) + + out_dir = op.abspath(op.join(op.dirname(__file__), "..", "_build", "html")) + app = Bunch(builder=Bunch(outdir=out_dir, config=Bunch(make_flow_diagram=True))) g = generate_flow_diagram(app) diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py index 0339160b2bb..0ca15319d36 100644 --- a/doc/sphinxext/gen_commands.py +++ b/doc/sphinxext/gen_commands.py @@ -7,7 +7,7 @@ def setup(app): - app.connect('builder-inited', generate_commands_rst) + app.connect("builder-inited", generate_commands_rst) def setup_module(): @@ -52,23 +52,24 @@ def generate_commands_rst(app=None): except Exception: from sphinx.util import status_iterator root = Path(__file__).parent.parent.parent.absolute() - out_dir = (root / 'doc' / 'generated').absolute() + out_dir = (root / "doc" / "generated").absolute() out_dir.mkdir(exist_ok=True) - out_fname =out_dir / 'commands.rst.new' + out_fname = out_dir / "commands.rst.new" - command_path = root / 'mne' / 'commands' + command_path = root / "mne" / "commands" fnames = sorted( - Path(fname).name - for fname in glob.glob(str(command_path / 'mne_*.py'))) + Path(fname).name for fname in glob.glob(str(command_path / "mne_*.py")) + ) assert len(fnames) iterator = status_iterator( - fnames, 'generating MNE command help ... ', length=len(fnames)) - with open(out_fname, 'w', encoding='utf8') as f: + fnames, "generating MNE command help ... ", length=len(fnames) + ) + with open(out_fname, "w", encoding="utf8") as f: f.write(header) for fname in iterator: cmd_name = fname[:-3] - module = import_module('.' + cmd_name, 'mne.commands') - with ArgvSetter(('mne', cmd_name, '--help')) as out: + module = import_module("." + cmd_name, "mne.commands") + with ArgvSetter(("mne", cmd_name, "--help")) as out: try: module.run() except SystemExit: # this is how these terminate @@ -80,29 +81,30 @@ def generate_commands_rst(app=None): # Add header marking for idx in (1, 0): - output.insert(idx, '-' * len(output[0])) + output.insert(idx, "-" * len(output[0])) # Add code styling for the "Usage: " line for li, line in enumerate(output): - if line.startswith('Usage: mne '): - output[li] = 'Usage: ``%s``' % line[7:] + if line.startswith("Usage: mne "): + output[li] = "Usage: ``%s``" % line[7:] break # Turn "Options:" into field list - if 'Options:' in output: - ii = output.index('Options:') - output[ii] = 'Options' - output.insert(ii + 1, '-------') - output.insert(ii + 2, '') - output.insert(ii + 3, '.. rst-class:: field-list cmd-list') - output.insert(ii + 4, '') - output = '\n'.join(output) - cmd_name_space = cmd_name.replace('mne_', 'mne ') - f.write(command_rst.format( - cmd_name_space, '=' * len(cmd_name_space), output)) + if "Options:" in output: + ii = output.index("Options:") + output[ii] = "Options" + output.insert(ii + 1, "-------") + output.insert(ii + 2, "") + output.insert(ii + 3, ".. rst-class:: field-list cmd-list") + output.insert(ii + 4, "") + output = "\n".join(output) + cmd_name_space = cmd_name.replace("mne_", "mne ") + f.write( + command_rst.format(cmd_name_space, "=" * len(cmd_name_space), output) + ) _replace_md5(str(out_fname)) # This is useful for testing/iterating to see what the result looks like -if __name__ == '__main__': +if __name__ == "__main__": generate_commands_rst() diff --git a/doc/sphinxext/gen_names.py b/doc/sphinxext/gen_names.py index 92c155b8f52..c5cc7f9f9ea 100644 --- a/doc/sphinxext/gen_names.py +++ b/doc/sphinxext/gen_names.py @@ -3,7 +3,7 @@ def setup(app): - app.connect('builder-inited', generate_name_links_rst) + app.connect("builder-inited", generate_name_links_rst) def setup_module(): @@ -12,17 +12,18 @@ def setup_module(): def generate_name_links_rst(app=None): - if 'linkcheck' not in str(app.builder).lower(): + if "linkcheck" not in str(app.builder).lower(): return - out_dir = op.abspath(op.join(op.dirname(__file__), '..', 'generated')) + out_dir = op.abspath(op.join(op.dirname(__file__), "..", "generated")) if not op.isdir(out_dir): os.mkdir(out_dir) - out_fname = op.join(out_dir, '_names.rst') + out_fname = op.join(out_dir, "_names.rst") names_path = op.abspath( - op.join(os.path.dirname(__file__), '..', 'changes', 'names.inc')) - with open(out_fname, 'w', encoding='utf8') as fout: - fout.write(':orphan:\n\n') - with open(names_path, 'r') as fin: + op.join(os.path.dirname(__file__), "..", "changes", "names.inc") + ) + with open(out_fname, "w", encoding="utf8") as fout: + fout.write(":orphan:\n\n") + with open(names_path, "r") as fin: for line in fin: - if line.startswith('.. _'): - fout.write(f'- {line[4:]}') + if line.startswith(".. _"): + fout.write(f"- {line[4:]}") diff --git a/doc/sphinxext/gh_substitutions.py b/doc/sphinxext/gh_substitutions.py index f0c6a05c5ba..4463425867d 100644 --- a/doc/sphinxext/gh_substitutions.py +++ b/doc/sphinxext/gh_substitutions.py @@ -15,14 +15,14 @@ def gh_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # direct link mode slug = text else: - slug = 'issues/' + text - text = '#' + text - ref = 'https://github.com/mne-tools/mne-python/' + slug + slug = "issues/" + text + text = "#" + text + ref = "https://github.com/mne-tools/mne-python/" + slug set_classes(options) node = reference(rawtext, text, refuri=ref, **options) return [node], [] def setup(app): - app.add_role('gh', gh_role) + app.add_role("gh", gh_role) return diff --git a/doc/sphinxext/mne_substitutions.py b/doc/sphinxext/mne_substitutions.py index a9309baaf42..a1b8627edf9 100644 --- a/doc/sphinxext/mne_substitutions.py +++ b/doc/sphinxext/mne_substitutions.py @@ -3,46 +3,57 @@ from docutils.statemachine import StringList from mne.defaults import DEFAULTS -from mne.io.pick import (_PICK_TYPES_DATA_DICT, _DATA_CH_TYPES_SPLIT, - _DATA_CH_TYPES_ORDER_DEFAULT) +from mne.io.pick import ( + _PICK_TYPES_DATA_DICT, + _DATA_CH_TYPES_SPLIT, + _DATA_CH_TYPES_ORDER_DEFAULT, +) class MNESubstitution(Directive): # noqa: D101 - has_content = False required_arguments = 1 final_argument_whitespace = True def run(self, **kwargs): # noqa: D102 env = self.state.document.settings.env - if self.arguments[0] == 'data channels list': + if self.arguments[0] == "data channels list": keys = list() for key in _DATA_CH_TYPES_ORDER_DEFAULT: if key in _DATA_CH_TYPES_SPLIT: keys.append(key) - elif key not in ('meg', 'fnirs') and \ - _PICK_TYPES_DATA_DICT.get(key, False): + elif key not in ("meg", "fnirs") and _PICK_TYPES_DATA_DICT.get( + key, False + ): keys.append(key) - rst = '- ' + '\n- '.join( - '``%r``: **%s** (scaled by %g to plot in *%s*)' - % (key, DEFAULTS['titles'][key], DEFAULTS['scalings'][key], - DEFAULTS['units'][key]) - for key in keys) + rst = "- " + "\n- ".join( + "``%r``: **%s** (scaled by %g to plot in *%s*)" + % ( + key, + DEFAULTS["titles"][key], + DEFAULTS["scalings"][key], + DEFAULTS["units"][key], + ) + for key in keys + ) else: raise self.error( - 'MNE directive unknown in %s: %r' - % (env.doc2path(env.docname, base=None), - self.arguments[0],)) + "MNE directive unknown in %s: %r" + % ( + env.doc2path(env.docname, base=None), + self.arguments[0], + ) + ) node = nodes.compound(rst) # General(Body), Element content = StringList( - rst.split('\n'), parent=self.content.parent, - parent_offset=self.content.parent_offset) + rst.split("\n"), + parent=self.content.parent, + parent_offset=self.content.parent_offset, + ) self.state.nested_parse(content, self.content_offset, node) return [node] def setup(app): # noqa: D103 - app.add_directive('mne', MNESubstitution) - return {'version': '0.1', - 'parallel_read_safe': True, - 'parallel_write_safe': True} + app.add_directive("mne", MNESubstitution) + return {"version": "0.1", "parallel_read_safe": True, "parallel_write_safe": True} diff --git a/doc/sphinxext/newcontrib_substitutions.py b/doc/sphinxext/newcontrib_substitutions.py index 68595e74bdb..8c31e8ca0e2 100644 --- a/doc/sphinxext/newcontrib_substitutions.py +++ b/doc/sphinxext/newcontrib_substitutions.py @@ -1,18 +1,17 @@ from docutils.nodes import reference, strong, target -def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, - content=[]): +def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, content=[]): """Create a role to highlight new contributors in changelog entries.""" - newcontrib = f'new contributor {text}' - alias_text = f' <{text}_>' - rawtext = f'`{newcontrib}{alias_text}`_' + newcontrib = f"new contributor {text}" + alias_text = f" <{text}_>" + rawtext = f"`{newcontrib}{alias_text}`_" refname = text.lower() strong_node = strong(rawtext, newcontrib) target_node = target(alias_text, refname=refname, names=[newcontrib]) target_node.indirect_reference_name = text options.update(refname=refname, name=newcontrib) - ref_node = reference('', '', strong_node, **options) + ref_node = reference("", "", strong_node, **options) ref_node[0].rawsource = rawtext inliner.document.note_indirect_target(target_node) inliner.document.note_refname(ref_node) @@ -20,5 +19,5 @@ def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, def setup(app): - app.add_role('newcontrib', newcontrib_role) + app.add_role("newcontrib", newcontrib_role) return diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index d912786b474..83b82c223e4 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -6,8 +6,10 @@ def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): def pass_error_to_sphinx(rawtext, text, lineno, inliner): msg = inliner.reporter.error( - 'The :unit: role requires a space-separated number and unit; ' - f'got {text}', line=lineno) + "The :unit: role requires a space-separated number and unit; " + f"got {text}", + line=lineno, + ) prb = inliner.problematic(rawtext, rawtext, msg) return [prb], [msg] @@ -20,10 +22,10 @@ def pass_error_to_sphinx(rawtext, text, lineno, inliner): except ValueError: return pass_error_to_sphinx(rawtext, text, lineno, inliner) # input is well-formatted: proceed - node = nodes.Text('\u202F'.join(parts)) + node = nodes.Text("\u202F".join(parts)) return [node], [] def setup(app): - app.add_role('unit', unit_role) + app.add_role("unit", unit_role) return diff --git a/examples/datasets/brainstorm_data.py b/examples/datasets/brainstorm_data.py index 949f2511a88..df08d0d383a 100644 --- a/examples/datasets/brainstorm_data.py +++ b/examples/datasets/brainstorm_data.py @@ -30,29 +30,38 @@ data_path = bst_raw.data_path() -raw_path = (data_path / 'MEG' / 'bst_raw' / - 'subj001_somatosensory_20111109_01_AUX-f.ds') +raw_path = data_path / "MEG" / "bst_raw" / "subj001_somatosensory_20111109_01_AUX-f.ds" # Here we crop to half the length to save memory raw = read_raw_ctf(raw_path).crop(0, 120).load_data() raw.plot() # set EOG channel -raw.set_channel_types({'EEG058': 'eog'}) -raw.set_eeg_reference('average', projection=True) +raw.set_channel_types({"EEG058": "eog"}) +raw.set_eeg_reference("average", projection=True) # show power line interference and remove it raw.compute_psd(tmax=60).plot(average=False) -raw.notch_filter(np.arange(60, 181, 60), fir_design='firwin') +raw.notch_filter(np.arange(60, 181, 60), fir_design="firwin") -events = mne.find_events(raw, stim_channel='UPPT001') +events = mne.find_events(raw, stim_channel="UPPT001") # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Compute epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject, preload=False) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=reject, + preload=False, +) # compute evoked evoked = epochs.average() @@ -68,11 +77,10 @@ evoked.shift_time(-0.004) # plot the result -evoked.plot(time_unit='s') +evoked.plot(time_unit="s") # show topomaps -evoked.plot_topomap(times=np.array([0.016, 0.030, 0.060, 0.070]), - time_unit='s') +evoked.plot_topomap(times=np.array([0.016, 0.030, 0.060, 0.070]), time_unit="s") # %% # References diff --git a/examples/datasets/hf_sef_data.py b/examples/datasets/hf_sef_data.py index 9857d22d09d..36ea2cbc2bb 100644 --- a/examples/datasets/hf_sef_data.py +++ b/examples/datasets/hf_sef_data.py @@ -18,8 +18,7 @@ import os from mne.datasets import hf_sef -fname_evoked = os.path.join(hf_sef.data_path(), - 'MEG/subject_b/hf_sef_15min-ave.fif') +fname_evoked = os.path.join(hf_sef.data_path(), "MEG/subject_b/hf_sef_15min-ave.fif") print(__doc__) @@ -34,7 +33,7 @@ # %% # Compare high-pass filtered and unfiltered data on a single channel -ch = 'MEG0443' +ch = "MEG0443" pick = evoked.ch_names.index(ch) -edi = {'HF': evoked_hp, 'Regular': evoked} +edi = {"HF": evoked_hp, "Regular": evoked} mne.viz.plot_compare_evokeds(edi, picks=pick) diff --git a/examples/datasets/limo_data.py b/examples/datasets/limo_data.py index d5670f62ffe..4285411dd6c 100644 --- a/examples/datasets/limo_data.py +++ b/examples/datasets/limo_data.py @@ -112,7 +112,7 @@ # metadata. # We want include all columns in the summary table -epochs_summary = limo_epochs.metadata.describe(include='all').round(3) +epochs_summary = limo_epochs.metadata.describe(include="all").round(3) print(epochs_summary) # %% @@ -137,13 +137,13 @@ ts_args = dict(xlim=(-0.25, 0.5)) # plot evoked response for face A -limo_epochs['Face/A'].average().plot_joint(times=[0.15], - title='Evoked response: Face A', - ts_args=ts_args) +limo_epochs["Face/A"].average().plot_joint( + times=[0.15], title="Evoked response: Face A", ts_args=ts_args +) # and face B -limo_epochs['Face/B'].average().plot_joint(times=[0.15], - title='Evoked response: Face B', - ts_args=ts_args) +limo_epochs["Face/B"].average().plot_joint( + times=[0.15], title="Evoked response: Face B", ts_args=ts_args +) # %% # We can also compute the difference wave contrasting Face A and Face B. @@ -151,12 +151,12 @@ # differences among these face-stimuli. # Face A minus Face B -difference_wave = combine_evoked([limo_epochs['Face/A'].average(), - limo_epochs['Face/B'].average()], - weights=[1, -1]) +difference_wave = combine_evoked( + [limo_epochs["Face/A"].average(), limo_epochs["Face/B"].average()], weights=[1, -1] +) # plot difference wave -difference_wave.plot_joint(times=[0.15], title='Difference Face A - Face B') +difference_wave.plot_joint(times=[0.15], title="Difference Face A - Face B") # %% # As expected, no clear pattern appears when contrasting @@ -167,11 +167,10 @@ # Create a dictionary containing the evoked responses conditions = ["Face/A", "Face/B"] -evokeds = {condition: limo_epochs[condition].average() - for condition in conditions} +evokeds = {condition: limo_epochs[condition].average() for condition in conditions} # concentrate analysis an occipital electrodes (e.g. B11) -pick = evokeds["Face/A"].ch_names.index('B11') +pick = evokeds["Face/A"].ch_names.index("B11") # compare evoked responses plot_compare_evokeds(evokeds, picks=pick, ylim=dict(eeg=(-15, 7.5))) @@ -188,26 +187,30 @@ # one could expect that faces with high phase-coherence should evoke stronger # activation patterns along occipital electrodes. -phase_coh = limo_epochs.metadata['phase-coherence'] +phase_coh = limo_epochs.metadata["phase-coherence"] # get levels of phase coherence levels = sorted(phase_coh.unique()) # create labels for levels of phase coherence (i.e., 0 - 85%) -labels = ["{0:.2f}".format(i) for i in np.arange(0., 0.90, 0.05)] +labels = ["{0:.2f}".format(i) for i in np.arange(0.0, 0.90, 0.05)] # create dict of evokeds for each level of phase-coherence -evokeds = {label: limo_epochs[phase_coh == level].average() - for level, label in zip(levels, labels)} +evokeds = { + label: limo_epochs[phase_coh == level].average() + for level, label in zip(levels, labels) +} # pick channel to plot -electrodes = ['C22', 'B11'] +electrodes = ["C22", "B11"] # create figures for electrode in electrodes: fig, ax = plt.subplots(figsize=(8, 4)) - plot_compare_evokeds(evokeds, - axes=ax, - ylim=dict(eeg=(-20, 15)), - picks=electrode, - cmap=("Phase coherence", "magma")) + plot_compare_evokeds( + evokeds, + axes=ax, + ylim=dict(eeg=(-20, 15)), + picks=electrode, + cmap=("Phase coherence", "magma"), + ) # %% # As shown above, there are some considerable differences between the @@ -225,7 +228,7 @@ # present in the data: limo_epochs.interpolate_bads(reset_bads=True) -limo_epochs.drop_channels(['EXG1', 'EXG2', 'EXG3', 'EXG4']) +limo_epochs.drop_channels(["EXG1", "EXG2", "EXG3", "EXG4"]) # %% # Define predictor variables and design matrix @@ -238,21 +241,19 @@ # ``limo_epochs.metadata``: phase-coherence and Face A vs. Face B. # name of predictors + intercept -predictor_vars = ['face a - face b', 'phase-coherence', 'intercept'] +predictor_vars = ["face a - face b", "phase-coherence", "intercept"] # create design matrix -design = limo_epochs.metadata[['phase-coherence', 'face']].copy() -design['face a - face b'] = np.where(design['face'] == 'A', 1, -1) -design['intercept'] = 1 +design = limo_epochs.metadata[["phase-coherence", "face"]].copy() +design["face a - face b"] = np.where(design["face"] == "A", 1, -1) +design["intercept"] = 1 design = design[predictor_vars] # %% # Now we can set up the linear model to be used in the analysis using # MNE-Python's func:`~mne.stats.linear_regression` function. -reg = linear_regression(limo_epochs, - design_matrix=design, - names=predictor_vars) +reg = linear_regression(limo_epochs, design_matrix=design, names=predictor_vars) # %% # Extract regression coefficients @@ -262,8 +263,8 @@ # which is a dictionary of evoked objects containing # multiple inferential measures for each predictor in the design matrix. -print('predictors are:', list(reg)) -print('fields are:', [field for field in getattr(reg['intercept'], '_fields')]) +print("predictors are:", list(reg)) +print("fields are:", [field for field in getattr(reg["intercept"], "_fields")]) # %% # Plot model results @@ -279,25 +280,23 @@ # the activity measured at occipital electrodes around 200 to 250 ms following # stimulus onset. -reg['phase-coherence'].beta.plot_joint(ts_args=ts_args, - title='Effect of Phase-coherence', - times=[0.23]) +reg["phase-coherence"].beta.plot_joint( + ts_args=ts_args, title="Effect of Phase-coherence", times=[0.23] +) # %% # We can also plot the corresponding T values. # use unit=False and scale=1 to keep values at their original # scale (i.e., avoid conversion to micro-volt). -ts_args = dict(xlim=(-0.25, 0.5), - unit=False) -topomap_args = dict(scalings=dict(eeg=1), - average=0.05) +ts_args = dict(xlim=(-0.25, 0.5), unit=False) +topomap_args = dict(scalings=dict(eeg=1), average=0.05) # sphinx_gallery_thumbnail_number = 9 -fig = reg['phase-coherence'].t_val.plot_joint(ts_args=ts_args, - topomap_args=topomap_args, - times=[0.23]) -fig.axes[0].set_ylabel('T-value') +fig = reg["phase-coherence"].t_val.plot_joint( + ts_args=ts_args, topomap_args=topomap_args, times=[0.23] +) +fig.axes[0].set_ylabel("T-value") # %% # Conversely, there appears to be no (or very small) systematic effects when @@ -305,9 +304,9 @@ # difference wave approach presented above. ts_args = dict(xlim=(-0.25, 0.5)) -reg['face a - face b'].beta.plot_joint(ts_args=ts_args, - title='Effect of Face A vs. Face B', - times=[0.23]) +reg["face a - face b"].beta.plot_joint( + ts_args=ts_args, title="Effect of Face A vs. Face B", times=[0.23] +) # %% # References diff --git a/examples/datasets/opm_data.py b/examples/datasets/opm_data.py index ec6daab1037..184ea216866 100644 --- a/examples/datasets/opm_data.py +++ b/examples/datasets/opm_data.py @@ -16,13 +16,12 @@ import mne data_path = mne.datasets.opm.data_path() -subject = 'OPM_sample' -subjects_dir = data_path / 'subjects' -raw_fname = data_path / 'MEG' / 'OPM' / 'OPM_SEF_raw.fif' -bem_fname = (subjects_dir / subject / 'bem' / - f'{subject}-5120-5120-5120-bem-sol.fif') -fwd_fname = data_path / 'MEG' / 'OPM' / 'OPM_sample-fwd.fif' -coil_def_fname = data_path / 'MEG' / 'OPM' / 'coil_def.dat' +subject = "OPM_sample" +subjects_dir = data_path / "subjects" +raw_fname = data_path / "MEG" / "OPM" / "OPM_SEF_raw.fif" +bem_fname = subjects_dir / subject / "bem" / f"{subject}-5120-5120-5120-bem-sol.fif" +fwd_fname = data_path / "MEG" / "OPM" / "OPM_sample-fwd.fif" +coil_def_fname = data_path / "MEG" / "OPM" / "coil_def.dat" # %% # Prepare data for localization @@ -30,8 +29,8 @@ # First we filter and epoch the data: raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(None, 90, h_trans_bandwidth=10.) -raw.notch_filter(50., notch_widths=1) +raw.filter(None, 90, h_trans_bandwidth=10.0) +raw.notch_filter(50.0, notch_widths=1) # Set epoch rejection threshold a bit larger than for SQUIDs @@ -40,16 +39,26 @@ # Find median nerve stimulator trigger event_id = dict(Median=257) -events = mne.find_events(raw, stim_channel='STI101', mask=257, mask_type='and') +events = mne.find_events(raw, stim_channel="STI101", mask=257, mask_type="and") picks = mne.pick_types(raw.info, meg=True, eeg=False) # We use verbose='error' to suppress warning about decimation causing aliasing, # ideally we would low-pass and then decimate instead -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, verbose='error', - reject=reject, picks=picks, proj=False, decim=10, - preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + verbose="error", + reject=reject, + picks=picks, + proj=False, + decim=10, + preload=True, +) evoked = epochs.average() evoked.plot() -cov = mne.compute_covariance(epochs, tmax=0.) +cov = mne.compute_covariance(epochs, tmax=0.0) del epochs, raw # %% @@ -63,7 +72,7 @@ # but should be fine for these analyses. bem = mne.read_bem_solution(bem_fname) -trans = mne.transforms.Transform('head', 'mri') # identity transformation +trans = mne.transforms.Transform("head", "mri") # identity transformation # To compute the forward solution, we must # provide our temporary/custom coil definitions, which can be done as:: @@ -78,12 +87,18 @@ mne.convert_forward_solution(fwd, force_fixed=True, copy=False) with mne.use_coil_def(coil_def_fname): - fig = mne.viz.plot_alignment(evoked.info, trans=trans, subject=subject, - subjects_dir=subjects_dir, - surfaces=('head', 'pial'), bem=bem) - -mne.viz.set_3d_view(figure=fig, azimuth=45, elevation=60, distance=0.4, - focalpoint=(0.02, 0, 0.04)) + fig = mne.viz.plot_alignment( + evoked.info, + trans=trans, + subject=subject, + subjects_dir=subjects_dir, + surfaces=("head", "pial"), + bem=bem, + ) + +mne.viz.set_3d_view( + figure=fig, azimuth=45, elevation=60, distance=0.4, focalpoint=(0.02, 0, 0.04) +) # %% # Perform dipole fitting @@ -91,15 +106,17 @@ # Fit dipoles on a subset of time points with mne.use_coil_def(coil_def_fname): - dip_opm, _ = mne.fit_dipole(evoked.copy().crop(0.040, 0.080), - cov, bem, trans, verbose=True) + dip_opm, _ = mne.fit_dipole( + evoked.copy().crop(0.040, 0.080), cov, bem, trans, verbose=True + ) idx = np.argmax(dip_opm.gof) -print('Best dipole at t=%0.1f ms with %0.1f%% GOF' - % (1000 * dip_opm.times[idx], dip_opm.gof[idx])) +print( + "Best dipole at t=%0.1f ms with %0.1f%% GOF" + % (1000 * dip_opm.times[idx], dip_opm.gof[idx]) +) # Plot N20m dipole as an example -dip_opm.plot_locations(trans, subject, subjects_dir, - mode='orthoview', idx=idx) +dip_opm.plot_locations(trans, subject, subjects_dir, mode="orthoview", idx=idx) # %% # Perform minimum-norm localization @@ -109,18 +126,24 @@ # areas we are sensitive to might be a good idea. inverse_operator = mne.minimum_norm.make_inverse_operator( - evoked.info, fwd, cov, loose=0., depth=None) + evoked.info, fwd, cov, loose=0.0, depth=None +) del fwd, cov method = "MNE" -snr = 3. -lambda2 = 1. / snr ** 2 +snr = 3.0 +lambda2 = 1.0 / snr**2 stc = mne.minimum_norm.apply_inverse( - evoked, inverse_operator, lambda2, method=method, - pick_ori=None, verbose=True) + evoked, inverse_operator, lambda2, method=method, pick_ori=None, verbose=True +) # Plot source estimate at time of best dipole fit -brain = stc.plot(hemi='rh', views='lat', subjects_dir=subjects_dir, - initial_time=dip_opm.times[idx], - clim=dict(kind='percent', lims=[99, 99.9, 99.99]), - size=(400, 300), background='w') +brain = stc.plot( + hemi="rh", + views="lat", + subjects_dir=subjects_dir, + initial_time=dip_opm.times[idx], + clim=dict(kind="percent", lims=[99, 99.9, 99.99]), + size=(400, 300), + background="w", +) diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset_sgskip.py index 875cc2eb5d5..8806059b395 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset_sgskip.py @@ -35,26 +35,26 @@ print(__doc__) data_path = spm_face.data_path() -subjects_dir = data_path / 'subjects' -spm_path = data_path / 'MEG' / 'spm' +subjects_dir = data_path / "subjects" +spm_path = data_path / "MEG" / "spm" # %% # Load and filter data, set up epochs -raw_fname = spm_path / 'SPM_CTF_MEG_example_faces%d_3D.ds' +raw_fname = spm_path / "SPM_CTF_MEG_example_faces%d_3D.ds" raw = io.read_raw_ctf(raw_fname % 1, preload=True) # Take first run # Here to save memory and time we'll downsample heavily -- this is not # advised for real data as it can effectively jitter events! -raw.resample(120., npad='auto') +raw.resample(120.0, npad="auto") -picks = mne.pick_types(raw.info, meg=True, exclude='bads') -raw.filter(1, 30, method='fir', fir_design='firwin') +picks = mne.pick_types(raw.info, meg=True, exclude="bads") +raw.filter(1, 30, method="fir", fir_design="firwin") -events = mne.find_events(raw, stim_channel='UPPT001') +events = mne.find_events(raw, stim_channel="UPPT001") # plot the events to get an idea of the paradigm -mne.viz.plot_events(events, raw.info['sfreq']) +mne.viz.plot_events(events, raw.info["sfreq"]) event_ids = {"faces": 1, "scrambled": 2} @@ -62,16 +62,25 @@ baseline = None # no baseline as high-pass is applied reject = dict(mag=5e-12) -epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks, - baseline=baseline, preload=True, reject=reject) +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin, + tmax, + picks=picks, + baseline=baseline, + preload=True, + reject=reject, +) # Fit ICA, find and remove major artifacts -ica = ICA(n_components=0.95, max_iter='auto', random_state=0) +ica = ICA(n_components=0.95, max_iter="auto", random_state=0) ica.fit(raw, decim=1, reject=reject) # compute correlation scores, get bad indices sorted by score -eog_epochs = create_eog_epochs(raw, ch_name='MRT31-2908', reject=reject) -eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name='MRT31-2908') +eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) +eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on ica.plot_components(eog_inds) # view topographic sensitivity of components ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar @@ -90,18 +99,18 @@ plt.show() # estimate noise covarariance -noise_cov = mne.compute_covariance(epochs, tmax=0, method='shrunk', - rank=None) +noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) # %% # Visualize fields on MEG helmet # The transformation here was aligned using the dig-montage. It's included in # the spm_faces dataset and is named SPM_dig_montage.fif. -trans_fname = spm_path / 'SPM_CTF_MEG_example_faces1_3D_raw-trans.fif' +trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" -maps = mne.make_field_map(evoked[0], trans_fname, subject='spm', - subjects_dir=subjects_dir, n_jobs=None) +maps = mne.make_field_map( + evoked[0], trans_fname, subject="spm", subjects_dir=subjects_dir, n_jobs=None +) evoked[0].plot_field(maps, time=0.170) @@ -113,25 +122,31 @@ # %% # Compute forward model -src = subjects_dir / 'spm' / 'bem' / 'spm-oct-6-src.fif' -bem = subjects_dir / 'spm' / 'bem' / 'spm-5120-5120-5120-bem-sol.fif' +src = subjects_dir / "spm" / "bem" / "spm-oct-6-src.fif" +bem = subjects_dir / "spm" / "bem" / "spm-5120-5120-5120-bem-sol.fif" forward = mne.make_forward_solution(contrast.info, trans_fname, src, bem) # %% # Compute inverse solution snr = 3.0 -lambda2 = 1.0 / snr ** 2 -method = 'dSPM' +lambda2 = 1.0 / snr**2 +method = "dSPM" -inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov, - loose=0.2, depth=0.8) +inverse_operator = make_inverse_operator( + contrast.info, forward, noise_cov, loose=0.2, depth=0.8 +) # Compute inverse solution on contrast stc = apply_inverse(contrast, inverse_operator, lambda2, method, pick_ori=None) # stc.save('spm_%s_dSPM_inverse' % contrast.comment) # Plot contrast in 3D with mne.viz.Brain if available -brain = stc.plot(hemi='both', subjects_dir=subjects_dir, initial_time=0.170, - views=['ven'], clim={'kind': 'value', 'lims': [3., 6., 9.]}) +brain = stc.plot( + hemi="both", + subjects_dir=subjects_dir, + initial_time=0.170, + views=["ven"], + clim={"kind": "value", "lims": [3.0, 6.0, 9.0]}, +) # brain.save_image('dSPM_map.png') diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index beef85bbdc0..1ee2f0ce87d 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -40,7 +40,7 @@ # avoid classification of evoked responses by using epochs that start 1s after # cue onset. -tmin, tmax = -1., 4. +tmin, tmax = -1.0, 4.0 event_id = dict(hands=2, feet=3) subject = 1 runs = [6, 10, 14] # motor imagery: hands vs feet @@ -48,22 +48,30 @@ raw_fnames = eegbci.load_data(subject, runs) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames]) eegbci.standardize(raw) # set channel names -montage = make_standard_montage('standard_1005') +montage = make_standard_montage("standard_1005") raw.set_montage(montage) # Apply band-pass filter -raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge') +raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) -picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') +picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") # Read epochs (train will be done only between 1 and 2s) # Testing will be done with a running classifier -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, - baseline=None, preload=True) -epochs_train = epochs.copy().crop(tmin=1., tmax=2.) +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=picks, + baseline=None, + preload=True, +) +epochs_train = epochs.copy().crop(tmin=1.0, tmax=2.0) labels = epochs.events[:, -1] - 2 # %% @@ -81,25 +89,26 @@ csp = CSP(n_components=4, reg=None, log=True, norm_trace=False) # Use scikit-learn Pipeline with cross_val_score function -clf = Pipeline([('CSP', csp), ('LDA', lda)]) +clf = Pipeline([("CSP", csp), ("LDA", lda)]) scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None) # Printing the results class_balance = np.mean(labels == labels[0]) -class_balance = max(class_balance, 1. - class_balance) -print("Classification accuracy: %f / Chance level: %f" % (np.mean(scores), - class_balance)) +class_balance = max(class_balance, 1.0 - class_balance) +print( + "Classification accuracy: %f / Chance level: %f" % (np.mean(scores), class_balance) +) # plot CSP patterns estimated on full data for visualization csp.fit_transform(epochs_data, labels) -csp.plot_patterns(epochs.info, ch_type='eeg', units='Patterns (AU)', size=1.5) +csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5) # %% # Look at performance over time -sfreq = raw.info['sfreq'] -w_length = int(sfreq * 0.5) # running classifier: window length +sfreq = raw.info["sfreq"] +w_length = int(sfreq * 0.5) # running classifier: window length w_step = int(sfreq * 0.1) # running classifier: window step size w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) @@ -117,21 +126,21 @@ # running classifier: test classifier on sliding window score_this_window = [] for n in w_start: - X_test = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)]) + X_test = csp.transform(epochs_data[test_idx][:, :, n : (n + w_length)]) score_this_window.append(lda.score(X_test, y_test)) scores_windows.append(score_this_window) # Plot scores over time -w_times = (w_start + w_length / 2.) / sfreq + epochs.tmin +w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin plt.figure() -plt.plot(w_times, np.mean(scores_windows, 0), label='Score') -plt.axvline(0, linestyle='--', color='k', label='Onset') -plt.axhline(0.5, linestyle='-', color='k', label='Chance') -plt.xlabel('time (s)') -plt.ylabel('classification accuracy') -plt.title('Classification score over time') -plt.legend(loc='lower right') +plt.plot(w_times, np.mean(scores_windows, 0), label="Score") +plt.axvline(0, linestyle="--", color="k", label="Onset") +plt.axhline(0.5, linestyle="-", color="k", label="Chance") +plt.xlabel("time (s)") +plt.ylabel("classification accuracy") +plt.title("Classification score over time") +plt.legend(loc="lower right") plt.show() ############################################################################## diff --git a/examples/decoding/decoding_csp_timefreq.py b/examples/decoding/decoding_csp_timefreq.py index 6407646910b..3b048587ec1 100644 --- a/examples/decoding/decoding_csp_timefreq.py +++ b/examples/decoding/decoding_csp_timefreq.py @@ -44,22 +44,24 @@ raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames]) # Extract information from the raw file -sfreq = raw.info['sfreq'] +sfreq = raw.info["sfreq"] events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) -raw.pick_types(meg=False, eeg=True, stim=False, eog=False, exclude='bads') +raw.pick_types(meg=False, eeg=True, stim=False, eog=False, exclude="bads") raw.load_data() # Assemble the classifier using scikit-learn pipeline -clf = make_pipeline(CSP(n_components=4, reg=None, log=True, norm_trace=False), - LinearDiscriminantAnalysis()) +clf = make_pipeline( + CSP(n_components=4, reg=None, log=True, norm_trace=False), + LinearDiscriminantAnalysis(), +) n_splits = 3 # for cross-validation, 5 is better, here we use 3 for speed cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) # Classification & time-frequency parameters -tmin, tmax = -.200, 2.000 -n_cycles = 10. # how many complete cycles: used to define window size -min_freq = 8. -max_freq = 20. +tmin, tmax = -0.200, 2.000 +n_cycles = 10.0 # how many complete cycles: used to define window size +min_freq = 8.0 +max_freq = 20.0 n_freqs = 6 # how many frequency bins to use # Assemble list of frequency range tuples @@ -67,7 +69,7 @@ freq_ranges = list(zip(freqs[:-1], freqs[1:])) # make freqs list of tuples # Infer window spacing from the max freq and number of cycles to avoid gaps -window_spacing = (n_cycles / np.max(freqs) / 2.) +window_spacing = n_cycles / np.max(freqs) / 2.0 centered_w_times = np.arange(tmin, tmax, window_spacing)[1:] n_windows = len(centered_w_times) @@ -82,39 +84,50 @@ # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): - # Infer window size based on the frequency being used - w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds + w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds # Apply band-pass filter to isolate the specified frequencies - raw_filter = raw.copy().filter(fmin, fmax, fir_design='firwin', - skip_by_annotation='edge') + raw_filter = raw.copy().filter( + fmin, fmax, fir_design="firwin", skip_by_annotation="edge" + ) # Extract epochs from filtered data, padded by window size - epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, - proj=False, baseline=None, preload=True) + epochs = Epochs( + raw_filter, + events, + event_id, + tmin - w_size, + tmax + w_size, + proj=False, + baseline=None, + preload=True, + ) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) X = epochs.get_data() # Save mean scores over folds for each frequency and time window - freq_scores[freq] = np.mean(cross_val_score( - estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv), axis=0) + freq_scores[freq] = np.mean( + cross_val_score(estimator=clf, X=X, y=y, scoring="roc_auc", cv=cv), axis=0 + ) # %% # Plot frequency results -plt.bar(freqs[:-1], freq_scores, width=np.diff(freqs)[0], - align='edge', edgecolor='black') +plt.bar( + freqs[:-1], freq_scores, width=np.diff(freqs)[0], align="edge", edgecolor="black" +) plt.xticks(freqs) plt.ylim([0, 1]) -plt.axhline(len(epochs['feet']) / len(epochs), color='k', linestyle='--', - label='chance level') +plt.axhline( + len(epochs["feet"]) / len(epochs), color="k", linestyle="--", label="chance level" +) plt.legend() -plt.xlabel('Frequency (Hz)') -plt.ylabel('Decoding Scores') -plt.title('Frequency Decoding Scores') +plt.xlabel("Frequency (Hz)") +plt.ylabel("Decoding Scores") +plt.title("Frequency Decoding Scores") # %% # Loop through frequencies and time, apply classifier and save scores @@ -124,41 +137,53 @@ # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): - # Infer window size based on the frequency being used - w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds + w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds # Apply band-pass filter to isolate the specified frequencies - raw_filter = raw.copy().filter(fmin, fmax, fir_design='firwin', - skip_by_annotation='edge') + raw_filter = raw.copy().filter( + fmin, fmax, fir_design="firwin", skip_by_annotation="edge" + ) # Extract epochs from filtered data, padded by window size - epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, - proj=False, baseline=None, preload=True) + epochs = Epochs( + raw_filter, + events, + event_id, + tmin - w_size, + tmax + w_size, + proj=False, + baseline=None, + preload=True, + ) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) # Roll covariance, csp and lda over time for t, w_time in enumerate(centered_w_times): - # Center the min and max of the window - w_tmin = w_time - w_size / 2. - w_tmax = w_time + w_size / 2. + w_tmin = w_time - w_size / 2.0 + w_tmax = w_time + w_size / 2.0 # Crop data into time-window of interest X = epochs.copy().crop(w_tmin, w_tmax).get_data() # Save mean scores over folds for each frequency and time window - tf_scores[freq, t] = np.mean(cross_val_score( - estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv), axis=0) + tf_scores[freq, t] = np.mean( + cross_val_score(estimator=clf, X=X, y=y, scoring="roc_auc", cv=cv), axis=0 + ) # %% # Plot time-frequency results # Set up time frequency object -av_tfr = AverageTFR(create_info(['freq'], sfreq), tf_scores[np.newaxis, :], - centered_w_times, freqs[1:], 1) +av_tfr = AverageTFR( + create_info(["freq"], sfreq), + tf_scores[np.newaxis, :], + centered_w_times, + freqs[1:], + 1, +) chance = np.mean(y) # set chance level to white in the plot -av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", - cmap=plt.cm.Reds) +av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds) diff --git a/examples/decoding/decoding_rsa_sgskip.py b/examples/decoding/decoding_rsa_sgskip.py index ba1be187372..7cc6dbfbb01 100644 --- a/examples/decoding/decoding_rsa_sgskip.py +++ b/examples/decoding/decoding_rsa_sgskip.py @@ -50,7 +50,7 @@ data_path = visual_92_categories.data_path() # Define stimulus - trigger mapping -fname = data_path / 'visual_stimuli.csv' +fname = data_path / "visual_stimuli.csv" conds = read_csv(fname) print(conds.head(5)) @@ -64,38 +64,48 @@ conditions = [] for c in conds.values: cond_tags = list(c[:2]) - cond_tags += [('not-' if i == 0 else '') + conds.columns[k] - for k, i in enumerate(c[2:], 2)] - conditions.append('/'.join(map(str, cond_tags))) + cond_tags += [ + ("not-" if i == 0 else "") + conds.columns[k] for k, i in enumerate(c[2:], 2) + ] + conditions.append("/".join(map(str, cond_tags))) print(conditions[:10]) ############################################################################## # Let's make the event_id dictionary event_id = dict(zip(conditions, conds.trigger + 1)) -event_id['0/human bodypart/human/not-face/animal/natural'] +event_id["0/human bodypart/human/not-face/animal/natural"] ############################################################################## # Read MEG data n_runs = 4 # 4 for full data (use less to speed up computations) -fnames = [data_path / f'sample_subject_{b}_tsss_mc.fif' for b in range(n_runs)] -raws = [read_raw_fif(fname, verbose='error', on_split_missing='ignore') - for fname in fnames] # ignore filename warnings +fnames = [data_path / f"sample_subject_{b}_tsss_mc.fif" for b in range(n_runs)] +raws = [ + read_raw_fif(fname, verbose="error", on_split_missing="ignore") for fname in fnames +] # ignore filename warnings raw = concatenate_raws(raws) -events = mne.find_events(raw, min_duration=.002) +events = mne.find_events(raw, min_duration=0.002) events = events[events[:, 2] <= max_trigger] ############################################################################## # Epoch data picks = mne.pick_types(raw.info, meg=True) -epochs = mne.Epochs(raw, events=events, event_id=event_id, baseline=None, - picks=picks, tmin=-.1, tmax=.500, preload=True) +epochs = mne.Epochs( + raw, + events=events, + event_id=event_id, + baseline=None, + picks=picks, + tmin=-0.1, + tmax=0.500, + preload=True, +) ############################################################################## # Let's plot some conditions -epochs['face'].average().plot() -epochs['not-face'].average().plot() +epochs["face"].average().plot() +epochs["not-face"].average().plot() ############################################################################## # Representational Similarity Analysis (RSA) is a neuroimaging-specific @@ -112,9 +122,9 @@ # Classify using the average signal in the window 50ms to 300ms # to focus the classifier on the time interval with best SNR. -clf = make_pipeline(StandardScaler(), - LogisticRegression(C=1, solver='liblinear', - multi_class='auto')) +clf = make_pipeline( + StandardScaler(), LogisticRegression(C=1, solver="liblinear", multi_class="auto") +) X = epochs.copy().crop(0.05, 0.3).get_data().mean(axis=2) y = epochs.events[:, 2] @@ -139,15 +149,15 @@ ############################################################################## # Plot -labels = [''] * 5 + ['face'] + [''] * 11 + ['bodypart'] + [''] * 6 +labels = [""] * 5 + ["face"] + [""] * 11 + ["bodypart"] + [""] * 6 fig, ax = plt.subplots(1) -im = ax.matshow(confusion, cmap='RdBu_r', clim=[0.3, 0.7]) +im = ax.matshow(confusion, cmap="RdBu_r", clim=[0.3, 0.7]) ax.set_yticks(range(len(classes))) ax.set_yticklabels(labels) ax.set_xticks(range(len(classes))) -ax.set_xticklabels(labels, rotation=40, ha='left') -ax.axhline(11.5, color='k') -ax.axvline(11.5, color='k') +ax.set_xticklabels(labels, rotation=40, ha="left") +ax.axhline(11.5, color="k") +ax.axvline(11.5, color="k") plt.colorbar(im) plt.tight_layout() plt.show() @@ -157,19 +167,25 @@ # summarized with dimensionality reduction using multi-dimensional scaling [1]. # See how the face samples cluster together. fig, ax = plt.subplots(1) -mds = MDS(2, random_state=0, dissimilarity='precomputed') +mds = MDS(2, random_state=0, dissimilarity="precomputed") chance = 0.5 summary = mds.fit_transform(chance - confusion) -cmap = plt.colormaps['rainbow'] -colors = ['r', 'b'] -names = list(conds['condition'].values) +cmap = plt.colormaps["rainbow"] +colors = ["r", "b"] +names = list(conds["condition"].values) for color, name in zip(colors, set(names)): sel = np.where([this_name == name for this_name in names])[0] - size = 500 if name == 'human face' else 100 - ax.scatter(summary[sel, 0], summary[sel, 1], s=size, - facecolors=color, label=name, edgecolors='k') -ax.axis('off') -ax.legend(loc='lower right', scatterpoints=1, ncol=2) + size = 500 if name == "human face" else 100 + ax.scatter( + summary[sel, 0], + summary[sel, 1], + s=size, + facecolors=color, + label=name, + edgecolors="k", + ) +ax.axis("off") +ax.legend(loc="lower right", scatterpoints=1, ncol=2) plt.tight_layout() plt.show() diff --git a/examples/decoding/decoding_spatio_temporal_source.py b/examples/decoding/decoding_spatio_temporal_source.py index 476b4d170c6..ad96720f640 100644 --- a/examples/decoding/decoding_spatio_temporal_source.py +++ b/examples/decoding/decoding_spatio_temporal_source.py @@ -31,42 +31,51 @@ import mne from mne.minimum_norm import apply_inverse_epochs, read_inverse_operator -from mne.decoding import (cross_val_multiscore, LinearModel, SlidingEstimator, - get_coef) +from mne.decoding import cross_val_multiscore, LinearModel, SlidingEstimator, get_coef print(__doc__) data_path = mne.datasets.sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-oct-6-fwd.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-oct-6-fwd.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" +subjects_dir = data_path / "subjects" # %% # Set parameters -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" tmin, tmax = -0.2, 0.8 event_id = dict(aud_r=2, vis_r=4) # load contra-lateral conditions # Setup for reading the raw data raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(None, 10., fir_design='firwin') +raw.filter(None, 10.0, fir_design="firwin") events = mne.read_events(event_fname) # Set up pick list: MEG - bad channels (modify to your needs) -raw.info['bads'] += ['MEG 2443'] # mark bads -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=True, eog=True, - exclude='bads') +raw.info["bads"] += ["MEG 2443"] # mark bads +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=True, eog=True, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - picks=picks, baseline=(None, 0), preload=True, - reject=dict(grad=4000e-13, eog=150e-6), - decim=5) # decimate to save memory and increase speed +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=picks, + baseline=(None, 0), + preload=True, + reject=dict(grad=4000e-13, eog=150e-6), + decim=5, +) # decimate to save memory and increase speed # %% # Compute inverse solution @@ -74,9 +83,14 @@ noise_cov = mne.read_cov(fname_cov) inverse_operator = read_inverse_operator(fname_inv) -stcs = apply_inverse_epochs(epochs, inverse_operator, - lambda2=1.0 / snr ** 2, verbose=False, - method="dSPM", pick_ori="normal") +stcs = apply_inverse_epochs( + epochs, + inverse_operator, + lambda2=1.0 / snr**2, + verbose=False, + method="dSPM", + pick_ori="normal", +) # %% # Decoding in sensor space using a logistic regression @@ -86,19 +100,21 @@ y = epochs.events[:, 2] # prepare a series of classifier applied at each time sample -clf = make_pipeline(StandardScaler(), # z-score normalization - SelectKBest(f_classif, k=500), # select features for speed - LinearModel(LogisticRegression(C=1, solver='liblinear'))) -time_decod = SlidingEstimator(clf, scoring='roc_auc') +clf = make_pipeline( + StandardScaler(), # z-score normalization + SelectKBest(f_classif, k=500), # select features for speed + LinearModel(LogisticRegression(C=1, solver="liblinear")), +) +time_decod = SlidingEstimator(clf, scoring="roc_auc") # Run cross-validated decoding analyses: scores = cross_val_multiscore(time_decod, X, y, cv=5, n_jobs=None) # Plot average decoding scores of 5 splits fig, ax = plt.subplots(1) -ax.plot(epochs.times, scores.mean(0), label='score') -ax.axhline(.5, color='k', linestyle='--', label='chance') -ax.axvline(0, color='k') +ax.plot(epochs.times, scores.mean(0), label="score") +ax.axhline(0.5, color="k", linestyle="--", label="chance") +ax.axvline(0, color="k") plt.legend() # %% @@ -109,13 +125,22 @@ time_decod.fit(X, y) # Retrieve patterns after inversing the z-score normalization step: -patterns = get_coef(time_decod, 'patterns_', inverse_transform=True) +patterns = get_coef(time_decod, "patterns_", inverse_transform=True) stc = stcs[0] # for convenience, lookup parameters from first stc vertices = [stc.lh_vertno, np.array([], int)] # empty array for right hemi -stc_feat = mne.SourceEstimate(np.abs(patterns), vertices=vertices, - tmin=stc.tmin, tstep=stc.tstep, subject='sample') - -brain = stc_feat.plot(views=['lat'], transparent=True, - initial_time=0.1, time_unit='s', - subjects_dir=subjects_dir) +stc_feat = mne.SourceEstimate( + np.abs(patterns), + vertices=vertices, + tmin=stc.tmin, + tstep=stc.tstep, + subject="sample", +) + +brain = stc_feat.plot( + views=["lat"], + transparent=True, + initial_time=0.1, + time_unit="s", + subjects_dir=subjects_dir, +) diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index f1fb8c86400..81acb0b9cc4 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -35,32 +35,31 @@ from sklearn.model_selection import KFold, cross_val_predict # Define parameters -fname = data_path() / 'SubjectCMC.ds' +fname = data_path() / "SubjectCMC.ds" raw = mne.io.read_raw_ctf(fname) -raw.crop(50., 200.) # crop for memory purposes +raw.crop(50.0, 200.0) # crop for memory purposes # Filter muscular activity to only keep high frequencies -emg = raw.copy().pick_channels(['EMGlft']).load_data() -emg.filter(20., None) +emg = raw.copy().pick_channels(["EMGlft"]).load_data() +emg.filter(20.0, None) # Filter MEG data to focus on beta band raw.pick_types(meg=True, ref_meg=True, eeg=False, eog=False).load_data() -raw.filter(15., 30.) +raw.filter(15.0, 30.0) # Build epochs as sliding windows over the continuous raw file events = mne.make_fixed_length_events(raw, id=1, duration=0.75) # Epoch length is 1.5 second -meg_epochs = Epochs(raw, events, tmin=0., tmax=1.5, baseline=None, - detrend=1, decim=12) -emg_epochs = Epochs(emg, events, tmin=0., tmax=1.5, baseline=None) +meg_epochs = Epochs(raw, events, tmin=0.0, tmax=1.5, baseline=None, detrend=1, decim=12) +emg_epochs = Epochs(emg, events, tmin=0.0, tmax=1.5, baseline=None) # Prepare classification X = meg_epochs.get_data() y = emg_epochs.get_data().var(axis=2)[:, 0] # target is EMG power # Classification pipeline with SPoC spatial filtering and Ridge Regression -spoc = SPoC(n_components=2, log=True, reg='oas', rank='full') +spoc = SPoC(n_components=2, log=True, reg="oas", rank="full") clf = make_pipeline(spoc, Ridge()) # Define a two fold cross-validation cv = KFold(n_splits=2, shuffle=False) @@ -71,11 +70,11 @@ # Plot the True EMG power and the EMG power predicted from MEG data fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = raw.times[meg_epochs.events[:, 0] - raw.first_samp] -ax.plot(times, y_preds, color='b', label='Predicted EMG') -ax.plot(times, y, color='r', label='True EMG') -ax.set_xlabel('Time (s)') -ax.set_ylabel('EMG Power') -ax.set_title('SPoC MEG Predictions') +ax.plot(times, y_preds, color="b", label="Predicted EMG") +ax.plot(times, y, color="r", label="True EMG") +ax.set_xlabel("Time (s)") +ax.set_ylabel("EMG Power") +ax.set_title("SPoC MEG Predictions") plt.legend() mne.viz.tight_layout() plt.show() diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index d39797e6561..08ca0d9c0c3 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -34,56 +34,78 @@ # Preprocess data data_path = sample.data_path() # Load and filter data, set up epochs -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -events_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +events_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" raw = mne.io.read_raw_fif(raw_fname, preload=True) -picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels -raw.filter(1., 30., fir_design='firwin') # Band pass filtering signals +picks = mne.pick_types(raw.info, meg=True, exclude="bads") # Pick MEG channels +raw.filter(1.0, 30.0, fir_design="firwin") # Band pass filtering signals events = mne.read_events(events_fname) -event_id = {'Auditory/Left': 1, 'Auditory/Right': 2, - 'Visual/Left': 3, 'Visual/Right': 4} +event_id = { + "Auditory/Left": 1, + "Auditory/Right": 2, + "Visual/Left": 3, + "Visual/Right": 4, +} tmin = -0.050 tmax = 0.400 # decimate to make the example faster to run, but then use verbose='error' in # the Epochs constructor to suppress warning about decimation causing aliasing decim = 2 -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, - proj=True, picks=picks, baseline=None, preload=True, - reject=dict(mag=5e-12), decim=decim, verbose='error') +epochs = mne.Epochs( + raw, + events, + event_id=event_id, + tmin=tmin, + tmax=tmax, + proj=True, + picks=picks, + baseline=None, + preload=True, + reject=dict(mag=5e-12), + decim=decim, + verbose="error", +) # %% # We will train the classifier on all left visual vs auditory trials # and test on all right visual vs auditory trials. clf = make_pipeline( StandardScaler(), - LogisticRegression(solver='liblinear') # liblinear is faster than lbfgs + LogisticRegression(solver="liblinear"), # liblinear is faster than lbfgs ) -time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=None, - verbose=True) +time_gen = GeneralizingEstimator(clf, scoring="roc_auc", n_jobs=None, verbose=True) # Fit classifiers on the epochs where the stimulus was presented to the left. # Note that the experimental condition y indicates auditory or visual -time_gen.fit(X=epochs['Left'].get_data(), - y=epochs['Left'].events[:, 2] > 2) +time_gen.fit(X=epochs["Left"].get_data(), y=epochs["Left"].events[:, 2] > 2) # %% # Score on the epochs where the stimulus was presented to the right. -scores = time_gen.score(X=epochs['Right'].get_data(), - y=epochs['Right'].events[:, 2] > 2) +scores = time_gen.score( + X=epochs["Right"].get_data(), y=epochs["Right"].events[:, 2] > 2 +) # %% # Plot fig, ax = plt.subplots(constrained_layout=True) -im = ax.matshow(scores, vmin=0, vmax=1., cmap='RdBu_r', origin='lower', - extent=epochs.times[[0, -1, 0, -1]]) -ax.axhline(0., color='k') -ax.axvline(0., color='k') -ax.xaxis.set_ticks_position('bottom') -ax.set_xlabel('Condition: "Right"\nTesting Time (s)',) +im = ax.matshow( + scores, + vmin=0, + vmax=1.0, + cmap="RdBu_r", + origin="lower", + extent=epochs.times[[0, -1, 0, -1]], +) +ax.axhline(0.0, color="k") +ax.axvline(0.0, color="k") +ax.xaxis.set_ticks_position("bottom") +ax.set_xlabel( + 'Condition: "Right"\nTesting Time (s)', +) ax.set_ylabel('Condition: "Left"\nTraining Time (s)') -ax.set_title('Generalization across time and condition', fontweight='bold') -fig.colorbar(im, ax=ax, label='Performance (ROC AUC)') +ax.set_title("Generalization across time and condition", fontweight="bold") +fig.colorbar(im, ax=ax, label="Performance (ROC AUC)") plt.show() ############################################################################## diff --git a/examples/decoding/decoding_unsupervised_spatial_filter.py b/examples/decoding/decoding_unsupervised_spatial_filter.py index a3fab432ace..d215203ac3c 100644 --- a/examples/decoding/decoding_unsupervised_spatial_filter.py +++ b/examples/decoding/decoding_unsupervised_spatial_filter.py @@ -32,22 +32,32 @@ data_path = sample.data_path() # Load and filter data, set up epochs -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") events = mne.read_events(event_fname) -picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') - -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +picks = mne.pick_types( + raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads" +) + +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) X = epochs.get_data() @@ -55,19 +65,22 @@ # Transform data with PCA computed on the average ie evoked response pca = UnsupervisedSpatialFilter(PCA(30), average=False) pca_data = pca.fit_transform(X) -ev = mne.EvokedArray(np.mean(pca_data, axis=0), - mne.create_info(30, epochs.info['sfreq'], - ch_types='eeg'), tmin=tmin) -ev.plot(show=False, window_title="PCA", time_unit='s') +ev = mne.EvokedArray( + np.mean(pca_data, axis=0), + mne.create_info(30, epochs.info["sfreq"], ch_types="eeg"), + tmin=tmin, +) +ev.plot(show=False, window_title="PCA", time_unit="s") ############################################################################## # Transform data with ICA computed on the raw epochs (no averaging) -ica = UnsupervisedSpatialFilter( - FastICA(30, whiten='unit-variance'), average=False) +ica = UnsupervisedSpatialFilter(FastICA(30, whiten="unit-variance"), average=False) ica_data = ica.fit_transform(X) -ev1 = mne.EvokedArray(np.mean(ica_data, axis=0), - mne.create_info(30, epochs.info['sfreq'], - ch_types='eeg'), tmin=tmin) -ev1.plot(show=False, window_title='ICA', time_unit='s') +ev1 = mne.EvokedArray( + np.mean(ica_data, axis=0), + mne.create_info(30, epochs.info["sfreq"], ch_types="eeg"), + tmin=tmin, +) +ev1.plot(show=False, window_title="ICA", time_unit="s") plt.show() diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index 9ec65f54976..3bdff716228 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -37,32 +37,45 @@ # %% # Set parameters and read data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 -event_id = {'Auditory/Left': 1, 'Auditory/Right': 2, - 'Visual/Left': 3, 'Visual/Right': 4} +event_id = { + "Auditory/Left": 1, + "Auditory/Right": 2, + "Visual/Left": 3, + "Visual/Right": 4, +} n_filter = 3 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") events = read_events(event_fname) -picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') - -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") + +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) # Create classification pipeline -clf = make_pipeline(Xdawn(n_components=n_filter), - Vectorizer(), - MinMaxScaler(), - LogisticRegression(penalty='l1', solver='liblinear', - multi_class='auto')) +clf = make_pipeline( + Xdawn(n_components=n_filter), + Vectorizer(), + MinMaxScaler(), + LogisticRegression(penalty="l1", solver="liblinear", multi_class="auto"), +) # Get the labels labels = epochs.events[:, -1] @@ -77,7 +90,7 @@ preds[test] = clf.predict(epochs[test]) # Classification report -target_names = ['aud_l', 'aud_r', 'vis_l', 'vis_r'] +target_names = ["aud_l", "aud_r", "vis_l", "vis_r"] report = classification_report(labels, preds, target_names=target_names) print(report) @@ -87,21 +100,22 @@ # Plot confusion matrix fig, ax = plt.subplots(1) -im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues) -ax.set(title='Normalized Confusion matrix') +im = ax.imshow(cm_normalized, interpolation="nearest", cmap=plt.cm.Blues) +ax.set(title="Normalized Confusion matrix") fig.colorbar(im) tick_marks = np.arange(len(target_names)) plt.xticks(tick_marks, target_names, rotation=45) plt.yticks(tick_marks, target_names) fig.tight_layout() -ax.set(ylabel='True label', xlabel='Predicted label') +ax.set(ylabel="True label", xlabel="Predicted label") # %% # The ``patterns_`` attribute of a fitted Xdawn instance (here from the last # cross-validation fold) can be used for visualization. -fig, axes = plt.subplots(nrows=len(event_id), ncols=n_filter, - figsize=(n_filter, len(event_id) * 2)) +fig, axes = plt.subplots( + nrows=len(event_id), ncols=n_filter, figsize=(n_filter, len(event_id) * 2) +) fitted_xdawn = clf.steps[0][1] info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) info.set_montage(epochs.get_montage()) @@ -110,8 +124,12 @@ pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0) pattern_evoked.plot_topomap( times=np.arange(n_filter), - time_format='Component %d' if ii == 0 else '', colorbar=False, - show_names=False, axes=axes[ii], show=False) + time_format="Component %d" if ii == 0 else "", + colorbar=False, + show_names=False, + axes=axes[ii], + show=False, + ) axes[ii, 0].set(ylabel=cur_class) fig.tight_layout(h_pad=1.0, w_pad=1.0, pad=0.1) diff --git a/examples/decoding/ems_filtering.py b/examples/decoding/ems_filtering.py index 8807bf57079..34b3bcf8489 100644 --- a/examples/decoding/ems_filtering.py +++ b/examples/decoding/ems_filtering.py @@ -39,24 +39,33 @@ data_path = sample.data_path() # Preprocess the data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -event_ids = {'AudL': 1, 'VisL': 3} +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +event_ids = {"AudL": 1, "VisL": 3} # Read data and create epochs raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(0.5, 45, fir_design='firwin') +raw.filter(0.5, 45, fir_design="firwin") events = mne.read_events(event_fname) -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') - -epochs = mne.Epochs(raw, events, event_ids, tmin=-0.2, tmax=0.5, picks=picks, - baseline=None, reject=dict(grad=4000e-13, eog=150e-6), - preload=True) +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) + +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin=-0.2, + tmax=0.5, + picks=picks, + baseline=None, + reject=dict(grad=4000e-13, eog=150e-6), + preload=True, +) epochs.drop_bad() -epochs.pick_types(meg='grad') +epochs.pick_types(meg="grad") # Setup the data to use it a scikit-learn way: X = epochs.get_data() # The MEG data @@ -98,23 +107,27 @@ # Plot individual trials plt.figure() -plt.title('single trial surrogates') -plt.imshow(X_transform[y.argsort()], origin='lower', aspect='auto', - extent=[epochs.times[0], epochs.times[-1], 1, len(X_transform)], - cmap='RdBu_r') -plt.xlabel('Time (ms)') -plt.ylabel('Trials (reordered by condition)') +plt.title("single trial surrogates") +plt.imshow( + X_transform[y.argsort()], + origin="lower", + aspect="auto", + extent=[epochs.times[0], epochs.times[-1], 1, len(X_transform)], + cmap="RdBu_r", +) +plt.xlabel("Time (ms)") +plt.ylabel("Trials (reordered by condition)") # Plot average response plt.figure() -plt.title('Average EMS signal') +plt.title("Average EMS signal") mappings = [(key, value) for key, value in event_ids.items()] for key, value in mappings: ems_ave = X_transform[y == value] plt.plot(epochs.times, ems_ave.mean(0), label=key) -plt.xlabel('Time (ms)') -plt.ylabel('a.u.') -plt.legend(loc='best') +plt.xlabel("Time (ms)") +plt.ylabel("a.u.") +plt.legend(loc="best") plt.show() # Visualize spatial filters across time diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index f708503214b..1786df4a4b8 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -37,23 +37,24 @@ print(__doc__) data_path = sample.data_path() -sample_path = data_path / 'MEG' / 'sample' +sample_path = data_path / "MEG" / "sample" # %% # Set parameters -raw_fname = sample_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = sample_path / 'sample_audvis_filt-0-40_raw-eve.fif' +raw_fname = sample_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = sample_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.4 event_id = dict(aud_l=1, vis_l=3) # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(.5, 25, fir_design='firwin') +raw.filter(0.5, 25, fir_design="firwin") events = mne.read_events(event_fname) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - decim=2, baseline=None, preload=True) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, proj=True, decim=2, baseline=None, preload=True +) del raw labels = epochs.events[:, -1] @@ -66,7 +67,7 @@ # Decoding in sensor space using a LogisticRegression classifier # -------------------------------------------------------------- -clf = LogisticRegression(solver='liblinear') # liblinear is faster than lbfgs +clf = LogisticRegression(solver="liblinear") # liblinear is faster than lbfgs scaler = StandardScaler() # create a linear model with LogisticRegression @@ -77,7 +78,7 @@ model.fit(X, labels) # Extract and plot spatial filters and spatial patterns -for name, coef in (('patterns', model.patterns_), ('filters', model.filters_)): +for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): # We fitted the linear model onto Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] @@ -89,7 +90,7 @@ # Plot evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin) fig = evoked.plot_topomap() - fig.suptitle(f'MEG {name}') + fig.suptitle(f"MEG {name}") # %% # Let's do the same on EEG data using a scikit-learn pipeline @@ -99,22 +100,22 @@ # Define a unique pipeline to sequentially: clf = make_pipeline( - Vectorizer(), # 1) vectorize across time and channels - StandardScaler(), # 2) normalize features across trials - LinearModel( # 3) fits a logistic regression - LogisticRegression(solver='liblinear') - ) + Vectorizer(), # 1) vectorize across time and channels + StandardScaler(), # 2) normalize features across trials + LinearModel( # 3) fits a logistic regression + LogisticRegression(solver="liblinear") + ), ) clf.fit(X, y) # Extract and plot patterns and filters -for name in ('patterns_', 'filters_'): +for name in ("patterns_", "filters_"): # The `inverse_transform` parameter will call this method on any estimator # contained in the pipeline, in reverse order. coef = get_coef(clf, name, inverse_transform=True) evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin) fig = evoked.plot_topomap() - fig.suptitle(f'EEG {name[:-1]}') + fig.suptitle(f"EEG {name[:-1]}") # %% # References diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 4e948613dbb..0773811f1f3 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -52,25 +52,25 @@ path = mne.datasets.mtrf.data_path() decim = 2 -data = loadmat(join(path, 'speech_data.mat')) -raw = data['EEG'].T -speech = data['envelope'].T -sfreq = float(data['Fs']) +data = loadmat(join(path, "speech_data.mat")) +raw = data["EEG"].T +speech = data["envelope"].T +sfreq = float(data["Fs"]) sfreq /= decim -speech = mne.filter.resample(speech, down=decim, npad='auto') -raw = mne.filter.resample(raw, down=decim, npad='auto') +speech = mne.filter.resample(speech, down=decim, npad="auto") +raw = mne.filter.resample(raw, down=decim, npad="auto") # Read in channel positions and create our MNE objects from the raw data -montage = mne.channels.make_standard_montage('biosemi128') -info = mne.create_info(montage.ch_names, sfreq, 'eeg').set_montage(montage) +montage = mne.channels.make_standard_montage("biosemi128") +info = mne.create_info(montage.ch_names, sfreq, "eeg").set_montage(montage) raw = mne.io.RawArray(raw, info) n_channels = len(raw.ch_names) # Plot a sample of brain and stimulus activity fig, ax = plt.subplots() -lns = ax.plot(scale(raw[:, :800][0].T), color='k', alpha=.1) -ln1 = ax.plot(scale(speech[0, :800]), color='r', lw=2) -ax.legend([lns[0], ln1[0]], ['EEG', 'Speech Envelope'], frameon=False) +lns = ax.plot(scale(raw[:, :800][0].T), color="k", alpha=0.1) +ln1 = ax.plot(scale(speech[0, :800]), color="r", lw=2) +ax.legend([lns[0], ln1[0]], ["EEG", "Speech Envelope"], frameon=False) ax.set(title="Sample activity", xlabel="Time (s)") mne.viz.tight_layout() @@ -83,11 +83,12 @@ # us to make predictions about the response to new stimuli. # Define the delays that we will use in the receptive field -tmin, tmax = -.2, .4 +tmin, tmax = -0.2, 0.4 # Initialize the model -rf = ReceptiveField(tmin, tmax, sfreq, feature_names=['envelope'], - estimator=1., scoring='corrcoef') +rf = ReceptiveField( + tmin, tmax, sfreq, feature_names=["envelope"], estimator=1.0, scoring="corrcoef" +) # We'll have (tmax - tmin) * sfreq delays # and an extra 2 delays since we are inclusive on the beginning / end index n_delays = int((tmax - tmin) * sfreq) + 2 @@ -104,7 +105,7 @@ coefs = np.zeros((n_splits, n_channels, n_delays)) scores = np.zeros((n_splits, n_channels)) for ii, (train, test) in enumerate(cv.split(speech)): - print('split %s / %s' % (ii + 1, n_splits)) + print("split %s / %s" % (ii + 1, n_splits)) rf.fit(speech[train], Y[train]) scores[ii] = rf.score(speech[test], Y[test]) # coef_ is shape (n_outputs, n_features, n_delays). we only have 1 feature @@ -119,7 +120,7 @@ fig, ax = plt.subplots() ix_chs = np.arange(n_channels) ax.plot(ix_chs, mean_scores) -ax.axhline(0, ls='--', color='r') +ax.axhline(0, ls="--", color="r") ax.set(title="Mean prediction score", xlabel="Channel", ylabel="Score ($r$)") mne.viz.tight_layout() @@ -135,20 +136,33 @@ time_plot = 0.180 # For highlighting a specific time. fig, ax = plt.subplots(figsize=(4, 8)) max_coef = mean_coefs.max() -ax.pcolormesh(times, ix_chs, mean_coefs, cmap='RdBu_r', - vmin=-max_coef, vmax=max_coef, shading='gouraud') -ax.axvline(time_plot, ls='--', color='k', lw=2) -ax.set(xlabel='Delay (s)', ylabel='Channel', title="Mean Model\nCoefficients", - xlim=times[[0, -1]], ylim=[len(ix_chs) - 1, 0], - xticks=np.arange(tmin, tmax + .2, .2)) +ax.pcolormesh( + times, + ix_chs, + mean_coefs, + cmap="RdBu_r", + vmin=-max_coef, + vmax=max_coef, + shading="gouraud", +) +ax.axvline(time_plot, ls="--", color="k", lw=2) +ax.set( + xlabel="Delay (s)", + ylabel="Channel", + title="Mean Model\nCoefficients", + xlim=times[[0, -1]], + ylim=[len(ix_chs) - 1, 0], + xticks=np.arange(tmin, tmax + 0.2, 0.2), +) plt.setp(ax.get_xticklabels(), rotation=45) mne.viz.tight_layout() # Make a topographic map of coefficients for a given delay (see Fig 2C) ix_plot = np.argmin(np.abs(time_plot - times)) fig, ax = plt.subplots() -mne.viz.plot_topomap(mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, - vlim=(-max_coef, max_coef)) +mne.viz.plot_topomap( + mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, vlim=(-max_coef, max_coef) +) ax.set(title="Topomap of model coefficients\nfor delay %s" % time_plot) mne.viz.tight_layout() @@ -174,15 +188,22 @@ # positive lags would index how a unit change in the amplitude of the EEG would # affect later stimulus activity (obviously this should have an amplitude of # zero). -tmin, tmax = -.2, 0. +tmin, tmax = -0.2, 0.0 # Initialize the model. Here the features are the EEG data. We also specify # ``patterns=True`` to compute inverse-transformed coefficients during model # fitting (cf. next section and :footcite:`HaufeEtAl2014`). # We'll use a ridge regression estimator with an alpha value similar to # Crosse et al. -sr = ReceptiveField(tmin, tmax, sfreq, feature_names=raw.ch_names, - estimator=1e4, scoring='corrcoef', patterns=True) +sr = ReceptiveField( + tmin, + tmax, + sfreq, + feature_names=raw.ch_names, + estimator=1e4, + scoring="corrcoef", + patterns=True, +) # We'll have (tmax - tmin) * sfreq delays # and an extra 2 delays since we are inclusive on the beginning / end index n_delays = int((tmax - tmin) * sfreq) + 2 @@ -195,7 +216,7 @@ patterns = coefs.copy() scores = np.zeros((n_splits,)) for ii, (train, test) in enumerate(cv.split(speech)): - print('split %s / %s' % (ii + 1, n_splits)) + print("split %s / %s" % (ii + 1, n_splits)) sr.fit(Y[train], speech[train]) scores[ii] = sr.score(Y[test], speech[test])[0] # coef_ is shape (n_outputs, n_features, n_delays). We have 128 features @@ -218,14 +239,15 @@ # stimulus envelopes side by side. y_pred = sr.predict(Y[test]) -time = np.linspace(0, 2., 5 * int(sfreq)) +time = np.linspace(0, 2.0, 5 * int(sfreq)) fig, ax = plt.subplots(figsize=(8, 4)) -ax.plot(time, speech[test][sr.valid_samples_][:int(5 * sfreq)], - color='grey', lw=2, ls='--') -ax.plot(time, y_pred[sr.valid_samples_][:int(5 * sfreq)], color='r', lw=2) -ax.legend([lns[0], ln1[0]], ['Envelope', 'Reconstruction'], frameon=False) +ax.plot( + time, speech[test][sr.valid_samples_][: int(5 * sfreq)], color="grey", lw=2, ls="--" +) +ax.plot(time, y_pred[sr.valid_samples_][: int(5 * sfreq)], color="r", lw=2) +ax.legend([lns[0], ln1[0]], ["Envelope", "Reconstruction"], frameon=False) ax.set(title="Stimulus reconstruction") -ax.set_xlabel('Time (s)') +ax.set_xlabel("Time (s)") mne.viz.tight_layout() # %% @@ -243,21 +265,33 @@ # interpretation as their value (and sign) directly relates to the stimulus # signal's strength (and effect direction). -time_plot = (-.140, -.125) # To average between two timepoints. -ix_plot = np.arange(np.argmin(np.abs(time_plot[0] - times)), - np.argmin(np.abs(time_plot[1] - times))) +time_plot = (-0.140, -0.125) # To average between two timepoints. +ix_plot = np.arange( + np.argmin(np.abs(time_plot[0] - times)), np.argmin(np.abs(time_plot[1] - times)) +) fig, ax = plt.subplots(1, 2) -mne.viz.plot_topomap(np.mean(mean_coefs[:, ix_plot], axis=1), - pos=info, axes=ax[0], show=False, - vlim=(-max_coef, max_coef)) -ax[0].set(title="Model coefficients\nbetween delays %s and %s" - % (time_plot[0], time_plot[1])) - -mne.viz.plot_topomap(np.mean(mean_patterns[:, ix_plot], axis=1), - pos=info, axes=ax[1], - show=False, vlim=(-max_patterns, max_patterns)) -ax[1].set(title="Inverse-transformed coefficients\nbetween delays %s and %s" - % (time_plot[0], time_plot[1])) +mne.viz.plot_topomap( + np.mean(mean_coefs[:, ix_plot], axis=1), + pos=info, + axes=ax[0], + show=False, + vlim=(-max_coef, max_coef), +) +ax[0].set( + title="Model coefficients\nbetween delays %s and %s" % (time_plot[0], time_plot[1]) +) + +mne.viz.plot_topomap( + np.mean(mean_patterns[:, ix_plot], axis=1), + pos=info, + axes=ax[1], + show=False, + vlim=(-max_patterns, max_patterns), +) +ax[1].set( + title="Inverse-transformed coefficients\nbetween delays %s and %s" + % (time_plot[0], time_plot[1]) +) mne.viz.tight_layout() # %% diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py index 723667c1864..6be80b9667c 100644 --- a/examples/decoding/ssd_spatial_filters.py +++ b/examples/decoding/ssd_spatial_filters.py @@ -28,11 +28,11 @@ # %% # Define parameters -fname = data_path() / 'SubjectCMC.ds' +fname = data_path() / "SubjectCMC.ds" # Prepare data raw = mne.io.read_raw_ctf(fname) -raw.crop(50., 110.).load_data() # crop for memory purposes +raw.crop(50.0, 110.0).load_data() # crop for memory purposes raw.resample(sfreq=250) raw.pick_types(meg=True, eeg=False, ref_meg=False) @@ -41,13 +41,23 @@ freqs_noise = 8, 13 -ssd = SSD(info=raw.info, - reg='oas', - sort_by_spectral_ratio=False, # False for purpose of example. - filt_params_signal=dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1), - filt_params_noise=dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1)) +ssd = SSD( + info=raw.info, + reg="oas", + sort_by_spectral_ratio=False, # False for purpose of example. + filt_params_signal=dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), + filt_params_noise=dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), +) ssd.fit(X=raw.get_data()) @@ -58,9 +68,8 @@ # (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t. # We rely on the inversion approach here. -pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, - info=ssd.info) -pattern.plot_topomap(units=dict(mag='A.U.'), time_format='') +pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info) +pattern.plot_topomap(units=dict(mag="A.U."), time_format="") # The topographies suggest that we picked up a parietal alpha generator. @@ -69,7 +78,8 @@ # Get psd of SSD-filtered signals. psd, freqs = mne.time_frequency.psd_array_welch( - ssd_sources, sfreq=raw.info['sfreq'], n_fft=4096) + ssd_sources, sfreq=raw.info["sfreq"], n_fft=4096 +) # Get spec_ratio information (already sorted). # Note that this is not necessary if sort_by_spectral_ratio=True (default). @@ -77,12 +87,12 @@ # Plot spectral ratio (see Eq. 24 in Nikulin 2011). fig, ax = plt.subplots(1) -ax.plot(spec_ratio, color='black') -ax.plot(spec_ratio[sorter], color='orange', label='sorted eigenvalues') +ax.plot(spec_ratio, color="black") +ax.plot(spec_ratio[sorter], color="orange", label="sorted eigenvalues") ax.set_xlabel("Eigenvalue Index") ax.set_ylabel(r"Spectral Ratio $\frac{P_f}{P_{sf}}$") ax.legend() -ax.axhline(1, linestyle='--') +ax.axhline(1, linestyle="--") # We can see that the initial sorting based on the eigenvalues # was already quite good. However, when using few components only @@ -96,12 +106,12 @@ # for highlighting the freq. band of interest bandfilt = (freqs_sig[0] <= freqs) & (freqs <= freqs_sig[1]) fig, ax = plt.subplots(1) -ax.loglog(freqs[below50], psd[0, below50], label='max SNR') -ax.loglog(freqs[below50], psd[-1, below50], label='min SNR') -ax.loglog(freqs[below50], psd[:, below50].mean(axis=0), label='mean') -ax.fill_between(freqs[bandfilt], 0, 10000, color='green', alpha=0.15) -ax.set_xlabel('log(frequency)') -ax.set_ylabel('log(power)') +ax.loglog(freqs[below50], psd[0, below50], label="max SNR") +ax.loglog(freqs[below50], psd[-1, below50], label="min SNR") +ax.loglog(freqs[below50], psd[:, below50].mean(axis=0), label="mean") +ax.fill_between(freqs[bandfilt], 0, 10000, color="green", alpha=0.15) +ax.set_xlabel("log(frequency)") +ax.set_ylabel("log(power)") ax.legend() # We can clearly see that the selected component enjoys an SNR that is @@ -117,25 +127,29 @@ events = mne.make_fixed_length_events(raw, id=1, duration=5.0, overlap=0.0) # Epoch length is 5 seconds. -epochs = Epochs(raw, events, tmin=0., tmax=5, - baseline=None, preload=True) - -ssd_epochs = SSD(info=epochs.info, - reg='oas', - filt_params_signal=dict(l_freq=freqs_sig[0], - h_freq=freqs_sig[1], - l_trans_bandwidth=1, - h_trans_bandwidth=1), - filt_params_noise=dict(l_freq=freqs_noise[0], - h_freq=freqs_noise[1], - l_trans_bandwidth=1, - h_trans_bandwidth=1)) +epochs = Epochs(raw, events, tmin=0.0, tmax=5, baseline=None, preload=True) + +ssd_epochs = SSD( + info=epochs.info, + reg="oas", + filt_params_signal=dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), + filt_params_noise=dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), +) ssd_epochs.fit(X=epochs.get_data()) # Plot topographies. -pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, - info=ssd_epochs.info) -pattern_epochs.plot_topomap(units=dict(mag='A.U.'), time_format='') +pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info) +pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="") # %% # References # ---------- diff --git a/examples/forward/forward_sensitivity_maps.py b/examples/forward/forward_sensitivity_maps.py index e17e8e38c12..dca41bb9b12 100644 --- a/examples/forward/forward_sensitivity_maps.py +++ b/examples/forward/forward_sensitivity_maps.py @@ -28,80 +28,84 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +subjects_dir = data_path / "subjects" # Read the forward solutions with surface orientation fwd = mne.read_forward_solution(fwd_fname) mne.convert_forward_solution(fwd, surf_ori=True, copy=False) -leadfield = fwd['sol']['data'] +leadfield = fwd["sol"]["data"] print("Leadfield size : %d x %d" % leadfield.shape) # %% # Compute sensitivity maps -grad_map = mne.sensitivity_map(fwd, ch_type='grad', mode='fixed') -mag_map = mne.sensitivity_map(fwd, ch_type='mag', mode='fixed') -eeg_map = mne.sensitivity_map(fwd, ch_type='eeg', mode='fixed') +grad_map = mne.sensitivity_map(fwd, ch_type="grad", mode="fixed") +mag_map = mne.sensitivity_map(fwd, ch_type="mag", mode="fixed") +eeg_map = mne.sensitivity_map(fwd, ch_type="eeg", mode="fixed") # %% # Show gain matrix a.k.a. leadfield matrix with sensitivity map -picks_meg = mne.pick_types(fwd['info'], meg=True, eeg=False) -picks_eeg = mne.pick_types(fwd['info'], meg=False, eeg=True) +picks_meg = mne.pick_types(fwd["info"], meg=True, eeg=False) +picks_eeg = mne.pick_types(fwd["info"], meg=False, eeg=True) fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True) -fig.suptitle('Lead field matrix (500 dipoles only)', fontsize=14) -for ax, picks, ch_type in zip(axes, [picks_meg, picks_eeg], ['meg', 'eeg']): - im = ax.imshow(leadfield[picks, :500], origin='lower', aspect='auto', - cmap='RdBu_r') +fig.suptitle("Lead field matrix (500 dipoles only)", fontsize=14) +for ax, picks, ch_type in zip(axes, [picks_meg, picks_eeg], ["meg", "eeg"]): + im = ax.imshow(leadfield[picks, :500], origin="lower", aspect="auto", cmap="RdBu_r") ax.set_title(ch_type.upper()) - ax.set_xlabel('sources') - ax.set_ylabel('sensors') + ax.set_xlabel("sources") + ax.set_ylabel("sensors") fig.colorbar(im, ax=ax) fig_2, ax = plt.subplots() -ax.hist([grad_map.data.ravel(), mag_map.data.ravel(), eeg_map.data.ravel()], - bins=20, label=['Gradiometers', 'Magnetometers', 'EEG'], - color=['c', 'b', 'k']) +ax.hist( + [grad_map.data.ravel(), mag_map.data.ravel(), eeg_map.data.ravel()], + bins=20, + label=["Gradiometers", "Magnetometers", "EEG"], + color=["c", "b", "k"], +) fig_2.legend() -ax.set(title='Normal orientation sensitivity', - xlabel='sensitivity', ylabel='count') +ax.set(title="Normal orientation sensitivity", xlabel="sensitivity", ylabel="count") # sphinx_gallery_thumbnail_number = 3 brain_sens = grad_map.plot( - subjects_dir=subjects_dir, clim=dict(lims=[0, 50, 100]), figure=1) -brain_sens.add_text(0.1, 0.9, 'Gradiometer sensitivity', 'title', font_size=16) + subjects_dir=subjects_dir, clim=dict(lims=[0, 50, 100]), figure=1 +) +brain_sens.add_text(0.1, 0.9, "Gradiometer sensitivity", "title", font_size=16) # %% # Compare sensitivity map with distribution of source depths # source space with vertices -src = fwd['src'] +src = fwd["src"] # Compute minimum Euclidean distances between vertices and MEG sensors -depths = compute_distance_to_sensors(src=src, info=fwd['info'], - picks=picks_meg).min(axis=1) +depths = compute_distance_to_sensors(src=src, info=fwd["info"], picks=picks_meg).min( + axis=1 +) maxdep = depths.max() # for scaling -vertices = [src[0]['vertno'], src[1]['vertno']] +vertices = [src[0]["vertno"], src[1]["vertno"]] -depths_map = SourceEstimate(data=depths, vertices=vertices, tmin=0., - tstep=1.) +depths_map = SourceEstimate(data=depths, vertices=vertices, tmin=0.0, tstep=1.0) brain_dep = depths_map.plot( - subject='sample', subjects_dir=subjects_dir, - clim=dict(kind='value', lims=[0, maxdep / 2., maxdep]), figure=2) -brain_dep.add_text(0.1, 0.9, 'Source depth (m)', 'title', font_size=16) + subject="sample", + subjects_dir=subjects_dir, + clim=dict(kind="value", lims=[0, maxdep / 2.0, maxdep]), + figure=2, +) +brain_dep.add_text(0.1, 0.9, "Source depth (m)", "title", font_size=16) # %% # Sensitivity is likely to co-vary with the distance between sources to # sensors. To determine the strength of this relationship, we can compute the # correlation between source depth and sensitivity values. corr = np.corrcoef(depths, grad_map.data[:, 0])[0, 1] -print('Correlation between source depth and gradiomter sensitivity values: %f.' - % corr) +print("Correlation between source depth and gradiomter sensitivity values: %f." % corr) # %% # Gradiometer sensitiviy is highest close to the sensors, and decreases rapidly diff --git a/examples/forward/left_cerebellum_volume_source.py b/examples/forward/left_cerebellum_volume_source.py index c8327100f10..e74b71c6c4f 100644 --- a/examples/forward/left_cerebellum_volume_source.py +++ b/examples/forward/left_cerebellum_volume_source.py @@ -23,9 +23,9 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -aseg_fname = subjects_dir / 'sample' / 'mri' / 'aseg.mgz' +subjects_dir = data_path / "subjects" +subject = "sample" +aseg_fname = subjects_dir / "sample" / "mri" / "aseg.mgz" # %% # Setup the source spaces @@ -35,11 +35,16 @@ lh_surf = surf[0] # setup a volume source space of the left cerebellum cortex -volume_label = 'Left-Cerebellum-Cortex' +volume_label = "Left-Cerebellum-Cortex" sphere = (0, 0, 0, 0.12) lh_cereb = setup_volume_source_space( - subject, mri=aseg_fname, sphere=sphere, volume_label=volume_label, - subjects_dir=subjects_dir, sphere_units='m') + subject, + mri=aseg_fname, + sphere=sphere, + volume_label=volume_label, + subjects_dir=subjects_dir, + sphere_units="m", +) # Combine the source spaces src = surf + lh_cereb @@ -47,11 +52,16 @@ # %% # Plot the positions of each source space -fig = mne.viz.plot_alignment(subject=subject, subjects_dir=subjects_dir, - surfaces='white', coord_frame='mri', - src=src) -mne.viz.set_3d_view(fig, azimuth=180, elevation=90, - distance=0.30, focalpoint=(-0.03, -0.01, 0.03)) +fig = mne.viz.plot_alignment( + subject=subject, + subjects_dir=subjects_dir, + surfaces="white", + coord_frame="mri", + src=src, +) +mne.viz.set_3d_view( + fig, azimuth=180, elevation=90, distance=0.30, focalpoint=(-0.03, -0.01, 0.03) +) # %% # You can export source positions to a NIfTI file:: diff --git a/examples/forward/source_space_morphing.py b/examples/forward/source_space_morphing.py index 77688705e97..5085e629615 100644 --- a/examples/forward/source_space_morphing.py +++ b/examples/forward/source_space_morphing.py @@ -24,38 +24,38 @@ import mne data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -fname_trans = ( - data_path / 'MEG' / 'sample' / 'sample_audvis_raw-trans.fif') -fname_bem = ( - subjects_dir / 'sample' / 'bem' / 'sample-5120-bem-sol.fif') -fname_src_fs = ( - subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-ico-5-src.fif') -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' +subjects_dir = data_path / "subjects" +fname_trans = data_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif" +fname_bem = subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif" +fname_src_fs = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" # Get relevant channel information info = mne.io.read_info(raw_fname) -info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, - exclude=[])) +info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, exclude=[])) # Morph fsaverage's source space to sample src_fs = mne.read_source_spaces(fname_src_fs) -src_morph = mne.morph_source_spaces(src_fs, subject_to='sample', - subjects_dir=subjects_dir) +src_morph = mne.morph_source_spaces( + src_fs, subject_to="sample", subjects_dir=subjects_dir +) # Compute the forward with our morphed source space -fwd = mne.make_forward_solution(info, trans=fname_trans, - src=src_morph, bem=fname_bem) -mag_map = mne.sensitivity_map(fwd, ch_type='mag') +fwd = mne.make_forward_solution(info, trans=fname_trans, src=src_morph, bem=fname_bem) +mag_map = mne.sensitivity_map(fwd, ch_type="mag") # Return this SourceEstimate (on sample's surfaces) to fsaverage's surfaces mag_map_fs = mag_map.to_original_src(src_fs, subjects_dir=subjects_dir) # Plot the result, which tracks the sulcal-gyral folding # outliers may occur, we'll place the cutoff at 99 percent. -kwargs = dict(clim=dict(kind='percent', lims=[0, 50, 99]), - # no smoothing, let's see the dipoles on the cortex. - smoothing_steps=1, hemi='rh', views=['lat']) +kwargs = dict( + clim=dict(kind="percent", lims=[0, 50, 99]), + # no smoothing, let's see the dipoles on the cortex. + smoothing_steps=1, + hemi="rh", + views=["lat"], +) # Now note that the dipoles on fsaverage are almost equidistant while # morphing will distribute the dipoles unevenly across the given subject's @@ -63,7 +63,9 @@ # Our testing code suggests a correlation of higher than 0.99. brain_subject = mag_map.plot( # plot forward in subject source space (morphed) - time_label='Morphed', subjects_dir=subjects_dir, **kwargs) + time_label="Morphed", subjects_dir=subjects_dir, **kwargs +) brain_fs = mag_map_fs.plot( # plot forward in original source space (remapped) - time_label='Remapped', subjects_dir=subjects_dir, **kwargs) + time_label="Remapped", subjects_dir=subjects_dir, **kwargs +) diff --git a/examples/inverse/compute_mne_inverse_epochs_in_label.py b/examples/inverse/compute_mne_inverse_epochs_in_label.py index e78b37c17fe..e779444f6cf 100644 --- a/examples/inverse/compute_mne_inverse_epochs_in_label.py +++ b/examples/inverse/compute_mne_inverse_epochs_in_label.py @@ -25,18 +25,18 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = meg_path / 'sample_audvis_filt-0-40_raw.fif' -fname_event = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -label_name = 'Aud-lh' -fname_label = meg_path / 'labels' / f'{label_name}.label' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = meg_path / "sample_audvis_filt-0-40_raw.fif" +fname_event = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +label_name = "Aud-lh" +fname_label = meg_path / "labels" / f"{label_name}.label" event_id, tmin, tmax = 1, -0.2, 0.5 # Using the same inverse operator when inspecting single trials Vs. evoked snr = 3.0 # Standard assumption for average data but using it for single trial -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) @@ -50,15 +50,23 @@ include = [] # Add a bad channel -raw.info['bads'] += ['EEG 053'] # bads + 1 more +raw.info["bads"] += ["EEG 053"] # bads + 1 more # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, include=include, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Get evoked data (averaging across trials in sensor space) evoked = epochs.average() @@ -66,21 +74,27 @@ # Compute inverse solution and stcs for each epoch # Use the same inverse operator as with evoked data (i.e., set nave) # If you use a different nave, dSPM just scales by a factor sqrt(nave) -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, label, - pick_ori="normal", nave=evoked.nave) +stcs = apply_inverse_epochs( + epochs, + inverse_operator, + lambda2, + method, + label, + pick_ori="normal", + nave=evoked.nave, +) # Mean across trials but not across vertices in label mean_stc = sum(stcs) / len(stcs) # compute sign flip to avoid signal cancellation when averaging signed values -flip = mne.label_sign_flip(label, inverse_operator['src']) +flip = mne.label_sign_flip(label, inverse_operator["src"]) label_mean = np.mean(mean_stc.data, axis=0) label_mean_flip = np.mean(flip[:, np.newaxis] * mean_stc.data, axis=0) # Get inverse solution by inverting evoked data -stc_evoked = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori="normal") +stc_evoked = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal") # apply_inverse() does whole brain, so sub-select label of interest stc_evoked_label = stc_evoked.in_label(label) @@ -94,13 +108,12 @@ times = 1e3 * stcs[0].times # times in ms plt.figure() -h0 = plt.plot(times, mean_stc.data.T, 'k') -h1, = plt.plot(times, label_mean, 'r', linewidth=3) -h2, = plt.plot(times, label_mean_flip, 'g', linewidth=3) -plt.legend((h0[0], h1, h2), ('all dipoles in label', 'mean', - 'mean with sign flip')) -plt.xlabel('time (ms)') -plt.ylabel('dSPM value') +h0 = plt.plot(times, mean_stc.data.T, "k") +(h1,) = plt.plot(times, label_mean, "r", linewidth=3) +(h2,) = plt.plot(times, label_mean_flip, "g", linewidth=3) +plt.legend((h0[0], h1, h2), ("all dipoles in label", "mean", "mean with sign flip")) +plt.xlabel("time (ms)") +plt.ylabel("dSPM value") plt.show() # %% @@ -110,19 +123,21 @@ # Single trial plt.figure() for k, stc_trial in enumerate(stcs): - plt.plot(times, np.mean(stc_trial.data, axis=0).T, 'k--', - label='Single Trials' if k == 0 else '_nolegend_', - alpha=0.5) + plt.plot( + times, + np.mean(stc_trial.data, axis=0).T, + "k--", + label="Single Trials" if k == 0 else "_nolegend_", + alpha=0.5, + ) # Single trial inverse then average.. making linewidth large to not be masked -plt.plot(times, label_mean, 'b', linewidth=6, - label='dSPM first, then average') +plt.plot(times, label_mean, "b", linewidth=6, label="dSPM first, then average") # Evoked and then inverse -plt.plot(times, label_mean_evoked, 'r', linewidth=2, - label='Average first, then dSPM') +plt.plot(times, label_mean_evoked, "r", linewidth=2, label="Average first, then dSPM") -plt.xlabel('time (ms)') -plt.ylabel('dSPM value') +plt.xlabel("time (ms)") +plt.ylabel("dSPM value") plt.legend() plt.show() diff --git a/examples/inverse/compute_mne_inverse_raw_in_label.py b/examples/inverse/compute_mne_inverse_raw_in_label.py index 1d473f2db1f..5c15563f76a 100644 --- a/examples/inverse/compute_mne_inverse_raw_in_label.py +++ b/examples/inverse/compute_mne_inverse_raw_in_label.py @@ -25,14 +25,13 @@ print(__doc__) data_path = sample.data_path() -fname_inv = ( - data_path / 'MEG' / 'sample' / 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_raw = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' -label_name = 'Aud-lh' -fname_label = data_path / 'MEG' / 'sample' / 'labels' / f'{label_name}.label' +fname_inv = data_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" +label_name = "Aud-lh" +fname_label = data_path / "MEG" / "sample" / "labels" / f"{label_name}.label" snr = 1.0 # use smaller SNR for raw data -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "sLORETA" # use sLORETA method (could also be MNE or dSPM) # Load data @@ -40,19 +39,20 @@ inverse_operator = read_inverse_operator(fname_inv) label = mne.read_label(fname_label) -raw.set_eeg_reference('average', projection=True) # set average reference. +raw.set_eeg_reference("average", projection=True) # set average reference. start, stop = raw.time_as_index([0, 15]) # read the first 15s of data # Compute inverse solution -stc = apply_inverse_raw(raw, inverse_operator, lambda2, method, label, - start, stop, pick_ori=None) +stc = apply_inverse_raw( + raw, inverse_operator, lambda2, method, label, start, stop, pick_ori=None +) # Save result in stc files -stc.save('mne_%s_raw_inverse_%s' % (method, label_name), overwrite=True) +stc.save("mne_%s_raw_inverse_%s" % (method, label_name), overwrite=True) # %% # View activation time-series plt.plot(1e3 * stc.times, stc.data[::100, :].T) -plt.xlabel('time (ms)') -plt.ylabel('%s value' % method) +plt.xlabel("time (ms)") +plt.ylabel("%s value" % method) plt.show() diff --git a/examples/inverse/compute_mne_inverse_volume.py b/examples/inverse/compute_mne_inverse_volume.py index 215977ca393..7b5193a081b 100644 --- a/examples/inverse/compute_mne_inverse_volume.py +++ b/examples/inverse/compute_mne_inverse_volume.py @@ -24,33 +24,36 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-vol-7-meg-inv.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-vol-7-meg-inv.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] +src = inverse_operator["src"] # Compute inverse solution stc = apply_inverse(evoked, inverse_operator, lambda2, method) stc.crop(0.0, 0.2) # Export result as a 4D nifti object -img = stc.as_volume(src, - mri_resolution=False) # set True for full MRI resolution +img = stc.as_volume(src, mri_resolution=False) # set True for full MRI resolution # Save it as a nifti file # nib.save(img, 'mne_%s_inverse.nii.gz' % method) -t1_fname = data_path / 'subjects' / 'sample' / 'mri' / 'T1.mgz' +t1_fname = data_path / "subjects" / "sample" / "mri" / "T1.mgz" # %% # Plot with nilearn: -plot_stat_map(index_img(img, 61), str(t1_fname), threshold=8., - title='%s (t=%.1f s.)' % (method, stc.times[61])) +plot_stat_map( + index_img(img, 61), + str(t1_fname), + threshold=8.0, + title="%s (t=%.1f s.)" % (method, stc.times[61]), +) diff --git a/examples/inverse/custom_inverse_solver.py b/examples/inverse/custom_inverse_solver.py index 760ef4408e5..3324b1198bb 100644 --- a/examples/inverse/custom_inverse_solver.py +++ b/examples/inverse/custom_inverse_solver.py @@ -29,12 +29,12 @@ data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -subjects_dir = data_path / 'subjects' -condition = 'Left Auditory' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +subjects_dir = data_path / "subjects" +condition = "Left Auditory" # Read noise covariance matrix noise_cov = mne.read_cov(cov_fname) @@ -50,6 +50,7 @@ # %% # Auxiliary function to run the solver + def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): """Call a custom solver on evoked data. @@ -93,19 +94,30 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): The source estimates. """ # Import the necessary private functions - from mne.inverse_sparse.mxne_inverse import \ - (_prepare_gain, is_fixed_orient, - _reapply_source_weighting, _make_sparse_stc) + from mne.inverse_sparse.mxne_inverse import ( + _prepare_gain, + is_fixed_orient, + _reapply_source_weighting, + _make_sparse_stc, + ) all_ch_names = evoked.ch_names # Handle depth weighting and whitening (here is no weights) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( - forward, evoked.info, noise_cov, pca=False, depth=depth, - loose=loose, weights=None, weights_min=None, rank=None) + forward, + evoked.info, + noise_cov, + pca=False, + depth=depth, + loose=loose, + weights=None, + weights_min=None, + rank=None, + ) # Select channels of interest - sel = [all_ch_names.index(name) for name in gain_info['ch_names']] + sel = [all_ch_names.index(name) for name in gain_info["ch_names"]] M = evoked.data[sel] # Whiten data @@ -115,8 +127,9 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): X, active_set = solver(M, gain, n_orient) X = _reapply_source_weighting(X, source_weighting, active_set) - stc = _make_sparse_stc(X, active_set, forward, tmin=evoked.times[0], - tstep=1. / evoked.info['sfreq']) + stc = _make_sparse_stc( + X, active_set, forward, tmin=evoked.times[0], tstep=1.0 / evoked.info["sfreq"] + ) return stc @@ -124,6 +137,7 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): # %% # Define your solver + def solver(M, G, n_orient): """Run L2 penalized regression and keep 10 strongest locations. @@ -155,11 +169,11 @@ def solver(M, G, n_orient): K /= np.linalg.norm(K, axis=1)[:, None] X = np.dot(K, M) - indices = np.argsort(np.sum(X ** 2, axis=1))[-10:] + indices = np.argsort(np.sum(X**2, axis=1))[-10:] active_set = np.zeros(G.shape[1], dtype=bool) for idx in indices: idx -= idx % n_orient - active_set[idx:idx + n_orient] = True + active_set[idx : idx + n_orient] = True X = X[active_set] return X, active_set @@ -168,10 +182,9 @@ def solver(M, G, n_orient): # Apply your custom solver # loose, depth = 0.2, 0.8 # corresponds to loose orientation -loose, depth = 1., 0. # corresponds to free orientation +loose, depth = 1.0, 0.0 # corresponds to free orientation stc = apply_solver(solver, evoked, forward, noise_cov, loose, depth) # %% # View in 2D and 3D ("glass" brain like 3D plot) -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - opacity=0.1) +plot_sparse_source_estimates(forward["src"], stc, bgcolor=(1, 1, 1), opacity=0.1) diff --git a/examples/inverse/dics_epochs.py b/examples/inverse/dics_epochs.py index 8aba68b9e44..dc8a0b7e14c 100644 --- a/examples/inverse/dics_epochs.py +++ b/examples/inverse/dics_epochs.py @@ -28,13 +28,13 @@ # Organize the data that we will use for this example. data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') -fname_fwd = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" +fname_fwd = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" # %% # First, we load the data and compute for each epoch the time-frequency @@ -43,11 +43,19 @@ # Load raw data and make epochs. raw = mne.io.read_raw_fif(raw_fname) events = mne.find_events(raw) -epochs = mne.Epochs(raw, events, event_id=1, tmin=-1, tmax=2.5, - reject=dict(grad=5000e-13, # unit: T / m (gradiometers) - mag=5e-12, # unit: T (magnetometers) - eog=250e-6, # unit: V (EOG channels) - ), preload=True) +epochs = mne.Epochs( + raw, + events, + event_id=1, + tmin=-1, + tmax=2.5, + reject=dict( + grad=5000e-13, # unit: T / m (gradiometers) + mag=5e-12, # unit: T (magnetometers) + eog=250e-6, # unit: V (EOG channels) + ), + preload=True, +) epochs = epochs[:10] # just for speed of execution for the tutorial # We are mostly interested in the beta band since it has been shown to be @@ -58,8 +66,9 @@ # decomposition for each epoch. We must pass ``output='complex'`` if we wish to # use this TFR later with a DICS beamformer. We also pass ``average=False`` to # compute the TFR for each individual epoch. -epochs_tfr = tfr_morlet(epochs, freqs, n_cycles=5, return_itc=False, - output='complex', average=False) +epochs_tfr = tfr_morlet( + epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False +) # crop either side to use a buffer to remove edge artifact epochs_tfr.crop(tmin=-0.5, tmax=2) @@ -78,15 +87,21 @@ fwd = mne.read_forward_solution(fname_fwd) # compute scalar DICS beamfomer -filters = make_dics(epochs.info, fwd, csd, noise_csd=baseline_csd, - pick_ori='max-power', reduce_rank=True, real_filter=True) +filters = make_dics( + epochs.info, + fwd, + csd, + noise_csd=baseline_csd, + pick_ori="max-power", + reduce_rank=True, + real_filter=True, +) # project the TFR for each epoch to source space -epochs_stcs = apply_dics_tfr_epochs( - epochs_tfr, filters, return_generator=True) +epochs_stcs = apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=True) # average across frequencies and epochs -data = np.zeros((fwd['nsource'], epochs_tfr.times.size)) +data = np.zeros((fwd["nsource"], epochs_tfr.times.size)) for epoch_stcs in epochs_stcs: for stc in epoch_stcs: data += (stc.data * np.conj(stc.data)).real @@ -104,13 +119,17 @@ fmax = 4500 brain = stc.plot( subjects_dir=subjects_dir, - hemi='both', - views='dorsal', + hemi="both", + views="dorsal", initial_time=0.55, brain_kwargs=dict(show=False), - add_data_kwargs=dict(fmin=fmax / 10, fmid=fmax / 2, fmax=fmax, - scale_factor=0.0001, - colorbar_kwargs=dict(label_font_size=10)) + add_data_kwargs=dict( + fmin=fmax / 10, + fmid=fmax / 2, + fmax=fmax, + scale_factor=0.0001, + colorbar_kwargs=dict(label_font_size=10), + ), ) # You can save a movie like the one on our documentation website with: diff --git a/examples/inverse/dics_source_power.py b/examples/inverse/dics_source_power.py index 8a3ee2c1cf6..68925202b17 100644 --- a/examples/inverse/dics_source_power.py +++ b/examples/inverse/dics_source_power.py @@ -31,10 +31,9 @@ # %% # Reading the raw data and creating epochs: data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" # Use a shorter segment of raw just for speed here raw = mne.io.read_raw_fif(raw_fname) @@ -47,10 +46,11 @@ del raw # Paths to forward operator and FreeSurfer subject directory -fname_fwd = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') +fname_fwd = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" # %% # We are interested in the beta band. Define a range of frequencies, using a @@ -79,8 +79,15 @@ # Computing DICS spatial filters using the CSD that was computed on the entire # timecourse. fwd = mne.read_forward_solution(fname_fwd) -filters = make_dics(info, fwd, csd, noise_csd=csd_baseline, - pick_ori='max-power', reduce_rank=True, real_filter=True) +filters = make_dics( + info, + fwd, + csd, + noise_csd=csd_baseline, + pick_ori="max-power", + reduce_rank=True, + real_filter=True, +) del fwd # %% @@ -92,9 +99,14 @@ # %% # Visualizing source power during ERS activity relative to the baseline power. stc = beta_source_power / baseline_source_power -message = 'DICS source power in the 12-30 Hz frequency band' -brain = stc.plot(hemi='both', views='axial', subjects_dir=subjects_dir, - subject=subject, time_label=message) +message = "DICS source power in the 12-30 Hz frequency band" +brain = stc.plot( + hemi="both", + views="axial", + subjects_dir=subjects_dir, + subject=subject, + time_label=message, +) # %% # References diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index b3ccaab5e04..272b0518293 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -22,19 +22,22 @@ from mne.cov import compute_covariance from mne.datasets import somato from mne.time_frequency import csd_morlet -from mne.beamformer import (make_dics, apply_dics_csd, make_lcmv, - apply_lcmv_cov) -from mne.minimum_norm import (make_inverse_operator, apply_inverse_cov) +from mne.beamformer import make_dics, apply_dics_csd, make_lcmv, apply_lcmv_cov +from mne.minimum_norm import make_inverse_operator, apply_inverse_cov print(__doc__) # %% # Reading the raw data and creating epochs: data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / 'sub-{}'.format(subject) / 'meg' / - 'sub-{}_task-{}_meg.fif'.format(subject, task)) +subject = "01" +task = "somato" +raw_fname = ( + data_path + / "sub-{}".format(subject) + / "meg" + / "sub-{}_task-{}_meg.fif".format(subject, task) +) # crop to 5 minutes to save memory raw = mne.io.read_raw_fif(raw_fname).crop(0, 300) @@ -44,17 +47,22 @@ # The DICS beamformer currently only supports a single sensor type. # We'll use the gradiometers in this example. -picks = mne.pick_types(raw.info, meg='grad', exclude='bads') +picks = mne.pick_types(raw.info, meg="grad", exclude="bads") # Read epochs events = mne.find_events(raw) -epochs = mne.Epochs(raw, events, event_id=1, tmin=-1.5, tmax=2, picks=picks, - preload=True, decim=3) +epochs = mne.Epochs( + raw, events, event_id=1, tmin=-1.5, tmax=2, picks=picks, preload=True, decim=3 +) # Read forward operator and point to freesurfer subject directory -fname_fwd = (data_path / 'derivatives' / 'sub-{}'.format(subject) / - 'sub-{}_task-{}-fwd.fif'.format(subject, task)) -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +fname_fwd = ( + data_path + / "derivatives" + / "sub-{}".format(subject) + / "sub-{}_task-{}-fwd.fif".format(subject, task) +) +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" fwd = mne.read_forward_solution(fname_fwd) @@ -68,14 +76,25 @@ # combination with an advanced covariance estimator like "shrunk", the rank # will be correctly preserved. -rank = mne.compute_rank(epochs, tol=1e-6, tol_kind='relative') +rank = mne.compute_rank(epochs, tol=1e-6, tol_kind="relative") active_win = (0.5, 1.5) baseline_win = (-1, 0) -baseline_cov = compute_covariance(epochs, tmin=baseline_win[0], - tmax=baseline_win[1], method='shrunk', - rank=rank, verbose=True) -active_cov = compute_covariance(epochs, tmin=active_win[0], tmax=active_win[1], - method='shrunk', rank=rank, verbose=True) +baseline_cov = compute_covariance( + epochs, + tmin=baseline_win[0], + tmax=baseline_win[1], + method="shrunk", + rank=rank, + verbose=True, +) +active_cov = compute_covariance( + epochs, + tmin=active_win[0], + tmax=active_win[1], + method="shrunk", + rank=rank, + verbose=True, +) # Weighted averaging is already in the addition of covariance objects. common_cov = baseline_cov + active_cov @@ -93,12 +112,21 @@ def _gen_dics(active_win, baseline_win, epochs): freqs = np.logspace(np.log10(12), np.log10(30), 9) csd = csd_morlet(epochs, freqs, tmin=-1, tmax=1.5, decim=20) - csd_baseline = csd_morlet(epochs, freqs, tmin=baseline_win[0], - tmax=baseline_win[1], decim=20) - csd_ers = csd_morlet(epochs, freqs, tmin=active_win[0], tmax=active_win[1], - decim=20) - filters = make_dics(epochs.info, fwd, csd.mean(), pick_ori='max-power', - reduce_rank=True, real_filter=True, rank=rank) + csd_baseline = csd_morlet( + epochs, freqs, tmin=baseline_win[0], tmax=baseline_win[1], decim=20 + ) + csd_ers = csd_morlet( + epochs, freqs, tmin=active_win[0], tmax=active_win[1], decim=20 + ) + filters = make_dics( + epochs.info, + fwd, + csd.mean(), + pick_ori="max-power", + reduce_rank=True, + real_filter=True, + rank=rank, + ) stc_base, freqs = apply_dics_csd(csd_baseline.mean(), filters) stc_act, freqs = apply_dics_csd(csd_ers.mean(), filters) stc_act /= stc_base @@ -107,8 +135,9 @@ def _gen_dics(active_win, baseline_win, epochs): # generate lcmv source estimate def _gen_lcmv(active_cov, baseline_cov, common_cov): - filters = make_lcmv(epochs.info, fwd, common_cov, reg=0.05, - noise_cov=None, pick_ori='max-power') + filters = make_lcmv( + epochs.info, fwd, common_cov, reg=0.05, noise_cov=None, pick_ori="max-power" + ) stc_base = apply_lcmv_cov(baseline_cov, filters) stc_act = apply_lcmv_cov(active_cov, filters) stc_act /= stc_base @@ -116,12 +145,14 @@ def _gen_lcmv(active_cov, baseline_cov, common_cov): # generate mne/dSPM source estimate -def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): +def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method="dSPM"): inverse_operator = make_inverse_operator(info, fwd, common_cov) - stc_act = apply_inverse_cov(active_cov, info, inverse_operator, - method=method, verbose=True) - stc_base = apply_inverse_cov(baseline_cov, info, inverse_operator, - method=method, verbose=True) + stc_act = apply_inverse_cov( + active_cov, info, inverse_operator, method=method, verbose=True + ) + stc_base = apply_inverse_cov( + baseline_cov, info, inverse_operator, method=method, verbose=True + ) stc_act /= stc_base return stc_act @@ -137,22 +168,31 @@ def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): # DICS: brain_dics = stc_dics.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='DICS source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="DICS source power in the 12-30 Hz frequency band", +) # %% # LCMV: brain_lcmv = stc_lcmv.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='LCMV source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="LCMV source power in the 12-30 Hz frequency band", +) # %% # dSPM: brain_dspm = stc_dspm.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='dSPM source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="dSPM source power in the 12-30 Hz frequency band", +) # %% # For more advanced usage, see diff --git a/examples/inverse/gamma_map_inverse.py b/examples/inverse/gamma_map_inverse.py index 20a205c3322..f3ff529a331 100644 --- a/examples/inverse/gamma_map_inverse.py +++ b/examples/inverse/gamma_map_inverse.py @@ -19,22 +19,24 @@ import mne from mne.datasets import sample from mne.inverse_sparse import gamma_map, make_stc_from_dipoles -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -evoked_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +evoked_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" # Read the evoked response and crop it -condition = 'Left visual' -evoked = mne.read_evokeds(evoked_fname, condition=condition, - baseline=(None, 0)) +condition = "Left visual" +evoked = mne.read_evokeds(evoked_fname, condition=condition, baseline=(None, 0)) evoked.crop(tmin=-50e-3, tmax=300e-3) # Read the forward solution @@ -47,8 +49,14 @@ # Run the Gamma-MAP method with dipole output alpha = 0.5 dipoles, residual = gamma_map( - evoked, forward, cov, alpha, xyz_same_gamma=True, return_residual=True, - return_as_dipoles=True) + evoked, + forward, + cov, + alpha, + xyz_same_gamma=True, + return_residual=True, + return_as_dipoles=True, +) # %% # Plot dipole activations @@ -56,9 +64,14 @@ # Plot dipole location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # # Plot dipole locations of all dipoles with MRI slices # for dip in dipoles: @@ -69,17 +82,22 @@ # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-120, 120]) -evoked.pick_types(meg='grad', exclude='bads') -evoked.plot(titles=dict(grad='Evoked Response Gradiometers'), ylim=ylim, - proj=True, time_unit='s') - -residual.pick_types(meg='grad', exclude='bads') -residual.plot(titles=dict(grad='Residuals Gradiometers'), ylim=ylim, - proj=True, time_unit='s') +evoked.pick_types(meg="grad", exclude="bads") +evoked.plot( + titles=dict(grad="Evoked Response Gradiometers"), + ylim=ylim, + proj=True, + time_unit="s", +) + +residual.pick_types(meg="grad", exclude="bads") +residual.plot( + titles=dict(grad="Residuals Gradiometers"), ylim=ylim, proj=True, time_unit="s" +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) @@ -88,9 +106,14 @@ scale_factors = 0.5 * (1 + scale_factors / np.max(scale_factors)) plot_sparse_source_estimates( - forward['src'], stc, bgcolor=(1, 1, 1), - modes=['sphere'], opacity=0.1, scale_factors=(scale_factors, None), - fig_name="Gamma-MAP") + forward["src"], + stc, + bgcolor=(1, 1, 1), + modes=["sphere"], + opacity=0.1, + scale_factors=(scale_factors, None), + fig_name="Gamma-MAP", +) # %% # References diff --git a/examples/inverse/label_activation_from_stc.py b/examples/inverse/label_activation_from_stc.py index 20368b68183..358de19bff2 100644 --- a/examples/inverse/label_activation_from_stc.py +++ b/examples/inverse/label_activation_from_stc.py @@ -24,15 +24,15 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" # load the stc -stc = mne.read_source_estimate(meg_path / 'sample_audvis-meg') +stc = mne.read_source_estimate(meg_path / "sample_audvis-meg") # load the labels -aud_lh = mne.read_label(meg_path / 'labels' / 'Aud-lh.label') -aud_rh = mne.read_label(meg_path / 'labels' / 'Aud-rh.label') +aud_lh = mne.read_label(meg_path / "labels" / "Aud-lh.label") +aud_rh = mne.read_label(meg_path / "labels" / "Aud-rh.label") # extract the time course for different labels from the stc stc_lh = stc.in_label(aud_lh) @@ -40,25 +40,27 @@ stc_bh = stc.in_label(aud_lh + aud_rh) # calculate center of mass and transform to mni coordinates -vtx, _, t_lh = stc_lh.center_of_mass('sample', subjects_dir=subjects_dir) -mni_lh = mne.vertex_to_mni(vtx, 0, 'sample', subjects_dir=subjects_dir)[0] -vtx, _, t_rh = stc_rh.center_of_mass('sample', subjects_dir=subjects_dir) -mni_rh = mne.vertex_to_mni(vtx, 1, 'sample', subjects_dir=subjects_dir)[0] +vtx, _, t_lh = stc_lh.center_of_mass("sample", subjects_dir=subjects_dir) +mni_lh = mne.vertex_to_mni(vtx, 0, "sample", subjects_dir=subjects_dir)[0] +vtx, _, t_rh = stc_rh.center_of_mass("sample", subjects_dir=subjects_dir) +mni_rh = mne.vertex_to_mni(vtx, 1, "sample", subjects_dir=subjects_dir)[0] # plot the activation plt.figure() -plt.axes([.1, .275, .85, .625]) -hl = plt.plot(stc.times, stc_lh.data.mean(0), 'b')[0] -hr = plt.plot(stc.times, stc_rh.data.mean(0), 'g')[0] -hb = plt.plot(stc.times, stc_bh.data.mean(0), 'r')[0] -plt.xlabel('Time (s)') -plt.ylabel('Source amplitude (dSPM)') +plt.axes([0.1, 0.275, 0.85, 0.625]) +hl = plt.plot(stc.times, stc_lh.data.mean(0), "b")[0] +hr = plt.plot(stc.times, stc_rh.data.mean(0), "g")[0] +hb = plt.plot(stc.times, stc_bh.data.mean(0), "r")[0] +plt.xlabel("Time (s)") +plt.ylabel("Source amplitude (dSPM)") plt.xlim(stc.times[0], stc.times[-1]) # add a legend including center-of-mass mni coordinates to the plot -labels = ['LH: center of mass = %s' % mni_lh.round(2), - 'RH: center of mass = %s' % mni_rh.round(2), - 'Combined LH & RH'] -plt.figlegend([hl, hr, hb], labels, loc='lower center') -plt.suptitle('Average activation in auditory cortex labels', fontsize=20) +labels = [ + "LH: center of mass = %s" % mni_lh.round(2), + "RH: center of mass = %s" % mni_rh.round(2), + "Combined LH & RH", +] +plt.figlegend([hl, hr, hb], labels, loc="lower center") +plt.suptitle("Average activation in auditory cortex labels", fontsize=20) plt.show() diff --git a/examples/inverse/label_from_stc.py b/examples/inverse/label_from_stc.py index 3d3abae2a16..39469e8c68b 100644 --- a/examples/inverse/label_from_stc.py +++ b/examples/inverse/label_from_stc.py @@ -28,29 +28,27 @@ print(__doc__) data_path = sample.data_path() -fname_inv = ( - data_path / 'MEG' / 'sample' / 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_evoked = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' -subjects_dir = data_path / 'subjects' -subject = 'sample' +fname_inv = data_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" +subjects_dir = data_path / "subjects" +subject = "sample" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Compute a label/ROI based on the peak power between 80 and 120 ms. # The label bankssts-lh is used for the comparison. -aparc_label_name = 'bankssts-lh' +aparc_label_name = "bankssts-lh" tmin, tmax = 0.080, 0.120 # Load data evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] # get the source space +src = inverse_operator["src"] # get the source space # Compute inverse solution -stc = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori='normal') +stc = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal") # Make an STC in the time interval of interest and take the mean stc_mean = stc.copy().crop(tmin, tmax).mean() @@ -58,33 +56,38 @@ # use the stc_mean to generate a functional label # region growing is halted at 60% of the peak value within the # anatomical label / ROI specified by aparc_label_name -label = mne.read_labels_from_annot(subject, parc='aparc', - subjects_dir=subjects_dir, - regexp=aparc_label_name)[0] +label = mne.read_labels_from_annot( + subject, parc="aparc", subjects_dir=subjects_dir, regexp=aparc_label_name +)[0] stc_mean_label = stc_mean.in_label(label) data = np.abs(stc_mean_label.data) -stc_mean_label.data[data < 0.6 * np.max(data)] = 0. +stc_mean_label.data[data < 0.6 * np.max(data)] = 0.0 # 8.5% of original source space vertices were omitted during forward # calculation, suppress the warning here with verbose='error' -func_labels, _ = mne.stc_to_label(stc_mean_label, src=src, smooth=True, - subjects_dir=subjects_dir, connected=True, - verbose='error') +func_labels, _ = mne.stc_to_label( + stc_mean_label, + src=src, + smooth=True, + subjects_dir=subjects_dir, + connected=True, + verbose="error", +) # take first as func_labels are ordered based on maximum values in stc func_label = func_labels[0] # load the anatomical ROI for comparison -anat_label = mne.read_labels_from_annot(subject, parc='aparc', - subjects_dir=subjects_dir, - regexp=aparc_label_name)[0] +anat_label = mne.read_labels_from_annot( + subject, parc="aparc", subjects_dir=subjects_dir, regexp=aparc_label_name +)[0] # extract the anatomical time course for each label stc_anat_label = stc.in_label(anat_label) -pca_anat = stc.extract_label_time_course(anat_label, src, mode='pca_flip')[0] +pca_anat = stc.extract_label_time_course(anat_label, src, mode="pca_flip")[0] stc_func_label = stc.in_label(func_label) -pca_func = stc.extract_label_time_course(func_label, src, mode='pca_flip')[0] +pca_func = stc.extract_label_time_course(func_label, src, mode="pca_flip")[0] # flip the pca so that the max power between tmin and tmax is positive pca_anat *= np.sign(pca_anat[np.argmax(np.abs(pca_anat))]) @@ -93,18 +96,20 @@ # %% # plot the time courses.... plt.figure() -plt.plot(1e3 * stc_anat_label.times, pca_anat, 'k', - label='Anatomical %s' % aparc_label_name) -plt.plot(1e3 * stc_func_label.times, pca_func, 'b', - label='Functional %s' % aparc_label_name) +plt.plot( + 1e3 * stc_anat_label.times, pca_anat, "k", label="Anatomical %s" % aparc_label_name +) +plt.plot( + 1e3 * stc_func_label.times, pca_func, "b", label="Functional %s" % aparc_label_name +) plt.legend() plt.show() # %% # plot brain in 3D with mne.viz.Brain if available -brain = stc_mean.plot(hemi='lh', subjects_dir=subjects_dir) -brain.show_view('lateral') +brain = stc_mean.plot(hemi="lh", subjects_dir=subjects_dir) +brain.show_view("lateral") # show both labels -brain.add_label(anat_label, borders=True, color='k') -brain.add_label(func_label, borders=True, color='b') +brain.add_label(anat_label, borders=True, color="k") +brain.add_label(func_label, borders=True, color="b") diff --git a/examples/inverse/label_source_activations.py b/examples/inverse/label_source_activations.py index 30a55970d81..599fff4c2f8 100644 --- a/examples/inverse/label_source_activations.py +++ b/examples/inverse/label_source_activations.py @@ -28,32 +28,31 @@ print(__doc__) data_path = sample.data_path() -label = 'Aud-lh' -meg_path = data_path / 'MEG' / 'sample' -label_fname = meg_path / 'labels' / f'{label}.label' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +label = "Aud-lh" +meg_path = data_path / "MEG" / "sample" +label_fname = meg_path / "labels" / f"{label}.label" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] +src = inverse_operator["src"] # %% # Compute inverse solution # ------------------------ pick_ori = "normal" # Get signed values to see the effect of sign flip -stc = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori=pick_ori) +stc = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori=pick_ori) label = mne.read_label(label_fname) stc_label = stc.in_label(label) -modes = ('mean', 'mean_flip', 'pca_flip') +modes = ("mean", "mean_flip", "pca_flip") tcs = dict() for mode in modes: tcs[mode] = stc.extract_label_time_course(label, src, mode=mode) @@ -65,17 +64,23 @@ fig, ax = plt.subplots(1) t = 1e3 * stc_label.times -ax.plot(t, stc_label.data.T, 'k', linewidth=0.5, alpha=0.5) -pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), - path_effects.Normal()] +ax.plot(t, stc_label.data.T, "k", linewidth=0.5, alpha=0.5) +pe = [ + path_effects.Stroke(linewidth=5, foreground="w", alpha=0.5), + path_effects.Normal(), +] for mode, tc in tcs.items(): ax.plot(t, tc[0], linewidth=3, label=str(mode), path_effects=pe) xlim = t[[0, -1]] ylim = [-27, 22] -ax.legend(loc='upper right') -ax.set(xlabel='Time (ms)', ylabel='Source amplitude', - title='Activations in Label %r' % (label.name), - xlim=xlim, ylim=ylim) +ax.legend(loc="upper right") +ax.set( + xlabel="Time (ms)", + ylabel="Source amplitude", + title="Activations in Label %r" % (label.name), + xlim=xlim, + ylim=ylim, +) mne.viz.tight_layout() # %% @@ -84,21 +89,32 @@ # It's also possible to compute label time courses for a # :class:`mne.VectorSourceEstimate`, but only with ``mode='mean'``. -pick_ori = 'vector' -stc_vec = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori=pick_ori) +pick_ori = "vector" +stc_vec = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori=pick_ori) data = stc_vec.extract_label_time_course(label, src) fig, ax = plt.subplots(1) stc_vec_label = stc_vec.in_label(label) -colors = ['#EE6677', '#228833', '#4477AA'] -for ii, name in enumerate('XYZ'): +colors = ["#EE6677", "#228833", "#4477AA"] +for ii, name in enumerate("XYZ"): color = colors[ii] - ax.plot(t, stc_vec_label.data[:, ii].T, color=color, lw=0.5, alpha=0.5, - zorder=5 - ii) - ax.plot(t, data[0, ii], lw=3, color=color, label='+' + name, zorder=8 - ii, - path_effects=pe) -ax.legend(loc='upper right') -ax.set(xlabel='Time (ms)', ylabel='Source amplitude', - title='Mean vector activations in Label %r' % (label.name,), - xlim=xlim, ylim=ylim) + ax.plot( + t, stc_vec_label.data[:, ii].T, color=color, lw=0.5, alpha=0.5, zorder=5 - ii + ) + ax.plot( + t, + data[0, ii], + lw=3, + color=color, + label="+" + name, + zorder=8 - ii, + path_effects=pe, + ) +ax.legend(loc="upper right") +ax.set( + xlabel="Time (ms)", + ylabel="Source amplitude", + title="Mean vector activations in Label %r" % (label.name,), + xlim=xlim, + ylim=ylim, +) mne.viz.tight_layout() diff --git a/examples/inverse/mixed_norm_inverse.py b/examples/inverse/mixed_norm_inverse.py index 56b64e744a1..ce8a1e74a69 100644 --- a/examples/inverse/mixed_norm_inverse.py +++ b/examples/inverse/mixed_norm_inverse.py @@ -25,22 +25,25 @@ from mne.datasets import sample from mne.inverse_sparse import mixed_norm, make_stc_from_dipoles from mne.minimum_norm import make_inverse_operator, apply_inverse -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +subjects_dir = data_path / "subjects" # Read noise covariance matrix cov = mne.read_cov(cov_fname) # Handling average file -condition = 'Left Auditory' +condition = "Left Auditory" evoked = mne.read_evokeds(ave_fname, condition=condition, baseline=(None, 0)) evoked.crop(tmin=0, tmax=0.3) # Handling forward solution @@ -54,18 +57,30 @@ # if n_mxne_iter > 1 dSPM weighting can be avoided. # Compute dSPM solution to be used as weights in MxNE -inverse_operator = make_inverse_operator(evoked.info, forward, cov, - depth=depth, fixed=True, - use_cps=True) -stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1. / 9., - method='dSPM') +inverse_operator = make_inverse_operator( + evoked.info, forward, cov, depth=depth, fixed=True, use_cps=True +) +stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1.0 / 9.0, method="dSPM") # Compute (ir)MxNE inverse solution with dipole output dipoles, residual = mixed_norm( - evoked, forward, cov, alpha, loose=loose, depth=depth, maxit=3000, - tol=1e-4, active_set_size=10, debias=False, weights=stc_dspm, - weights_min=8., n_mxne_iter=n_mxne_iter, return_residual=True, - return_as_dipoles=True, verbose=True, random_state=0, + evoked, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=3000, + tol=1e-4, + active_set_size=10, + debias=False, + weights=stc_dspm, + weights_min=8.0, + n_mxne_iter=n_mxne_iter, + return_residual=True, + return_as_dipoles=True, + verbose=True, + random_state=0, # for this dataset we know we should use a high alpha, so avoid some # of the slower (lower) alpha values sure_alpha_grid=np.linspace(100, 40, 10), @@ -74,8 +89,7 @@ t = 0.083 tidx = evoked.time_as_index(t) for di, dip in enumerate(dipoles, 1): - print(f'Dipole #{di} GOF at {1000 * t:0.1f} ms: ' - f'{float(dip.gof[tidx]):0.1f}%') + print(f"Dipole #{di} GOF at {1000 * t:0.1f} ms: " f"{float(dip.gof[tidx]):0.1f}%") # %% # Plot dipole activations @@ -83,48 +97,70 @@ # Plot dipole location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # Plot dipole locations of all dipoles with MRI slices for dip in dipoles: - plot_dipole_locations(dip, forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') + plot_dipole_locations( + dip, + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", + ) # %% # Plot residual ylim = dict(eeg=[-10, 10], grad=[-400, 400], mag=[-600, 600]) -evoked.pick_types(meg=True, eeg=True, exclude='bads') -evoked.plot(ylim=ylim, proj=True, time_unit='s') -residual.pick_types(meg=True, eeg=True, exclude='bads') -residual.plot(ylim=ylim, proj=True, time_unit='s') +evoked.pick_types(meg=True, eeg=True, exclude="bads") +evoked.plot(ylim=ylim, proj=True, time_unit="s") +residual.pick_types(meg=True, eeg=True, exclude="bads") +residual.plot(ylim=ylim, proj=True, time_unit="s") # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) solver = "MxNE" if n_mxne_iter == 1 else "irMxNE" -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - fig_name="%s (cond %s)" % (solver, condition), - opacity=0.1) +plot_sparse_source_estimates( + forward["src"], + stc, + bgcolor=(1, 1, 1), + fig_name="%s (cond %s)" % (solver, condition), + opacity=0.1, +) # %% # Morph onto fsaverage brain and view -morph = mne.compute_source_morph(stc, subject_from='sample', - subject_to='fsaverage', spacing=None, - sparse=True, subjects_dir=subjects_dir) +morph = mne.compute_source_morph( + stc, + subject_from="sample", + subject_to="fsaverage", + spacing=None, + sparse=True, + subjects_dir=subjects_dir, +) stc_fsaverage = morph.apply(stc) -src_fsaverage_fname = ( - subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-ico-5-src.fif') +src_fsaverage_fname = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" src_fsaverage = mne.read_source_spaces(src_fsaverage_fname) -plot_sparse_source_estimates(src_fsaverage, stc_fsaverage, bgcolor=(1, 1, 1), - fig_name="Morphed %s (cond %s)" % (solver, - condition), opacity=0.1) +plot_sparse_source_estimates( + src_fsaverage, + stc_fsaverage, + bgcolor=(1, 1, 1), + fig_name="Morphed %s (cond %s)" % (solver, condition), + opacity=0.1, +) # %% # References diff --git a/examples/inverse/mixed_source_space_inverse.py b/examples/inverse/mixed_source_space_inverse.py index f732178ea9f..9baac7da379 100644 --- a/examples/inverse/mixed_source_space_inverse.py +++ b/examples/inverse/mixed_source_space_inverse.py @@ -23,22 +23,22 @@ # Set dir data_path = mne.datasets.sample.data_path() -subject = 'sample' -data_dir = data_path / 'MEG' / subject -subjects_dir = data_path / 'subjects' -bem_dir = subjects_dir / subject / 'bem' +subject = "sample" +data_dir = data_path / "MEG" / subject +subjects_dir = data_path / "subjects" +bem_dir = subjects_dir / subject / "bem" # Set file names -fname_mixed_src = bem_dir / f'{subject}-oct-6-mixed-src.fif' -fname_aseg = subjects_dir / subject / 'mri' / 'aseg.mgz' +fname_mixed_src = bem_dir / f"{subject}-oct-6-mixed-src.fif" +fname_aseg = subjects_dir / subject / "mri" / "aseg.mgz" -fname_model = bem_dir / f'{subject}-5120-bem.fif' -fname_bem = bem_dir / f'{subject}-5120-bem-sol.fif' +fname_model = bem_dir / f"{subject}-5120-bem.fif" +fname_bem = bem_dir / f"{subject}-5120-bem-sol.fif" -fname_evoked = data_dir / f'{subject}_audvis-ave.fif' -fname_trans = data_dir / f'{subject}_audvis_raw-trans.fif' -fname_fwd = data_dir / f'{subject}_audvis-meg-oct-6-mixed-fwd.fif' -fname_cov = data_dir / f'{subject}_audvis-shrunk-cov.fif' +fname_evoked = data_dir / f"{subject}_audvis-ave.fif" +fname_trans = data_dir / f"{subject}_audvis_raw-trans.fif" +fname_fwd = data_dir / f"{subject}_audvis-meg-oct-6-mixed-fwd.fif" +fname_cov = data_dir / f"{subject}_audvis-shrunk-cov.fif" # %% # Set up our source space @@ -46,19 +46,22 @@ # List substructures we are interested in. We select only the # sub structures we want to include in the source space: -labels_vol = ['Left-Amygdala', - 'Left-Thalamus-Proper', - 'Left-Cerebellum-Cortex', - 'Brain-Stem', - 'Right-Amygdala', - 'Right-Thalamus-Proper', - 'Right-Cerebellum-Cortex'] +labels_vol = [ + "Left-Amygdala", + "Left-Thalamus-Proper", + "Left-Cerebellum-Cortex", + "Brain-Stem", + "Right-Amygdala", + "Right-Thalamus-Proper", + "Right-Cerebellum-Cortex", +] # %% # Get a surface-based source space, here with few source points for speed # in this demonstration, in general you should use oct6 spacing! -src = mne.setup_source_space(subject, spacing='oct5', - add_dist=False, subjects_dir=subjects_dir) +src = mne.setup_source_space( + subject, spacing="oct5", add_dist=False, subjects_dir=subjects_dir +) # %% # Now we create a mixed src space by adding the volume regions specified in the @@ -67,15 +70,22 @@ # we recommend something smaller like 5.0 in actual analyses): vol_src = mne.setup_volume_source_space( - subject, mri=fname_aseg, pos=10.0, bem=fname_model, - volume_label=labels_vol, subjects_dir=subjects_dir, + subject, + mri=fname_aseg, + pos=10.0, + bem=fname_model, + volume_label=labels_vol, + subjects_dir=subjects_dir, add_interpolator=False, # just for speed, usually this should be True - verbose=True) + verbose=True, +) # Generate the mixed source space src += vol_src -print(f"The source space contains {len(src)} spaces and " - f"{sum(s['nuse'] for s in src)} vertices") +print( + f"The source space contains {len(src)} spaces and " + f"{sum(s['nuse'] for s in src)} vertices" +) # %% # View the source space @@ -90,47 +100,54 @@ # # We can also export source positions to NIfTI file and visualize it again: -nii_fname = bem_dir / f'{subject}-mixed-src.nii' +nii_fname = bem_dir / f"{subject}-mixed-src.nii" src.export_volume(nii_fname, mri_resolution=True, overwrite=True) -plotting.plot_img(str(nii_fname), cmap='nipy_spectral') +plotting.plot_img(str(nii_fname), cmap="nipy_spectral") # %% # Compute the fwd matrix # ---------------------- fwd = mne.make_forward_solution( - fname_evoked, fname_trans, src, fname_bem, + fname_evoked, + fname_trans, + src, + fname_bem, mindist=5.0, # ignore sources<=5mm from innerskull - meg=True, eeg=False, n_jobs=None) + meg=True, + eeg=False, + n_jobs=None, +) del src # save memory -leadfield = fwd['sol']['data'] +leadfield = fwd["sol"]["data"] print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape) -print(f"The fwd source space contains {len(fwd['src'])} spaces and " - f"{sum(s['nuse'] for s in fwd['src'])} vertices") +print( + f"The fwd source space contains {len(fwd['src'])} spaces and " + f"{sum(s['nuse'] for s in fwd['src'])} vertices" +) # Load data -condition = 'Left Auditory' -evoked = mne.read_evokeds(fname_evoked, condition=condition, - baseline=(None, 0)) +condition = "Left Auditory" +evoked = mne.read_evokeds(fname_evoked, condition=condition, baseline=(None, 0)) noise_cov = mne.read_cov(fname_cov) # %% # Compute inverse solution # ------------------------ -snr = 3.0 # use smaller SNR for raw data -inv_method = 'dSPM' # sLORETA, MNE, dSPM -parc = 'aparc' # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' -loose = dict(surface=0.2, volume=1.) +snr = 3.0 # use smaller SNR for raw data +inv_method = "dSPM" # sLORETA, MNE, dSPM +parc = "aparc" # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' +loose = dict(surface=0.2, volume=1.0) -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 inverse_operator = make_inverse_operator( - evoked.info, fwd, noise_cov, depth=None, loose=loose, verbose=True) + evoked.info, fwd, noise_cov, depth=None, loose=loose, verbose=True +) del fwd -stc = apply_inverse(evoked, inverse_operator, lambda2, inv_method, - pick_ori=None) -src = inverse_operator['src'] +stc = apply_inverse(evoked, inverse_operator, lambda2, inv_method, pick_ori=None) +src = inverse_operator["src"] # %% # Plot the mixed source estimate @@ -138,24 +155,30 @@ # sphinx_gallery_thumbnail_number = 3 initial_time = 0.1 -stc_vec = apply_inverse(evoked, inverse_operator, lambda2, inv_method, - pick_ori='vector') +stc_vec = apply_inverse( + evoked, inverse_operator, lambda2, inv_method, pick_ori="vector" +) brain = stc_vec.plot( - hemi='both', src=inverse_operator['src'], views='coronal', - initial_time=initial_time, subjects_dir=subjects_dir, - brain_kwargs=dict(silhouette=True), smoothing_steps=7) + hemi="both", + src=inverse_operator["src"], + views="coronal", + initial_time=initial_time, + subjects_dir=subjects_dir, + brain_kwargs=dict(silhouette=True), + smoothing_steps=7, +) # %% # Plot the surface # ---------------- -brain = stc.surface().plot(initial_time=initial_time, - subjects_dir=subjects_dir, smoothing_steps=7) +brain = stc.surface().plot( + initial_time=initial_time, subjects_dir=subjects_dir, smoothing_steps=7 +) # %% # Plot the volume # --------------- -fig = stc.volume().plot(initial_time=initial_time, src=src, - subjects_dir=subjects_dir) +fig = stc.volume().plot(initial_time=initial_time, src=src, subjects_dir=subjects_dir) # %% # Process labels @@ -164,16 +187,16 @@ # and each sub structure contained in the src space # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels_parc = mne.read_labels_from_annot( - subject, parc=parc, subjects_dir=subjects_dir) +labels_parc = mne.read_labels_from_annot(subject, parc=parc, subjects_dir=subjects_dir) label_ts = mne.extract_label_time_course( - [stc], labels_parc, src, mode='mean', allow_empty=True) + [stc], labels_parc, src, mode="mean", allow_empty=True +) # plot the times series of 2 labels fig, axes = plt.subplots(1) -axes.plot(1e3 * stc.times, label_ts[0][0, :], 'k', label='bankssts-lh') -axes.plot(1e3 * stc.times, label_ts[0][-1, :].T, 'r', label='Brain-stem') -axes.set(xlabel='Time (ms)', ylabel='MNE current (nAm)') +axes.plot(1e3 * stc.times, label_ts[0][0, :], "k", label="bankssts-lh") +axes.plot(1e3 * stc.times, label_ts[0][-1, :].T, "r", label="Brain-stem") +axes.set(xlabel="Time (ms)", ylabel="MNE current (nAm)") axes.legend() mne.viz.tight_layout() diff --git a/examples/inverse/mne_cov_power.py b/examples/inverse/mne_cov_power.py index 91fc47bc577..592664a72ef 100644 --- a/examples/inverse/mne_cov_power.py +++ b/examples/inverse/mne_cov_power.py @@ -30,9 +30,9 @@ from mne.minimum_norm import make_inverse_operator, apply_inverse_cov data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = mne.io.read_raw_fif(raw_fname) # %% @@ -41,28 +41,39 @@ # First we compute an empty-room covariance, which captures noise from the # sensors and environment. -raw_empty_room_fname = data_path / 'MEG' / 'sample' / 'ernoise_raw.fif' +raw_empty_room_fname = data_path / "MEG" / "sample" / "ernoise_raw.fif" raw_empty_room = mne.io.read_raw_fif(raw_empty_room_fname) raw_empty_room.crop(0, 30) # cropped just for speed -raw_empty_room.info['bads'] = ['MEG 2443'] -raw_empty_room.add_proj(raw.info['projs']) -noise_cov = mne.compute_raw_covariance(raw_empty_room, method='shrunk') +raw_empty_room.info["bads"] = ["MEG 2443"] +raw_empty_room.add_proj(raw.info["projs"]) +noise_cov = mne.compute_raw_covariance(raw_empty_room, method="shrunk") del raw_empty_room # %% # Epoch the data # -------------- -raw.pick(['meg', 'stim', 'eog']).load_data().filter(4, 12) -raw.info['bads'] = ['MEG 2443'] -events = mne.find_events(raw, stim_channel='STI 014') +raw.pick(["meg", "stim", "eog"]).load_data().filter(4, 12) +raw.info["bads"] = ["MEG 2443"] +events = mne.find_events(raw, stim_channel="STI 014") event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) tmin, tmax = -0.2, 0.5 baseline = (None, 0) # means from the first instant to t = 0 reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, - proj=True, picks=('meg', 'eog'), baseline=None, - reject=reject, preload=True, decim=5, verbose='error') +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=("meg", "eog"), + baseline=None, + reject=reject, + preload=True, + decim=5, + verbose="error", +) del raw # %% @@ -78,9 +89,11 @@ # to noise sources). base_cov = mne.compute_covariance( - epochs, tmin=-0.2, tmax=0, method='shrunk', verbose=True) + epochs, tmin=-0.2, tmax=0, method="shrunk", verbose=True +) data_cov = mne.compute_covariance( - epochs, tmin=0., tmax=0.2, method='shrunk', verbose=True) + epochs, tmin=0.0, tmax=0.2, method="shrunk", verbose=True +) fig_noise_cov = mne.viz.plot_cov(noise_cov, epochs.info, show_svd=False) fig_base_cov = mne.viz.plot_cov(base_cov, epochs.info, show_svd=False) @@ -91,16 +104,18 @@ # baseline and data covariances, followed by the data covariance whitened # by the baseline covariance: -evoked = epochs.average().pick('meg') -evoked.drop_channels(evoked.info['bads']) -evoked.plot(time_unit='s') -evoked.plot_topomap(times=np.linspace(0.05, 0.15, 5), ch_type='mag') +evoked = epochs.average().pick("meg") +evoked.drop_channels(evoked.info["bads"]) +evoked.plot(time_unit="s") +evoked.plot_topomap(times=np.linspace(0.05, 0.15, 5), ch_type="mag") -loop = {'Noise': (noise_cov, dict()), - 'Data': (data_cov, dict()), - 'Whitened data': (data_cov, dict(noise_cov=noise_cov))} +loop = { + "Noise": (noise_cov, dict()), + "Data": (data_cov, dict()), + "Whitened data": (data_cov, dict(noise_cov=noise_cov)), +} for title, (_cov, _kw) in loop.items(): - fig = _cov.plot_topomap(evoked.info, 'grad', **_kw) + fig = _cov.plot_topomap(evoked.info, "grad", **_kw) fig.suptitle(title) # %% @@ -109,20 +124,31 @@ # Finally, we can construct an inverse using the empty-room noise covariance: # Read the forward solution and compute the inverse operator -fname_fwd = meg_path / 'sample_audvis-meg-oct-6-fwd.fif' +fname_fwd = meg_path / "sample_audvis-meg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fname_fwd) # make an MEG inverse operator info = evoked.info -inverse_operator = make_inverse_operator(info, fwd, noise_cov, - loose=0.2, depth=0.8) +inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8) # %% # Project our data and baseline covariance to source space: -stc_data = apply_inverse_cov(data_cov, evoked.info, inverse_operator, - nave=len(epochs), method='dSPM', verbose=True) -stc_base = apply_inverse_cov(base_cov, evoked.info, inverse_operator, - nave=len(epochs), method='dSPM', verbose=True) +stc_data = apply_inverse_cov( + data_cov, + evoked.info, + inverse_operator, + nave=len(epochs), + method="dSPM", + verbose=True, +) +stc_base = apply_inverse_cov( + base_cov, + evoked.info, + inverse_operator, + nave=len(epochs), + method="dSPM", + verbose=True, +) # %% # And visualize power is relative to the baseline: @@ -130,6 +156,9 @@ # sphinx_gallery_thumbnail_number = 9 stc_data /= stc_base -brain = stc_data.plot(subject='sample', subjects_dir=subjects_dir, - clim=dict(kind='percent', lims=(50, 90, 98)), - smoothing_steps=7) +brain = stc_data.plot( + subject="sample", + subjects_dir=subjects_dir, + clim=dict(kind="percent", lims=(50, 90, 98)), + smoothing_steps=7, +) diff --git a/examples/inverse/morph_surface_stc.py b/examples/inverse/morph_surface_stc.py index 80a35c87ed8..0417a8d807a 100644 --- a/examples/inverse/morph_surface_stc.py +++ b/examples/inverse/morph_surface_stc.py @@ -37,19 +37,18 @@ # Setup paths data_path = sample.data_path() -sample_dir = data_path / 'MEG' / 'sample' -subjects_dir = data_path / 'subjects' -fname_src = subjects_dir / 'sample' / 'bem' / 'sample-oct-6-src.fif' -fname_fwd = sample_dir / 'sample_audvis-meg-oct-6-fwd.fif' -fname_fsaverage_src = (subjects_dir / 'fsaverage' / 'bem' / - 'fsaverage-ico-5-src.fif') -fname_stc = sample_dir / 'sample_audvis-meg' +sample_dir = data_path / "MEG" / "sample" +subjects_dir = data_path / "subjects" +fname_src = subjects_dir / "sample" / "bem" / "sample-oct-6-src.fif" +fname_fwd = sample_dir / "sample_audvis-meg-oct-6-fwd.fif" +fname_fsaverage_src = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" +fname_stc = sample_dir / "sample_audvis-meg" # %% # Load example data # Read stc from file -stc = mne.read_source_estimate(fname_stc, subject='sample') +stc = mne.read_source_estimate(fname_stc, subject="sample") # %% # Setting up SourceMorph for SourceEstimate @@ -66,7 +65,7 @@ src_orig = mne.read_source_spaces(fname_src) print(src_orig) # n_used=4098, 4098 fwd = mne.read_forward_solution(fname_fwd) -print(fwd['src']) # n_used=3732, 3766 +print(fwd["src"]) # n_used=3732, 3766 print([len(v) for v in stc.vertices]) # %% @@ -86,10 +85,14 @@ # Initialize SourceMorph for SourceEstimate src_to = mne.read_source_spaces(fname_fsaverage_src) -print(src_to[0]['vertno']) # special, np.arange(10242) -morph = mne.compute_source_morph(stc, subject_from='sample', - subject_to='fsaverage', src_to=src_to, - subjects_dir=subjects_dir) +print(src_to[0]["vertno"]) # special, np.arange(10242) +morph = mne.compute_source_morph( + stc, + subject_from="sample", + subject_to="fsaverage", + src_to=src_to, + subjects_dir=subjects_dir, +) # %% # Apply morph to (Vector) SourceEstimate @@ -106,25 +109,28 @@ # Define plotting parameters surfer_kwargs = dict( - hemi='lh', subjects_dir=subjects_dir, - clim=dict(kind='value', lims=[8, 12, 15]), views='lateral', - initial_time=0.09, time_unit='s', size=(800, 800), - smoothing_steps=5) + hemi="lh", + subjects_dir=subjects_dir, + clim=dict(kind="value", lims=[8, 12, 15]), + views="lateral", + initial_time=0.09, + time_unit="s", + size=(800, 800), + smoothing_steps=5, +) # As spherical surface -brain = stc_fsaverage.plot(surface='sphere', **surfer_kwargs) +brain = stc_fsaverage.plot(surface="sphere", **surfer_kwargs) # Add title -brain.add_text(0.1, 0.9, 'Morphed to fsaverage (spherical)', 'title', - font_size=16) +brain.add_text(0.1, 0.9, "Morphed to fsaverage (spherical)", "title", font_size=16) # %% # As inflated surface -brain_inf = stc_fsaverage.plot(surface='inflated', **surfer_kwargs) +brain_inf = stc_fsaverage.plot(surface="inflated", **surfer_kwargs) # Add title -brain_inf.add_text(0.1, 0.9, 'Morphed to fsaverage (inflated)', 'title', - font_size=16) +brain_inf.add_text(0.1, 0.9, "Morphed to fsaverage (inflated)", "title", font_size=16) # %% # Reading and writing SourceMorph from and to disk @@ -153,8 +159,7 @@ # easily chained into a handy one-liner. Taking this together the shortest # possible way to morph data directly would be: -stc_fsaverage = mne.compute_source_morph(stc, - subjects_dir=subjects_dir).apply(stc) +stc_fsaverage = mne.compute_source_morph(stc, subjects_dir=subjects_dir).apply(stc) # %% # For more examples, check out :ref:`examples using SourceMorph.apply diff --git a/examples/inverse/morph_volume_stc.py b/examples/inverse/morph_volume_stc.py index 1494b7b30c8..adf20db7905 100644 --- a/examples/inverse/morph_volume_stc.py +++ b/examples/inverse/morph_volume_stc.py @@ -38,16 +38,15 @@ # %% # Setup paths sample_dir_raw = sample.data_path() -sample_dir = os.path.join(sample_dir_raw, 'MEG', 'sample') -subjects_dir = os.path.join(sample_dir_raw, 'subjects') +sample_dir = os.path.join(sample_dir_raw, "MEG", "sample") +subjects_dir = os.path.join(sample_dir_raw, "subjects") -fname_evoked = os.path.join(sample_dir, 'sample_audvis-ave.fif') -fname_inv = os.path.join(sample_dir, 'sample_audvis-meg-vol-7-meg-inv.fif') +fname_evoked = os.path.join(sample_dir, "sample_audvis-ave.fif") +fname_inv = os.path.join(sample_dir, "sample_audvis-meg-vol-7-meg-inv.fif") -fname_t1_fsaverage = os.path.join(subjects_dir, 'fsaverage', 'mri', - 'brain.mgz') +fname_t1_fsaverage = os.path.join(subjects_dir, "fsaverage", "mri", "brain.mgz") fetch_fsaverage(subjects_dir) # ensure fsaverage src exists -fname_src_fsaverage = subjects_dir + '/fsaverage/bem/fsaverage-vol-5-src.fif' +fname_src_fsaverage = subjects_dir + "/fsaverage/bem/fsaverage-vol-5-src.fif" # %% # Compute example data. For reference see :ref:`ex-inverse-volume`. @@ -57,7 +56,7 @@ inverse_operator = read_inverse_operator(fname_inv) # Apply inverse operator -stc = apply_inverse(evoked, inverse_operator, 1.0 / 3.0 ** 2, "dSPM") +stc = apply_inverse(evoked, inverse_operator, 1.0 / 3.0**2, "dSPM") # To save time stc.crop(0.09, 0.09) @@ -84,9 +83,14 @@ src_fs = mne.read_source_spaces(fname_src_fsaverage) morph = mne.compute_source_morph( - inverse_operator['src'], subject_from='sample', subjects_dir=subjects_dir, - niter_affine=[10, 10, 5], niter_sdr=[10, 10, 5], # just for speed - src_to=src_fs, verbose=True) + inverse_operator["src"], + subject_from="sample", + subjects_dir=subjects_dir, + niter_affine=[10, 10, 5], + niter_sdr=[10, 10, 5], # just for speed + src_to=src_fs, + verbose=True, +) # %% # Apply morph to VolSourceEstimate @@ -119,7 +123,7 @@ # :meth:`morph.apply(..., output='nifti1') `. # Create mri-resolution volume of results -img_fsaverage = morph.apply(stc, mri_resolution=2, output='nifti1') +img_fsaverage = morph.apply(stc, mri_resolution=2, output="nifti1") # %% # Plot results @@ -129,10 +133,9 @@ t1_fsaverage = nib.load(fname_t1_fsaverage) # Plot glass brain (change to plot_anat to display an overlaid anatomical T1) -display = plot_glass_brain(t1_fsaverage, - title='subject results to fsaverage', - draw_cross=False, - annotate=True) +display = plot_glass_brain( + t1_fsaverage, title="subject results to fsaverage", draw_cross=False, annotate=True +) # Add functional data as overlay display.add_overlay(img_fsaverage, alpha=0.75) diff --git a/examples/inverse/multi_dipole_model.py b/examples/inverse/multi_dipole_model.py index afed2d738df..40bbf60c919 100644 --- a/examples/inverse/multi_dipole_model.py +++ b/examples/inverse/multi_dipole_model.py @@ -33,17 +33,16 @@ import mne from mne.datasets import sample from mne.channels import read_vectorview_selection -from mne.minimum_norm import (make_inverse_operator, apply_inverse, - apply_inverse_epochs) +from mne.minimum_norm import make_inverse_operator, apply_inverse, apply_inverse_epochs import matplotlib.pyplot as plt import numpy as np data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -bem_dir = data_path / 'subjects' / 'sample' / 'bem' -bem_fname = bem_dir / 'sample-5120-5120-5120-bem-sol.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +bem_dir = data_path / "subjects" / "sample" / "bem" +bem_fname = bem_dir / "sample-5120-5120-5120-bem-sol.fif" ############################################################################### # Read the MEG data from the audvis experiment. Make epochs and evokeds for the @@ -55,13 +54,19 @@ # Create epochs for auditory events events = mne.find_events(raw) event_id = dict(right=1, left=2) -epochs = mne.Epochs(raw, events, event_id, - tmin=-0.1, tmax=0.3, baseline=(None, 0), - reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin=-0.1, + tmax=0.3, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Create evokeds for left and right auditory stimulation -evoked_left = epochs['left'].average() -evoked_right = epochs['right'].average() +evoked_left = epochs["left"].average() +evoked_right = epochs["right"].average() ############################################################################### # Guided dipole modeling, meaning fitting dipoles to a manually selected subset @@ -76,12 +81,12 @@ # Fit two dipoles at t=80ms. The first dipole is fitted using only the sensors # on the left side of the helmet. The second dipole is fitted using only the # sensors on the right side of the helmet. -picks_left = read_vectorview_selection('Left', info=info) +picks_left = read_vectorview_selection("Left", info=info) evoked_fit_left = evoked_left.copy().crop(0.08, 0.08) evoked_fit_left.pick_channels(picks_left, ordered=False) cov_fit_left = cov.copy().pick_channels(picks_left, ordered=False) -picks_right = read_vectorview_selection('Right', info=info) +picks_right = read_vectorview_selection("Right", info=info) evoked_fit_right = evoked_right.copy().crop(0.08, 0.08) evoked_fit_right.pick_channels(picks_right, ordered=False) cov_fit_right = cov.copy().pick_channels(picks_right, ordered=False) @@ -90,8 +95,8 @@ # after picking channels. evoked_fit_left.info.normalize_proj() evoked_fit_right.info.normalize_proj() -cov_fit_left['projs'] = evoked_fit_left.info['projs'] -cov_fit_right['projs'] = evoked_fit_right.info['projs'] +cov_fit_left["projs"] = evoked_fit_left.info["projs"] +cov_fit_right["projs"] = evoked_fit_right.info["projs"] # Fit the dipoles with the subset of sensors. dip_left, _ = mne.fit_dipole(evoked_fit_left, cov_fit_left, bem) @@ -107,27 +112,25 @@ # Apply MNE inverse inv = make_inverse_operator(info, fwd, cov, fixed=True, depth=0) -stc_left = apply_inverse(evoked_left, inv, method='MNE', lambda2=1E-6) -stc_right = apply_inverse(evoked_right, inv, method='MNE', lambda2=1E-6) +stc_left = apply_inverse(evoked_left, inv, method="MNE", lambda2=1e-6) +stc_right = apply_inverse(evoked_right, inv, method="MNE", lambda2=1e-6) # Plot the timecourses of the resulting source estimate fig, axes = plt.subplots(nrows=2, sharex=True, sharey=True) axes[0].plot(stc_left.times, stc_left.data.T) -axes[0].set_title('Left auditory stimulation') -axes[0].legend(['Dipole 1', 'Dipole 2']) +axes[0].set_title("Left auditory stimulation") +axes[0].legend(["Dipole 1", "Dipole 2"]) axes[1].plot(stc_right.times, stc_right.data.T) -axes[1].set_title('Right auditory stimulation') -axes[1].set_xlabel('Time (s)') -fig.supylabel('Dipole amplitude') +axes[1].set_title("Right auditory stimulation") +axes[1].set_xlabel("Time (s)") +fig.supylabel("Dipole amplitude") ############################################################################### # We can also fit the timecourses to single epochs. Here, we do it for each # experimental condition separately. -stcs_left = apply_inverse_epochs(epochs['left'], inv, lambda2=1E-6, - method='MNE') -stcs_right = apply_inverse_epochs(epochs['right'], inv, lambda2=1E-6, - method='MNE') +stcs_left = apply_inverse_epochs(epochs["left"], inv, lambda2=1e-6, method="MNE") +stcs_right = apply_inverse_epochs(epochs["right"], inv, lambda2=1e-6, method="MNE") ############################################################################### # To summarize and visualize the single-epoch dipole amplitudes, we will create @@ -151,17 +154,17 @@ mean_right = np.mean(amplitudes_right, axis=0) fig, ax = plt.subplots(figsize=(8, 4)) -ax.scatter(np.arange(n), amplitudes[:, 0], label='Dipole 1') -ax.scatter(np.arange(n), amplitudes[:, 1], label='Dipole 2') +ax.scatter(np.arange(n), amplitudes[:, 0], label="Dipole 1") +ax.scatter(np.arange(n), amplitudes[:, 1], label="Dipole 2") transition_point = n_left - 0.5 -ax.plot([0, transition_point], [mean_left[0], mean_left[0]], color='C0') -ax.plot([0, transition_point], [mean_left[1], mean_left[1]], color='C1') -ax.plot([transition_point, n], [mean_right[0], mean_right[0]], color='C0') -ax.plot([transition_point, n], [mean_right[1], mean_right[1]], color='C1') -ax.axvline(transition_point, color='black') -ax.set_xlabel('Epochs') -ax.set_ylabel('Dipole amplitude') +ax.plot([0, transition_point], [mean_left[0], mean_left[0]], color="C0") +ax.plot([0, transition_point], [mean_left[1], mean_left[1]], color="C1") +ax.plot([transition_point, n], [mean_right[0], mean_right[0]], color="C0") +ax.plot([transition_point, n], [mean_right[1], mean_right[1]], color="C1") +ax.axvline(transition_point, color="black") +ax.set_xlabel("Epochs") +ax.set_ylabel("Dipole amplitude") ax.legend() -fig.suptitle('Single epoch dipole amplitudes') -fig.text(0.30, 0.9, 'Left auditory stimulation', ha='center') -fig.text(0.70, 0.9, 'Right auditory stimulation', ha='center') +fig.suptitle("Single epoch dipole amplitudes") +fig.text(0.30, 0.9, "Left auditory stimulation", ha="center") +fig.text(0.70, 0.9, "Right auditory stimulation", ha="center") diff --git a/examples/inverse/multidict_reweighted_tfmxne.py b/examples/inverse/multidict_reweighted_tfmxne.py index 58aa0fefb09..f12903d1754 100644 --- a/examples/inverse/multidict_reweighted_tfmxne.py +++ b/examples/inverse/multidict_reweighted_tfmxne.py @@ -38,28 +38,29 @@ # Load somatosensory MEG data data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') -fwd_fname = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" +fwd_fname = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) # Read evoked raw = mne.io.read_raw_fif(raw_fname) raw.pick_types(meg=True, eog=True, stim=True) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") reject = dict(grad=4000e-13, eog=350e-6) event_id, tmin, tmax = dict(unknown=1), -0.5, 0.5 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, reject=reject, - baseline=(None, 0)) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, reject=reject, baseline=(None, 0) +) evoked = epochs.average() evoked.crop(tmin=0.0, tmax=0.2) # Compute noise covariance matrix -cov = mne.compute_covariance(epochs, rank='info', tmax=0.) +cov = mne.compute_covariance(epochs, rank="info", tmax=0.0) del epochs, raw # Handling forward solution @@ -69,7 +70,7 @@ # Run iterative reweighted multidict TF-MxNE solver alpha, l1_ratio = 20, 0.05 -loose, depth = 0.9, 1. +loose, depth = 0.9, 1.0 # Use a multiscale time-frequency dictionary wsize, tstep = [4, 16], [2, 4] @@ -77,27 +78,42 @@ n_tfmxne_iter = 10 # Compute TF-MxNE inverse solution with dipole output dipoles, residual = tf_mixed_norm( - evoked, forward, cov, alpha=alpha, l1_ratio=l1_ratio, - n_tfmxne_iter=n_tfmxne_iter, loose=loose, - depth=depth, tol=1e-3, - wsize=wsize, tstep=tstep, return_as_dipoles=True, - return_residual=True) + evoked, + forward, + cov, + alpha=alpha, + l1_ratio=l1_ratio, + n_tfmxne_iter=n_tfmxne_iter, + loose=loose, + depth=depth, + tol=1e-3, + wsize=wsize, + tstep=tstep, + return_as_dipoles=True, + return_residual=True, +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) plot_sparse_source_estimates( - forward['src'], stc, bgcolor=(1, 1, 1), opacity=0.1, - fig_name=f"irTF-MxNE (cond {evoked.comment})") + forward["src"], + stc, + bgcolor=(1, 1, 1), + opacity=0.1, + fig_name=f"irTF-MxNE (cond {evoked.comment})", +) # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-300, 300]) -evoked.copy().pick_types(meg='grad').plot( - titles=dict(grad='Evoked Response: Gradiometers'), ylim=ylim) -residual.copy().pick_types(meg='grad').plot( - titles=dict(grad='Residuals: Gradiometers'), ylim=ylim) +evoked.copy().pick_types(meg="grad").plot( + titles=dict(grad="Evoked Response: Gradiometers"), ylim=ylim +) +residual.copy().pick_types(meg="grad").plot( + titles=dict(grad="Residuals: Gradiometers"), ylim=ylim +) # %% # References diff --git a/examples/inverse/psf_ctf_label_leakage.py b/examples/inverse/psf_ctf_label_leakage.py index 5975584c391..d74663d369a 100644 --- a/examples/inverse/psf_ctf_label_leakage.py +++ b/examples/inverse/psf_ctf_label_leakage.py @@ -25,9 +25,11 @@ import mne from mne.datasets import sample -from mne.minimum_norm import (read_inverse_operator, - make_inverse_resolution_matrix, - get_point_spread) +from mne.minimum_norm import ( + read_inverse_operator, + make_inverse_resolution_matrix, + get_point_spread, +) from mne.viz import circular_layout from mne_connectivity.viz import plot_connectivity_circle @@ -43,20 +45,20 @@ # resolution matrices for different methods. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-fixed-inv.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-fixed-inv.fif" forward = mne.read_forward_solution(fname_fwd) # Convert forward solution to fixed source orientations -mne.convert_forward_solution( - forward, surf_ori=True, force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) inverse_operator = read_inverse_operator(fname_inv) # Compute resolution matrices for MNE -rm_mne = make_inverse_resolution_matrix(forward, inverse_operator, - method='MNE', lambda2=1. / 3.**2) -src = inverse_operator['src'] +rm_mne = make_inverse_resolution_matrix( + forward, inverse_operator, method="MNE", lambda2=1.0 / 3.0**2 +) +src = inverse_operator["src"] del forward, inverse_operator # save memory # %% @@ -64,13 +66,12 @@ # -------------------------------------------------- # # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels = mne.read_labels_from_annot('sample', parc='aparc', - subjects_dir=subjects_dir) +labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir) n_labels = len(labels) label_colors = [label.color for label in labels] # First, we reorder the labels based on their location in the left hemi label_names = [label.name for label in labels] -lh_labels = [name for name in label_names if name.endswith('lh')] +lh_labels = [name for name in label_names if name.endswith("lh")] # Get the y-location of the label label_ypos = list() @@ -83,7 +84,7 @@ lh_labels = [label for (yp, label) in sorted(zip(label_ypos, lh_labels))] # For the right hemi -rh_labels = [label[:-2] + 'rh' for label in lh_labels] +rh_labels = [label[:-2] + "rh" for label in lh_labels] # %% # Compute point-spread function summaries (PCA) for all labels @@ -97,8 +98,8 @@ # spatial extents of labels. n_comp = 5 stcs_psf_mne, pca_vars_mne = get_point_spread( - rm_mne, src, labels, mode='pca', n_comp=n_comp, norm=None, - return_pca_vars=True) + rm_mne, src, labels, mode="pca", n_comp=n_comp, norm=None, return_pca_vars=True +) n_verts = rm_mne.shape[0] del rm_mne @@ -109,7 +110,7 @@ with np.printoptions(precision=1): for [name, var] in zip(label_names, pca_vars_mne): - print(f'{name}: {var.sum():.1f}% {var}') + print(f"{name}: {var.sum():.1f}% {var}") # %% # The output shows the summed variance explained by the first five principal @@ -132,15 +133,23 @@ # Save the plot order and create a circular layout node_order = lh_labels[::-1] + rh_labels # mirror label order across hemis -node_angles = circular_layout(label_names, node_order, start_pos=90, - group_boundaries=[0, len(label_names) / 2]) +node_angles = circular_layout( + label_names, node_order, start_pos=90, group_boundaries=[0, len(label_names) / 2] +) # Plot the graph using node colors from the FreeSurfer parcellation. We only # show the 200 strongest connections. fig, ax = plt.subplots( - figsize=(8, 8), facecolor='black', subplot_kw=dict(projection='polar')) -plot_connectivity_circle(leakage_mne, label_names, n_lines=200, - node_angles=node_angles, node_colors=label_colors, - title='MNE Leakage', ax=ax) + figsize=(8, 8), facecolor="black", subplot_kw=dict(projection="polar") +) +plot_connectivity_circle( + leakage_mne, + label_names, + n_lines=200, + node_angles=node_angles, + node_colors=label_colors, + title="MNE Leakage", + ax=ax, +) # %% # Most leakage occurs for neighbouring regions, but also for deeper regions @@ -175,20 +184,26 @@ # %% # Point-spread function for the lateral occipital label in the left hemisphere -brain_lh = stc_lh.plot(subjects_dir=subjects_dir, subject='sample', - hemi='both', views='caudal', - clim=dict(kind='value', - pos_lims=(0, max_val / 2., max_val))) -brain_lh.add_text(0.1, 0.9, label_names[idx[0]], 'title', font_size=16) +brain_lh = stc_lh.plot( + subjects_dir=subjects_dir, + subject="sample", + hemi="both", + views="caudal", + clim=dict(kind="value", pos_lims=(0, max_val / 2.0, max_val)), +) +brain_lh.add_text(0.1, 0.9, label_names[idx[0]], "title", font_size=16) # %% # and in the right hemisphere. -brain_rh = stc_rh.plot(subjects_dir=subjects_dir, subject='sample', - hemi='both', views='caudal', - clim=dict(kind='value', - pos_lims=(0, max_val / 2., max_val))) -brain_rh.add_text(0.1, 0.9, label_names[idx[1]], 'title', font_size=16) +brain_rh = stc_rh.plot( + subjects_dir=subjects_dir, + subject="sample", + hemi="both", + views="caudal", + clim=dict(kind="value", pos_lims=(0, max_val / 2.0, max_val)), +) +brain_rh.add_text(0.1, 0.9, label_names[idx[1]], "title", font_size=16) # %% # Both summary PSFs are confined to their respective hemispheres, indicating diff --git a/examples/inverse/psf_ctf_vertices.py b/examples/inverse/psf_ctf_vertices.py index a365991ffa1..0ec01a865dc 100644 --- a/examples/inverse/psf_ctf_vertices.py +++ b/examples/inverse/psf_ctf_vertices.py @@ -16,23 +16,25 @@ import mne from mne.datasets import sample -from mne.minimum_norm import (make_inverse_resolution_matrix, get_cross_talk, - get_point_spread) +from mne.minimum_norm import ( + make_inverse_resolution_matrix, + get_cross_talk, + get_point_spread, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution forward = mne.read_forward_solution(fname_fwd) # forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # noise covariance matrix noise_cov = mne.read_cov(fname_cov) @@ -43,23 +45,24 @@ # make inverse operator from forward solution # free source orientation inverse_operator = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # compute resolution matrix for sLORETA -rm_lor = make_inverse_resolution_matrix(forward, inverse_operator, - method='sLORETA', lambda2=lambda2) +rm_lor = make_inverse_resolution_matrix( + forward, inverse_operator, method="sLORETA", lambda2=lambda2 +) # get PSF and CTF for sLORETA at one vertex sources = [1000] -stc_psf = get_point_spread(rm_lor, forward['src'], sources, norm=True) +stc_psf = get_point_spread(rm_lor, forward["src"], sources, norm=True) -stc_ctf = get_cross_talk(rm_lor, forward['src'], sources, norm=True) +stc_ctf = get_cross_talk(rm_lor, forward["src"], sources, norm=True) del rm_lor ############################################################################## @@ -68,37 +71,41 @@ # PSF: # Which vertex corresponds to selected source -vertno_lh = forward['src'][0]['vertno'] +vertno_lh = forward["src"][0]["vertno"] verttrue = [vertno_lh[sources[0]]] # just one vertex # find vertices with maxima in PSF and CTF vert_max_psf = vertno_lh[stc_psf.data.argmax()] vert_max_ctf = vertno_lh[stc_ctf.data.argmax()] -brain_psf = stc_psf.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir) -brain_psf.show_view('ventral') -brain_psf.add_text(0.1, 0.9, 'sLORETA PSF', 'title', font_size=16) +brain_psf = stc_psf.plot("sample", "inflated", "lh", subjects_dir=subjects_dir) +brain_psf.show_view("ventral") +brain_psf.add_text(0.1, 0.9, "sLORETA PSF", "title", font_size=16) # True source location for PSF -brain_psf.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_psf.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # Maximum of PSF -brain_psf.add_foci(vert_max_psf, coords_as_verts=True, scale_factor=1., - hemi='lh', color='black') +brain_psf.add_foci( + vert_max_psf, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="black" +) # %% # CTF: -brain_ctf = stc_ctf.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir) -brain_ctf.add_text(0.1, 0.9, 'sLORETA CTF', 'title', font_size=16) -brain_ctf.show_view('ventral') -brain_ctf.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_ctf = stc_ctf.plot("sample", "inflated", "lh", subjects_dir=subjects_dir) +brain_ctf.add_text(0.1, 0.9, "sLORETA CTF", "title", font_size=16) +brain_ctf.show_view("ventral") +brain_ctf.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # Maximum of CTF -brain_ctf.add_foci(vert_max_ctf, coords_as_verts=True, scale_factor=1., - hemi='lh', color='black') +brain_ctf.add_foci( + vert_max_ctf, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="black" +) # %% diff --git a/examples/inverse/psf_ctf_vertices_lcmv.py b/examples/inverse/psf_ctf_vertices_lcmv.py index de774c2149e..7f3d2a4207e 100644 --- a/examples/inverse/psf_ctf_vertices_lcmv.py +++ b/examples/inverse/psf_ctf_vertices_lcmv.py @@ -23,75 +23,91 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" # Read raw data raw = mne.io.read_raw_fif(raw_fname) # only pick good EEG/MEG sensors -raw.info['bads'] += ['EEG 053'] # bads + 1 more -picks = mne.pick_types(raw.info, meg=True, eeg=True, exclude='bads') +raw.info["bads"] += ["EEG 053"] # bads + 1 more +picks = mne.pick_types(raw.info, meg=True, eeg=True, exclude="bads") # Find events events = mne.find_events(raw) # event_id = {'aud/l': 1, 'aud/r': 2, 'vis/l': 3, 'vis/r': 4} -event_id = {'vis/l': 3, 'vis/r': 4} - -tmin, tmax = -.2, .25 # epoch duration -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, - picks=picks, baseline=(-.2, 0.), preload=True) +event_id = {"vis/l": 3, "vis/r": 4} + +tmin, tmax = -0.2, 0.25 # epoch duration +epochs = mne.Epochs( + raw, + events, + event_id=event_id, + tmin=tmin, + tmax=tmax, + picks=picks, + baseline=(-0.2, 0.0), + preload=True, +) del raw # covariance matrix for pre-stimulus interval -tmin, tmax = -.2, 0. -cov_pre = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, - method='empirical') +tmin, tmax = -0.2, 0.0 +cov_pre = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, method="empirical") # covariance matrix for post-stimulus interval (around main evoked responses) -tmin, tmax = 0.05, .25 -cov_post = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, - method='empirical') +tmin, tmax = 0.05, 0.25 +cov_post = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, method="empirical") info = epochs.info del epochs # read forward solution forward = mne.read_forward_solution(fname_fwd) # use forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # read noise covariance matrix noise_cov = mne.read_cov(fname_cov) # regularize noise covariance (we used 'empirical' above) -noise_cov = mne.cov.regularize(noise_cov, info, mag=0.1, grad=0.1, - eeg=0.1, rank='info') +noise_cov = mne.cov.regularize(noise_cov, info, mag=0.1, grad=0.1, eeg=0.1, rank="info") ############################################################################## # Compute LCMV filters with different data covariance matrices # ------------------------------------------------------------ # compute LCMV beamformer filters for pre-stimulus interval -filters_pre = make_lcmv(info, forward, cov_pre, reg=0.05, - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) +filters_pre = make_lcmv( + info, + forward, + cov_pre, + reg=0.05, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, +) # compute LCMV beamformer filters for post-stimulus interval -filters_post = make_lcmv(info, forward, cov_post, reg=0.05, - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) +filters_post = make_lcmv( + info, + forward, + cov_post, + reg=0.05, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, +) ############################################################################## # Compute resolution matrices for the two LCMV beamformers @@ -99,14 +115,14 @@ # compute cross-talk functions (CTFs) for one target vertex sources = [3000] -verttrue = [forward['src'][0]['vertno'][sources[0]]] # pick one vertex +verttrue = [forward["src"][0]["vertno"][sources[0]]] # pick one vertex rm_pre = make_lcmv_resolution_matrix(filters_pre, forward, info) -stc_pre = get_cross_talk(rm_pre, forward['src'], sources, norm=True) +stc_pre = get_cross_talk(rm_pre, forward["src"], sources, norm=True) del rm_pre ############################################################################## rm_post = make_lcmv_resolution_matrix(filters_post, forward, info) -stc_post = get_cross_talk(rm_post, forward['src'], sources, norm=True) +stc_post = get_cross_talk(rm_post, forward["src"], sources, norm=True) del rm_post ############################################################################## @@ -114,28 +130,51 @@ # --------- # Pre: -brain_pre = stc_pre.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir, - figure=1, clim=dict(kind='value', lims=(0, .2, .4))) - -brain_pre.add_text(0.1, 0.9, 'LCMV beamformer with pre-stimulus\ndata ' - 'covariance matrix', 'title', font_size=16) +brain_pre = stc_pre.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 0.2, 0.4)), +) + +brain_pre.add_text( + 0.1, + 0.9, + "LCMV beamformer with pre-stimulus\ndata " "covariance matrix", + "title", + font_size=16, +) # mark true source location for CTFs -brain_pre.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_pre.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # %% # Post: -brain_post = stc_post.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, - figure=2, clim=dict(kind='value', lims=(0, .2, .4))) - -brain_post.add_text(0.1, 0.9, 'LCMV beamformer with post-stimulus\ndata ' - 'covariance matrix', 'title', font_size=16) - -brain_post.add_foci(verttrue, coords_as_verts=True, scale_factor=1., - hemi='lh', color='green') +brain_post = stc_post.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 0.2, 0.4)), +) + +brain_post.add_text( + 0.1, + 0.9, + "LCMV beamformer with post-stimulus\ndata " "covariance matrix", + "title", + font_size=16, +) + +brain_post.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # %% # The pre-stimulus beamformer's CTF has lower values in parietal regions diff --git a/examples/inverse/psf_volume.py b/examples/inverse/psf_volume.py index 7cfd0675cd8..f2e465c1b20 100644 --- a/examples/inverse/psf_volume.py +++ b/examples/inverse/psf_volume.py @@ -24,13 +24,12 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' -fname_trans = meg_path / 'sample_audvis_raw-trans.fif' -fname_bem = ( - subjects_dir / 'sample' / 'bem' / 'sample-5120-bem-sol.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" +fname_trans = meg_path / "sample_audvis_raw-trans.fif" +fname_bem = subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif" # %% # For the volume, create a coarse source space for speed (don't do this in @@ -42,27 +41,29 @@ # create a coarse source space src_vol = mne.setup_volume_source_space( # this is a very course resolution! - 'sample', pos=15., subjects_dir=subjects_dir, - add_interpolator=False) # usually you want True, this is just for speed + "sample", pos=15.0, subjects_dir=subjects_dir, add_interpolator=False +) # usually you want True, this is just for speed # compute the forward forward_vol = mne.make_forward_solution( # MEG-only for speed - evoked.info, fname_trans, src_vol, fname_bem, eeg=False) + evoked.info, fname_trans, src_vol, fname_bem, eeg=False +) del src_vol # %% # Now make an inverse operator and compute the PSF at a source. inverse_operator_vol = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_vol, noise_cov=noise_cov) + info=evoked.info, forward=forward_vol, noise_cov=noise_cov +) # compute resolution matrix for sLORETA rm_lor_vol = make_inverse_resolution_matrix( - forward_vol, inverse_operator_vol, method='sLORETA', lambda2=1. / 9.) + forward_vol, inverse_operator_vol, method="sLORETA", lambda2=1.0 / 9.0 +) # get PSF and CTF for sLORETA at one vertex sources_vol = [100] -stc_psf_vol = get_point_spread( - rm_lor_vol, forward_vol['src'], sources_vol, norm=True) +stc_psf_vol = get_point_spread(rm_lor_vol, forward_vol["src"], sources_vol, norm=True) del rm_lor_vol ############################################################################## @@ -71,23 +72,30 @@ # PSF: # Which vertex corresponds to selected source -src_vol = forward_vol['src'] -verttrue_vol = src_vol[0]['vertno'][sources_vol] +src_vol = forward_vol["src"] +verttrue_vol = src_vol[0]["vertno"][sources_vol] # find vertex with maximum in PSF -max_vert_idx, _ = np.unravel_index( - stc_psf_vol.data.argmax(), stc_psf_vol.data.shape) -vert_max_ctf_vol = src_vol[0]['vertno'][[max_vert_idx]] +max_vert_idx, _ = np.unravel_index(stc_psf_vol.data.argmax(), stc_psf_vol.data.shape) +vert_max_ctf_vol = src_vol[0]["vertno"][[max_vert_idx]] # plot them brain_psf_vol = stc_psf_vol.plot_3d( - 'sample', src=forward_vol['src'], views='ven', subjects_dir=subjects_dir, - volume_options=dict(alpha=0.5)) -brain_psf_vol.add_text( - 0.1, 0.9, 'Volumetric sLORETA PSF', 'title', font_size=16) + "sample", + src=forward_vol["src"], + views="ven", + subjects_dir=subjects_dir, + volume_options=dict(alpha=0.5), +) +brain_psf_vol.add_text(0.1, 0.9, "Volumetric sLORETA PSF", "title", font_size=16) brain_psf_vol.add_foci( - verttrue_vol, coords_as_verts=True, - scale_factor=1, hemi='vol', color='green') + verttrue_vol, coords_as_verts=True, scale_factor=1, hemi="vol", color="green" +) brain_psf_vol.add_foci( - vert_max_ctf_vol, coords_as_verts=True, - scale_factor=1.25, hemi='vol', color='black', alpha=0.3) + vert_max_ctf_vol, + coords_as_verts=True, + scale_factor=1.25, + hemi="vol", + color="black", + alpha=0.3, +) diff --git a/examples/inverse/rap_music.py b/examples/inverse/rap_music.py index 937351b96dd..787c6d3b8c7 100644 --- a/examples/inverse/rap_music.py +++ b/examples/inverse/rap_music.py @@ -24,16 +24,15 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -evoked_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +evoked_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" # Read the evoked response and crop it -condition = 'Right Auditory' -evoked = mne.read_evokeds(evoked_fname, condition=condition, - baseline=(None, 0)) +condition = "Right Auditory" +evoked = mne.read_evokeds(evoked_fname, condition=condition, baseline=(None, 0)) # select N100 evoked.crop(tmin=0.05, tmax=0.15) @@ -45,17 +44,16 @@ # Read noise covariance matrix noise_cov = mne.read_cov(cov_fname) -dipoles, residual = rap_music(evoked, forward, noise_cov, n_dipoles=2, - return_residual=True, verbose=True) -trans = forward['mri_head_t'] -plot_dipole_locations(dipoles, trans, 'sample', subjects_dir=subjects_dir) +dipoles, residual = rap_music( + evoked, forward, noise_cov, n_dipoles=2, return_residual=True, verbose=True +) +trans = forward["mri_head_t"] +plot_dipole_locations(dipoles, trans, "sample", subjects_dir=subjects_dir) plot_dipole_amplitudes(dipoles) # Plot the evoked data and the residual. -evoked.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), - time_unit='s') -residual.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), - time_unit='s') +evoked.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), time_unit="s") +residual.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), time_unit="s") # %% # References diff --git a/examples/inverse/read_inverse.py b/examples/inverse/read_inverse.py index fd604b08f35..a0fe1774252 100644 --- a/examples/inverse/read_inverse.py +++ b/examples/inverse/read_inverse.py @@ -21,30 +21,35 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_trans = meg_path / 'sample_audvis_raw-trans.fif' -inv_fname = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_trans = meg_path / "sample_audvis_raw-trans.fif" +inv_fname = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" inv = read_inverse_operator(inv_fname) -print("Method: %s" % inv['methods']) -print("fMRI prior: %s" % inv['fmri_prior']) -print("Number of sources: %s" % inv['nsource']) -print("Number of channels: %s" % inv['nchan']) +print("Method: %s" % inv["methods"]) +print("fMRI prior: %s" % inv["fmri_prior"]) +print("Number of sources: %s" % inv["nsource"]) +print("Number of channels: %s" % inv["nchan"]) -src = inv['src'] # get the source space +src = inv["src"] # get the source space # Get access to the triangulation of the cortex -print("Number of vertices on the left hemisphere: %d" % len(src[0]['rr'])) -print("Number of triangles on left hemisphere: %d" % len(src[0]['use_tris'])) -print("Number of vertices on the right hemisphere: %d" % len(src[1]['rr'])) -print("Number of triangles on right hemisphere: %d" % len(src[1]['use_tris'])) +print("Number of vertices on the left hemisphere: %d" % len(src[0]["rr"])) +print("Number of triangles on left hemisphere: %d" % len(src[0]["use_tris"])) +print("Number of vertices on the right hemisphere: %d" % len(src[1]["rr"])) +print("Number of triangles on right hemisphere: %d" % len(src[1]["use_tris"])) # %% # Show the 3D source space -fig = mne.viz.plot_alignment(subject='sample', subjects_dir=subjects_dir, - trans=fname_trans, surfaces='white', src=src) -set_3d_view(fig, focalpoint=(0., 0., 0.06)) +fig = mne.viz.plot_alignment( + subject="sample", + subjects_dir=subjects_dir, + trans=fname_trans, + surfaces="white", + src=src, +) +set_3d_view(fig, focalpoint=(0.0, 0.0, 0.06)) diff --git a/examples/inverse/read_stc.py b/examples/inverse/read_stc.py index 3ae91bfc799..d98ba170400 100644 --- a/examples/inverse/read_stc.py +++ b/examples/inverse/read_stc.py @@ -22,17 +22,18 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-meg' +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-meg" stc = mne.read_source_estimate(fname) n_vertices, n_samples = stc.data.shape -print("stc data size: %s (nb of vertices) x %s (nb of samples)" - % (n_vertices, n_samples)) +print( + "stc data size: %s (nb of vertices) x %s (nb of samples)" % (n_vertices, n_samples) +) # View source activations plt.plot(stc.times, stc.data[::100, :].T) -plt.xlabel('time (ms)') -plt.ylabel('Source amplitude') +plt.xlabel("time (ms)") +plt.ylabel("Source amplitude") plt.show() diff --git a/examples/inverse/resolution_metrics.py b/examples/inverse/resolution_metrics.py index 10d3e03944c..e3e98827bea 100644 --- a/examples/inverse/resolution_metrics.py +++ b/examples/inverse/resolution_metrics.py @@ -25,17 +25,16 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution forward = mne.read_forward_solution(fname_fwd) # forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # noise covariance matrix noise_cov = mne.read_cov(fname_cov) @@ -46,12 +45,12 @@ # make inverse operator from forward solution # free source orientation inverse_operator = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # %% # MNE @@ -59,12 +58,15 @@ # Compute resolution matrices, peak localisation error (PLE) for point spread # functions (PSFs), spatial deviation (SD) for PSFs: -rm_mne = make_inverse_resolution_matrix(forward, inverse_operator, - method='MNE', lambda2=lambda2) -ple_mne_psf = resolution_metrics(rm_mne, inverse_operator['src'], - function='psf', metric='peak_err') -sd_mne_psf = resolution_metrics(rm_mne, inverse_operator['src'], - function='psf', metric='sd_ext') +rm_mne = make_inverse_resolution_matrix( + forward, inverse_operator, method="MNE", lambda2=lambda2 +) +ple_mne_psf = resolution_metrics( + rm_mne, inverse_operator["src"], function="psf", metric="peak_err" +) +sd_mne_psf = resolution_metrics( + rm_mne, inverse_operator["src"], function="psf", metric="sd_ext" +) del rm_mne # %% @@ -72,39 +74,57 @@ # ---- # Do the same for dSPM: -rm_dspm = make_inverse_resolution_matrix(forward, inverse_operator, - method='dSPM', lambda2=lambda2) -ple_dspm_psf = resolution_metrics(rm_dspm, inverse_operator['src'], - function='psf', metric='peak_err') -sd_dspm_psf = resolution_metrics(rm_dspm, inverse_operator['src'], - function='psf', metric='sd_ext') +rm_dspm = make_inverse_resolution_matrix( + forward, inverse_operator, method="dSPM", lambda2=lambda2 +) +ple_dspm_psf = resolution_metrics( + rm_dspm, inverse_operator["src"], function="psf", metric="peak_err" +) +sd_dspm_psf = resolution_metrics( + rm_dspm, inverse_operator["src"], function="psf", metric="sd_ext" +) del rm_dspm, forward # %% # Visualize results # ----------------- # Visualise peak localisation error (PLE) across the whole cortex for MNE PSF: -brain_ple_mne = ple_mne_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=1, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_ple_mne.add_text(0.1, 0.9, 'PLE MNE', 'title', font_size=16) +brain_ple_mne = ple_mne_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_ple_mne.add_text(0.1, 0.9, "PLE MNE", "title", font_size=16) # %% # And dSPM: -brain_ple_dspm = ple_dspm_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=2, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_ple_dspm.add_text(0.1, 0.9, 'PLE dSPM', 'title', font_size=16) +brain_ple_dspm = ple_dspm_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_ple_dspm.add_text(0.1, 0.9, "PLE dSPM", "title", font_size=16) # %% # Subtract the two distributions and plot this difference diff_ple = ple_mne_psf - ple_dspm_psf -brain_ple_diff = diff_ple.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=3, - clim=dict(kind='value', pos_lims=(0., 1., 2.))) -brain_ple_diff.add_text(0.1, 0.9, 'PLE MNE-dSPM', 'title', font_size=16) +brain_ple_diff = diff_ple.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=3, + clim=dict(kind="value", pos_lims=(0.0, 1.0, 2.0)), +) +brain_ple_diff.add_text(0.1, 0.9, "PLE MNE-dSPM", "title", font_size=16) # %% # These plots show that dSPM has generally lower peak localization error (red @@ -114,28 +134,43 @@ # Next we'll visualise spatial deviation (SD) across the whole cortex for MNE # PSF: -brain_sd_mne = sd_mne_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=4, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_sd_mne.add_text(0.1, 0.9, 'SD MNE', 'title', font_size=16) +brain_sd_mne = sd_mne_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=4, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_sd_mne.add_text(0.1, 0.9, "SD MNE", "title", font_size=16) # %% # And dSPM: -brain_sd_dspm = sd_dspm_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=5, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_sd_dspm.add_text(0.1, 0.9, 'SD dSPM', 'title', font_size=16) +brain_sd_dspm = sd_dspm_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=5, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_sd_dspm.add_text(0.1, 0.9, "SD dSPM", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_sd = sd_mne_psf - sd_dspm_psf -brain_sd_diff = diff_sd.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=6, - clim=dict(kind='value', pos_lims=(0., 1., 2.))) -brain_sd_diff.add_text(0.1, 0.9, 'SD MNE-dSPM', 'title', font_size=16) +brain_sd_diff = diff_sd.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=6, + clim=dict(kind="value", pos_lims=(0.0, 1.0, 2.0)), +) +brain_sd_diff.add_text(0.1, 0.9, "SD MNE-dSPM", "title", font_size=16) # %% # These plots show that dSPM has generally higher spatial deviation than MNE diff --git a/examples/inverse/resolution_metrics_eegmeg.py b/examples/inverse/resolution_metrics_eegmeg.py index 06268178058..d570cb42baa 100644 --- a/examples/inverse/resolution_metrics_eegmeg.py +++ b/examples/inverse/resolution_metrics_eegmeg.py @@ -27,17 +27,18 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects/' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd_emeg = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects/" +meg_path = data_path / "MEG" / "sample" +fname_fwd_emeg = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution with EEG and MEG forward_emeg = mne.read_forward_solution(fname_fwd_emeg) # forward operator with fixed source orientations -forward_emeg = mne.convert_forward_solution(forward_emeg, surf_ori=True, - force_fixed=True) +forward_emeg = mne.convert_forward_solution( + forward_emeg, surf_ori=True, force_fixed=True +) # create a forward solution with MEG only forward_meg = mne.pick_types_forward(forward_emeg, meg=True, eeg=False) @@ -50,16 +51,16 @@ # make inverse operator from forward solution for MEG and EEGMEG inv_emeg = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_emeg, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward_emeg, noise_cov=noise_cov, loose=0.0, depth=None +) inv_meg = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_meg, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward_meg, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # %% # EEGMEG @@ -67,12 +68,15 @@ # Compute resolution matrices, localization error, and spatial deviations # for MNE: -rm_emeg = make_inverse_resolution_matrix(forward_emeg, inv_emeg, - method='MNE', lambda2=lambda2) -ple_psf_emeg = resolution_metrics(rm_emeg, inv_emeg['src'], - function='psf', metric='peak_err') -sd_psf_emeg = resolution_metrics(rm_emeg, inv_emeg['src'], - function='psf', metric='sd_ext') +rm_emeg = make_inverse_resolution_matrix( + forward_emeg, inv_emeg, method="MNE", lambda2=lambda2 +) +ple_psf_emeg = resolution_metrics( + rm_emeg, inv_emeg["src"], function="psf", metric="peak_err" +) +sd_psf_emeg = resolution_metrics( + rm_emeg, inv_emeg["src"], function="psf", metric="sd_ext" +) del rm_emeg # %% @@ -80,12 +84,13 @@ # --- # Do the same for MEG: -rm_meg = make_inverse_resolution_matrix(forward_meg, inv_meg, - method='MNE', lambda2=lambda2) -ple_psf_meg = resolution_metrics(rm_meg, inv_meg['src'], - function='psf', metric='peak_err') -sd_psf_meg = resolution_metrics(rm_meg, inv_meg['src'], - function='psf', metric='sd_ext') +rm_meg = make_inverse_resolution_matrix( + forward_meg, inv_meg, method="MNE", lambda2=lambda2 +) +ple_psf_meg = resolution_metrics( + rm_meg, inv_meg["src"], function="psf", metric="peak_err" +) +sd_psf_meg = resolution_metrics(rm_meg, inv_meg["src"], function="psf", metric="sd_ext") del rm_meg # %% @@ -93,64 +98,94 @@ # ------------- # Look at peak localisation error (PLE) across the whole cortex for PSF: -brain_ple_emeg = ple_psf_emeg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=1, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_ple_emeg = ple_psf_emeg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_ple_emeg.add_text(0.1, 0.9, 'PLE PSF EMEG', 'title', font_size=16) +brain_ple_emeg.add_text(0.1, 0.9, "PLE PSF EMEG", "title", font_size=16) # %% # For MEG only: -brain_ple_meg = ple_psf_meg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=2, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_ple_meg = ple_psf_meg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_ple_meg.add_text(0.1, 0.9, 'PLE PSF MEG', 'title', font_size=16) +brain_ple_meg.add_text(0.1, 0.9, "PLE PSF MEG", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_ple = ple_psf_emeg - ple_psf_meg -brain_ple_diff = diff_ple.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=3, - clim=dict(kind='value', pos_lims=(0., .5, 1.)), - smoothing_steps=20) +brain_ple_diff = diff_ple.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=3, + clim=dict(kind="value", pos_lims=(0.0, 0.5, 1.0)), + smoothing_steps=20, +) -brain_ple_diff.add_text(0.1, 0.9, 'PLE EMEG-MEG', 'title', font_size=16) +brain_ple_diff.add_text(0.1, 0.9, "PLE EMEG-MEG", "title", font_size=16) # %% # These plots show that with respect to peak localization error, adding EEG to # MEG does not bring much benefit. Next let's visualise spatial deviation (SD) # across the whole cortex for PSF: -brain_sd_emeg = sd_psf_emeg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=4, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_sd_emeg = sd_psf_emeg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=4, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_sd_emeg.add_text(0.1, 0.9, 'SD PSF EMEG', 'title', font_size=16) +brain_sd_emeg.add_text(0.1, 0.9, "SD PSF EMEG", "title", font_size=16) # %% # For MEG only: -brain_sd_meg = sd_psf_meg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=5, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_sd_meg = sd_psf_meg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=5, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_sd_meg.add_text(0.1, 0.9, 'SD PSF MEG', 'title', font_size=16) +brain_sd_meg.add_text(0.1, 0.9, "SD PSF MEG", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_sd = sd_psf_emeg - sd_psf_meg -brain_sd_diff = diff_sd.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=6, - clim=dict(kind='value', pos_lims=(0., .5, 1.)), - smoothing_steps=20) - -brain_sd_diff.add_text(0.1, 0.9, 'SD EMEG-MEG', 'title', font_size=16) +brain_sd_diff = diff_sd.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=6, + clim=dict(kind="value", pos_lims=(0.0, 0.5, 1.0)), + smoothing_steps=20, +) + +brain_sd_diff.add_text(0.1, 0.9, "SD EMEG-MEG", "title", font_size=16) # %% # Adding EEG to MEG decreases the spatial extent of point-spread diff --git a/examples/inverse/snr_estimate.py b/examples/inverse/snr_estimate.py index 956f3cbe643..4a88a9d13c4 100644 --- a/examples/inverse/snr_estimate.py +++ b/examples/inverse/snr_estimate.py @@ -21,9 +21,9 @@ print(__doc__) -data_dir = data_path() / 'MEG' / 'sample' -fname_inv = data_dir / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_evoked = data_dir / 'sample_audvis-ave.fif' +data_dir = data_path() / "MEG" / "sample" +fname_inv = data_dir / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = data_dir / "sample_audvis-ave.fif" inv = read_inverse_operator(fname_inv) evoked = read_evokeds(fname_evoked, baseline=(None, 0))[0] diff --git a/examples/inverse/source_space_snr.py b/examples/inverse/source_space_snr.py index 0dd14e71722..c5599a5d331 100644 --- a/examples/inverse/source_space_snr.py +++ b/examples/inverse/source_space_snr.py @@ -26,15 +26,14 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' +subjects_dir = data_path / "subjects" # Read data -meg_path = data_path / 'MEG' / 'sample' -fname_evoked = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname_evoked, condition='Left Auditory', - baseline=(None, 0)) -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' +meg_path = data_path / "MEG" / "sample" +fname_evoked = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname_evoked, condition="Left Auditory", baseline=(None, 0)) +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" fwd = mne.read_forward_solution(fname_fwd) cov = mne.read_cov(fname_cov) @@ -43,8 +42,8 @@ # Calculate MNE: snr = 3.0 -lambda2 = 1.0 / snr ** 2 -stc = apply_inverse(evoked, inv_op, lambda2, 'MNE', verbose=True) +lambda2 = 1.0 / snr**2 +stc = apply_inverse(evoked, inv_op, lambda2, "MNE", verbose=True) # Calculate SNR in source space: snr_stc = stc.estimate_snr(evoked.info, fwd, cov) @@ -54,17 +53,23 @@ fig, ax = plt.subplots() ax.plot(evoked.times, ave) -ax.set(xlabel='Time (s)', ylabel='SNR MEG-EEG') +ax.set(xlabel="Time (s)", ylabel="SNR MEG-EEG") fig.tight_layout() # Find time point of maximum SNR maxidx = np.argmax(ave) # Plot SNR on source space at the time point of maximum SNR: -kwargs = dict(initial_time=evoked.times[maxidx], hemi='split', - views=['lat', 'med'], subjects_dir=subjects_dir, size=(600, 600), - clim=dict(kind='value', lims=(-100, -70, -40)), - transparent=True, colormap='viridis') +kwargs = dict( + initial_time=evoked.times[maxidx], + hemi="split", + views=["lat", "med"], + subjects_dir=subjects_dir, + size=(600, 600), + clim=dict(kind="value", lims=(-100, -70, -40)), + transparent=True, + colormap="viridis", +) brain = snr_stc.plot(**kwargs) # %% @@ -73,9 +78,8 @@ # Next we do the same for EEG and plot the result on the cortex: evoked_eeg = evoked.copy().pick_types(eeg=True, meg=False) -inv_op_eeg = make_inverse_operator(evoked_eeg.info, fwd, cov, fixed=True, - verbose=True) -stc_eeg = apply_inverse(evoked_eeg, inv_op_eeg, lambda2, 'MNE', verbose=True) +inv_op_eeg = make_inverse_operator(evoked_eeg.info, fwd, cov, fixed=True, verbose=True) +stc_eeg = apply_inverse(evoked_eeg, inv_op_eeg, lambda2, "MNE", verbose=True) snr_stc_eeg = stc_eeg.estimate_snr(evoked_eeg.info, fwd, cov) brain = snr_stc_eeg.plot(**kwargs) diff --git a/examples/inverse/time_frequency_mixed_norm_inverse.py b/examples/inverse/time_frequency_mixed_norm_inverse.py index 2271c58f24c..d69d4769058 100644 --- a/examples/inverse/time_frequency_mixed_norm_inverse.py +++ b/examples/inverse/time_frequency_mixed_norm_inverse.py @@ -32,23 +32,26 @@ from mne.datasets import sample from mne.minimum_norm import make_inverse_operator, apply_inverse from mne.inverse_sparse import tf_mixed_norm, make_stc_from_dipoles -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" # Read noise covariance matrix cov = mne.read_cov(cov_fname) # Handling average file -condition = 'Left visual' +condition = "Left visual" evoked = mne.read_evokeds(ave_fname, condition=condition, baseline=(None, 0)) # We make the window slightly larger than what you'll eventually be interested # in ([-0.05, 0.3]) to avoid edge effects. @@ -61,7 +64,7 @@ # Run solver # alpha parameter is between 0 and 100 (100 gives 0 active source) -alpha = 40. # general regularization parameter +alpha = 40.0 # general regularization parameter # l1_ratio parameter between 0 and 1 promotes temporal smoothness # (0 means no temporal regularization) l1_ratio = 0.03 # temporal regularization parameter @@ -69,17 +72,31 @@ loose, depth = 0.2, 0.9 # loose orientation & depth weighting # Compute dSPM solution to be used as weights in MxNE -inverse_operator = make_inverse_operator(evoked.info, forward, cov, - loose=loose, depth=depth) -stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1. / 9., - method='dSPM') +inverse_operator = make_inverse_operator( + evoked.info, forward, cov, loose=loose, depth=depth +) +stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1.0 / 9.0, method="dSPM") # Compute TF-MxNE inverse solution with dipole output dipoles, residual = tf_mixed_norm( - evoked, forward, cov, alpha=alpha, l1_ratio=l1_ratio, loose=loose, - depth=depth, maxit=200, tol=1e-6, weights=stc_dspm, weights_min=8., - debias=True, wsize=16, tstep=4, window=0.05, return_as_dipoles=True, - return_residual=True) + evoked, + forward, + cov, + alpha=alpha, + l1_ratio=l1_ratio, + loose=loose, + depth=depth, + maxit=200, + tol=1e-6, + weights=stc_dspm, + weights_min=8.0, + debias=True, + wsize=16, + tstep=4, + window=0.05, + return_as_dipoles=True, + return_residual=True, +) # Crop to remove edges for dip in dipoles: @@ -94,9 +111,14 @@ # %% # Plot location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # # Plot dipole locations of all dipoles with MRI slices: # for dip in dipoles: @@ -107,31 +129,51 @@ # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-120, 120]) -evoked.pick_types(meg='grad', exclude='bads') -evoked.plot(titles=dict(grad='Evoked Response: Gradiometers'), ylim=ylim, - proj=True, time_unit='s') - -residual.pick_types(meg='grad', exclude='bads') -residual.plot(titles=dict(grad='Residuals: Gradiometers'), ylim=ylim, - proj=True, time_unit='s') +evoked.pick_types(meg="grad", exclude="bads") +evoked.plot( + titles=dict(grad="Evoked Response: Gradiometers"), + ylim=ylim, + proj=True, + time_unit="s", +) + +residual.pick_types(meg="grad", exclude="bads") +residual.plot( + titles=dict(grad="Residuals: Gradiometers"), ylim=ylim, proj=True, time_unit="s" +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - opacity=0.1, fig_name="TF-MxNE (cond %s)" - % condition, modes=['sphere'], scale_factors=[1.]) - -time_label = 'TF-MxNE time=%0.2f ms' -clim = dict(kind='value', lims=[10e-9, 15e-9, 20e-9]) -brain = stc.plot('sample', 'inflated', 'rh', views='medial', - clim=clim, time_label=time_label, smoothing_steps=5, - subjects_dir=subjects_dir, initial_time=150, time_unit='ms') -brain.add_label("V1", color="yellow", scalar_thresh=.5, borders=True) -brain.add_label("V2", color="red", scalar_thresh=.5, borders=True) +plot_sparse_source_estimates( + forward["src"], + stc, + bgcolor=(1, 1, 1), + opacity=0.1, + fig_name="TF-MxNE (cond %s)" % condition, + modes=["sphere"], + scale_factors=[1.0], +) + +time_label = "TF-MxNE time=%0.2f ms" +clim = dict(kind="value", lims=[10e-9, 15e-9, 20e-9]) +brain = stc.plot( + "sample", + "inflated", + "rh", + views="medial", + clim=clim, + time_label=time_label, + smoothing_steps=5, + subjects_dir=subjects_dir, + initial_time=150, + time_unit="ms", +) +brain.add_label("V1", color="yellow", scalar_thresh=0.5, borders=True) +brain.add_label("V2", color="red", scalar_thresh=0.5, borders=True) # %% # References diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index caba3a46201..2733f40acd1 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -33,34 +33,37 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' +subjects_dir = data_path / "subjects" smoothing_steps = 7 # Read evoked data -meg_path = data_path / 'MEG' / 'sample' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname_evoked = meg_path / "sample_audvis-ave.fif" evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) # Read inverse solution -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" inv = read_inverse_operator(fname_inv) # Apply inverse solution, set pick_ori='vector' to obtain a # :class:`mne.VectorSourceEstimate` object snr = 3.0 -lambda2 = 1.0 / snr ** 2 -stc = apply_inverse(evoked, inv, lambda2, 'dSPM', pick_ori='vector') +lambda2 = 1.0 / snr**2 +stc = apply_inverse(evoked, inv, lambda2, "dSPM", pick_ori="vector") # Use peak getter to move visualization to the time point of the peak magnitude -_, peak_time = stc.magnitude().get_peak(hemi='lh') +_, peak_time = stc.magnitude().get_peak(hemi="lh") # %% # Plot the source estimate: # sphinx_gallery_thumbnail_number = 2 brain = stc.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + smoothing_steps=smoothing_steps, +) # You can save a brain movie with: # brain.save_movie(time_dilation=20, tmin=0.05, tmax=0.16, framerate=10, @@ -69,32 +72,43 @@ # %% # Plot the activation in the direction of maximal power for this data: -stc_max, directions = stc.project('pca', src=inv['src']) +stc_max, directions = stc.project("pca", src=inv["src"]) # These directions must by design be close to the normals because this # inverse was computed with loose=0.2 -print('Absolute cosine similarity between source normals and directions: ' - f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}') +print( + "Absolute cosine similarity between source normals and directions: " + f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}' +) brain_max = stc_max.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - time_label='Max power', smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + time_label="Max power", + smoothing_steps=smoothing_steps, +) # %% # The normal is very similar: -brain_normal = stc.project('normal', inv['src'])[0].plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - time_label='Normal', smoothing_steps=smoothing_steps) +brain_normal = stc.project("normal", inv["src"])[0].plot( + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + time_label="Normal", + smoothing_steps=smoothing_steps, +) # %% # You can also do this with a fixed-orientation inverse. It looks a lot like # the result above because the ``loose=0.2`` orientation constraint keeps # sources close to fixed orientation: -fname_inv_fixed = ( - meg_path / 'sample_audvis-meg-oct-6-meg-fixed-inv.fif') +fname_inv_fixed = meg_path / "sample_audvis-meg-oct-6-meg-fixed-inv.fif" inv_fixed = read_inverse_operator(fname_inv_fixed) -stc_fixed = apply_inverse( - evoked, inv_fixed, lambda2, 'dSPM', pick_ori='vector') +stc_fixed = apply_inverse(evoked, inv_fixed, lambda2, "dSPM", pick_ori="vector") brain_fixed = stc_fixed.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + smoothing_steps=smoothing_steps, +) diff --git a/examples/io/elekta_epochs.py b/examples/io/elekta_epochs.py index 8c24902d209..125a1e2c028 100644 --- a/examples/io/elekta_epochs.py +++ b/examples/io/elekta_epochs.py @@ -20,7 +20,7 @@ import os from mne.datasets import multimodal -fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif') +fname_raw = os.path.join(multimodal.data_path(), "multimodal_raw.fif") print(__doc__) @@ -35,9 +35,9 @@ # %% # Extract epochs corresponding to a category -cond = raw.acqparser.get_condition(raw, 'Auditory right') +cond = raw.acqparser.get_condition(raw, "Auditory right") epochs = mne.Epochs(raw, **cond) -epochs.average().plot_topo(background_color='w') +epochs.average().plot_topo(background_color="w") # %% # Get epochs from all conditions, average @@ -45,10 +45,11 @@ for cat in raw.acqparser.categories: cond = raw.acqparser.get_condition(raw, cat) # copy (supported) rejection parameters from DACQ settings - epochs = mne.Epochs(raw, reject=raw.acqparser.reject, - flat=raw.acqparser.flat, **cond) + epochs = mne.Epochs( + raw, reject=raw.acqparser.reject, flat=raw.acqparser.flat, **cond + ) evoked = epochs.average() - evoked.comment = cat['comment'] + evoked.comment = cat["comment"] evokeds.append(evoked) # save all averages to an evoked fiff file # fname_out = 'multimodal-ave.fif' @@ -57,16 +58,15 @@ # %% # Make a new averaging category newcat = dict() -newcat['comment'] = 'Visual lower left, longer epochs' -newcat['event'] = 3 # reference event -newcat['start'] = -.2 # epoch start rel. to ref. event (in seconds) -newcat['end'] = .7 # epoch end -newcat['reqevent'] = 0 # additional required event; 0 if none -newcat['reqwithin'] = .5 # ...required within .5 s (before or after) -newcat['reqwhen'] = 2 # ...required before (1) or after (2) ref. event -newcat['index'] = 9 # can be set freely +newcat["comment"] = "Visual lower left, longer epochs" +newcat["event"] = 3 # reference event +newcat["start"] = -0.2 # epoch start rel. to ref. event (in seconds) +newcat["end"] = 0.7 # epoch end +newcat["reqevent"] = 0 # additional required event; 0 if none +newcat["reqwithin"] = 0.5 # ...required within .5 s (before or after) +newcat["reqwhen"] = 2 # ...required before (1) or after (2) ref. event +newcat["index"] = 9 # can be set freely cond = raw.acqparser.get_condition(raw, newcat) -epochs = mne.Epochs(raw, reject=raw.acqparser.reject, - flat=raw.acqparser.flat, **cond) -epochs.average().plot(time_unit='s') +epochs = mne.Epochs(raw, reject=raw.acqparser.reject, flat=raw.acqparser.flat, **cond) +epochs.average().plot(time_unit="s") diff --git a/examples/io/read_neo_format.py b/examples/io/read_neo_format.py index 43b8a98f876..7847e23dcfa 100644 --- a/examples/io/read_neo_format.py +++ b/examples/io/read_neo_format.py @@ -22,15 +22,15 @@ # demonstrate the steps to using NEO data. For actual data and different file # formats, consult the NEO documentation. -reader = neo.io.ExampleIO('fakedata.nof') +reader = neo.io.ExampleIO("fakedata.nof") block = reader.read(lazy=False)[0] # get the first block -segment = block.segments[0] # get data from first (and only) segment +segment = block.segments[0] # get data from first (and only) segment signals = segment.analogsignals[0] # get first (multichannel) signal -data = signals.rescale('V').magnitude.T +data = signals.rescale("V").magnitude.T sfreq = signals.sampling_rate.magnitude -ch_names = [f'Neo {(idx + 1):02}' for idx in range(signals.shape[1])] -ch_types = ['eeg'] * len(ch_names) # if not specified, type 'misc' is assumed +ch_names = [f"Neo {(idx + 1):02}" for idx in range(signals.shape[1])] +ch_types = ["eeg"] * len(ch_names) # if not specified, type 'misc' is assumed info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = mne.io.RawArray(data, info) diff --git a/examples/io/read_noise_covariance_matrix.py b/examples/io/read_noise_covariance_matrix.py index 57b0d314e25..ba9e126a4ea 100644 --- a/examples/io/read_noise_covariance_matrix.py +++ b/examples/io/read_noise_covariance_matrix.py @@ -17,8 +17,8 @@ from mne.datasets import sample data_path = sample.data_path() -fname_cov = data_path / 'MEG' / 'sample' / 'sample_audvis-cov.fif' -fname_evo = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname_cov = data_path / "MEG" / "sample" / "sample_audvis-cov.fif" +fname_evo = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" cov = mne.read_cov(fname_cov) print(cov) @@ -27,4 +27,4 @@ # %% # Plot covariance -cov.plot(ev_info, exclude='bads', show_svd=False) +cov.plot(ev_info, exclude="bads", show_svd=False) diff --git a/examples/io/read_xdf.py b/examples/io/read_xdf.py index d65784d85ad..1edc8faf2e6 100644 --- a/examples/io/read_xdf.py +++ b/examples/io/read_xdf.py @@ -22,15 +22,13 @@ import mne from mne.datasets import misc -fname = ( - misc.data_path() / 'xdf' / - 'sub-P001_ses-S004_task-Default_run-001_eeg_a2.xdf') +fname = misc.data_path() / "xdf" / "sub-P001_ses-S004_task-Default_run-001_eeg_a2.xdf" streams, header = pyxdf.load_xdf(fname) data = streams[0]["time_series"].T assert data.shape[0] == 5 # four raw EEG plus one stim channel data[:4:2] -= data[1:4:2] # subtract (rereference) to get two bipolar EEG data = data[::2] # subselect -data[:2] *= (1e-6 / 50 / 2) # uV -> V and preamp gain +data[:2] *= 1e-6 / 50 / 2 # uV -> V and preamp gain sfreq = float(streams[0]["info"]["nominal_srate"][0]) info = mne.create_info(3, sfreq, ["eeg", "eeg", "stim"]) raw = mne.io.RawArray(data, info) diff --git a/examples/preprocessing/contralateral_referencing.py b/examples/preprocessing/contralateral_referencing.py index 2c04ccc7c8f..c3aff2afe16 100644 --- a/examples/preprocessing/contralateral_referencing.py +++ b/examples/preprocessing/contralateral_referencing.py @@ -15,28 +15,24 @@ import mne ssvep_folder = mne.datasets.ssvep.data_path() -ssvep_data_raw_path = (ssvep_folder / 'sub-02' / 'ses-01' / 'eeg' / - 'sub-02_ses-01_task-ssvep_eeg.vhdr') +ssvep_data_raw_path = ( + ssvep_folder / "sub-02" / "ses-01" / "eeg" / "sub-02_ses-01_task-ssvep_eeg.vhdr" +) raw = mne.io.read_raw(ssvep_data_raw_path, preload=True) -_ = raw.set_montage('easycap-M1') +_ = raw.set_montage("easycap-M1") # %% # The electrodes TP9 and TP10 are near the mastoids so we'll use them as our # contralateral reference channels. Then we'll create our hemisphere groups. -raw.rename_channels({ - 'TP9': 'M1', - 'TP10': 'M2' -}) +raw.rename_channels({"TP9": "M1", "TP10": "M2"}) # this splits electrodes into 3 groups; left, midline, and right -ch_names = mne.channels.make_1020_channel_selections( - raw.info, return_ch_names=True -) +ch_names = mne.channels.make_1020_channel_selections(raw.info, return_ch_names=True) # remove the ref channels from the lists of to-be-rereferenced channels -ch_names['Left'].remove('M1') -ch_names['Right'].remove('M2') +ch_names["Left"].remove("M1") +ch_names["Right"].remove("M2") # %% # Finally we do the referencing. For the midline channels we'll reference them @@ -44,25 +40,23 @@ # reference to the single contralateral mastoid channel. # midline referencing to mean of mastoids: -mastoids = ['M1', 'M2'] -rereferenced_midline_chs = (raw.copy() - .pick(mastoids + ch_names['Midline']) - .set_eeg_reference(mastoids) - .drop_channels(mastoids) - ) +mastoids = ["M1", "M2"] +rereferenced_midline_chs = ( + raw.copy() + .pick(mastoids + ch_names["Midline"]) + .set_eeg_reference(mastoids) + .drop_channels(mastoids) +) # contralateral referencing (alters channels in `raw` in-place): -for ref, hemi in dict(M2=ch_names['Left'], M1=ch_names['Right']).items(): - mne.set_bipolar_reference( - raw, anode=hemi, cathode=[ref] * len(hemi), copy=False - ) +for ref, hemi in dict(M2=ch_names["Left"], M1=ch_names["Right"]).items(): + mne.set_bipolar_reference(raw, anode=hemi, cathode=[ref] * len(hemi), copy=False) # strip off '-M1' and '-M2' suffixes added to each bipolar-referenced channel -raw.rename_channels(lambda ch_name: ch_name.split('-')[0]) +raw.rename_channels(lambda ch_name: ch_name.split("-")[0]) # replace unreferenced midline with rereferenced midline -_ = (raw.drop_channels(ch_names['Midline']) - .add_channels([rereferenced_midline_chs])) +_ = raw.drop_channels(ch_names["Midline"]).add_channels([rereferenced_midline_chs]) # %% # Make sure the channel locations still look right: -fig = raw.plot_sensors(show_names=True, sphere='eeglab') +fig = raw.plot_sensors(show_names=True, sphere="eeglab") diff --git a/examples/preprocessing/css.py b/examples/preprocessing/css.py index 2631dc54d23..73e86c1b389 100644 --- a/examples/preprocessing/css.py +++ b/examples/preprocessing/css.py @@ -27,26 +27,25 @@ ############################################################################### # Load sample subject data data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' -trans_fname = meg_path / 'sample_audvis_raw-trans.fif' -bem_fname = subjects_dir / 'sample' / 'bem' / '/sample-5120-bem-sol.fif' - -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" +trans_fname = meg_path / "sample_audvis_raw-trans.fif" +bem_fname = subjects_dir / "sample" / "bem" / "/sample-5120-bem-sol.fif" + +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") fwd = mne.read_forward_solution(fwd_fname) fwd = mne.convert_forward_solution(fwd, force_fixed=True, surf_ori=True) -fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads']) +fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info["bads"]) cov = mne.read_cov(cov_fname) ############################################################################### # Find patches (labels) to activate -all_labels = mne.read_labels_from_annot(subject='sample', - subjects_dir=subjects_dir) +all_labels = mne.read_labels_from_annot(subject="sample", subjects_dir=subjects_dir) labels = [] -for select_label in ['parahippocampal-lh', 'postcentral-rh']: +for select_label in ["parahippocampal-lh", "postcentral-rh"]: labels.append([lab for lab in all_labels if lab.name in select_label][0]) hiplab, postcenlab = labels @@ -64,32 +63,38 @@ def subcortical_waveform(times): return 10e-9 * np.cos(times * 2 * np.pi * 239) -times = np.linspace(0, 0.5, int(0.5 * raw.info['sfreq'])) -stc = simulate_sparse_stc(fwd['src'], n_dipoles=2, times=times, - location='center', subjects_dir=subjects_dir, - labels=[postcenlab, hiplab], - data_fun=cortical_waveform) -stc.data[np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], :] = \ - subcortical_waveform(times) +times = np.linspace(0, 0.5, int(0.5 * raw.info["sfreq"])) +stc = simulate_sparse_stc( + fwd["src"], + n_dipoles=2, + times=times, + location="center", + subjects_dir=subjects_dir, + labels=[postcenlab, hiplab], + data_fun=cortical_waveform, +) +stc.data[ + np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], : +] = subcortical_waveform(times) evoked = simulate_evoked(fwd, stc, raw.info, cov, nave=15) ############################################################################### # Process with CSS and plot PSD of EEG data before and after processing -evoked_subcortical = mne.preprocessing.cortical_signal_suppression(evoked, - n_proj=6) +evoked_subcortical = mne.preprocessing.cortical_signal_suppression(evoked, n_proj=6) chs = mne.pick_types(evoked.info, meg=False, eeg=True) -psd = np.mean(np.abs(np.fft.rfft(evoked.data))**2, axis=0) -psd_proc = np.mean(np.abs(np.fft.rfft(evoked_subcortical.data))**2, axis=0) -freq = np.arange(0, stop=int(evoked.info['sfreq'] / 2), - step=evoked.info['sfreq'] / (2 * len(psd))) +psd = np.mean(np.abs(np.fft.rfft(evoked.data)) ** 2, axis=0) +psd_proc = np.mean(np.abs(np.fft.rfft(evoked_subcortical.data)) ** 2, axis=0) +freq = np.arange( + 0, stop=int(evoked.info["sfreq"] / 2), step=evoked.info["sfreq"] / (2 * len(psd)) +) fig, ax = plt.subplots() -ax.plot(freq, psd, label='raw') -ax.plot(freq, psd_proc, label='processed') -ax.text(.2, .7, 'cortical', transform=ax.transAxes) -ax.text(.8, .25, 'subcortical', transform=ax.transAxes) -ax.set(ylabel='EEG Power spectral density', xlabel='Frequency (Hz)') +ax.plot(freq, psd, label="raw") +ax.plot(freq, psd_proc, label="processed") +ax.text(0.2, 0.7, "cortical", transform=ax.transAxes) +ax.text(0.8, 0.25, "subcortical", transform=ax.transAxes) +ax.set(ylabel="EEG Power spectral density", xlabel="Frequency (Hz)") ax.legend() # References diff --git a/examples/preprocessing/define_target_events.py b/examples/preprocessing/define_target_events.py index f35b16743d9..51e0fdbb960 100644 --- a/examples/preprocessing/define_target_events.py +++ b/examples/preprocessing/define_target_events.py @@ -33,9 +33,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) @@ -43,25 +43,33 @@ # Set up pick list: EEG + STI 014 - bad channels (modify to your needs) include = [] # or stim channels ['STI 014'] -raw.info['bads'] += ['EEG 053'] # bads +raw.info["bads"] += ["EEG 053"] # bads # pick MEG channels -picks = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, + meg="mag", + eeg=False, + stim=False, + eog=True, + include=include, + exclude="bads", +) # %% # Find stimulus event followed by quick button presses reference_id = 5 # presentation of a smiley face target_id = 32 # button press -sfreq = raw.info['sfreq'] # sampling rate +sfreq = raw.info["sfreq"] # sampling rate tmin = 0.1 # trials leading to very early responses will be rejected tmax = 0.59 # ignore face stimuli followed by button press later than 590 ms new_id = 42 # the new event id for a hit. If None, reference_id is used. fill_na = 99 # the fill value for misses -events_, lag = define_target_events(events, reference_id, target_id, - sfreq, tmin, tmax, new_id, fill_na) +events_, lag = define_target_events( + events, reference_id, target_id, sfreq, tmin, tmax, new_id, fill_na +) print(events_) # The 99 indicates missing or too late button presses @@ -77,9 +85,16 @@ tmax_ = 0.4 event_id = dict(early=new_id, late=fill_na) -epochs = mne.Epochs(raw, events_, event_id, tmin_, - tmax_, picks=picks, baseline=(None, 0), - reject=dict(mag=4e-12)) +epochs = mne.Epochs( + raw, + events_, + event_id, + tmin_, + tmax_, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12), +) # average epochs and get an Evoked dataset. @@ -89,11 +104,11 @@ # View evoked response times = 1e3 * epochs.times # time in milliseconds -title = 'Evoked response followed by %s button press' +title = "Evoked response followed by %s button press" fig, axes = plt.subplots(2, 1) -early.plot(axes=axes[0], time_unit='s') -axes[0].set(title=title % 'late', ylabel='Evoked field (fT)') -late.plot(axes=axes[1], time_unit='s') -axes[1].set(title=title % 'early', ylabel='Evoked field (fT)') +early.plot(axes=axes[0], time_unit="s") +axes[0].set(title=title % "late", ylabel="Evoked field (fT)") +late.plot(axes=axes[1], time_unit="s") +axes[1].set(title=title % "early", ylabel="Evoked field (fT)") plt.show() diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index fa94e752c71..09319a2cdea 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -60,11 +60,11 @@ # bridging so using the last segment of the data will # give the most conservative estimate. -montage = mne.channels.make_standard_montage('standard_1005') +montage = mne.channels.make_standard_montage("standard_1005") ed_data = dict() # electrical distance/bridging data raw_data = dict() # store infos for electrode positions for sub in range(1, 11): - print(f'Computing electrode bridges for subject {sub}') + print(f"Computing electrode bridges for subject {sub}") raw_fname = mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0] raw = mne.io.read_raw(raw_fname, preload=True, verbose=False) mne.datasets.eegbci.standardize(raw) # set channel names @@ -89,7 +89,7 @@ bridged_idx, ed_matrix = ed_data[6] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) -fig.suptitle('Subject 6 Electrical Distance Matrix') +fig.suptitle("Subject 6 Electrical Distance Matrix") # take median across epochs, only use upper triangular, lower is NaNs ed_plot = np.zeros(ed_matrix.shape[1:]) * np.nan @@ -98,17 +98,17 @@ ed_plot[idx0, idx1] = np.nanmedian(ed_matrix[:, idx0, idx1]) # plot full distribution color range -im1 = ax1.imshow(ed_plot, aspect='auto') +im1 = ax1.imshow(ed_plot, aspect="auto") cax1 = fig.colorbar(im1, ax=ax1) -cax1.set_label(r'Electrical Distance ($\mu$$V^2$)') +cax1.set_label(r"Electrical Distance ($\mu$$V^2$)") # plot zoomed in colors -im2 = ax2.imshow(ed_plot, aspect='auto', vmax=5) +im2 = ax2.imshow(ed_plot, aspect="auto", vmax=5) cax2 = fig.colorbar(im2, ax=ax2) -cax2.set_label(r'Electrical Distance ($\mu$$V^2$)') +cax2.set_label(r"Electrical Distance ($\mu$$V^2$)") for ax in (ax1, ax2): - ax.set_xlabel('Channel Index') - ax.set_ylabel('Channel Index') + ax.set_xlabel("Channel Index") + ax.set_ylabel("Channel Index") fig.tight_layout() @@ -125,10 +125,10 @@ # without bridged electrodes do not have a peak near zero. fig, ax = plt.subplots(figsize=(5, 5)) -fig.suptitle('Subject 6 Electrical Distance Matrix Distribution') +fig.suptitle("Subject 6 Electrical Distance Matrix Distribution") ax.hist(ed_matrix[~np.isnan(ed_matrix)], bins=np.linspace(0, 500, 51)) -ax.set_xlabel(r'Electrical Distance ($\mu$$V^2$)') -ax.set_ylabel('Count (channel pairs for all epochs)') +ax.set_xlabel(r"Electrical Distance ($\mu$$V^2$)") +ax.set_ylabel("Count (channel pairs for all epochs)") # %% # Plot Electrical Distances on a Topomap @@ -145,8 +145,12 @@ # may have inserted the gel syringe tip in too far). mne.viz.plot_bridged_electrodes( - raw_data[6].info, bridged_idx, ed_matrix, - title='Subject 6 Bridged Electrodes', topomap_args=dict(vlim=(None, 5))) + raw_data[6].info, + bridged_idx, + ed_matrix, + title="Subject 6 Bridged Electrodes", + topomap_args=dict(vlim=(None, 5)), +) # %% # Plot the Raw Voltage Time Series for Bridged Electrodes @@ -160,18 +164,30 @@ # pairs, meaning that it is unlikely that all four of these electrodes are # bridged. -raw = raw_data[6].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) -raw.add_channels([mne.io.RawArray( - raw.get_data(ch1) - raw.get_data(ch2), - mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), - raw.first_samp) for ch1, ch2 in [('F2', 'F4'), ('FC2', 'FC4')]]) +raw = raw_data[6].copy().pick_channels(["FC2", "FC4", "F2", "F4"]) +raw.add_channels( + [ + mne.io.RawArray( + raw.get_data(ch1) - raw.get_data(ch2), + mne.create_info([f"{ch1}-{ch2}"], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + for ch1, ch2 in [("F2", "F4"), ("FC2", "FC4")] + ] +) raw.plot(duration=20, scalings=dict(eeg=2e-4)) -raw = raw_data[1].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) -raw.add_channels([mne.io.RawArray( - raw.get_data(ch1) - raw.get_data(ch2), - mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), - raw.first_samp) for ch1, ch2 in [('F2', 'F4'), ('FC2', 'FC4')]]) +raw = raw_data[1].copy().pick_channels(["FC2", "FC4", "F2", "F4"]) +raw.add_channels( + [ + mne.io.RawArray( + raw.get_data(ch1) - raw.get_data(ch2), + mne.create_info([f"{ch1}-{ch2}"], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + for ch1, ch2 in [("F2", "F4"), ("FC2", "FC4")] + ] +) raw.plot(duration=20, scalings=dict(eeg=2e-4)) # %% @@ -193,23 +209,25 @@ # distance from the sensors to the brain). fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) -fig.suptitle('Electrical Distance Distribution for EEGBCI Subjects') +fig.suptitle("Electrical Distance Distribution for EEGBCI Subjects") for ax in (ax1, ax2): - ax.set_ylabel('Count') - ax.set_xlabel(r'Electrical Distance ($\mu$$V^2$)') + ax.set_ylabel("Count") + ax.set_xlabel(r"Electrical Distance ($\mu$$V^2$)") for sub, (bridged_idx, ed_matrix) in ed_data.items(): # ed_matrix is upper triangular so exclude bottom half of NaNs - hist, edges = np.histogram(ed_matrix[~np.isnan(ed_matrix)].flatten(), - bins=np.linspace(0, 1000, 101)) + hist, edges = np.histogram( + ed_matrix[~np.isnan(ed_matrix)].flatten(), bins=np.linspace(0, 1000, 101) + ) centers = (edges[1:] + edges[:-1]) / 2 ax1.plot(centers, hist) - hist, edges = np.histogram(ed_matrix[~np.isnan(ed_matrix)].flatten(), - bins=np.linspace(0, 30, 21)) + hist, edges = np.histogram( + ed_matrix[~np.isnan(ed_matrix)].flatten(), bins=np.linspace(0, 30, 21) + ) centers = (edges[1:] + edges[:-1]) / 2 - ax2.plot(centers, hist, label=f'Sub {sub} #={len(bridged_idx)}') + ax2.plot(centers, hist, label=f"Sub {sub} #={len(bridged_idx)}") -ax1.axvspan(0, 30, color='r', alpha=0.5) +ax1.axvspan(0, 30, color="r", alpha=0.5) ax2.legend(loc=(1.04, 0)) fig.subplots_adjust(right=0.725, bottom=0.15, wspace=0.4) @@ -223,9 +241,12 @@ for sub, (bridged_idx, ed_matrix) in ed_data.items(): mne.viz.plot_bridged_electrodes( - raw_data[sub].info, bridged_idx, ed_matrix, - title=f'Subject {sub} Bridged Electrodes', - topomap_args=dict(vlim=(None, 5))) + raw_data[sub].info, + bridged_idx, + ed_matrix, + title=f"Subject {sub} Bridged Electrodes", + topomap_args=dict(vlim=(None, 5)), + ) # %% # For subjects with many bridged channels like Subject 6 shown in the example @@ -242,7 +263,8 @@ # use subject 2, only one bridged electrode pair bridged_idx = ed_data[2][0] raw = mne.preprocessing.interpolate_bridged_electrodes( - raw_data[2].copy(), bridged_idx=bridged_idx) + raw_data[2].copy(), bridged_idx=bridged_idx +) # %% # Let's make sure that our virtual channel aided the interpolation. We can do @@ -274,41 +296,73 @@ bridged_data[0] += 1e-7 * rng.normal(size=raw.times.size) bridged_data[1] += 1e-7 * rng.normal(size=raw.times.size) # add back simulated data -raw_sim = raw_sim.add_channels([mne.io.RawArray( - bridged_data, mne.create_info([ch0, ch1], raw.info['sfreq'], 'eeg'), - raw.first_samp)]) +raw_sim = raw_sim.add_channels( + [ + mne.io.RawArray( + bridged_data, + mne.create_info([ch0, ch1], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + ] +) raw_sim.set_montage(montage) # add back channel positions # use virtual channel method raw_virtual = mne.preprocessing.interpolate_bridged_electrodes( - raw_sim.copy(), bridged_idx=bridged_idx_simulated) + raw_sim.copy(), bridged_idx=bridged_idx_simulated +) data_virtual = raw_virtual.get_data(picks=(idx0, idx1)) # set bads to be bridged electrodes to interpolate without a virtual channel raw_comp = raw_sim.copy() -raw_comp.info['bads'] = [raw_sim.ch_names[idx0], raw_sim.ch_names[idx1]] +raw_comp.info["bads"] = [raw_sim.ch_names[idx0], raw_sim.ch_names[idx1]] raw_comp.interpolate_bads() data_comp = raw_comp.get_data(picks=(idx0, idx1)) # compute variance of residuals -print('Variance of residual (interpolated data - original data)\n\n' - 'With adding virtual channel: {}\n' - 'Compared to interpolation only using other channels: {}' - ''.format(np.mean(np.var(data_virtual - data_orig, axis=1)), - np.mean(np.var(data_comp - data_orig, axis=1)))) +print( + "Variance of residual (interpolated data - original data)\n\n" + "With adding virtual channel: {}\n" + "Compared to interpolation only using other channels: {}" + "".format( + np.mean(np.var(data_virtual - data_orig, axis=1)), + np.mean(np.var(data_comp - data_orig, axis=1)), + ) +) # plot results raw = raw.pick_channels([ch0, ch1]) -raw = raw.add_channels([mne.io.RawArray( - np.concatenate([data_virtual, data_virtual - data_orig]), - mne.create_info([f'{ch0} virtual', f'{ch1} virtual', - f'{ch0} virtual diff', f'{ch1} virtual diff'], - raw.info['sfreq'], 'eeg'), raw.first_samp)]) -raw = raw.add_channels([mne.io.RawArray( - np.concatenate([data_comp, data_comp - data_orig]), - mne.create_info([f'{ch0} comp', f'{ch1} comp', - f'{ch0} comp diff', f'{ch1} comp diff'], - raw.info['sfreq'], 'eeg'), raw.first_samp)]) +raw = raw.add_channels( + [ + mne.io.RawArray( + np.concatenate([data_virtual, data_virtual - data_orig]), + mne.create_info( + [ + f"{ch0} virtual", + f"{ch1} virtual", + f"{ch0} virtual diff", + f"{ch1} virtual diff", + ], + raw.info["sfreq"], + "eeg", + ), + raw.first_samp, + ) + ] +) +raw = raw.add_channels( + [ + mne.io.RawArray( + np.concatenate([data_comp, data_comp - data_orig]), + mne.create_info( + [f"{ch0} comp", f"{ch1} comp", f"{ch0} comp diff", f"{ch1} comp diff"], + raw.info["sfreq"], + "eeg", + ), + raw.first_samp, + ) + ] +) raw.plot(scalings=dict(eeg=7e-5)) # %% @@ -332,17 +386,26 @@ raw = raw_data[1] # typically impedances < 25 kOhm are acceptable for active systems and # impedances < 5 kOhm are desirable for a passive system -impedances = rng.random((len(raw.ch_names,))) * 30 +impedances = ( + rng.random( + ( + len( + raw.ch_names, + ) + ) + ) + * 30 +) impedances[10] = 80 # set a few bad impendances impedances[25] = 99 -cmap = LinearSegmentedColormap.from_list(name='impedance_cmap', - colors=['g', 'y', 'r'], N=256) +cmap = LinearSegmentedColormap.from_list( + name="impedance_cmap", colors=["g", "y", "r"], N=256 +) fig, ax = plt.subplots(figsize=(5, 5)) -im, cn = mne.viz.plot_topomap(impedances, raw.info, axes=ax, - cmap=cmap, vlim=(25, 75)) -ax.set_title('Electrode Impendances') +im, cn = mne.viz.plot_topomap(impedances, raw.info, axes=ax, cmap=cmap, vlim=(25, 75)) +ax.set_title("Electrode Impendances") cax = fig.colorbar(im, ax=ax) -cax.set_label(r'Impedance (k$\Omega$)') +cax.set_label(r"Impedance (k$\Omega$)") # %% # Summary diff --git a/examples/preprocessing/eeg_csd.py b/examples/preprocessing/eeg_csd.py index 24f33b91e53..7bb19415eaa 100644 --- a/examples/preprocessing/eeg_csd.py +++ b/examples/preprocessing/eeg_csd.py @@ -32,10 +32,11 @@ # %% # Load sample subject data -meg_path = data_path / 'MEG' / 'sample' -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') -raw = raw.pick_types(meg=False, eeg=True, eog=True, ecg=True, stim=True, - exclude=raw.info['bads']).load_data() +meg_path = data_path / "MEG" / "sample" +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") +raw = raw.pick_types( + meg=False, eeg=True, eog=True, ecg=True, stim=True, exclude=raw.info["bads"] +).load_data() events = mne.find_events(raw) raw.set_eeg_reference(projection=True).apply_proj() @@ -56,19 +57,24 @@ # CSD can also be computed on Evoked (averaged) data. # Here we epoch and average the data so we can demonstrate that. -event_id = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'smiley': 5, 'button': 32} -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=-0.2, tmax=.5, - preload=True) -evoked = epochs['auditory'].average() +event_id = { + "auditory/left": 1, + "auditory/right": 2, + "visual/left": 3, + "visual/right": 4, + "smiley": 5, + "button": 32, +} +epochs = mne.Epochs(raw, events, event_id=event_id, tmin=-0.2, tmax=0.5, preload=True) +evoked = epochs["auditory"].average() # %% # First let's look at how CSD affects scalp topography: -times = np.array([-0.1, 0., 0.05, 0.1, 0.15]) +times = np.array([-0.1, 0.0, 0.05, 0.1, 0.15]) evoked_csd = mne.preprocessing.compute_current_source_density(evoked) -evoked.plot_joint(title='Average Reference', show=False) -evoked_csd.plot_joint(title='Current Source Density') +evoked.plot_joint(title="Average Reference", show=False) +evoked_csd.plot_joint(title="Current Source Density") # %% # CSD has parameters ``stiffness`` and ``lambda2`` affecting smoothing and @@ -80,11 +86,12 @@ for i, lambda2 in enumerate([0, 1e-7, 1e-5, 1e-3]): for j, m in enumerate([5, 4, 3, 2]): this_evoked_csd = mne.preprocessing.compute_current_source_density( - evoked, stiffness=m, lambda2=lambda2) + evoked, stiffness=m, lambda2=lambda2 + ) this_evoked_csd.plot_topomap( - 0.1, axes=ax[i, j], contours=4, time_unit='s', - colorbar=False, show=False) - ax[i, j].set_title('stiffness=%i\nλ²=%s' % (m, lambda2)) + 0.1, axes=ax[i, j], contours=4, time_unit="s", colorbar=False, show=False + ) + ax[i, j].set_title("stiffness=%i\nλ²=%s" % (m, lambda2)) # %% # References diff --git a/examples/preprocessing/eog_artifact_histogram.py b/examples/preprocessing/eog_artifact_histogram.py index a6a3e895b3c..6953e8a8ed3 100644 --- a/examples/preprocessing/eog_artifact_histogram.py +++ b/examples/preprocessing/eog_artifact_histogram.py @@ -28,24 +28,24 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -events = mne.find_events(raw, 'STI 014') +events = mne.find_events(raw, "STI 014") eog_event_id = 512 eog_events = mne.preprocessing.find_eog_events(raw, eog_event_id) -raw.add_events(eog_events, 'STI 014') +raw.add_events(eog_events, "STI 014") # Read epochs picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=True, eog=False) tmin, tmax = -0.2, 0.5 -event_ids = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4} +event_ids = {"AudL": 1, "AudR": 2, "VisL": 3, "VisR": 4} epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks) # Get the stim channel data -pick_ch = mne.pick_channels(epochs.ch_names, ['STI 014'])[0] +pick_ch = mne.pick_channels(epochs.ch_names, ["STI 014"])[0] data = epochs.get_data()[:, pick_ch, :] data = np.sum((data.astype(int) & eog_event_id) == eog_event_id, axis=0) @@ -53,6 +53,5 @@ # Plot EOG artifact distribution fig, ax = plt.subplots() ax.stem(1e3 * epochs.times, data) -ax.set(xlabel='Times (ms)', - ylabel='Blink counts (from %s trials)' % len(epochs)) +ax.set(xlabel="Times (ms)", ylabel="Blink counts (from %s trials)" % len(epochs)) fig.tight_layout() diff --git a/examples/preprocessing/eog_regression.py b/examples/preprocessing/eog_regression.py index 1d7f6879b9a..6c88cb01d9a 100644 --- a/examples/preprocessing/eog_regression.py +++ b/examples/preprocessing/eog_regression.py @@ -30,14 +30,14 @@ print(__doc__) data_path = sample.data_path() -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_filt-0-40_raw.fif' +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_filt-0-40_raw.fif" # Read raw data raw = mne.io.read_raw_fif(raw_fname, preload=True) -events = mne.find_events(raw, 'STI 014') +events = mne.find_events(raw, "STI 014") # Highpass filter to eliminate slow drifts -raw.filter(0.3, None, picks='all') +raw.filter(0.3, None, picks="all") # %% # Perform regression and remove EOG @@ -57,21 +57,22 @@ # is best visualized by extracting epochs and plotting the evoked potential. tmin, tmax = -0.1, 0.5 -event_id = {'visual/left': 3, 'visual/right': 4} -evoked_before = mne.Epochs(raw, events, event_id, tmin, tmax, - baseline=(tmin, 0)).average() -evoked_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, - baseline=(tmin, 0)).average() +event_id = {"visual/left": 3, "visual/right": 4} +evoked_before = mne.Epochs( + raw, events, event_id, tmin, tmax, baseline=(tmin, 0) +).average() +evoked_after = mne.Epochs( + raw_clean, events, event_id, tmin, tmax, baseline=(tmin, 0) +).average() # Create epochs after EOG correction -epochs_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, - baseline=(tmin, 0)) +epochs_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, baseline=(tmin, 0)) evoked_after = epochs_after.average() -fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10, 7), - sharex=True, sharey='row') +fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10, 7), sharex=True, sharey="row") evoked_before.plot(axes=ax[:, 0], spatial_colors=True) evoked_after.plot(axes=ax[:, 1], spatial_colors=True) -fig.subplots_adjust(top=0.905, bottom=0.09, left=0.08, right=0.975, - hspace=0.325, wspace=0.145) -fig.suptitle('Before --> After') +fig.subplots_adjust( + top=0.905, bottom=0.09, left=0.08, right=0.975, hspace=0.325, wspace=0.145 +) +fig.suptitle("Before --> After") diff --git a/examples/preprocessing/find_ref_artifacts.py b/examples/preprocessing/find_ref_artifacts.py index f3781a0c1cc..8a08658a174 100644 --- a/examples/preprocessing/find_ref_artifacts.py +++ b/examples/preprocessing/find_ref_artifacts.py @@ -45,7 +45,7 @@ # %% # Read raw data, cropping to 5 minutes to save memory -raw_fname = data_path / 'sample_reference_MEG_noise-raw.fif' +raw_fname = data_path / "sample_reference_MEG_noise-raw.fif" raw = io.read_raw_fif(raw_fname).crop(300, 600).load_data() # %% @@ -53,11 +53,17 @@ # been applied to these data, much of the noise in the reference channels # (bottom of the plot) can still be seen in the standard channels. select_picks = np.concatenate( - (mne.pick_types(raw.info, meg=True)[-32:], - mne.pick_types(raw.info, meg=False, ref_meg=True))) + ( + mne.pick_types(raw.info, meg=True)[-32:], + mne.pick_types(raw.info, meg=False, ref_meg=True), + ) +) plot_kwargs = dict( - duration=100, order=select_picks, n_channels=len(select_picks), - scalings={"mag": 8e-13, "ref_meg": 2e-11}) + duration=100, + order=select_picks, + n_channels=len(select_picks), + scalings={"mag": 8e-13, "ref_meg": 2e-11}, +) raw.plot(**plot_kwargs) # %% @@ -68,12 +74,11 @@ # Run the "together" algorithm. raw_tog = raw.copy() ica_kwargs = dict( - method='picard', + method="picard", fit_params=dict(tol=1e-4), # use a high tol here for speed ) all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True) -ica_tog = ICA(n_components=60, max_iter='auto', allow_ref_meg=True, - **ica_kwargs) +ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs) ica_tog.fit(raw_tog, picks=all_picks) # low threshold (2.0) here because of cropped data, entire recording can use # a higher threshold (2.5) @@ -100,8 +105,7 @@ # Do ICA only on the reference channels. ref_picks = mne.pick_types(raw_sep.info, meg=False, ref_meg=True) -ica_ref = ICA(n_components=2, max_iter='auto', allow_ref_meg=True, - **ica_kwargs) +ica_ref = ICA(n_components=2, max_iter="auto", allow_ref_meg=True, **ica_kwargs) ica_ref.fit(raw_sep, picks=ref_picks) # Do ICA on both reference and standard channels. Here, we can just reuse diff --git a/examples/preprocessing/fnirs_artifact_removal.py b/examples/preprocessing/fnirs_artifact_removal.py index b7236b76636..d669d6ce09c 100644 --- a/examples/preprocessing/fnirs_artifact_removal.py +++ b/examples/preprocessing/fnirs_artifact_removal.py @@ -18,8 +18,10 @@ import os import mne -from mne.preprocessing.nirs import (optical_density, - temporal_derivative_distribution_repair) +from mne.preprocessing.nirs import ( + optical_density, + temporal_derivative_distribution_repair, +) # %% # Import data @@ -31,12 +33,13 @@ # and plot these signals. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_cw_amplitude_dir, verbose=True) raw_intensity.load_data().resample(3, npad="auto") raw_od = optical_density(raw_intensity) -new_annotations = mne.Annotations([31, 187, 317], [8, 8, 8], - ["Movement", "Movement", "Movement"]) +new_annotations = mne.Annotations( + [31, 187, 317], [8, 8, 8], ["Movement", "Movement", "Movement"] +) raw_od.set_annotations(new_annotations) raw_od.plot(n_channels=15, duration=400, show_scrollbars=False) @@ -61,10 +64,10 @@ corrupted_data = raw_od.get_data() corrupted_data[:, 298:302] = corrupted_data[:, 298:302] - 0.06 corrupted_data[:, 450:750] = corrupted_data[:, 450:750] + 0.03 -corrupted_od = mne.io.RawArray(corrupted_data, raw_od.info, - first_samp=raw_od.first_samp) -new_annotations.append([95, 145, 245], [10, 10, 10], - ["Spike", "Baseline", "Baseline"]) +corrupted_od = mne.io.RawArray( + corrupted_data, raw_od.info, first_samp=raw_od.first_samp +) +new_annotations.append([95, 145, 245], [10, 10, 10], ["Spike", "Baseline", "Baseline"]) corrupted_od.set_annotations(new_annotations) corrupted_od.plot(n_channels=15, duration=400, show_scrollbars=False) diff --git a/examples/preprocessing/ica_comparison.py b/examples/preprocessing/ica_comparison.py index 7c4a8aa733c..6aa601dd5fa 100644 --- a/examples/preprocessing/ica_comparison.py +++ b/examples/preprocessing/ica_comparison.py @@ -31,13 +31,13 @@ # - 1-30 Hz band-pass filter data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" -raw = mne.io.read_raw_fif(raw_fname).crop(0, 60).pick('meg').load_data() +raw = mne.io.read_raw_fif(raw_fname).crop(0, 60).pick("meg").load_data() reject = dict(mag=5e-12, grad=4000e-13) -raw.filter(1, 30, fir_design='firwin') +raw.filter(1, 30, fir_design="firwin") # %% @@ -45,27 +45,32 @@ def run_ica(method, fit_params=None): - ica = ICA(n_components=20, method=method, fit_params=fit_params, - max_iter='auto', random_state=0) + ica = ICA( + n_components=20, + method=method, + fit_params=fit_params, + max_iter="auto", + random_state=0, + ) t0 = time() ica.fit(raw, reject=reject) fit_time = time() - t0 - title = ('ICA decomposition using %s (took %.1fs)' % (method, fit_time)) + title = "ICA decomposition using %s (took %.1fs)" % (method, fit_time) ica.plot_components(title=title) # %% # FastICA -run_ica('fastica') +run_ica("fastica") # %% # Picard -run_ica('picard') +run_ica("picard") # %% # Infomax -run_ica('infomax') +run_ica("infomax") # %% # Extended Infomax -run_ica('infomax', fit_params=dict(extended=True)) +run_ica("infomax", fit_params=dict(extended=True)) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index 635dffcbfba..7040e24299e 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -28,24 +28,24 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname, condition='Left Auditory', - baseline=(None, 0)) +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname, condition="Left Auditory", baseline=(None, 0)) # plot with bads -evoked.plot(exclude=[], picks=('grad', 'eeg')) +evoked.plot(exclude=[], picks=("grad", "eeg")) # %% # Compute interpolation (also works with Raw and Epochs objects) evoked_interp = evoked.copy().interpolate_bads(reset_bads=False) -evoked_interp.plot(exclude=[], picks=('grad', 'eeg')) +evoked_interp.plot(exclude=[], picks=("grad", "eeg")) # %% # You can also use minimum-norm for EEG as well as MEG evoked_interp_mne = evoked.copy().interpolate_bads( - reset_bads=False, method=dict(eeg='MNE'), verbose=True) -evoked_interp_mne.plot(exclude=[], picks=('grad', 'eeg')) + reset_bads=False, method=dict(eeg="MNE"), verbose=True +) +evoked_interp_mne.plot(exclude=[], picks=("grad", "eeg")) # %% # References diff --git a/examples/preprocessing/movement_compensation.py b/examples/preprocessing/movement_compensation.py index 3a31648c4a5..97d183533a8 100644 --- a/examples/preprocessing/movement_compensation.py +++ b/examples/preprocessing/movement_compensation.py @@ -24,11 +24,11 @@ print(__doc__) -data_path = mne.datasets.misc.data_path(verbose=True) / 'movement' +data_path = mne.datasets.misc.data_path(verbose=True) / "movement" -head_pos = mne.chpi.read_head_pos(data_path / 'simulated_quats.pos') -raw = mne.io.read_raw_fif(data_path / 'simulated_movement_raw.fif') -raw_stat = mne.io.read_raw_fif(data_path / 'simulated_stationary_raw.fif') +head_pos = mne.chpi.read_head_pos(data_path / "simulated_quats.pos") +raw = mne.io.read_raw_fif(data_path / "simulated_movement_raw.fif") +raw_stat = mne.io.read_raw_fif(data_path / "simulated_stationary_raw.fif") # %% # Visualize the "subject" head movements. By providing the measurement @@ -37,29 +37,31 @@ # be shown in blue, and the destination (if given) shown in red. mne.viz.plot_head_positions( - head_pos, mode='traces', destination=raw.info['dev_head_t'], info=raw.info) + head_pos, mode="traces", destination=raw.info["dev_head_t"], info=raw.info +) # %% # This can also be visualized using a quiver. mne.viz.plot_head_positions( - head_pos, mode='field', destination=raw.info['dev_head_t'], info=raw.info) + head_pos, mode="field", destination=raw.info["dev_head_t"], info=raw.info +) # %% # Process our simulated raw data (taking into account head movements). # extract our resulting events -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") events[:, 2] = 1 raw.plot(events=events) -topo_kwargs = dict(times=[0, 0.1, 0.2], ch_type='mag', vlim=(-500, 500)) +topo_kwargs = dict(times=[0, 0.1, 0.2], ch_type="mag", vlim=(-500, 500)) # %% # First, take the average of stationary data (bilateral auditory patterns). evoked_stat = mne.Epochs(raw_stat, events, 1, -0.2, 0.8).average() fig = evoked_stat.plot_topomap(**topo_kwargs) -fig.suptitle('Stationary') +fig.suptitle("Stationary") # %% # Second, take a naive average, which averages across epochs that have been @@ -68,18 +70,18 @@ epochs = mne.Epochs(raw, events, 1, -0.2, 0.8) evoked = epochs.average() fig = evoked.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: naive average') +fig.suptitle("Moving: naive average") # %% # Third, use raw movement compensation (restores pattern). raw_sss = maxwell_filter(raw, head_pos=head_pos) evoked_raw_mc = mne.Epochs(raw_sss, events, 1, -0.2, 0.8).average() fig = evoked_raw_mc.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: movement compensated (raw)') +fig.suptitle("Moving: movement compensated (raw)") # %% # Fourth, use evoked movement compensation. For these data, which contain # very large rotations, it does not as cleanly restore the pattern. evoked_evo_mc = mne.epochs.average_movements(epochs, head_pos=head_pos) fig = evoked_evo_mc.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: movement compensated (evoked)') +fig.suptitle("Moving: movement compensated (evoked)") diff --git a/examples/preprocessing/movement_detection.py b/examples/preprocessing/movement_detection.py index ac90f45f587..2984c53fea2 100644 --- a/examples/preprocessing/movement_detection.py +++ b/examples/preprocessing/movement_detection.py @@ -29,16 +29,17 @@ # Load data data_path = bst_auditory.data_path() -data_path_MEG = data_path / 'MEG' -subject = 'bst_auditory' -subjects_dir = data_path / 'subjects' -trans_fname = data_path / 'MEG' / 'bst_auditory' / 'bst_auditory-trans.fif' -raw_fname1 = data_path_MEG / 'bst_auditory' / 'S01_AEF_20131218_01.ds' -raw_fname2 = data_path_MEG / 'bst_auditory' / 'S01_AEF_20131218_02.ds' +data_path_MEG = data_path / "MEG" +subject = "bst_auditory" +subjects_dir = data_path / "subjects" +trans_fname = data_path / "MEG" / "bst_auditory" / "bst_auditory-trans.fif" +raw_fname1 = data_path_MEG / "bst_auditory" / "S01_AEF_20131218_01.ds" +raw_fname2 = data_path_MEG / "bst_auditory" / "S01_AEF_20131218_02.ds" # read and concatenate two files, ignoring device<->head mismatch raw = read_raw_ctf(raw_fname1, preload=False) mne.io.concatenate_raws( - [raw, read_raw_ctf(raw_fname2, preload=False)], on_mismatch='ignore') + [raw, read_raw_ctf(raw_fname2, preload=False)], on_mismatch="ignore" +) raw.crop(350, 410).load_data() raw.resample(100, npad="auto") @@ -49,15 +50,18 @@ # Get cHPI time series and compute average chpi_locs = mne.chpi.extract_chpi_locs_ctf(raw) head_pos = mne.chpi.compute_head_pos(raw.info, chpi_locs) -original_head_dev_t = mne.transforms.invert_transform( - raw.info['dev_head_t']) +original_head_dev_t = mne.transforms.invert_transform(raw.info["dev_head_t"]) average_head_dev_t = mne.transforms.invert_transform( - compute_average_dev_head_t(raw, head_pos)) + compute_average_dev_head_t(raw, head_pos) +) fig = mne.viz.plot_head_positions(head_pos) -for ax, val, val_ori in zip(fig.axes[::2], average_head_dev_t['trans'][:3, 3], - original_head_dev_t['trans'][:3, 3]): - ax.axhline(1000 * val, color='r') - ax.axhline(1000 * val_ori, color='g') +for ax, val, val_ori in zip( + fig.axes[::2], + average_head_dev_t["trans"][:3, 3], + original_head_dev_t["trans"][:3, 3], +): + ax.axhline(1000 * val, color="r") + ax.axhline(1000 * val_ori, color="g") # The green horizontal lines represent the original head position, whereas the # red lines are the new head position averaged over all the time points. @@ -66,9 +70,10 @@ # Plot raw data with annotated movement # ------------------------------------------------------------------ -mean_distance_limit = .0015 # in meters +mean_distance_limit = 0.0015 # in meters annotation_movement, hpi_disp = annotate_movement( - raw, head_pos, mean_distance_limit=mean_distance_limit) + raw, head_pos, mean_distance_limit=mean_distance_limit +) raw.set_annotations(annotation_movement) raw.plot(n_channels=100, duration=20) @@ -76,7 +81,12 @@ # After checking the annotated movement artifacts, calculate the new transform # and plot it: new_dev_head_t = compute_average_dev_head_t(raw, head_pos) -raw.info['dev_head_t'] = new_dev_head_t -fig = mne.viz.plot_alignment(raw.info, show_axes=True, subject=subject, - trans=trans_fname, subjects_dir=subjects_dir) +raw.info["dev_head_t"] = new_dev_head_t +fig = mne.viz.plot_alignment( + raw.info, + show_axes=True, + subject=subject, + trans=trans_fname, + subjects_dir=subjects_dir, +) mne.viz.set_3d_view(fig, azimuth=90, elevation=60) diff --git a/examples/preprocessing/muscle_detection.py b/examples/preprocessing/muscle_detection.py index 223f93743d9..37bd021d853 100644 --- a/examples/preprocessing/muscle_detection.py +++ b/examples/preprocessing/muscle_detection.py @@ -39,7 +39,7 @@ # Load data data_path = bst_auditory.data_path() -raw_fname = data_path / 'MEG' / 'bst_auditory' / 'S01_AEF_20131218_01.ds' +raw_fname = data_path / "MEG" / "bst_auditory" / "S01_AEF_20131218_01.ds" raw = read_raw_ctf(raw_fname, preload=False) @@ -64,8 +64,12 @@ # Choose one channel type, if there are axial gradiometers and magnetometers, # select magnetometers as they are more sensitive to muscle activity. annot_muscle, scores_muscle = annotate_muscle_zscore( - raw, ch_type="mag", threshold=threshold_muscle, min_length_good=0.2, - filter_freq=[110, 140]) + raw, + ch_type="mag", + threshold=threshold_muscle, + min_length_good=0.2, + filter_freq=[110, 140], +) # %% # Plot muscle z-scores across recording @@ -73,8 +77,8 @@ fig, ax = plt.subplots() ax.plot(raw.times, scores_muscle) -ax.axhline(y=threshold_muscle, color='r') -ax.set(xlabel='time, (s)', ylabel='zscore', title='Muscle activity') +ax.axhline(y=threshold_muscle, color="r") +ax.set(xlabel="time, (s)", ylabel="zscore", title="Muscle activity") # %% # View the annotations # -------------------------------------------------------------------------- diff --git a/examples/preprocessing/muscle_ica.py b/examples/preprocessing/muscle_ica.py index 8abc96f5d6a..8f50615c66a 100644 --- a/examples/preprocessing/muscle_ica.py +++ b/examples/preprocessing/muscle_ica.py @@ -22,7 +22,7 @@ import mne data_path = mne.datasets.sample.data_path() -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" raw = mne.io.read_raw_fif(raw_fname) raw.crop(tmin=100, tmax=130) # take 30 seconds for speed @@ -33,12 +33,13 @@ # ICA works best with a highpass filter applied raw.load_data() -raw.filter(l_freq=1., h_freq=None) +raw.filter(l_freq=1.0, h_freq=None) # %% # Run ICA ica = mne.preprocessing.ICA( - n_components=15, method='picard', max_iter='auto', random_state=97) + n_components=15, method="picard", max_iter="auto", random_state=97 +) ica.fit(raw) # %% @@ -85,8 +86,10 @@ # and ensure that it gets the same components we did manually. muscle_idx_auto, scores = ica.find_bads_muscle(raw) ica.plot_scores(scores, exclude=muscle_idx_auto) -print(f'Manually found muscle artifact ICA components: {muscle_idx}\n' - f'Automatically found muscle artifact ICA components: {muscle_idx_auto}') +print( + f"Manually found muscle artifact ICA components: {muscle_idx}\n" + f"Automatically found muscle artifact ICA components: {muscle_idx_auto}" +) # %% # Let's now replicate this on the EEGBCI dataset @@ -94,24 +97,28 @@ for sub in (1, 2): raw = mne.io.read_raw_edf( - mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0], preload=True) + mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0], preload=True + ) mne.datasets.eegbci.standardize(raw) # set channel names - montage = mne.channels.make_standard_montage('standard_1005') + montage = mne.channels.make_standard_montage("standard_1005") raw.set_montage(montage) - raw.filter(l_freq=1., h_freq=None) + raw.filter(l_freq=1.0, h_freq=None) # Run ICA ica = mne.preprocessing.ICA( - n_components=15, method='picard', max_iter='auto', random_state=97) + n_components=15, method="picard", max_iter="auto", random_state=97 + ) ica.fit(raw) ica.plot_sources(raw) muscle_idx_auto, scores = ica.find_bads_muscle(raw) ica.plot_properties(raw, picks=muscle_idx_auto, log_scale=True) ica.plot_scores(scores, exclude=muscle_idx_auto) - print(f'Manually found muscle artifact ICA components: {muscle_idx}\n' - 'Automatically found muscle artifact ICA components: ' - f'{muscle_idx_auto}') + print( + f"Manually found muscle artifact ICA components: {muscle_idx}\n" + "Automatically found muscle artifact ICA components: " + f"{muscle_idx_auto}" + ) # %% # References diff --git a/examples/preprocessing/otp.py b/examples/preprocessing/otp.py index 520d66166ac..a05eaf5c6ce 100644 --- a/examples/preprocessing/otp.py +++ b/examples/preprocessing/otp.py @@ -32,17 +32,17 @@ dipole_number = 1 data_path = bst_phantom_elekta.data_path() -raw = read_raw_fif(data_path / 'kojak_all_200nAm_pp_no_chpi_no_ms_raw.fif') -raw.crop(40., 50.).load_data() +raw = read_raw_fif(data_path / "kojak_all_200nAm_pp_no_chpi_no_ms_raw.fif") +raw.crop(40.0, 50.0).load_data() order = list(range(160, 170)) -raw.copy().filter(0., 40.).plot(order=order, n_channels=10) +raw.copy().filter(0.0, 40.0).plot(order=order, n_channels=10) # %% # Now we can clean the data with OTP, lowpass, and plot. The flux jumps have # been suppressed alongside the random sensor noise. raw_clean = mne.preprocessing.oversampled_temporal_projection(raw) -raw_clean.filter(0., 40.) +raw_clean.filter(0.0, 40.0) raw_clean.plot(order=order, n_channels=10) @@ -52,19 +52,26 @@ # for more information. Here we use a version that does single-trial # localization across the 17 trials are in our 10-second window: + def compute_bias(raw): - events = find_events(raw, 'STI201', verbose=False) + events = find_events(raw, "STI201", verbose=False) events = events[1:] # first one has an artifact tmin, tmax = -0.2, 0.1 - epochs = mne.Epochs(raw, events, dipole_number, tmin, tmax, - baseline=(None, -0.01), preload=True, verbose=False) - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=None, - verbose=False) - cov = mne.compute_covariance(epochs, tmax=0, method='oas', - rank=None, verbose=False) + epochs = mne.Epochs( + raw, + events, + dipole_number, + tmin, + tmax, + baseline=(None, -0.01), + preload=True, + verbose=False, + ) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=None, verbose=False) + cov = mne.compute_covariance(epochs, tmax=0, method="oas", rank=None, verbose=False) idx = epochs.time_as_index(0.036)[0] data = epochs.get_data()[:, :, idx].T - evoked = mne.EvokedArray(data, epochs.info, tmin=0.) + evoked = mne.EvokedArray(data, epochs.info, tmin=0.0) dip = fit_dipole(evoked, cov, sphere, n_jobs=None, verbose=False)[0] actual_pos = mne.dipole.get_phantom_dipoles()[0][dipole_number - 1] misses = 1000 * np.linalg.norm(dip.pos - actual_pos, axis=-1) @@ -72,11 +79,15 @@ def compute_bias(raw): bias = compute_bias(raw) -print('Raw bias: %0.1fmm (worst: %0.1fmm)' - % (np.mean(bias), np.max(bias))) +print("Raw bias: %0.1fmm (worst: %0.1fmm)" % (np.mean(bias), np.max(bias))) bias_clean = compute_bias(raw_clean) -print('OTP bias: %0.1fmm (worst: %0.1fmm)' - % (np.mean(bias_clean), np.max(bias_clean),)) +print( + "OTP bias: %0.1fmm (worst: %0.1fmm)" + % ( + np.mean(bias_clean), + np.max(bias_clean), + ) +) # %% # References diff --git a/examples/preprocessing/shift_evoked.py b/examples/preprocessing/shift_evoked.py index 3bbe0386416..7b05a1b4714 100644 --- a/examples/preprocessing/shift_evoked.py +++ b/examples/preprocessing/shift_evoked.py @@ -20,32 +20,46 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" # Reading evoked data -condition = 'Left Auditory' -evoked = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), - proj=True) +condition = "Left Auditory" +evoked = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), proj=True) -ch_names = evoked.info['ch_names'] +ch_names = evoked.info["ch_names"] picks = mne.pick_channels(ch_names=ch_names, include=["MEG 2332"]) # Create subplots f, (ax1, ax2, ax3) = plt.subplots(3) -evoked.plot(exclude=[], picks=picks, axes=ax1, - titles=dict(grad='Before time shifting'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax1, + titles=dict(grad="Before time shifting"), + time_unit="s", +) # Apply relative time-shift of 500 ms evoked.shift_time(0.5, relative=True) -evoked.plot(exclude=[], picks=picks, axes=ax2, - titles=dict(grad='Relative shift: 500 ms'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax2, + titles=dict(grad="Relative shift: 500 ms"), + time_unit="s", +) # Apply absolute time-shift of 500 ms evoked.shift_time(0.5, relative=False) -evoked.plot(exclude=[], picks=picks, axes=ax3, - titles=dict(grad='Absolute shift: 500 ms'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax3, + titles=dict(grad="Absolute shift: 500 ms"), + time_unit="s", +) tight_layout() diff --git a/examples/preprocessing/virtual_evoked.py b/examples/preprocessing/virtual_evoked.py index b947226b40b..096165910da 100644 --- a/examples/preprocessing/virtual_evoked.py +++ b/examples/preprocessing/virtual_evoked.py @@ -27,35 +27,35 @@ # read the evoked data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname, condition='Left Auditory', baseline=(None, 0)) +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname, condition="Left Auditory", baseline=(None, 0)) # %% # First, let's call remap gradiometers to magnometers, and plot # the original and remapped topomaps of the magnetometers. # go from grad + mag to mag and plot original mag -virt_evoked = evoked.as_type('mag') -fig = evoked.plot_topomap(ch_type='mag') -fig.suptitle('mag (original)') +virt_evoked = evoked.as_type("mag") +fig = evoked.plot_topomap(ch_type="mag") +fig.suptitle("mag (original)") # %% # plot interpolated grad + mag -fig = virt_evoked.plot_topomap(ch_type='mag') -fig.suptitle('mag (interpolated from mag + grad)') +fig = virt_evoked.plot_topomap(ch_type="mag") +fig.suptitle("mag (interpolated from mag + grad)") # %% # Now, we remap magnometers to gradiometers, and plot # the original and remapped topomaps of the gradiometers # go from grad + mag to grad and plot original grad -virt_evoked = evoked.as_type('grad') -fig = evoked.plot_topomap(ch_type='grad') -fig.suptitle('grad (original)') +virt_evoked = evoked.as_type("grad") +fig = evoked.plot_topomap(ch_type="grad") +fig.suptitle("grad (original)") # %% # plot interpolated grad + mag -fig = virt_evoked.plot_topomap(ch_type='grad') -fig.suptitle('grad (interpolated from mag + grad)') +fig = virt_evoked.plot_topomap(ch_type="grad") +fig.suptitle("grad (interpolated from mag + grad)") diff --git a/examples/preprocessing/xdawn_denoising.py b/examples/preprocessing/xdawn_denoising.py index aa7c0f48e08..b6eed43d142 100644 --- a/examples/preprocessing/xdawn_denoising.py +++ b/examples/preprocessing/xdawn_denoising.py @@ -25,7 +25,7 @@ # %% -from mne import (io, compute_raw_covariance, read_events, pick_types, Epochs) +from mne import io, compute_raw_covariance, read_events, pick_types, Epochs from mne.datasets import sample from mne.preprocessing import Xdawn from mne.viz import plot_epochs_image @@ -36,27 +36,35 @@ # %% # Set parameters and read data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 event_id = dict(vis_r=4) # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') # replace baselining with high-pass +raw.filter(1, 20, fir_design="firwin") # replace baselining with high-pass events = read_events(event_fname) -raw.info['bads'] = ['MEG 2443'] # set bad channels -picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False, - exclude='bads') +raw.info["bads"] = ["MEG 2443"] # set bad channels +picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False, exclude="bads") # Epoching -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) # Plot image epoch before xdawn -plot_epochs_image(epochs['vis_r'], picks=[230], vmin=-500, vmax=500) +plot_epochs_image(epochs["vis_r"], picks=[230], vmin=-500, vmax=500) # %% # Now, we estimate a set of xDAWN filters for the epochs (which contain only @@ -78,7 +86,7 @@ epochs_denoised = xd.apply(epochs) # Plot image epoch after Xdawn -plot_epochs_image(epochs_denoised['vis_r'], picks=[230], vmin=-500, vmax=500) +plot_epochs_image(epochs_denoised["vis_r"], picks=[230], vmin=-500, vmax=500) # %% # References diff --git a/examples/simulation/plot_stc_metrics.py b/examples/simulation/plot_stc_metrics.py index 2e53c6bcd02..20912c12cc1 100644 --- a/examples/simulation/plot_stc_metrics.py +++ b/examples/simulation/plot_stc_metrics.py @@ -20,31 +20,36 @@ import mne from mne.datasets import sample from mne.minimum_norm import make_inverse_operator, apply_inverse -from mne.simulation.metrics import (region_localization_error, - f1_score, precision_score, - recall_score, cosine_score, - peak_position_error, - spatial_deviation_error) +from mne.simulation.metrics import ( + region_localization_error, + f1_score, + precision_score, + recall_score, + cosine_score, + peak_position_error, + spatial_deviation_error, +) random_state = 42 # set random state to make this example deterministic # Import sample data data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -evoked_fname = data_path / 'MEG' / subject / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +subject = "sample" +evoked_fname = data_path / "MEG" / subject / "sample_audvis-ave.fif" info = mne.io.read_info(evoked_fname) -tstep = 1. / info['sfreq'] +tstep = 1.0 / info["sfreq"] # Import forward operator and source space -fwd_fname = data_path / 'MEG' / subject / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = data_path / "MEG" / subject / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # To select source, we use the caudal middle frontal to grow # a region of interest. selected_label = mne.read_labels_from_annot( - subject, regexp='caudalmiddlefrontal-lh', subjects_dir=subjects_dir)[0] + subject, regexp="caudalmiddlefrontal-lh", subjects_dir=subjects_dir +)[0] ############################################################################### @@ -61,22 +66,32 @@ # WHERE? # Region -location = 'center' # Use the center of the label as a seed. -extent = 20. # Extent in mm of the region. +location = "center" # Use the center of the label as a seed. +extent = 20.0 # Extent in mm of the region. label_region = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir, random_state=random_state) + subject, + selected_label, + location=location, + extent=extent, + subjects_dir=subjects_dir, + random_state=random_state, +) # Dipole location = 1915 # Use the index of the vertex as a seed -extent = 0. # One dipole source +extent = 0.0 # One dipole source label_dipole = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir, random_state=random_state) + subject, + selected_label, + location=location, + extent=extent, + subjects_dir=subjects_dir, + random_state=random_state, +) # WHAT? # Define the time course of the activity -source_time_series = np.sin(2. * np.pi * 18. * np.arange(100) * tstep) * 10e-9 +source_time_series = np.sin(2.0 * np.pi * 18.0 * np.arange(100) * tstep) * 10e-9 # WHEN? # Define when the activity occurs using events. @@ -107,20 +122,20 @@ # noise obtained from the noise covariance from the sample data. # Region -raw_region = mne.simulation.simulate_raw(info, source_simulator_region, - forward=fwd) +raw_region = mne.simulation.simulate_raw(info, source_simulator_region, forward=fwd) raw_region = raw_region.pick_types(meg=False, eeg=True, stim=True) cov = mne.make_ad_hoc_cov(raw_region.info) -mne.simulation.add_noise(raw_region, cov, iir_filter=[0.2, -0.2, 0.04], - random_state=random_state) +mne.simulation.add_noise( + raw_region, cov, iir_filter=[0.2, -0.2, 0.04], random_state=random_state +) # Dipole -raw_dipole = mne.simulation.simulate_raw(info, source_simulator_dipole, - forward=fwd) +raw_dipole = mne.simulation.simulate_raw(info, source_simulator_dipole, forward=fwd) raw_dipole = raw_dipole.pick_types(meg=False, eeg=True, stim=True) cov = mne.make_ad_hoc_cov(raw_dipole.info) -mne.simulation.add_noise(raw_dipole, cov, iir_filter=[0.2, -0.2, 0.04], - random_state=random_state) +mne.simulation.add_noise( + raw_dipole, cov, iir_filter=[0.2, -0.2, 0.04], random_state=random_state +) ############################################################################### # Compute evoked from raw data @@ -149,14 +164,14 @@ # same number of time samples. # Region -stc_true_region = \ - source_simulator_region.get_stc(start_sample=0, - stop_sample=len(source_time_series)) +stc_true_region = source_simulator_region.get_stc( + start_sample=0, stop_sample=len(source_time_series) +) # Dipole -stc_true_dipole = \ - source_simulator_dipole.get_stc(start_sample=0, - stop_sample=len(source_time_series)) +stc_true_dipole = source_simulator_dipole.get_stc( + start_sample=0, stop_sample=len(source_time_series) +) ############################################################################### # Reconstruct simulated sources @@ -166,27 +181,29 @@ # Region snr = 30.0 -inv_method = 'sLORETA' -lambda2 = 1.0 / snr ** 2 +inv_method = "sLORETA" +lambda2 = 1.0 / snr**2 -inverse_operator = make_inverse_operator(evoked_region.info, fwd, cov, - loose='auto', depth=0.8, - fixed=True) +inverse_operator = make_inverse_operator( + evoked_region.info, fwd, cov, loose="auto", depth=0.8, fixed=True +) -stc_est_region = apply_inverse(evoked_region, inverse_operator, lambda2, - inv_method, pick_ori=None) +stc_est_region = apply_inverse( + evoked_region, inverse_operator, lambda2, inv_method, pick_ori=None +) # Dipole snr = 3.0 -inv_method = 'sLORETA' -lambda2 = 1.0 / snr ** 2 +inv_method = "sLORETA" +lambda2 = 1.0 / snr**2 -inverse_operator = make_inverse_operator(evoked_dipole.info, fwd, cov, - loose='auto', depth=0.8, - fixed=True) +inverse_operator = make_inverse_operator( + evoked_dipole.info, fwd, cov, loose="auto", depth=0.8, fixed=True +) -stc_est_dipole = apply_inverse(evoked_dipole, inverse_operator, lambda2, - inv_method, pick_ori=None) +stc_est_dipole = apply_inverse( + evoked_dipole, inverse_operator, lambda2, inv_method, pick_ori=None +) ############################################################################### # Compute performance scores for different source amplitude thresholds @@ -201,32 +218,34 @@ # # create a set of scorers -scorers = {'RLE': partial(region_localization_error, src=src), - 'Precision': precision_score, 'Recall': recall_score, - 'F1 score': f1_score} +scorers = { + "RLE": partial(region_localization_error, src=src), + "Precision": precision_score, + "Recall": recall_score, + "F1 score": f1_score, +} # compute results region_results = {} for name, scorer in scorers.items(): - region_results[name] = [scorer(stc_true_region, stc_est_region, - threshold=f'{thx}%', per_sample=False) - for thx in thresholds] + region_results[name] = [ + scorer(stc_true_region, stc_est_region, threshold=f"{thx}%", per_sample=False) + for thx in thresholds + ] # Plot the results -f, ((ax1, ax2), (ax3, ax4)) = plt.subplots( - 2, 2, sharex='col', constrained_layout=True) +f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex="col", constrained_layout=True) for ax, (title, results) in zip([ax1, ax2, ax3, ax4], region_results.items()): - ax.plot(thresholds, results, '.-') - ax.set(title=title, ylabel='score', xlabel='Threshold', - xticks=thresholds) + ax.plot(thresholds, results, ".-") + ax.set(title=title, ylabel="score", xlabel="Threshold", xticks=thresholds) -f.suptitle('Performance scores per threshold') # Add Super title -ax1.ticklabel_format(axis='y', style='sci', scilimits=(0, 1)) # tweak RLE +f.suptitle("Performance scores per threshold") # Add Super title +ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 1)) # tweak RLE # Cosine score with respect to time f, ax1 = plt.subplots(constrained_layout=True) ax1.plot(stc_true_region.times, cosine_score(stc_true_region, stc_est_region)) -ax1.set(title='Cosine score', xlabel='Time', ylabel='Score') +ax1.set(title="Cosine score", xlabel="Time", ylabel="Score") ############################################################################### @@ -236,22 +255,28 @@ # create a set of scorers scorers = { - 'Peak Position Error': peak_position_error, - 'Spatial Deviation Error': spatial_deviation_error, + "Peak Position Error": peak_position_error, + "Spatial Deviation Error": spatial_deviation_error, } # compute results dipole_results = {} for name, scorer in scorers.items(): - dipole_results[name] = [scorer(stc_true_dipole, stc_est_dipole, src=src, - threshold=f'{thx}%', per_sample=False) - for thx in thresholds] + dipole_results[name] = [ + scorer( + stc_true_dipole, + stc_est_dipole, + src=src, + threshold=f"{thx}%", + per_sample=False, + ) + for thx in thresholds + ] # Plot the results for name, results in dipole_results.items(): f, ax1 = plt.subplots(constrained_layout=True) - ax1.plot(thresholds, 100 * np.array(results), '.-') - ax1.set(title=name, ylabel='Error (cm)', xlabel='Threshold', - xticks=thresholds) + ax1.plot(thresholds, 100 * np.array(results), ".-") + ax1.set(title=name, ylabel="Error (cm)", xlabel="Threshold", xticks=thresholds) diff --git a/examples/simulation/simulate_evoked_data.py b/examples/simulation/simulate_evoked_data.py index b906d2df265..0a8d69a66ed 100644 --- a/examples/simulation/simulate_evoked_data.py +++ b/examples/simulation/simulate_evoked_data.py @@ -28,55 +28,65 @@ # %% # Load real data as templates data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') -proj = mne.read_proj(meg_path / 'sample_audvis_ecg-proj.fif') +meg_path = data_path / "MEG" / "sample" +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") +proj = mne.read_proj(meg_path / "sample_audvis_ecg-proj.fif") raw.add_proj(proj) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels +raw.info["bads"] = ["MEG 2443", "EEG 053"] # mark bad channels -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" fwd = mne.read_forward_solution(fwd_fname) -fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads']) +fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info["bads"]) cov = mne.read_cov(cov_fname) info = mne.io.read_info(ave_fname) -label_names = ['Aud-lh', 'Aud-rh'] -labels = [mne.read_label(meg_path / 'labels' / f'{ln}.label') - for ln in label_names] +label_names = ["Aud-lh", "Aud-rh"] +labels = [mne.read_label(meg_path / "labels" / f"{ln}.label") for ln in label_names] # %% # Generate source time courses from 2 dipoles and the corresponding evoked data -times = np.arange(300, dtype=np.float64) / raw.info['sfreq'] - 0.1 +times = np.arange(300, dtype=np.float64) / raw.info["sfreq"] - 0.1 rng = np.random.RandomState(42) def data_fun(times): """Generate random source time courses.""" - return (50e-9 * np.sin(30. * times) * - np.exp(- (times - 0.15 + 0.05 * rng.randn(1)) ** 2 / 0.01)) - - -stc = simulate_sparse_stc(fwd['src'], n_dipoles=2, times=times, - random_state=42, labels=labels, data_fun=data_fun) + return ( + 50e-9 + * np.sin(30.0 * times) + * np.exp(-((times - 0.15 + 0.05 * rng.randn(1)) ** 2) / 0.01) + ) + + +stc = simulate_sparse_stc( + fwd["src"], + n_dipoles=2, + times=times, + random_state=42, + labels=labels, + data_fun=data_fun, +) # %% # Generate noisy evoked data -picks = mne.pick_types(raw.info, meg=True, exclude='bads') +picks = mne.pick_types(raw.info, meg=True, exclude="bads") iir_filter = fit_iir_model_raw(raw, order=5, picks=picks, tmin=60, tmax=180)[1] nave = 100 # simulate average of 100 epochs -evoked = simulate_evoked(fwd, stc, info, cov, nave=nave, use_cps=True, - iir_filter=iir_filter) +evoked = simulate_evoked( + fwd, stc, info, cov, nave=nave, use_cps=True, iir_filter=iir_filter +) # %% # Plot -plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1), - opacity=0.5, high_resolution=True) +plot_sparse_source_estimates( + fwd["src"], stc, bgcolor=(1, 1, 1), opacity=0.5, high_resolution=True +) plt.figure() plt.psd(evoked.data[0]) -evoked.plot(time_unit='s') +evoked.plot(time_unit="s") diff --git a/examples/simulation/simulate_raw_data.py b/examples/simulation/simulate_raw_data.py index 6c308792c97..902429717c2 100644 --- a/examples/simulation/simulate_raw_data.py +++ b/examples/simulation/simulate_raw_data.py @@ -22,15 +22,20 @@ import mne from mne import find_events, Epochs, compute_covariance, make_ad_hoc_cov from mne.datasets import sample -from mne.simulation import (simulate_sparse_stc, simulate_raw, - add_noise, add_ecg, add_eog) +from mne.simulation import ( + simulate_sparse_stc, + simulate_raw, + add_noise, + add_ecg, + add_eog, +) print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" # Load real data as the template raw = mne.io.read_raw_fif(raw_fname) @@ -39,7 +44,7 @@ ############################################################################## # Generate dipole time series n_dipoles = 4 # number of dipoles to create -epoch_duration = 2. # duration of each epoch/event +epoch_duration = 2.0 # duration of each epoch/event n = 0 # harmonic number rng = np.random.RandomState(0) # random state (make reproducible) @@ -49,24 +54,26 @@ def data_fun(times): global n n_samp = len(times) window = np.zeros(n_samp) - start, stop = [int(ii * float(n_samp) / (2 * n_dipoles)) - for ii in (2 * n, 2 * n + 1)] - window[start:stop] = 1. + start, stop = [ + int(ii * float(n_samp) / (2 * n_dipoles)) for ii in (2 * n, 2 * n + 1) + ] + window[start:stop] = 1.0 n += 1 - data = 25e-9 * np.sin(2. * np.pi * 10. * n * times) + data = 25e-9 * np.sin(2.0 * np.pi * 10.0 * n * times) data *= window return data -times = raw.times[:int(raw.info['sfreq'] * epoch_duration)] +times = raw.times[: int(raw.info["sfreq"] * epoch_duration)] fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] -stc = simulate_sparse_stc(src, n_dipoles=n_dipoles, times=times, - data_fun=data_fun, random_state=rng) +src = fwd["src"] +stc = simulate_sparse_stc( + src, n_dipoles=n_dipoles, times=times, data_fun=data_fun, random_state=rng +) # look at our source data fig, ax = plt.subplots(1) ax.plot(times, 1e9 * stc.data.T) -ax.set(ylabel='Amplitude (nAm)', xlabel='Time (s)') +ax.set(ylabel="Amplitude (nAm)", xlabel="Time (s)") mne.viz.utils.plt_show() ############################################################################## @@ -82,7 +89,8 @@ def data_fun(times): # Plot evoked data events = find_events(raw_sim) # only 1 pos, so event number == 1 epochs = Epochs(raw_sim, events, 1, tmin=-0.2, tmax=epoch_duration) -cov = compute_covariance(epochs, tmax=0., method='empirical', - verbose='error') # quick calc +cov = compute_covariance( + epochs, tmax=0.0, method="empirical", verbose="error" +) # quick calc evoked = epochs.average() -evoked.plot_white(cov, time_unit='s') +evoked.plot_white(cov, time_unit="s") diff --git a/examples/simulation/simulated_raw_data_using_subject_anatomy.py b/examples/simulation/simulated_raw_data_using_subject_anatomy.py index 0edb33e7d0f..af13d124383 100644 --- a/examples/simulation/simulated_raw_data_using_subject_anatomy.py +++ b/examples/simulation/simulated_raw_data_using_subject_anatomy.py @@ -34,24 +34,24 @@ # to be given to functions. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -meg_path = data_path / 'MEG' / subject +subjects_dir = data_path / "subjects" +subject = "sample" +meg_path = data_path / "MEG" / subject # %% # First, we get an info structure from the sample subject. -fname_info = meg_path / 'sample_audvis_raw.fif' +fname_info = meg_path / "sample_audvis_raw.fif" info = mne.io.read_info(fname_info) -tstep = 1 / info['sfreq'] +tstep = 1 / info["sfreq"] # %% # To simulate sources, we also need a source space. It can be obtained from the # forward solution of the sample subject. -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # %% # To simulate raw data, we need to define when the activity occurs using events @@ -60,16 +60,22 @@ # Here, both are loaded from the sample dataset, but they can also be specified # by the user. -fname_event = meg_path / 'sample_audvis_raw-eve.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' +fname_event = meg_path / "sample_audvis_raw-eve.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" events = mne.read_events(fname_event) noise_cov = mne.read_cov(fname_cov) # Standard sample event IDs. These values will correspond to the third column # in the events matrix. -event_id = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'smiley': 5, 'button': 32} +event_id = { + "auditory/left": 1, + "auditory/right": 2, + "visual/left": 3, + "visual/right": 4, + "smiley": 5, + "button": 32, +} # Take only a few events for speed @@ -92,26 +98,28 @@ # times more) than the ipsilateral. activations = { - 'auditory/left': - [('G_temp_sup-G_T_transv-lh', 30), # label, activation (nAm) - ('G_temp_sup-G_T_transv-rh', 60)], - 'auditory/right': - [('G_temp_sup-G_T_transv-lh', 60), - ('G_temp_sup-G_T_transv-rh', 30)], - 'visual/left': - [('S_calcarine-lh', 30), - ('S_calcarine-rh', 60)], - 'visual/right': - [('S_calcarine-lh', 60), - ('S_calcarine-rh', 30)], + "auditory/left": [ + ("G_temp_sup-G_T_transv-lh", 30), # label, activation (nAm) + ("G_temp_sup-G_T_transv-rh", 60), + ], + "auditory/right": [ + ("G_temp_sup-G_T_transv-lh", 60), + ("G_temp_sup-G_T_transv-rh", 30), + ], + "visual/left": [("S_calcarine-lh", 30), ("S_calcarine-rh", 60)], + "visual/right": [("S_calcarine-lh", 60), ("S_calcarine-rh", 30)], } -annot = 'aparc.a2009s' +annot = "aparc.a2009s" # Load the 4 necessary label names. -label_names = sorted(set(activation[0] - for activation_list in activations.values() - for activation in activation_list)) +label_names = sorted( + set( + activation[0] + for activation_list in activations.values() + for activation in activation_list + ) +) region_names = list(activations.keys()) # %% @@ -128,8 +136,9 @@ def data_fun(times, latency, duration): f = 15 # oscillating frequency, beta band [Hz] sigma = 0.375 * duration sinusoid = np.sin(2 * np.pi * f * (times - latency)) - gf = np.exp(- (times - latency - (sigma / 4.) * rng.rand(1)) ** 2 / - (2 * (sigma ** 2))) + gf = np.exp( + -((times - latency - (sigma / 4.0) * rng.rand(1)) ** 2) / (2 * (sigma**2)) + ) return 1e-9 * sinusoid * gf @@ -152,7 +161,7 @@ def data_fun(times, latency, duration): # event, the second is not used. The third one is the event id, which is # different for each of the 4 areas. -times = np.arange(150, dtype=np.float64) / info['sfreq'] +times = np.arange(150, dtype=np.float64) / info["sfreq"] duration = 0.03 rng = np.random.RandomState(7) source_simulator = mne.simulation.SourceSimulator(src, tstep=tstep) @@ -161,20 +170,17 @@ def data_fun(times, latency, duration): events_tmp = events[np.where(events[:, 2] == region_id)[0], :] for i in range(2): label_name = activations[region_name][i][0] - label_tmp = mne.read_labels_from_annot(subject, annot, - subjects_dir=subjects_dir, - regexp=label_name, - verbose=False) + label_tmp = mne.read_labels_from_annot( + subject, annot, subjects_dir=subjects_dir, regexp=label_name, verbose=False + ) label_tmp = label_tmp[0] amplitude_tmp = activations[region_name][i][1] - if region_name.split('/')[1][0] == label_tmp.hemi[0]: + if region_name.split("/")[1][0] == label_tmp.hemi[0]: latency_tmp = 0.115 else: latency_tmp = 0.1 wf_tmp = data_fun(times, latency_tmp, duration) - source_simulator.add_data(label_tmp, - amplitude_tmp * wf_tmp, - events_tmp) + source_simulator.add_data(label_tmp, amplitude_tmp * wf_tmp, events_tmp) # To obtain a SourceEstimate object, we need to use `get_stc()` method of # SourceSimulator class. @@ -203,17 +209,16 @@ def data_fun(times, latency, duration): mne.simulation.add_ecg(raw_sim, random_state=0) # Plot original and simulated raw data. -raw_sim.plot(title='Simulated raw data') +raw_sim.plot(title="Simulated raw data") # %% # Extract epochs and compute evoked responsses # -------------------------------------------- # -epochs = mne.Epochs(raw_sim, events, event_id, tmin=-0.2, tmax=0.3, - baseline=(None, 0)) -evoked_aud_left = epochs['auditory/left'].average() -evoked_vis_right = epochs['visual/right'].average() +epochs = mne.Epochs(raw_sim, events, event_id, tmin=-0.2, tmax=0.3, baseline=(None, 0)) +evoked_aud_left = epochs["auditory/left"].average() +evoked_vis_right = epochs["visual/right"].average() # Visualize the evoked data evoked_aud_left.plot(spatial_colors=True) @@ -229,16 +234,15 @@ def data_fun(times, latency, duration): # As expected, when high activations appear in primary auditory areas, primary # visual areas will have low activations and vice versa. -method, lambda2 = 'dSPM', 1. / 9. +method, lambda2 = "dSPM", 1.0 / 9.0 inv = mne.minimum_norm.make_inverse_operator(epochs.info, fwd, noise_cov) -stc_aud = mne.minimum_norm.apply_inverse( - evoked_aud_left, inv, lambda2, method) -stc_vis = mne.minimum_norm.apply_inverse( - evoked_vis_right, inv, lambda2, method) +stc_aud = mne.minimum_norm.apply_inverse(evoked_aud_left, inv, lambda2, method) +stc_vis = mne.minimum_norm.apply_inverse(evoked_vis_right, inv, lambda2, method) stc_diff = stc_aud - stc_vis -brain = stc_diff.plot(subjects_dir=subjects_dir, initial_time=0.1, - hemi='split', views=['lat', 'med']) +brain = stc_diff.plot( + subjects_dir=subjects_dir, initial_time=0.1, hemi="split", views=["lat", "med"] +) # %% # References diff --git a/examples/simulation/source_simulator.py b/examples/simulation/source_simulator.py index 93a348e46ca..69cb803c134 100644 --- a/examples/simulation/source_simulator.py +++ b/examples/simulation/source_simulator.py @@ -29,38 +29,39 @@ class to generate source estimates and raw data. It is meant to be a brief # This will download the data if it not already on your machine. We also set # the subjects directory so we don't need to give it to functions. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' +subjects_dir = data_path / "subjects" +subject = "sample" # %% # First, we get an info structure from the test subject. -evoked_fname = data_path / 'MEG' / subject / 'sample_audvis-ave.fif' +evoked_fname = data_path / "MEG" / subject / "sample_audvis-ave.fif" info = mne.io.read_info(evoked_fname) -tstep = 1. / info['sfreq'] +tstep = 1.0 / info["sfreq"] # %% # To simulate sources, we also need a source space. It can be obtained from the # forward solution of the sample subject. -fwd_fname = data_path / 'MEG' / subject / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = data_path / "MEG" / subject / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # %% # To select a region to activate, we use the caudal middle frontal to grow # a region of interest. selected_label = mne.read_labels_from_annot( - subject, regexp='caudalmiddlefrontal-lh', subjects_dir=subjects_dir)[0] -location = 'center' # Use the center of the region as a seed. -extent = 10. # Extent in mm of the region. + subject, regexp="caudalmiddlefrontal-lh", subjects_dir=subjects_dir +)[0] +location = "center" # Use the center of the region as a seed. +extent = 10.0 # Extent in mm of the region. label = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir) + subject, selected_label, location=location, extent=extent, subjects_dir=subjects_dir +) # %% # Define the time course of the activity for each source of the region to # activate. Here we use a sine wave at 18 Hz with a peak amplitude # of 10 nAm. -source_time_series = np.sin(2. * np.pi * 18. * np.arange(100) * tstep) * 10e-9 +source_time_series = np.sin(2.0 * np.pi * 18.0 * np.arange(100) * tstep) * 10e-9 # %% # Define when the activity occurs using events. The first column is the sample diff --git a/examples/stats/cluster_stats_evoked.py b/examples/stats/cluster_stats_evoked.py index cf2f9d59c18..1e21cdb7617 100644 --- a/examples/stats/cluster_stats_evoked.py +++ b/examples/stats/cluster_stats_evoked.py @@ -28,9 +28,9 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin = -0.2 tmax = 0.5 @@ -38,22 +38,23 @@ raw = io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) -channel = 'MEG 1332' # include only this channel in analysis +channel = "MEG 1332" # include only this channel in analysis include = [channel] # %% # Read epochs for the channel of interest -picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, - exclude='bads') +picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, exclude="bads") event_id = 1 reject = dict(grad=4000e-13, eog=150e-6) -epochs1 = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs1 = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) condition1 = epochs1.get_data() # as 3D matrix event_id = 2 -epochs2 = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs2 = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) condition2 = epochs2.get_data() # as 3D matrix condition1 = condition1[:, 0, :] # take only one channel to get a 2D array @@ -62,31 +63,36 @@ # %% # Compute statistic threshold = 6.0 -T_obs, clusters, cluster_p_values, H0 = \ - permutation_cluster_test([condition1, condition2], n_permutations=1000, - threshold=threshold, tail=1, n_jobs=None, - out_type='mask') +T_obs, clusters, cluster_p_values, H0 = permutation_cluster_test( + [condition1, condition2], + n_permutations=1000, + threshold=threshold, + tail=1, + n_jobs=None, + out_type="mask", +) # %% # Plot times = epochs1.times fig, (ax, ax2) = plt.subplots(2, 1, figsize=(8, 4)) -ax.set_title('Channel : ' + channel) -ax.plot(times, condition1.mean(axis=0) - condition2.mean(axis=0), - label="ERF Contrast (Event 1 - Event 2)") +ax.set_title("Channel : " + channel) +ax.plot( + times, + condition1.mean(axis=0) - condition2.mean(axis=0), + label="ERF Contrast (Event 1 - Event 2)", +) ax.set_ylabel("MEG (T / m)") ax.legend() for i_c, c in enumerate(clusters): c = c[0] if cluster_p_values[i_c] <= 0.05: - h = ax2.axvspan(times[c.start], times[c.stop - 1], - color='r', alpha=0.3) + h = ax2.axvspan(times[c.start], times[c.stop - 1], color="r", alpha=0.3) else: - ax2.axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), - alpha=0.3) + ax2.axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), alpha=0.3) -hf = plt.plot(times, T_obs, 'g') -ax2.legend((h, ), ('cluster p-value < 0.05', )) +hf = plt.plot(times, T_obs, "g") +ax2.legend((h,), ("cluster p-value < 0.05",)) ax2.set_xlabel("time (ms)") ax2.set_ylabel("f-values") diff --git a/examples/stats/fdr_stats_evoked.py b/examples/stats/fdr_stats_evoked.py index b90ab6f9ccd..94239f887df 100644 --- a/examples/stats/fdr_stats_evoked.py +++ b/examples/stats/fdr_stats_evoked.py @@ -30,26 +30,26 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id, tmin, tmax = 1, -0.2, 0.5 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) events = mne.read_events(event_fname)[:30] -channel = 'MEG 1332' # include only this channel in analysis +channel = "MEG 1332" # include only this channel in analysis include = [channel] # %% # Read epochs for the channel of interest -picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, - exclude='bads') +picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, exclude="bads") event_id = 1 reject = dict(grad=4000e-13, eog=150e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) X = epochs.get_data() # as 3D matrix X = X[:, 0, :] # take only one channel to get a 2D array @@ -64,22 +64,43 @@ reject_bonferroni, pval_bonferroni = bonferroni_correction(pval, alpha=alpha) threshold_bonferroni = stats.t.ppf(1.0 - alpha / n_tests, n_samples - 1) -reject_fdr, pval_fdr = fdr_correction(pval, alpha=alpha, method='indep') +reject_fdr, pval_fdr = fdr_correction(pval, alpha=alpha, method="indep") threshold_fdr = np.min(np.abs(T)[reject_fdr]) # %% # Plot times = 1e3 * epochs.times -plt.close('all') -plt.plot(times, T, 'k', label='T-stat') +plt.close("all") +plt.plot(times, T, "k", label="T-stat") xmin, xmax = plt.xlim() -plt.hlines(threshold_uncorrected, xmin, xmax, linestyle='--', colors='k', - label='p=0.05 (uncorrected)', linewidth=2) -plt.hlines(threshold_bonferroni, xmin, xmax, linestyle='--', colors='r', - label='p=0.05 (Bonferroni)', linewidth=2) -plt.hlines(threshold_fdr, xmin, xmax, linestyle='--', colors='b', - label='p=0.05 (FDR)', linewidth=2) +plt.hlines( + threshold_uncorrected, + xmin, + xmax, + linestyle="--", + colors="k", + label="p=0.05 (uncorrected)", + linewidth=2, +) +plt.hlines( + threshold_bonferroni, + xmin, + xmax, + linestyle="--", + colors="r", + label="p=0.05 (Bonferroni)", + linewidth=2, +) +plt.hlines( + threshold_fdr, + xmin, + xmax, + linestyle="--", + colors="b", + label="p=0.05 (FDR)", + linewidth=2, +) plt.legend() plt.xlabel("Time (ms)") plt.ylabel("T-stat") diff --git a/examples/stats/linear_regression_raw.py b/examples/stats/linear_regression_raw.py index 54aef70c8e2..11a21a80305 100644 --- a/examples/stats/linear_regression_raw.py +++ b/examples/stats/linear_regression_raw.py @@ -33,37 +33,47 @@ # Load and preprocess data data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = mne.io.read_raw_fif(raw_fname) -raw.pick_types(meg='grad', stim=True, eeg=False).load_data() -raw.filter(1, None, fir_design='firwin') # high-pass +raw.pick_types(meg="grad", stim=True, eeg=False).load_data() +raw.filter(1, None, fir_design="firwin") # high-pass # Set up events events = mne.find_events(raw) -event_id = {'Aud/L': 1, 'Aud/R': 2} -tmin, tmax = -.1, .5 +event_id = {"Aud/L": 1, "Aud/R": 2} +tmin, tmax = -0.1, 0.5 # regular epoching picks = mne.pick_types(raw.info, meg=True) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, reject=None, - baseline=None, preload=True, verbose=False) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + reject=None, + baseline=None, + preload=True, + verbose=False, +) # rERF -evokeds = linear_regression_raw(raw, events=events, event_id=event_id, - reject=None, tmin=tmin, tmax=tmax) +evokeds = linear_regression_raw( + raw, events=events, event_id=event_id, reject=None, tmin=tmin, tmax=tmax +) # linear_regression_raw returns a dict of evokeds # select conditions similarly to mne.Epochs objects # plot both results, and their difference cond = "Aud/L" fig, (ax1, ax2, ax3) = plt.subplots(3, 1) -params = dict(spatial_colors=True, show=False, ylim=dict(grad=(-200, 200)), - time_unit='s') +params = dict( + spatial_colors=True, show=False, ylim=dict(grad=(-200, 200)), time_unit="s" +) epochs[cond].average().plot(axes=ax1, **params) evokeds[cond].plot(axes=ax2, **params) -contrast = mne.combine_evoked([evokeds[cond], epochs[cond].average()], - weights=[1, -1]) +contrast = mne.combine_evoked([evokeds[cond], epochs[cond].average()], weights=[1, -1]) contrast.plot(axes=ax3, **params) ax1.set_title("Traditional averaging") ax2.set_title("rERF") diff --git a/examples/stats/sensor_permutation_test.py b/examples/stats/sensor_permutation_test.py index 654c9b7153c..7d54df71357 100644 --- a/examples/stats/sensor_permutation_test.py +++ b/examples/stats/sensor_permutation_test.py @@ -28,9 +28,9 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id = 1 tmin = -0.2 tmax = 0.5 @@ -40,10 +40,19 @@ events = mne.read_events(event_fname) # pick MEG Gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6)) +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), +) data = epochs.get_data() times = epochs.times @@ -62,15 +71,23 @@ # %% # View location of significantly active sensors -evoked = mne.EvokedArray(-np.log10(p_values)[:, np.newaxis], - epochs.info, tmin=0.) +evoked = mne.EvokedArray(-np.log10(p_values)[:, np.newaxis], epochs.info, tmin=0.0) # Extract mask and indices of active sensors in the layout stats_picks = mne.pick_channels(evoked.ch_names, significant_sensors_names) mask = p_values[:, np.newaxis] <= 0.05 -evoked.plot_topomap(ch_type='grad', times=[0], scalings=1, - time_format=None, cmap='Reds', vlim=(0., np.max), - units='-log10(p)', cbar_fmt='-%0.1f', mask=mask, - size=3, show_names=lambda x: x[4:] + ' ' * 20, - time_unit='s') +evoked.plot_topomap( + ch_type="grad", + times=[0], + scalings=1, + time_format=None, + cmap="Reds", + vlim=(0.0, np.max), + units="-log10(p)", + cbar_fmt="-%0.1f", + mask=mask, + size=3, + show_names=lambda x: x[4:] + " " * 20, + time_unit="s", +) diff --git a/examples/stats/sensor_regression.py b/examples/stats/sensor_regression.py index 9a1e42ae7f8..2b17927b28b 100644 --- a/examples/stats/sensor_regression.py +++ b/examples/stats/sensor_regression.py @@ -38,7 +38,7 @@ from mne.datasets import kiloword # Load the data -path = kiloword.data_path() / 'kword_metadata-epo.fif' +path = kiloword.data_path() / "kword_metadata-epo.fif" epochs = mne.read_epochs(path) print(epochs.metadata.head()) @@ -54,8 +54,9 @@ colors = {str(val): val for val in df[name].unique()} epochs.metadata = df.assign(Intercept=1) # Add an intercept for later evokeds = {val: epochs[name + " == " + val].average() for val in colors} -plot_compare_evokeds(evokeds, colors=colors, split_legend=True, - cmap=(name + " Percentile", "viridis")) +plot_compare_evokeds( + evokeds, colors=colors, split_legend=True, cmap=(name + " Percentile", "viridis") +) ############################################################################## # We observe that there appears to be a monotonic dependence of EEG on @@ -66,8 +67,9 @@ names = ["Intercept", name] res = linear_regression(epochs, epochs.metadata[names], names=names) for cond in names: - res[cond].beta.plot_joint(title=cond, ts_args=dict(time_unit='s'), - topomap_args=dict(time_unit='s')) + res[cond].beta.plot_joint( + title=cond, ts_args=dict(time_unit="s"), topomap_args=dict(time_unit="s") + ) ############################################################################## # Because the :func:`~mne.stats.linear_regression` function also estimates @@ -81,4 +83,4 @@ # by dark contour lines. reject_H0, fdr_pvals = fdr_correction(res["Concreteness"].p_val.data) evoked = res["Concreteness"].beta -evoked.plot_image(mask=reject_H0, time_unit='s') +evoked.plot_image(mask=reject_H0, time_unit="s") diff --git a/examples/time_frequency/compute_csd.py b/examples/time_frequency/compute_csd.py index e9a962bb733..0de5482be1e 100644 --- a/examples/time_frequency/compute_csd.py +++ b/examples/time_frequency/compute_csd.py @@ -35,9 +35,9 @@ # %% # Loading the sample dataset. data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_raw = meg_path / 'sample_audvis_raw.fif' -fname_event = meg_path / 'sample_audvis_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +fname_raw = meg_path / "sample_audvis_raw.fif" +fname_event = meg_path / "sample_audvis_raw-eve.fif" raw = mne.io.read_raw_fif(fname_raw) events = mne.read_events(fname_event) @@ -47,12 +47,20 @@ # measurement units, and thus the scalings, differ across sensors. In this # example, for speed and clarity, we select a single channel type: # gradiometers. -picks = mne.pick_types(raw.info, meg='grad') +picks = mne.pick_types(raw.info, meg="grad") # Make some epochs, based on events with trigger code 1 -epochs = mne.Epochs(raw, events, event_id=1, tmin=-0.2, tmax=1, - picks=picks, baseline=(None, 0), - reject=dict(grad=4000e-13), preload=True) +epochs = mne.Epochs( + raw, + events, + event_id=1, + tmin=-0.2, + tmax=1, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13), + preload=True, +) # %% # Computing CSD matrices using short-term Fourier transform and (adaptive) @@ -85,9 +93,11 @@ # created figures; in this case, each returned list has only one figure # so we use a Python trick of including a comma after our variable name # to assign the figure (not the list) to our ``fig`` variable: -plot_dict = {'Short-time Fourier transform': csd_fft, - 'Adaptive multitapers': csd_mt, - 'Morlet wavelet transform': csd_wav} +plot_dict = { + "Short-time Fourier transform": csd_fft, + "Adaptive multitapers": csd_mt, + "Morlet wavelet transform": csd_wav, +} for title, csd in plot_dict.items(): - fig, = csd.mean().plot() + (fig,) = csd.mean().plot() fig.suptitle(title) diff --git a/examples/time_frequency/compute_source_psd_epochs.py b/examples/time_frequency/compute_source_psd_epochs.py index 1ca42643f49..745fc69717e 100644 --- a/examples/time_frequency/compute_source_psd_epochs.py +++ b/examples/time_frequency/compute_source_psd_epochs.py @@ -24,17 +24,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = meg_path / 'sample_audvis_raw.fif' -fname_event = meg_path / 'sample_audvis_raw-eve.fif' -label_name = 'Aud-lh' -fname_label = meg_path / 'labels' / f'{label_name}.label' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = meg_path / "sample_audvis_raw.fif" +fname_event = meg_path / "sample_audvis_raw-eve.fif" +label_name = "Aud-lh" +fname_label = meg_path / "labels" / f"{label_name}.label" +subjects_dir = data_path / "subjects" event_id, tmin, tmax = 1, -0.2, 0.5 snr = 1.0 # use smaller SNR for raw data -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data @@ -45,19 +45,27 @@ # Set up pick list include = [] -raw.info['bads'] += ['EEG 053'] # bads + 1 more +raw.info["bads"] += ["EEG 053"] # bads + 1 more # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, include=include, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # define frequencies of interest -fmin, fmax = 0., 70. -bandwidth = 4. # bandwidth of the windows in Hz +fmin, fmax = 0.0, 70.0 +bandwidth = 4.0 # bandwidth of the windows in Hz # %% # Compute source space PSD in label @@ -68,14 +76,21 @@ # keep everything in memory. n_epochs_use = 10 -stcs = compute_source_psd_epochs(epochs[:n_epochs_use], inverse_operator, - lambda2=lambda2, - method=method, fmin=fmin, fmax=fmax, - bandwidth=bandwidth, label=label, - return_generator=True, verbose=True) +stcs = compute_source_psd_epochs( + epochs[:n_epochs_use], + inverse_operator, + lambda2=lambda2, + method=method, + fmin=fmin, + fmax=fmax, + bandwidth=bandwidth, + label=label, + return_generator=True, + verbose=True, +) # compute average PSD over the first 10 epochs -psd_avg = 0. +psd_avg = 0.0 for i, stc in enumerate(stcs): psd_avg += stc.data psd_avg /= n_epochs_use @@ -85,16 +100,21 @@ # %% # Visualize the 10 Hz PSD: -brain = stc.plot(initial_time=10., hemi='lh', views='lat', # 10 HZ - clim=dict(kind='value', lims=(20, 40, 60)), - smoothing_steps=3, subjects_dir=subjects_dir) -brain.add_label(label, borders=True, color='k') +brain = stc.plot( + initial_time=10.0, + hemi="lh", + views="lat", # 10 HZ + clim=dict(kind="value", lims=(20, 40, 60)), + smoothing_steps=3, + subjects_dir=subjects_dir, +) +brain.add_label(label, borders=True, color="k") # %% # Visualize the entire spectrum: fig, ax = plt.subplots() ax.plot(freqs, psd_avg.mean(axis=0)) -ax.set_xlabel('Freq (Hz)') +ax.set_xlabel("Freq (Hz)") ax.set_xlim(stc.times[[0, -1]]) -ax.set_ylabel('Power Spectral Density') +ax.set_ylabel("Power Spectral Density") diff --git a/examples/time_frequency/source_label_time_frequency.py b/examples/time_frequency/source_label_time_frequency.py index 721c2fc4d2d..da3af06e4dc 100644 --- a/examples/time_frequency/source_label_time_frequency.py +++ b/examples/time_frequency/source_label_time_frequency.py @@ -32,50 +32,66 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -label_name = 'Aud-rh' -fname_label = meg_path / 'labels' / f'{label_name}.label' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +label_name = "Aud-rh" +fname_label = meg_path / "labels" / f"{label_name}.label" tmin, tmax, event_id = -0.2, 0.5, 2 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) include = [] -raw.info['bads'] += ['MEG 2443', 'EEG 053'] # bads + 2 more +raw.info["bads"] += ["MEG 2443", "EEG 053"] # bads + 2 more # Picks MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, include=include, exclude="bads" +) reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) # Load epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject, - preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=reject, + preload=True, +) # Compute a source estimate per frequency band including and excluding the # evoked response freqs = np.arange(7, 30, 2) # define frequencies of interest label = mne.read_label(fname_label) -n_cycles = freqs / 3. # different number of cycle per frequency +n_cycles = freqs / 3.0 # different number of cycle per frequency # subtract the evoked response in order to exclude evoked activity epochs_induced = epochs.copy().subtract_evoked() -plt.close('all') +plt.close("all") -for ii, (this_epochs, title) in enumerate(zip([epochs, epochs_induced], - ['evoked + induced', - 'induced only'])): +for ii, (this_epochs, title) in enumerate( + zip([epochs, epochs_induced], ["evoked + induced", "induced only"]) +): # compute the source space power and the inter-trial coherence power, itc = source_induced_power( - this_epochs, inverse_operator, freqs, label, baseline=(-0.1, 0), - baseline_mode='percent', n_cycles=n_cycles, n_jobs=None) + this_epochs, + inverse_operator, + freqs, + label, + baseline=(-0.1, 0), + baseline_mode="percent", + n_cycles=n_cycles, + n_jobs=None, + ) power = np.mean(power, axis=0) # average over sources itc = np.mean(itc, axis=0) # average over sources @@ -85,22 +101,33 @@ # View time-frequency plots plt.subplots_adjust(0.1, 0.08, 0.96, 0.94, 0.2, 0.43) plt.subplot(2, 2, 2 * ii + 1) - plt.imshow(20 * power, - extent=[times[0], times[-1], freqs[0], freqs[-1]], - aspect='auto', origin='lower', vmin=0., vmax=30., cmap='RdBu_r') - plt.xlabel('Time (s)') - plt.ylabel('Frequency (Hz)') - plt.title('Power (%s)' % title) + plt.imshow( + 20 * power, + extent=[times[0], times[-1], freqs[0], freqs[-1]], + aspect="auto", + origin="lower", + vmin=0.0, + vmax=30.0, + cmap="RdBu_r", + ) + plt.xlabel("Time (s)") + plt.ylabel("Frequency (Hz)") + plt.title("Power (%s)" % title) plt.colorbar() plt.subplot(2, 2, 2 * ii + 2) - plt.imshow(itc, - extent=[times[0], times[-1], freqs[0], freqs[-1]], - aspect='auto', origin='lower', vmin=0, vmax=0.7, - cmap='RdBu_r') - plt.xlabel('Time (s)') - plt.ylabel('Frequency (Hz)') - plt.title('ITC (%s)' % title) + plt.imshow( + itc, + extent=[times[0], times[-1], freqs[0], freqs[-1]], + aspect="auto", + origin="lower", + vmin=0, + vmax=0.7, + cmap="RdBu_r", + ) + plt.xlabel("Time (s)") + plt.ylabel("Frequency (Hz)") + plt.title("ITC (%s)" % title) plt.colorbar() plt.show() diff --git a/examples/time_frequency/source_power_spectrum.py b/examples/time_frequency/source_power_spectrum.py index 4b6d582d50b..a2aab813930 100644 --- a/examples/time_frequency/source_power_spectrum.py +++ b/examples/time_frequency/source_power_spectrum.py @@ -26,37 +26,48 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_label = meg_path / 'labels' / 'Aud-lh.label' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_label = meg_path / "labels" / "Aud-lh.label" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, verbose=False) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] +raw.info["bads"] = ["MEG 2443", "EEG 053"] # picks MEG gradiometers -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, exclude="bads" +) tmin, tmax = 0, 120 # use the first 120s of data fmin, fmax = 4, 100 # look at frequencies between 4 and 100Hz n_fft = 2048 # the FFT size (n_fft). Ideally a power of 2 label = mne.read_label(fname_label) -stc = compute_source_psd(raw, inverse_operator, lambda2=1. / 9., method="dSPM", - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - pick_ori="normal", n_fft=n_fft, label=label, - dB=True) +stc = compute_source_psd( + raw, + inverse_operator, + lambda2=1.0 / 9.0, + method="dSPM", + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + pick_ori="normal", + n_fft=n_fft, + label=label, + dB=True, +) -stc.save('psd_dSPM', overwrite=True) +stc.save("psd_dSPM", overwrite=True) # %% # View PSD of sources in label plt.plot(stc.times, stc.data.T) -plt.xlabel('Frequency (Hz)') -plt.ylabel('PSD (dB)') -plt.title('Source Power Spectrum (PSD)') +plt.xlabel("Frequency (Hz)") +plt.ylabel("PSD (dB)") +plt.title("Source Power Spectrum (PSD)") plt.show() diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index 462f79c8eb9..0e5cf5d34c8 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -35,19 +35,19 @@ print(__doc__) data_path = mne.datasets.opm.data_path() -subject = 'OPM_sample' - -subjects_dir = data_path / 'subjects' -bem_dir = subjects_dir / subject / 'bem' -bem_fname = bem_dir / f'{subject}-5120-5120-5120-bem-sol.fif' -src_fname = bem_dir / f'{subject}-oct6-src.fif' -vv_fname = data_path / 'MEG' / 'SQUID' / 'SQUID_resting_state.fif' -vv_erm_fname = data_path / 'MEG' / 'SQUID' / 'SQUID_empty_room.fif' -vv_trans_fname = data_path / 'MEG' / 'SQUID' / 'SQUID-trans.fif' -opm_fname = data_path / 'MEG' / 'OPM' / 'OPM_resting_state_raw.fif' -opm_erm_fname = data_path / 'MEG' / 'OPM' / 'OPM_empty_room_raw.fif' -opm_trans = mne.transforms.Transform('head', 'mri') # use identity transform -opm_coil_def_fname = data_path / 'MEG' / 'OPM' / 'coil_def.dat' +subject = "OPM_sample" + +subjects_dir = data_path / "subjects" +bem_dir = subjects_dir / subject / "bem" +bem_fname = bem_dir / f"{subject}-5120-5120-5120-bem-sol.fif" +src_fname = bem_dir / f"{subject}-oct6-src.fif" +vv_fname = data_path / "MEG" / "SQUID" / "SQUID_resting_state.fif" +vv_erm_fname = data_path / "MEG" / "SQUID" / "SQUID_empty_room.fif" +vv_trans_fname = data_path / "MEG" / "SQUID" / "SQUID-trans.fif" +opm_fname = data_path / "MEG" / "OPM" / "OPM_resting_state_raw.fif" +opm_erm_fname = data_path / "MEG" / "OPM" / "OPM_empty_room_raw.fif" +opm_trans = mne.transforms.Transform("head", "mri") # use identity transform +opm_coil_def_fname = data_path / "MEG" / "OPM" / "coil_def.dat" ############################################################################## # Load data, resample. We will store the raw objects in dicts with entries @@ -55,28 +55,28 @@ raws = dict() raw_erms = dict() -new_sfreq = 60. # Nyquist frequency (30 Hz) < line noise freq (50 Hz) -raws['vv'] = mne.io.read_raw_fif(vv_fname, verbose='error') # ignore naming -raws['vv'].load_data().resample(new_sfreq) -raws['vv'].info['bads'] = ['MEG2233', 'MEG1842'] -raw_erms['vv'] = mne.io.read_raw_fif(vv_erm_fname, verbose='error') -raw_erms['vv'].load_data().resample(new_sfreq) -raw_erms['vv'].info['bads'] = ['MEG2233', 'MEG1842'] - -raws['opm'] = mne.io.read_raw_fif(opm_fname) -raws['opm'].load_data().resample(new_sfreq) -raw_erms['opm'] = mne.io.read_raw_fif(opm_erm_fname) -raw_erms['opm'].load_data().resample(new_sfreq) +new_sfreq = 60.0 # Nyquist frequency (30 Hz) < line noise freq (50 Hz) +raws["vv"] = mne.io.read_raw_fif(vv_fname, verbose="error") # ignore naming +raws["vv"].load_data().resample(new_sfreq) +raws["vv"].info["bads"] = ["MEG2233", "MEG1842"] +raw_erms["vv"] = mne.io.read_raw_fif(vv_erm_fname, verbose="error") +raw_erms["vv"].load_data().resample(new_sfreq) +raw_erms["vv"].info["bads"] = ["MEG2233", "MEG1842"] + +raws["opm"] = mne.io.read_raw_fif(opm_fname) +raws["opm"].load_data().resample(new_sfreq) +raw_erms["opm"] = mne.io.read_raw_fif(opm_erm_fname) +raw_erms["opm"].load_data().resample(new_sfreq) # Make sure our assumptions later hold -assert raws['opm'].info['sfreq'] == raws['vv'].info['sfreq'] +assert raws["opm"].info["sfreq"] == raws["vv"].info["sfreq"] ############################################################################## # Explore data -titles = dict(vv='VectorView', opm='OPM') -kinds = ('vv', 'opm') +titles = dict(vv="VectorView", opm="OPM") +kinds = ("vv", "opm") n_fft = next_fast_len(int(round(4 * new_sfreq))) -print('Using n_fft=%d (%0.1f s)' % (n_fft, n_fft / raws['vv'].info['sfreq'])) +print("Using n_fft=%d (%0.1f s)" % (n_fft, n_fft / raws["vv"].info["sfreq"])) for kind in kinds: fig = raws[kind].compute_psd(n_fft=n_fft, proj=True).plot() fig.suptitle(titles[kind]) @@ -87,37 +87,48 @@ # --------------------- # Here we use a reduced size source space (oct5) just for speed -src = mne.setup_source_space( - subject, 'oct5', add_dist=False, subjects_dir=subjects_dir) +src = mne.setup_source_space(subject, "oct5", add_dist=False, subjects_dir=subjects_dir) # This line removes source-to-source distances that we will not need. # We only do it here to save a bit of memory, in general this is not required. -del src[0]['dist'], src[1]['dist'] +del src[0]["dist"], src[1]["dist"] bem = mne.read_bem_solution(bem_fname) # For speed, let's just use a 1-layer BEM -bem = mne.make_bem_solution(bem['surfs'][-1:]) +bem = mne.make_bem_solution(bem["surfs"][-1:]) fwd = dict() # check alignment and generate forward for VectorView -kwargs = dict(azimuth=0, elevation=90, distance=0.6, focalpoint=(0., 0., 0.)) +kwargs = dict(azimuth=0, elevation=90, distance=0.6, focalpoint=(0.0, 0.0, 0.0)) fig = mne.viz.plot_alignment( - raws['vv'].info, trans=vv_trans_fname, subject=subject, - subjects_dir=subjects_dir, dig=True, coord_frame='mri', - surfaces=('head', 'white')) + raws["vv"].info, + trans=vv_trans_fname, + subject=subject, + subjects_dir=subjects_dir, + dig=True, + coord_frame="mri", + surfaces=("head", "white"), +) mne.viz.set_3d_view(figure=fig, **kwargs) -fwd['vv'] = mne.make_forward_solution( - raws['vv'].info, vv_trans_fname, src, bem, eeg=False, verbose=True) +fwd["vv"] = mne.make_forward_solution( + raws["vv"].info, vv_trans_fname, src, bem, eeg=False, verbose=True +) ############################################################################## # And for OPM: with mne.use_coil_def(opm_coil_def_fname): fig = mne.viz.plot_alignment( - raws['opm'].info, trans=opm_trans, subject=subject, - subjects_dir=subjects_dir, dig=False, coord_frame='mri', - surfaces=('head', 'white')) + raws["opm"].info, + trans=opm_trans, + subject=subject, + subjects_dir=subjects_dir, + dig=False, + coord_frame="mri", + surfaces=("head", "white"), + ) mne.viz.set_3d_view(figure=fig, **kwargs) - fwd['opm'] = mne.make_forward_solution( - raws['opm'].info, opm_trans, src, bem, eeg=False, verbose=True) + fwd["opm"] = mne.make_forward_solution( + raws["opm"].info, opm_trans, src, bem, eeg=False, verbose=True + ) del src, bem @@ -131,24 +142,29 @@ topos = dict(vv=dict(), opm=dict()) stcs = dict(vv=dict(), opm=dict()) -snr = 3. -lambda2 = 1. / snr ** 2 +snr = 3.0 +lambda2 = 1.0 / snr**2 for kind in kinds: noise_cov = mne.compute_raw_covariance(raw_erms[kind]) inverse_operator = mne.minimum_norm.make_inverse_operator( - raws[kind].info, forward=fwd[kind], noise_cov=noise_cov, verbose=True) + raws[kind].info, forward=fwd[kind], noise_cov=noise_cov, verbose=True + ) stc_psd, sensor_psd = mne.minimum_norm.compute_source_psd( - raws[kind], inverse_operator, lambda2=lambda2, - n_fft=n_fft, dB=False, return_sensor=True, verbose=True) + raws[kind], + inverse_operator, + lambda2=lambda2, + n_fft=n_fft, + dB=False, + return_sensor=True, + verbose=True, + ) topo_norm = sensor_psd.data.sum(axis=1, keepdims=True) stc_norm = stc_psd.sum() # same operation on MNE object, sum across freqs # Normalize each source point by the total power across freqs for band, limits in freq_bands.items(): data = sensor_psd.copy().crop(*limits).data.sum(axis=1, keepdims=True) - topos[kind][band] = mne.EvokedArray( - 100 * data / topo_norm, sensor_psd.info) - stcs[kind][band] = \ - 100 * stc_psd.copy().crop(*limits).sum() / stc_norm.data + topos[kind][band] = mne.EvokedArray(100 * data / topo_norm, sensor_psd.info) + stcs[kind][band] = 100 * stc_psd.copy().crop(*limits).sum() / stc_norm.data del inverse_operator del fwd, raws, raw_erms @@ -161,22 +177,42 @@ # Alpha # ----- + def plot_band(kind, band): """Plot activity within a frequency band on the subject's brain.""" - title = "%s %s\n(%d-%d Hz)" % ((titles[kind], band,) + freq_bands[band]) + title = "%s %s\n(%d-%d Hz)" % ( + ( + titles[kind], + band, + ) + + freq_bands[band] + ) topos[kind][band].plot_topomap( - times=0., scalings=1., cbar_fmt='%0.1f', vlim=(0, None), - cmap='inferno', time_format=title) + times=0.0, + scalings=1.0, + cbar_fmt="%0.1f", + vlim=(0, None), + cmap="inferno", + time_format=title, + ) brain = stcs[kind][band].plot( - subject=subject, subjects_dir=subjects_dir, views='cau', hemi='both', - time_label=title, title=title, colormap='inferno', - time_viewer=False, show_traces=False, - clim=dict(kind='percent', lims=(70, 85, 99)), smoothing_steps=10) + subject=subject, + subjects_dir=subjects_dir, + views="cau", + hemi="both", + time_label=title, + title=title, + colormap="inferno", + time_viewer=False, + show_traces=False, + clim=dict(kind="percent", lims=(70, 85, 99)), + smoothing_steps=10, + ) brain.show_view(azimuth=0, elevation=0, roll=0) return fig, brain -fig_alpha, brain_alpha = plot_band('vv', 'alpha') +fig_alpha, brain_alpha = plot_band("vv", "alpha") # %% # Beta @@ -184,13 +220,13 @@ def plot_band(kind, band): # Here we also show OPM data, which shows a profile similar to the VectorView # data beneath the sensors. VectorView first: -fig_beta, brain_beta = plot_band('vv', 'beta') +fig_beta, brain_beta = plot_band("vv", "beta") # %% # Then OPM: # sphinx_gallery_thumbnail_number = 10 -fig_beta_opm, brain_beta_opm = plot_band('opm', 'beta') +fig_beta_opm, brain_beta_opm = plot_band("opm", "beta") # %% # References diff --git a/examples/time_frequency/source_space_time_frequency.py b/examples/time_frequency/source_space_time_frequency.py index a0a5f944439..61c3959c232 100644 --- a/examples/time_frequency/source_space_time_frequency.py +++ b/examples/time_frequency/source_space_time_frequency.py @@ -28,46 +28,57 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" tmin, tmax, event_id = -0.2, 0.5, 1 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) include = [] -raw.info['bads'] += ['MEG 2443', 'EEG 053'] # bads + 2 more +raw.info["bads"] += ["MEG 2443", "EEG 053"] # bads + 2 more # picks MEG gradiometers -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, include=include, exclude="bads" +) # Load condition 1 event_id = 1 events = events[:10] # take 10 events to keep the computation time low # Use linear detrend to reduce any edge artifacts -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6), - preload=True, detrend=1) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), + preload=True, + detrend=1, +) # Compute a source estimate per frequency band bands = dict(alpha=[9, 11], beta=[18, 22]) -stcs = source_band_induced_power(epochs, inverse_operator, bands, n_cycles=2, - use_fft=False, n_jobs=None) +stcs = source_band_induced_power( + epochs, inverse_operator, bands, n_cycles=2, use_fft=False, n_jobs=None +) for b, stc in stcs.items(): - stc.save('induced_power_%s' % b, overwrite=True) + stc.save("induced_power_%s" % b, overwrite=True) # %% # plot mean power -plt.plot(stcs['alpha'].times, stcs['alpha'].data.mean(axis=0), label='Alpha') -plt.plot(stcs['beta'].times, stcs['beta'].data.mean(axis=0), label='Beta') -plt.xlabel('Time (ms)') -plt.ylabel('Power') +plt.plot(stcs["alpha"].times, stcs["alpha"].data.mean(axis=0), label="Alpha") +plt.plot(stcs["beta"].times, stcs["beta"].data.mean(axis=0), label="Beta") +plt.xlabel("Time (ms)") +plt.ylabel("Power") plt.legend() -plt.title('Mean source induced power') +plt.title("Mean source induced power") plt.show() diff --git a/examples/time_frequency/temporal_whitening.py b/examples/time_frequency/temporal_whitening.py index 068abad7337..de70216461b 100644 --- a/examples/time_frequency/temporal_whitening.py +++ b/examples/time_frequency/temporal_whitening.py @@ -26,17 +26,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -proj_fname = meg_path / 'sample_audvis_ecg-proj.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +proj_fname = meg_path / "sample_audvis_ecg-proj.fif" raw = mne.io.read_raw_fif(raw_fname) proj = mne.read_proj(proj_fname) raw.add_proj(proj) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels +raw.info["bads"] = ["MEG 2443", "EEG 053"] # mark bad channels # Set up pick list: Gradiometers - bad channels -picks = mne.pick_types(raw.info, meg='grad', exclude='bads') +picks = mne.pick_types(raw.info, meg="grad", exclude="bads") order = 5 # define model order picks = picks[:1] @@ -45,21 +45,21 @@ b, a = fit_iir_model_raw(raw, order=order, picks=picks, tmin=60, tmax=180) d, times = raw[0, 10000:20000] # look at one channel from now on d = d.ravel() # make flat vector -innovation = signal.convolve(d, a, 'valid') +innovation = signal.convolve(d, a, "valid") d_ = signal.lfilter(b, a, innovation) # regenerate the signal d_ = np.r_[d_[0] * np.ones(order), d_] # dummy samples to keep signal length # %% # Plot the different time series and PSDs -plt.close('all') +plt.close("all") plt.figure() -plt.plot(d[:100], label='signal') -plt.plot(d_[:100], label='regenerated signal') +plt.plot(d[:100], label="signal") +plt.plot(d_[:100], label="regenerated signal") plt.legend() plt.figure() -plt.psd(d, Fs=raw.info['sfreq'], NFFT=2048) -plt.psd(innovation, Fs=raw.info['sfreq'], NFFT=2048) -plt.psd(d_, Fs=raw.info['sfreq'], NFFT=2048, linestyle='--') -plt.legend(('Signal', 'Innovation', 'Regenerated signal')) +plt.psd(d, Fs=raw.info["sfreq"], NFFT=2048) +plt.psd(innovation, Fs=raw.info["sfreq"], NFFT=2048) +plt.psd(d_, Fs=raw.info["sfreq"], NFFT=2048, linestyle="--") +plt.legend(("Signal", "Innovation", "Regenerated signal")) plt.show() diff --git a/examples/time_frequency/time_frequency_erds.py b/examples/time_frequency/time_frequency_erds.py index d55122c232b..72b5f36d172 100644 --- a/examples/time_frequency/time_frequency_erds.py +++ b/examples/time_frequency/time_frequency_erds.py @@ -52,7 +52,7 @@ fnames = eegbci.load_data(subject=1, runs=(6, 10, 14)) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames]) -raw.rename_channels(lambda x: x.strip('.')) # remove dots from channel names +raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3)) @@ -61,8 +61,16 @@ tmin, tmax = -1, 4 event_ids = dict(hands=2, feet=3) # map event IDs to tasks -epochs = mne.Epochs(raw, events, event_ids, tmin - 0.5, tmax + 0.5, - picks=('C3', 'Cz', 'C4'), baseline=None, preload=True) +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin - 0.5, + tmax + 0.5, + picks=("C3", "Cz", "C4"), + baseline=None, + preload=True, +) # %% # .. _cnorm-example: @@ -80,20 +88,29 @@ baseline = (-1, 0) # baseline interval (in s) cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax) # min, center & max ERDS -kwargs = dict(n_permutations=100, step_down_p=0.05, seed=1, - buffer_size=None, out_type='mask') # for cluster test +kwargs = dict( + n_permutations=100, step_down_p=0.05, seed=1, buffer_size=None, out_type="mask" +) # for cluster test # %% # Finally, we perform time/frequency decomposition over all epochs. -tfr = tfr_multitaper(epochs, freqs=freqs, n_cycles=freqs, use_fft=True, - return_itc=False, average=False, decim=2) +tfr = tfr_multitaper( + epochs, + freqs=freqs, + n_cycles=freqs, + use_fft=True, + return_itc=False, + average=False, + decim=2, +) tfr.crop(tmin, tmax).apply_baseline(baseline, mode="percent") for event in event_ids: # select desired epochs for visualization tfr_ev = tfr[event] - fig, axes = plt.subplots(1, 4, figsize=(12, 4), - gridspec_kw={"width_ratios": [10, 10, 10, 1]}) + fig, axes = plt.subplots( + 1, 4, figsize=(12, 4), gridspec_kw={"width_ratios": [10, 10, 10, 1]} + ) for ch, ax in enumerate(axes[:-1]): # for each channel # positive clusters _, c1, p1, _ = pcluster_test(tfr_ev.data[:, ch], tail=1, **kwargs) @@ -108,9 +125,16 @@ mask = c[..., p <= 0.05].any(axis=-1) # plot TFR (ERDS map with masking) - tfr_ev.average().plot([ch], cmap="RdBu", cnorm=cnorm, axes=ax, - colorbar=False, show=False, mask=mask, - mask_style="mask") + tfr_ev.average().plot( + [ch], + cmap="RdBu", + cnorm=cnorm, + axes=ax, + colorbar=False, + show=False, + mask=mask, + mask_style="mask", + ) ax.set_title(epochs.ch_names[ch], fontsize=10) ax.axvline(0, linewidth=1, color="black", linestyle=":") # event @@ -139,33 +163,28 @@ df = tfr.to_data_frame(time_format=None, long_format=True) # Map to frequency bands: -freq_bounds = {'_': 0, - 'delta': 3, - 'theta': 7, - 'alpha': 13, - 'beta': 35, - 'gamma': 140} -df['band'] = pd.cut(df['freq'], list(freq_bounds.values()), - labels=list(freq_bounds)[1:]) +freq_bounds = {"_": 0, "delta": 3, "theta": 7, "alpha": 13, "beta": 35, "gamma": 140} +df["band"] = pd.cut( + df["freq"], list(freq_bounds.values()), labels=list(freq_bounds)[1:] +) # Filter to retain only relevant frequency bands: -freq_bands_of_interest = ['delta', 'theta', 'alpha', 'beta'] +freq_bands_of_interest = ["delta", "theta", "alpha", "beta"] df = df[df.band.isin(freq_bands_of_interest)] -df['band'] = df['band'].cat.remove_unused_categories() +df["band"] = df["band"].cat.remove_unused_categories() # Order channels for plotting: -df['channel'] = df['channel'].cat.reorder_categories(('C3', 'Cz', 'C4'), - ordered=True) +df["channel"] = df["channel"].cat.reorder_categories(("C3", "Cz", "C4"), ordered=True) -g = sns.FacetGrid(df, row='band', col='channel', margin_titles=True) -g.map(sns.lineplot, 'time', 'value', 'condition', n_boot=10) -axline_kw = dict(color='black', linestyle='dashed', linewidth=0.5, alpha=0.5) +g = sns.FacetGrid(df, row="band", col="channel", margin_titles=True) +g.map(sns.lineplot, "time", "value", "condition", n_boot=10) +axline_kw = dict(color="black", linestyle="dashed", linewidth=0.5, alpha=0.5) g.map(plt.axhline, y=0, **axline_kw) g.map(plt.axvline, x=0, **axline_kw) g.set(ylim=(None, 1.5)) g.set_axis_labels("Time (s)", "ERDS (%)") g.set_titles(col_template="{col_name}", row_template="{row_name}") -g.add_legend(ncol=2, loc='lower center') +g.add_legend(ncol=2, loc="lower center") g.fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.08) # %% @@ -174,17 +193,27 @@ # Here, we use seaborn to plot the average ERDS in the motor imagery interval # as a function of frequency band and imagery condition: -df_mean = (df.query('time > 1') - .groupby(['condition', 'epoch', 'band', 'channel'])[['value']] - .mean() - .reset_index()) - -g = sns.FacetGrid(df_mean, col='condition', col_order=['hands', 'feet'], - margin_titles=True) -g = (g.map(sns.violinplot, 'channel', 'value', 'band', n_boot=10, - palette='deep', order=['C3', 'Cz', 'C4'], - hue_order=freq_bands_of_interest, - linewidth=0.5).add_legend(ncol=4, loc='lower center')) +df_mean = ( + df.query("time > 1") + .groupby(["condition", "epoch", "band", "channel"])[["value"]] + .mean() + .reset_index() +) + +g = sns.FacetGrid( + df_mean, col="condition", col_order=["hands", "feet"], margin_titles=True +) +g = g.map( + sns.violinplot, + "channel", + "value", + "band", + n_boot=10, + palette="deep", + order=["C3", "Cz", "C4"], + hue_order=freq_bands_of_interest, + linewidth=0.5, +).add_legend(ncol=4, loc="lower center") g.map(plt.axhline, **axline_kw) g.set_axis_labels("", "ERDS (%)") diff --git a/examples/time_frequency/time_frequency_global_field_power.py b/examples/time_frequency/time_frequency_global_field_power.py index a9af92cdde9..df816162f1c 100644 --- a/examples/time_frequency/time_frequency_global_field_power.py +++ b/examples/time_frequency/time_frequency_global_field_power.py @@ -54,47 +54,52 @@ # %% # Set parameters data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" # let's explore some frequency bands -iter_freqs = [ - ('Theta', 4, 7), - ('Alpha', 8, 12), - ('Beta', 13, 25), - ('Gamma', 30, 45) -] +iter_freqs = [("Theta", 4, 7), ("Alpha", 8, 12), ("Beta", 13, 25), ("Gamma", 30, 45)] # %% # We create average power time courses for each frequency band # set epoching parameters -event_id, tmin, tmax = 1, -1., 3. +event_id, tmin, tmax = 1, -1.0, 3.0 baseline = None # get the header to extract events raw = mne.io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") frequency_map = list() for band, fmin, fmax in iter_freqs: # (re)load the data to save memory raw = mne.io.read_raw_fif(raw_fname) - raw.pick_types(meg='grad', eog=True) # we just look at gradiometers + raw.pick_types(meg="grad", eog=True) # we just look at gradiometers raw.load_data() # bandpass filter - raw.filter(fmin, fmax, n_jobs=None, # use more jobs to speed up. - l_trans_bandwidth=1, # make sure filter params are the same - h_trans_bandwidth=1) # in each band and skip "auto" option. + raw.filter( + fmin, + fmax, + n_jobs=None, # use more jobs to speed up. + l_trans_bandwidth=1, # make sure filter params are the same + h_trans_bandwidth=1, + ) # in each band and skip "auto" option. # epoch - epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=baseline, - reject=dict(grad=4000e-13, eog=350e-6), - preload=True) + epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + baseline=baseline, + reject=dict(grad=4000e-13, eog=350e-6), + preload=True, + ) # remove evoked response epochs.subtract_evoked() @@ -115,30 +120,34 @@ # Helper function for plotting spread def stat_fun(x): """Return sum of squares.""" - return np.sum(x ** 2, axis=0) + return np.sum(x**2, axis=0) # Plot fig, axes = plt.subplots(4, 1, figsize=(10, 7), sharex=True, sharey=True) -colors = plt.colormaps['winter_r'](np.linspace(0, 1, 4)) +colors = plt.colormaps["winter_r"](np.linspace(0, 1, 4)) for ((freq_name, fmin, fmax), average), color, ax in zip( - frequency_map, colors, axes.ravel()[::-1]): + frequency_map, colors, axes.ravel()[::-1] +): times = average.times * 1e3 - gfp = np.sum(average.data ** 2, axis=0) + gfp = np.sum(average.data**2, axis=0) gfp = mne.baseline.rescale(gfp, times, baseline=(None, 0)) ax.plot(times, gfp, label=freq_name, color=color, linewidth=2.5) - ax.axhline(0, linestyle='--', color='grey', linewidth=2) - ci_low, ci_up = bootstrap_confidence_interval(average.data, random_state=0, - stat_fun=stat_fun) + ax.axhline(0, linestyle="--", color="grey", linewidth=2) + ci_low, ci_up = bootstrap_confidence_interval( + average.data, random_state=0, stat_fun=stat_fun + ) ci_low = rescale(ci_low, average.times, baseline=(None, 0)) ci_up = rescale(ci_up, average.times, baseline=(None, 0)) ax.fill_between(times, gfp + ci_up, gfp - ci_low, color=color, alpha=0.3) ax.grid(True) - ax.set_ylabel('GFP') - ax.annotate('%s (%d-%dHz)' % (freq_name, fmin, fmax), - xy=(0.95, 0.8), - horizontalalignment='right', - xycoords='axes fraction') + ax.set_ylabel("GFP") + ax.annotate( + "%s (%d-%dHz)" % (freq_name, fmin, fmax), + xy=(0.95, 0.8), + horizontalalignment="right", + xycoords="axes fraction", + ) ax.set_xlim(-1000, 3000) -axes.ravel()[-1].set_xlabel('Time [ms]') +axes.ravel()[-1].set_xlabel("Time [ms]") diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index c84803d7d2f..bf8b1dba6ca 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -26,8 +26,13 @@ from mne import create_info, Epochs from mne.baseline import rescale from mne.io import RawArray -from mne.time_frequency import (tfr_multitaper, tfr_stockwell, tfr_morlet, - tfr_array_morlet, AverageTFR) +from mne.time_frequency import ( + tfr_multitaper, + tfr_stockwell, + tfr_morlet, + tfr_array_morlet, + AverageTFR, +) from mne.viz import centers_to_edges print(__doc__) @@ -39,8 +44,8 @@ # We'll simulate data with a known spectro-temporal structure. sfreq = 1000.0 -ch_names = ['SIM0001', 'SIM0002'] -ch_types = ['grad', 'grad'] +ch_names = ["SIM0001", "SIM0002"] +ch_types = ["grad", "grad"] info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) n_times = 1024 # Just over 1 second epochs @@ -51,8 +56,8 @@ # Add a 50 Hz sinusoidal burst to the noise and ramp it. t = np.arange(n_times, dtype=np.float64) / sfreq -signal = np.sin(np.pi * 2. * 50. * t) # 50 Hz sinusoid signal -signal[np.logical_or(t < 0.45, t > 0.55)] = 0. # Hard windowing +signal = np.sin(np.pi * 2.0 * 50.0 * t) # 50 Hz sinusoid signal +signal[np.logical_or(t < 0.45, t > 0.55)] = 0.0 # Hard windowing on_time = np.logical_and(t >= 0.45, t <= 0.55) signal[on_time] *= np.hanning(on_time.sum()) # Ramping data[:, 100:-100] += np.tile(signal, n_epochs) # add signal @@ -60,8 +65,15 @@ raw = RawArray(data, info) events = np.zeros((n_epochs, 3), dtype=int) events[:, 0] = np.arange(n_epochs) * n_times -epochs = Epochs(raw, events, dict(sin50hz=0), tmin=0, tmax=n_times / sfreq, - reject=dict(grad=4000), baseline=None) +epochs = Epochs( + raw, + events, + dict(sin50hz=0), + tmin=0, + tmax=n_times / sfreq, + reject=dict(grad=4000), + baseline=None, +) epochs.average().plot() @@ -85,23 +97,39 @@ # properties, and thus a different TFR. You can trade time resolution or # frequency resolution or both in order to get a reduction in variance. -freqs = np.arange(5., 100., 3.) -vmin, vmax = -3., 3. # Define our color limits. +freqs = np.arange(5.0, 100.0, 3.0) +vmin, vmax = -3.0, 3.0 # Define our color limits. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) for n_cycles, time_bandwidth, ax, title in zip( - [freqs / 2, freqs, freqs / 2], # number of cycles - [2.0, 4.0, 8.0], # time bandwidth - axs, - ['Sim: Least smoothing, most variance', - 'Sim: Less frequency smoothing,\nmore time smoothing', - 'Sim: Less time smoothing,\nmore frequency smoothing']): - power = tfr_multitaper(epochs, freqs=freqs, n_cycles=n_cycles, - time_bandwidth=time_bandwidth, return_itc=False) + [freqs / 2, freqs, freqs / 2], # number of cycles + [2.0, 4.0, 8.0], # time bandwidth + axs, + [ + "Sim: Least smoothing, most variance", + "Sim: Less frequency smoothing,\nmore time smoothing", + "Sim: Less time smoothing,\nmore frequency smoothing", + ], +): + power = tfr_multitaper( + epochs, + freqs=freqs, + n_cycles=n_cycles, + time_bandwidth=time_bandwidth, + return_itc=False, + ) ax.set_title(title) # Plot results. Baseline correct based on first 100 ms. - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - axes=ax, show=False, colorbar=False) + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + axes=ax, + show=False, + colorbar=False, + ) plt.tight_layout() ############################################################################## @@ -119,9 +147,10 @@ fmin, fmax = freqs[[0, -1]] for width, ax in zip((0.2, 0.7, 3.0), axs): power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width) - power.plot([0], baseline=(0., 0.1), mode='mean', axes=ax, show=False, - colorbar=False) - ax.set_title('Sim: Using S transform, width = {:0.1f}'.format(width)) + power.plot( + [0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False + ) + ax.set_title("Sim: Using S transform, width = {:0.1f}".format(width)) plt.tight_layout() # %% @@ -134,14 +163,21 @@ # number of cycles to include in the window. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) -all_n_cycles = [1, 3, freqs / 2.] +all_n_cycles = [1, 3, freqs / 2.0] for n_cycles, ax in zip(all_n_cycles, axs): - power = tfr_morlet(epochs, freqs=freqs, - n_cycles=n_cycles, return_itc=False) - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - axes=ax, show=False, colorbar=False) - n_cycles = 'scaled by freqs' if not isinstance(n_cycles, int) else n_cycles - ax.set_title(f'Sim: Using Morlet wavelet, n_cycles = {n_cycles}') + power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False) + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + axes=ax, + show=False, + colorbar=False, + ) + n_cycles = "scaled by freqs" if not isinstance(n_cycles, int) else n_cycles + ax.set_title(f"Sim: Using Morlet wavelet, n_cycles = {n_cycles}") plt.tight_layout() # %% @@ -154,10 +190,9 @@ # the width of this filter is recommended to be about 2 Hz. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) -bandwidths = [1., 2., 4.] +bandwidths = [1.0, 2.0, 4.0] for bandwidth, ax in zip(bandwidths, axs): - data = np.zeros((len(ch_names), freqs.size, epochs.times.size), - dtype=complex) + data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex) for idx, freq in enumerate(freqs): # Filter raw data and re-epoch to avoid the filter being longer than # the epoch data for low frequencies and short epochs, such as here. @@ -167,24 +202,37 @@ # these are all very similar because the filters are almost the same. # In practice, using the default is usually a wise choice. raw_filter.filter( - l_freq=freq - bandwidth / 2, h_freq=freq + bandwidth / 2, + l_freq=freq - bandwidth / 2, + h_freq=freq + bandwidth / 2, # no negative values for large bandwidth and low freq l_trans_bandwidth=min([4 * bandwidth, freq - bandwidth]), - h_trans_bandwidth=4 * bandwidth) + h_trans_bandwidth=4 * bandwidth, + ) raw_filter.apply_hilbert() - epochs_hilb = Epochs(raw_filter, events, tmin=0, tmax=n_times / sfreq, - baseline=(0, 0.1)) + epochs_hilb = Epochs( + raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1) + ) tfr_data = epochs_hilb.get_data() tfr_data = tfr_data * tfr_data.conj() # compute power tfr_data = np.mean(tfr_data, axis=0) # average over epochs data[:, idx] = tfr_data power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs) - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=-0.1, vmax=0.1, - axes=ax, show=False, colorbar=False) - n_cycles = 'scaled by freqs' if not isinstance(n_cycles, int) else n_cycles - ax.set_title('Sim: Using narrow bandpass filter Hilbert,\n' - f'bandwidth = {bandwidth}, ' - f'transition bandwidth = {4 * bandwidth}') + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=-0.1, + vmax=0.1, + axes=ax, + show=False, + colorbar=False, + ) + n_cycles = "scaled by freqs" if not isinstance(n_cycles, int) else n_cycles + ax.set_title( + "Sim: Using narrow bandpass filter Hilbert,\n" + f"bandwidth = {bandwidth}, " + f"transition bandwidth = {4 * bandwidth}" + ) plt.tight_layout() # %% @@ -195,13 +243,21 @@ # We can do this by using ``average=False``. In this case, an instance of # :class:`mne.time_frequency.EpochsTFR` is returned. -n_cycles = freqs / 2. -power = tfr_morlet(epochs, freqs=freqs, - n_cycles=n_cycles, return_itc=False, average=False) +n_cycles = freqs / 2.0 +power = tfr_morlet( + epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False +) print(type(power)) avgpower = power.average() -avgpower.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - title='Using Morlet wavelets and EpochsTFR', show=False) +avgpower.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + title="Using Morlet wavelets and EpochsTFR", + show=False, +) # %% # Operating on arrays @@ -212,16 +268,20 @@ # ``(n_epochs, n_channels, n_times)``. They will also return a numpy array # of shape ``(n_epochs, n_channels, n_freqs, n_times)``. -power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], - freqs=freqs, n_cycles=n_cycles, - output='avg_power') +power = tfr_array_morlet( + epochs.get_data(), + sfreq=epochs.info["sfreq"], + freqs=freqs, + n_cycles=n_cycles, + output="avg_power", +) # Baseline the output -rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False) +rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False) fig, ax = plt.subplots() x, y = centers_to_edges(epochs.times * 1000, freqs) -mesh = ax.pcolormesh(x, y, power[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) -ax.set_title('TFR calculated on a numpy array') -ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)') +mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) +ax.set_title("TFR calculated on a numpy array") +ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)") fig.colorbar(mesh) plt.tight_layout() diff --git a/examples/visualization/3d_to_2d.py b/examples/visualization/3d_to_2d.py index bb692533baa..9eecc33f196 100644 --- a/examples/visualization/3d_to_2d.py +++ b/examples/visualization/3d_to_2d.py @@ -34,12 +34,12 @@ from mne.viz import plot_alignment, set_3d_view, snapshot_brain_montage misc_path = mne.datasets.misc.data_path() -subjects_dir = misc_path / 'ecog' -ecog_data_fname = subjects_dir / 'sample_ecog_ieeg.fif' +subjects_dir = misc_path / "ecog" +ecog_data_fname = subjects_dir / "sample_ecog_ieeg.fif" # We've already clicked and exported -layout_path = Path(dirname(mne.__file__)) / 'data' / 'image' -layout_name = 'custom_layout.lout' +layout_path = Path(dirname(mne.__file__)) / "data" / "image" +layout_name = "custom_layout.lout" # %% # Load data @@ -49,14 +49,14 @@ # a 2D snapshot. raw = read_raw_fif(ecog_data_fname) -raw.pick_channels([f'G{i}' for i in range(1, 257)]) # pick just one grid +raw.pick_channels([f"G{i}" for i in range(1, 257)]) # pick just one grid # Since we loaded in the ecog data from FIF, the coordinates # are in 'head' space, but we actually want them in 'mri' space. # So we will apply the head->mri transform that was used when # generating the dataset (the estimated head->mri transform). montage = raw.get_montage() -trans = mne.coreg.estimate_head_mri_t('sample_ecog', subjects_dir) +trans = mne.coreg.estimate_head_mri_t("sample_ecog", subjects_dir) montage.apply_trans(trans) # %% @@ -68,8 +68,13 @@ # with the electrode positions on that image. We use this in conjunction with # :func:`mne.viz.plot_alignment`, which visualizes electrode positions. -fig = plot_alignment(raw.info, trans=trans, subject='sample_ecog', - subjects_dir=subjects_dir, surfaces=dict(pial=0.9)) +fig = plot_alignment( + raw.info, + trans=trans, + subject="sample_ecog", + subjects_dir=subjects_dir, + surfaces=dict(pial=0.9), +) set_3d_view(figure=fig, azimuth=20, elevation=80) xy, im = snapshot_brain_montage(fig, montage) @@ -84,9 +89,9 @@ # This allows us to use matplotlib to create arbitrary 2d scatterplots fig2, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) -cmap = ax.scatter(*xy_pts.T, c=beta_power, s=100, cmap='coolwarm') +cmap = ax.scatter(*xy_pts.T, c=beta_power, s=100, cmap="coolwarm") cbar = fig2.colorbar(cmap) -cbar.ax.set_ylabel('Beta Power') +cbar.ax.set_ylabel("Beta Power") ax.set_axis_off() # fig2.savefig('./brain.png', bbox_inches='tight') # For ClickableImage @@ -126,6 +131,6 @@ y = (1 - lt.pos[:, 1]) * float(im.shape[0]) # Flip the y-position fig, ax = plt.subplots() ax.imshow(im) -ax.scatter(x, y, s=80, color='r') +ax.scatter(x, y, s=80, color="r") fig.tight_layout() ax.set_axis_off() diff --git a/examples/visualization/brain.py b/examples/visualization/brain.py index 5b31bc7b106..35a7ac77bfd 100644 --- a/examples/visualization/brain.py +++ b/examples/visualization/brain.py @@ -29,8 +29,8 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -sample_dir = data_path / 'MEG' / 'sample' +subjects_dir = data_path / "subjects" +sample_dir = data_path / "MEG" / "sample" # %% # Add source information @@ -38,16 +38,21 @@ # # Plot source information. -brain_kwargs = dict(alpha=0.1, background='white', cortex='low_contrast') -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain_kwargs = dict(alpha=0.1, background="white", cortex="low_contrast") +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) -stc = mne.read_source_estimate(sample_dir / 'sample_audvis-meg') +stc = mne.read_source_estimate(sample_dir / "sample_audvis-meg") stc.crop(0.09, 0.1) -kwargs = dict(fmin=stc.data.min(), fmax=stc.data.max(), alpha=0.25, - smoothing_steps='nearest', time=stc.times) -brain.add_data(stc.lh_data, hemi='lh', vertices=stc.lh_vertno, **kwargs) -brain.add_data(stc.rh_data, hemi='rh', vertices=stc.rh_vertno, **kwargs) +kwargs = dict( + fmin=stc.data.min(), + fmax=stc.data.max(), + alpha=0.25, + smoothing_steps="nearest", + time=stc.times, +) +brain.add_data(stc.lh_data, hemi="lh", vertices=stc.lh_vertno, **kwargs) +brain.add_data(stc.rh_data, hemi="rh", vertices=stc.rh_vertno, **kwargs) # %% # Modify the view of the brain @@ -55,7 +60,7 @@ # # You can adjust the view of the brain using ``show_view`` method. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) brain.show_view(azimuth=190, elevation=70, distance=350, focalpoint=(0, 0, 20)) # %% @@ -73,8 +78,8 @@ # .. note:: The MNE sample dataset contains only a subselection of the # Freesurfer labels created during the ``recon-all``. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -brain.add_label('BA44', hemi='lh', color='green', borders=True) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +brain.add_label("BA44", hemi="lh", color="green", borders=True) brain.show_view(azimuth=190, elevation=70, distance=350, focalpoint=(0, 0, 20)) # %% @@ -83,7 +88,7 @@ # # Add a head image using the ``add_head`` method. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) brain.add_head(alpha=0.5) # %% @@ -93,9 +98,9 @@ # To put into context the data that generated the source time course, # the sensor positions can be displayed as well. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -evoked = mne.read_evokeds(sample_dir / 'sample_audvis-ave.fif')[0] -trans = mne.read_trans(sample_dir / 'sample_audvis_raw-trans.fif') +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +evoked = mne.read_evokeds(sample_dir / "sample_audvis-ave.fif")[0] +trans = mne.read_trans(sample_dir / "sample_audvis_raw-trans.fif") brain.add_sensors(evoked.info, trans) brain.show_view(distance=500) # move back to show sensors @@ -106,9 +111,9 @@ # Dipole modeling as in :ref:`tut-dipole-orientations` can be plotted on the # brain as well. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -dip = mne.read_dipole(sample_dir / 'sample_audvis_set1.dip') -cmap = plt.colormaps['YlOrRd'] +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +dip = mne.read_dipole(sample_dir / "sample_audvis_set1.dip") +cmap = plt.colormaps["YlOrRd"] colors = [cmap(gof / dip.gof.max()) for gof in dip.gof] brain.add_dipole(dip, trans, colors=colors, scales=list(dip.amplitude * 1e8)) brain.show_view(azimuth=-20, elevation=60, distance=300) @@ -123,8 +128,8 @@ fig, ax = plt.subplots() ax.imshow(img) -ax.axis('off') +ax.axis("off") cax = fig.add_axes([0.9, 0.1, 0.05, 0.8]) norm = Normalize(vmin=0, vmax=dip.gof.max()) fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax) -fig.suptitle('Dipole Fits Scaled by Amplitude and Colored by GOF') +fig.suptitle("Dipole Fits Scaled by Amplitude and Colored by GOF") diff --git a/examples/visualization/channel_epochs_image.py b/examples/visualization/channel_epochs_image.py index bb52c11c44b..618330ec44d 100644 --- a/examples/visualization/channel_epochs_image.py +++ b/examples/visualization/channel_epochs_image.py @@ -33,9 +33,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id, tmin, tmax = 1, -0.2, 0.4 # Setup for reading the raw data @@ -43,12 +43,21 @@ events = mne.read_events(event_fname) # Set up pick list: EEG + MEG - bad channels (modify to your needs) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] +raw.info["bads"] = ["MEG 2443", "EEG 053"] # Create epochs, here for gradiometers + EOG only for simplicity -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - picks=('grad', 'eog'), baseline=(None, 0), preload=True, - reject=dict(grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=("grad", "eog"), + baseline=(None, 0), + preload=True, + reject=dict(grad=4000e-13, eog=150e-6), +) # %% # Show event-related fields images @@ -56,26 +65,36 @@ # and order with spectral reordering # If you don't have scikit-learn installed set order_func to None from sklearn.manifold import spectral_embedding # noqa -from sklearn.metrics.pairwise import rbf_kernel # noqa +from sklearn.metrics.pairwise import rbf_kernel # noqa def order_func(times, data): this_data = data[:, (times > 0.0) & (times < 0.350)] - this_data /= np.sqrt(np.sum(this_data ** 2, axis=1))[:, np.newaxis] - return np.argsort(spectral_embedding(rbf_kernel(this_data, gamma=1.), - n_components=1, random_state=0).ravel()) + this_data /= np.sqrt(np.sum(this_data**2, axis=1))[:, np.newaxis] + return np.argsort( + spectral_embedding( + rbf_kernel(this_data, gamma=1.0), n_components=1, random_state=0 + ).ravel() + ) good_pick = 97 # channel with a clear evoked response bad_pick = 98 # channel with no evoked response # We'll also plot a sample time onset for each trial -plt_times = np.linspace(0, .2, len(epochs)) - -plt.close('all') -mne.viz.plot_epochs_image(epochs, [good_pick, bad_pick], sigma=.5, - order=order_func, vmin=-250, vmax=250, - overlay_times=plt_times, show=True) +plt_times = np.linspace(0, 0.2, len(epochs)) + +plt.close("all") +mne.viz.plot_epochs_image( + epochs, + [good_pick, bad_pick], + sigma=0.5, + order=order_func, + vmin=-250, + vmax=250, + overlay_times=plt_times, + show=True, +) # %% # References diff --git a/examples/visualization/eeg_on_scalp.py b/examples/visualization/eeg_on_scalp.py index 7ad5438b9dc..f27bc63ecdd 100644 --- a/examples/visualization/eeg_on_scalp.py +++ b/examples/visualization/eeg_on_scalp.py @@ -19,15 +19,22 @@ print(__doc__) data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -trans = mne.read_trans(meg_path / 'sample_audvis_raw-trans.fif') -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +trans = mne.read_trans(meg_path / "sample_audvis_raw-trans.fif") +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") # Plot electrode locations on scalp -fig = plot_alignment(raw.info, trans, subject='sample', dig=False, - eeg=['original', 'projected'], meg=[], - coord_frame='head', subjects_dir=subjects_dir) +fig = plot_alignment( + raw.info, + trans, + subject="sample", + dig=False, + eeg=["original", "projected"], + meg=[], + coord_frame="head", + subjects_dir=subjects_dir, +) # Set viewing angle set_3d_view(figure=fig, azimuth=135, elevation=80) diff --git a/examples/visualization/evoked_arrowmap.py b/examples/visualization/evoked_arrowmap.py index 7ce3f1df093..294be182c7c 100644 --- a/examples/visualization/evoked_arrowmap.py +++ b/examples/visualization/evoked_arrowmap.py @@ -33,13 +33,13 @@ print(__doc__) path = sample.data_path() -fname = path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname = path / "MEG" / "sample" / "sample_audvis-ave.fif" # load evoked data -condition = 'Left Auditory' +condition = "Left Auditory" evoked = read_evokeds(fname, condition=condition, baseline=(None, 0)) -evoked_mag = evoked.copy().pick_types(meg='mag') -evoked_grad = evoked.copy().pick_types(meg='grad') +evoked_mag = evoked.copy().pick_types(meg="mag") +evoked_grad = evoked.copy().pick_types(meg="grad") # %% # Plot magnetometer data as an arrowmap along with the topoplot at the time @@ -57,8 +57,11 @@ # %% # Plot gradiometer data as an arrowmap along with the topoplot at the time # of the maximum sensor space activity: -plot_arrowmap(evoked_grad.data[:, max_time_idx], info_from=evoked_grad.info, - info_to=evoked_mag.info) +plot_arrowmap( + evoked_grad.data[:, max_time_idx], + info_from=evoked_grad.info, + info_to=evoked_mag.info, +) # %% # Since Vectorview 102 system perform sparse spatial sampling of the magnetic @@ -68,10 +71,14 @@ # Plot gradiometer data as an arrowmap along with the topoplot at the time # of the maximum sensor space activity: path = bst_raw.data_path() -raw_fname = (path / 'MEG' / 'bst_raw' / - 'subj001_somatosensory_20111109_01_AUX-f.ds') +raw_fname = path / "MEG" / "bst_raw" / "subj001_somatosensory_20111109_01_AUX-f.ds" raw_ctf = mne.io.read_raw_ctf(raw_fname) raw_ctf_info = mne.pick_info( - raw_ctf.info, mne.pick_types(raw_ctf.info, meg=True, ref_meg=False)) -plot_arrowmap(evoked_grad.data[:, max_time_idx], info_from=evoked_grad.info, - info_to=raw_ctf_info, scale=6e-10) + raw_ctf.info, mne.pick_types(raw_ctf.info, meg=True, ref_meg=False) +) +plot_arrowmap( + evoked_grad.data[:, max_time_idx], + info_from=evoked_grad.info, + info_to=raw_ctf_info, + scale=6e-10, +) diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index abeb527757e..20bb9611497 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -29,11 +29,11 @@ print(__doc__) path = sample.data_path() -fname = path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname = path / "MEG" / "sample" / "sample_audvis-ave.fif" # load evoked corresponding to a specific condition # from the fif file and subtract baseline -condition = 'Left Auditory' +condition = "Left Auditory" evoked = read_evokeds(fname, condition=condition, baseline=(None, 0)) # %% @@ -45,28 +45,28 @@ # topographies will be shown. We select timepoints from 50 to 150 ms with a # step of 20ms and plot magnetometer data: times = np.arange(0.05, 0.151, 0.02) -evoked.plot_topomap(times, ch_type='mag') +evoked.plot_topomap(times, ch_type="mag") # %% # If times is set to None at most 10 regularly spaced topographies will be # shown: -evoked.plot_topomap(ch_type='mag') +evoked.plot_topomap(ch_type="mag") # %% # We can use ``nrows`` and ``ncols`` parameter to create multiline plots # with more timepoints. all_times = np.arange(-0.2, 0.5, 0.03) -evoked.plot_topomap(all_times, ch_type='mag', ncols=8, nrows='auto') +evoked.plot_topomap(all_times, ch_type="mag", ncols=8, nrows="auto") # %% # Instead of showing topographies at specific time points we can compute # averages of 50 ms bins centered on these time points to reduce the noise in # the topographies: -evoked.plot_topomap(times, ch_type='mag', average=0.05) +evoked.plot_topomap(times, ch_type="mag", average=0.05) # %% # We can plot gradiometer data (plots the RMS for each pair of gradiometers) -evoked.plot_topomap(times, ch_type='grad') +evoked.plot_topomap(times, ch_type="grad") # %% # Additional :func:`~mne.viz.plot_topomap` options @@ -79,8 +79,7 @@ # * ``res`` - to control the resolution of the topographies (lower resolution # means faster plotting) # * ``contours`` to define how many contour lines should be plotted -evoked.plot_topomap(times, ch_type='mag', cmap='Spectral_r', res=32, - contours=4) +evoked.plot_topomap(times, ch_type="mag", cmap="Spectral_r", res=32, contours=4) # %% # If you look at the edges of the head circle of a single topomap you'll see @@ -94,17 +93,24 @@ # The default value ``extrapolate='auto'`` will use ``'local'`` for MEG sensors # and ``'head'`` otherwise. Here we show each option: -extrapolations = ['local', 'head', 'box'] +extrapolations = ["local", "head", "box"] fig, axes = plt.subplots(figsize=(7.5, 4.5), nrows=2, ncols=3) # Here we look at EEG channels, and use a custom head sphere to get all the # sensors to be well within the drawn head surface -for axes_row, ch_type in zip(axes, ('mag', 'eeg')): +for axes_row, ch_type in zip(axes, ("mag", "eeg")): for ax, extr in zip(axes_row, extrapolations): - evoked.plot_topomap(0.1, ch_type=ch_type, size=2, extrapolate=extr, - axes=ax, show=False, colorbar=False, - sphere=(0., 0., 0., 0.09)) - ax.set_title('%s %s' % (ch_type.upper(), extr), fontsize=14) + evoked.plot_topomap( + 0.1, + ch_type=ch_type, + size=2, + extrapolate=extr, + axes=ax, + show=False, + colorbar=False, + sphere=(0.0, 0.0, 0.0, 0.09), + ) + ax.set_title("%s %s" % (ch_type.upper(), extr), fontsize=14) fig.tight_layout() # %% @@ -114,10 +120,11 @@ # Now we plot magnetometer data as topomap at a single time point: 100 ms # post-stimulus, add channel labels, title and adjust plot margins: -fig = evoked.plot_topomap(0.1, ch_type='mag', show_names=True, colorbar=False, - size=6, res=128) +fig = evoked.plot_topomap( + 0.1, ch_type="mag", show_names=True, colorbar=False, size=6, res=128 +) fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.88) -fig.suptitle('Auditory response') +fig.suptitle("Auditory response") # %% # We can also highlight specific channels by adding a mask, to e.g. mark @@ -128,8 +135,8 @@ # Select times and plot times = (0.09, 0.1, 0.11) -mask_params = dict(markersize=10, markerfacecolor='y') -evoked.plot_topomap(times, ch_type='mag', mask=mask, mask_params=mask_params) +mask_params = dict(markersize=10, markerfacecolor="y") +evoked.plot_topomap(times, ch_type="mag", mask=mask, mask_params=mask_params) # %% # Or by manually picking the channels to highlight at different times: @@ -137,16 +144,17 @@ times = (0.09, 0.1, 0.11) _times = ((np.abs(evoked.times - t)).argmin() for t in times) significant_channels = [ - ('MEG 0231', 'MEG 1611', 'MEG 1621', 'MEG 1631', 'MEG 1811'), - ('MEG 2411', 'MEG 2421'), - ('MEG 1621')] + ("MEG 0231", "MEG 1611", "MEG 1621", "MEG 1631", "MEG 1811"), + ("MEG 2411", "MEG 2421"), + ("MEG 1621"), +] _channels = [np.in1d(evoked.ch_names, ch) for ch in significant_channels] -mask = np.zeros(evoked.data.shape, dtype='bool') +mask = np.zeros(evoked.data.shape, dtype="bool") for _chs, _time in zip(_channels, _times): mask[_chs, _time] = True -evoked.plot_topomap(times, ch_type='mag', mask=mask, mask_params=mask_params) +evoked.plot_topomap(times, ch_type="mag", mask=mask, mask_params=mask_params) # %% # Interpolating topomaps @@ -162,18 +170,18 @@ # The default cubic interpolation is the smoothest and is great for # publications. -evoked.plot_topomap(times, ch_type='eeg', image_interp='cubic') +evoked.plot_topomap(times, ch_type="eeg", image_interp="cubic") # %% # The linear interpolation might be helpful in some cases. -evoked.plot_topomap(times, ch_type='eeg', image_interp='linear') +evoked.plot_topomap(times, ch_type="eeg", image_interp="linear") # %% # The nearest (Voronoi, no interpolation) interpolation is especially helpful # for debugging and seeing the values assigned to the topomap unaltered. -evoked.plot_topomap(times, ch_type='eeg', image_interp='nearest', contours=0) +evoked.plot_topomap(times, ch_type="eeg", image_interp="nearest", contours=0) # %% # Animating the topomap @@ -184,5 +192,4 @@ # sphinx_gallery_thumbnail_number = 9 times = np.arange(0.05, 0.151, 0.01) -fig, anim = evoked.animate_topomap( - times=times, ch_type='mag', frame_rate=2, blit=False) +fig, anim = evoked.animate_topomap(times=times, ch_type="mag", frame_rate=2, blit=False) diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index 7a5f7552cc1..1d1575a83b6 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -35,21 +35,30 @@ # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 40, fir_design='firwin') -raw.info['bads'] += ['MEG 2443'] # bads + 1 more +raw.filter(1, 40, fir_design="firwin") +raw.info["bads"] += ["MEG 2443"] # bads + 1 more events = mne.read_events(event_fname) # let's look at rare events, button presses event_id, tmin, tmax = 2, -0.2, 0.5 reject = dict(mag=4e-12, grad=4000e-13, eeg=80e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=('meg', 'eeg'), - baseline=None, reject=reject, preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=("meg", "eeg"), + baseline=None, + reject=reject, + preload=True, +) # Uncomment next line to use fewer samples and study regularization effects # epochs = epochs[:20] # For your data, use as many samples as you can! @@ -57,24 +66,32 @@ # %% # Compute covariance using automated regularization method_params = dict(diagonal_fixed=dict(mag=0.01, grad=0.01, eeg=0.01)) -noise_covs = compute_covariance(epochs, tmin=None, tmax=0, method='auto', - return_estimators=True, n_jobs=None, - projs=None, rank=None, - method_params=method_params, verbose=True) +noise_covs = compute_covariance( + epochs, + tmin=None, + tmax=0, + method="auto", + return_estimators=True, + n_jobs=None, + projs=None, + rank=None, + method_params=method_params, + verbose=True, +) # With "return_estimator=True" all estimated covariances sorted # by log-likelihood are returned. -print('Covariance estimates sorted from best to worst') +print("Covariance estimates sorted from best to worst") for c in noise_covs: - print("%s : %s" % (c['method'], c['loglik'])) + print("%s : %s" % (c["method"], c["loglik"])) # %% # Show the evoked data: evoked = epochs.average() -evoked.plot(time_unit='s') # plot evoked response +evoked.plot(time_unit="s") # plot evoked response # %% # We can then show whitening for our various noise covariance estimates. @@ -85,4 +102,4 @@ # # For the Global field power we expect a value of 1. -evoked.plot_white(noise_covs, time_unit='s') +evoked.plot_white(noise_covs, time_unit="s") diff --git a/examples/visualization/meg_sensors.py b/examples/visualization/meg_sensors.py index 9d5ccd6411c..3685ee68543 100644 --- a/examples/visualization/meg_sensors.py +++ b/examples/visualization/meg_sensors.py @@ -17,8 +17,13 @@ import mne from mne.datasets import sample, spm_face, testing -from mne.io import (read_raw_artemis123, read_raw_bti, read_raw_ctf, - read_raw_fif, read_raw_kit) +from mne.io import ( + read_raw_artemis123, + read_raw_bti, + read_raw_ctf, + read_raw_fif, + read_raw_kit, +) from mne.viz import plot_alignment, set_3d_title print(__doc__) @@ -29,48 +34,52 @@ # Neuromag # -------- -kwargs = dict(eeg=False, coord_frame='meg', show_axes=True, verbose=True) +kwargs = dict(eeg=False, coord_frame="meg", show_axes=True, verbose=True) -raw = read_raw_fif( - sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors'), **kwargs) -set_3d_title(figure=fig, title='Neuromag') +raw = read_raw_fif(sample.data_path() / "MEG" / "sample" / "sample_audvis_raw.fif") +fig = plot_alignment(raw.info, meg=("helmet", "sensors"), **kwargs) +set_3d_title(figure=fig, title="Neuromag") # %% # CTF # --- raw = read_raw_ctf( - spm_face.data_path() / 'MEG' / 'spm' / 'SPM_CTF_MEG_example_faces1_3D.ds') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='CTF 275') + spm_face.data_path() / "MEG" / "spm" / "SPM_CTF_MEG_example_faces1_3D.ds" +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="CTF 275") # %% # BTi # --- -bti_path = root_path / 'io' / 'bti' / 'tests' / 'data' -raw = read_raw_bti(bti_path / 'test_pdf_linux', - bti_path / 'test_config_linux', - bti_path / 'test_hs_linux') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='Magnes 3600wh') +bti_path = root_path / "io" / "bti" / "tests" / "data" +raw = read_raw_bti( + bti_path / "test_pdf_linux", + bti_path / "test_config_linux", + bti_path / "test_hs_linux", +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="Magnes 3600wh") # %% # KIT # --- -kit_path = root_path / 'io' / 'kit' / 'tests' / 'data' -raw = read_raw_kit(kit_path / 'test.sqd') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors'), **kwargs) -set_3d_title(figure=fig, title='KIT') +kit_path = root_path / "io" / "kit" / "tests" / "data" +raw = read_raw_kit(kit_path / "test.sqd") +fig = plot_alignment(raw.info, meg=("helmet", "sensors"), **kwargs) +set_3d_title(figure=fig, title="KIT") # %% # Artemis123 # ---------- raw = read_raw_artemis123( - testing.data_path() / 'ARTEMIS123' / - 'Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='Artemis123') + testing.data_path() + / "ARTEMIS123" + / "Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin" +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="Artemis123") diff --git a/examples/visualization/mne_helmet.py b/examples/visualization/mne_helmet.py index c6c155bcfd3..1085cbfc044 100644 --- a/examples/visualization/mne_helmet.py +++ b/examples/visualization/mne_helmet.py @@ -13,23 +13,44 @@ import mne sample_path = mne.datasets.sample.data_path() -subjects_dir = sample_path / 'subjects' -fname_evoked = sample_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' -fname_inv = (sample_path / 'MEG' / 'sample' / - 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_trans = sample_path / 'MEG' / 'sample' / 'sample_audvis_raw-trans.fif' +subjects_dir = sample_path / "subjects" +fname_evoked = sample_path / "MEG" / "sample" / "sample_audvis-ave.fif" +fname_inv = sample_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_trans = sample_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif" inv = mne.minimum_norm.read_inverse_operator(fname_inv) -evoked = mne.read_evokeds(fname_evoked, baseline=(None, 0), - proj=True, verbose=False, condition='Left Auditory') -maps = mne.make_field_map(evoked, trans=fname_trans, ch_type='meg', - subject='sample', subjects_dir=subjects_dir) +evoked = mne.read_evokeds( + fname_evoked, + baseline=(None, 0), + proj=True, + verbose=False, + condition="Left Auditory", +) +maps = mne.make_field_map( + evoked, + trans=fname_trans, + ch_type="meg", + subject="sample", + subjects_dir=subjects_dir, +) time = 0.083 fig = mne.viz.create_3d_figure((256, 256)) mne.viz.plot_alignment( - evoked.info, subject='sample', subjects_dir=subjects_dir, fig=fig, - trans=fname_trans, meg='sensors', eeg=False, surfaces='pial', - coord_frame='mri') + evoked.info, + subject="sample", + subjects_dir=subjects_dir, + fig=fig, + trans=fname_trans, + meg="sensors", + eeg=False, + surfaces="pial", + coord_frame="mri", +) evoked.plot_field(maps, time=time, fig=fig, time_label=None, vmax=5e-13) mne.viz.set_3d_view( - fig, azimuth=40, elevation=87, focalpoint=(0., -0.01, 0.04), roll=-25, - distance=0.55) + fig, + azimuth=40, + elevation=87, + focalpoint=(0.0, -0.01, 0.04), + roll=-25, + distance=0.55, +) diff --git a/examples/visualization/montage_sgskip.py b/examples/visualization/montage_sgskip.py index 96ab574499e..521e4e87a16 100644 --- a/examples/visualization/montage_sgskip.py +++ b/examples/visualization/montage_sgskip.py @@ -28,15 +28,18 @@ for current_montage in get_builtin_montages(): montage = mne.channels.make_standard_montage(current_montage) - info = mne.create_info( - ch_names=montage.ch_names, sfreq=100., ch_types='eeg') + info = mne.create_info(ch_names=montage.ch_names, sfreq=100.0, ch_types="eeg") info.set_montage(montage) - sphere = mne.make_sphere_model(r0='auto', head_radius='auto', info=info) + sphere = mne.make_sphere_model(r0="auto", head_radius="auto", info=info) fig = mne.viz.plot_alignment( # Plot options - show_axes=True, dig='fiducials', surfaces='head', + show_axes=True, + dig="fiducials", + surfaces="head", trans=mne.Transform("head", "mri", trans=np.eye(4)), # identity - bem=sphere, info=info) + bem=sphere, + info=info, + ) set_3d_view(figure=fig, azimuth=135, elevation=80) set_3d_title(figure=fig, title=current_montage) @@ -49,15 +52,19 @@ for current_montage in get_builtin_montages(): montage = mne.channels.make_standard_montage(current_montage) # Create dummy info - info = mne.create_info( - ch_names=montage.ch_names, sfreq=100., ch_types='eeg') + info = mne.create_info(ch_names=montage.ch_names, sfreq=100.0, ch_types="eeg") info.set_montage(montage) fig = mne.viz.plot_alignment( # Plot options - show_axes=True, dig='fiducials', surfaces='head', mri_fiducials=True, - subject='fsaverage', subjects_dir=subjects_dir, info=info, - coord_frame='mri', - trans='fsaverage', # transform from head coords to fsaverage's MRI + show_axes=True, + dig="fiducials", + surfaces="head", + mri_fiducials=True, + subject="fsaverage", + subjects_dir=subjects_dir, + info=info, + coord_frame="mri", + trans="fsaverage", # transform from head coords to fsaverage's MRI ) set_3d_view(figure=fig, azimuth=135, elevation=80) set_3d_title(figure=fig, title=current_montage) diff --git a/examples/visualization/parcellation.py b/examples/visualization/parcellation.py index 7118a2594b5..9e416c97c48 100644 --- a/examples/visualization/parcellation.py +++ b/examples/visualization/parcellation.py @@ -26,37 +26,58 @@ # %% import mne + Brain = mne.viz.get_brain_class() -subjects_dir = mne.datasets.sample.data_path() / 'subjects' -mne.datasets.fetch_hcp_mmp_parcellation(subjects_dir=subjects_dir, - verbose=True) +subjects_dir = mne.datasets.sample.data_path() / "subjects" +mne.datasets.fetch_hcp_mmp_parcellation(subjects_dir=subjects_dir, verbose=True) -mne.datasets.fetch_aparc_sub_parcellation(subjects_dir=subjects_dir, - verbose=True) +mne.datasets.fetch_aparc_sub_parcellation(subjects_dir=subjects_dir, verbose=True) labels = mne.read_labels_from_annot( - 'fsaverage', 'HCPMMP1', 'lh', subjects_dir=subjects_dir) - -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('HCPMMP1') -aud_label = [label for label in labels if label.name == 'L_A1_ROI-lh'][0] + "fsaverage", "HCPMMP1", "lh", subjects_dir=subjects_dir +) + +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("HCPMMP1") +aud_label = [label for label in labels if label.name == "L_A1_ROI-lh"][0] brain.add_label(aud_label, borders=False) # %% # We can also plot a combined set of labels (23 per hemisphere). -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('HCPMMP1_combined') +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("HCPMMP1_combined") # %% # We can add another custom parcellation -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('aparc_sub') +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("aparc_sub") # %% # References diff --git a/examples/visualization/publication_figure.py b/examples/visualization/publication_figure.py index f86cc44075d..f753c72a2c8 100644 --- a/examples/visualization/publication_figure.py +++ b/examples/visualization/publication_figure.py @@ -22,8 +22,7 @@ import numpy as np import matplotlib.pyplot as plt -from mpl_toolkits.axes_grid1 import (make_axes_locatable, ImageGrid, - inset_locator) +from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid, inset_locator import mne @@ -36,12 +35,12 @@ # start by loading some :ref:`example data `. data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -fname_stc = data_path / 'MEG' / 'sample' / 'sample_audvis-meg-eeg-lh.stc' -fname_evoked = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +fname_stc = data_path / "MEG" / "sample" / "sample_audvis-meg-eeg-lh.stc" +fname_evoked = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" -evoked = mne.read_evokeds(fname_evoked, 'Left Auditory') -evoked.pick_types(meg='grad').apply_baseline((None, 0.)) +evoked = mne.read_evokeds(fname_evoked, "Left Auditory") +evoked.pick_types(meg="grad").apply_baseline((None, 0.0)) max_t = evoked.get_peak()[1] stc = mne.read_source_estimate(fname_stc) @@ -51,9 +50,16 @@ evoked.plot() -stc.plot(views='lat', hemi='split', size=(800, 400), subject='sample', - subjects_dir=subjects_dir, initial_time=max_t, - time_viewer=False, show_traces=False) +stc.plot( + views="lat", + hemi="split", + size=(800, 400), + subject="sample", + subjects_dir=subjects_dir, + initial_time=max_t, + time_viewer=False, + show_traces=False, +) # %% # To make a publication-ready figure, first we'll re-plot the brain on a white @@ -61,14 +67,24 @@ # While we're at it, let's change the colormap, set custom colormap limits and # remove the default colorbar (so we can add a smaller, vertical one later): -colormap = 'viridis' -clim = dict(kind='value', lims=[4, 8, 12]) +colormap = "viridis" +clim = dict(kind="value", lims=[4, 8, 12]) # Plot the STC, get the brain image, crop it: -brain = stc.plot(views='lat', hemi='split', size=(800, 400), subject='sample', - subjects_dir=subjects_dir, initial_time=max_t, background='w', - colorbar=False, clim=clim, colormap=colormap, - time_viewer=False, show_traces=False) +brain = stc.plot( + views="lat", + hemi="split", + size=(800, 400), + subject="sample", + subjects_dir=subjects_dir, + initial_time=max_t, + background="w", + colorbar=False, + clim=clim, + colormap=colormap, + time_viewer=False, + show_traces=False, +) screenshot = brain.screenshot() brain.close() @@ -87,10 +103,11 @@ # before/after results fig = plt.figure(figsize=(4, 4)) axes = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.5) -for ax, image, title in zip(axes, [screenshot, cropped_screenshot], - ['Before', 'After']): +for ax, image, title in zip( + axes, [screenshot, cropped_screenshot], ["Before", "After"] +): ax.imshow(image) - ax.set_title('{} cropping'.format(title)) + ax.set_title("{} cropping".format(title)) # %% # A lot of figure settings can be adjusted after the figure is created, but @@ -99,14 +116,16 @@ # script generates several figures that you want to all have the same style: # Tweak the figure style -plt.rcParams.update({ - 'ytick.labelsize': 'small', - 'xtick.labelsize': 'small', - 'axes.labelsize': 'small', - 'axes.titlesize': 'medium', - 'grid.color': '0.75', - 'grid.linestyle': ':', -}) +plt.rcParams.update( + { + "ytick.labelsize": "small", + "xtick.labelsize": "small", + "axes.labelsize": "small", + "axes.titlesize": "medium", + "grid.color": "0.75", + "grid.linestyle": ":", + } +) # %% # Now let's create our custom figure. There are lots of ways to do this step. @@ -119,8 +138,9 @@ # sphinx_gallery_thumbnail_number = 4 # figsize unit is inches -fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(4.5, 3.), - gridspec_kw=dict(height_ratios=[3, 4])) +fig, axes = plt.subplots( + nrows=2, ncols=1, figsize=(4.5, 3.0), gridspec_kw=dict(height_ratios=[3, 4]) +) # alternate way #1: using subplot2grid # fig = plt.figure(figsize=(4.5, 3.)) @@ -138,42 +158,55 @@ # plot the evoked in the desired subplot, and add a line at peak activation evoked.plot(axes=axes[evoked_idx]) -peak_line = axes[evoked_idx].axvline(max_t, color='#66CCEE', ls='--') +peak_line = axes[evoked_idx].axvline(max_t, color="#66CCEE", ls="--") # custom legend axes[evoked_idx].legend( - [axes[evoked_idx].lines[0], peak_line], ['MEG data', 'Peak time'], - frameon=True, columnspacing=0.1, labelspacing=0.1, - fontsize=8, fancybox=True, handlelength=1.8) + [axes[evoked_idx].lines[0], peak_line], + ["MEG data", "Peak time"], + frameon=True, + columnspacing=0.1, + labelspacing=0.1, + fontsize=8, + fancybox=True, + handlelength=1.8, +) # remove the "N_ave" annotation for text in list(axes[evoked_idx].texts): text.remove() # Remove spines and add grid axes[evoked_idx].grid(True) axes[evoked_idx].set_axisbelow(True) -for key in ('top', 'right'): +for key in ("top", "right"): axes[evoked_idx].spines[key].set(visible=False) # Tweak the ticks and limits axes[evoked_idx].set( - yticks=np.arange(-200, 201, 100), xticks=np.arange(-0.2, 0.51, 0.1)) -axes[evoked_idx].set( - ylim=[-225, 225], xlim=[-0.2, 0.5]) + yticks=np.arange(-200, 201, 100), xticks=np.arange(-0.2, 0.51, 0.1) +) +axes[evoked_idx].set(ylim=[-225, 225], xlim=[-0.2, 0.5]) # now add the brain to the lower axes axes[brain_idx].imshow(cropped_screenshot) -axes[brain_idx].axis('off') +axes[brain_idx].axis("off") # add a vertical colorbar with the same properties as the 3D one divider = make_axes_locatable(axes[brain_idx]) -cax = divider.append_axes('right', size='5%', pad=0.2) -cbar = mne.viz.plot_brain_colorbar(cax, clim, colormap, label='Activation (F)') +cax = divider.append_axes("right", size="5%", pad=0.2) +cbar = mne.viz.plot_brain_colorbar(cax, clim, colormap, label="Activation (F)") # tweak margins and spacing -fig.subplots_adjust( - left=0.15, right=0.9, bottom=0.01, top=0.9, wspace=0.1, hspace=0.5) +fig.subplots_adjust(left=0.15, right=0.9, bottom=0.01, top=0.9, wspace=0.1, hspace=0.5) # add subplot labels -for ax, label in zip(axes, 'AB'): - ax.text(0.03, ax.get_position().ymax, label, transform=fig.transFigure, - fontsize=12, fontweight='bold', va='top', ha='left') +for ax, label in zip(axes, "AB"): + ax.text( + 0.03, + ax.get_position().ymax, + label, + transform=fig.transFigure, + fontsize=12, + fontweight="bold", + va="top", + ha="left", + ) # %% # Custom timecourse with montage inset @@ -206,10 +239,9 @@ to_plot = [f"EEG {i:03}" for i in range(1, 5)] # get the data for plotting in a short time interval from 10 to 20 seconds -start = int(raw.info['sfreq'] * 10) -stop = int(raw.info['sfreq'] * 20) -data, times = raw.get_data(picks=to_plot, - start=start, stop=stop, return_times=True) +start = int(raw.info["sfreq"] * 10) +stop = int(raw.info["sfreq"] * 20) +data, times = raw.get_data(picks=to_plot, start=start, stop=stop, return_times=True) # Scale the data from the MNE internal unit V to µV data *= 1e6 diff --git a/examples/visualization/roi_erpimage_by_rt.py b/examples/visualization/roi_erpimage_by_rt.py index e803b3cb14b..a8d2bae8d58 100644 --- a/examples/visualization/roi_erpimage_by_rt.py +++ b/examples/visualization/roi_erpimage_by_rt.py @@ -31,24 +31,48 @@ # %% # Load EEGLAB example data (a small EEG dataset) data_path = mne.datasets.testing.data_path() -fname = data_path / 'EEGLAB' / 'test_raw.set' +fname = data_path / "EEGLAB" / "test_raw.set" event_id = {"rt": 1, "square": 2} # must be specified for str events raw = mne.io.read_raw_eeglab(fname) mapping = { - 'EEG 000': 'Fpz', 'EEG 001': 'EOG1', 'EEG 002': 'F3', 'EEG 003': 'Fz', - 'EEG 004': 'F4', 'EEG 005': 'EOG2', 'EEG 006': 'FC5', 'EEG 007': 'FC1', - 'EEG 008': 'FC2', 'EEG 009': 'FC6', 'EEG 010': 'T7', 'EEG 011': 'C3', - 'EEG 012': 'C4', 'EEG 013': 'Cz', 'EEG 014': 'T8', 'EEG 015': 'CP5', - 'EEG 016': 'CP1', 'EEG 017': 'CP2', 'EEG 018': 'CP6', 'EEG 019': 'P7', - 'EEG 020': 'P3', 'EEG 021': 'Pz', 'EEG 022': 'P4', 'EEG 023': 'P8', - 'EEG 024': 'PO7', 'EEG 025': 'PO3', 'EEG 026': 'POz', 'EEG 027': 'PO4', - 'EEG 028': 'PO8', 'EEG 029': 'O1', 'EEG 030': 'Oz', 'EEG 031': 'O2' + "EEG 000": "Fpz", + "EEG 001": "EOG1", + "EEG 002": "F3", + "EEG 003": "Fz", + "EEG 004": "F4", + "EEG 005": "EOG2", + "EEG 006": "FC5", + "EEG 007": "FC1", + "EEG 008": "FC2", + "EEG 009": "FC6", + "EEG 010": "T7", + "EEG 011": "C3", + "EEG 012": "C4", + "EEG 013": "Cz", + "EEG 014": "T8", + "EEG 015": "CP5", + "EEG 016": "CP1", + "EEG 017": "CP2", + "EEG 018": "CP6", + "EEG 019": "P7", + "EEG 020": "P3", + "EEG 021": "Pz", + "EEG 022": "P4", + "EEG 023": "P8", + "EEG 024": "PO7", + "EEG 025": "PO3", + "EEG 026": "POz", + "EEG 027": "PO4", + "EEG 028": "PO8", + "EEG 029": "O1", + "EEG 030": "Oz", + "EEG 031": "O2", } raw.rename_channels(mapping) -raw.set_channel_types({"EOG1": 'eog', "EOG2": 'eog'}) -raw.set_montage('standard_1020') +raw.set_channel_types({"EOG1": "eog", "EOG2": "eog"}) +raw.set_montage("standard_1020") events = mne.events_from_annotations(raw, event_id)[0] @@ -61,11 +85,11 @@ tmax = 0.7 sfreq = raw.info["sfreq"] reference_id, target_id = 2, 1 -new_events, rts = define_target_events(events, reference_id, target_id, sfreq, - tmin=0., tmax=tmax, new_id=2) +new_events, rts = define_target_events( + events, reference_id, target_id, sfreq, tmin=0.0, tmax=tmax, new_id=2 +) -epochs = mne.Epochs(raw, events=new_events, tmax=tmax + 0.1, - event_id={"square": 2}) +epochs = mne.Epochs(raw, events=new_events, tmax=tmax + 0.1, event_id={"square": 2}) # %% # Plot using :term:`global field power` @@ -76,13 +100,23 @@ selections = make_1020_channel_selections(epochs.info, midline="12z") # The actual plots (GFP) -epochs.plot_image(group_by=selections, order=order, sigma=1.5, - overlay_times=rts / 1000., combine='gfp', - ts_args=dict(vlines=[0, rts.mean() / 1000.])) +epochs.plot_image( + group_by=selections, + order=order, + sigma=1.5, + overlay_times=rts / 1000.0, + combine="gfp", + ts_args=dict(vlines=[0, rts.mean() / 1000.0]), +) # %% # Plot using median -epochs.plot_image(group_by=selections, order=order, sigma=1.5, - overlay_times=rts / 1000., combine='median', - ts_args=dict(vlines=[0, rts.mean() / 1000.])) +epochs.plot_image( + group_by=selections, + order=order, + sigma=1.5, + overlay_times=rts / 1000.0, + combine="median", + ts_args=dict(vlines=[0, rts.mean() / 1000.0]), +) diff --git a/examples/visualization/sensor_noise_level.py b/examples/visualization/sensor_noise_level.py index 55b220ba1c0..ca5c70d0233 100644 --- a/examples/visualization/sensor_noise_level.py +++ b/examples/visualization/sensor_noise_level.py @@ -19,13 +19,14 @@ data_path = mne.datasets.sample.data_path() raw_erm = mne.io.read_raw_fif( - data_path / 'MEG' / 'sample' / 'ernoise_raw.fif', preload=True + data_path / "MEG" / "sample" / "ernoise_raw.fif", preload=True ) # %% # We can plot the absolute noise levels: -raw_erm.compute_psd(tmax=10).plot(average=True, spatial_colors=False, - dB=False, xscale='log') +raw_erm.compute_psd(tmax=10).plot( + average=True, spatial_colors=False, dB=False, xscale="log" +) # %% # References # ---------- diff --git a/examples/visualization/ssp_projs_sensitivity_map.py b/examples/visualization/ssp_projs_sensitivity_map.py index 2c8259d7a24..d51c498e423 100644 --- a/examples/visualization/ssp_projs_sensitivity_map.py +++ b/examples/visualization/ssp_projs_sensitivity_map.py @@ -24,10 +24,10 @@ data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ecg_fname = meg_path / 'sample_audvis_ecg-proj.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ecg_fname = meg_path / "sample_audvis_ecg-proj.fif" fwd = read_forward_solution(fname) @@ -36,7 +36,7 @@ projs = projs[::2] # Compute sensitivity map -ssp_ecg_map = sensitivity_map(fwd, ch_type='grad', projs=projs, mode='angle') +ssp_ecg_map = sensitivity_map(fwd, ch_type="grad", projs=projs, mode="angle") # %% # Show sensitivity map @@ -44,6 +44,10 @@ plt.hist(ssp_ecg_map.data.ravel()) plt.show() -args = dict(clim=dict(kind='value', lims=(0.2, 0.6, 1.)), smoothing_steps=7, - hemi='rh', subjects_dir=subjects_dir) -ssp_ecg_map.plot(subject='sample', time_label='ECG SSP sensitivity', **args) +args = dict( + clim=dict(kind="value", lims=(0.2, 0.6, 1.0)), + smoothing_steps=7, + hemi="rh", + subjects_dir=subjects_dir, +) +ssp_ecg_map.plot(subject="sample", time_label="ECG SSP sensitivity", **args) diff --git a/examples/visualization/topo_compare_conditions.py b/examples/visualization/topo_compare_conditions.py index 6687ba37576..742565fc1fd 100644 --- a/examples/visualization/topo_compare_conditions.py +++ b/examples/visualization/topo_compare_conditions.py @@ -31,9 +31,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin = -0.2 tmax = 0.5 @@ -45,20 +45,20 @@ reject = dict(grad=4000e-13, mag=4e-12) # Create epochs including different events -event_id = {'audio/left': 1, 'audio/right': 2, - 'visual/left': 3, 'visual/right': 4} -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, - picks='meg', baseline=(None, 0), reject=reject) +event_id = {"audio/left": 1, "audio/right": 2, "visual/left": 3, "visual/right": 4} +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, picks="meg", baseline=(None, 0), reject=reject +) # Generate list of evoked objects from conditions names -evokeds = [epochs[name].average() for name in ('left', 'right')] +evokeds = [epochs[name].average() for name in ("left", "right")] # %% # Show topography for two different conditions -colors = 'blue', 'red' -title = 'MNE sample data\nleft vs right (A/V combined)' +colors = "blue", "red" +title = "MNE sample data\nleft vs right (A/V combined)" -plot_evoked_topo(evokeds, color=colors, title=title, background_color='w') +plot_evoked_topo(evokeds, color=colors, title=title, background_color="w") plt.show() diff --git a/examples/visualization/topo_customized.py b/examples/visualization/topo_customized.py index e9106a1e8d2..02c0435b25f 100644 --- a/examples/visualization/topo_customized.py +++ b/examples/visualization/topo_customized.py @@ -30,18 +30,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") picks = mne.pick_types(raw.info, meg=True, exclude=[]) tmin, tmax = 0, 120 # use the first 120s of data fmin, fmax = 2, 20 # look at frequencies between 2 and 20Hz n_fft = 2048 # the FFT size (n_fft). Ideally a power of 2 -spectrum = raw.compute_psd( - picks=picks, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax) +spectrum = raw.compute_psd(picks=picks, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax) psds, freqs = spectrum.get_data(exclude=(), return_freqs=True) psds = 20 * np.log10(psds) # scale to dB @@ -53,17 +52,19 @@ def my_callback(ax, ch_idx): in the plot. To work with the viz internals, this function should only take two parameters, the axis and the channel or data index. """ - ax.plot(freqs, psds[ch_idx], color='red') - ax.set_xlabel('Frequency (Hz)') - ax.set_ylabel('Power (dB)') + ax.plot(freqs, psds[ch_idx], color="red") + ax.set_xlabel("Frequency (Hz)") + ax.set_ylabel("Power (dB)") -for ax, idx in iter_topography(raw.info, - fig_facecolor='white', - axis_facecolor='white', - axis_spinecolor='white', - on_pick=my_callback): - ax.plot(psds[idx], color='red') +for ax, idx in iter_topography( + raw.info, + fig_facecolor="white", + axis_facecolor="white", + axis_spinecolor="white", + on_pick=my_callback, +): + ax.plot(psds[idx], color="red") -plt.gcf().suptitle('Power spectral densities') +plt.gcf().suptitle("Power spectral densities") plt.show() diff --git a/examples/visualization/xhemi.py b/examples/visualization/xhemi.py index bb5a4971d4d..693d702629c 100644 --- a/examples/visualization/xhemi.py +++ b/examples/visualization/xhemi.py @@ -19,26 +19,31 @@ import mne data_dir = mne.datasets.sample.data_path() -subjects_dir = data_dir / 'subjects' -stc_path = data_dir / 'MEG' / 'sample' / 'sample_audvis-meg-eeg' -stc = mne.read_source_estimate(stc_path, 'sample') +subjects_dir = data_dir / "subjects" +stc_path = data_dir / "MEG" / "sample" / "sample_audvis-meg-eeg" +stc = mne.read_source_estimate(stc_path, "sample") # First, morph the data to fsaverage_sym, for which we have left_right # registrations: -stc = mne.compute_source_morph(stc, 'sample', 'fsaverage_sym', smooth=5, - warn=False, - subjects_dir=subjects_dir).apply(stc) +stc = mne.compute_source_morph( + stc, "sample", "fsaverage_sym", smooth=5, warn=False, subjects_dir=subjects_dir +).apply(stc) # Compute a morph-matrix mapping the right to the left hemisphere, # and vice-versa. -morph = mne.compute_source_morph(stc, 'fsaverage_sym', 'fsaverage_sym', - spacing=stc.vertices, warn=False, - subjects_dir=subjects_dir, xhemi=True, - verbose='error') # creating morph map +morph = mne.compute_source_morph( + stc, + "fsaverage_sym", + "fsaverage_sym", + spacing=stc.vertices, + warn=False, + subjects_dir=subjects_dir, + xhemi=True, + verbose="error", +) # creating morph map stc_xhemi = morph.apply(stc) # Now we can subtract them and plot the result: diff = stc - stc_xhemi -diff.plot(hemi='lh', subjects_dir=subjects_dir, initial_time=0.07, - size=(800, 600)) +diff.plot(hemi="lh", subjects_dir=subjects_dir, initial_time=0.07, size=(800, 600)) diff --git a/logo/generate_mne_logos.py b/logo/generate_mne_logos.py index 072710182be..34b77788750 100644 --- a/logo/generate_mne_logos.py +++ b/logo/generate_mne_logos.py @@ -23,18 +23,24 @@ dpi = 300 center_fudge = np.array([15, 30]) # compensate for font bounding box padding tagline_scale_fudge = 0.97 # to get justification right -tagline_offset_fudge = np.array([0, -100.]) +tagline_offset_fudge = np.array([0, -100.0]) # font, etc -rcp = {'font.sans-serif': ['Primetime'], 'font.style': 'normal', - 'font.weight': 'black', 'font.variant': 'normal', 'figure.dpi': dpi, - 'savefig.dpi': dpi, 'contour.negative_linestyle': 'solid'} +rcp = { + "font.sans-serif": ["Primetime"], + "font.style": "normal", + "font.weight": "black", + "font.variant": "normal", + "figure.dpi": dpi, + "savefig.dpi": dpi, + "contour.negative_linestyle": "solid", +} plt.rcdefaults() rcParams.update(rcp) # initialize figure (no axes, margins, etc) fig = plt.figure(1, figsize=(5, 2.25), frameon=False, dpi=dpi) -ax = plt.Axes(fig, [0., 0., 1., 1.]) +ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) @@ -44,10 +50,12 @@ y = np.arange(-3.0, 3.0, delta) X, Y = np.meshgrid(x, y) xy = np.array([X, Y]).transpose(1, 2, 0) -Z1 = multivariate_normal.pdf(xy, mean=[-5.0, 0.9], - cov=np.array([[8.0, 1.0], [1.0, 7.0]]) ** 2) -Z2 = multivariate_normal.pdf(xy, mean=[2.6, -2.5], - cov=np.array([[15.0, 2.5], [2.5, 2.5]]) ** 2) +Z1 = multivariate_normal.pdf( + xy, mean=[-5.0, 0.9], cov=np.array([[8.0, 1.0], [1.0, 7.0]]) ** 2 +) +Z2 = multivariate_normal.pdf( + xy, mean=[2.6, -2.5], cov=np.array([[15.0, 2.5], [2.5, 2.5]]) ** 2 +) Z = Z2 - 0.7 * Z1 # color map: field gradient (yellow-red-gray-blue-cyan) @@ -56,36 +64,46 @@ # 'blue': ((0, 0, 0), (0.4, 0, 0), (0.5, 0.5, 0.5), (0.6, 1, 1), (1, 1, 1)), # noqa # 'green': ((0, 1, 1), (0.4, 0, 0), (0.5, 0.5, 0.5), (0.6, 0, 0), (1, 1, 1)), # noqa # } -yrtbc = {'red': ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), - 'blue': ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), - 'green': ((0.0, 1.0, 1.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)), - 'alpha': ((0.0, 1.0, 1.0), (0.4, 0.8, 0.8), (0.5, 0.2, 0.2), - (0.6, 0.8, 0.8), (1.0, 1.0, 1.0))} +yrtbc = { + "red": ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), + "blue": ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), + "green": ((0.0, 1.0, 1.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)), + "alpha": ( + (0.0, 1.0, 1.0), + (0.4, 0.8, 0.8), + (0.5, 0.2, 0.2), + (0.6, 0.8, 0.8), + (1.0, 1.0, 1.0), + ), +} # color map: field lines (red | blue) -redbl = {'red': ((0., 1., 1.), (0.5, 1., 0.), (1., 0., 0.)), - 'blue': ((0., 0., 0.), (0.5, 0., 1.), (1., 1., 1.)), - 'green': ((0., 0., 0.), (1., 0., 0.)), - 'alpha': ((0., 0.4, 0.4), (1., 0.4, 0.4))} -mne_field_grad_cols = LinearSegmentedColormap('mne_grad', yrtbc) -mne_field_line_cols = LinearSegmentedColormap('mne_line', redbl) +redbl = { + "red": ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), + "blue": ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), + "green": ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), + "alpha": ((0.0, 0.4, 0.4), (1.0, 0.4, 0.4)), +} +mne_field_grad_cols = LinearSegmentedColormap("mne_grad", yrtbc) +mne_field_line_cols = LinearSegmentedColormap("mne_line", redbl) # plot gradient and contour lines -im = ax.imshow(Z, cmap=mne_field_grad_cols, aspect='equal', zorder=1) +im = ax.imshow(Z, cmap=mne_field_grad_cols, aspect="equal", zorder=1) cs = ax.contour(Z, 9, cmap=mne_field_line_cols, linewidths=1, zorder=1) xlim, ylim = ax.get_xbound(), ax.get_ybound() plot_dims = np.r_[np.diff(xlim), np.diff(ylim)] rect = Rectangle( - [xlim[0], ylim[0]], plot_dims[0], plot_dims[1], facecolor='w', zorder=0.5) + [xlim[0], ylim[0]], plot_dims[0], plot_dims[1], facecolor="w", zorder=0.5 +) # create MNE clipping mask -mne_path = TextPath((0, 0), 'MNE') +mne_path = TextPath((0, 0), "MNE") dims = mne_path.vertices.max(0) - mne_path.vertices.min(0) -vert = mne_path.vertices - dims / 2. +vert = mne_path.vertices - dims / 2.0 mult = (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted (origin at top left) -offset = plot_dims / 2. - center_fudge +offset = plot_dims / 2.0 - center_fudge mne_clip = Path(offset + vert * mult, mne_path.codes) -ax.add_patch(PathPatch(mne_clip, color='w', zorder=0, linewidth=0)) +ax.add_patch(PathPatch(mne_clip, color="w", zorder=0, linewidth=0)) # apply clipping mask to field gradient and lines im.set_clip_path(mne_clip, transform=im.get_transform()) ax.add_patch(rect) @@ -96,64 +114,78 @@ mne_corners = mne_clip.get_extents().corners() # add tagline -rcParams.update({'font.sans-serif': ['Cooper Hewitt'], 'font.weight': '300'}) -tag_path = TextPath((0, 0), 'MEG + EEG ANALYSIS & VISUALIZATION') +rcParams.update({"font.sans-serif": ["Cooper Hewitt"], "font.weight": "300"}) +tag_path = TextPath((0, 0), "MEG + EEG ANALYSIS & VISUALIZATION") dims = tag_path.vertices.max(0) - tag_path.vertices.min(0) -vert = tag_path.vertices - dims / 2. +vert = tag_path.vertices - dims / 2.0 mult = tagline_scale_fudge * (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted -offset = mne_corners[-1] - np.array([mne_clip.get_extents().size[0] / 2., - -dims[1]]) - tagline_offset_fudge +offset = ( + mne_corners[-1] + - np.array([mne_clip.get_extents().size[0] / 2.0, -dims[1]]) + - tagline_offset_fudge +) tag_clip = Path(offset + vert * mult, tag_path.codes) -tag_patch = PathPatch(tag_clip, facecolor='k', edgecolor='none', zorder=10) +tag_patch = PathPatch(tag_clip, facecolor="k", edgecolor="none", zorder=10) ax.add_patch(tag_patch) yl = ax.get_ylim() -yy = np.max([tag_clip.vertices.max(0)[-1], - tag_clip.vertices.min(0)[-1]]) +yy = np.max([tag_clip.vertices.max(0)[-1], tag_clip.vertices.min(0)[-1]]) ax.set_ylim(np.ceil(yy), yl[-1]) # only save actual image extent plus a bit of padding plt.draw() -static_dir = op.join(op.dirname(__file__), '..', 'doc', '_static') +static_dir = op.join(op.dirname(__file__), "..", "doc", "_static") assert op.isdir(static_dir) -plt.savefig(op.join(static_dir, 'mne_logo.svg'), transparent=True) -tag_patch.set_facecolor('w') -rect.set_facecolor('0.5') -plt.savefig(op.join(static_dir, 'mne_logo_dark.svg'), transparent=True) -tag_patch.set_facecolor('k') -rect.set_facecolor('w') +plt.savefig(op.join(static_dir, "mne_logo.svg"), transparent=True) +tag_patch.set_facecolor("w") +rect.set_facecolor("0.5") +plt.savefig(op.join(static_dir, "mne_logo_dark.svg"), transparent=True) +tag_patch.set_facecolor("k") +rect.set_facecolor("w") # modify to make the splash screen -data_dir = op.join(op.dirname(__file__), '..', 'mne', 'icons') -ax.patches[-1].set_facecolor('w') +data_dir = op.join(op.dirname(__file__), "..", "mne", "icons") +ax.patches[-1].set_facecolor("w") for coll in list(ax.collections): coll.remove() -bounds = np.array([ - [mne_path.vertices[:, ii].min(), mne_path.vertices[:, ii].max()] - for ii in range(2)]) -bounds *= (plot_dims / dims) +bounds = np.array( + [ + [mne_path.vertices[:, ii].min(), mne_path.vertices[:, ii].max()] + for ii in range(2) + ] +) +bounds *= plot_dims / dims xy = np.mean(bounds, axis=1) - [100, 0] r = np.diff(bounds, axis=1).max() * 1.2 w, h = r, r * (2 / 3) box_xy = [xy[0] - w * 0.5, xy[1] - h * (2 / 5)] ax.set_ylim(box_xy[1] + h * 1.001, box_xy[1] - h * 0.001) patch = FancyBboxPatch( - box_xy, w, h, clip_on=False, zorder=-1, fc='k', ec='none', alpha=0.75, - boxstyle="round,rounding_size=200.0", mutation_aspect=1) + box_xy, + w, + h, + clip_on=False, + zorder=-1, + fc="k", + ec="none", + alpha=0.75, + boxstyle="round,rounding_size=200.0", + mutation_aspect=1, +) ax.add_patch(patch) fig.set_size_inches((512 / dpi, 512 * (h / w) / dpi)) -plt.savefig(op.join(data_dir, 'mne_splash.png'), transparent=True) +plt.savefig(op.join(data_dir, "mne_splash.png"), transparent=True) patch.remove() # modify to make an icon ax.patches.pop(-1) # no tag line for our icon -patch = Ellipse(xy, r, r, clip_on=False, zorder=-1, fc='k') +patch = Ellipse(xy, r, r, clip_on=False, zorder=-1, fc="k") ax.add_patch(patch) ax.set_ylim(xy[1] + r / 1.9, xy[1] - r / 1.9) fig.set_size_inches((256 / dpi, 256 / dpi)) # Qt does not support clip paths in SVG rendering so we have to use PNG here # then use "optipng -o7" on it afterward (14% reduction in file size) -plt.savefig(op.join(data_dir, 'mne_default_icon.png'), transparent=True) +plt.savefig(op.join(data_dir, "mne_default_icon.png"), transparent=True) plt.close() # 188x45 image @@ -162,31 +194,33 @@ h_px = 45 center_fudge = np.array([60, 0]) scale_fudge = 2.1 -rcParams.update({'font.sans-serif': ['Primetime'], 'font.weight': 'black'}) -x = np.linspace(-1., 1., w_px // 2) -y = np.linspace(-1., 1., h_px // 2) +rcParams.update({"font.sans-serif": ["Primetime"], "font.weight": "black"}) +x = np.linspace(-1.0, 1.0, w_px // 2) +y = np.linspace(-1.0, 1.0, h_px // 2) X, Y = np.meshgrid(x, y) # initialize figure (no axes, margins, etc) -fig = plt.figure(1, figsize=(w_px / dpi, h_px / dpi), facecolor='k', - frameon=False, dpi=dpi) -ax = plt.Axes(fig, [0., 0., 1., 1.]) +fig = plt.figure( + 1, figsize=(w_px / dpi, h_px / dpi), facecolor="k", frameon=False, dpi=dpi +) +ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) # plot rainbow -ax.imshow(X, cmap=mne_field_grad_cols, aspect='equal', zorder=1) -ax.imshow(np.ones_like(X) * 0.5, cmap='Greys', aspect='equal', zorder=0, - clim=[0, 1]) +ax.imshow(X, cmap=mne_field_grad_cols, aspect="equal", zorder=1) +ax.imshow(np.ones_like(X) * 0.5, cmap="Greys", aspect="equal", zorder=0, clim=[0, 1]) plot_dims = np.r_[np.diff(ax.get_xbound()), np.diff(ax.get_ybound())] # MNE text in white -mne_path = TextPath((0, 0), 'MNE') +mne_path = TextPath((0, 0), "MNE") dims = mne_path.vertices.max(0) - mne_path.vertices.min(0) -vert = mne_path.vertices - dims / 2. +vert = mne_path.vertices - dims / 2.0 mult = scale_fudge * (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted (origin at top left) -offset = np.array([scale_fudge, 1.]) * \ - np.array([-dims[0], plot_dims[-1]]) / 2. - center_fudge +offset = ( + np.array([scale_fudge, 1.0]) * np.array([-dims[0], plot_dims[-1]]) / 2.0 + - center_fudge +) mne_clip = Path(offset + vert * mult, mne_path.codes) -mne_patch = PathPatch(mne_clip, facecolor='0.5', edgecolor='none', zorder=10) +mne_patch = PathPatch(mne_clip, facecolor="0.5", edgecolor="none", zorder=10) ax.add_patch(mne_patch) # adjust xlim and ylim mne_corners = mne_clip.get_extents().corners() @@ -194,11 +228,10 @@ xmax, ymax = np.max(mne_corners, axis=0) xl = ax.get_xlim() yl = ax.get_ylim() -xpad = np.abs(np.diff([xmin, xl[1]])) / 20. -ypad = np.abs(np.diff([ymax, ymin])) / 20. +xpad = np.abs(np.diff([xmin, xl[1]])) / 20.0 +ypad = np.abs(np.diff([ymax, ymin])) / 20.0 ax.set_xlim(xmin - xpad, xl[1] + xpad) ax.set_ylim(ymax + ypad, ymin - ypad) plt.draw() -plt.savefig(op.join(static_dir, 'mne_logo_small.svg'), - dpi=dpi, transparent=True) +plt.savefig(op.join(static_dir, "mne_logo_small.svg"), dpi=dpi, transparent=True) plt.close() diff --git a/mne/__init__.py b/mne/__init__.py index 27a2846887e..4457f310986 100644 --- a/mne/__init__.py +++ b/mne/__init__.py @@ -18,95 +18,214 @@ try: from importlib.metadata import version + __version__ = version("mne") except Exception: try: from ._version import __version__ except ImportError: - __version__ = '0.0.0' + __version__ = "0.0.0" # have to import verbose first since it's needed by many things -from .utils import (set_log_level, set_log_file, verbose, set_config, - get_config, get_config_path, set_cache_dir, - set_memmap_min_size, grand_average, sys_info, open_docs, - use_log_level) -from .io.pick import (pick_types, pick_channels, - pick_channels_regexp, pick_channels_forward, - pick_types_forward, pick_channels_cov, - pick_channels_evoked, pick_info, - channel_type, channel_indices_by_type) +from .utils import ( + set_log_level, + set_log_file, + verbose, + set_config, + get_config, + get_config_path, + set_cache_dir, + set_memmap_min_size, + grand_average, + sys_info, + open_docs, + use_log_level, +) +from .io.pick import ( + pick_types, + pick_channels, + pick_channels_regexp, + pick_channels_forward, + pick_types_forward, + pick_channels_cov, + pick_channels_evoked, + pick_info, + channel_type, + channel_indices_by_type, +) from .io.base import concatenate_raws, match_channel_orders from .io.meas_info import create_info, Info from .io.proj import Projection from .io.kit import read_epochs_kit from .io.eeglab import read_epochs_eeglab -from .io.reference import (set_eeg_reference, set_bipolar_reference, - add_reference_channels) +from .io.reference import ( + set_eeg_reference, + set_bipolar_reference, + add_reference_channels, +) from .io.what import what -from .bem import (make_sphere_model, make_bem_model, make_bem_solution, - read_bem_surfaces, write_bem_surfaces, write_head_bem, - read_bem_solution, write_bem_solution) -from .cov import (read_cov, write_cov, Covariance, compute_raw_covariance, - compute_covariance, whiten_evoked, make_ad_hoc_cov) -from .event import (read_events, write_events, find_events, merge_events, - pick_events, make_fixed_length_events, concatenate_events, - find_stim_steps, AcqParserFIF, count_events) -from ._freesurfer import (head_to_mni, head_to_mri, read_talxfm, - get_volume_labels_from_aseg, read_freesurfer_lut, - vertex_to_mni, read_lta) -from .forward import (read_forward_solution, apply_forward, apply_forward_raw, - average_forward_solutions, Forward, - write_forward_solution, make_forward_solution, - convert_forward_solution, make_field_map, - make_forward_dipole, use_coil_def) -from .source_estimate import (read_source_estimate, - SourceEstimate, VectorSourceEstimate, - VolSourceEstimate, VolVectorSourceEstimate, - MixedSourceEstimate, MixedVectorSourceEstimate, - grade_to_tris, - spatial_src_adjacency, - spatial_tris_adjacency, - spatial_dist_adjacency, - spatial_inter_hemi_adjacency, - spatio_temporal_src_adjacency, - spatio_temporal_tris_adjacency, - spatio_temporal_dist_adjacency, - extract_label_time_course, stc_near_sensors) -from .surface import (read_surface, write_surface, decimate_surface, read_tri, - get_head_surf, get_meg_helmet_surf, dig_mri_distances, - warp_montage_volume, get_montage_volume_labels) +from .bem import ( + make_sphere_model, + make_bem_model, + make_bem_solution, + read_bem_surfaces, + write_bem_surfaces, + write_head_bem, + read_bem_solution, + write_bem_solution, +) +from .cov import ( + read_cov, + write_cov, + Covariance, + compute_raw_covariance, + compute_covariance, + whiten_evoked, + make_ad_hoc_cov, +) +from .event import ( + read_events, + write_events, + find_events, + merge_events, + pick_events, + make_fixed_length_events, + concatenate_events, + find_stim_steps, + AcqParserFIF, + count_events, +) +from ._freesurfer import ( + head_to_mni, + head_to_mri, + read_talxfm, + get_volume_labels_from_aseg, + read_freesurfer_lut, + vertex_to_mni, + read_lta, +) +from .forward import ( + read_forward_solution, + apply_forward, + apply_forward_raw, + average_forward_solutions, + Forward, + write_forward_solution, + make_forward_solution, + convert_forward_solution, + make_field_map, + make_forward_dipole, + use_coil_def, +) +from .source_estimate import ( + read_source_estimate, + SourceEstimate, + VectorSourceEstimate, + VolSourceEstimate, + VolVectorSourceEstimate, + MixedSourceEstimate, + MixedVectorSourceEstimate, + grade_to_tris, + spatial_src_adjacency, + spatial_tris_adjacency, + spatial_dist_adjacency, + spatial_inter_hemi_adjacency, + spatio_temporal_src_adjacency, + spatio_temporal_tris_adjacency, + spatio_temporal_dist_adjacency, + extract_label_time_course, + stc_near_sensors, +) +from .surface import ( + read_surface, + write_surface, + decimate_surface, + read_tri, + get_head_surf, + get_meg_helmet_surf, + dig_mri_distances, + warp_montage_volume, + get_montage_volume_labels, +) from .morph_map import read_morph_map -from .morph import (SourceMorph, read_source_morph, grade_to_vertices, - compute_source_morph) -from .source_space import (read_source_spaces, - write_source_spaces, setup_source_space, - setup_volume_source_space, SourceSpaces, - add_source_space_distances, morph_source_spaces, - get_volume_labels_from_src) -from .annotations import (Annotations, read_annotations, annotations_from_events, - events_from_annotations) -from .epochs import (BaseEpochs, Epochs, EpochsArray, read_epochs, - concatenate_epochs, make_fixed_length_epochs) -from .evoked import (Evoked, EvokedArray, read_evokeds, write_evokeds, - combine_evoked) -from .label import (read_label, label_sign_flip, - write_label, stc_to_label, grow_labels, Label, split_label, - BiHemiLabel, read_labels_from_annot, write_labels_to_annot, - random_parcellation, morph_labels, labels_to_stc) +from .morph import ( + SourceMorph, + read_source_morph, + grade_to_vertices, + compute_source_morph, +) +from .source_space import ( + read_source_spaces, + write_source_spaces, + setup_source_space, + setup_volume_source_space, + SourceSpaces, + add_source_space_distances, + morph_source_spaces, + get_volume_labels_from_src, +) +from .annotations import ( + Annotations, + read_annotations, + annotations_from_events, + events_from_annotations, +) +from .epochs import ( + BaseEpochs, + Epochs, + EpochsArray, + read_epochs, + concatenate_epochs, + make_fixed_length_epochs, +) +from .evoked import Evoked, EvokedArray, read_evokeds, write_evokeds, combine_evoked +from .label import ( + read_label, + label_sign_flip, + write_label, + stc_to_label, + grow_labels, + Label, + split_label, + BiHemiLabel, + read_labels_from_annot, + write_labels_to_annot, + random_parcellation, + morph_labels, + labels_to_stc, +) from .misc import parse_config, read_reject_parameters -from .coreg import (create_default_subject, scale_bem, scale_mri, scale_labels, - scale_source_space) -from .transforms import (read_trans, write_trans, - transform_surface_to, Transform) -from .proj import (read_proj, write_proj, compute_proj_epochs, - compute_proj_evoked, compute_proj_raw, sensitivity_map) +from .coreg import ( + create_default_subject, + scale_bem, + scale_mri, + scale_labels, + scale_source_space, +) +from .transforms import read_trans, write_trans, transform_surface_to, Transform +from .proj import ( + read_proj, + write_proj, + compute_proj_epochs, + compute_proj_evoked, + compute_proj_raw, + sensitivity_map, +) from .dipole import read_dipole, Dipole, DipoleFixed, fit_dipole -from .channels import (equalize_channels, rename_channels, find_layout, - read_vectorview_selection) +from .channels import ( + equalize_channels, + rename_channels, + find_layout, + read_vectorview_selection, +) from .report import Report, open_report -from .io import (read_epochs_fieldtrip, read_evoked_besa, - read_evoked_fieldtrip, read_evokeds_mff) +from .io import ( + read_epochs_fieldtrip, + read_evoked_besa, + read_evoked_fieldtrip, + read_evokeds_mff, +) from .rank import compute_rank from . import beamformer diff --git a/mne/__main__.py b/mne/__main__.py index 414754c1885..5a3bfa5abb6 100644 --- a/mne/__main__.py +++ b/mne/__main__.py @@ -3,5 +3,5 @@ from .commands.utils import main -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mne/_freesurfer.py b/mne/_freesurfer.py index d92ac40e807..9b774dafc0d 100644 --- a/mne/_freesurfer.py +++ b/mne/_freesurfer.py @@ -12,31 +12,45 @@ from .bem import _bem_find_surface, read_bem_surfaces from .io.constants import FIFF from .io.meas_info import read_fiducials -from .transforms import (apply_trans, invert_transform, combine_transforms, - _ensure_trans, read_ras_mni_t, Transform) +from .transforms import ( + apply_trans, + invert_transform, + combine_transforms, + _ensure_trans, + read_ras_mni_t, + Transform, +) from .surface import read_surface, _read_mri_surface -from .utils import (verbose, _validate_type, _check_fname, _check_option, - get_subjects_dir, _import_nibabel, logger) +from .utils import ( + verbose, + _validate_type, + _check_fname, + _check_option, + get_subjects_dir, + _import_nibabel, + logger, +) def _check_subject_dir(subject, subjects_dir): """Check that the Freesurfer subject directory is as expected.""" subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) - for img_name in ('T1', 'brain', 'aseg'): + for img_name in ("T1", "brain", "aseg"): if not (subjects_dir / subject / "mri" / f"{img_name}.mgz").is_file(): - raise ValueError('Freesurfer recon-all subject folder ' - 'is incorrect or improperly formatted, ' - f'got {subjects_dir / subject}') + raise ValueError( + "Freesurfer recon-all subject folder " + "is incorrect or improperly formatted, " + f"got {subjects_dir / subject}" + ) return subjects_dir / subject def _get_aseg(aseg, subject, subjects_dir): """Check that the anatomical segmentation file exists and load it.""" - nib = _import_nibabel('load aseg') + nib = _import_nibabel("load aseg") subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) - if not aseg.endswith('aseg'): - raise RuntimeError( - f'`aseg` file path must end with "aseg", got {aseg}') + if not aseg.endswith("aseg"): + raise RuntimeError(f'`aseg` file path must end with "aseg", got {aseg}') aseg = _check_fname( subjects_dir / subject / "mri" / (aseg + ".mgz"), overwrite="read", @@ -47,7 +61,7 @@ def _get_aseg(aseg, subject, subjects_dir): return aseg, aseg_data -def _reorient_image(img, axcodes='RAS'): +def _reorient_image(img, axcodes="RAS"): """Reorient an image to a given orientation. Parameters @@ -69,11 +83,12 @@ def _reorient_image(img, axcodes='RAS'): ----- .. versionadded:: 0.24 """ - nib = _import_nibabel('reorient MRI image') + nib = _import_nibabel("reorient MRI image") orig_data = np.array(img.dataobj).astype(np.float32) # reorient data to RAS ornt = nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(img.affine)).astype(int) + nib.orientations.aff2axcodes(img.affine) + ).astype(int) ras_ornt = nib.orientations.axcodes2ornt(axcodes) ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) @@ -105,7 +120,7 @@ def _mri_orientation(orientation): .. versionadded:: 0.21 .. versionchanged:: 0.24 """ - _check_option('orientation', orientation, ('coronal', 'axial', 'sagittal')) + _check_option("orientation", orientation, ("coronal", "axial", "sagittal")) axis = dict(coronal=1, axial=2, sagittal=0)[orientation] x, y = sorted(set([0, 1, 2]).difference(set([axis]))) return axis, x, y @@ -114,72 +129,81 @@ def _mri_orientation(orientation): def _get_mri_info_data(mri, data): # Read the segmentation data using nibabel if data: - _import_nibabel('load MRI atlas data') + _import_nibabel("load MRI atlas data") out = dict() - _, out['vox_mri_t'], out['mri_ras_t'], dims, _, mgz = _read_mri_info( - mri, return_img=True) + _, out["vox_mri_t"], out["mri_ras_t"], dims, _, mgz = _read_mri_info( + mri, return_img=True + ) out.update( - mri_width=dims[0], mri_height=dims[1], - mri_depth=dims[1], mri_volume_name=mri) + mri_width=dims[0], mri_height=dims[1], mri_depth=dims[1], mri_volume_name=mri + ) if data: assert mgz is not None - out['mri_vox_t'] = invert_transform(out['vox_mri_t']) - out['data'] = np.asarray(mgz.dataobj) + out["mri_vox_t"] = invert_transform(out["vox_mri_t"]) + out["data"] = np.asarray(mgz.dataobj) return out def _get_mgz_header(fname): """Adapted from nibabel to quickly extract header info.""" - fname = _check_fname(fname, overwrite='read', must_exist=True, - name='MRI image') + fname = _check_fname(fname, overwrite="read", must_exist=True, name="MRI image") if fname.suffix != ".mgz": - raise OSError('Filename must end with .mgz') - header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)), - ('type', '>i4'), ('dof', '>i4'), ('goodRASFlag', '>i2'), - ('delta', '>f4', (3,)), ('Mdc', '>f4', (3, 3)), - ('Pxyz_c', '>f4', (3,))] + raise OSError("Filename must end with .mgz") + header_dtd = [ + ("version", ">i4"), + ("dims", ">i4", (4,)), + ("type", ">i4"), + ("dof", ">i4"), + ("goodRASFlag", ">i2"), + ("delta", ">f4", (3,)), + ("Mdc", ">f4", (3, 3)), + ("Pxyz_c", ">f4", (3,)), + ] header_dtype = np.dtype(header_dtd) - with GzipFile(fname, 'rb') as fid: + with GzipFile(fname, "rb") as fid: hdr_str = fid.read(header_dtype.itemsize) - header = np.ndarray(shape=(), dtype=header_dtype, - buffer=hdr_str) + header = np.ndarray(shape=(), dtype=header_dtype, buffer=hdr_str) # dims - dims = header['dims'].astype(int) + dims = header["dims"].astype(int) dims = dims[:3] if len(dims) == 4 else dims # vox2ras_tkr - delta = header['delta'] + delta = header["delta"] ds = np.array(delta, float) ns = np.array(dims * ds) / 2.0 - v2rtkr = np.array([[-ds[0], 0, 0, ns[0]], - [0, 0, ds[2], -ns[2]], - [0, -ds[1], 0, ns[1]], - [0, 0, 0, 1]], dtype=np.float32) + v2rtkr = np.array( + [ + [-ds[0], 0, 0, ns[0]], + [0, 0, ds[2], -ns[2]], + [0, -ds[1], 0, ns[1]], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) # ras2vox d = np.diag(delta) pcrs_c = dims / 2.0 - Mdc = header['Mdc'].T - pxyz_0 = header['Pxyz_c'] - np.dot(Mdc, np.dot(d, pcrs_c)) + Mdc = header["Mdc"].T + pxyz_0 = header["Pxyz_c"] - np.dot(Mdc, np.dot(d, pcrs_c)) M = np.eye(4, 4) M[0:3, 0:3] = np.dot(Mdc, d) M[0:3, 3] = pxyz_0.T - header = dict(dims=dims, vox2ras_tkr=v2rtkr, vox2ras=M, - zooms=header['delta']) + header = dict(dims=dims, vox2ras_tkr=v2rtkr, vox2ras=M, zooms=header["delta"]) return header def _get_atlas_values(vol_info, rr): # Transform MRI coordinates (where our surfaces live) to voxels - rr_vox = apply_trans(vol_info['mri_vox_t'], rr) - good = ((rr_vox >= -.5) & - (rr_vox < np.array(vol_info['data'].shape, int) - 0.5)).all(-1) + rr_vox = apply_trans(vol_info["mri_vox_t"], rr) + good = ( + (rr_vox >= -0.5) & (rr_vox < np.array(vol_info["data"].shape, int) - 0.5) + ).all(-1) idx = np.round(rr_vox[good].T).astype(np.int64) values = np.full(rr.shape[0], np.nan) - values[good] = vol_info['data'][tuple(idx)] + values[good] = vol_info["data"][tuple(idx)] return values -def get_volume_labels_from_aseg(mgz_fname, return_colors=False, - atlas_ids=None): +def get_volume_labels_from_aseg(mgz_fname, return_colors=False, atlas_ids=None): """Return a list of names and colors of segmented volumes. Parameters @@ -214,7 +238,7 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, .. versionadded:: 0.9.0 """ - nib = _import_nibabel('load MRI atlas data') + nib = _import_nibabel("load MRI atlas data") mgz_fname = _check_fname( mgz_fname, overwrite="read", must_exist=True, name="mgz_fname" ) @@ -224,12 +248,13 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, if atlas_ids is None: atlas_ids, colors = read_freesurfer_lut() elif return_colors: - raise ValueError('return_colors must be False if atlas_ids are ' - 'provided') + raise ValueError("return_colors must be False if atlas_ids are " "provided") # restrict to the ones in the MRI, sorted by label name keep = np.in1d(list(atlas_ids.values()), want) - keys = sorted((key for ki, key in enumerate(atlas_ids.keys()) if keep[ki]), - key=lambda x: atlas_ids[x]) + keys = sorted( + (key for ki, key in enumerate(atlas_ids.keys()) if keep[ki]), + key=lambda x: atlas_ids[x], + ) if return_colors: colors = [colors[k] for k in keys] out = keys, colors @@ -243,8 +268,16 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, @verbose -def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, - kind='mri', unscale=False, verbose=None): +def head_to_mri( + pos, + subject, + mri_head_t, + subjects_dir=None, + *, + kind="mri", + unscale=False, + verbose=None, +): """Convert pos from head coordinate system to MRI ones. Parameters @@ -279,23 +312,24 @@ def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, This function requires nibabel. """ from .coreg import read_mri_cfg - _validate_type(kind, str, 'kind') - _check_option('kind', kind, ('ras', 'mri')) + + _validate_type(kind, str, "kind") + _check_option("kind", kind, ("ras", "mri")) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) t1_fname = subjects_dir / subject / "mri" / "T1.mgz" - head_mri_t = _ensure_trans(mri_head_t, 'head', 'mri') - if kind == 'ras': + head_mri_t = _ensure_trans(mri_head_t, "head", "mri") + if kind == "ras": _, _, mri_ras_t, _, _ = _read_mri_info(t1_fname) - head_ras_t = combine_transforms(head_mri_t, mri_ras_t, 'head', 'ras') + head_ras_t = combine_transforms(head_mri_t, mri_ras_t, "head", "ras") head_dest_t = head_ras_t else: - assert kind == 'mri' + assert kind == "mri" head_dest_t = head_mri_t pos_dest = apply_trans(head_dest_t, pos) # unscale if requested if unscale: params = read_mri_cfg(subject, subjects_dir) - pos_dest /= params['scale'] + pos_dest /= params["scale"] pos_dest *= 1e3 # mm return pos_dest @@ -303,6 +337,7 @@ def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, ############################################################################## # Surface to MNI conversion + @verbose def vertex_to_mni(vertices, hemis, subject, subjects_dir=None, verbose=None): """Convert the array of vertices for a hemisphere to MNI coordinates. @@ -332,33 +367,30 @@ def vertex_to_mni(vertices, hemis, subject, subjects_dir=None, verbose=None): hemis = [hemis] * len(vertices) if not len(hemis) == len(vertices): - raise ValueError('hemi and vertices must match in length') + raise ValueError("hemi and vertices must match in length") subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - surfs = [ - subjects_dir / subject / "surf" / f"{h}.white" - for h in ["lh", "rh"] - ] + surfs = [subjects_dir / subject / "surf" / f"{h}.white" for h in ["lh", "rh"]] # read surface locations in MRI space rr = [read_surface(s)[0] for s in surfs] # take point locations in MRI space and convert to MNI coordinates xfm = read_talxfm(subject, subjects_dir) - xfm['trans'][:3, 3] *= 1000. # m->mm + xfm["trans"][:3, 3] *= 1000.0 # m->mm data = np.array([rr[h][v, :] for h, v in zip(hemis, vertices)]) if singleton: data = data[0] - return apply_trans(xfm['trans'], data) + return apply_trans(xfm["trans"], data) ############################################################################## # Volume to MNI conversion + @verbose -def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, - verbose=None): +def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, verbose=None): """Convert pos from head coordinate system to MNI ones. Parameters @@ -384,9 +416,12 @@ def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, # before we go from head to MRI (surface RAS) head_mni_t = combine_transforms( - _ensure_trans(mri_head_t, 'head', 'mri'), - read_talxfm(subject, subjects_dir), 'head', 'mni_tal') - return apply_trans(head_mni_t, pos) * 1000. + _ensure_trans(mri_head_t, "head", "mri"), + read_talxfm(subject, subjects_dir), + "head", + "mni_tal", + ) + return apply_trans(head_mni_t, pos) * 1000.0 @verbose @@ -424,20 +459,17 @@ def get_mni_fiducials(subject, subjects_dir=None, verbose=None): # transformation, and/or project the points onto the head surface # (if available). fname_fids_fs = ( - Path(__file__).parent - / "data" - / "fsaverage" - / "fsaverage-fiducials.fif" + Path(__file__).parent / "data" / "fsaverage" / "fsaverage-fiducials.fif" ) # Read fsaverage fiducials file and subject Talairach. fids, coord_frame = read_fiducials(fname_fids_fs) assert coord_frame == FIFF.FIFFV_COORD_MRI - if subject == 'fsaverage': + if subject == "fsaverage": return fids # special short-circuit for fsaverage mni_mri_t = invert_transform(read_talxfm(subject, subjects_dir)) for f in fids: - f['r'] = apply_trans(mni_mri_t, f['r']) + f["r"] = apply_trans(mni_mri_t, f["r"]) return fids @@ -463,34 +495,40 @@ def estimate_head_mri_t(subject, subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) lpa, nasion, rpa = get_mni_fiducials(subject, subjects_dir) - montage = make_dig_montage(lpa=lpa['r'], nasion=nasion['r'], rpa=rpa['r'], - coord_frame='mri') + montage = make_dig_montage( + lpa=lpa["r"], nasion=nasion["r"], rpa=rpa["r"], coord_frame="mri" + ) return invert_transform(compute_native_head_t(montage)) def _ensure_image_in_surface_RAS(image, subject, subjects_dir): """Check if the image is in Freesurfer surface RAS space.""" - nib = _import_nibabel('load a volume image') + nib = _import_nibabel("load a volume image") if not isinstance(image, nib.spatialimages.SpatialImage): image = nib.load(image) image = nib.MGHImage(image.dataobj.astype(np.float32), image.affine) - fs_img = nib.load(op.join(subjects_dir, subject, 'mri', 'brain.mgz')) + fs_img = nib.load(op.join(subjects_dir, subject, "mri", "brain.mgz")) if not np.allclose(image.affine, fs_img.affine, atol=1e-6): - raise RuntimeError('The `image` is not aligned to Freesurfer ' - 'surface RAS space. This space is required as ' - 'it is the space where the anatomical ' - 'segmentation and reconstructed surfaces are') + raise RuntimeError( + "The `image` is not aligned to Freesurfer " + "surface RAS space. This space is required as " + "it is the space where the anatomical " + "segmentation and reconstructed surfaces are" + ) return image # returns MGH image for header def _get_affine_from_lta_info(lines): """Get the vox2ras affine from lta file info.""" - volume_data = np.loadtxt( - [line.split('=')[1] for line in lines]) + volume_data = np.loadtxt([line.split("=")[1] for line in lines]) # get the size of the volume (number of voxels), slice resolution. # the matrix of directional cosines and the ras at the center of the bore - dims, deltas, dir_cos, center_ras = \ - volume_data[0], volume_data[1], volume_data[2:5], volume_data[5] + dims, deltas, dir_cos, center_ras = ( + volume_data[0], + volume_data[1], + volume_data[2:5], + volume_data[5], + ) dir_cos_delta = dir_cos.T * deltas vol_center = (dir_cos_delta @ dims[:3]) / 2 affine = np.eye(4) @@ -514,11 +552,11 @@ def read_lta(fname, verbose=None): affine : ndarray The affine transformation described by the lta file. """ - _check_fname(fname, 'read', must_exist=True) - with open(fname, 'r') as fid: + _check_fname(fname, "read", must_exist=True) + with open(fname, "r") as fid: lines = fid.readlines() # 0 is linear vox2vox, 1 is linear ras2ras - trans_type = int(lines[0].split('=')[1].strip()[0]) + trans_type = int(lines[0].split("=")[1].strip()[0]) assert trans_type in (0, 1) affine = np.loadtxt(lines[5:9]) if trans_type == 1: @@ -556,7 +594,7 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir) # Setup the RAS to MNI transform ras_mni_t = read_ras_mni_t(subject, subjects_dir) - ras_mni_t['trans'][:3, 3] /= 1000. # mm->m + ras_mni_t["trans"][:3, 3] /= 1000.0 # mm->m # We want to get from Freesurfer surface RAS ('mri') to MNI ('mni_tal'). # This file only gives us RAS (non-zero origin) ('ras') to MNI ('mni_tal'). @@ -568,33 +606,36 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): if not path.is_file(): path = subjects_dir / subject / "mri" / "T1.mgz" if not path.is_file(): - raise OSError('mri not found: %s' % path) + raise OSError("mri not found: %s" % path) _, _, mri_ras_t, _, _ = _read_mri_info(path) - mri_mni_t = combine_transforms(mri_ras_t, ras_mni_t, 'mri', 'mni_tal') + mri_mni_t = combine_transforms(mri_ras_t, ras_mni_t, "mri", "mni_tal") return mri_mni_t def _check_mri(mri, subject, subjects_dir): """Check whether an mri exists in the Freesurfer subject directory.""" - _validate_type(mri, 'path-like', 'mri') + _validate_type(mri, "path-like", "mri") if op.isfile(mri) and op.basename(mri) != mri: return mri if not op.isfile(mri): if subject is None: raise FileNotFoundError( - f'MRI file {mri!r} not found and no subject provided') + f"MRI file {mri!r} not found and no subject provided" + ) subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - mri = op.join(subjects_dir, subject, 'mri', mri) + mri = op.join(subjects_dir, subject, "mri", mri) if not op.isfile(mri): - raise FileNotFoundError(f'MRI file {mri!r} not found') + raise FileNotFoundError(f"MRI file {mri!r} not found") if op.basename(mri) == mri: - err = (f'Ambiguous filename - found {mri!r} in current folder.\n' - 'If this is correct prefix name with relative or absolute path') + err = ( + f"Ambiguous filename - found {mri!r} in current folder.\n" + "If this is correct prefix name with relative or absolute path" + ) raise OSError(err) return mri -def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): +def _read_mri_info(path, units="m", return_img=False, use_nibabel=False): # This is equivalent but 100x slower, so only use nibabel if we need to # (later): if use_nibabel: @@ -606,29 +647,28 @@ def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): zooms = hdr.get_zooms()[:3] else: hdr = _get_mgz_header(path) - n_orig = hdr['vox2ras'] - t_orig = hdr['vox2ras_tkr'] - dims = hdr['dims'] - zooms = hdr['zooms'] + n_orig = hdr["vox2ras"] + t_orig = hdr["vox2ras_tkr"] + dims = hdr["dims"] + zooms = hdr["zooms"] # extract the MRI_VOXEL to RAS (non-zero origin) transform - vox_ras_t = Transform('mri_voxel', 'ras', n_orig) + vox_ras_t = Transform("mri_voxel", "ras", n_orig) # extract the MRI_VOXEL to MRI transform - vox_mri_t = Transform('mri_voxel', 'mri', t_orig) + vox_mri_t = Transform("mri_voxel", "mri", t_orig) # construct the MRI to RAS (non-zero origin) transform - mri_ras_t = combine_transforms( - invert_transform(vox_mri_t), vox_ras_t, 'mri', 'ras') + mri_ras_t = combine_transforms(invert_transform(vox_mri_t), vox_ras_t, "mri", "ras") - assert units in ('m', 'mm') - if units == 'm': + assert units in ("m", "mm") + if units == "m": conv = np.array([[1e-3, 1e-3, 1e-3, 1]]).T # scaling and translation terms - vox_ras_t['trans'] *= conv - vox_mri_t['trans'] *= conv + vox_ras_t["trans"] *= conv + vox_mri_t["trans"] *= conv # just the translation term - mri_ras_t['trans'][:, 3:4] *= conv + mri_ras_t["trans"][:, 3:4] *= conv out = (vox_ras_t, vox_mri_t, mri_ras_t, dims, zooms) if return_img: @@ -653,8 +693,8 @@ def read_freesurfer_lut(fname=None): Mapping from label names to colors. """ lut = _get_lut(fname) - names, ids = lut['name'], lut['id'] - colors = np.array([lut['R'], lut['G'], lut['B'], lut['A']], float).T + names, ids = lut["name"], lut["id"] + colors = np.array([lut["R"], lut["G"], lut["B"], lut["A"]], float).T atlas_ids = dict(zip(names, ids)) colors = dict(zip(names, colors)) return atlas_ids, colors @@ -664,22 +704,28 @@ def _get_lut(fname=None): """Get a FreeSurfer LUT.""" if fname is None: fname = Path(__file__).parent / "data" / "FreeSurferColorLUT.txt" - _check_fname(fname, 'read', must_exist=True) - dtype = [('id', ' 0 + assert len(lut["name"]) > 0 return lut @@ -709,44 +755,50 @@ def _get_head_surface(surf, subject, subjects_dir, bem=None, verbose=None): ----- .. versionadded: 0.24 """ - _check_option( - 'surf', surf, ('auto', 'head', 'outer_skin', 'head-dense', 'seghead')) - if surf in ('auto', 'head', 'outer_skin'): + _check_option("surf", surf, ("auto", "head", "outer_skin", "head-dense", "seghead")) + if surf in ("auto", "head", "outer_skin"): if bem is not None: try: - return _bem_find_surface(bem, 'head') + return _bem_find_surface(bem, "head") except RuntimeError: - logger.info('Could not find the surface for ' - 'head in the provided BEM model, ' - 'looking in the subject directory.') + logger.info( + "Could not find the surface for " + "head in the provided BEM model, " + "looking in the subject directory." + ) if subject is None: - if surf == 'auto': + if surf == "auto": return - raise ValueError('To plot the head surface, the BEM/sphere' - ' model must contain a head surface ' - 'or "subject" must be provided (got ' - 'None)') - subject_dir = op.join( - get_subjects_dir(subjects_dir, raise_error=True), subject) - if surf in ('head-dense', 'seghead'): - try_fnames = [op.join(subject_dir, 'bem', f'{subject}-head-dense.fif'), - op.join(subject_dir, 'surf', 'lh.seghead')] + raise ValueError( + "To plot the head surface, the BEM/sphere" + " model must contain a head surface " + 'or "subject" must be provided (got ' + "None)" + ) + subject_dir = op.join(get_subjects_dir(subjects_dir, raise_error=True), subject) + if surf in ("head-dense", "seghead"): + try_fnames = [ + op.join(subject_dir, "bem", f"{subject}-head-dense.fif"), + op.join(subject_dir, "surf", "lh.seghead"), + ] else: try_fnames = [ - op.join(subject_dir, 'bem', 'outer_skin.surf'), - op.join(subject_dir, 'bem', 'flash', 'outer_skin.surf'), - op.join(subject_dir, 'bem', f'{subject}-head-sparse.fif'), - op.join(subject_dir, 'bem', f'{subject}-head.fif'), + op.join(subject_dir, "bem", "outer_skin.surf"), + op.join(subject_dir, "bem", "flash", "outer_skin.surf"), + op.join(subject_dir, "bem", f"{subject}-head-sparse.fif"), + op.join(subject_dir, "bem", f"{subject}-head.fif"), ] for fname in try_fnames: if op.exists(fname): - logger.info(f'Using {op.basename(fname)} for head surface.') - if op.splitext(fname)[-1] == '.fif': - return read_bem_surfaces(fname, on_defects='warn')[0] + logger.info(f"Using {op.basename(fname)} for head surface.") + if op.splitext(fname)[-1] == ".fif": + return read_bem_surfaces(fname, on_defects="warn")[0] else: return _read_mri_surface(fname) - raise OSError('No head surface found for subject ' - f'{subject} after trying:\n' + '\n'.join(try_fnames)) + raise OSError( + "No head surface found for subject " + f"{subject} after trying:\n" + "\n".join(try_fnames) + ) @verbose @@ -776,29 +828,32 @@ def _get_skull_surface(surf, subject, subjects_dir, bem=None, verbose=None): """ if bem is not None: try: - return _bem_find_surface(bem, surf + '_skull') + return _bem_find_surface(bem, surf + "_skull") except RuntimeError: - logger.info('Could not find the surface for ' - 'skull in the provided BEM model, ' - 'looking in the subject directory.') + logger.info( + "Could not find the surface for " + "skull in the provided BEM model, " + "looking in the subject directory." + ) subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) fname = _check_fname( subjects_dir / subject / "bem" / (surf + "_skull.surf"), overwrite="read", must_exist=True, - name=f"{surf} skull surface" + name=f"{surf} skull surface", ) return _read_mri_surface(fname) def _estimate_talxfm_rigid(subject, subjects_dir): from .coreg import fit_matched_points, _trans_from_params + xfm = read_talxfm(subject, subjects_dir) # XYZ+origin + halfway pts_tal = np.concatenate([np.eye(4)[:, :3], np.eye(3) * 0.5]) pts_subj = apply_trans(invert_transform(xfm), pts_tal) # we fit with scaling enabled, but then discard it (we just need # the rigid-body components) - params = fit_matched_points(pts_subj, pts_tal, scale=3, out='params') + params = fit_matched_points(pts_subj, pts_tal, scale=3, out="params") rigid = _trans_from_params((True, True, False), params[:6]) return rigid diff --git a/mne/_ola.py b/mne/_ola.py index a4ecad26a66..df92f771bf6 100644 --- a/mne/_ola.py +++ b/mne/_ola.py @@ -10,6 +10,7 @@ ############################################################################### # Class for interpolation between adjacent points + class _Interp2: r"""Interpolate between two points. @@ -41,56 +42,62 @@ class _Interp2: """ - def __init__(self, control_points, values, interp='hann'): + def __init__(self, control_points, values, interp="hann"): # set up interpolation self.control_points = np.array(control_points, int).ravel() - if not np.array_equal(np.unique(self.control_points), - self.control_points): - raise ValueError('Control points must be sorted and unique') + if not np.array_equal(np.unique(self.control_points), self.control_points): + raise ValueError("Control points must be sorted and unique") if len(self.control_points) == 0: - raise ValueError('Must be at least one control point') + raise ValueError("Must be at least one control point") if not (self.control_points >= 0).all(): - raise ValueError('All control points must be positive (got %s)' - % (self.control_points[:3],)) + raise ValueError( + "All control points must be positive (got %s)" + % (self.control_points[:3],) + ) if isinstance(values, np.ndarray): values = [values] if isinstance(values, (list, tuple)): for v in values: if not (v is None or isinstance(v, np.ndarray)): - raise TypeError('All entries in "values" must be ndarray ' - 'or None, got %s' % (type(v),)) + raise TypeError( + 'All entries in "values" must be ndarray ' + "or None, got %s" % (type(v),) + ) if v is not None and v.shape[0] != len(self.control_points): - raise ValueError('Values, if provided, must be the same ' - 'length as the number of control points ' - '(%s), got %s' - % (len(self.control_points), v.shape[0])) + raise ValueError( + "Values, if provided, must be the same " + "length as the number of control points " + "(%s), got %s" % (len(self.control_points), v.shape[0]) + ) use_values = values def val(pt): idx = np.where(control_points == pt)[0][0] return [v[idx] if v is not None else None for v in use_values] + values = val self.values = values self.n_last = None self._position = 0 # start at zero self._left_idx = 0 self._left = self._right = self._use_interp = None - known_types = ('cos2', 'linear', 'zero', 'hann') + known_types = ("cos2", "linear", "zero", "hann") if interp not in known_types: - raise ValueError('interp must be one of %s, got "%s"' - % (known_types, interp)) + raise ValueError( + 'interp must be one of %s, got "%s"' % (known_types, interp) + ) self._interp = interp def feed_generator(self, n_pts): """Feed data and get interpolators as a generator.""" self.n_last = 0 - n_pts = _ensure_int(n_pts, 'n_pts') + n_pts = _ensure_int(n_pts, "n_pts") original_position = self._position stop = self._position + n_pts - logger.debug('Feed %s (%s-%s)' % (n_pts, self._position, stop)) + logger.debug("Feed %s (%s-%s)" % (n_pts, self._position, stop)) used = np.zeros(n_pts, bool) if self._left is None: # first one - logger.debug(' Eval @ %s (%s)' % (0, self.control_points[0])) + logger.debug(" Eval @ %s (%s)" % (0, self.control_points[0])) self._left = self.values(self.control_points[0]) if len(self.control_points) == 1: self._right = self._left @@ -98,9 +105,8 @@ def feed_generator(self, n_pts): # Left zero-order hold condition if self._position < self.control_points[self._left_idx]: - n_use = min(self.control_points[self._left_idx] - self._position, - n_pts) - logger.debug(' Left ZOH %s' % n_use) + n_use = min(self.control_points[self._left_idx] - self._position, n_pts) + logger.debug(" Left ZOH %s" % n_use) this_sl = slice(None, n_use) assert used[this_sl].size == n_use assert not used[this_sl].any() @@ -125,35 +131,36 @@ def feed_generator(self, n_pts): self._left_idx += 1 self._use_interp = None # need to recreate it eval_pt = self.control_points[self._left_idx + 1] - logger.debug(' Eval @ %s (%s)' - % (self._left_idx + 1, eval_pt)) + logger.debug(" Eval @ %s (%s)" % (self._left_idx + 1, eval_pt)) self._right = self.values(eval_pt) assert self._right is not None left_point = self.control_points[self._left_idx] right_point = self.control_points[self._left_idx + 1] if self._use_interp is None: interp_span = right_point - left_point - if self._interp == 'zero': + if self._interp == "zero": self._use_interp = None - elif self._interp == 'linear': - self._use_interp = np.linspace(1., 0., interp_span, - endpoint=False) + elif self._interp == "linear": + self._use_interp = np.linspace( + 1.0, 0.0, interp_span, endpoint=False + ) else: # self._interp in ('cos2', 'hann'): self._use_interp = np.cos( - np.linspace(0, np.pi / 2., interp_span, - endpoint=False)) + np.linspace(0, np.pi / 2.0, interp_span, endpoint=False) + ) self._use_interp *= self._use_interp n_use = min(stop, right_point) - self._position if n_use > 0: - logger.debug(' Interp %s %s (%s-%s)' % (self._interp, n_use, - left_point, right_point)) + logger.debug( + " Interp %s %s (%s-%s)" + % (self._interp, n_use, left_point, right_point) + ) interp_start = self._position - left_point assert interp_start >= 0 if self._use_interp is None: this_interp = None else: - this_interp = \ - self._use_interp[interp_start:interp_start + n_use] + this_interp = self._use_interp[interp_start : interp_start + n_use] assert this_interp.size == n_use this_sl = slice(n_used, n_used + n_use) assert used[this_sl].size == n_use @@ -167,7 +174,7 @@ def feed_generator(self, n_pts): if self.control_points[self._left_idx] <= self._position: n_use = stop - self._position if n_use > 0: - logger.debug(' Right ZOH %s' % n_use) + logger.debug(" Right ZOH %s" % n_use) this_sl = slice(n_pts - n_use, None) assert not used[this_sl].any() used[this_sl] = True @@ -187,16 +194,18 @@ def feed(self, n_pts): out_arrays = None for o in self.feed_generator(n_pts): if out_arrays is None: - out_arrays = [np.empty(v.shape + (n_pts,)) - if v is not None else None for v in o[1]] + out_arrays = [ + np.empty(v.shape + (n_pts,)) if v is not None else None + for v in o[1] + ] for ai, arr in enumerate(out_arrays): if arr is not None: if o[3] is None: arr[..., o[0]] = o[1][ai][..., np.newaxis] else: - arr[..., o[0]] = ( - o[1][ai][..., np.newaxis] * o[3] + - o[2][ai][..., np.newaxis] * (1. - o[3])) + arr[..., o[0]] = o[1][ai][..., np.newaxis] * o[3] + o[2][ai][ + ..., np.newaxis + ] * (1.0 - o[3]) assert out_arrays is not None return out_arrays @@ -208,12 +217,12 @@ def feed(self, n_pts): def _check_store(store): if isinstance(store, np.ndarray): store = [store] - if isinstance(store, (list, tuple)) and all(isinstance(s, np.ndarray) - for s in store): + if isinstance(store, (list, tuple)) and all( + isinstance(s, np.ndarray) for s in store + ): store = _Storer(*store) if not callable(store): - raise TypeError('store must be callable, got type %s' - % (type(store),)) + raise TypeError("store must be callable, got type %s" % (type(store),)) return store @@ -261,28 +270,40 @@ class _COLA: """ @verbose - def __init__(self, process, store, n_total, n_samples, n_overlap, - sfreq, window='hann', tol=1e-10, *, verbose=None): + def __init__( + self, + process, + store, + n_total, + n_samples, + n_overlap, + sfreq, + window="hann", + tol=1e-10, + *, + verbose=None + ): from scipy.signal import get_window - n_samples = _ensure_int(n_samples, 'n_samples') - n_overlap = _ensure_int(n_overlap, 'n_overlap') - n_total = _ensure_int(n_total, 'n_total') + + n_samples = _ensure_int(n_samples, "n_samples") + n_overlap = _ensure_int(n_overlap, "n_overlap") + n_total = _ensure_int(n_total, "n_total") if n_samples <= 0: - raise ValueError('n_samples must be > 0, got %s' % (n_samples,)) + raise ValueError("n_samples must be > 0, got %s" % (n_samples,)) if n_overlap < 0: - raise ValueError('n_overlap must be >= 0, got %s' % (n_overlap,)) + raise ValueError("n_overlap must be >= 0, got %s" % (n_overlap,)) if n_total < 0: - raise ValueError('n_total must be >= 0, got %s' % (n_total,)) + raise ValueError("n_total must be >= 0, got %s" % (n_total,)) self._n_samples = int(n_samples) self._n_overlap = int(n_overlap) del n_samples, n_overlap if n_total < self._n_samples: - raise ValueError('Number of samples per window (%d) must be at ' - 'most the total number of samples (%s)' - % (self._n_samples, n_total)) + raise ValueError( + "Number of samples per window (%d) must be at " + "most the total number of samples (%s)" % (self._n_samples, n_total) + ) if not callable(process): - raise TypeError('process must be callable, got type %s' - % (type(process),)) + raise TypeError("process must be callable, got type %s" % (type(process),)) self._process = process self._step = self._n_samples - self._n_overlap self._store = _check_store(store) @@ -290,25 +311,36 @@ def __init__(self, process, store, n_total, n_samples, n_overlap, self._in_buffers = self._out_buffers = None # Create our window boundaries - window_name = window if isinstance(window, str) else 'custom' - self._window = get_window(window, self._n_samples, - fftbins=(self._n_samples - 1) % 2) - self._window /= _check_cola(self._window, self._n_samples, self._step, - window_name, tol=tol) + window_name = window if isinstance(window, str) else "custom" + self._window = get_window( + window, self._n_samples, fftbins=(self._n_samples - 1) % 2 + ) + self._window /= _check_cola( + self._window, self._n_samples, self._step, window_name, tol=tol + ) self.starts = np.arange(0, n_total - self._n_samples + 1, self._step) self.stops = self.starts + self._n_samples delta = n_total - self.stops[-1] self.stops[-1] = n_total sfreq = float(sfreq) - pl = 's' if len(self.starts) != 1 else '' - logger.info(' Processing %4d data chunk%s of (at least) %0.1f s ' - 'with %0.1f s overlap and %s windowing' - % (len(self.starts), pl, self._n_samples / sfreq, - self._n_overlap / sfreq, window_name)) + pl = "s" if len(self.starts) != 1 else "" + logger.info( + " Processing %4d data chunk%s of (at least) %0.1f s " + "with %0.1f s overlap and %s windowing" + % ( + len(self.starts), + pl, + self._n_samples / sfreq, + self._n_overlap / sfreq, + window_name, + ) + ) del window, window_name if delta > 0: - logger.info(' The final %0.3f s will be lumped into the ' - 'final window' % (delta / sfreq,)) + logger.info( + " The final %0.3f s will be lumped into the " + "final window" % (delta / sfreq,) + ) @property def _in_offset(self): @@ -322,65 +354,79 @@ def feed(self, *datas, verbose=None, **kwargs): if self._in_buffers is None: self._in_buffers = [None] * len(datas) if len(datas) != len(self._in_buffers): - raise ValueError('Got %d array(s), needed %d' - % (len(datas), len(self._in_buffers))) + raise ValueError( + "Got %d array(s), needed %d" % (len(datas), len(self._in_buffers)) + ) for di, data in enumerate(datas): if not isinstance(data, np.ndarray) or data.ndim < 1: - raise TypeError('data entry %d must be an 2D ndarray, got %s' - % (di, type(data),)) + raise TypeError( + "data entry %d must be an 2D ndarray, got %s" + % ( + di, + type(data), + ) + ) if self._in_buffers[di] is None: # In practice, users can give large chunks, so we use # dynamic allocation of the in buffer. We could save some # memory allocation by only ever processing max_len at once, # but this would increase code complexity. - self._in_buffers[di] = np.empty( - data.shape[:-1] + (0,), data.dtype) - if data.shape[:-1] != self._in_buffers[di].shape[:-1] or \ - self._in_buffers[di].dtype != data.dtype: - raise TypeError('data must dtype %s and shape[:-1]==%s, ' - 'got dtype %s shape[:-1]=%s' - % (self._in_buffers[di].dtype, - self._in_buffers[di].shape[:-1], - data.dtype, data.shape[:-1])) - logger.debug(' + Appending %d->%d' - % (self._in_offset, self._in_offset + data.shape[-1])) - self._in_buffers[di] = np.concatenate( - [self._in_buffers[di], data], -1) + self._in_buffers[di] = np.empty(data.shape[:-1] + (0,), data.dtype) + if ( + data.shape[:-1] != self._in_buffers[di].shape[:-1] + or self._in_buffers[di].dtype != data.dtype + ): + raise TypeError( + "data must dtype %s and shape[:-1]==%s, " + "got dtype %s shape[:-1]=%s" + % ( + self._in_buffers[di].dtype, + self._in_buffers[di].shape[:-1], + data.dtype, + data.shape[:-1], + ) + ) + logger.debug( + " + Appending %d->%d" + % (self._in_offset, self._in_offset + data.shape[-1]) + ) + self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: - raise ValueError('data (shape %s) exceeded expected total ' - 'buffer size (%s > %s)' - % (data.shape, self._in_offset, - self.stops[-1])) + raise ValueError( + "data (shape %s) exceeded expected total " + "buffer size (%s > %s)" + % (data.shape, self._in_offset, self.stops[-1]) + ) # Check to see if we can process the next chunk and dump outputs - while self._idx < len(self.starts) and \ - self._in_offset >= self.stops[self._idx]: + while self._idx < len(self.starts) and self._in_offset >= self.stops[self._idx]: start, stop = self.starts[self._idx], self.stops[self._idx] this_len = stop - start this_window = self._window.copy() if self._idx == len(self.starts) - 1: this_window = np.pad( - self._window, (0, this_len - len(this_window)), 'constant') + self._window, (0, this_len - len(this_window)), "constant" + ) for offset in range(self._step, len(this_window), self._step): n_use = len(this_window) - offset this_window[offset:] += self._window[:n_use] if self._idx == 0: - for offset in range(self._n_samples - self._step, 0, - -self._step): + for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] - logger.debug(' * Processing %d->%d' % (start, stop)) - this_proc = [in_[..., :this_len].copy() - for in_ in self._in_buffers] - if not all(proc.shape[-1] == this_len == this_window.size - for proc in this_proc): - raise RuntimeError('internal indexing error') + logger.debug(" * Processing %d->%d" % (start, stop)) + this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] + if not all( + proc.shape[-1] == this_len == this_window.size for proc in this_proc + ): + raise RuntimeError("internal indexing error") outs = self._process(*this_proc, **kwargs) if self._out_buffers is None: max_len = np.max(self.stops - self.starts) - self._out_buffers = [np.zeros(o.shape[:-1] + (max_len,), - o.dtype) for o in outs] + self._out_buffers = [ + np.zeros(o.shape[:-1] + (max_len,), o.dtype) for o in outs + ] for oi, out in enumerate(outs): out *= this_window - self._out_buffers[oi][..., :stop - start] += out + self._out_buffers[oi][..., : stop - start] += out self._idx += 1 if self._idx < len(self.starts): next_start = self.starts[self._idx] @@ -389,29 +435,29 @@ def feed(self, *datas, verbose=None, **kwargs): delta = next_start - self.starts[self._idx - 1] for di in range(len(self._in_buffers)): self._in_buffers[di] = self._in_buffers[di][..., delta:] - logger.debug(' - Shifting input/output buffers by %d samples' - % (delta,)) + logger.debug(" - Shifting input/output buffers by %d samples" % (delta,)) self._store(*[o[..., :delta] for o in self._out_buffers]) for ob in self._out_buffers: ob[..., :-delta] = ob[..., delta:] - ob[..., -delta:] = 0. + ob[..., -delta:] = 0.0 def _check_cola(win, nperseg, step, window_name, tol=1e-10): """Check whether the Constant OverLap Add (COLA) constraint is met.""" # adapted from SciPy - binsums = np.sum([win[ii * step:(ii + 1) * step] - for ii in range(nperseg // step)], axis=0) + binsums = np.sum( + [win[ii * step : (ii + 1) * step] for ii in range(nperseg // step)], axis=0 + ) if nperseg % step != 0: - binsums[:nperseg % step] += win[-(nperseg % step):] + binsums[: nperseg % step] += win[-(nperseg % step) :] const = np.median(binsums) deviation = np.max(np.abs(binsums - const)) if deviation > tol: - raise ValueError('segment length %d with step %d for %s window ' - 'type does not provide a constant output ' - '(%g%% deviation)' - % (nperseg, step, window_name, - 100 * deviation / const)) + raise ValueError( + "segment length %d with step %d for %s window " + "type does not provide a constant output " + "(%g%% deviation)" % (nperseg, step, window_name, 100 * deviation / const) + ) return const @@ -421,16 +467,16 @@ class _Storer: def __init__(self, *outs, picks=None): for oi, out in enumerate(outs): if not isinstance(out, np.ndarray) or out.ndim < 1: - raise TypeError('outs[oi] must be >= 1D ndarray, got %s' - % (out,)) + raise TypeError("outs[oi] must be >= 1D ndarray, got %s" % (out,)) self.outs = outs self.idx = 0 self.picks = picks def __call__(self, *outs): - if (len(outs) != len(self.outs) or - not all(out.shape[-1] == outs[0].shape[-1] for out in outs)): - raise ValueError('Bad outs') + if len(outs) != len(self.outs) or not all( + out.shape[-1] == outs[0].shape[-1] for out in outs + ): + raise ValueError("Bad outs") idx = (Ellipsis,) if self.picks is not None: idx += (self.picks,) diff --git a/mne/annotations.py b/mne/annotations.py index 00de96d32e4..9df381bfc87 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -16,15 +16,38 @@ from textwrap import shorten import numpy as np -from .utils import (_pl, check_fname, _validate_type, verbose, warn, logger, - _check_pandas_installed, _mask_to_onsets_offsets, - _DefaultEventParser, _check_dt, _stamp_to_dt, _dt_to_stamp, - _check_fname, int_like, _check_option, fill_doc, - _on_missing, _is_numeric, _check_dict_keys) - -from .io.write import (start_block, end_block, write_float, - write_name_list_sanitized, _safe_name_list, - write_double, start_file, write_string) +from .utils import ( + _pl, + check_fname, + _validate_type, + verbose, + warn, + logger, + _check_pandas_installed, + _mask_to_onsets_offsets, + _DefaultEventParser, + _check_dt, + _stamp_to_dt, + _dt_to_stamp, + _check_fname, + int_like, + _check_option, + fill_doc, + _on_missing, + _is_numeric, + _check_dict_keys, +) + +from .io.write import ( + start_block, + end_block, + write_float, + write_name_list_sanitized, + _safe_name_list, + write_double, + start_file, + write_string, +) from .io.constants import FIFF from .io.open import fiff_open from .io.tree import dir_tree_find @@ -38,41 +61,46 @@ def _check_o_d_s_c(onset, duration, description, ch_names): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: - raise ValueError('Onset must be a one dimensional array, got %s ' - '(shape %s).' - % (onset.ndim, onset.shape)) + raise ValueError( + "Onset must be a one dimensional array, got %s " + "(shape %s)." % (onset.ndim, onset.shape) + ) duration = np.array(duration, dtype=float) if duration.ndim == 0 or duration.shape == (1,): duration = np.repeat(duration, len(onset)) if duration.ndim != 1: - raise ValueError('Duration must be a one dimensional array, ' - 'got %d.' % (duration.ndim,)) + raise ValueError( + "Duration must be a one dimensional array, " "got %d." % (duration.ndim,) + ) description = np.array(description, dtype=str) if description.ndim == 0 or description.shape == (1,): description = np.repeat(description, len(onset)) if description.ndim != 1: - raise ValueError('Description must be a one dimensional array, ' - 'got %d.' % (description.ndim,)) - _safe_name_list(description, 'write', 'description') + raise ValueError( + "Description must be a one dimensional array, " + "got %d." % (description.ndim,) + ) + _safe_name_list(description, "write", "description") # ch_names: convert to ndarray of tuples - _validate_type(ch_names, (None, tuple, list, np.ndarray), 'ch_names') + _validate_type(ch_names, (None, tuple, list, np.ndarray), "ch_names") if ch_names is None: ch_names = [()] * len(onset) ch_names = list(ch_names) for ai, ch in enumerate(ch_names): - _validate_type(ch, (list, tuple, np.ndarray), f'ch_names[{ai}]') + _validate_type(ch, (list, tuple, np.ndarray), f"ch_names[{ai}]") ch_names[ai] = tuple(ch) for ci, name in enumerate(ch_names[ai]): - _validate_type(name, str, f'ch_names[{ai}][{ci}]') + _validate_type(name, str, f"ch_names[{ai}][{ci}]") ch_names = _ndarray_ch_names(ch_names) if not (len(onset) == len(duration) == len(description) == len(ch_names)): raise ValueError( - 'Onset, duration, description, and ch_names must be ' - f'equal in sizes, got {len(onset)}, {len(duration)}, ' - f'{len(description)}, and {len(ch_names)}.') + "Onset, duration, description, and ch_names must be " + f"equal in sizes, got {len(onset)}, {len(duration)}, " + f"{len(description)}, and {len(ch_names)}." + ) return onset, duration, description, ch_names @@ -247,11 +275,13 @@ class Annotations: :meth:`Raw.save() ` notes for details. """ # noqa: E501 - def __init__(self, onset, duration, description, - orig_time=None, ch_names=None): # noqa: D102 + def __init__( + self, onset, duration, description, orig_time=None, ch_names=None + ): # noqa: D102 self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names = \ - _check_o_d_s_c(onset, duration, description, ch_names) + self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c( + onset, duration, description, ch_names + ) self._sort() # ensure we're sorted @property @@ -263,21 +293,27 @@ def __eq__(self, other): """Compare to another Annotations instance.""" if not isinstance(other, Annotations): return False - return (np.array_equal(self.onset, other.onset) and - np.array_equal(self.duration, other.duration) and - np.array_equal(self.description, other.description) and - np.array_equal(self.ch_names, other.ch_names) and - self.orig_time == other.orig_time) + return ( + np.array_equal(self.onset, other.onset) + and np.array_equal(self.duration, other.duration) + and np.array_equal(self.description, other.description) + and np.array_equal(self.ch_names, other.ch_names) + and self.orig_time == other.orig_time + ) def __repr__(self): """Show the representation.""" counter = Counter(self.description) - kinds = ', '.join(['%s (%s)' % k for k in sorted(counter.items())]) - kinds = (': ' if len(kinds) > 0 else '') + kinds - ch_specific = ', channel-specific' if self._any_ch_names() else '' - s = ('Annotations | %s segment%s%s%s' % - (len(self.onset), _pl(len(self.onset)), ch_specific, kinds)) - return '<' + shorten(s, width=77, placeholder=' ...') + '>' + kinds = ", ".join(["%s (%s)" % k for k in sorted(counter.items())]) + kinds = (": " if len(kinds) > 0 else "") + kinds + ch_specific = ", channel-specific" if self._any_ch_names() else "" + s = "Annotations | %s segment%s%s%s" % ( + len(self.onset), + _pl(len(self.onset)), + ch_specific, + kinds, + ) + return "<" + shorten(s, width=77, placeholder=" ...") + ">" def __len__(self): """Return the number of annotations. @@ -303,12 +339,14 @@ def __iadd__(self, other): if len(self) == 0: self._orig_time = other.orig_time if self.orig_time != other.orig_time: - raise ValueError("orig_time should be the same to " - "add/concatenate 2 annotations " - "(got %s != %s)" % (self.orig_time, - other.orig_time)) - return self.append(other.onset, other.duration, other.description, - other.ch_names) + raise ValueError( + "orig_time should be the same to " + "add/concatenate 2 annotations " + "(got %s != %s)" % (self.orig_time, other.orig_time) + ) + return self.append( + other.onset, other.duration, other.description, other.ch_names + ) def __iter__(self): """Iterate over the annotations.""" @@ -321,21 +359,26 @@ def __iter__(self): def __getitem__(self, key, *, with_ch_names=None): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): - out_keys = ('onset', 'duration', 'description', 'orig_time') - out_vals = (self.onset[key], self.duration[key], - self.description[key], self.orig_time) - if with_ch_names or (with_ch_names is None and - self._any_ch_names()): - out_keys += ('ch_names',) + out_keys = ("onset", "duration", "description", "orig_time") + out_vals = ( + self.onset[key], + self.duration[key], + self.description[key], + self.orig_time, + ) + if with_ch_names or (with_ch_names is None and self._any_ch_names()): + out_keys += ("ch_names",) out_vals += (self.ch_names[key],) return OrderedDict(zip(out_keys, out_vals)) else: key = list(key) if isinstance(key, tuple) else key - return Annotations(onset=self.onset[key], - duration=self.duration[key], - description=self.description[key], - orig_time=self.orig_time, - ch_names=self.ch_names[key]) + return Annotations( + onset=self.onset[key], + duration=self.duration[key], + description=self.description[key], + orig_time=self.orig_time, + ch_names=self.ch_names[key], + ) @fill_doc def append(self, onset, duration, description, ch_names=None): @@ -367,7 +410,8 @@ def append(self, onset, duration, description, ch_names=None): `list.extend `__. """ # noqa: E501 onset, duration, description, ch_names = _check_o_d_s_c( - onset, duration, description, ch_names) + onset, duration, description, ch_names + ) self.onset = np.append(self.onset, onset) self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) @@ -415,8 +459,7 @@ def to_data_frame(self): dt = _handle_meas_date(0) dt = dt.replace(tzinfo=None) onsets_dt = [dt + timedelta(seconds=o) for o in self.onset] - df = dict(onset=onsets_dt, duration=self.duration, - description=self.description) + df = dict(onset=onsets_dt, duration=self.duration, description=self.description) if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) @@ -428,7 +471,7 @@ def _any_ch_names(self): def _prune_ch_names(self, info, on_missing): # this prunes channel names and if a given channel-specific annotation # no longer has any channels left, it gets dropped - keep = set(info['ch_names']) + keep = set(info["ch_names"]) ch_names = self.ch_names warned = False drop_idx = list() @@ -439,8 +482,10 @@ def _prune_ch_names(self, info, on_missing): if name not in keep: if not warned: _on_missing( - on_missing, 'At least one channel name in ' - f'annotations missing from info: {name}') + on_missing, + "At least one channel name in " + f"annotations missing from info: {name}", + ) warned = True else: names.append(name) @@ -477,9 +522,18 @@ def save(self, fname, *, overwrite=False, verbose=None): whereas :file:`.txt` files store onset as seconds since start of the recording (e.g., ``45.95597082905339``). """ - check_fname(fname, 'annotations', ('-annot.fif', '-annot.fif.gz', - '_annot.fif', '_annot.fif.gz', - '.txt', '.csv')) + check_fname( + fname, + "annotations", + ( + "-annot.fif", + "-annot.fif.gz", + "_annot.fif", + "_annot.fif.gz", + ".txt", + ".csv", + ), + ) fname = _check_fname(fname, overwrite=overwrite) if fname.suffix == ".txt": _write_annotations_txt(fname, self) @@ -501,8 +555,9 @@ def _sort(self): self.ch_names = self.ch_names[order] @verbose - def crop(self, tmin=None, tmax=None, emit_warning=False, - use_orig_time=True, verbose=None): + def crop( + self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None + ): """Remove all annotation that are outside of [tmin, tmax]. The method operates inplace. @@ -535,39 +590,42 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if tmin is None: tmin = timedelta(seconds=self.onset.min()) + offset if tmax is None: - tmax = timedelta( - seconds=(self.onset + self.duration).max()) + offset - for key, val in [('tmin', tmin), ('tmax', tmax)]: - _validate_type(val, ('numeric', _datetime), key, - 'numeric, datetime, or None') + tmax = timedelta(seconds=(self.onset + self.duration).max()) + offset + for key, val in [("tmin", tmin), ("tmax", tmax)]: + _validate_type( + val, ("numeric", _datetime), key, "numeric, datetime, or None" + ) absolute_tmin = _handle_meas_date(tmin) absolute_tmax = _handle_meas_date(tmax) del tmin, tmax if absolute_tmin > absolute_tmax: - raise ValueError('tmax should be greater than or equal to tmin ' - '(%s < %s).' % (absolute_tmin, absolute_tmax)) - logger.debug('Cropping annotations %s - %s' % (absolute_tmin, - absolute_tmax)) + raise ValueError( + "tmax should be greater than or equal to tmin " + "(%s < %s)." % (absolute_tmin, absolute_tmax) + ) + logger.debug("Cropping annotations %s - %s" % (absolute_tmin, absolute_tmax)) onsets, durations, descriptions, ch_names = [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] - for idx, (onset, duration, description, ch) in enumerate(zip( - self.onset, self.duration, self.description, self.ch_names)): + for idx, (onset, duration, description, ch) in enumerate( + zip(self.onset, self.duration, self.description, self.ch_names) + ): # if duration is NaN behave like a zero if np.isnan(duration): - duration = 0. + duration = 0.0 # convert to absolute times absolute_onset = timedelta(seconds=onset) + offset absolute_offset = absolute_onset + timedelta(seconds=duration) out_of_bounds.append( - absolute_onset > absolute_tmax or - absolute_offset < absolute_tmin) + absolute_onset > absolute_tmax or absolute_offset < absolute_tmin + ) if out_of_bounds[-1]: clip_left_elem.append(False) clip_right_elem.append(False) logger.debug( - f' [{idx}] Dropping ' - f'({absolute_onset} - {absolute_offset}: {description})') + f" [{idx}] Dropping " + f"({absolute_onset} - {absolute_offset}: {description})" + ) else: # clip the left side clip_left_elem.append(absolute_onset < absolute_tmin) @@ -577,19 +635,18 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if clip_right_elem[-1]: absolute_offset = absolute_tmax if clip_left_elem[-1] or clip_right_elem[-1]: - durations.append( - (absolute_offset - absolute_onset).total_seconds()) + durations.append((absolute_offset - absolute_onset).total_seconds()) else: durations.append(duration) - onsets.append( - (absolute_onset - offset).total_seconds()) + onsets.append((absolute_onset - offset).total_seconds()) logger.debug( - f' [{idx}] Keeping ' - f'({absolute_onset} - {absolute_offset} -> ' - f'{onset} - {onset + duration})') + f" [{idx}] Keeping " + f"({absolute_onset} - {absolute_offset} -> " + f"{onset} - {onset + duration})" + ) descriptions.append(description) ch_names.append(ch) - logger.debug(f'Cropping complete (kept {len(onsets)})') + logger.debug(f"Cropping complete (kept {len(onsets)})") self.onset = np.array(onsets, float) self.duration = np.array(durations, float) assert (self.duration >= 0).all() @@ -599,13 +656,16 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if emit_warning: omitted = np.array(out_of_bounds).sum() if omitted > 0: - warn('Omitted %s annotation(s) that were outside data' - ' range.' % omitted) - limited = (np.array(clip_left_elem) | - np.array(clip_right_elem)).sum() + warn( + "Omitted %s annotation(s) that were outside data" + " range." % omitted + ) + limited = (np.array(clip_left_elem) | np.array(clip_right_elem)).sum() if limited > 0: - warn('Limited %s annotation(s) that were expanding outside the' - ' data range.' % limited) + warn( + "Limited %s annotation(s) that were expanding outside the" + " data range." % limited + ) return self @@ -634,9 +694,12 @@ def set_durations(self, mapping, verbose=None): _validate_type(mapping, (int, float, dict)) if isinstance(mapping, dict): - _check_dict_keys(mapping, self.description, - valid_key_source="data", - key_description="Annotation description(s)") + _check_dict_keys( + mapping, + self.description, + valid_key_source="data", + key_description="Annotation description(s)", + ) for stim in mapping: map_idx = [desc == stim for desc in self.description] self.duration[map_idx] = mapping[stim] @@ -645,9 +708,11 @@ def set_durations(self, mapping, verbose=None): self.duration = np.ones(self.description.shape) * mapping else: - raise ValueError("Setting durations requires the mapping of " - "descriptions to times to be provided as a dict. " - f"Instead {type(mapping)} was provided.") + raise ValueError( + "Setting durations requires the mapping of " + "descriptions to times to be provided as a dict. " + f"Instead {type(mapping)} was provided." + ) return self @@ -672,10 +737,13 @@ def rename(self, mapping, verbose=None): .. versionadded:: 0.24.0 """ _validate_type(mapping, dict) - _check_dict_keys(mapping, self.description, valid_key_source="data", - key_description="Annotation description(s)") - self.description = np.array( - [str(mapping.get(d, d)) for d in self.description]) + _check_dict_keys( + mapping, + self.description, + valid_key_source="data", + key_description="Annotation description(s)", + ) + self.description = np.array([str(mapping.get(d, d)) for d in self.description]) return self @@ -687,8 +755,7 @@ def annotations(self): # noqa: D102 return self._annotations @verbose - def set_annotations(self, annotations, on_missing='raise', *, - verbose=None): + def set_annotations(self, annotations, on_missing="raise", *, verbose=None): """Setter for Epoch annotations from Raw. This method does not handle offsetting the times based @@ -728,16 +795,18 @@ def set_annotations(self, annotations, on_missing='raise', *, .. versionadded:: 1.0 """ - _validate_type(annotations, (Annotations, None), 'annotations') + _validate_type(annotations, (Annotations, None), "annotations") if annotations is None: self._annotations = None else: - if getattr(self, '_unsafe_annot_add', False): - warn('Adding annotations to Epochs created (and saved to ' - 'disk) before 1.0 will yield incorrect results if ' - 'decimation or resampling was performed on the instance, ' - 'we recommend regenerating the Epochs and re-saving them ' - 'to disk') + if getattr(self, "_unsafe_annot_add", False): + warn( + "Adding annotations to Epochs created (and saved to " + "disk) before 1.0 will yield incorrect results if " + "decimation or resampling was performed on the instance, " + "we recommend regenerating the Epochs and re-saving them " + "to disk" + ) new_annotations = annotations.copy() new_annotations._prune_ch_names(self.info, on_missing) self._annotations = new_annotations @@ -766,8 +835,9 @@ def get_annotations_per_epoch(self): # when each epoch and annotation starts/stops # no need to account for first_samp here... epoch_tzeros = self.events[:, 0] / self._raw_sfreq - epoch_starts, epoch_stops = np.atleast_2d( - epoch_tzeros) + np.atleast_2d(self.times[[0, -1]]).T + epoch_starts, epoch_stops = ( + np.atleast_2d(epoch_tzeros) + np.atleast_2d(self.times[[0, -1]]).T + ) # ... because first_samp isn't accounted for here either annot_starts = self._annotations.onset annot_stops = annot_starts + self._annotations.duration @@ -779,33 +849,40 @@ def get_annotations_per_epoch(self): # we care about is presence/absence of overlap). annot_straddles_epoch_start = np.logical_and( np.atleast_2d(epoch_starts) >= np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_starts) < np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_starts) < np.atleast_2d(annot_stops).T, + ) annot_straddles_epoch_end = np.logical_and( np.atleast_2d(epoch_stops) > np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_stops) <= np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_stops) <= np.atleast_2d(annot_stops).T, + ) # this captures the only remaining case we care about: annotations # fully contained within an epoch (or exactly coextensive with it). annot_fully_within_epoch = np.logical_and( np.atleast_2d(epoch_starts) <= np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_stops) >= np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_stops) >= np.atleast_2d(annot_stops).T, + ) # combine all cases to get array of shape (n_annotations, n_epochs). # Nonzero entries indicate overlap between the corresponding # annotation (row index) and epoch (column index). - all_cases = (annot_straddles_epoch_start + - annot_straddles_epoch_end + - annot_fully_within_epoch) + all_cases = ( + annot_straddles_epoch_start + + annot_straddles_epoch_end + + annot_fully_within_epoch + ) # for each Epoch-Annotation overlap occurrence: for annot_ix, epo_ix in zip(*np.nonzero(all_cases)): this_annot = self._annotations[annot_ix] this_tzero = epoch_tzeros[epo_ix] # adjust annotation onset to be relative to epoch tzero... - annot = (this_annot['onset'] - this_tzero, - this_annot['duration'], - this_annot['description']) + annot = ( + this_annot["onset"] - this_tzero, + this_annot["duration"], + this_annot["description"], + ) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) return epoch_annot_list @@ -841,8 +918,10 @@ def add_annotations_to_metadata(self, overwrite=False): # check if annotations exist if self.annotations is None: - warn(f'There were no Annotations stored in {self}, so ' - 'metadata was not modified.') + warn( + f"There were no Annotations stored in {self}, so " + "metadata was not modified." + ) return self # get existing metadata DataFrame or instantiate an empty one @@ -852,12 +931,17 @@ def add_annotations_to_metadata(self, overwrite=False): data = np.empty((len(self.events), 0)) metadata = pd.DataFrame(data=data) - if any(name in metadata.columns for name in - ['annot_onset', 'annot_duration', 'annot_description']) and \ - not overwrite: + if ( + any( + name in metadata.columns + for name in ["annot_onset", "annot_duration", "annot_description"] + ) + and not overwrite + ): raise RuntimeError( - 'Metadata for Epochs already contains columns ' - '"annot_onset", "annot_duration", or "annot_description".') + "Metadata for Epochs already contains columns " + '"annot_onset", "annot_duration", or "annot_description".' + ) # get the Epoch annotations, then convert to separate lists for # onsets, durations, and descriptions @@ -875,17 +959,18 @@ def add_annotations_to_metadata(self, overwrite=False): # Create a new Annotations column that is instantiated as an empty # list per Epoch. - metadata['annot_onset'] = pd.Series(onset) - metadata['annot_duration'] = pd.Series(duration) - metadata['annot_description'] = pd.Series(description) + metadata["annot_onset"] = pd.Series(onset) + metadata["annot_duration"] = pd.Series(duration) + metadata["annot_description"] = pd.Series(description) # reset the metadata self.metadata = metadata return self -def _combine_annotations(one, two, one_n_samples, one_first_samp, - two_first_samp, sfreq): +def _combine_annotations( + one, two, one_n_samples, one_first_samp, two_first_samp, sfreq +): """Combine a tuple of annotations.""" assert one is not None assert two is not None @@ -909,7 +994,7 @@ def _handle_meas_date(meas_date): time. """ if isinstance(meas_date, str): - ACCEPTED_ISO8601 = '%Y-%m-%d %H:%M:%S.%f' + ACCEPTED_ISO8601 = "%Y-%m-%d %H:%M:%S.%f" try: meas_date = datetime.strptime(meas_date, ACCEPTED_ISO8601) except ValueError: @@ -937,13 +1022,12 @@ def _handle_meas_date(meas_date): def _sync_onset(raw, onset, inverse=False): """Adjust onsets in relation to raw data.""" offset = (-1 if inverse else 1) * raw._first_time - assert raw.info['meas_date'] == raw.annotations.orig_time + assert raw.info["meas_date"] == raw.annotations.orig_time annot_start = onset - offset return annot_start -def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', - invert=False): +def _annotations_starts_stops(raw, kinds, name="skip_by_annotation", invert=False): """Get starts and stops from given kinds. onsets and ends are inclusive. @@ -953,14 +1037,16 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', kinds = [kinds] else: for kind in kinds: - _validate_type(kind, 'str', "All entries") + _validate_type(kind, "str", "All entries") if len(raw.annotations) == 0: onsets, ends = np.array([], int), np.array([], int) else: - idxs = [idx for idx, desc in enumerate(raw.annotations.description) - if any(desc.upper().startswith(kind.upper()) - for kind in kinds)] + idxs = [ + idx + for idx, desc in enumerate(raw.annotations.description) + if any(desc.upper().startswith(kind.upper()) for kind in kinds) + ] # onsets are already sorted onsets = raw.annotations.onset[idxs] onsets = _sync_onset(raw, onsets) @@ -975,7 +1061,7 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', for onset, end in zip(onsets, ends): mask[onset:end] = True mask = ~mask - extras = (onsets == ends) + extras = onsets == ends extra_onsets, extra_ends = onsets[extras], ends[extras] onsets, ends = _mask_to_onsets_offsets(mask) # Keep ones where things were exactly equal @@ -992,25 +1078,28 @@ def _write_annotations(fid, annotations): """Write annotations.""" start_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) write_float(fid, FIFF.FIFF_MNE_BASELINE_MIN, annotations.onset) - write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, - annotations.duration + annotations.onset) + write_float( + fid, FIFF.FIFF_MNE_BASELINE_MAX, annotations.duration + annotations.onset + ) write_name_list_sanitized( - fid, FIFF.FIFF_COMMENT, annotations.description, name='description') + fid, FIFF.FIFF_COMMENT, annotations.description, name="description" + ) if annotations.orig_time is not None: - write_double(fid, FIFF.FIFF_MEAS_DATE, - _dt_to_stamp(annotations.orig_time)) + write_double(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(annotations.orig_time)) if annotations._any_ch_names(): - write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, - json.dumps(tuple(annotations.ch_names))) + write_string( + fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) + ) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) def _write_annotations_csv(fname, annot): annot = annot.to_data_frame() - if 'ch_names' in annot: - annot['ch_names'] = [ - _safe_name_list(ch, 'write', name=f'annot["ch_names"][{ci}') - for ci, ch in enumerate(annot['ch_names'])] + if "ch_names" in annot: + annot["ch_names"] = [ + _safe_name_list(ch, "write", name=f'annot["ch_names"][{ci}') + for ci, ch in enumerate(annot["ch_names"]) + ] annot.to_csv(fname, index=False) @@ -1022,21 +1111,24 @@ def _write_annotations_txt(fname, annot): content += "# onset, duration, description" data = [annot.onset, annot.duration, annot.description] if annot._any_ch_names(): - content += ', ch_names' - data.append([ - _safe_name_list(ch, 'write', f'annot.ch_names[{ci}]') - for ci, ch in enumerate(annot.ch_names)]) - content += '\n' + content += ", ch_names" + data.append( + [ + _safe_name_list(ch, "write", f"annot.ch_names[{ci}]") + for ci, ch in enumerate(annot.ch_names) + ] + ) + content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 assert data.shape[0] == len(annot.onset) assert data.shape[1] in (3, 4) - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: fid.write(content.encode()) - np.savetxt(fid, data, delimiter=',', fmt="%s") + np.savetxt(fid, data, delimiter=",", fmt="%s") -def read_annotations(fname, sfreq='auto', uint16_codec=None): +def read_annotations(fname, sfreq="auto", uint16_codec=None): r"""Read annotations from a file. This function reads a ``.fif``, ``.fif.gz``, ``.vmrk``, ``.amrk``, @@ -1093,46 +1185,49 @@ def read_annotations(fname, sfreq='auto', uint16_codec=None): ) ) name = op.basename(fname) - if name.endswith(('fif', 'fif.gz')): + if name.endswith(("fif", "fif.gz")): # Read FiF files ff, tree, _ = fiff_open(fname, preload=False) with ff as fid: annotations = _read_annotations_fif(fid, tree) - elif name.endswith('txt'): + elif name.endswith("txt"): orig_time = _read_annotations_txt_parse_header(fname) onset, duration, description, ch_names = _read_annotations_txt(fname) - annotations = Annotations(onset=onset, duration=duration, - description=description, orig_time=orig_time, - ch_names=ch_names) + annotations = Annotations( + onset=onset, + duration=duration, + description=description, + orig_time=orig_time, + ch_names=ch_names, + ) - elif name.endswith(('vmrk', 'amrk')): + elif name.endswith(("vmrk", "amrk")): annotations = _read_annotations_brainvision(fname, sfreq=sfreq) - elif name.endswith('csv'): + elif name.endswith("csv"): annotations = _read_annotations_csv(fname) - elif name.endswith('cnt'): + elif name.endswith("cnt"): annotations = _read_annotations_cnt(fname) - elif name.endswith('ds'): + elif name.endswith("ds"): annotations = _read_annotations_ctf(fname) - elif name.endswith('cef'): + elif name.endswith("cef"): annotations = _read_annotations_curry(fname, sfreq=sfreq) - elif name.endswith('set'): - annotations = _read_annotations_eeglab(fname, - uint16_codec=uint16_codec) + elif name.endswith("set"): + annotations = _read_annotations_eeglab(fname, uint16_codec=uint16_codec) - elif name.endswith(('edf', 'bdf', 'gdf')): + elif name.endswith(("edf", "bdf", "gdf")): onset, duration, description = _read_annotations_edf(fname) onset = np.array(onset, dtype=float) duration = np.array(duration, dtype=float) - annotations = Annotations(onset=onset, duration=duration, - description=description, - orig_time=None) + annotations = Annotations( + onset=onset, duration=duration, description=description, orig_time=None + ) - elif name.startswith('events_') and fname.endswith('mat'): + elif name.startswith("events_") and fname.endswith("mat"): annotations = _read_brainstorm_annotations(fname) else: raise OSError('Unknown annotation file format "%s"' % fname) @@ -1157,23 +1252,27 @@ def _read_annotations_csv(fname): """ pd = _check_pandas_installed(strict=True) df = pd.read_csv(fname, keep_default_na=False) - orig_time = df['onset'].values[0] + orig_time = df["onset"].values[0] try: float(orig_time) - warn('It looks like you have provided annotation onsets as floats. ' - 'These will be interpreted as MILLISECONDS. If that is not what ' - 'you want, save your CSV as a TXT file; the TXT reader accepts ' - 'onsets in seconds.') + warn( + "It looks like you have provided annotation onsets as floats. " + "These will be interpreted as MILLISECONDS. If that is not what " + "you want, save your CSV as a TXT file; the TXT reader accepts " + "onsets in seconds." + ) except ValueError: pass - onset_dt = pd.to_datetime(df['onset']) + onset_dt = pd.to_datetime(df["onset"]) onset = (onset_dt - onset_dt[0]).dt.total_seconds() - duration = df['duration'].values.astype(float) - description = df['description'].values + duration = df["duration"].values.astype(float) + description = df["description"].values ch_names = None - if 'ch_names' in df.columns: - ch_names = [_safe_name_list(val, 'read', 'annotation channel name') - for val in df['ch_names'].values] + if "ch_names" in df.columns: + ch_names = [ + _safe_name_list(val, "read", "annotation channel name") + for val in df["ch_names"].values + ] return Annotations(onset, duration, description, orig_time, ch_names) @@ -1205,33 +1304,34 @@ def get_duration_from_times(t): annot_data = io.loadmat(fname) onsets, durations, descriptions = (list(), list(), list()) - for label, _, _, _, times, _, _ in annot_data['events'][0]: + for label, _, _, _, times, _, _ in annot_data["events"][0]: onsets.append(times[0]) durations.append(get_duration_from_times(times)) n_annot = len(times[0]) descriptions += [str(label[0])] * n_annot - return Annotations(onset=np.concatenate(onsets), - duration=np.concatenate(durations), - description=descriptions, - orig_time=orig_time) + return Annotations( + onset=np.concatenate(onsets), + duration=np.concatenate(durations), + description=descriptions, + orig_time=orig_time, + ) def _is_iso8601(candidate_str): - ISO8601 = r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}\.\d{6}$' + ISO8601 = r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}\.\d{6}$" return re.compile(ISO8601).match(candidate_str) is not None def _read_annotations_txt_parse_header(fname): def is_orig_time(x): - return x.startswith('# orig_time :') + return x.startswith("# orig_time :") with open(fname) as fid: - header = list(takewhile(lambda x: x.startswith('#'), fid)) + header = list(takewhile(lambda x: x.startswith("#"), fid)) orig_values = [h[13:].strip() for h in header if is_orig_time(h)] - orig_values = [_handle_meas_date(orig) for orig in orig_values - if _is_iso8601(orig)] + orig_values = [_handle_meas_date(orig) for orig in orig_values if _is_iso8601(orig)] return None if not orig_values else orig_values[0] @@ -1239,13 +1339,12 @@ def is_orig_time(x): def _read_annotations_txt(fname): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") - out = np.loadtxt(fname, delimiter=',', - dtype=np.bytes_, unpack=True) + out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) ch_names = None if len(out) == 0: onset, duration, desc = [], [], [] else: - _check_option('text header', len(out), (3, 4)) + _check_option("text header", len(out), (3, 4)) if len(out) == 3: onset, duration, desc = out else: @@ -1256,8 +1355,9 @@ def _read_annotations_txt(fname): desc = [str(d.decode()).strip() for d in np.atleast_1d(desc)] if ch_names is not None: ch_names = [ - _safe_name_list(ch.decode().strip(), 'read', f'ch_names[{ci}]') - for ci, ch in enumerate(ch_names)] + _safe_name_list(ch.decode().strip(), "read", f"ch_names[{ci}]") + for ci, ch in enumerate(ch_names) + ] return onset, duration, desc, ch_names @@ -1270,7 +1370,7 @@ def _read_annotations_fif(fid, tree): annot_data = annot_data[0] orig_time = ch_names = None onset, duration, description = list(), list(), list() - for ent in annot_data['directory']: + for ent in annot_data["directory"]: kind = ent.kind pos = ent.pos tag = read_tag(fid, pos) @@ -1281,7 +1381,7 @@ def _read_annotations_fif(fid, tree): duration = tag.data duration = list() if duration is None else duration - onset elif kind == FIFF.FIFF_COMMENT: - description = _safe_name_list(tag.data, 'read', 'description') + description = _safe_name_list(tag.data, "read", "description") elif kind == FIFF.FIFF_MEAS_DATE: orig_time = tag.data try: @@ -1291,14 +1391,13 @@ def _read_annotations_fif(fid, tree): elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: ch_names = tuple(tuple(x) for x in json.loads(tag.data)) assert len(onset) == len(duration) == len(description) - annotations = Annotations(onset, duration, description, - orig_time, ch_names) + annotations = Annotations(onset, duration, description, orig_time, ch_names) return annotations def _select_annotations_based_on_description(descriptions, event_id, regexp): """Get a collection of descriptions and returns index of selected.""" - regexp_comp = re.compile('.*' if regexp is None else regexp) + regexp_comp = re.compile(".*" if regexp is None else regexp) event_id_ = dict() dropped = [] @@ -1323,11 +1422,10 @@ def _select_annotations_based_on_description(descriptions, event_id, regexp): else: dropped.append(desc) - event_sel = [ii for ii, kk in enumerate(descriptions) - if kk in event_id_] + event_sel = [ii for ii, kk in enumerate(descriptions) if kk in event_id_] if len(event_sel) == 0 and regexp is not None: - raise ValueError('Could not find any of the events you specified.') + raise ValueError("Could not find any of the events you specified.") return event_sel, event_id_ @@ -1345,7 +1443,7 @@ def _select_events_based_on_id(events, event_desc): event_sel = [ii for ii, e in enumerate(events) if e[2] in event_desc_] if len(event_sel) == 0: - raise ValueError('Could not find any of the events you specified.') + raise ValueError("Could not find any of the events you specified.") return event_sel, event_desc_ @@ -1358,20 +1456,23 @@ def _check_event_id(event_id, raw): if event_id is None: return _DefaultEventParser() - elif event_id == 'auto': + elif event_id == "auto": if isinstance(raw, RawBrainVision): return _BVEventParser() - elif (isinstance(raw, (RawFIF, RawArray)) and - _check_bv_annot(raw.annotations.description)): - logger.info('Non-RawBrainVision raw using branvision markers') + elif isinstance(raw, (RawFIF, RawArray)) and _check_bv_annot( + raw.annotations.description + ): + logger.info("Non-RawBrainVision raw using branvision markers") return _BVEventParser() else: return _DefaultEventParser() elif callable(event_id) or isinstance(event_id, dict): return event_id else: - raise ValueError('Invalid type for event_id (should be None, str, ' - 'dict or callable). Got {}'.format(type(event_id))) + raise ValueError( + "Invalid type for event_id (should be None, str, " + "dict or callable). Got {}".format(type(event_id)) + ) def _check_event_description(event_desc, events): @@ -1381,28 +1482,34 @@ def _check_event_description(event_desc, events): if isinstance(event_desc, dict): for val in event_desc.values(): - _validate_type(val, (str, None), 'Event names') + _validate_type(val, (str, None), "Event names") elif isinstance(event_desc, Iterable): event_desc = np.asarray(event_desc) if event_desc.ndim != 1: - raise ValueError('event_desc must be 1D, got shape {}'.format( - event_desc.shape)) + raise ValueError( + "event_desc must be 1D, got shape {}".format(event_desc.shape) + ) event_desc = dict(zip(event_desc, map(str, event_desc))) elif callable(event_desc): pass else: - raise ValueError('Invalid type for event_desc (should be None, list, ' - '1darray, dict or callable). Got {}'.format( - type(event_desc))) + raise ValueError( + "Invalid type for event_desc (should be None, list, " + "1darray, dict or callable). Got {}".format(type(event_desc)) + ) return event_desc @verbose -def events_from_annotations(raw, event_id="auto", - regexp=r'^(?![Bb][Aa][Dd]|[Ee][Dd][Gg][Ee]).*$', - use_rounding=True, chunk_duration=None, - verbose=None): +def events_from_annotations( + raw, + event_id="auto", + regexp=r"^(?![Bb][Aa][Dd]|[Ee][Dd][Gg][Ee]).*$", + use_rounding=True, + chunk_duration=None, + verbose=None, +): """Get :term:`events` and ``event_id`` from an Annotations object. Parameters @@ -1473,11 +1580,13 @@ def events_from_annotations(raw, event_id="auto", event_id = _check_event_id(event_id, raw) event_sel, event_id_ = _select_annotations_based_on_description( - annotations.description, event_id=event_id, regexp=regexp) + annotations.description, event_id=event_id, regexp=regexp + ) if chunk_duration is None: - inds = raw.time_as_index(annotations.onset, use_rounding=use_rounding, - origin=annotations.orig_time) + inds = raw.time_as_index( + annotations.onset, use_rounding=use_rounding, origin=annotations.orig_time + ) if annotations.orig_time is not None: inds += raw.first_samp values = [event_id_[kk] for kk in annotations.description[event_sel]] @@ -1485,33 +1594,36 @@ def events_from_annotations(raw, event_id="auto", else: inds = values = np.array([]).astype(int) for annot in annotations[event_sel]: - annot_offset = annot['onset'] + annot['duration'] - _onsets = np.arange(start=annot['onset'], stop=annot_offset, - step=chunk_duration) + annot_offset = annot["onset"] + annot["duration"] + _onsets = np.arange( + start=annot["onset"], stop=annot_offset, step=chunk_duration + ) good_events = annot_offset - _onsets >= chunk_duration if good_events.any(): _onsets = _onsets[good_events] - _inds = raw.time_as_index(_onsets, - use_rounding=use_rounding, - origin=annotations.orig_time) + _inds = raw.time_as_index( + _onsets, use_rounding=use_rounding, origin=annotations.orig_time + ) _inds += raw.first_samp inds = np.append(inds, _inds) - _values = np.full(shape=len(_inds), - fill_value=event_id_[annot['description']], - dtype=int) + _values = np.full( + shape=len(_inds), + fill_value=event_id_[annot["description"]], + dtype=int, + ) values = np.append(values, _values) events = np.c_[inds, np.zeros(len(inds)), values].astype(int) - logger.info('Used Annotations descriptions: %s' % - (list(event_id_.keys()),)) + logger.info("Used Annotations descriptions: %s" % (list(event_id_.keys()),)) return events, event_id_ @verbose -def annotations_from_events(events, sfreq, event_desc=None, first_samp=0, - orig_time=None, verbose=None): +def annotations_from_events( + events, sfreq, event_desc=None, first_samp=0, orig_time=None, verbose=None +): """Convert an event array to an Annotations object. Parameters @@ -1569,10 +1681,9 @@ def annotations_from_events(events, sfreq, event_desc=None, first_samp=0, durations = np.zeros(len(events_sel)) # dummy durations # Create annotations - annots = Annotations(onset=onsets, - duration=durations, - description=descriptions, - orig_time=orig_time) + annots = Annotations( + onset=onsets, duration=durations, description=descriptions, orig_time=orig_time + ) return annots @@ -1581,5 +1692,5 @@ def _adjust_onset_meas_date(annot, raw): """Adjust the annotation onsets based on raw meas_date.""" # If there is a non-None meas date, then the onset should take into # account the first_samp / first_time. - if raw.info['meas_date'] is not None: + if raw.info["meas_date"] is not None: annot.onset += raw.first_time diff --git a/mne/baseline.py b/mne/baseline.py index 10b868b46f9..21aebdde807 100644 --- a/mne/baseline.py +++ b/mne/baseline.py @@ -9,20 +9,22 @@ from .utils import logger, verbose, _check_option -def _log_rescale(baseline, mode='mean'): +def _log_rescale(baseline, mode="mean"): """Log the rescaling method.""" if baseline is not None: - _check_option('mode', mode, ['logratio', 'ratio', 'zscore', 'mean', - 'percent', 'zlogratio']) - msg = 'Applying baseline correction (mode: %s)' % mode + _check_option( + "mode", + mode, + ["logratio", "ratio", "zscore", "mean", "percent", "zlogratio"], + ) + msg = "Applying baseline correction (mode: %s)" % mode else: - msg = 'No baseline correction applied' + msg = "No baseline correction applied" return msg @verbose -def rescale(data, times, baseline, mode='mean', copy=True, picks=None, - verbose=None): +def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=None): """Rescale (baseline correct) data. Parameters @@ -73,44 +75,60 @@ def rescale(data, times, baseline, mode='mean', copy=True, picks=None, else: imin = np.where(times >= bmin)[0] if len(imin) == 0: - raise ValueError('bmin is too large (%s), it exceeds the largest ' - 'time value' % (bmin,)) + raise ValueError( + "bmin is too large (%s), it exceeds the largest " "time value" % (bmin,) + ) imin = int(imin[0]) if bmax is None: imax = len(times) else: imax = np.where(times <= bmax)[0] if len(imax) == 0: - raise ValueError('bmax is too small (%s), it is smaller than the ' - 'smallest time value' % (bmax,)) + raise ValueError( + "bmax is too small (%s), it is smaller than the " + "smallest time value" % (bmax,) + ) imax = int(imax[-1]) + 1 if imin >= imax: - raise ValueError('Bad rescaling slice (%s:%s) from time values %s, %s' - % (imin, imax, bmin, bmax)) + raise ValueError( + "Bad rescaling slice (%s:%s) from time values %s, %s" + % (imin, imax, bmin, bmax) + ) # technically this is inefficient when `picks` is given, but assuming # that we generally pick most channels for rescaling, it's not so bad mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True) - if mode == 'mean': + if mode == "mean": + def fun(d, m): d -= m - elif mode == 'ratio': + + elif mode == "ratio": + def fun(d, m): d /= m - elif mode == 'logratio': + + elif mode == "logratio": + def fun(d, m): d /= m np.log10(d, out=d) - elif mode == 'percent': + + elif mode == "percent": + def fun(d, m): d -= m d /= m - elif mode == 'zscore': + + elif mode == "zscore": + def fun(d, m): d -= m d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) - elif mode == 'zlogratio': + + elif mode == "zlogratio": + def fun(d, m): d /= m np.log10(d, out=d) @@ -124,7 +142,7 @@ def fun(d, m): return data -def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): +def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): """Check if the baseline is valid, and adjust it if requested. ``None`` values inside the baseline parameter will be replaced with @@ -158,16 +176,20 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): return None if not isinstance(baseline, tuple) or len(baseline) != 2: - raise ValueError(f'`baseline={baseline}` is an invalid argument, must ' - f'be a tuple of length 2 or None') + raise ValueError( + f"`baseline={baseline}` is an invalid argument, must " + f"be a tuple of length 2 or None" + ) tmin, tmax = times[0], times[-1] - tstep = 1. / float(sfreq) + tstep = 1.0 / float(sfreq) # check default value of baseline and `tmin=0` if baseline == (None, 0) and tmin == 0: - raise ValueError('Baseline interval is only one sample. Use ' - '`baseline=(0, 0)` if this is desired.') + raise ValueError( + "Baseline interval is only one sample. Use " + "`baseline=(0, 0)` if this is desired." + ) baseline_tmin, baseline_tmax = baseline @@ -182,17 +204,20 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): if baseline_tmin > baseline_tmax: raise ValueError( "Baseline min (%s) must be less than baseline max (%s)" - % (baseline_tmin, baseline_tmax)) + % (baseline_tmin, baseline_tmax) + ) if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep): - msg = (f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s " - f"is outside of epochs data [{tmin}, {tmax}] s. Epochs were " - f"probably cropped.") - if on_baseline_outside_data == 'raise': + msg = ( + f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s " + f"is outside of epochs data [{tmin}, {tmax}] s. Epochs were " + f"probably cropped." + ) + if on_baseline_outside_data == "raise": raise ValueError(msg) - elif on_baseline_outside_data == 'info': + elif on_baseline_outside_data == "info": logger.info(msg) - elif on_baseline_outside_data == 'adjust': + elif on_baseline_outside_data == "adjust": if baseline_tmin < tmin - tstep: baseline_tmin = tmin if baseline_tmax > tmax + tstep: diff --git a/mne/beamformer/__init__.py b/mne/beamformer/__init__.py index b82add2a7cc..a1e233686c4 100644 --- a/mne/beamformer/__init__.py +++ b/mne/beamformer/__init__.py @@ -1,9 +1,19 @@ """Beamformers for source localization.""" -from ._lcmv import (make_lcmv, apply_lcmv, apply_lcmv_epochs, apply_lcmv_raw, - apply_lcmv_cov) -from ._dics import (make_dics, apply_dics, apply_dics_epochs, - apply_dics_tfr_epochs, apply_dics_csd) +from ._lcmv import ( + make_lcmv, + apply_lcmv, + apply_lcmv_epochs, + apply_lcmv_raw, + apply_lcmv_cov, +) +from ._dics import ( + make_dics, + apply_dics, + apply_dics_epochs, + apply_dics_tfr_epochs, + apply_dics_csd, +) from ._rap_music import rap_music from ._compute_beamformer import Beamformer, read_beamformer from .resolution_matrix import make_lcmv_resolution_matrix diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index bfb547e9712..adc0c14ce40 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -15,78 +15,124 @@ from ..io.proj import make_projector, Projection from ..minimum_norm.inverse import _get_vertno, _prepare_forward from ..source_space import label_src_vertno_sel -from ..utils import (verbose, check_fname, _reg_pinv, _check_option, logger, - _pl, _check_src_normal, _sym_mat_pow, warn, - _import_h5io_funcs) +from ..utils import ( + verbose, + check_fname, + _reg_pinv, + _check_option, + logger, + _pl, + _check_src_normal, + _sym_mat_pow, + warn, + _import_h5io_funcs, +) from ..time_frequency.csd import CrossSpectralDensity def _check_proj_match(proj, filters): """Check whether SSP projections in data and spatial filter match.""" - proj_data, _, _ = make_projector(proj, filters['ch_names']) - if not np.allclose(proj_data, filters['proj'], - atol=np.finfo(float).eps, rtol=1e-13): - raise ValueError('The SSP projections present in the data ' - 'do not match the projections used when ' - 'calculating the spatial filter.') + proj_data, _, _ = make_projector(proj, filters["ch_names"]) + if not np.allclose( + proj_data, filters["proj"], atol=np.finfo(float).eps, rtol=1e-13 + ): + raise ValueError( + "The SSP projections present in the data " + "do not match the projections used when " + "calculating the spatial filter." + ) def _check_src_type(filters): """Check whether src_type is in filters and set custom warning.""" - if 'src_type' not in filters: - filters['src_type'] = None - warn_text = ('The spatial filter does not contain src_type and a robust ' - 'guess of src_type is not possible without src. Consider ' - 'recomputing the filter.') + if "src_type" not in filters: + filters["src_type"] = None + warn_text = ( + "The spatial filter does not contain src_type and a robust " + "guess of src_type is not possible without src. Consider " + "recomputing the filter." + ) return filters, warn_text -def _prepare_beamformer_input(info, forward, label=None, pick_ori=None, - noise_cov=None, rank=None, pca=False, loose=None, - combine_xyz='fro', exp=None, limit=None, - allow_fixed_depth=True, limit_depth_chs=False): +def _prepare_beamformer_input( + info, + forward, + label=None, + pick_ori=None, + noise_cov=None, + rank=None, + pca=False, + loose=None, + combine_xyz="fro", + exp=None, + limit=None, + allow_fixed_depth=True, + limit_depth_chs=False, +): """Input preparation common for LCMV, DICS, and RAP-MUSIC.""" - _check_option('pick_ori', pick_ori, - ('normal', 'max-power', 'vector', None)) + _check_option("pick_ori", pick_ori, ("normal", "max-power", "vector", None)) # Restrict forward solution to selected vertices if label is not None: - _, src_sel = label_src_vertno_sel(label, forward['src']) + _, src_sel = label_src_vertno_sel(label, forward["src"]) forward = _restrict_forward_to_src_sel(forward, src_sel) if loose is None: - loose = 0. if is_fixed_orient(forward) else 1. + loose = 0.0 if is_fixed_orient(forward) else 1.0 # TODO: Deduplicate with _check_one_ch_type, should not be necessary # (DICS hits this code path, LCMV does not) if noise_cov is None: - noise_cov = make_ad_hoc_cov(info, std=1.) - forward, info_picked, gain, _, orient_prior, _, trace_GRGT, noise_cov, \ - whitener = _prepare_forward( - forward, info, noise_cov, 'auto', loose, rank=rank, pca=pca, - use_cps=True, exp=exp, limit_depth_chs=limit_depth_chs, - combine_xyz=combine_xyz, limit=limit, - allow_fixed_depth=allow_fixed_depth) + noise_cov = make_ad_hoc_cov(info, std=1.0) + ( + forward, + info_picked, + gain, + _, + orient_prior, + _, + trace_GRGT, + noise_cov, + whitener, + ) = _prepare_forward( + forward, + info, + noise_cov, + "auto", + loose, + rank=rank, + pca=pca, + use_cps=True, + exp=exp, + limit_depth_chs=limit_depth_chs, + combine_xyz=combine_xyz, + limit=limit, + allow_fixed_depth=allow_fixed_depth, + ) is_free_ori = not is_fixed_orient(forward) # could have been changed - nn = forward['source_nn'] + nn = forward["source_nn"] if is_free_ori: # take Z coordinate nn = nn[2::3] nn = nn.copy() - vertno = _get_vertno(forward['src']) - if forward['surf_ori']: + vertno = _get_vertno(forward["src"]) + if forward["surf_ori"]: nn[...] = [0, 0, 1] # align to local +Z coordinate if pick_ori is not None and not is_free_ori: raise ValueError( - 'Normal or max-power orientation (got %r) can only be picked when ' - 'a forward operator with free orientation is used.' % (pick_ori,)) - if pick_ori == 'normal' and not forward['surf_ori']: - raise ValueError('Normal orientation can only be picked when a ' - 'forward operator oriented in surface coordinates is ' - 'used.') - _check_src_normal(pick_ori, forward['src']) + "Normal or max-power orientation (got %r) can only be picked when " + "a forward operator with free orientation is used." % (pick_ori,) + ) + if pick_ori == "normal" and not forward["surf_ori"]: + raise ValueError( + "Normal orientation can only be picked when a " + "forward operator oriented in surface coordinates is " + "used." + ) + _check_src_normal(pick_ori, forward["src"]) del forward, info # Undo the scaling that MNE prefers - scale = np.sqrt((noise_cov['eig'] > 0).sum() / trace_GRGT) + scale = np.sqrt((noise_cov["eig"] > 0).sum() / trace_GRGT) gain /= scale if orient_prior is not None: orient_std = np.sqrt(orient_prior) @@ -94,10 +140,8 @@ def _prepare_beamformer_input(info, forward, label=None, pick_ori=None, orient_std = np.ones(gain.shape[1]) # Get the projector - proj, _, _ = make_projector( - info_picked['projs'], info_picked['ch_names']) - return (is_free_ori, info_picked, proj, vertno, gain, whitener, nn, - orient_std) + proj, _, _ = make_projector(info_picked["projs"], info_picked["ch_names"]) + return (is_free_ori, info_picked, proj, vertno, gain, whitener, nn, orient_std) def _reduce_leadfield_rank(G): @@ -115,12 +159,12 @@ def _reduce_leadfield_rank(G): def _sym_inv_sm(x, reduce_rank, inversion, sk): """Symmetric inversion with single- or matrix-style inversion.""" if x.shape[1:] == (1, 1): - with np.errstate(divide='ignore', invalid='ignore'): - x_inv = 1. / x - x_inv[~np.isfinite(x_inv)] = 1. + with np.errstate(divide="ignore", invalid="ignore"): + x_inv = 1.0 / x + x_inv[~np.isfinite(x_inv)] = 1.0 else: assert x.shape[1:] == (3, 3) - if inversion == 'matrix': + if inversion == "matrix": x_inv = _sym_mat_pow(x, -1, reduce_rank=reduce_rank) # Reapply source covariance after inversion x_inv *= sk[:, :, np.newaxis] @@ -128,22 +172,33 @@ def _sym_inv_sm(x, reduce_rank, inversion, sk): else: # Invert for each dipole separately using plain division diags = np.diagonal(x, axis1=1, axis2=2) - assert not reduce_rank # guaranteed earlier - with np.errstate(divide='ignore'): - diags = 1. / diags + assert not reduce_rank # guaranteed earlier + with np.errstate(divide="ignore"): + diags = 1.0 / diags # set the diagonal of each 3x3 x_inv = np.zeros_like(x) for k in range(x.shape[0]): this = diags[k] # Reapply source covariance after inversion - this *= (sk[k] * sk[k]) + this *= sk[k] * sk[k] x_inv[k].flat[::4] = this return x_inv -def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, - reduce_rank, rank, inversion, nn, orient_std, - whitener): +def _compute_beamformer( + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank, + inversion, + nn, + orient_std, + whitener, +): """Compute a spatial beamformer filter (LCMV or DICS). For more detailed information on the parameters, see the docstrings of @@ -181,22 +236,26 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, W : ndarray, shape (n_dipoles, n_channels) The beamformer filter weights. """ - _check_option('weight_norm', weight_norm, - ['unit-noise-gain-invariant', 'unit-noise-gain', - 'nai', None]) + _check_option( + "weight_norm", + weight_norm, + ["unit-noise-gain-invariant", "unit-noise-gain", "nai", None], + ) # Whiten the data covariance Cm = whitener @ Cm @ whitener.T.conj() # Restore to properly Hermitian as large whitening coefs can have bad # rounding error - Cm[:] = (Cm + Cm.T.conj()) / 2. + Cm[:] = (Cm + Cm.T.conj()) / 2.0 assert Cm.shape == (G.shape[0],) * 2 s, _ = np.linalg.eigh(Cm) if not (s >= -s.max() * 1e-7).all(): # This shouldn't ever happen, but just in case - warn('data covariance does not appear to be positive semidefinite, ' - 'results will likely be incorrect') + warn( + "data covariance does not appear to be positive semidefinite, " + "results will likely be incorrect" + ) # Tikhonov regularization using reg parameter to control for # trade-off between spatial resolution and noise sensitivity # eq. 25 in Gross and Ioannides, 1999 Phys. Med. Biol. 44 2081 @@ -206,8 +265,9 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, n_sources = G.shape[1] // n_orient assert nn.shape == (n_sources, 3) - logger.info('Computing beamformer filters for %d source%s' - % (n_sources, _pl(n_sources))) + logger.info( + "Computing beamformer filters for %d source%s" % (n_sources, _pl(n_sources)) + ) n_channels = G.shape[0] assert n_orient in (3, 1) Gk = np.reshape(G.T, (n_sources, n_orient, n_channels)).transpose(0, 2, 1) @@ -215,29 +275,37 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, sk = np.reshape(orient_std, (n_sources, n_orient)) del G, orient_std - _check_option('reduce_rank', reduce_rank, (True, False)) + _check_option("reduce_rank", reduce_rank, (True, False)) # inversion of the denominator - _check_option('inversion', inversion, ('matrix', 'single')) - if inversion == 'single' and n_orient > 1 and pick_ori == 'vector' and \ - weight_norm == 'unit-noise-gain-invariant': + _check_option("inversion", inversion, ("matrix", "single")) + if ( + inversion == "single" + and n_orient > 1 + and pick_ori == "vector" + and weight_norm == "unit-noise-gain-invariant" + ): raise ValueError( 'Cannot use pick_ori="vector" with inversion="single" and ' - 'weight_norm="unit-noise-gain-invariant"') - if reduce_rank and inversion == 'single': - raise ValueError('reduce_rank cannot be used with inversion="single"; ' - 'consider using inversion="matrix" if you have a ' - 'rank-deficient forward model (i.e., from a sphere ' - 'model with MEG channels), otherwise consider using ' - 'reduce_rank=False') + 'weight_norm="unit-noise-gain-invariant"' + ) + if reduce_rank and inversion == "single": + raise ValueError( + 'reduce_rank cannot be used with inversion="single"; ' + 'consider using inversion="matrix" if you have a ' + "rank-deficient forward model (i.e., from a sphere " + "model with MEG channels), otherwise consider using " + "reduce_rank=False" + ) if n_orient > 1: _, Gk_s, _ = np.linalg.svd(Gk, full_matrices=False) assert Gk_s.shape == (n_sources, n_orient) if not reduce_rank and (Gk_s[:, 0] > 1e6 * Gk_s[:, 2]).any(): raise ValueError( - 'Singular matrix detected when estimating spatial filters. ' - 'Consider reducing the rank of the forward operator by using ' - 'reduce_rank=True.') + "Singular matrix detected when estimating spatial filters. " + "Consider reducing the rank of the forward operator by using " + "reduce_rank=True." + ) del Gk_s # @@ -254,7 +322,7 @@ def _compute_bf_terms(Gk, Cm_inv): # # 2. Reorient lead field in direction of max power or normal # - if pick_ori == 'max-power': + if pick_ori == "max-power": assert n_orient == 3 _, bf_denom = _compute_bf_terms(Gk, Cm_inv) if weight_norm is None: @@ -265,7 +333,8 @@ def _compute_bf_terms(Gk, Cm_inv): ori_numer = bf_denom # Cm_inv should be Hermitian so no need for .T.conj() ori_denom = np.matmul( - np.matmul(Gk.swapaxes(-2, -1).conj(), Cm_inv @ Cm_inv), Gk) + np.matmul(Gk.swapaxes(-2, -1).conj(), Cm_inv @ Cm_inv), Gk + ) ori_denom_inv = _sym_inv_sm(ori_denom, reduce_rank, inversion, sk) ori_pick = np.matmul(ori_denom_inv, ori_numer) assert ori_pick.shape == (n_sources, n_orient, n_orient) @@ -280,7 +349,7 @@ def _compute_bf_terms(Gk, Cm_inv): # set the (otherwise arbitrary) sign to match the normal signs = np.sign(np.sum(max_power_ori * nn, axis=1, keepdims=True)) - signs[signs == 0] = 1. + signs[signs == 0] = 1.0 max_power_ori *= signs # Compute the lead field for the optimal orientation, @@ -289,7 +358,7 @@ def _compute_bf_terms(Gk, Cm_inv): n_orient = 1 else: max_power_ori = None - if pick_ori == 'normal': + if pick_ori == "normal": Gk = Gk[..., 2:3] n_orient = 1 @@ -338,16 +407,17 @@ def _compute_bf_terms(Gk, Cm_inv): # # Sekihara 2008 says to use sqrt(diag(W_ug @ W_ug.T)), which is not # rotation invariant: - if weight_norm in ('unit-noise-gain', 'nai'): + if weight_norm in ("unit-noise-gain", "nai"): noise_norm = np.matmul(W, W.swapaxes(-2, -1).conj()).real noise_norm = np.reshape( # np.diag operation over last two axes - noise_norm, (n_sources, -1, 1))[:, ::n_orient + 1] + noise_norm, (n_sources, -1, 1) + )[:, :: n_orient + 1] np.sqrt(noise_norm, out=noise_norm) noise_norm[noise_norm == 0] = np.inf assert noise_norm.shape == (n_sources, n_orient, 1) W /= noise_norm else: - assert weight_norm == 'unit-noise-gain-invariant' + assert weight_norm == "unit-noise-gain-invariant" # Here we use sqrtm. The shortcut: # # use = W @@ -357,9 +427,9 @@ def _compute_bf_terms(Gk, Cm_inv): use = bf_numer inner = np.matmul(use, use.swapaxes(-2, -1).conj()) W = np.matmul(_sym_mat_pow(inner, -0.5), use) - noise_norm = 1. + noise_norm = 1.0 - if weight_norm == 'nai': + if weight_norm == "nai": # Estimate noise level based on covariance matrix, taking the # first eigenvalue that falls outside the signal subspace or the # loading factor used during regularization, whichever is largest. @@ -368,10 +438,11 @@ def _compute_bf_terms(Gk, Cm_inv): # Use the loading factor as noise ceiling. if loading_factor == 0: raise RuntimeError( - 'Cannot compute noise subspace with a full-rank ' - 'covariance matrix and no regularization. Try ' - 'manually specifying the rank of the covariance ' - 'matrix or using regularization.') + "Cannot compute noise subspace with a full-rank " + "covariance matrix and no regularization. Try " + "manually specifying the rank of the covariance " + "matrix or using regularization." + ) noise = loading_factor else: noise, _ = np.linalg.eigh(Cm) @@ -380,7 +451,7 @@ def _compute_bf_terms(Gk, Cm_inv): W /= np.sqrt(noise) W = W.reshape(n_sources * n_orient, n_channels) - logger.info('Filter computation complete') + logger.info("Filter computation complete") return W, max_power_ori @@ -402,8 +473,9 @@ def _compute_power(Cm, W, n_orient): n_sources = W.shape[0] // n_orient Wk = W.reshape(n_sources, n_orient, W.shape[1]) - source_power = np.trace((Wk @ Cm @ Wk.conj().transpose(0, 2, 1)).real, - axis1=1, axis2=2) + source_power = np.trace( + (Wk @ Cm @ Wk.conj().transpose(0, 2, 1)).real, axis1=1, axis2=2 + ) return source_power @@ -427,23 +499,27 @@ def copy(self): return deepcopy(self) def __repr__(self): # noqa: D105 - n_verts = sum(len(v) for v in self['vertices']) - n_channels = len(self['ch_names']) - if self['subject'] is None: - subject = 'unknown' + n_verts = sum(len(v) for v in self["vertices"]) + n_channels = len(self["ch_names"]) + if self["subject"] is None: + subject = "unknown" else: - subject = '"%s"' % (self['subject'],) - out = (' 1: - logger.info(' computing DICS spatial filter at ' - f'{round(freq, 2)} Hz ({i + 1}/{n_freqs})') + logger.info( + " computing DICS spatial filter at " + f"{round(freq, 2)} Hz ({i + 1}/{n_freqs})" + ) Cm = csd.get_data(index=i) @@ -228,29 +268,51 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None, # compute spatial filter n_orient = 3 if is_free_ori else 1 W, max_power_ori = _compute_beamformer( - G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank, - rank=csd_int_rank[i], inversion=inversion, nn=nn, - orient_std=orient_std, whitener=whitener) + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank=csd_int_rank[i], + inversion=inversion, + nn=nn, + orient_std=orient_std, + whitener=whitener, + ) Ws.append(W) max_oris.append(max_power_ori) Ws = np.array(Ws) - if pick_ori == 'max-power': + if pick_ori == "max-power": max_oris = np.array(max_oris) else: max_oris = None - src_type = _get_src_type(forward['src'], vertices) + src_type = _get_src_type(forward["src"], vertices) subject = _subject_from_forward(forward) - is_free_ori = is_free_ori if pick_ori in [None, 'vector'] else False + is_free_ori = is_free_ori if pick_ori in [None, "vector"] else False n_sources = np.sum([len(v) for v in vertices]) filters = Beamformer( - kind='DICS', weights=Ws, csd=csd, ch_names=ch_names, proj=proj, - vertices=vertices, n_sources=n_sources, subject=subject, - pick_ori=pick_ori, inversion=inversion, weight_norm=weight_norm, - src_type=src_type, source_nn=forward['source_nn'].copy(), - is_free_ori=is_free_ori, whitener=whitener, max_power_ori=max_oris) + kind="DICS", + weights=Ws, + csd=csd, + ch_names=ch_names, + proj=proj, + vertices=vertices, + n_sources=n_sources, + subject=subject, + pick_ori=pick_ori, + inversion=inversion, + weight_norm=weight_norm, + src_type=src_type, + source_nn=forward["source_nn"].copy(), + is_free_ori=is_free_ori, + whitener=whitener, + max_power_ori=max_oris, + ) return filters @@ -263,7 +325,7 @@ def _prepare_noise_csd(csd, noise_csd, real_filter): noise_csd = noise_csd.mean() noise_csd = noise_csd.get_data(as_cov=True) if real_filter: - noise_csd['data'] = noise_csd['data'].real + noise_csd["data"] = noise_csd["data"].real return csd, noise_csd @@ -275,10 +337,10 @@ def _apply_dics(data, filters, info, tmin, tfr=False): else: one_epoch = False - Ws = filters['weights'] + Ws = filters["weights"] one_freq = len(Ws) == 1 - subject = filters['subject'] + subject = filters["subject"] # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) @@ -288,35 +350,41 @@ def _apply_dics(data, filters, info, tmin, tfr=False): # Apply SSPs if not tfr: # save computation, only compute once - M_w = _proj_whiten_data(M, info['projs'], filters) + M_w = _proj_whiten_data(M, info["projs"], filters) stcs = [] for j, W in enumerate(Ws): - if tfr: # must compute for each frequency - M_w = _proj_whiten_data(M[:, j], info['projs'], filters) + M_w = _proj_whiten_data(M[:, j], info["projs"], filters) # project to source space using beamformer weights sol = np.dot(W, M_w) - if filters['is_free_ori'] and filters['pick_ori'] != 'vector': - logger.info('combining the current components...') + if filters["is_free_ori"] and filters["pick_ori"] != "vector": + logger.info("combining the current components...") sol = combine_xyz(sol) - tstep = 1.0 / info['sfreq'] - - stcs.append(_make_stc(sol, vertices=filters['vertices'], - src_type=filters['src_type'], tmin=tmin, - tstep=tstep, subject=subject, - vector=(filters['pick_ori'] == 'vector'), - source_nn=filters['source_nn'], - warn_text=warn_text)) + tstep = 1.0 / info["sfreq"] + + stcs.append( + _make_stc( + sol, + vertices=filters["vertices"], + src_type=filters["src_type"], + tmin=tmin, + tstep=tstep, + subject=subject, + vector=(filters["pick_ori"] == "vector"), + source_nn=filters["source_nn"], + warn_text=warn_text, + ) + ) if one_freq: yield stcs[0] else: yield stcs - logger.info('[done]') + logger.info("[done]") @verbose @@ -413,12 +481,12 @@ def apply_dics_epochs(epochs, filters, return_generator=False, verbose=None): """ _check_reference(epochs) - if len(filters['weights']) > 1: + if len(filters["weights"]) > 1: raise ValueError( - 'This function only works on DICS beamformer weights that have ' - 'been computed for a single frequency. When calling make_dics(), ' - 'make sure to use a CSD object with only a single frequency (or ' - 'frequency-bin) defined.' + "This function only works on DICS beamformer weights that have " + "been computed for a single frequency. When calling make_dics(), " + "make sure to use a CSD object with only a single frequency (or " + "frequency-bin) defined." ) info = epochs.info @@ -436,8 +504,7 @@ def apply_dics_epochs(epochs, filters, return_generator=False, verbose=None): @verbose -def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, - verbose=None): +def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, verbose=None): """Apply Dynamic Imaging of Coherent Sources (DICS) beamformer weights. Apply Dynamic Imaging of Coherent Sources (DICS) beamformer weights @@ -466,22 +533,23 @@ def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, apply_dics apply_dics_epochs apply_dics_csd - """ # noqa E501 + """ # noqa E501 _validate_type(epochs_tfr, EpochsTFR) _check_tfr_complex(epochs_tfr) - if filters['pick_ori'] == 'vector': - warn('Using a vector solution to compute power will lead to ' - 'inaccurate directions (only in the first quadrent) ' - 'because power is a strictly positive (squared) metric. ' - 'Using singular value decomposition (SVD) to determine ' - 'the direction is not yet supported in MNE.') + if filters["pick_ori"] == "vector": + warn( + "Using a vector solution to compute power will lead to " + "inaccurate directions (only in the first quadrent) " + "because power is a strictly positive (squared) metric. " + "Using singular value decomposition (SVD) to determine " + "the direction is not yet supported in MNE." + ) sel = _check_channels_spatial_filter(epochs_tfr.ch_names, filters) data = epochs_tfr.data[:, sel, :, :] - stcs = _apply_dics(data, filters, epochs_tfr.info, - epochs_tfr.tmin, tfr=True) + stcs = _apply_dics(data, filters, epochs_tfr.info, epochs_tfr.tmin, tfr=True) if not return_generator: stcs = [[stc for stc in tfr_stcs] for tfr_stcs in stcs] return stcs @@ -531,12 +599,12 @@ def apply_dics_csd(csd, filters, verbose=None): ---------- .. footbibliography:: """ # noqa: E501 - ch_names = filters['ch_names'] - vertices = filters['vertices'] - n_orient = 3 if filters['is_free_ori'] else 1 - subject = filters['subject'] - whitener = filters['whitener'] - n_sources = filters['n_sources'] + ch_names = filters["ch_names"] + vertices = filters["vertices"] + n_orient = 3 if filters["is_free_ori"] else 1 + subject = filters["subject"] + whitener = filters["whitener"] + n_sources = filters["n_sources"] # If CSD is summed over multiple frequencies, take the average frequency frequencies = [np.mean(dfreq) for dfreq in csd.frequencies] @@ -547,27 +615,37 @@ def apply_dics_csd(csd, filters, verbose=None): # Ensure the CSD is in the same order as the weights csd_picks = [csd.ch_names.index(ch) for ch in ch_names] - logger.info('Computing DICS source power...') + logger.info("Computing DICS source power...") for i, freq in enumerate(frequencies): if n_freqs > 1: - logger.info(' applying DICS spatial filter at ' - f'{round(freq, 2)} Hz ({i + 1}/{n_freqs})') + logger.info( + " applying DICS spatial filter at " + f"{round(freq, 2)} Hz ({i + 1}/{n_freqs})" + ) Cm = csd.get_data(index=i) Cm = Cm[csd_picks, :][:, csd_picks] - W = filters['weights'][i] + W = filters["weights"][i] # Whiten the CSD Cm = np.dot(whitener, np.dot(Cm, whitener.conj().T)) source_power[:, i] = _compute_power(Cm, W, n_orient) - logger.info('[done]') + logger.info("[done]") # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - return (_make_stc(source_power, vertices=vertices, - src_type=filters['src_type'], tmin=0., tstep=1., - subject=subject, warn_text=warn_text), - frequencies) + return ( + _make_stc( + source_power, + vertices=vertices, + src_type=filters["src_type"], + tmin=0.0, + tstep=1.0, + subject=subject, + warn_text=warn_text, + ), + frequencies, + ) diff --git a/mne/beamformer/_lcmv.py b/mne/beamformer/_lcmv.py index 61c45a8ec66..3e67890da65 100644 --- a/mne/beamformer/_lcmv.py +++ b/mne/beamformer/_lcmv.py @@ -13,18 +13,39 @@ from ..forward import _subject_from_forward from ..minimum_norm.inverse import combine_xyz, _check_reference, _check_depth from ..source_estimate import _make_stc, _get_src_type -from ..utils import (logger, verbose, _check_channels_spatial_filter, - _check_one_ch_type, _check_info_inv) +from ..utils import ( + logger, + verbose, + _check_channels_spatial_filter, + _check_one_ch_type, + _check_info_inv, +) from ._compute_beamformer import ( - _prepare_beamformer_input, _compute_power, - _compute_beamformer, _check_src_type, Beamformer, _proj_whiten_data) + _prepare_beamformer_input, + _compute_power, + _compute_beamformer, + _check_src_type, + Beamformer, + _proj_whiten_data, +) @verbose -def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, - pick_ori=None, rank='info', - weight_norm='unit-noise-gain-invariant', - reduce_rank=False, depth=None, inversion='matrix', verbose=None): +def make_lcmv( + info, + forward, + data_cov, + reg=0.05, + noise_cov=None, + label=None, + pick_ori=None, + rank="info", + weight_norm="unit-noise-gain-invariant", + reduce_rank=False, + depth=None, + inversion="matrix", + verbose=None, +): """Compute LCMV spatial filter. Parameters @@ -144,7 +165,8 @@ def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, # check number of sensor types present in the data and ensure a noise cov info = _simplify_info(info) noise_cov, _, allow_mismatch = _check_one_ch_type( - 'lcmv', info, forward, data_cov, noise_cov) + "lcmv", info, forward, data_cov, noise_cov + ) # XXX we need this extra picking step (can't just rely on minimum norm's # because there can be a mismatch. Should probably add an extra arg to # _prepare_beamformer_input at some point (later) @@ -153,58 +175,97 @@ def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, data_rank = compute_rank(data_cov, rank=rank, info=info) noise_rank = compute_rank(noise_cov, rank=rank, info=info) for key in data_rank: - if (key not in noise_rank or data_rank[key] != noise_rank[key]) and \ - not allow_mismatch: - raise ValueError('%s data rank (%s) did not match the noise ' - 'rank (%s)' - % (key, data_rank[key], - noise_rank.get(key, None))) + if ( + key not in noise_rank or data_rank[key] != noise_rank[key] + ) and not allow_mismatch: + raise ValueError( + "%s data rank (%s) did not match the noise " + "rank (%s)" % (key, data_rank[key], noise_rank.get(key, None)) + ) del noise_rank rank = data_rank - logger.info('Making LCMV beamformer with rank %s' % (rank,)) + logger.info("Making LCMV beamformer with rank %s" % (rank,)) del data_rank - depth = _check_depth(depth, 'depth_sparse') - if inversion == 'single': - depth['combine_xyz'] = False - - is_free_ori, info, proj, vertno, G, whitener, nn, orient_std = \ - _prepare_beamformer_input( - info, forward, label, pick_ori, noise_cov=noise_cov, rank=rank, - pca=False, **depth) - ch_names = list(info['ch_names']) + depth = _check_depth(depth, "depth_sparse") + if inversion == "single": + depth["combine_xyz"] = False + + ( + is_free_ori, + info, + proj, + vertno, + G, + whitener, + nn, + orient_std, + ) = _prepare_beamformer_input( + info, + forward, + label, + pick_ori, + noise_cov=noise_cov, + rank=rank, + pca=False, + **depth + ) + ch_names = list(info["ch_names"]) data_cov = pick_channels_cov(data_cov, include=ch_names) Cm = data_cov._get_square() - if 'estimator' in data_cov: - del data_cov['estimator'] + if "estimator" in data_cov: + del data_cov["estimator"] rank_int = sum(rank.values()) del rank # compute spatial filter n_orient = 3 if is_free_ori else 1 W, max_power_ori = _compute_beamformer( - G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank, rank_int, - inversion=inversion, nn=nn, orient_std=orient_std, - whitener=whitener) + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank_int, + inversion=inversion, + nn=nn, + orient_std=orient_std, + whitener=whitener, + ) # get src type to store with filters for _make_stc - src_type = _get_src_type(forward['src'], vertno) + src_type = _get_src_type(forward["src"], vertno) # get subject to store with filters subject_from = _subject_from_forward(forward) # Is the computed beamformer a scalar or vector beamformer? - is_free_ori = is_free_ori if pick_ori in [None, 'vector'] else False - is_ssp = bool(info['projs']) + is_free_ori = is_free_ori if pick_ori in [None, "vector"] else False + is_ssp = bool(info["projs"]) filters = Beamformer( - kind='LCMV', weights=W, data_cov=data_cov, noise_cov=noise_cov, - whitener=whitener, weight_norm=weight_norm, pick_ori=pick_ori, - ch_names=ch_names, proj=proj, is_ssp=is_ssp, vertices=vertno, - is_free_ori=is_free_ori, n_sources=forward['nsource'], - src_type=src_type, source_nn=forward['source_nn'].copy(), - subject=subject_from, rank=rank_int, max_power_ori=max_power_ori, - inversion=inversion) + kind="LCMV", + weights=W, + data_cov=data_cov, + noise_cov=noise_cov, + whitener=whitener, + weight_norm=weight_norm, + pick_ori=pick_ori, + ch_names=ch_names, + proj=proj, + is_ssp=is_ssp, + vertices=vertno, + is_free_ori=is_free_ori, + n_sources=forward["nsource"], + src_type=src_type, + source_nn=forward["source_nn"].copy(), + subject=subject_from, + rank=rank_int, + max_power_ori=max_power_ori, + inversion=inversion, + ) return filters @@ -217,45 +278,51 @@ def _apply_lcmv(data, filters, info, tmin): else: return_single = False - W = filters['weights'] + W = filters["weights"] for i, M in enumerate(data): - if len(M) != len(filters['ch_names']): - raise ValueError('data and picks must have the same length') + if len(M) != len(filters["ch_names"]): + raise ValueError("data and picks must have the same length") if not return_single: logger.info("Processing epoch : %d" % (i + 1)) - M = _proj_whiten_data(M, info['projs'], filters) + M = _proj_whiten_data(M, info["projs"], filters) # project to source space using beamformer weights vector = False - if filters['is_free_ori']: + if filters["is_free_ori"]: sol = np.dot(W, M) - if filters['pick_ori'] == 'vector': + if filters["pick_ori"] == "vector": vector = True else: - logger.info('combining the current components...') + logger.info("combining the current components...") sol = combine_xyz(sol) else: # Linear inverse: do computation here or delayed - if (M.shape[0] < W.shape[0] and - filters['pick_ori'] != 'max-power'): + if M.shape[0] < W.shape[0] and filters["pick_ori"] != "max-power": sol = (W, M) else: sol = np.dot(W, M) - tstep = 1.0 / info['sfreq'] + tstep = 1.0 / info["sfreq"] # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - yield _make_stc(sol, vertices=filters['vertices'], tmin=tmin, - tstep=tstep, subject=filters['subject'], - vector=vector, source_nn=filters['source_nn'], - src_type=filters['src_type'], warn_text=warn_text) + yield _make_stc( + sol, + vertices=filters["vertices"], + tmin=tmin, + tstep=tstep, + subject=filters["subject"], + vector=vector, + source_nn=filters["source_nn"], + src_type=filters["src_type"], + warn_text=warn_text, + ) - logger.info('[done]') + logger.info("[done]") @verbose @@ -296,15 +363,13 @@ def apply_lcmv(evoked, filters, *, verbose=None): sel = _check_channels_spatial_filter(evoked.ch_names, filters) data = data[sel] - stc = _apply_lcmv(data=data, filters=filters, info=info, - tmin=tmin) + stc = _apply_lcmv(data=data, filters=filters, info=info, tmin=tmin) return next(stc) @verbose -def apply_lcmv_epochs(epochs, filters, *, return_generator=False, - verbose=None): +def apply_lcmv_epochs(epochs, filters, *, return_generator=False, verbose=None): """Apply Linearly Constrained Minimum Variance (LCMV) beamformer weights. Apply Linearly Constrained Minimum Variance (LCMV) beamformer weights @@ -338,8 +403,7 @@ def apply_lcmv_epochs(epochs, filters, *, return_generator=False, sel = _check_channels_spatial_filter(epochs.ch_names, filters) data = epochs.get_data()[:, sel, :] - stcs = _apply_lcmv(data=data, filters=filters, info=info, - tmin=tmin) + stcs = _apply_lcmv(data=data, filters=filters, info=info, tmin=tmin) if not return_generator: stcs = [s for s in stcs] @@ -418,17 +482,23 @@ def apply_lcmv_cov(data_cov, filters, verbose=None): sel_names = [data_cov.ch_names[ii] for ii in sel] data_cov = pick_channels_cov(data_cov, sel_names) - n_orient = filters['weights'].shape[0] // filters['n_sources'] + n_orient = filters["weights"].shape[0] // filters["n_sources"] # Need to project and whiten along both dimensions - data = _proj_whiten_data(data_cov['data'].T, data_cov['projs'], filters) - data = _proj_whiten_data(data.T, data_cov['projs'], filters) + data = _proj_whiten_data(data_cov["data"].T, data_cov["projs"], filters) + data = _proj_whiten_data(data.T, data_cov["projs"], filters) del data_cov - source_power = _compute_power(data, filters['weights'], n_orient) + source_power = _compute_power(data, filters["weights"], n_orient) # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - return _make_stc(source_power, vertices=filters['vertices'], - src_type=filters['src_type'], tmin=0., tstep=1., - subject=filters['subject'], - source_nn=filters['source_nn'], warn_text=warn_text) + return _make_stc( + source_power, + vertices=filters["vertices"], + src_type=filters["src_type"], + tmin=0.0, + tstep=1.0, + subject=filters["subject"], + source_nn=filters["source_nn"], + warn_text=warn_text, + ) diff --git a/mne/beamformer/_rap_music.py b/mne/beamformer/_rap_music.py index 3b59fa90c46..d58de523b2a 100644 --- a/mne/beamformer/_rap_music.py +++ b/mne/beamformer/_rap_music.py @@ -17,8 +17,7 @@ @fill_doc -def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, - picks=None): +def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, picks=None): """RAP-MUSIC for evoked data. Parameters @@ -47,15 +46,17 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, Computed only if return_explained_data is True. """ from scipy import linalg + info = pick_info(info, picks) del picks # things are much simpler if we avoid surface orientation - align = forward['source_nn'].copy() - if forward['surf_ori'] and not is_fixed_orient(forward): + align = forward["source_nn"].copy() + if forward["surf_ori"] and not is_fixed_orient(forward): forward = convert_forward_solution(forward, surf_ori=False) is_free_ori, info, _, _, G, whitener, _, _ = _prepare_beamformer_input( - info, forward, noise_cov=noise_cov, rank=None) - forward = pick_channels_forward(forward, info['ch_names'], ordered=True) + info, forward, noise_cov=noise_cov, rank=None + ) + forward = pick_channels_forward(forward, info["ch_names"], ordered=True) del info # whiten the data (leadfield already whitened) @@ -67,7 +68,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, n_orient = 3 if is_free_ori else 1 G.shape = (G.shape[0], -1, n_orient) - gain = forward['sol']['data'].copy() + gain = forward["sol"]["data"].copy() gain.shape = G.shape n_channels = G.shape[0] A = np.empty((n_channels, n_dipoles)) @@ -80,7 +81,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, idxs = list() for k in range(n_dipoles): - subcorr_max = -1. + subcorr_max = -1.0 source_idx, source_ori, source_pos = 0, [0, 0, 0], [0, 0, 0] for i_source in range(G.shape[1]): Gk = G_proj[:, i_source] @@ -89,13 +90,13 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, subcorr_max = subcorr source_idx = i_source source_ori = ori - source_pos = forward['source_rr'][i_source] + source_pos = forward["source_rr"][i_source] if n_orient == 3 and align is not None: - surf_normal = forward['source_nn'][3 * i_source + 2] + surf_normal = forward["source_nn"][3 * i_source + 2] # make sure ori is aligned to the surface orientation - source_ori *= np.sign(source_ori @ surf_normal) or 1. + source_ori *= np.sign(source_ori @ surf_normal) or 1.0 if n_orient == 1: - source_ori = forward['source_nn'][i_source] + source_ori = forward["source_nn"][i_source] idxs.append(source_idx) if n_orient == 3: @@ -110,8 +111,8 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, if n_orient == 3: logger.info("ori = %s %s %s" % tuple(oris[k])) - projection = _compute_proj(A[:, :k + 1]) - G_proj = np.einsum('ab,bso->aso', projection, G) + projection = _compute_proj(A[:, : k + 1]) + G_proj = np.einsum("ab,bso->aso", projection, G) phi_sig_proj = np.dot(projection, phi_sig) del G, G_proj @@ -126,8 +127,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, if n_orient == 3: gain_dip = (oris * gain_active).sum(-1) idxs = np.array(idxs) - active_set = np.array( - [[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel() + active_set = np.array([[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel() else: gain_dip = gain_active[:, :, 0] active_set = idxs @@ -137,15 +137,15 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, explained_data = gain_dip @ sol M_estimate = whitener @ explained_data _log_exp_var(M, M_estimate) - tstep = np.median(np.diff(times)) if len(times) > 1 else 1. + tstep = np.median(np.diff(times)) if len(times) > 1 else 1.0 dipoles = _make_dipoles_sparse( - X, active_set, forward, times[0], tstep, M, - gain_active, active_is_idx=True) + X, active_set, forward, times[0], tstep, M, gain_active, active_is_idx=True + ) for dipole, ori in zip(dipoles, oris): signs = np.sign((dipole.ori * ori).sum(-1, keepdims=True)) dipole.ori *= signs dipole.amplitude *= signs[:, 0] - logger.info('[done]') + logger.info("[done]") return dipoles, explained_data @@ -185,6 +185,7 @@ def _make_dipoles(times, poss, oris, sol, gof): def _compute_subcorr(G, phi_sig): """Compute the subspace correlation.""" from scipy import linalg + Ug, Sg, Vg = linalg.svd(G, full_matrices=False) # Now we look at the actual rank of the forward fields # in G and handle the fact that it might be rank defficient @@ -204,13 +205,15 @@ def _compute_subcorr(G, phi_sig): def _compute_proj(A): """Compute the orthogonal projection operation for a manifold vector A.""" from scipy import linalg + U, _, _ = linalg.svd(A, full_matrices=False) return np.identity(A.shape[0]) - np.dot(U, U.T.conjugate()) @verbose -def rap_music(evoked, forward, noise_cov, n_dipoles=5, return_residual=False, - verbose=None): +def rap_music( + evoked, forward, noise_cov, n_dipoles=5, return_residual=False, verbose=None +): """RAP-MUSIC source localization method. Compute Recursively Applied and Projected MUltiple SIgnal Classification @@ -269,16 +272,16 @@ def rap_music(evoked, forward, noise_cov, n_dipoles=5, return_residual=False, data = data[picks] - dipoles, explained_data = _apply_rap_music(data, info, times, forward, - noise_cov, n_dipoles, - picks) + dipoles, explained_data = _apply_rap_music( + data, info, times, forward, noise_cov, n_dipoles, picks + ) if return_residual: - residual = evoked.copy().pick([info['ch_names'][p] for p in picks]) + residual = evoked.copy().pick([info["ch_names"][p] for p in picks]) residual.data -= explained_data - active_projs = [p for p in residual.info['projs'] if p['active']] + active_projs = [p for p in residual.info["projs"] if p["active"]] for p in active_projs: - p['active'] = False + p["active"] = False residual.add_proj(active_projs, remove_existing=True) residual.apply_proj() return dipoles, residual diff --git a/mne/beamformer/resolution_matrix.py b/mne/beamformer/resolution_matrix.py index 5294de5a621..278ae65692a 100644 --- a/mne/beamformer/resolution_matrix.py +++ b/mne/beamformer/resolution_matrix.py @@ -33,8 +33,8 @@ def make_lcmv_resolution_matrix(filters, forward, info): for free dipole orientation versus factor 1 for scalar beamformers). """ # don't include bad channels from noise covariance matrix - bads_filt = filters['noise_cov']['bads'] - ch_names = filters['noise_cov']['names'] + bads_filt = filters["noise_cov"]["bads"] + ch_names = filters["noise_cov"]["names"] # good channels ch_names = [c for c in ch_names if (c not in bads_filt)] @@ -43,7 +43,7 @@ def make_lcmv_resolution_matrix(filters, forward, info): forward = pick_channels_forward(forward, ch_names, ordered=True) # get leadfield matrix from forward solution - leadfield = forward['sol']['data'] + leadfield = forward["sol"]["data"] # get the filter weights for beamformer as matrix filtmat = _get_matrix_from_lcmv(filters, forward, info) @@ -53,7 +53,7 @@ def make_lcmv_resolution_matrix(filters, forward, info): shape = resmat.shape - logger.info('Dimensions of LCMV resolution matrix: %d by %d.' % shape) + logger.info("Dimensions of LCMV resolution matrix: %d by %d." % shape) return resmat @@ -67,16 +67,15 @@ def _get_matrix_from_lcmv(filters, forward, info, verbose=None): Inverse matrix associated with LCMV beamformer filters. """ # number of channels for identity matrix - info = pick_info( - info, pick_channels(info['ch_names'], filters['ch_names'])) - n_chs = len(info['ch_names']) + info = pick_info(info, pick_channels(info["ch_names"], filters["ch_names"])) + n_chs = len(info["ch_names"]) # create identity matrix as input for inverse operator # set elements to zero for non-selected channels id_mat = np.eye(n_chs) # convert identity matrix to evoked data type (pretending it's an epochs - evo_ident = EvokedArray(id_mat, info=info, tmin=0.) + evo_ident = EvokedArray(id_mat, info=info, tmin=0.0) # apply beamformer to identity matrix stc_lcmv = apply_lcmv(evo_ident, filters, verbose=verbose) diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 74d273a0b66..6bc18d81e3e 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -6,15 +6,20 @@ import copy as cp import pytest -from numpy.testing import (assert_array_equal, assert_allclose, - assert_array_less) +from numpy.testing import assert_array_equal, assert_allclose, assert_array_less import numpy as np import mne from mne import pick_types -from mne.beamformer import (make_dics, apply_dics, apply_dics_epochs, - apply_dics_tfr_epochs, apply_dics_csd, - read_beamformer, Beamformer) +from mne.beamformer import ( + make_dics, + apply_dics, + apply_dics_epochs, + apply_dics_tfr_epochs, + apply_dics_csd, + read_beamformer, + Beamformer, +) from mne.beamformer._compute_beamformer import _prepare_beamformer_input from mne.beamformer._dics import _prepare_noise_csd from mne.beamformer.tests.test_lcmv import _assert_weight_norm @@ -24,47 +29,40 @@ from mne.io.pick import pick_info from mne.proj import compute_proj_evoked, make_projector from mne.surface import _compute_nearest -from mne.time_frequency import (CrossSpectralDensity, csd_morlet, EpochsTFR, - csd_tfr) +from mne.time_frequency import CrossSpectralDensity, csd_morlet, EpochsTFR, csd_tfr from mne.time_frequency.csd import _sym_mat_to_vector from mne.transforms import invert_transform, apply_trans from mne.utils import object_diff, requires_version, catch_logging data_path = testing.data_path(download=False) fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" subjects_dir = data_path / "subjects" -@pytest.fixture(scope='module', params=[testing._pytest_param()]) +@pytest.fixture(scope="module", params=[testing._pytest_param()]) def _load_forward(): """Load forward models.""" fwd_free = mne.read_forward_solution(fname_fwd) fwd_free = mne.pick_types_forward(fwd_free, meg=True, eeg=False) fwd_free = mne.convert_forward_solution(fwd_free, surf_ori=False) - fwd_surf = mne.convert_forward_solution(fwd_free, surf_ori=True, - use_cps=False) - fwd_fixed = mne.convert_forward_solution(fwd_free, force_fixed=True, - use_cps=False) + fwd_surf = mne.convert_forward_solution(fwd_free, surf_ori=True, use_cps=False) + fwd_fixed = mne.convert_forward_solution(fwd_free, force_fixed=True, use_cps=False) fwd_vol = mne.read_forward_solution(fname_fwd_vol) return fwd_free, fwd_surf, fwd_fixed, fwd_vol def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default """Simulate an oscillator on the cortex.""" - pytest.importorskip('nibabel') - source_vertno = fwd['src'][0]['vertno'][idx] + pytest.importorskip("nibabel") + source_vertno = fwd["src"][0]["vertno"][idx] - sfreq = 50. # Hz. + sfreq = 50.0 # Hz. times = np.arange(10 * sfreq) / sfreq # 10 seconds of data signal = np.sin(20 * 2 * np.pi * times) # 20 Hz oscillator - signal[:len(times) // 2] *= 2 # Make signal louder at the beginning + signal[: len(times) // 2] *= 2 # Make signal louder at the beginning signal *= 1e-9 # Scale to be in the ballpark of MEG data # Construct a SourceEstimate object that describes the signal at the @@ -74,16 +72,16 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default vertices=[[source_vertno], []], tmin=0, tstep=1 / sfreq, - subject='sample', + subject="sample", ) # Create an info object that holds information about the sensors - info = mne.create_info(fwd['info']['ch_names'], sfreq, ch_types='grad') + info = mne.create_info(fwd["info"]["ch_names"], sfreq, ch_types="grad") with info._unlock(): - info.update(fwd['info']) # Merge in sensor position information + info.update(fwd["info"]) # Merge in sensor position information # heavily decimate sensors to make it much faster - info = mne.pick_info(info, np.arange(info['nchan'])[::5]) - fwd = mne.pick_channels_forward(fwd, info['ch_names']) + info = mne.pick_info(info, np.arange(info["nchan"])[::5]) + fwd = mne.pick_channels_forward(fwd, info["ch_names"]) # Run the simulated signal through the forward model, obtaining # simulated sensor data. @@ -95,31 +93,39 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default raw._data += noise # Define a single epoch (weird baseline but shouldn't matter) - epochs = mne.Epochs(raw, [[0, 0, 1]], event_id=1, tmin=0, - tmax=raw.times[-1], baseline=(0., 0.), preload=True) + epochs = mne.Epochs( + raw, + [[0, 0, 1]], + event_id=1, + tmin=0, + tmax=raw.times[-1], + baseline=(0.0, 0.0), + preload=True, + ) evoked = epochs.average() # Compute the cross-spectral density matrix csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=5) - labels = mne.read_labels_from_annot( - 'sample', hemi='lh', subjects_dir=subjects_dir) - label = [ - label for label in labels if np.in1d(source_vertno, label.vertices)[0]] + labels = mne.read_labels_from_annot("sample", hemi="lh", subjects_dir=subjects_dir) + label = [label for label in labels if np.in1d(source_vertno, label.vertices)[0]] assert len(label) == 1 label = label[0] - vertices = np.intersect1d(label.vertices, fwd['src'][0]['vertno']) + vertices = np.intersect1d(label.vertices, fwd["src"][0]["vertno"]) source_ind = vertices.tolist().index(source_vertno) assert vertices[source_ind] == source_vertno return epochs, evoked, csd, source_vertno, label, vertices, source_ind -idx_param = pytest.mark.parametrize('idx', [ - 0, - pytest.param(100, marks=pytest.mark.slowtest), - 200, - pytest.param(233, marks=pytest.mark.slowtest), -]) +idx_param = pytest.mark.parametrize( + "idx", + [ + 0, + pytest.param(100, marks=pytest.mark.slowtest), + 200, + pytest.param(233, marks=pytest.mark.slowtest), + ], +) def _rand_csd(rng, info): @@ -130,7 +136,7 @@ def _rand_csd(rng, info): data = data @ data.conj().T data *= scales data *= scales[:, np.newaxis] - data.flat[::n + 1] = scales + data.flat[:: n + 1] = scales return data @@ -141,67 +147,74 @@ def _make_rand_csd(info, csd): s, u = np.linalg.eigh(csd.get_data(csd.frequencies[0])) mask = np.abs(s) >= s[-1] * 1e-7 rank = mask.sum() - assert rank == len(data) == len(info['ch_names']) + assert rank == len(data) == len(info["ch_names"]) noise_csd = CrossSpectralDensity( - _sym_mat_to_vector(data), info['ch_names'], 0., csd.n_fft) + _sym_mat_to_vector(data), info["ch_names"], 0.0, csd.n_fft + ) return noise_csd, rank @pytest.mark.slowtest @testing.requires_testing_data -@requires_version('h5io') +@requires_version("h5io") @idx_param -@pytest.mark.parametrize('whiten', [ - pytest.param(False, marks=pytest.mark.slowtest), - True, -]) +@pytest.mark.parametrize( + "whiten", + [ + pytest.param(False, marks=pytest.mark.slowtest), + True, + ], +) def test_make_dics(tmp_path, _load_forward, idx, whiten): """Test making DICS beamformer filters.""" # We only test proper handling of parameters here. Testing the results is # done in test_apply_dics_timeseries and test_apply_dics_csd. fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, _, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - with pytest.raises(ValueError, match='several sensor types'): + epochs, _, csd, _, label, vertices, source_ind = _simulate_data(fwd_fixed, idx) + with pytest.raises(ValueError, match="several sensor types"): make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None) if whiten: noise_csd, rank = _make_rand_csd(epochs.info, csd) - assert rank == len(epochs.info['ch_names']) == 62 + assert rank == len(epochs.info["ch_names"]) == 62 else: noise_csd = None - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") with pytest.raises(ValueError, match="Invalid value for the 'pick_ori'"): - make_dics(epochs.info, fwd_fixed, csd, pick_ori="notexistent", - noise_csd=noise_csd) - with pytest.raises(ValueError, match='rank, if str'): - make_dics(epochs.info, fwd_fixed, csd, rank='foo', noise_csd=noise_csd) - with pytest.raises(TypeError, match='rank must be'): - make_dics(epochs.info, fwd_fixed, csd, rank=1., noise_csd=noise_csd) + make_dics( + epochs.info, fwd_fixed, csd, pick_ori="notexistent", noise_csd=noise_csd + ) + with pytest.raises(ValueError, match="rank, if str"): + make_dics(epochs.info, fwd_fixed, csd, rank="foo", noise_csd=noise_csd) + with pytest.raises(TypeError, match="rank must be"): + make_dics(epochs.info, fwd_fixed, csd, rank=1.0, noise_csd=noise_csd) # Test if fixed forward operator is detected when picking normal # orientation - with pytest.raises(ValueError, match='forward operator with free ori'): - make_dics(epochs.info, fwd_fixed, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="forward operator with free ori"): + make_dics(epochs.info, fwd_fixed, csd, pick_ori="normal", noise_csd=noise_csd) # Test if non-surface oriented forward operator is detected when picking # normal orientation - with pytest.raises(ValueError, match='oriented in surface coordinates'): - make_dics(epochs.info, fwd_free, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="oriented in surface coordinates"): + make_dics(epochs.info, fwd_free, csd, pick_ori="normal", noise_csd=noise_csd) # Test if volume forward operator is detected when picking normal # orientation - with pytest.raises(ValueError, match='oriented in surface coordinates'): - make_dics(epochs.info, fwd_vol, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="oriented in surface coordinates"): + make_dics(epochs.info, fwd_vol, csd, pick_ori="normal", noise_csd=noise_csd) # Test invalid combinations of parameters - with pytest.raises(ValueError, match='reduce_rank cannot be used with'): - make_dics(epochs.info, fwd_free, csd, inversion='single', - reduce_rank=True, noise_csd=noise_csd) + with pytest.raises(ValueError, match="reduce_rank cannot be used with"): + make_dics( + epochs.info, + fwd_free, + csd, + inversion="single", + reduce_rank=True, + noise_csd=noise_csd, + ) # TODO: Restore this? # with pytest.raises(ValueError, match='not stable with depth'): # make_dics(epochs.info, fwd_free, csd, weight_norm='unit-noise-gain', @@ -209,83 +222,136 @@ def test_make_dics(tmp_path, _load_forward, idx, whiten): # Sanity checks on the returned filters n_freq = len(csd.frequencies) - vertices = np.intersect1d(label.vertices, fwd_free['src'][0]['vertno']) + vertices = np.intersect1d(label.vertices, fwd_free["src"][0]["vertno"]) n_verts = len(vertices) n_orient = 3 n_channels = len(epochs.ch_names) # Test return values - weight_norm = 'unit-noise-gain' - inversion = 'single' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, real_filter=False, - noise_csd=noise_csd, inversion=inversion) - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert np.iscomplexobj(filters['weights']) - assert filters['csd'].ch_names == epochs.ch_names - assert isinstance(filters['csd'], CrossSpectralDensity) - assert filters['ch_names'] == epochs.ch_names - assert_array_equal(filters['proj'], np.eye(n_channels)) - assert_array_equal(filters['vertices'][0], vertices) - assert_array_equal(filters['vertices'][1], []) # Label was on the LH - assert filters['subject'] == fwd_free['src']._subject - assert filters['pick_ori'] is None - assert filters['is_free_ori'] - assert filters['inversion'] == inversion - assert filters['weight_norm'] == weight_norm - assert 'DICS' in repr(filters) + weight_norm = "unit-noise-gain" + inversion = "single" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + real_filter=False, + noise_csd=noise_csd, + inversion=inversion, + ) + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert np.iscomplexobj(filters["weights"]) + assert filters["csd"].ch_names == epochs.ch_names + assert isinstance(filters["csd"], CrossSpectralDensity) + assert filters["ch_names"] == epochs.ch_names + assert_array_equal(filters["proj"], np.eye(n_channels)) + assert_array_equal(filters["vertices"][0], vertices) + assert_array_equal(filters["vertices"][1], []) # Label was on the LH + assert filters["subject"] == fwd_free["src"]._subject + assert filters["pick_ori"] is None + assert filters["is_free_ori"] + assert filters["inversion"] == inversion + assert filters["weight_norm"] == weight_norm + assert "DICS" in repr(filters) assert 'subject "sample"' in repr(filters) assert str(len(vertices)) in repr(filters) assert str(n_channels) in repr(filters) - assert 'rank' not in repr(filters) + assert "rank" not in repr(filters) _, noise_cov = _prepare_noise_csd(csd, noise_csd, real_filter=False) _, _, _, _, G, _, _, _ = _prepare_beamformer_input( - epochs.info, fwd_surf, label, 'vector', combine_xyz=False, exp=None, - noise_cov=noise_cov) + epochs.info, + fwd_surf, + label, + "vector", + combine_xyz=False, + exp=None, + noise_cov=noise_cov, + ) G.shape = (n_channels, n_verts, n_orient) G = G.transpose(1, 2, 0).conj() # verts, orient, ch _assert_weight_norm(filters, G) - inversion = 'matrix' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, - noise_csd=noise_csd, inversion=inversion) + inversion = "matrix" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) _assert_weight_norm(filters, G) - weight_norm = 'unit-noise-gain-invariant' - inversion = 'single' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, - noise_csd=noise_csd, inversion=inversion) + weight_norm = "unit-noise-gain-invariant" + inversion = "single" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) _assert_weight_norm(filters, G) # Test picking orientations. Also test weight norming under these different # conditions. - weight_norm = 'unit-noise-gain' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='normal', weight_norm=weight_norm, - depth=None, noise_csd=noise_csd, inversion=inversion) + weight_norm = "unit-noise-gain" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="normal", + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) n_orient = 1 - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert not filters['is_free_ori'] + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert not filters["is_free_ori"] _assert_weight_norm(filters, G) - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='max-power', weight_norm=weight_norm, - depth=None, noise_csd=noise_csd, inversion=inversion) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) n_orient = 1 - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert not filters['is_free_ori'] + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert not filters["is_free_ori"] _assert_weight_norm(filters, G) # From here on, only work on a single frequency csd = csd[0] # Test using a real-valued filter - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='normal', real_filter=True, - noise_csd=noise_csd) - assert not np.iscomplexobj(filters['weights']) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="normal", + real_filter=True, + noise_csd=noise_csd, + ) + assert not np.iscomplexobj(filters["weights"]) # Test forward normalization. When inversion='single', the power of a # unit-noise CSD should be 1, even without weight normalization. @@ -294,105 +360,151 @@ def test_make_dics(tmp_path, _load_forward, idx, whiten): inds = np.triu_indices(csd.n_channels) # Using [:, :] syntax for in-place broadcasting csd_noise._data[:, :] = np.eye(csd.n_channels)[inds][:, np.newaxis] - filters = make_dics(epochs.info, fwd_surf, csd_noise, label=label, - weight_norm=None, depth=1., noise_csd=noise_csd, - inversion='single') - w = filters['weights'][0][:3] - assert_allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-6, - atol=0) + filters = make_dics( + epochs.info, + fwd_surf, + csd_noise, + label=label, + weight_norm=None, + depth=1.0, + noise_csd=noise_csd, + inversion="single", + ) + w = filters["weights"][0][:3] + assert_allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-6, atol=0) # Test turning off both forward and weight normalization - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - weight_norm=None, depth=None, noise_csd=noise_csd) - w = filters['weights'][0][:3] - assert not np.allclose(np.diag(w.dot(w.conjugate().T)), 1.0, - rtol=1e-2, atol=0) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + weight_norm=None, + depth=None, + noise_csd=noise_csd, + ) + w = filters["weights"][0][:3] + assert not np.allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-2, atol=0) # Test neural-activity-index weight normalization. It should be a scaled # version of the unit-noise-gain beamformer. filters_nai = make_dics( - epochs.info, fwd_surf, csd, label=label, pick_ori='max-power', - weight_norm='nai', depth=None, noise_csd=noise_csd) - w_nai = filters_nai['weights'][0] + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm="nai", + depth=None, + noise_csd=noise_csd, + ) + w_nai = filters_nai["weights"][0] filters_ung = make_dics( - epochs.info, fwd_surf, csd, label=label, pick_ori='max-power', - weight_norm='unit-noise-gain', depth=None, noise_csd=noise_csd) - w_ung = filters_ung['weights'][0] - assert_allclose(np.corrcoef(np.abs(w_nai).ravel(), - np.abs(w_ung).ravel()), 1, atol=1e-7) + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm="unit-noise-gain", + depth=None, + noise_csd=noise_csd, + ) + w_ung = filters_ung["weights"][0] + assert_allclose( + np.corrcoef(np.abs(w_nai).ravel(), np.abs(w_ung).ravel()), 1, atol=1e-7 + ) # Test whether spatial filter contains src_type - assert 'src_type' in filters + assert "src_type" in filters fname = tmp_path / "filters-dics.h5" filters.save(fname) filters_read = read_beamformer(fname) assert isinstance(filters, Beamformer) assert isinstance(filters_read, Beamformer) - for key in ['tmin', 'tmax']: # deal with strictness of object_diff - setattr(filters['csd'], key, np.float64(getattr(filters['csd'], key))) - assert object_diff(filters, filters_read) == '' + for key in ["tmin", "tmax"]: # deal with strictness of object_diff + setattr(filters["csd"], key, np.float64(getattr(filters["csd"], key))) + assert object_diff(filters, filters_read) == "" def _fwd_dist(power, fwd, vertices, source_ind, tidx=1): idx = np.argmax(power.data[:, tidx]) - rr_got = fwd['src'][0]['rr'][vertices[idx]] - rr_want = fwd['src'][0]['rr'][vertices[source_ind]] + rr_got = fwd["src"][0]["rr"][vertices[idx]] + rr_want = fwd["src"][0]["rr"][vertices[source_ind]] return np.linalg.norm(rr_got - rr_want) @idx_param -@pytest.mark.parametrize('inversion, weight_norm', [ - ('single', None), - ('matrix', 'unit-noise-gain'), -]) +@pytest.mark.parametrize( + "inversion, weight_norm", + [ + ("single", None), + ("matrix", "unit-noise-gain"), + ], +) def test_apply_dics_csd(_load_forward, idx, inversion, weight_norm): """Test applying a DICS beamformer to a CSD matrix.""" fwd_free, fwd_surf, fwd_fixed, _ = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) reg = 1 # Lots of regularization for our toy dataset - with pytest.raises(ValueError, match='several sensor types'): + with pytest.raises(ValueError, match="several sensor types"): make_dics(epochs.info, fwd_free, csd) - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") # Try different types of forward models - assert label.hemi == 'lh' + assert label.hemi == "lh" for fwd in [fwd_free, fwd_surf, fwd_fixed]: - filters = make_dics(epochs.info, fwd, csd, label=label, reg=reg, - inversion=inversion, weight_norm=weight_norm) + filters = make_dics( + epochs.info, + fwd, + csd, + label=label, + reg=reg, + inversion=inversion, + weight_norm=weight_norm, + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] # Did we find the true source at 20 Hz? dist = _fwd_dist(power, fwd_free, vertices, source_ind) - assert dist == 0. + assert dist == 0.0 # Is the signal stronger at 20 Hz than 10? assert power.data[source_ind, 1] > power.data[source_ind, 0] -@pytest.mark.parametrize('pick_ori', [None, 'normal', 'max-power', 'vector']) -@pytest.mark.parametrize('inversion', ['single', 'matrix']) +@pytest.mark.parametrize("pick_ori", [None, "normal", "max-power", "vector"]) +@pytest.mark.parametrize("inversion", ["single", "matrix"]) @idx_param def test_apply_dics_ori_inv(_load_forward, pick_ori, inversion, idx): """Test picking different orientations and inversion modes.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - epochs.pick_types(meg='grad') - - reg_ = 5 if inversion == 'matrix' else 1 - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - reg=reg_, pick_ori=pick_ori, - inversion=inversion, depth=None, - weight_norm='unit-noise-gain') + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) + epochs.pick_types(meg="grad") + + reg_ = 5 if inversion == "matrix" else 1 + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg_, + pick_ori=pick_ori, + inversion=inversion, + depth=None, + weight_norm="unit-noise-gain", + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) # This is 0. for unit-noise-gain-invariant: - assert dist <= (0.02 if inversion == 'matrix' else 0.) + assert dist <= (0.02 if inversion == "matrix" else 0.0) assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test unit-noise-gain weighting @@ -400,40 +512,55 @@ def test_apply_dics_ori_inv(_load_forward, pick_ori, inversion, idx): inds = np.triu_indices(csd.n_channels) csd_noise._data[...] = np.eye(csd.n_channels)[inds][:, np.newaxis] noise_power, f = apply_dics_csd(csd_noise, filters) - want_norm = 3 if pick_ori in (None, 'vector') else 1 + want_norm = 3 if pick_ori in (None, "vector") else 1 assert_allclose(noise_power.data, want_norm, atol=1e-7) # Test filter with forward normalization instead of weight # normalization - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - reg=reg_, pick_ori=pick_ori, - inversion=inversion, weight_norm=None, - depth=1.) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg_, + pick_ori=pick_ori, + inversion=inversion, + weight_norm=None, + depth=1.0, + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) mat_tol = {0: 0.055, 100: 0.20, 200: 0.015, 233: 0.035}[idx] - max_ = (mat_tol if inversion == 'matrix' else 0.) + max_ = mat_tol if inversion == "matrix" else 0.0 assert 0 <= dist <= max_ assert power.data[source_ind, 1] > power.data[source_ind, 0] def _nearest_vol_ind(fwd_vol, fwd, vertices, source_ind): return _compute_nearest( - fwd_vol['source_rr'], - fwd['src'][0]['rr'][vertices][source_ind][np.newaxis])[0] + fwd_vol["source_rr"], fwd["src"][0]["rr"][vertices][source_ind][np.newaxis] + )[0] @idx_param def test_real(_load_forward, idx): """Test using a real-valued filter.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - epochs.pick_types(meg='grad') + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) + epochs.pick_types(meg="grad") reg = 1 # Lots of regularization for our toy dataset - filters_real = make_dics(epochs.info, fwd_surf, csd, label=label, reg=reg, - real_filter=True, inversion='single') + filters_real = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg, + real_filter=True, + inversion="single", + ) # Also test here that no warnings are thrown - implemented to check whether # src should not be None warning occurs: power, f = apply_dics_csd(csd, filters_real) @@ -444,9 +571,16 @@ def test_real(_load_forward, idx): assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test rank reduction - filters_real = make_dics(epochs.info, fwd_surf, csd, label=label, reg=5, - pick_ori='max-power', inversion='matrix', - reduce_rank=True) + filters_real = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=5, + pick_ori="max-power", + inversion="matrix", + reduce_rank=True, + ) power, f = apply_dics_csd(csd, filters_real) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) @@ -454,57 +588,58 @@ def test_real(_load_forward, idx): assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test computing source power on a volume source space - filters_vol = make_dics(epochs.info, fwd_vol, csd, reg=reg, - inversion='single') + filters_vol = make_dics(epochs.info, fwd_vol, csd, reg=reg, inversion="single") power, f = apply_dics_csd(csd, filters_vol) vol_source_ind = _nearest_vol_ind(fwd_vol, fwd_surf, vertices, source_ind) assert f == [10, 20] - dist = _fwd_dist( - power, fwd_vol, fwd_vol['src'][0]['vertno'], vol_source_ind) + dist = _fwd_dist(power, fwd_vol, fwd_vol["src"][0]["vertno"], vol_source_ind) vol_tols = {100: 0.008, 200: 0.008} - assert dist <= vol_tols.get(idx, 0.) + assert dist <= vol_tols.get(idx, 0.0) assert power.data[vol_source_ind, 1] > power.data[vol_source_ind, 0] # check whether a filters object without src_type throws expected warning - del filters_vol['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='spatial filter does not contain ' - 'src_type'): + del filters_vol["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns( + RuntimeWarning, match="spatial filter does not contain " "src_type" + ): apply_dics_csd(csd, filters_vol) -@pytest.mark.filterwarnings("ignore:The use of several sensor types with the" - ":RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:The use of several sensor types with the" ":RuntimeWarning" +) @idx_param def test_apply_dics_timeseries(_load_forward, idx): """Test DICS applied to timeseries data.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, evoked, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) + epochs, evoked, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) reg = 5 # Lots of regularization for our toy dataset - with pytest.raises(ValueError, match='several sensor types'): + with pytest.raises(ValueError, match="several sensor types"): make_dics(evoked.info, fwd_surf, csd) - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") - multiple_filters = make_dics(evoked.info, fwd_surf, csd, label=label, - reg=reg) + multiple_filters = make_dics(evoked.info, fwd_surf, csd, label=label, reg=reg) # Sanity checks on the resulting STC after applying DICS on evoked stcs = apply_dics(evoked, multiple_filters) assert isinstance(stcs, list) - assert len(stcs) == len(multiple_filters['weights']) - assert_array_equal(stcs[0].vertices[0], multiple_filters['vertices'][0]) - assert_array_equal(stcs[0].vertices[1], multiple_filters['vertices'][1]) + assert len(stcs) == len(multiple_filters["weights"]) + assert_array_equal(stcs[0].vertices[0], multiple_filters["vertices"][0]) + assert_array_equal(stcs[0].vertices[1], multiple_filters["vertices"][1]) assert_allclose(stcs[0].times, evoked.times) # Applying filters for multiple frequencies on epoch data should fail - with pytest.raises(ValueError, match='computed for a single frequency'): + with pytest.raises(ValueError, match="computed for a single frequency"): apply_dics_epochs(epochs, multiple_filters) # From now on, only apply filters with a single frequency (20 Hz). csd20 = csd.pick_frequency(20) - filters = make_dics(evoked.info, fwd_surf, csd20, label=label, reg=reg, - inversion='single') + filters = make_dics( + evoked.info, fwd_surf, csd20, label=label, reg=reg, inversion="single" + ) # Sanity checks on the resulting STC after applying DICS on epochs. # Also test here that no warnings are thrown - implemented to check whether @@ -513,8 +648,8 @@ def test_apply_dics_timeseries(_load_forward, idx): assert isinstance(stcs, list) assert len(stcs) == 1 - assert_array_equal(stcs[0].vertices[0], filters['vertices'][0]) - assert_array_equal(stcs[0].vertices[1], filters['vertices'][1]) + assert_array_equal(stcs[0].vertices[0], filters["vertices"][0]) + assert_array_equal(stcs[0].vertices[1], filters["vertices"][1]) assert_allclose(stcs[0].times, epochs.times) # Did we find the source? @@ -524,14 +659,14 @@ def test_apply_dics_timeseries(_load_forward, idx): # Apply filters to evoked stc = apply_dics(evoked, filters) - stc = (stc ** 2).mean() + stc = (stc**2).mean() dist = _fwd_dist(stc, fwd_surf, vertices, source_ind, tidx=0) assert dist == 0 # Test if wrong channel selection is detected in application of filter evoked_ch = cp.deepcopy(evoked) evoked_ch.pick_channels(evoked_ch.ch_names[:-1]) - with pytest.raises(ValueError, match='MEG 2633 which is not present'): + with pytest.raises(ValueError, match="MEG 2633 which is not present"): apply_dics(evoked_ch, filters) # Test whether projections are applied, by adding a custom projection @@ -542,13 +677,13 @@ def test_apply_dics_timeseries(_load_forward, idx): proj_matrix = make_projector(p, evoked_proj.ch_names)[0] evoked_proj.add_proj(p) filters_proj = make_dics(evoked_proj.info, fwd_surf, csd20, label=label) - assert_array_equal(filters_proj['proj'], proj_matrix) + assert_array_equal(filters_proj["proj"], proj_matrix) stc_proj = apply_dics(evoked_proj, filters_proj) assert np.any(np.not_equal(stc_noproj.data, stc_proj.data)) # Test detecting incompatible projections - filters_proj['proj'] = filters_proj['proj'][:-1, :-1] - with pytest.raises(ValueError, match='operands could not be broadcast'): + filters_proj["proj"] = filters_proj["proj"][:-1, :-1] + with pytest.raises(ValueError, match="operands could not be broadcast"): apply_dics(evoked_proj, filters_proj) # Test returning a generator @@ -557,30 +692,28 @@ def test_apply_dics_timeseries(_load_forward, idx): assert_array_equal(stcs[0].data, next(stcs_gen).data) # Test computing timecourses on a volume source space - filters_vol = make_dics(evoked.info, fwd_vol, csd20, reg=reg, - inversion='single') + filters_vol = make_dics(evoked.info, fwd_vol, csd20, reg=reg, inversion="single") stc = apply_dics(evoked, filters_vol) - stc = (stc ** 2).mean() + stc = (stc**2).mean() assert stc.data.shape[1] == 1 vol_source_ind = _nearest_vol_ind(fwd_vol, fwd_surf, vertices, source_ind) - dist = _fwd_dist(stc, fwd_vol, fwd_vol['src'][0]['vertno'], vol_source_ind, - tidx=0) + dist = _fwd_dist(stc, fwd_vol, fwd_vol["src"][0]["vertno"], vol_source_ind, tidx=0) vol_tols = {100: 0.008, 200: 0.015} - vol_tol = vol_tols.get(idx, 0.) + vol_tol = vol_tols.get(idx, 0.0) assert dist <= vol_tol # check whether a filters object without src_type throws expected warning - del filters_vol['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='filter does not contain src_typ'): + del filters_vol["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns(RuntimeWarning, match="filter does not contain src_typ"): apply_dics_epochs(epochs, filters_vol) @testing.requires_testing_data -@pytest.mark.parametrize('return_generator', (True, False)) +@pytest.mark.parametrize("return_generator", (True, False)) def test_apply_dics_tfr(return_generator): """Test DICS applied to time-frequency objects.""" info = read_info(fname_raw) - info = pick_info(info, pick_types(info, meg='grad')) + info = pick_info(info, pick_types(info, meg="grad")) forward = mne.read_forward_solution(fname_fwd) rng = np.random.default_rng(11) @@ -589,7 +722,7 @@ def test_apply_dics_tfr(return_generator): n_chans = len(info.ch_names) freqs = [8, 9] n_times = 300 - times = np.arange(n_times) / info['sfreq'] + times = np.arange(n_times) / info["sfreq"] data = rng.random((n_epochs, n_chans, len(freqs), n_times)) data *= 1e-6 data = data + data * 1j # add imag. component to simulate phase @@ -606,18 +739,23 @@ def test_apply_dics_tfr(return_generator): assert_allclose(stcs[0][0].times, times) assert len(stcs) == len(epochs_tfr) # check same number of epochs assert all([len(s) == len(freqs) for s in stcs]) # check nested freqs - assert all([s.data.shape == (forward['nsource'], n_times) - for these_stcs in stcs for s in these_stcs]) + assert all( + [ + s.data.shape == (forward["nsource"], n_times) + for these_stcs in stcs + for s in these_stcs + ] + ) # Compute power from the source space TFR. This should yield the same # result as the apply_dics_csd function. - source_power = np.zeros((forward['nsource'], len(freqs))) + source_power = np.zeros((forward["nsource"], len(freqs))) for stcs_epoch in stcs: for i, stc_freq in enumerate(stcs_epoch): power = (stc_freq.data * np.conj(stc_freq.data)).real power = power.mean(axis=-1) # mean over time # Scaling by sampling frequency for compatibility with Matlab - power /= epochs_tfr.info['sfreq'] + power /= epochs_tfr.info["sfreq"] source_power[:, i] += power.T source_power /= n_epochs @@ -628,86 +766,111 @@ def test_apply_dics_tfr(return_generator): # Test that real-value only data fails, due to non-linearity of computing # power, it is recommended to transform to source-space first before # converting to power. - with pytest.raises(RuntimeError, - match='Time-frequency data must be complex'): + with pytest.raises(RuntimeError, match="Time-frequency data must be complex"): epochs_tfr_real = epochs_tfr.copy() epochs_tfr_real.data = epochs_tfr_real.data.real stcs = apply_dics_tfr_epochs(epochs_tfr_real, filters) filters_vector = filters.copy() - filters_vector['pick_ori'] = 'vector' - with pytest.warns(match='vector solution'): + filters_vector["pick_ori"] = "vector" + with pytest.warns(match="vector solution"): apply_dics_tfr_epochs(epochs_tfr, filters_vector) def _cov_as_csd(cov, info): rng = np.random.RandomState(0) - assert cov['data'].ndim == 2 - assert len(cov['data']) == len(cov['names']) + assert cov["data"].ndim == 2 + assert len(cov["data"]) == len(cov["names"]) # we need to make this have at least some complex structure - data = cov['data'] + 1e-1 * _rand_csd(rng, info) + data = cov["data"] + 1e-1 * _rand_csd(rng, info) assert data.dtype == np.complex128 - return CrossSpectralDensity(_sym_mat_to_vector(data), cov['names'], 0., 16) + return CrossSpectralDensity(_sym_mat_to_vector(data), cov["names"], 0.0, 16) # Just test free ori here (assume fixed is same as LCMV if these are) # Changes here should be synced with test_lcmv.py @pytest.mark.slowtest @pytest.mark.parametrize( - 'reg, pick_ori, weight_norm, use_cov, depth, lower, upper, real_filter', [ - (0.05, 'vector', 'unit-noise-gain-invariant', - False, None, 26, 28, True), - (0.05, 'vector', 'unit-noise-gain', False, None, 13, 15, True), - (0.05, 'vector', 'nai', False, None, 13, 15, True), - (0.05, None, 'unit-noise-gain-invariant', False, None, 26, 28, False), - (0.05, None, 'unit-noise-gain-invariant', True, None, 40, 42, False), - (0.05, None, 'unit-noise-gain-invariant', True, None, 40, 42, True), - (0.05, None, 'unit-noise-gain', False, None, 13, 14, False), - (0.05, None, 'unit-noise-gain', True, None, 35, 37, False), - (0.05, None, 'nai', True, None, 35, 37, False), + "reg, pick_ori, weight_norm, use_cov, depth, lower, upper, real_filter", + [ + (0.05, "vector", "unit-noise-gain-invariant", False, None, 26, 28, True), + (0.05, "vector", "unit-noise-gain", False, None, 13, 15, True), + (0.05, "vector", "nai", False, None, 13, 15, True), + (0.05, None, "unit-noise-gain-invariant", False, None, 26, 28, False), + (0.05, None, "unit-noise-gain-invariant", True, None, 40, 42, False), + (0.05, None, "unit-noise-gain-invariant", True, None, 40, 42, True), + (0.05, None, "unit-noise-gain", False, None, 13, 14, False), + (0.05, None, "unit-noise-gain", True, None, 35, 37, False), + (0.05, None, "nai", True, None, 35, 37, False), (0.05, None, None, True, None, 12, 14, False), (0.05, None, None, True, 0.8, 39, 43, False), - (0.05, 'max-power', 'unit-noise-gain-invariant', False, None, 17, 20, - False), - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, False), - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, True), - (0.05, 'max-power', 'nai', True, None, 21, 24, False), - (0.05, 'max-power', None, True, None, 7, 10, False), - (0.05, 'max-power', None, True, 0.8, 15, 18, False), + (0.05, "max-power", "unit-noise-gain-invariant", False, None, 17, 20, False), + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, False), + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, True), + (0.05, "max-power", "nai", True, None, 21, 24, False), + (0.05, "max-power", None, True, None, 7, 10, False), + (0.05, "max-power", None, True, 0.8, 15, 18, False), # skip most no-reg tests, assume others are equal to LCMV if these are (0.00, None, None, True, None, 21, 32, False), - (0.00, 'max-power', None, True, None, 13, 19, False), - ]) -def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, - use_cov, depth, lower, upper, real_filter): + (0.00, "max-power", None, True, None, 13, 19, False), + ], +) +def test_localization_bias_free( + bias_params_free, + reg, + pick_ori, + weight_norm, + use_cov, + depth, + lower, + upper, + real_filter, +): """Test localization bias for free-orientation DICS.""" evoked, fwd, noise_cov, data_cov, want = bias_params_free noise_csd = _cov_as_csd(noise_cov, evoked.info) data_csd = _cov_as_csd(data_cov, evoked.info) del noise_cov, data_cov if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_csd = None filters = make_dics( - evoked.info, fwd, data_csd, reg, noise_csd, pick_ori=pick_ori, - weight_norm=weight_norm, depth=depth, real_filter=real_filter) + evoked.info, + fwd, + data_csd, + reg, + noise_csd, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=depth, + real_filter=real_filter, + ) loc = apply_dics(evoked, filters).data - loc = np.linalg.norm(loc, axis=1) if pick_ori == 'vector' else np.abs(loc) + loc = np.linalg.norm(loc, axis=1) if pick_ori == "vector" else np.abs(loc) # Compute the percentage of sources for which there is no loc bias: perc = (want == np.argmax(loc, axis=0)).mean() * 100 assert lower <= perc <= upper @pytest.mark.parametrize( - 'weight_norm, lower, upper, lower_ori, upper_ori, real_filter', [ - ('unit-noise-gain-invariant', 57, 58, 0.60, 0.61, False), - ('unit-noise-gain', 57, 58, 0.60, 0.61, False), - ('unit-noise-gain', 57, 58, 0.60, 0.61, True), + "weight_norm, lower, upper, lower_ori, upper_ori, real_filter", + [ + ("unit-noise-gain-invariant", 57, 58, 0.60, 0.61, False), + ("unit-noise-gain", 57, 58, 0.60, 0.61, False), + ("unit-noise-gain", 57, 58, 0.60, 0.61, True), (None, 27, 28, 0.56, 0.57, False), - ]) -def test_orientation_max_power(bias_params_fixed, bias_params_free, - weight_norm, lower, upper, lower_ori, upper_ori, - real_filter): + ], +) +def test_orientation_max_power( + bias_params_fixed, + bias_params_free, + weight_norm, + lower, + upper, + lower_ori, + upper_ori, + real_filter, +): """Test orientation selection for bias for max-power DICS.""" # we simulate data for the fixed orientation forward and beamform using # the free orientation forward, and check the orientation match at the end @@ -716,11 +879,19 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, data_csd = _cov_as_csd(data_cov, evoked.info) del data_cov, noise_cov fwd = bias_params_free[1] - filters = make_dics(evoked.info, fwd, data_csd, 0.05, noise_csd, - pick_ori='max-power', weight_norm=weight_norm, - depth=None, real_filter=real_filter) + filters = make_dics( + evoked.info, + fwd, + data_csd, + 0.05, + noise_csd, + pick_ori="max-power", + weight_norm=weight_norm, + depth=None, + real_filter=real_filter, + ) loc = np.abs(apply_dics(evoked, filters).data) - ori = filters['max_power_ori'][0] + ori = filters["max_power_ori"][0] assert ori.shape == (246, 3) loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: @@ -730,11 +901,10 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, assert lower <= perc <= upper # Compute the dot products of our forward normals and # assert we get some hopefully reasonable agreement - assert fwd['coord_frame'] == FIFF.FIFFV_COORD_HEAD - nn = np.concatenate( - [s['nn'][v] for s, v in zip(fwd['src'], filters['vertices'])]) + assert fwd["coord_frame"] == FIFF.FIFFV_COORD_HEAD + nn = np.concatenate([s["nn"][v] for s, v in zip(fwd["src"], filters["vertices"])]) nn = nn[want] - nn = apply_trans(invert_transform(fwd['mri_head_t']), nn, move=False) + nn = apply_trans(invert_transform(fwd["mri_head_t"]), nn, move=False) assert_allclose(np.linalg.norm(nn, axis=1), 1, atol=1e-6) assert_allclose(np.linalg.norm(ori, axis=1), 1, atol=1e-12) dots = np.abs((nn[mask] * ori[mask]).sum(-1)) @@ -746,40 +916,46 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, @testing.requires_testing_data @idx_param -@pytest.mark.parametrize('whiten', (False, True)) +@pytest.mark.parametrize("whiten", (False, True)) def test_make_dics_rank(_load_forward, idx, whiten): """Test making DICS beamformer filters with rank param.""" _, fwd_surf, fwd_fixed, _ = _load_forward epochs, _, csd, _, label, _, _ = _simulate_data(fwd_fixed, idx) if whiten: noise_csd, want_rank = _make_rand_csd(epochs.info, csd) - kind = 'mag + grad' + kind = "mag + grad" else: noise_csd = None - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") want_rank = len(epochs.ch_names) assert want_rank == 41 - kind = 'grad' + kind = "grad" with catch_logging() as log: filters = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - verbose=True) + epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, verbose=True + ) log = log.getvalue() - assert f'Estimated rank ({kind}): {want_rank}' in log, log + assert f"Estimated rank ({kind}): {want_rank}" in log, log stc, _ = apply_dics_csd(csd, filters) other_rank = want_rank - 1 # shouldn't make a huge difference use_rank = dict(meg=other_rank) if not whiten: # XXX it's a bug that our rank functions don't treat "meg" # properly here... - use_rank['grad'] = use_rank.pop('meg') + use_rank["grad"] = use_rank.pop("meg") with catch_logging() as log: filters_2 = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - rank=use_rank, verbose=True) + epochs.info, + fwd_surf, + csd, + label=label, + noise_csd=noise_csd, + rank=use_rank, + verbose=True, + ) log = log.getvalue() - assert f'Computing rank from covariance with rank={use_rank}' in log, log + assert f"Computing rank from covariance with rank={use_rank}" in log, log stc_2, _ = apply_dics_csd(csd, filters_2) corr = np.corrcoef(stc_2.data.ravel(), stc.data.ravel())[0, 1] assert 0.8 < corr < 0.999999 @@ -787,10 +963,15 @@ def test_make_dics_rank(_load_forward, idx, whiten): # degenerate conditions if whiten: # make rank deficient - data = noise_csd.get_data(0.) + data = noise_csd.get_data(0.0) data[0] = data[:0] = 0 noise_csd._data[:, 0] = _sym_mat_to_vector(data) - with pytest.raises(ValueError, match='meg data rank.*the noise rank'): + with pytest.raises(ValueError, match="meg data rank.*the noise rank"): filters = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - verbose=True) + epochs.info, + fwd_surf, + csd, + label=label, + noise_csd=noise_csd, + verbose=True, + ) diff --git a/mne/beamformer/tests/test_external.py b/mne/beamformer/tests/test_external.py index a20cb3b3e79..6195f572bae 100644 --- a/mne/beamformer/tests/test_external.py +++ b/mne/beamformer/tests/test_external.py @@ -17,19 +17,15 @@ ft_data_path = data_path / "fieldtrip" / "beamformer" fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" fname_label = data_path / "MEG" / "sample" / "labels" / "Aud-lh.label" reject = dict(grad=4000e-13, mag=4e-12) -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def _get_bf_data(save_fieldtrip=False): raw, epochs, evoked, data_cov, _, _, _, _, _, fwd = _get_data(proj=False) @@ -38,28 +34,29 @@ def _get_bf_data(save_fieldtrip=False): raw.save(ft_data_path / "raw.fif", overwrite=True) # src (tris are not available in fwd['src'] once imported into MATLAB) - src = fwd['src'].copy() + src = fwd["src"].copy() mne.write_source_spaces( - ft_data_path / "src.fif", src, verbose='error', overwrite=True + ft_data_path / "src.fif", src, verbose="error", overwrite=True ) # pick gradiometers only: - epochs.pick_types(meg='grad') - evoked.pick_types(meg='grad') + epochs.pick_types(meg="grad") + evoked.pick_types(meg="grad") # compute covariance matrix (ignore false alarm about no baseline) - data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.145, - method='empirical', verbose='error') + data_cov = mne.compute_covariance( + epochs, tmin=0.04, tmax=0.145, method="empirical", verbose="error" + ) if save_fieldtrip is True: # if the covariance matrix and epochs need resaving: # data covariance: cov_savepath = ft_data_path / "sample_cov.mat" - sample_cov = {'sample_cov': data_cov['data']} + sample_cov = {"sample_cov": data_cov["data"]} savemat(cov_savepath, sample_cov) # evoked data: ev_savepath = ft_data_path / "sample_evoked.mat" - data_ev = {'sample_evoked': evoked.data} + data_ev = {"sample_evoked": evoked.data} savemat(ev_savepath, data_ev) return evoked, data_cov, fwd @@ -67,23 +64,33 @@ def _get_bf_data(save_fieldtrip=False): # beamformer types to be tested: unit-gain (vector and scalar) and # unit-noise-gain (time series and power output [apply_lcmv_cov]) -@requires_version('pymatreader') -@pytest.mark.parametrize('bf_type, weight_norm, pick_ori, pwr', [ - ['ug_scal', None, 'max-power', False], - ['ung', 'unit-noise-gain', 'max-power', False], - ['ung_pow', 'unit-noise-gain', 'max-power', True], - ['ug_vec', None, 'vector', False], - ['ung_vec', 'unit-noise-gain', 'vector', False], -]) +@requires_version("pymatreader") +@pytest.mark.parametrize( + "bf_type, weight_norm, pick_ori, pwr", + [ + ["ug_scal", None, "max-power", False], + ["ung", "unit-noise-gain", "max-power", False], + ["ung_pow", "unit-noise-gain", "max-power", True], + ["ug_vec", None, "vector", False], + ["ung_vec", "unit-noise-gain", "vector", False], + ], +) def test_lcmv_fieldtrip(_get_bf_data, bf_type, weight_norm, pick_ori, pwr): """Test LCMV vs fieldtrip output.""" from pymatreader import read_mat + evoked, data_cov, fwd = _get_bf_data # run the MNE-Python beamformer - filters = make_lcmv(evoked.info, fwd, data_cov=data_cov, - noise_cov=None, pick_ori=pick_ori, reg=0.05, - weight_norm=weight_norm) + filters = make_lcmv( + evoked.info, + fwd, + data_cov=data_cov, + noise_cov=None, + pick_ori=pick_ori, + reg=0.05, + weight_norm=weight_norm, + ) if pwr: stc_mne = apply_lcmv_cov(data_cov, filters) else: @@ -91,18 +98,21 @@ def test_lcmv_fieldtrip(_get_bf_data, bf_type, weight_norm, pick_ori, pwr): # load the FieldTrip output ft_fname = ft_data_path / ("ft_source_" + bf_type + "-vol.mat") - stc_ft_data = read_mat(ft_fname)['stc'] + stc_ft_data = read_mat(ft_fname)["stc"] if stc_ft_data.ndim == 1: stc_ft_data.shape = (stc_ft_data.size, 1) if stc_mne.data.ndim == 2: signs = np.sign((stc_mne.data * stc_ft_data).sum(-1, keepdims=True)) if pwr: - assert_array_equal(signs, 1.) + assert_array_equal(signs, 1.0) stc_mne.data *= signs assert stc_ft_data.shape == stc_mne.data.shape - if pick_ori == 'vector': + if pick_ori == "vector": # compare norms first - assert_allclose(np.linalg.norm(stc_mne.data, axis=1), - np.linalg.norm(stc_ft_data, axis=1), rtol=1e-6) + assert_allclose( + np.linalg.norm(stc_mne.data, axis=1), + np.linalg.norm(stc_ft_data, axis=1), + rtol=1e-6, + ) assert_allclose(stc_mne.data, stc_ft_data, rtol=1e-6) diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 7f8e654c9bf..ae7a64f844e 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -5,17 +5,35 @@ import numpy as np from scipy import linalg from scipy.spatial.distance import cdist -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_array_less) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_allclose, + assert_array_less, +) import mne from mne.transforms import apply_trans, invert_transform -from mne import (convert_forward_solution, read_forward_solution, compute_rank, - VolVectorSourceEstimate, VolSourceEstimate, EvokedArray, - pick_channels_cov, read_vectorview_selection) -from mne.beamformer import (make_lcmv, apply_lcmv, apply_lcmv_epochs, - apply_lcmv_raw, Beamformer, - read_beamformer, apply_lcmv_cov, make_dics) +from mne import ( + convert_forward_solution, + read_forward_solution, + compute_rank, + VolVectorSourceEstimate, + VolSourceEstimate, + EvokedArray, + pick_channels_cov, + read_vectorview_selection, +) +from mne.beamformer import ( + make_lcmv, + apply_lcmv, + apply_lcmv_epochs, + apply_lcmv_raw, + Beamformer, + read_beamformer, + apply_lcmv_cov, + make_dics, +) from mne.beamformer._compute_beamformer import _prepare_beamformer_input from mne.datasets import testing from mne.io.compensator import set_current_comp @@ -23,19 +41,14 @@ from mne.minimum_norm import make_inverse_operator, apply_inverse from mne.minimum_norm.tests.test_inverse import _assert_free_ori_match from mne.simulation import simulate_evoked -from mne.utils import (object_diff, requires_version, catch_logging, - _record_warnings) +from mne.utils import object_diff, requires_version, catch_logging, _record_warnings data_path = testing.data_path(download=False) fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" fname_label = data_path / "MEG" / "sample" / "labels" / "Aud-lh.label" ctf_fname = data_path / "CTF" / "somMDYO-18av.ds" @@ -49,18 +62,25 @@ def _read_forward_solution_meg(*args, **kwargs): return mne.pick_types_forward(fwd, meg=True, eeg=False) -def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, - epochs_preload=True, data_cov=True, proj=True): +def _get_data( + tmin=-0.1, + tmax=0.15, + all_forward=True, + epochs=True, + epochs_preload=True, + data_cov=True, + proj=True, +): """Read in data used in tests.""" label = mne.read_label(fname_label) events = mne.read_events(fname_event) raw = mne.io.read_raw_fif(fname_raw, preload=True) forward = mne.read_forward_solution(fname_fwd) if all_forward: - forward_surf_ori = _read_forward_solution_meg( - fname_fwd, surf_ori=True) + forward_surf_ori = _read_forward_solution_meg(fname_fwd, surf_ori=True) forward_fixed = _read_forward_solution_meg( - fname_fwd, force_fixed=True, surf_ori=True, use_cps=False) + fname_fwd, force_fixed=True, surf_ori=True, use_cps=False + ) forward_vol = _read_forward_solution_meg(fname_fwd_vol) else: forward_surf_ori = None @@ -70,11 +90,10 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, event_id, tmin, tmax = 1, tmin, tmax # Setup for reading the raw data - raw.info['bads'] = ['MEG 2443', 'EEG 053'] # 2 bad channels + raw.info["bads"] = ["MEG 2443", "EEG 053"] # 2 bad channels # Set up pick list: MEG - bad channels - left_temporal_channels = read_vectorview_selection('Left-temporal') - picks = mne.pick_types(raw.info, meg=True, - selection=left_temporal_channels) + left_temporal_channels = read_vectorview_selection("Left-temporal") + picks = mne.pick_types(raw.info, meg=True, selection=left_temporal_channels) picks = picks[::2] # decimate for speed # add a couple channels we will consider bad bad_picks = [100, 101] @@ -84,7 +103,7 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, raw.pick_channels([raw.ch_names[ii] for ii in picks], ordered=True) del picks - raw.info['bads'] = bads # add more bads + raw.info["bads"] = bads # add more bads if proj: raw.info.normalize_proj() # avoid projection warnings else: @@ -93,8 +112,16 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, if epochs: # Read epochs epochs = mne.Epochs( - raw, events, event_id, tmin, tmax, proj=True, - baseline=(None, 0), preload=epochs_preload, reject=reject) + raw, + events, + event_id, + tmin, + tmax, + proj=True, + baseline=(None, 0), + preload=epochs_preload, + reject=reject, + ) if epochs_preload: epochs.resample(200, npad=0) epochs.crop(0, None) @@ -106,17 +133,29 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, info = raw.info noise_cov = mne.read_cov(fname_cov) - noise_cov['projs'] = [] # avoid warning - noise_cov = mne.cov.regularize(noise_cov, info, mag=0.05, grad=0.05, - eeg=0.1, proj=True, rank=None) + noise_cov["projs"] = [] # avoid warning + noise_cov = mne.cov.regularize( + noise_cov, info, mag=0.05, grad=0.05, eeg=0.1, proj=True, rank=None + ) if data_cov: data_cov = mne.compute_covariance( - epochs, tmin=0.04, tmax=0.145, verbose='error') # baseline warning + epochs, tmin=0.04, tmax=0.145, verbose="error" + ) # baseline warning else: data_cov = None - return raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol + return ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) @pytest.mark.slowtest @@ -127,40 +166,43 @@ def test_lcmv_vector(): # For speed and for rank-deficiency calculation simplicity, # just use grads - info = mne.pick_info(info, mne.pick_types(info, meg='grad', exclude=())) + info = mne.pick_info(info, mne.pick_types(info, meg="grad", exclude=())) with info._unlock(): info.update(bads=[], projs=[]) forward = mne.read_forward_solution(fname_fwd) - forward = mne.pick_channels_forward(forward, info['ch_names']) - vertices = [s['vertno'][::200] for s in forward['src']] + forward = mne.pick_channels_forward(forward, info["ch_names"]) + vertices = [s["vertno"][::200] for s in forward["src"]] n_vertices = sum(len(v) for v in vertices) assert n_vertices == 4 amplitude = 100e-9 - stc = mne.SourceEstimate(amplitude * np.eye(n_vertices), vertices, - 0, 1. / info['sfreq']) - forward_sim = mne.convert_forward_solution(forward, force_fixed=True, - use_cps=True, copy=True) + stc = mne.SourceEstimate( + amplitude * np.eye(n_vertices), vertices, 0, 1.0 / info["sfreq"] + ) + forward_sim = mne.convert_forward_solution( + forward, force_fixed=True, use_cps=True, copy=True + ) forward_sim = mne.forward.restrict_forward_to_stc(forward_sim, stc) noise_cov = mne.make_ad_hoc_cov(info) - noise_cov.update(data=np.diag(noise_cov['data']), diag=False) + noise_cov.update(data=np.diag(noise_cov["data"]), diag=False) evoked = simulate_evoked(forward_sim, stc, info, noise_cov, nave=1) - source_nn = forward_sim['source_nn'] - source_rr = forward_sim['source_rr'] + source_nn = forward_sim["source_nn"] + source_rr = forward_sim["source_rr"] # Figure out our indices - mask = np.concatenate([np.in1d(s['vertno'], v) - for s, v in zip(forward['src'], vertices)]) + mask = np.concatenate( + [np.in1d(s["vertno"], v) for s, v in zip(forward["src"], vertices)] + ) mapping = np.where(mask)[0] - assert_array_equal(source_rr, forward['source_rr'][mapping]) + assert_array_equal(source_rr, forward["source_rr"][mapping]) # Don't check NN because we didn't rotate to surf ori del forward_sim # Let's do minimum norm as a sanity check (dipole_fit is slower) - inv = make_inverse_operator(info, forward, noise_cov, loose=1.) - stc_vector_mne = apply_inverse(evoked, inv, pick_ori='vector') + inv = make_inverse_operator(info, forward, noise_cov, loose=1.0) + stc_vector_mne = apply_inverse(evoked, inv, pick_ori="vector") mne_ori = stc_vector_mne.data[mapping, :, np.arange(n_vertices)] mne_ori /= np.linalg.norm(mne_ori, axis=-1)[:, np.newaxis] mne_angles = np.rad2deg(np.arccos(np.sum(mne_ori * source_nn, axis=-1))) @@ -169,28 +211,34 @@ def test_lcmv_vector(): # Now let's do LCMV data_cov = mne.make_ad_hoc_cov(info) # just a stub for later with pytest.raises(ValueError, match="pick_ori"): - make_lcmv(info, forward, data_cov, 0.05, noise_cov, pick_ori='bad') + make_lcmv(info, forward, data_cov, 0.05, noise_cov, pick_ori="bad") lcmv_ori = list() for ti in range(n_vertices): this_evoked = evoked.copy().crop(evoked.times[ti], evoked.times[ti]) - data_cov['diag'] = False - data_cov['data'] = (np.outer(this_evoked.data, this_evoked.data) + - noise_cov['data']) - vals = linalg.svdvals(data_cov['data']) + data_cov["diag"] = False + data_cov["data"] = ( + np.outer(this_evoked.data, this_evoked.data) + noise_cov["data"] + ) + vals = linalg.svdvals(data_cov["data"]) assert vals[0] / vals[-1] < 1e5 # not rank deficient with catch_logging() as log: - filters = make_lcmv(info, forward, data_cov, 0.05, noise_cov, - verbose=True) + filters = make_lcmv(info, forward, data_cov, 0.05, noise_cov, verbose=True) log = log.getvalue() - assert '498 sources' in log + assert "498 sources" in log with catch_logging() as log: - filters_vector = make_lcmv(info, forward, data_cov, 0.05, - noise_cov, pick_ori='vector', - verbose=True) + filters_vector = make_lcmv( + info, + forward, + data_cov, + 0.05, + noise_cov, + pick_ori="vector", + verbose=True, + ) log = log.getvalue() - assert '498 sources' in log + assert "498 sources" in log stc = apply_lcmv(this_evoked, filters) stc_vector = apply_lcmv(this_evoked, filters_vector) assert isinstance(stc, mne.SourceEstimate) @@ -199,7 +247,7 @@ def test_lcmv_vector(): # Check the orientation by pooling across some neighbors, as LCMV can # have some "holes" at the points of interest - idx = np.where(cdist(forward['source_rr'], source_rr[[ti]]) < 0.02)[0] + idx = np.where(cdist(forward["source_rr"], source_rr[[ti]]) < 0.02)[0] lcmv_ori.append(np.mean(stc_vector.data[idx, :, 0], axis=0)) lcmv_ori[-1] /= np.linalg.norm(lcmv_ori[-1]) @@ -208,27 +256,39 @@ def test_lcmv_vector(): @pytest.mark.slowtest -@requires_version('h5io') +@requires_version("h5io") @testing.requires_testing_data -@pytest.mark.parametrize('reg, proj, kind', [ - (0.01, True, 'volume'), - (0., False, 'volume'), - (0.01, False, 'surface'), - (0., True, 'surface'), -]) +@pytest.mark.parametrize( + "reg, proj, kind", + [ + (0.01, True, "volume"), + (0.0, False, "volume"), + (0.01, False, "surface"), + (0.0, True, "surface"), + ], +) def test_make_lcmv_bem(tmp_path, reg, proj, kind): """Test LCMV with evoked data and single trials.""" - raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol = _get_data(proj=proj) - - if kind == 'surface': + ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) = _get_data(proj=proj) + + if kind == "surface": fwd = forward else: fwd = forward_vol - assert kind == 'volume' + assert kind == "volume" - filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, - noise_cov=noise_cov) + filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, noise_cov=noise_cov) stc = apply_lcmv(evoked, filters) stc.crop(0.02, None) @@ -240,11 +300,17 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): assert 0.08 < tmax < 0.15, tmax assert 0.9 < np.max(max_stc) < 3.5, np.max(max_stc) - if kind == 'surface': + if kind == "surface": # Test picking normal orientation (surface source space only). - filters = make_lcmv(evoked.info, forward_surf_ori, data_cov, - reg=reg, noise_cov=noise_cov, - pick_ori='normal', weight_norm=None) + filters = make_lcmv( + evoked.info, + forward_surf_ori, + data_cov, + reg=reg, + noise_cov=noise_cov, + pick_ori="normal", + weight_norm=None, + ) stc_normal = apply_lcmv(evoked, filters) stc_normal.crop(0.02, None) @@ -264,8 +330,9 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): assert (np.abs(stc_normal.data) <= stc.data).all() # Test picking source orientation maximizing output source power - filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, - noise_cov=noise_cov, pick_ori='max-power') + filters = make_lcmv( + evoked.info, fwd, data_cov, reg=reg, noise_cov=noise_cov, pick_ori="max-power" + ) stc_max_power = apply_lcmv(evoked, filters) stc_max_power.crop(0.02, None) stc_pow = np.sum(np.abs(stc_max_power.data), axis=1) @@ -275,85 +342,125 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): lower = 0.08 if proj else 0.04 assert lower < tmax < 0.15, tmax - assert 0.8 < np.max(max_stc) < 3., np.max(max_stc) + assert 0.8 < np.max(max_stc) < 3.0, np.max(max_stc) stc_max_power.data[:, :] = np.abs(stc_max_power.data) - if kind == 'surface': + if kind == "surface": # Maximum output source power orientation results should be # similar to free orientation results in areas with channel # coverage label = mne.read_label(fname_label) - mean_stc = stc.extract_label_time_course( - label, fwd['src'], mode='mean') - mean_stc_max_pow = \ - stc_max_power.extract_label_time_course( - label, fwd['src'], mode='mean') + mean_stc = stc.extract_label_time_course(label, fwd["src"], mode="mean") + mean_stc_max_pow = stc_max_power.extract_label_time_course( + label, fwd["src"], mode="mean" + ) assert_array_less(np.abs(mean_stc - mean_stc_max_pow), 1.0) # Test if spatial filter contains src_type - assert filters['src_type'] == kind + assert filters["src_type"] == kind # __repr__ assert len(evoked.ch_names) == 22 - assert len(evoked.info['projs']) == (3 if proj else 0) - assert len(evoked.info['bads']) == 2 + assert len(evoked.info["projs"]) == (3 if proj else 0) + assert len(evoked.info["bads"]) == 2 rank = 17 if proj else 20 - assert 'LCMV' in repr(filters) - assert 'unknown subject' not in repr(filters) + assert "LCMV" in repr(filters) + assert "unknown subject" not in repr(filters) assert f'{fwd["nsource"]} vert' in repr(filters) - assert '20 ch' in repr(filters) - assert 'rank %s' % rank in repr(filters) + assert "20 ch" in repr(filters) + assert "rank %s" % rank in repr(filters) # I/O fname = tmp_path / "filters.h5" - with pytest.warns(RuntimeWarning, match='-lcmv.h5'): + with pytest.warns(RuntimeWarning, match="-lcmv.h5"): filters.save(fname) filters_read = read_beamformer(fname) assert isinstance(filters, Beamformer) assert isinstance(filters_read, Beamformer) # deal with object_diff strictness - filters_read['rank'] = int(filters_read['rank']) - filters['rank'] = int(filters['rank']) - assert object_diff(filters, filters_read) == '' + filters_read["rank"] = int(filters_read["rank"]) + filters["rank"] = int(filters["rank"]) + assert object_diff(filters, filters_read) == "" - if kind != 'surface': + if kind != "surface": return # Test if fixed forward operator is detected when picking normal or # max-power orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward_fixed, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') - pytest.raises(ValueError, make_lcmv, evoked.info, forward_fixed, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='max-power') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_fixed, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_fixed, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="max-power", + ) # Test if non-surface oriented forward operator is detected when picking # normal orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) # Test if volume forward operator is detected when picking normal # orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward_vol, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_vol, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) # Test if missing of noise covariance matrix is detected when more than # one channel type is present in the data - pytest.raises(ValueError, make_lcmv, evoked.info, forward_vol, - data_cov=data_cov, reg=0.01, noise_cov=None, - pick_ori='max-power') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_vol, + data_cov=data_cov, + reg=0.01, + noise_cov=None, + pick_ori="max-power", + ) # Test if wrong channel selection is detected in application of filter evoked_ch = deepcopy(evoked) evoked_ch.pick_channels(evoked_ch.ch_names[1:]) - filters = make_lcmv(evoked.info, forward_vol, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked.info, forward_vol, data_cov, reg=0.01, noise_cov=noise_cov + ) # Test if discrepancies in channel selection of data and fwd model are # handled correctly in apply_lcmv # make filter with data where first channel was removed - filters = make_lcmv(evoked_ch.info, forward_vol, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked_ch.info, forward_vol, data_cov, reg=0.01, noise_cov=noise_cov + ) # applying that filter to the full data set should automatically exclude # this channel from the data # also test here that no warnings are thrown - implemented to check whether @@ -368,34 +475,36 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): # Test if non-matching SSP projection is detected in application of filter if proj: raw_proj = raw.copy().del_proj() - with pytest.raises(ValueError, match='do not match the projections'): + with pytest.raises(ValueError, match="do not match the projections"): apply_lcmv_raw(raw_proj, filters) # Test apply_lcmv_raw use_raw = raw.copy().crop(0, 1) stc = apply_lcmv_raw(use_raw, filters) assert_allclose(stc.times, use_raw.times) - assert_array_equal(stc.vertices[0], forward_vol['src'][0]['vertno']) + assert_array_equal(stc.vertices[0], forward_vol["src"][0]["vertno"]) # Test if spatial filter contains src_type - assert 'src_type' in filters + assert "src_type" in filters # check whether a filters object without src_type throws expected warning - del filters['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='spatial filter does not contain ' - 'src_type'): + del filters["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns( + RuntimeWarning, match="spatial filter does not contain " "src_type" + ): apply_lcmv(evoked, filters) # Now test single trial using fixed orientation forward solution # so we can compare it to the evoked solution - filters = make_lcmv(epochs.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + epochs.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov + ) stcs = apply_lcmv_epochs(epochs, filters) stcs_ = apply_lcmv_epochs(epochs, filters, return_generator=True) assert_array_equal(stcs[0].data, next(stcs_).data) epochs.drop_bad() - assert (len(epochs.events) == len(stcs)) + assert len(epochs.events) == len(stcs) # average the single trial estimates stc_avg = np.zeros_like(stcs[0].data) @@ -404,15 +513,17 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): stc_avg /= len(stcs) # compare it to the solution using evoked with fixed orientation - filters = make_lcmv(evoked.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov + ) stc_fixed = apply_lcmv(evoked, filters) assert_array_almost_equal(stc_avg, stc_fixed.data) # use a label so we have few source vertices and delayed computation is # not used - filters = make_lcmv(epochs.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov, label=label) + filters = make_lcmv( + epochs.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov, label=label + ) stcs_label = apply_lcmv_epochs(epochs, filters) assert_array_almost_equal(stcs_label[0].data, stcs[0].in_label(label).data) @@ -420,54 +531,78 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): # Test condition where the filters weights are zero. There should not be # any divide-by-zero errors zero_cov = data_cov.copy() - zero_cov['data'][:] = 0 - filters = make_lcmv(epochs.info, forward_fixed, zero_cov, reg=0.01, - noise_cov=noise_cov) - assert_array_equal(filters['weights'], 0) + zero_cov["data"][:] = 0 + filters = make_lcmv( + epochs.info, forward_fixed, zero_cov, reg=0.01, noise_cov=noise_cov + ) + assert_array_equal(filters["weights"], 0) # Test condition where one channel type is picked # (avoid "grad data rank (13) did not match the noise rank (None)") data_cov_grad = pick_channels_cov( - data_cov, [ch_name for ch_name in epochs.info['ch_names'] - if ch_name.endswith(('2', '3'))], ordered=False) - assert len(data_cov_grad['names']) > 4 - make_lcmv(epochs.info, forward_fixed, data_cov_grad, reg=0.01, - noise_cov=noise_cov) + data_cov, + [ + ch_name + for ch_name in epochs.info["ch_names"] + if ch_name.endswith(("2", "3")) + ], + ordered=False, + ) + assert len(data_cov_grad["names"]) > 4 + make_lcmv(epochs.info, forward_fixed, data_cov_grad, reg=0.01, noise_cov=noise_cov) @testing.requires_testing_data @pytest.mark.slowtest -@pytest.mark.parametrize('weight_norm, pick_ori', [ - ('unit-noise-gain', 'max-power'), - ('unit-noise-gain', 'vector'), - ('unit-noise-gain', None), - ('nai', 'vector'), - (None, 'max-power'), -]) +@pytest.mark.parametrize( + "weight_norm, pick_ori", + [ + ("unit-noise-gain", "max-power"), + ("unit-noise-gain", "vector"), + ("unit-noise-gain", None), + ("nai", "vector"), + (None, "max-power"), + ], +) def test_make_lcmv_sphere(pick_ori, weight_norm): """Test LCMV with sphere head model.""" # unit-noise gain beamformer and orientation # selection and rank reduction of the leadfield _, _, evoked, data_cov, noise_cov, _, _, _, _, _ = _get_data(proj=True) - assert 'eeg' not in evoked - assert 'meg' in evoked - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080) + assert "eeg" not in evoked + assert "meg" in evoked + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.080) src = mne.setup_volume_source_space( - pos=25., sphere=sphere, mindist=5.0, exclude=2.0) + pos=25.0, sphere=sphere, mindist=5.0, exclude=2.0 + ) fwd_sphere = mne.make_forward_solution(evoked.info, None, src, sphere) # Test that we get an error if not reducing rank - with pytest.raises(ValueError, match='Singular matrix detected'): - with pytest.warns(RuntimeWarning, match='positive semidefinite'): + with pytest.raises(ValueError, match="Singular matrix detected"): + with pytest.warns(RuntimeWarning, match="positive semidefinite"): make_lcmv( - evoked.info, fwd_sphere, data_cov, reg=0.1, - noise_cov=noise_cov, weight_norm=weight_norm, - pick_ori=pick_ori, reduce_rank=False, rank='full') + evoked.info, + fwd_sphere, + data_cov, + reg=0.1, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + reduce_rank=False, + rank="full", + ) # Now let's reduce it - filters = make_lcmv(evoked.info, fwd_sphere, data_cov, reg=0.1, - noise_cov=noise_cov, weight_norm=weight_norm, - pick_ori=pick_ori, reduce_rank=True) + filters = make_lcmv( + evoked.info, + fwd_sphere, + data_cov, + reg=0.1, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + reduce_rank=True, + ) stc_sphere = apply_lcmv(evoked, filters) if isinstance(stc_sphere, VolVectorSourceEstimate): stc_sphere = stc_sphere.magnitude() @@ -489,21 +624,36 @@ def test_make_lcmv_sphere(pick_ori, weight_norm): @testing.requires_testing_data -@pytest.mark.parametrize('weight_norm', (None, 'unit-noise-gain')) -@pytest.mark.parametrize('pick_ori', ('max-power', 'normal')) +@pytest.mark.parametrize("weight_norm", (None, "unit-noise-gain")) +@pytest.mark.parametrize("pick_ori", ("max-power", "normal")) def test_lcmv_cov(weight_norm, pick_ori): """Test LCMV source power computation.""" - raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol = _get_data() + ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) = _get_data() convert_forward_solution(forward, surf_ori=True, copy=False) - filters = make_lcmv(evoked.info, forward, data_cov, noise_cov=noise_cov, - weight_norm=weight_norm, pick_ori=pick_ori) + filters = make_lcmv( + evoked.info, + forward, + data_cov, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + ) for cov in (data_cov, noise_cov): this_cov = pick_channels_cov(cov, evoked.ch_names, ordered=False) - this_evoked = evoked.copy().pick_channels( - this_cov['names'], ordered=True) - this_cov['projs'] = this_evoked.info['projs'] - assert this_evoked.ch_names == this_cov['names'] + this_evoked = evoked.copy().pick_channels(this_cov["names"], ordered=True) + this_cov["projs"] = this_evoked.info["projs"] + assert this_evoked.ch_names == this_cov["names"] stc = apply_lcmv_cov(this_cov, filters) assert stc.data.min() > 0 assert stc.shape == (498, 1) @@ -530,29 +680,35 @@ def test_lcmv_ctf_comp(): evoked = epochs.average() data_cov = mne.compute_covariance(epochs) - fwd = mne.make_forward_solution(evoked.info, None, - mne.setup_volume_source_space(pos=30.0), - mne.make_sphere_model()) - with pytest.raises(ValueError, match='reduce_rank'): + fwd = mne.make_forward_solution( + evoked.info, + None, + mne.setup_volume_source_space(pos=30.0), + mne.make_sphere_model(), + ) + with pytest.raises(ValueError, match="reduce_rank"): make_lcmv(evoked.info, fwd, data_cov) filters = make_lcmv(evoked.info, fwd, data_cov, reduce_rank=True) - assert 'weights' in filters + assert "weights" in filters # test whether different compensations throw error info_comp = evoked.info.copy() set_current_comp(info_comp, 1) - with pytest.raises(RuntimeError, match='Compensation grade .* not match'): + with pytest.raises(RuntimeError, match="Compensation grade .* not match"): make_lcmv(info_comp, fwd, data_cov) @pytest.mark.slowtest @testing.requires_testing_data -@pytest.mark.parametrize('proj, weight_norm', [ - (True, 'unit-noise-gain'), - (False, 'unit-noise-gain'), - (True, None), - (True, 'nai'), -]) +@pytest.mark.parametrize( + "proj, weight_norm", + [ + (True, "unit-noise-gain"), + (False, "unit-noise-gain"), + (True, None), + (True, "nai"), + ], +) def test_lcmv_reg_proj(proj, weight_norm): """Test LCMV with and without proj.""" raw = mne.io.read_raw_fif(fname_raw, preload=True) @@ -560,58 +716,70 @@ def test_lcmv_reg_proj(proj, weight_norm): raw.pick_types(meg=True) assert len(raw.ch_names) == 305 epochs = mne.Epochs(raw, events, None, preload=True, proj=proj) - with pytest.warns(RuntimeWarning, match='Too few samples'): + with pytest.warns(RuntimeWarning, match="Too few samples"): noise_cov = mne.compute_covariance(epochs, tmax=0) data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.15) forward = mne.read_forward_solution(fname_fwd) - filters = make_lcmv(epochs.info, forward, data_cov, reg=0.05, - noise_cov=noise_cov, pick_ori='max-power', - weight_norm='nai', rank=None, verbose=True) + filters = make_lcmv( + epochs.info, + forward, + data_cov, + reg=0.05, + noise_cov=noise_cov, + pick_ori="max-power", + weight_norm="nai", + rank=None, + verbose=True, + ) want_rank = 302 # 305 good channels - 3 MEG projs - assert filters['rank'] == want_rank + assert filters["rank"] == want_rank # And also with and without noise_cov - with pytest.raises(ValueError, match='several sensor types'): - make_lcmv(epochs.info, forward, data_cov, reg=0.05, - noise_cov=None) - epochs.pick_types(meg='grad') + with pytest.raises(ValueError, match="several sensor types"): + make_lcmv(epochs.info, forward, data_cov, reg=0.05, noise_cov=None) + epochs.pick_types(meg="grad") kwargs = dict(reg=0.05, pick_ori=None, weight_norm=weight_norm) - filters_cov = make_lcmv(epochs.info, forward, data_cov, - noise_cov=noise_cov, **kwargs) - filters_nocov = make_lcmv(epochs.info, forward, data_cov, - noise_cov=None, **kwargs) + filters_cov = make_lcmv( + epochs.info, forward, data_cov, noise_cov=noise_cov, **kwargs + ) + filters_nocov = make_lcmv(epochs.info, forward, data_cov, noise_cov=None, **kwargs) ad_hoc = mne.make_ad_hoc_cov(epochs.info) - filters_adhoc = make_lcmv(epochs.info, forward, data_cov, - noise_cov=ad_hoc, **kwargs) + filters_adhoc = make_lcmv( + epochs.info, forward, data_cov, noise_cov=ad_hoc, **kwargs + ) evoked = epochs.average() stc_cov = apply_lcmv(evoked, filters_cov) stc_nocov = apply_lcmv(evoked, filters_nocov) stc_adhoc = apply_lcmv(evoked, filters_adhoc) # Compare adhoc and nocov: scale difference is necessitated by using std=1. - if weight_norm == 'unit-noise-gain': - scale = np.sqrt(ad_hoc['data'][0]) + if weight_norm == "unit-noise-gain": + scale = np.sqrt(ad_hoc["data"][0]) else: - scale = 1. + scale = 1.0 assert_allclose(stc_nocov.data, stc_adhoc.data * scale) - a = np.dot(filters_nocov['weights'], filters_nocov['whitener']) - b = np.dot(filters_adhoc['weights'], filters_adhoc['whitener']) * scale + a = np.dot(filters_nocov["weights"], filters_nocov["whitener"]) + b = np.dot(filters_adhoc["weights"], filters_adhoc["whitener"]) * scale atol = np.mean(np.sqrt(a * a)) * 1e-7 assert_allclose(a, b, atol=atol, rtol=1e-7) # Compare adhoc and cov: locs might not be equivalent, but the same # general profile should persist, so look at the std and be lenient: - if weight_norm == 'unit-noise-gain': + if weight_norm == "unit-noise-gain": adhoc_scale = 0.12 else: - adhoc_scale = 1. + adhoc_scale = 1.0 assert_allclose( np.linalg.norm(stc_adhoc.data, axis=0) * adhoc_scale, - np.linalg.norm(stc_cov.data, axis=0), rtol=0.3) + np.linalg.norm(stc_cov.data, axis=0), + rtol=0.3, + ) assert_allclose( np.linalg.norm(stc_nocov.data, axis=0) / scale * adhoc_scale, - np.linalg.norm(stc_cov.data, axis=0), rtol=0.3) + np.linalg.norm(stc_cov.data, axis=0), + rtol=0.3, + ) - if weight_norm == 'nai': + if weight_norm == "nai": # NAI is always normalized by noise-level (based on eigenvalues) for stc in (stc_nocov, stc_cov): assert_allclose(stc.data.std(), 0.584, rtol=0.2) @@ -621,34 +789,47 @@ def test_lcmv_reg_proj(proj, weight_norm): for stc in (stc_nocov, stc_cov): assert_allclose(stc.data.std(), 2.8e-8, rtol=0.1) else: - assert weight_norm == 'unit-noise-gain' + assert weight_norm == "unit-noise-gain" # Channel scalings depend on presence of noise_cov assert_allclose(stc_nocov.data.std(), 7.8e-13, rtol=0.1) assert_allclose(stc_cov.data.std(), 0.187, rtol=0.2) -@pytest.mark.parametrize('reg, weight_norm, use_cov, depth, lower, upper', [ - (0.05, 'unit-noise-gain', True, None, 97, 98), - (0.05, 'nai', True, None, 96, 98), - (0.05, 'nai', True, 0.8, 96, 98), - (0.05, None, True, None, 74, 76), - (0.05, None, True, 0.8, 90, 93), # depth improves weight_norm=None - (0.05, 'unit-noise-gain', False, None, 83, 86), - (0.05, 'unit-noise-gain', False, 0.8, 83, 86), # depth same for wn != None - # no reg - (0.00, 'unit-noise-gain', True, None, 35, 99), # TODO: Still not stable -]) -def test_localization_bias_fixed(bias_params_fixed, reg, weight_norm, use_cov, - depth, lower, upper): +@pytest.mark.parametrize( + "reg, weight_norm, use_cov, depth, lower, upper", + [ + (0.05, "unit-noise-gain", True, None, 97, 98), + (0.05, "nai", True, None, 96, 98), + (0.05, "nai", True, 0.8, 96, 98), + (0.05, None, True, None, 74, 76), + (0.05, None, True, 0.8, 90, 93), # depth improves weight_norm=None + (0.05, "unit-noise-gain", False, None, 83, 86), + (0.05, "unit-noise-gain", False, 0.8, 83, 86), # depth same for wn != None + # no reg + (0.00, "unit-noise-gain", True, None, 35, 99), # TODO: Still not stable + ], +) +def test_localization_bias_fixed( + bias_params_fixed, reg, weight_norm, use_cov, depth, lower, upper +): """Test localization bias for fixed-orientation LCMV.""" evoked, fwd, noise_cov, data_cov, want = bias_params_fixed if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None - assert data_cov['data'].shape[0] == len(data_cov['names']) - loc = apply_lcmv(evoked, make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, depth=depth, - weight_norm=weight_norm)).data + assert data_cov["data"].shape[0] == len(data_cov["names"]) + loc = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + depth=depth, + weight_norm=weight_norm, + ), + ).data loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: perc = (want == np.argmax(loc, axis=0)).mean() * 100 @@ -657,51 +838,117 @@ def test_localization_bias_fixed(bias_params_fixed, reg, weight_norm, use_cov, # Changes here should be synced with test_dics.py @pytest.mark.parametrize( - 'reg, pick_ori, weight_norm, use_cov, depth, lower, upper, ' - 'lower_ori, upper_ori', [ - (0.05, 'vector', 'unit-noise-gain-invariant', False, None, 26, 28, 0.82, 0.84), # noqa: E501 - (0.05, 'vector', 'unit-noise-gain-invariant', True, None, 40, 42, 0.96, 0.98), # noqa: E501 - (0.05, 'vector', 'unit-noise-gain', False, None, 13, 14, 0.79, 0.81), - (0.05, 'vector', 'unit-noise-gain', True, None, 35, 37, 0.98, 0.99), - (0.05, 'vector', 'nai', True, None, 35, 37, 0.98, 0.99), - (0.05, 'vector', None, True, None, 12, 14, 0.97, 0.98), - (0.05, 'vector', None, True, 0.8, 39, 43, 0.97, 0.98), - (0.05, 'max-power', 'unit-noise-gain-invariant', False, None, 17, 20, 0, 0), # noqa: E501 - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, 0, 0), - (0.05, 'max-power', 'nai', True, None, 21, 24, 0, 0), - (0.05, 'max-power', None, True, None, 7, 10, 0, 0), - (0.05, 'max-power', None, True, 0.8, 15, 18, 0, 0), + "reg, pick_ori, weight_norm, use_cov, depth, lower, upper, " "lower_ori, upper_ori", + [ + ( + 0.05, + "vector", + "unit-noise-gain-invariant", + False, + None, + 26, + 28, + 0.82, + 0.84, + ), # noqa: E501 + ( + 0.05, + "vector", + "unit-noise-gain-invariant", + True, + None, + 40, + 42, + 0.96, + 0.98, + ), # noqa: E501 + (0.05, "vector", "unit-noise-gain", False, None, 13, 14, 0.79, 0.81), + (0.05, "vector", "unit-noise-gain", True, None, 35, 37, 0.98, 0.99), + (0.05, "vector", "nai", True, None, 35, 37, 0.98, 0.99), + (0.05, "vector", None, True, None, 12, 14, 0.97, 0.98), + (0.05, "vector", None, True, 0.8, 39, 43, 0.97, 0.98), + ( + 0.05, + "max-power", + "unit-noise-gain-invariant", + False, + None, + 17, + 20, + 0, + 0, + ), # noqa: E501 + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, 0, 0), + (0.05, "max-power", "nai", True, None, 21, 24, 0, 0), + (0.05, "max-power", None, True, None, 7, 10, 0, 0), + (0.05, "max-power", None, True, 0.8, 15, 18, 0, 0), (0.05, None, None, True, 0.8, 40, 42, 0, 0), # no reg - (0.00, 'vector', None, True, None, 23, 24, 0.96, 0.97), - (0.00, 'vector', 'unit-noise-gain-invariant', True, None, 52, 54, 0.95, 0.96), # noqa: E501 - (0.00, 'vector', 'unit-noise-gain', True, None, 44, 48, 0.97, 0.99), - (0.00, 'vector', 'nai', True, None, 44, 48, 0.97, 0.99), - (0.00, 'max-power', None, True, None, 14, 15, 0, 0), - (0.00, 'max-power', 'unit-noise-gain-invariant', True, None, 35, 37, 0, 0), # noqa: E501 - (0.00, 'max-power', 'unit-noise-gain', True, None, 35, 37, 0, 0), - (0.00, 'max-power', 'nai', True, None, 35, 37, 0, 0), - ]) -def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, - use_cov, depth, lower, upper, - lower_ori, upper_ori): + (0.00, "vector", None, True, None, 23, 24, 0.96, 0.97), + ( + 0.00, + "vector", + "unit-noise-gain-invariant", + True, + None, + 52, + 54, + 0.95, + 0.96, + ), # noqa: E501 + (0.00, "vector", "unit-noise-gain", True, None, 44, 48, 0.97, 0.99), + (0.00, "vector", "nai", True, None, 44, 48, 0.97, 0.99), + (0.00, "max-power", None, True, None, 14, 15, 0, 0), + ( + 0.00, + "max-power", + "unit-noise-gain-invariant", + True, + None, + 35, + 37, + 0, + 0, + ), # noqa: E501 + (0.00, "max-power", "unit-noise-gain", True, None, 35, 37, 0, 0), + (0.00, "max-power", "nai", True, None, 35, 37, 0, 0), + ], +) +def test_localization_bias_free( + bias_params_free, + reg, + pick_ori, + weight_norm, + use_cov, + depth, + lower, + upper, + lower_ori, + upper_ori, +): """Test localization bias for free-orientation LCMV.""" evoked, fwd, noise_cov, data_cov, want = bias_params_free if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None with _record_warnings(): # rank deficiency of data_cov - filters = make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, - depth=depth) + filters = make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=depth, + ) loc = apply_lcmv(evoked, filters).data - if pick_ori == 'vector': + if pick_ori == "vector": ori = loc.copy() / np.linalg.norm(loc, axis=1, keepdims=True) else: # doesn't make sense for pooled (None) or max-power (can't be all 3) ori = None - loc = np.linalg.norm(loc, axis=1) if pick_ori == 'vector' else np.abs(loc) + loc = np.linalg.norm(loc, axis=1) if pick_ori == "vector" else np.abs(loc) # Compute the percentage of sources for which there is no loc bias: max_idx = np.argmax(loc, axis=0) perc = (want == max_idx).mean() * 100 @@ -712,35 +959,52 @@ def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, # Changes here should be synced with the ones above, but these have meaningful # orientation values @pytest.mark.parametrize( - 'reg, weight_norm, use_cov, depth, lower, upper, lower_ori, upper_ori', [ - (0.05, 'unit-noise-gain-invariant', False, None, 38, 40, 0.54, 0.55), - (0.05, 'unit-noise-gain', False, None, 38, 40, 0.54, 0.55), - (0.05, 'nai', True, None, 56, 57, 0.59, 0.61), + "reg, weight_norm, use_cov, depth, lower, upper, lower_ori, upper_ori", + [ + (0.05, "unit-noise-gain-invariant", False, None, 38, 40, 0.54, 0.55), + (0.05, "unit-noise-gain", False, None, 38, 40, 0.54, 0.55), + (0.05, "nai", True, None, 56, 57, 0.59, 0.61), (0.05, None, True, None, 27, 28, 0.56, 0.57), (0.05, None, True, 0.8, 42, 43, 0.56, 0.57), # no reg (0.00, None, True, None, 50, 51, 0.58, 0.59), - (0.00, 'unit-noise-gain-invariant', True, None, 73, 75, 0.59, 0.61), - (0.00, 'unit-noise-gain', True, None, 73, 75, 0.59, 0.61), - (0.00, 'nai', True, None, 73, 75, 0.59, 0.61), - ]) -def test_orientation_max_power(bias_params_fixed, bias_params_free, - reg, weight_norm, use_cov, depth, lower, upper, - lower_ori, upper_ori): + (0.00, "unit-noise-gain-invariant", True, None, 73, 75, 0.59, 0.61), + (0.00, "unit-noise-gain", True, None, 73, 75, 0.59, 0.61), + (0.00, "nai", True, None, 73, 75, 0.59, 0.61), + ], +) +def test_orientation_max_power( + bias_params_fixed, + bias_params_free, + reg, + weight_norm, + use_cov, + depth, + lower, + upper, + lower_ori, + upper_ori, +): """Test orientation selection for bias for max-power LCMV.""" # we simulate data for the fixed orientation forward and beamform using # the free orientation forward, and check the orientation match at the end evoked, _, noise_cov, data_cov, want = bias_params_fixed fwd = bias_params_free[1] if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None - filters = make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, pick_ori='max-power', - weight_norm=weight_norm, - depth=depth) + filters = make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + pick_ori="max-power", + weight_norm=weight_norm, + depth=depth, + ) loc = apply_lcmv(evoked, filters).data - ori = filters['max_power_ori'] + ori = filters["max_power_ori"] assert ori.shape == (246, 3) loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: @@ -749,11 +1013,10 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, perc = mask.mean() * 100 assert lower <= perc <= upper # Compute the dot products of our forward normals and - assert fwd['coord_frame'] == FIFF.FIFFV_COORD_HEAD - nn = np.concatenate( - [s['nn'][v] for s, v in zip(fwd['src'], filters['vertices'])]) + assert fwd["coord_frame"] == FIFF.FIFFV_COORD_HEAD + nn = np.concatenate([s["nn"][v] for s, v in zip(fwd["src"], filters["vertices"])]) nn = nn[want] - nn = apply_trans(invert_transform(fwd['mri_head_t']), nn, move=False) + nn = apply_trans(invert_transform(fwd["mri_head_t"]), nn, move=False) assert_allclose(np.linalg.norm(nn, axis=1), 1, atol=1e-6) assert_allclose(np.linalg.norm(ori, axis=1), 1, atol=1e-12) dots = np.abs((nn[mask] * ori[mask]).sum(-1)) @@ -763,21 +1026,44 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, assert lower_ori < got < upper_ori -@pytest.mark.parametrize('weight_norm, pick_ori', [ - pytest.param('nai', 'max-power', marks=pytest.mark.slowtest), - ('unit-noise-gain', 'vector'), - ('unit-noise-gain', 'max-power'), - pytest.param('unit-noise-gain', None, marks=pytest.mark.slowtest), -]) +@pytest.mark.parametrize( + "weight_norm, pick_ori", + [ + pytest.param("nai", "max-power", marks=pytest.mark.slowtest), + ("unit-noise-gain", "vector"), + ("unit-noise-gain", "max-power"), + pytest.param("unit-noise-gain", None, marks=pytest.mark.slowtest), + ], +) def test_depth_does_not_matter(bias_params_free, weight_norm, pick_ori): """Test that depth weighting does not matter for normalized filters.""" evoked, fwd, noise_cov, data_cov, _ = bias_params_free - data = apply_lcmv(evoked, make_lcmv( - evoked.info, fwd, data_cov, 0.05, noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, depth=0.)).data - data_depth = apply_lcmv(evoked, make_lcmv( - evoked.info, fwd, data_cov, 0.05, noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, depth=1.)).data + data = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + 0.05, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=0.0, + ), + ).data + data_depth = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + 0.05, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=1.0, + ), + ).data assert data.shape == data_depth.shape for d1, d2 in zip(data, data_depth): # Sign flips can change when nearly orthogonal to the normal direction @@ -793,59 +1079,78 @@ def test_lcmv_maxfiltered(): raw_sss = mne.preprocessing.maxwell_filter(raw) events = mne.find_events(raw_sss) del raw - raw_sss.pick_types(meg='mag') + raw_sss.pick_types(meg="mag") assert len(raw_sss.ch_names) == 102 epochs = mne.Epochs(raw_sss, events) data_cov = mne.compute_covariance(epochs, tmin=0) fwd = mne.read_forward_solution(fname_fwd) rank = compute_rank(data_cov, info=epochs.info) - assert rank == {'mag': 71} - for use_rank in ('info', rank, 'full', None): + assert rank == {"mag": 71} + for use_rank in ("info", rank, "full", None): make_lcmv(epochs.info, fwd, data_cov, rank=use_rank) # To reduce test time, only test combinations that should matter rather than # all of them @testing.requires_testing_data -@pytest.mark.parametrize('pick_ori, weight_norm, reg, inversion', [ - ('vector', 'unit-noise-gain-invariant', 0.05, 'matrix'), - ('vector', 'unit-noise-gain-invariant', 0.05, 'single'), - ('vector', 'unit-noise-gain', 0.05, 'matrix'), - ('vector', 'unit-noise-gain', 0.05, 'single'), - ('vector', 'unit-noise-gain', 0.0, 'matrix'), - ('vector', 'unit-noise-gain', 0.0, 'single'), - ('vector', 'nai', 0.05, 'matrix'), - ('max-power', 'unit-noise-gain', 0.05, 'matrix'), - ('max-power', 'unit-noise-gain', 0.0, 'single'), - ('max-power', 'unit-noise-gain', 0.05, 'single'), - ('max-power', 'unit-noise-gain-invariant', 0.05, 'matrix'), - ('normal', 'unit-noise-gain', 0.05, 'matrix'), - ('normal', 'nai', 0.0, 'matrix'), -]) +@pytest.mark.parametrize( + "pick_ori, weight_norm, reg, inversion", + [ + ("vector", "unit-noise-gain-invariant", 0.05, "matrix"), + ("vector", "unit-noise-gain-invariant", 0.05, "single"), + ("vector", "unit-noise-gain", 0.05, "matrix"), + ("vector", "unit-noise-gain", 0.05, "single"), + ("vector", "unit-noise-gain", 0.0, "matrix"), + ("vector", "unit-noise-gain", 0.0, "single"), + ("vector", "nai", 0.05, "matrix"), + ("max-power", "unit-noise-gain", 0.05, "matrix"), + ("max-power", "unit-noise-gain", 0.0, "single"), + ("max-power", "unit-noise-gain", 0.05, "single"), + ("max-power", "unit-noise-gain-invariant", 0.05, "matrix"), + ("normal", "unit-noise-gain", 0.05, "matrix"), + ("normal", "nai", 0.0, "matrix"), + ], +) def test_unit_noise_gain_formula(pick_ori, weight_norm, reg, inversion): """Test unit-noise-gain filter against formula.""" raw = mne.io.read_raw_fif(fname_raw, preload=True) events = mne.find_events(raw) - raw.pick_types(meg='mag') + raw.pick_types(meg="mag") assert len(raw.ch_names) == 102 epochs = mne.Epochs(raw, events, None, preload=True) data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.15) # for now, avoid whitening to make life easier - noise_cov = mne.make_ad_hoc_cov(epochs.info, std=dict(grad=1., mag=1.)) + noise_cov = mne.make_ad_hoc_cov(epochs.info, std=dict(grad=1.0, mag=1.0)) forward = mne.read_forward_solution(fname_fwd) convert_forward_solution(forward, surf_ori=True, copy=False) rank = None - kwargs = dict(reg=reg, noise_cov=noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, rank=rank, inversion=inversion) - if inversion == 'single' and pick_ori == 'vector' and \ - weight_norm == 'unit-noise-gain-invariant': - with pytest.raises(ValueError, match='Cannot use'): + kwargs = dict( + reg=reg, + noise_cov=noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + rank=rank, + inversion=inversion, + ) + if ( + inversion == "single" + and pick_ori == "vector" + and weight_norm == "unit-noise-gain-invariant" + ): + with pytest.raises(ValueError, match="Cannot use"): make_lcmv(epochs.info, forward, data_cov, **kwargs) return filters = make_lcmv(epochs.info, forward, data_cov, **kwargs) _, _, _, _, G, _, _, _ = _prepare_beamformer_input( - epochs.info, forward, None, 'vector', noise_cov=noise_cov, rank=rank, - pca=False, exp=None) + epochs.info, + forward, + None, + "vector", + noise_cov=noise_cov, + rank=rank, + pca=False, + exp=None, + ) n_channels, n_sources = G.shape n_sources //= 3 G.shape = (n_channels, n_sources, 3) @@ -855,26 +1160,26 @@ def test_unit_noise_gain_formula(pick_ori, weight_norm, reg, inversion): def _assert_weight_norm(filters, G): """Check the result of the chosen weight normalization strategy.""" - weights, max_power_ori = filters['weights'], filters['max_power_ori'] + weights, max_power_ori = filters["weights"], filters["max_power_ori"] # Make the dimensions of the weight matrix equal for both DICS (which # defines weights for multiple frequencies) and LCMV (which does not). - if filters['kind'] == 'LCMV': + if filters["kind"] == "LCMV": weights = weights[np.newaxis] if max_power_ori is not None: max_power_ori = max_power_ori[np.newaxis] if max_power_ori is not None: max_power_ori = max_power_ori[..., np.newaxis] - weight_norm = filters['weight_norm'] - inversion = filters['inversion'] + weight_norm = filters["weight_norm"] + inversion = filters["inversion"] n_channels = weights.shape[2] - if inversion == 'matrix': + if inversion == "matrix": # Dipoles are grouped in groups with size n_orient - n_sources = filters['n_sources'] - n_orient = 3 if filters['is_free_ori'] else 1 - elif inversion == 'single': + n_sources = filters["n_sources"] + n_orient = 3 if filters["is_free_ori"] else 1 + elif inversion == "single": # Every dipole is treated as a unique source n_sources = weights.shape[1] n_orient = 1 @@ -884,13 +1189,13 @@ def _assert_weight_norm(filters, G): # Compute leadfield in the direction chosen during the computation of # the beamformer. - if filters['pick_ori'] == 'max-power': + if filters["pick_ori"] == "max-power": use_G = np.sum(G * max_power_ori[wi], axis=1, keepdims=True) - elif filters['pick_ori'] == 'normal': + elif filters["pick_ori"] == "normal": use_G = G[:, -1:] else: use_G = G - if inversion == 'single': + if inversion == "single": # Every dipole is treated as a unique source use_G = use_G.reshape(n_sources, 1, n_channels) assert w.shape == use_G.shape == (n_sources, n_orient, n_channels) @@ -898,32 +1203,32 @@ def _assert_weight_norm(filters, G): # Test weight normalization scheme got = np.matmul(w, w.conj().swapaxes(-2, -1)) desired = np.repeat(np.eye(n_orient)[np.newaxis], w.shape[0], axis=0) - if n_orient == 3 and weight_norm in ('unit-noise-gain', 'nai'): + if n_orient == 3 and weight_norm in ("unit-noise-gain", "nai"): # only the diagonal is correct! assert not np.allclose(got, desired, atol=1e-7) - got = got.reshape(n_sources, -1)[:, ::n_orient + 1] + got = got.reshape(n_sources, -1)[:, :: n_orient + 1] desired = np.ones_like(got) - if weight_norm == 'nai': # additional scale factor, should be fixed + if weight_norm == "nai": # additional scale factor, should be fixed atol = 1e-7 * got.flat[0] desired *= got.flat[0] else: atol = 1e-7 - assert_allclose(got, desired, atol=atol, err_msg='w @ w.conj().T = I') + assert_allclose(got, desired, atol=atol, err_msg="w @ w.conj().T = I") # Check that the result here is a diagonal matrix for Sekihara - if n_orient > 1 and weight_norm != 'unit-noise-gain-invariant': + if n_orient > 1 and weight_norm != "unit-noise-gain-invariant": got = w @ use_G.swapaxes(-2, -1) diags = np.diagonal(got, 0, -2, -1) want = np.apply_along_axis(np.diagflat, 1, diags) atol = np.mean(diags).real * 1e-12 - assert_allclose(got, want, atol=atol, err_msg='G.T @ w = θI') + assert_allclose(got, want, atol=atol, err_msg="G.T @ w = θI") def test_api(): """Test LCMV/DICS API equivalence.""" lcmv_names = list(signature(make_lcmv).parameters) dics_names = list(signature(make_dics).parameters) - dics_names[dics_names.index('csd')] = 'data_cov' - dics_names[dics_names.index('noise_csd')] = 'noise_cov' - dics_names.pop(dics_names.index('real_filter')) # not a thing for LCMV + dics_names[dics_names.index("csd")] = "data_cov" + dics_names[dics_names.index("noise_csd")] = "noise_cov" + dics_names.pop(dics_names.index("real_filter")) # not a thing for LCMV assert lcmv_names == dics_names diff --git a/mne/beamformer/tests/test_rap_music.py b/mne/beamformer/tests/test_rap_music.py index 6595b792dcb..68abae4d435 100644 --- a/mne/beamformer/tests/test_rap_music.py +++ b/mne/beamformer/tests/test_rap_music.py @@ -19,18 +19,16 @@ data_path = testing.data_path(download=False) fname_ave = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" def _get_data(ch_decim=1): """Read in data used in tests.""" # Read evoked evoked = mne.read_evokeds(fname_ave, 0, baseline=(None, 0)) - evoked.info['bads'] = ['MEG 2443'] + evoked.info["bads"] = ["MEG 2443"] with evoked.info._unlock(): - evoked.info['lowpass'] = 16 # fake for decim + evoked.info["lowpass"] = 16 # fake for decim evoked.decimate(12) evoked.crop(0.0, 0.3) picks = mne.pick_types(evoked.info, meg=True, eeg=False) @@ -39,8 +37,8 @@ def _get_data(ch_decim=1): evoked.info.normalize_proj() noise_cov = mne.read_cov(fname_cov) - noise_cov['projs'] = [] - noise_cov = regularize(noise_cov, evoked.info, rank='full', proj=False) + noise_cov["projs"] = [] + noise_cov = regularize(noise_cov, evoked.info, rank="full", proj=False) return evoked, noise_cov @@ -51,66 +49,69 @@ def simu_data(evoked, forward, noise_cov, n_dipoles, times, nave=1): """ # Generate the two dipoles data mu, sigma = 0.1, 0.005 - s1 = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-(times - mu) ** 2 / - (2 * sigma ** 2)) + s1 = ( + 1 + / (sigma * np.sqrt(2 * np.pi)) + * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) + ) mu, sigma = 0.075, 0.008 - s2 = -1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-(times - mu) ** 2 / - (2 * sigma ** 2)) + s2 = ( + -1 + / (sigma * np.sqrt(2 * np.pi)) + * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) + ) data = np.array([s1, s2]) * 1e-9 - src = forward['src'] + src = forward["src"] rng = np.random.RandomState(42) - rndi = rng.randint(len(src[0]['vertno'])) - lh_vertno = src[0]['vertno'][[rndi]] + rndi = rng.randint(len(src[0]["vertno"])) + lh_vertno = src[0]["vertno"][[rndi]] - rndi = rng.randint(len(src[1]['vertno'])) - rh_vertno = src[1]['vertno'][[rndi]] + rndi = rng.randint(len(src[1]["vertno"])) + rh_vertno = src[1]["vertno"][[rndi]] vertices = [lh_vertno, rh_vertno] - tmin, tstep = times.min(), 1 / evoked.info['sfreq'] + tmin, tstep = times.min(), 1 / evoked.info["sfreq"] stc = mne.SourceEstimate(data, vertices=vertices, tmin=tmin, tstep=tstep) - sim_evoked = mne.simulation.simulate_evoked(forward, stc, evoked.info, - noise_cov, nave=nave, - random_state=rng) + sim_evoked = mne.simulation.simulate_evoked( + forward, stc, evoked.info, noise_cov, nave=nave, random_state=rng + ) return sim_evoked, stc def _check_dipoles(dipoles, fwd, stc, evoked, residual=None): - src = fwd['src'] - pos1 = fwd['source_rr'][np.where(src[0]['vertno'] == - stc.vertices[0])] - pos2 = fwd['source_rr'][np.where(src[1]['vertno'] == - stc.vertices[1])[0] + - len(src[0]['vertno'])] + src = fwd["src"] + pos1 = fwd["source_rr"][np.where(src[0]["vertno"] == stc.vertices[0])] + pos2 = fwd["source_rr"][ + np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"]) + ] # Check the position of the two dipoles - assert (dipoles[0].pos[0] in np.array([pos1, pos2])) - assert (dipoles[1].pos[0] in np.array([pos1, pos2])) + assert dipoles[0].pos[0] in np.array([pos1, pos2]) + assert dipoles[1].pos[0] in np.array([pos1, pos2]) - ori1 = fwd['source_nn'][np.where(src[0]['vertno'] == - stc.vertices[0])[0]][0] - ori2 = fwd['source_nn'][np.where(src[1]['vertno'] == - stc.vertices[1])[0] + - len(src[0]['vertno'])][0] + ori1 = fwd["source_nn"][np.where(src[0]["vertno"] == stc.vertices[0])[0]][0] + ori2 = fwd["source_nn"][ + np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"]) + ][0] # Check the orientation of the dipoles - assert (np.max(np.abs(np.dot(dipoles[0].ori[0], - np.array([ori1, ori2]).T))) > 0.99) + assert np.max(np.abs(np.dot(dipoles[0].ori[0], np.array([ori1, ori2]).T))) > 0.99 - assert (np.max(np.abs(np.dot(dipoles[1].ori[0], - np.array([ori1, ori2]).T))) > 0.99) + assert np.max(np.abs(np.dot(dipoles[1].ori[0], np.array([ori1, ori2]).T))) > 0.99 if residual is not None: - picks_grad = mne.pick_types(residual.info, meg='grad') - picks_mag = mne.pick_types(residual.info, meg='mag') + picks_grad = mne.pick_types(residual.info, meg="grad") + picks_mag = mne.pick_types(residual.info, meg="mag") rel_tol = 0.02 for picks in [picks_grad, picks_mag]: - assert (linalg.norm(residual.data[picks], ord='fro') < - rel_tol * linalg.norm(evoked.data[picks], ord='fro')) + assert linalg.norm(residual.data[picks], ord="fro") < rel_tol * linalg.norm( + evoked.data[picks], ord="fro" + ) @testing.requires_testing_data @@ -120,37 +121,48 @@ def test_rap_music_simulated(): forward = mne.read_forward_solution(fname_fwd) forward = mne.pick_channels_forward(forward, evoked.ch_names) forward_surf_ori = mne.convert_forward_solution(forward, surf_ori=True) - forward_fixed = mne.convert_forward_solution(forward, force_fixed=True, - surf_ori=True, use_cps=True) + forward_fixed = mne.convert_forward_solution( + forward, force_fixed=True, surf_ori=True, use_cps=True + ) n_dipoles = 2 - sim_evoked, stc = simu_data(evoked, forward_fixed, noise_cov, - n_dipoles, evoked.times, nave=evoked.nave) + sim_evoked, stc = simu_data( + evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=evoked.nave + ) # Check dipoles for fixed ori with catch_logging() as log: - dipoles = rap_music(sim_evoked, forward_fixed, noise_cov, - n_dipoles=n_dipoles, verbose=True) + dipoles = rap_music( + sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, verbose=True + ) assert_var_exp_log(log.getvalue(), 89, 91) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked) assert 97 < dipoles[0].gof.max() < 100 assert 91 < dipoles[1].gof.max() < 93 - assert dipoles[0].gof.min() >= 0. + assert dipoles[0].gof.min() >= 0.0 nave = 100000 # add a tiny amount of noise to the simulated evokeds - sim_evoked, stc = simu_data(evoked, forward_fixed, noise_cov, - n_dipoles, evoked.times, nave=nave) - dipoles, residual = rap_music(sim_evoked, forward_fixed, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + sim_evoked, stc = simu_data( + evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=nave + ) + dipoles, residual = rap_music( + sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, return_residual=True + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) # Check dipoles for free ori - dipoles, residual = rap_music(sim_evoked, forward, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + dipoles, residual = rap_music( + sim_evoked, forward, noise_cov, n_dipoles=n_dipoles, return_residual=True + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) # Check dipoles for free surface ori - dipoles, residual = rap_music(sim_evoked, forward_surf_ori, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + dipoles, residual = rap_music( + sim_evoked, + forward_surf_ori, + noise_cov, + n_dipoles=n_dipoles, + return_residual=True, + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) @@ -159,17 +171,19 @@ def test_rap_music_simulated(): def test_rap_music_sphere(): """Test RAP-MUSIC with real data, sphere model, MEG only.""" evoked, noise_cov = _get_data(ch_decim=8) - sphere = mne.make_sphere_model(r0=(0., 0., 0.04)) - src = mne.setup_volume_source_space(subject=None, pos=10., - sphere=(0.0, 0.0, 40, 65.0), - mindist=5.0, exclude=0.0, - sphere_units='mm') - forward = mne.make_forward_solution(evoked.info, trans=None, src=src, - bem=sphere) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.04)) + src = mne.setup_volume_source_space( + subject=None, + pos=10.0, + sphere=(0.0, 0.0, 40, 65.0), + mindist=5.0, + exclude=0.0, + sphere_units="mm", + ) + forward = mne.make_forward_solution(evoked.info, trans=None, src=src, bem=sphere) with catch_logging() as log: - dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2, - verbose=True) + dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2, verbose=True) assert_var_exp_log(log.getvalue(), 47, 49) # Test that there is one dipole on each hemisphere pos = np.array([dip.pos[0] for dip in dipoles]) @@ -177,11 +191,11 @@ def test_rap_music_sphere(): assert (pos[:, 0] < 0).sum() == 1 assert (pos[:, 0] > 0).sum() == 1 # Check the amplitude scale - assert (1e-10 < dipoles[0].amplitude[0] < 1e-7) + assert 1e-10 < dipoles[0].amplitude[0] < 1e-7 # Check the orientation dip_fit = mne.fit_dipole(evoked, noise_cov, sphere)[0] - assert (np.max(np.abs(np.dot(dip_fit.ori, dipoles[0].ori[0]))) > 0.99) - assert (np.max(np.abs(np.dot(dip_fit.ori, dipoles[1].ori[0]))) > 0.99) + assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[0].ori[0]))) > 0.99 + assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[1].ori[0]))) > 0.99 idx = dip_fit.gof.argmax() dist = np.linalg.norm(dipoles[0].pos[idx] - dip_fit.pos[idx]) assert 0.004 <= dist < 0.007 @@ -191,8 +205,7 @@ def test_rap_music_sphere(): @testing.requires_testing_data def test_rap_music_picks(): """Test RAP-MUSIC with picking.""" - evoked = mne.read_evokeds(fname_ave, condition='Right Auditory', - baseline=(None, 0)) + evoked = mne.read_evokeds(fname_ave, condition="Right Auditory", baseline=(None, 0)) evoked.crop(tmin=0.05, tmax=0.15) # select N100 evoked.pick_types(meg=True, eeg=False) forward = mne.read_forward_solution(fname_fwd) diff --git a/mne/beamformer/tests/test_resolution_matrix.py b/mne/beamformer/tests/test_resolution_matrix.py index 6d6730e3b9e..6e574bf89f5 100755 --- a/mne/beamformer/tests/test_resolution_matrix.py +++ b/mne/beamformer/tests/test_resolution_matrix.py @@ -19,16 +19,11 @@ data_path = testing.data_path(download=False) subjects_dir = data_path / "subjects" fname_inv = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif" ) fname_evoked = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" @@ -39,11 +34,10 @@ def test_resolution_matrix_lcmv(): forward = mne.read_forward_solution(fname_fwd) # remove bad channels - forward = mne.pick_channels_forward(forward, exclude='bads') + forward = mne.pick_channels_forward(forward, exclude="bads") # forward operator with fixed source orientations - forward_fxd = mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True) + forward_fxd = mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True) # evoked info info = mne.io.read_info(fname_evoked) @@ -59,12 +53,18 @@ def test_resolution_matrix_lcmv(): # compute beamformer filters # reg=0. to make sure noise_cov and data_cov are as similar as possible - filters = make_lcmv(info, forward_fxd, data_cov, reg=0., - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) + filters = make_lcmv( + info, + forward_fxd, + data_cov, + reg=0.0, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, + ) # Compute resolution matrix for beamformer resmat_lcmv = make_lcmv_resolution_matrix(filters, forward_fxd, info) @@ -73,9 +73,9 @@ def test_resolution_matrix_lcmv(): # transpose of leadfield # create filters with transposed whitened leadfield as weights - forward_fxd = mne.pick_channels_forward(forward_fxd, info['ch_names']) + forward_fxd = mne.pick_channels_forward(forward_fxd, info["ch_names"]) filters_lfd = deepcopy(filters) - filters_lfd['weights'][:] = forward_fxd['sol']['data'].T + filters_lfd["weights"][:] = forward_fxd["sol"]["data"].T # compute resolution matrix for filters with transposed leadfield resmat_fwd = make_lcmv_resolution_matrix(filters_lfd, forward_fxd, info) @@ -85,12 +85,11 @@ def test_resolution_matrix_lcmv(): # Some rows are off by about 0.1 - not yet clear why corr = [] - for (f, lf) in zip(resmat_fwd, resmat_lcmv): - + for f, lf in zip(resmat_fwd, resmat_lcmv): corr.append(np.corrcoef(f, lf)[0, 1]) # all row correlations should at least be above ~0.8 - assert_allclose(corr, 1., atol=0.2) + assert_allclose(corr, 1.0, atol=0.2) # Maximum row correlation should at least be close to 1 - assert_allclose(np.max(corr), 1., atol=0.01) + assert_allclose(np.max(corr), 1.0, atol=0.01) diff --git a/mne/bem.py b/mne/bem.py index 505c0fca79d..b9a2bb2b96e 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -23,22 +23,54 @@ from .fixes import _compare_version from .io.constants import FIFF, FWD from .io._digitization import _dig_kind_dict, _dig_kind_rev, _dig_kind_ints -from .io.write import (start_and_end_file, start_block, write_float, write_int, - write_float_matrix, write_int_matrix, end_block, - write_string) +from .io.write import ( + start_and_end_file, + start_block, + write_float, + write_int, + write_float_matrix, + write_int_matrix, + end_block, + write_string, +) from .io.tag import find_tag from .io.tree import dir_tree_find from .io.open import fiff_open -from .surface import (read_surface, write_surface, complete_surface_info, - _compute_nearest, _get_ico_surface, read_tri, - _fast_cross_nd_sum, _get_solids, _complete_sphere_surf, - decimate_surface, transform_surface_to) +from .surface import ( + read_surface, + write_surface, + complete_surface_info, + _compute_nearest, + _get_ico_surface, + read_tri, + _fast_cross_nd_sum, + _get_solids, + _complete_sphere_surf, + decimate_surface, + transform_surface_to, +) from .transforms import _ensure_trans, apply_trans, Transform -from .utils import (verbose, logger, run_subprocess, get_subjects_dir, warn, - _pl, _validate_type, _TempDir, _check_freesurfer_home, - _check_fname, _check_option, path_like, _import_nibabel, - _on_missing, _import_h5io_funcs, _ensure_int, - _path_like, _verbose_safe_false, _check_head_radius) +from .utils import ( + verbose, + logger, + run_subprocess, + get_subjects_dir, + warn, + _pl, + _validate_type, + _TempDir, + _check_freesurfer_home, + _check_fname, + _check_option, + path_like, + _import_nibabel, + _on_missing, + _import_h5io_funcs, + _ensure_int, + _path_like, + _verbose_safe_false, + _check_head_radius, +) # ############################################################################ @@ -56,20 +88,22 @@ class ConductorModel(dict): """BEM or sphere model.""" def __repr__(self): # noqa: D105 - if self['is_sphere']: - center = ', '.join('%0.1f' % (x * 1000.) for x in self['r0']) + if self["is_sphere"]: + center = ", ".join("%0.1f" % (x * 1000.0) for x in self["r0"]) rad = self.radius if rad is None: # no radius / MEG only - extra = 'Sphere (no layers): r0=[%s] mm' % center + extra = "Sphere (no layers): r0=[%s] mm" % center else: - extra = ('Sphere (%s layer%s): r0=[%s] R=%1.f mm' - % (len(self['layers']) - 1, _pl(self['layers']), - center, rad * 1000.)) + extra = "Sphere (%s layer%s): r0=[%s] R=%1.f mm" % ( + len(self["layers"]) - 1, + _pl(self["layers"]), + center, + rad * 1000.0, + ) else: - extra = ('BEM (%s layer%s)' % (len(self['surfs']), - _pl(self['surfs']))) - extra += " solver=%s" % self['solver'] - return '' % extra + extra = "BEM (%s layer%s)" % (len(self["surfs"]), _pl(self["surfs"])) + extra += " solver=%s" % self["solver"] + return "" % extra def copy(self): """Return copy of ConductorModel instance.""" @@ -78,9 +112,9 @@ def copy(self): @property def radius(self): """Sphere radius if an EEG sphere model.""" - if not self['is_sphere']: - raise RuntimeError('radius undefined for BEM') - return None if len(self['layers']) == 0 else self['layers'][-1]['rad'] + if not self["is_sphere"]: + raise RuntimeError("radius undefined for BEM") + return None if len(self["layers"]) == 0 else self["layers"][-1]["rad"] def _calc_beta(rk, rk_norm, rk1, rk1_norm): @@ -108,9 +142,9 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): l2 = np.linalg.norm(v2, axis=1) l3 = np.linalg.norm(v3, axis=1) ss = l1 * l2 * l3 - ss += np.einsum('ij,ij,i->i', v1, v2, l3) - ss += np.einsum('ij,ij,i->i', v1, v3, l2) - ss += np.einsum('ij,ij,i->i', v2, v3, l1) + ss += np.einsum("ij,ij,i->i", v1, v2, l3) + ss += np.einsum("ij,ij,i->i", v1, v3, l2) + ss += np.einsum("ij,ij,i->i", v2, v3, l1) solids = np.arctan2(triples, ss) # We *could* subselect the good points from v1, v2, v3, triples, solids, @@ -119,14 +153,16 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): # solution. These three lines ensure we don't get invalid values in # _calc_beta. bad_mask = np.abs(solids) < np.pi / 1e6 - l1[bad_mask] = 1. - l2[bad_mask] = 1. - l3[bad_mask] = 1. + l1[bad_mask] = 1.0 + l2[bad_mask] = 1.0 + l3[bad_mask] = 1.0 # Calculate the magic vector vec_omega - beta = [_calc_beta(v1, l1, v2, l2)[:, np.newaxis], - _calc_beta(v2, l2, v3, l3)[:, np.newaxis], - _calc_beta(v3, l3, v1, l1)[:, np.newaxis]] + beta = [ + _calc_beta(v1, l1, v2, l2)[:, np.newaxis], + _calc_beta(v2, l2, v3, l3)[:, np.newaxis], + _calc_beta(v3, l3, v1, l1)[:, np.newaxis], + ] vec_omega = (beta[2] - beta[0]) * v1 vec_omega += (beta[0] - beta[1]) * v2 vec_omega += (beta[1] - beta[2]) * v3 @@ -140,26 +176,27 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): for k in range(3): diff = yys[idx[k - 1]] - yys[idx[k + 1]] zdots = _fast_cross_nd_sum(yys[idx[k + 1]], yys[idx[k - 1]], tri_nn) - omega[:, k] = -n2 * (area2 * zdots * 2. * solids - - triples * (diff * vec_omega).sum(axis=-1)) + omega[:, k] = -n2 * ( + area2 * zdots * 2.0 * solids - triples * (diff * vec_omega).sum(axis=-1) + ) # omit the bad points from the solution - omega[bad_mask] = 0. + omega[bad_mask] = 0.0 return omega def _correct_auto_elements(surf, mat): """Improve auto-element approximation.""" pi2 = 2.0 * np.pi - tris_flat = surf['tris'].ravel() + tris_flat = surf["tris"].ravel() misses = pi2 - mat.sum(axis=1) for j, miss in enumerate(misses): # How much is missing? - n_memb = len(surf['neighbor_tri'][j]) + n_memb = len(surf["neighbor_tri"][j]) assert n_memb > 0 # should be guaranteed by our surface checks # The node itself receives one half mat[j, j] = miss / 2.0 # The rest is divided evenly among the member nodes... - miss /= (4.0 * n_memb) + miss /= 4.0 * n_memb members = np.where(j == tris_flat)[0] mods = members % 3 offsets = np.array([[1, 2], [-1, 1], [-1, -2]]) @@ -174,27 +211,34 @@ def _correct_auto_elements(surf, mat): def _fwd_bem_lin_pot_coeff(surfs): """Calculate the coefficients for linear collocation approach.""" # taken from fwd_bem_linear_collocation.c - nps = [surf['np'] for surf in surfs] + nps = [surf["np"] for surf in surfs] np_tot = sum(nps) coeff = np.zeros((np_tot, np_tot)) offsets = np.cumsum(np.concatenate(([0], nps))) for si_1, surf1 in enumerate(surfs): rr_ord = np.arange(nps[si_1]) for si_2, surf2 in enumerate(surfs): - logger.info(" %s (%d) -> %s (%d) ..." % - (_bem_surf_name[surf1['id']], nps[si_1], - _bem_surf_name[surf2['id']], nps[si_2])) - tri_rr = surf2['rr'][surf2['tris']] - tri_nn = surf2['tri_nn'] - tri_area = surf2['tri_area'] - submat = coeff[offsets[si_1]:offsets[si_1 + 1], - offsets[si_2]:offsets[si_2 + 1]] # view - for k in range(surf2['ntri']): - tri = surf2['tris'][k] + logger.info( + " %s (%d) -> %s (%d) ..." + % ( + _bem_surf_name[surf1["id"]], + nps[si_1], + _bem_surf_name[surf2["id"]], + nps[si_2], + ) + ) + tri_rr = surf2["rr"][surf2["tris"]] + tri_nn = surf2["tri_nn"] + tri_area = surf2["tri_area"] + submat = coeff[ + offsets[si_1] : offsets[si_1 + 1], offsets[si_2] : offsets[si_2 + 1] + ] # view + for k in range(surf2["ntri"]): + tri = surf2["tris"][k] if si_1 == si_2: - skip_idx = ((rr_ord == tri[0]) | - (rr_ord == tri[1]) | - (rr_ord == tri[2])) + skip_idx = ( + (rr_ord == tri[0]) | (rr_ord == tri[1]) | (rr_ord == tri[2]) + ) else: skip_idx = list() # No contribution from a triangle that @@ -202,9 +246,13 @@ def _fwd_bem_lin_pot_coeff(surfs): # if sidx1 == sidx2 and (tri == j).any(): # continue # Otherwise do the hard job - coeffs = _lin_pot_coeff(fros=surf1['rr'], tri_rr=tri_rr[k], - tri_nn=tri_nn[k], tri_area=tri_area[k]) - coeffs[skip_idx] = 0. + coeffs = _lin_pot_coeff( + fros=surf1["rr"], + tri_rr=tri_rr[k], + tri_nn=tri_nn[k], + tri_area=tri_area[k], + ) + coeffs[skip_idx] = 0.0 submat[:, tri] -= coeffs if si_1 == si_2: _correct_auto_elements(surf1, submat) @@ -246,11 +294,11 @@ def _fwd_bem_ip_modify_solution(solution, ip_solution, ip_mult, n_tri): n_last = n_tri[-1] mult = (1.0 + ip_mult) / ip_mult - logger.info(' Combining...') + logger.info(" Combining...") offsets = np.cumsum(np.concatenate(([0], n_tri))) for si in range(len(n_tri)): # Pick the correct submatrix (right column) and multiply - sub = solution[offsets[si]:offsets[si + 1], np.sum(n_tri[:-1]):] + sub = solution[offsets[si] : offsets[si + 1], np.sum(n_tri[:-1]) :] # Multiply sub -= 2 * np.dot(sub, ip_solution) @@ -258,63 +306,64 @@ def _fwd_bem_ip_modify_solution(solution, ip_solution, ip_mult, n_tri): sub[-n_last:, -n_last:] += mult * ip_solution # Final scaling - logger.info(' Scaling...') + logger.info(" Scaling...") solution *= ip_mult return -def _check_complete_surface(surf, copy=False, incomplete='raise', extra=''): - surf = complete_surface_info( - surf, copy=copy, verbose=_verbose_safe_false()) - fewer = np.where([len(t) < 3 for t in surf['neighbor_tri']])[0] +def _check_complete_surface(surf, copy=False, incomplete="raise", extra=""): + surf = complete_surface_info(surf, copy=copy, verbose=_verbose_safe_false()) + fewer = np.where([len(t) < 3 for t in surf["neighbor_tri"]])[0] if len(fewer) > 0: fewer = list(fewer) - fewer = (fewer[:80] + ['...']) if len(fewer) > 80 else fewer - fewer = ', '.join(str(f) for f in fewer) - msg = ('Surface {} has topological defects: {:.0f} / {:.0f} vertices ' - 'have fewer than three neighboring triangles [{}]{}' - .format(_bem_surf_name[surf['id']], len(fewer), len(surf['rr']), - fewer, extra)) - _on_missing(on_missing=incomplete, msg=msg, name='on_defects') + fewer = (fewer[:80] + ["..."]) if len(fewer) > 80 else fewer + fewer = ", ".join(str(f) for f in fewer) + msg = ( + "Surface {} has topological defects: {:.0f} / {:.0f} vertices " + "have fewer than three neighboring triangles [{}]{}".format( + _bem_surf_name[surf["id"]], len(fewer), len(surf["rr"]), fewer, extra + ) + ) + _on_missing(on_missing=incomplete, msg=msg, name="on_defects") return surf def _fwd_bem_linear_collocation_solution(bem): """Compute the linear collocation potential solution.""" # first, add surface geometries - logger.info('Computing the linear collocation solution...') - logger.info(' Matrix coefficients...') - coeff = _fwd_bem_lin_pot_coeff(bem['surfs']) - bem['nsol'] = len(coeff) + logger.info("Computing the linear collocation solution...") + logger.info(" Matrix coefficients...") + coeff = _fwd_bem_lin_pot_coeff(bem["surfs"]) + bem["nsol"] = len(coeff) logger.info(" Inverting the coefficient matrix...") - nps = [surf['np'] for surf in bem['surfs']] - bem['solution'] = _fwd_bem_multi_solution(coeff, bem['gamma'], nps) - if len(bem['surfs']) == 3: - ip_mult = bem['sigma'][1] / bem['sigma'][2] + nps = [surf["np"] for surf in bem["surfs"]] + bem["solution"] = _fwd_bem_multi_solution(coeff, bem["gamma"], nps) + if len(bem["surfs"]) == 3: + ip_mult = bem["sigma"][1] / bem["sigma"][2] if ip_mult <= FWD.BEM_IP_APPROACH_LIMIT: - logger.info('IP approach required...') - logger.info(' Matrix coefficients (homog)...') - coeff = _fwd_bem_lin_pot_coeff([bem['surfs'][-1]]) - logger.info(' Inverting the coefficient matrix (homog)...') - ip_solution = _fwd_bem_homog_solution(coeff, - [bem['surfs'][-1]['np']]) - logger.info(' Modify the original solution to incorporate ' - 'IP approach...') - _fwd_bem_ip_modify_solution(bem['solution'], ip_solution, ip_mult, - nps) - bem['bem_method'] = FIFF.FIFFV_BEM_APPROX_LINEAR - bem['solver'] = 'mne' - - -def _import_openmeeg(what='compute a BEM solution using OpenMEEG'): + logger.info("IP approach required...") + logger.info(" Matrix coefficients (homog)...") + coeff = _fwd_bem_lin_pot_coeff([bem["surfs"][-1]]) + logger.info(" Inverting the coefficient matrix (homog)...") + ip_solution = _fwd_bem_homog_solution(coeff, [bem["surfs"][-1]["np"]]) + logger.info( + " Modify the original solution to incorporate " "IP approach..." + ) + _fwd_bem_ip_modify_solution(bem["solution"], ip_solution, ip_mult, nps) + bem["bem_method"] = FIFF.FIFFV_BEM_APPROX_LINEAR + bem["solver"] = "mne" + + +def _import_openmeeg(what="compute a BEM solution using OpenMEEG"): try: import openmeeg as om except Exception as exc: raise ImportError( - f'The OpenMEEG module must be installed to {what}, but ' - f'"import openmeeg" resulted in: {exc}') from None - if not _compare_version(om.__version__, '>=', '2.5.6'): - raise ImportError(f'OpenMEEG 2.5.6+ is required, got {om.__version__}') + f"The OpenMEEG module must be installed to {what}, but " + f'"import openmeeg" resulted in: {exc}' + ) from None + if not _compare_version(om.__version__, ">=", "2.5.6"): + raise ImportError(f"OpenMEEG 2.5.6+ is required, got {om.__version__}") return om @@ -322,37 +371,37 @@ def _make_openmeeg_geometry(bem, mri_head_t=None): # OpenMEEG om = _import_openmeeg() meshes = [] - for surf in bem['surfs'][::-1]: + for surf in bem["surfs"][::-1]: if mri_head_t is not None: surf = transform_surface_to(surf, "head", mri_head_t, copy=True) - points, faces = surf['rr'], surf['tris'] + points, faces = surf["rr"], surf["tris"] faces = faces[:, [1, 0, 2]] # swap faces meshes.append((points, faces)) - conductivity = bem['sigma'][::-1] + conductivity = bem["sigma"][::-1] return om.make_nested_geometry(meshes, conductivity) def _fwd_bem_openmeeg_solution(bem): om = _import_openmeeg() - logger.info('Creating BEM solution using OpenMEEG') - logger.info('Computing the openmeeg head matrix solution...') - logger.info(' Matrix coefficients...') + logger.info("Creating BEM solution using OpenMEEG") + logger.info("Computing the openmeeg head matrix solution...") + logger.info(" Matrix coefficients...") geom = _make_openmeeg_geometry(bem) hm = om.HeadMat(geom) - bem['nsol'] = hm.nlin() + bem["nsol"] = hm.nlin() logger.info(" Inverting the coefficient matrix...") hm.invert() # invert inplace - bem['solution'] = hm.array_flat() - bem['bem_method'] = FIFF.FIFFV_BEM_APPROX_LINEAR - bem['solver'] = 'openmeeg' + bem["solution"] = hm.array_flat() + bem["bem_method"] = FIFF.FIFFV_BEM_APPROX_LINEAR + bem["solver"] = "openmeeg" @verbose -def make_bem_solution(surfs, *, solver='mne', verbose=None): +def make_bem_solution(surfs, *, solver="mne", verbose=None): """Create a BEM solution using the linear collocation approach. Parameters @@ -383,76 +432,83 @@ def make_bem_solution(surfs, *, solver='mne', verbose=None): ----- .. versionadded:: 0.10.0 """ - _validate_type(solver, str, 'solver') - _check_option('method', solver.lower(), ('mne', 'openmeeg')) + _validate_type(solver, str, "solver") + _check_option("method", solver.lower(), ("mne", "openmeeg")) bem = _ensure_bem_surfaces(surfs) _add_gamma_multipliers(bem) - if len(bem['surfs']) == 3: - logger.info('Three-layer model surfaces loaded.') - elif len(bem['surfs']) == 1: - logger.info('Homogeneous model surface loaded.') + if len(bem["surfs"]) == 3: + logger.info("Three-layer model surfaces loaded.") + elif len(bem["surfs"]) == 1: + logger.info("Homogeneous model surface loaded.") else: - raise RuntimeError('Only 1- or 3-layer BEM computations supported') - _check_bem_size(bem['surfs']) - for surf in bem['surfs']: + raise RuntimeError("Only 1- or 3-layer BEM computations supported") + _check_bem_size(bem["surfs"]) + for surf in bem["surfs"]: _check_complete_surface(surf) - if solver.lower() == 'openmeeg': + if solver.lower() == "openmeeg": _fwd_bem_openmeeg_solution(bem) else: - assert solver.lower() == 'mne' + assert solver.lower() == "mne" _fwd_bem_linear_collocation_solution(bem) logger.info("Solution ready.") - logger.info('BEM geometry computations complete.') + logger.info("BEM geometry computations complete.") return bem # ############################################################################ # Make BEM model + def _ico_downsample(surf, dest_grade): """Downsample the surface if isomorphic to a subdivided icosahedron.""" - n_tri = len(surf['tris']) - bad_msg = ("Cannot decimate to requested ico grade %d. The provided " - "BEM surface has %d triangles, which cannot be isomorphic with " - "a subdivided icosahedron. Consider manually decimating the " - "surface to a suitable density and then use ico=None in " - "make_bem_model." % (dest_grade, n_tri)) + n_tri = len(surf["tris"]) + bad_msg = ( + "Cannot decimate to requested ico grade %d. The provided " + "BEM surface has %d triangles, which cannot be isomorphic with " + "a subdivided icosahedron. Consider manually decimating the " + "surface to a suitable density and then use ico=None in " + "make_bem_model." % (dest_grade, n_tri) + ) if n_tri % 20 != 0: raise RuntimeError(bad_msg) n_tri = n_tri // 20 found = int(round(np.log(n_tri) / np.log(4))) - if n_tri != 4 ** found: + if n_tri != 4**found: raise RuntimeError(bad_msg) del n_tri if dest_grade > found: - raise RuntimeError('For this surface, decimation grade should be %d ' - 'or less, not %s.' % (found, dest_grade)) + raise RuntimeError( + "For this surface, decimation grade should be %d " + "or less, not %s." % (found, dest_grade) + ) source = _get_ico_surface(found) dest = _get_ico_surface(dest_grade, patch_stats=True) - del dest['tri_cent'] - del dest['tri_nn'] - del dest['neighbor_tri'] - del dest['tri_area'] - if not np.array_equal(source['tris'], surf['tris']): - raise RuntimeError('The source surface has a matching number of ' - 'triangles but ordering is wrong') - logger.info('Going from %dth to %dth subdivision of an icosahedron ' - '(n_tri: %d -> %d)' % (found, dest_grade, len(surf['tris']), - len(dest['tris']))) + del dest["tri_cent"] + del dest["tri_nn"] + del dest["neighbor_tri"] + del dest["tri_area"] + if not np.array_equal(source["tris"], surf["tris"]): + raise RuntimeError( + "The source surface has a matching number of " + "triangles but ordering is wrong" + ) + logger.info( + "Going from %dth to %dth subdivision of an icosahedron " + "(n_tri: %d -> %d)" % (found, dest_grade, len(surf["tris"]), len(dest["tris"])) + ) # Find the mapping - dest['rr'] = surf['rr'][_get_ico_map(source, dest)] + dest["rr"] = surf["rr"][_get_ico_map(source, dest)] return dest def _get_ico_map(fro, to): """Get a mapping between ico surfaces.""" - nearest, dists = _compute_nearest(fro['rr'], to['rr'], return_dists=True) + nearest, dists = _compute_nearest(fro["rr"], to["rr"], return_dists=True) n_bads = (dists > 5e-3).sum() if n_bads > 0: - raise RuntimeError('No matching vertex for %d destination vertices' - % (n_bads)) + raise RuntimeError("No matching vertex for %d destination vertices" % (n_bads)) return nearest @@ -461,32 +517,36 @@ def _order_surfaces(surfs): if len(surfs) != 3: return surfs # we have three surfaces - surf_order = [FIFF.FIFFV_BEM_SURF_ID_HEAD, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_BRAIN] - ids = np.array([surf['id'] for surf in surfs]) + surf_order = [ + FIFF.FIFFV_BEM_SURF_ID_HEAD, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + ] + ids = np.array([surf["id"] for surf in surfs]) if set(ids) != set(surf_order): - raise RuntimeError('bad surface ids: %s' % ids) + raise RuntimeError("bad surface ids: %s" % ids) order = [np.where(ids == id_)[0][0] for id_ in surf_order] surfs = [surfs[idx] for idx in order] return surfs -def _assert_complete_surface(surf, incomplete='raise'): +def _assert_complete_surface(surf, incomplete="raise"): """Check the sum of solid angles as seen from inside.""" # from surface_checks.c # Center of mass.... - cm = surf['rr'].mean(axis=0) - logger.info('%s CM is %6.2f %6.2f %6.2f mm' % - (_bem_surf_name[surf['id']], - 1000 * cm[0], 1000 * cm[1], 1000 * cm[2])) - tot_angle = _get_solids(surf['rr'][surf['tris']], cm[np.newaxis, :])[0] + cm = surf["rr"].mean(axis=0) + logger.info( + "%s CM is %6.2f %6.2f %6.2f mm" + % (_bem_surf_name[surf["id"]], 1000 * cm[0], 1000 * cm[1], 1000 * cm[2]) + ) + tot_angle = _get_solids(surf["rr"][surf["tris"]], cm[np.newaxis, :])[0] prop = tot_angle / (2 * np.pi) if np.abs(prop - 1.0) > 1e-5: - msg = (f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' - f'solid angles yielded {prop}, should be 1.)') - _on_missing( - incomplete, msg, name='incomplete', error_klass=RuntimeError) + msg = ( + f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' + f"solid angles yielded {prop}, should be 1.)" + ) + _on_missing(incomplete, msg, name="incomplete", error_klass=RuntimeError) def _assert_inside(fro, to): @@ -494,15 +554,15 @@ def _assert_inside(fro, to): # this is "is_inside" in surface_checks.c fro_name = _bem_surf_name[fro["id"]] to_name = _bem_surf_name[to["id"]] - logger.info( - f'Checking that surface {fro_name} is inside surface {to_name} ...') - tot_angle = _get_solids(to['rr'][to['tris']], fro['rr']) + logger.info(f"Checking that surface {fro_name} is inside surface {to_name} ...") + tot_angle = _get_solids(to["rr"][to["tris"]], fro["rr"]) if (np.abs(tot_angle / (2 * np.pi) - 1.0) > 1e-5).any(): raise RuntimeError( - f'Surface {fro_name} is not completely inside surface {to_name}') + f"Surface {fro_name} is not completely inside surface {to_name}" + ) -def _check_surfaces(surfs, incomplete='raise'): +def _check_surfaces(surfs, incomplete="raise"): """Check that the surfaces are complete and non-intersecting.""" for surf in surfs: _assert_complete_surface(surf, incomplete=incomplete) @@ -513,36 +573,40 @@ def _check_surfaces(surfs, incomplete='raise'): def _check_surface_size(surf): """Check that the coordinate limits are reasonable.""" - sizes = surf['rr'].max(axis=0) - surf['rr'].min(axis=0) + sizes = surf["rr"].max(axis=0) - surf["rr"].min(axis=0) if (sizes < 0.05).any(): raise RuntimeError( f'Dimensions of the surface {_bem_surf_name[surf["id"]]} seem too ' - f'small ({1000 * sizes.min():9.5f}). Maybe the unit of measure' - ' is meters instead of mm') + f"small ({1000 * sizes.min():9.5f}). Maybe the unit of measure" + " is meters instead of mm" + ) def _check_thicknesses(surfs): """Compute how close we are.""" for surf_1, surf_2 in zip(surfs[:-1], surfs[1:]): - min_dist = _compute_nearest(surf_1['rr'], surf_2['rr'], - return_dists=True)[1] + min_dist = _compute_nearest(surf_1["rr"], surf_2["rr"], return_dists=True)[1] min_dist = min_dist.min() - fro = _bem_surf_name[surf_1['id']] - to = _bem_surf_name[surf_2['id']] - logger.info(f'Checking distance between {fro} and {to} surfaces...') - logger.info(f'Minimum distance between the {fro} and {to} surfaces is ' - f'approximately {1000 * min_dist:6.1f} mm') - - -def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, - incomplete='raise', extra=''): + fro = _bem_surf_name[surf_1["id"]] + to = _bem_surf_name[surf_2["id"]] + logger.info(f"Checking distance between {fro} and {to} surfaces...") + logger.info( + f"Minimum distance between the {fro} and {to} surfaces is " + f"approximately {1000 * min_dist:6.1f} mm" + ) + + +def _surfaces_to_bem( + surfs, ids, sigmas, ico=None, rescale=True, incomplete="raise", extra="" +): """Convert surfaces to a BEM.""" # equivalent of mne_surf2bem # surfs can be strings (filenames) or surface dicts - if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == - len(sigmas)): - raise ValueError('surfs, ids, and sigmas must all have the same ' - 'number of elements (1 or 3)') + if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == len(sigmas)): + raise ValueError( + "surfs, ids, and sigmas must all have the same " + "number of elements (1 or 3)" + ) for si, surf in enumerate(surfs): if isinstance(surf, (str, Path, os.PathLike)): surfs[si] = surf = read_surface(surf, return_dict=True)[-1] @@ -552,19 +616,18 @@ def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, surfs[si] = _ico_downsample(surf, ico) for surf, id_ in zip(surfs, ids): # Do topology checks (but don't save data) to fail early - surf['id'] = id_ - _check_complete_surface(surf, copy=True, incomplete=incomplete, - extra=extra) - surf['coord_frame'] = surf.get('coord_frame', FIFF.FIFFV_COORD_MRI) - surf.update(np=len(surf['rr']), ntri=len(surf['tris'])) + surf["id"] = id_ + _check_complete_surface(surf, copy=True, incomplete=incomplete, extra=extra) + surf["coord_frame"] = surf.get("coord_frame", FIFF.FIFFV_COORD_MRI) + surf.update(np=len(surf["rr"]), ntri=len(surf["tris"])) if rescale: - surf['rr'] /= 1000. # convert to meters + surf["rr"] /= 1000.0 # convert to meters # Shifting surfaces is not implemented here... # Order the surfaces for the benefit of the topology checks for surf, sigma in zip(surfs, sigmas): - surf['sigma'] = sigma + surf["sigma"] = sigma surfs = _order_surfaces(surfs) # Check topology as best we can @@ -572,13 +635,14 @@ def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, for surf in surfs: _check_surface_size(surf) _check_thicknesses(surfs) - logger.info('Surfaces passed the basic topology checks.') + logger.info("Surfaces passed the basic topology checks.") return surfs @verbose -def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), - subjects_dir=None, verbose=None): +def make_bem_model( + subject, ico=4, conductivity=(0.3, 0.006, 0.3), subjects_dir=None, verbose=None +): """Create a BEM model for a subject. .. note:: To get a single layer bem corresponding to the --homog flag in @@ -619,8 +683,7 @@ def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), """ conductivity = np.array(conductivity, float) if conductivity.ndim != 1 or conductivity.size not in (1, 3): - raise ValueError('conductivity must be 1D array-like with 1 or 3 ' - 'elements') + raise ValueError("conductivity must be 1D array-like with 1 or 3 " "elements") subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) subject_dir = subjects_dir / subject bem_dir = subject_dir / "bem" @@ -628,27 +691,30 @@ def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), outer_skull = bem_dir / "outer_skull.surf" outer_skin = bem_dir / "outer_skin.surf" surfaces = [inner_skull, outer_skull, outer_skin] - ids = [FIFF.FIFFV_BEM_SURF_ID_BRAIN, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_HEAD] - logger.info('Creating the BEM geometry...') + ids = [ + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_HEAD, + ] + logger.info("Creating the BEM geometry...") if len(conductivity) == 1: surfaces = surfaces[:1] ids = ids[:1] surfaces = _surfaces_to_bem(surfaces, ids, conductivity, ico) _check_bem_size(surfaces) - logger.info('Complete.\n') + logger.info("Complete.\n") return surfaces # ############################################################################ # Compute EEG sphere model + def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): """Get the model depended weighting factor for n.""" - nlayer = len(m['layers']) + nlayer = len(m["layers"]) if nlayer in (0, 1): - return 1. + return 1.0 # Initialize the arrays c1 = np.zeros(nlayer - 1) @@ -656,9 +722,9 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): cr = np.zeros(nlayer - 1) cr_mult = np.zeros(nlayer - 1) for k in range(nlayer - 1): - c1[k] = m['layers'][k]['sigma'] / m['layers'][k + 1]['sigma'] + c1[k] = m["layers"][k]["sigma"] / m["layers"][k + 1]["sigma"] c2[k] = c1[k] - 1.0 - cr_mult[k] = m['layers'][k]['rel_rad'] + cr_mult[k] = m["layers"][k]["rel_rad"] cr[k] = cr_mult[k] cr_mult[k] *= cr_mult[k] @@ -672,8 +738,13 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): M = np.eye(2) n1 = n + 1.0 for k in range(nlayer - 2, -1, -1): - M = np.dot([[n + n1 * c1[k], n1 * c2[k] / cr[k]], - [n * c2[k] * cr[k], n1 + n * c1[k]]], M) + M = np.dot( + [ + [n + n1 * c1[k], n1 * c2[k] / cr[k]], + [n * c2[k] * cr[k], n1 + n * c1[k]], + ], + M, + ) num = n * (2.0 * n + 1.0) ** (nlayer - 1) coeffs[n - 1] = num / (n * M[1, 1] + n1 * M[1, 0]) return coeffs @@ -682,15 +753,15 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): def _compose_linear_fitting_data(mu, u): """Get the linear fitting data.""" from scipy import linalg - k1 = np.arange(1, u['nterms']) + + k1 = np.arange(1, u["nterms"]) mu1ns = mu[0] ** k1 # data to be fitted - y = u['w'][:-1] * (u['fn'][1:] - mu1ns * u['fn'][0]) + y = u["w"][:-1] * (u["fn"][1:] - mu1ns * u["fn"][0]) # model matrix - M = u['w'][:-1, np.newaxis] * (mu[1:] ** k1[:, np.newaxis] - - mu1ns[:, np.newaxis]) + M = u["w"][:-1, np.newaxis] * (mu[1:] ** k1[:, np.newaxis] - mu1ns[:, np.newaxis]) uu, sing, vv = linalg.svd(M, full_matrices=False) - ncomp = u['nfit'] - 1 + ncomp = u["nfit"] - 1 uu, sing, vv = uu[:, :ncomp], sing[:ncomp], vv[:ncomp] return y, uu, sing, vv @@ -704,9 +775,9 @@ def _compute_linear_parameters(mu, u): resi = y - np.dot(uu, vec) vec /= sing - lambda_ = np.zeros(u['nfit']) + lambda_ = np.zeros(u["nfit"]) lambda_[1:] = np.dot(vec, vv) - lambda_[0] = u['fn'][0] - np.sum(lambda_[1:]) + lambda_[0] = u["fn"][0] - np.sum(lambda_[1:]) rv = np.dot(resi, resi) / np.dot(y, y) return rv, lambda_ @@ -725,27 +796,28 @@ def _one_step(mu, u): def _fwd_eeg_fit_berg_scherg(m, nterms, nfit): """Fit the Berg-Scherg equivalent spherical model dipole parameters.""" from scipy.optimize import fmin_cobyla + assert nfit >= 2 u = dict(nfit=nfit, nterms=nterms) # (1) Calculate the coefficients of the true expansion - u['fn'] = _fwd_eeg_get_multi_sphere_model_coeffs(m, nterms + 1) + u["fn"] = _fwd_eeg_get_multi_sphere_model_coeffs(m, nterms + 1) # (2) Calculate the weighting - f = (min([layer['rad'] for layer in m['layers']]) / - max([layer['rad'] for layer in m['layers']])) + f = min([layer["rad"] for layer in m["layers"]]) / max( + [layer["rad"] for layer in m["layers"]] + ) # correct weighting k = np.arange(1, nterms + 1) - u['w'] = np.sqrt((2.0 * k + 1) * (3.0 * k + 1.0) / - k) * np.power(f, (k - 1.0)) - u['w'][-1] = 0 + u["w"] = np.sqrt((2.0 * k + 1) * (3.0 * k + 1.0) / k) * np.power(f, (k - 1.0)) + u["w"][-1] = 0 # Do the nonlinear minimization, constraining mu to the interval [-1, +1] mu_0 = np.zeros(3) fun = partial(_one_step, u=u) catol = 1e-6 - max_ = 1. - 2 * catol + max_ = 1.0 - 2 * catol def cons(x): return max_ - np.abs(x) @@ -757,17 +829,22 @@ def cons(x): order = np.argsort(mu)[::-1] mu, lambda_ = mu[order], lambda_[order] # sort: largest mu first - m['mu'] = mu + m["mu"] = mu # This division takes into account the actual conductivities - m['lambda'] = lambda_ / m['layers'][-1]['sigma'] - m['nfit'] = nfit + m["lambda"] = lambda_ / m["layers"][-1]["sigma"] + m["nfit"] = nfit return rv @verbose -def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, - relative_radii=(0.90, 0.92, 0.97, 1.0), - sigmas=(0.33, 1.0, 0.004, 0.33), verbose=None): +def make_sphere_model( + r0=(0.0, 0.0, 0.04), + head_radius=0.09, + info=None, + relative_radii=(0.90, 0.92, 0.97, 1.0), + sigmas=(0.33, 1.0, 0.004, 0.33), + verbose=None, +): """Create a spherical model for forward solution calculation. Parameters @@ -809,33 +886,37 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, .. versionadded:: 0.9.0 """ - for name in ('r0', 'head_radius'): + for name in ("r0", "head_radius"): param = locals()[name] if isinstance(param, str): - if param != 'auto': - raise ValueError('%s, if str, must be "auto" not "%s"' - % (name, param)) + if param != "auto": + raise ValueError('%s, if str, must be "auto" not "%s"' % (name, param)) relative_radii = np.array(relative_radii, float).ravel() sigmas = np.array(sigmas, float).ravel() if len(relative_radii) != len(sigmas): - raise ValueError('relative_radii length (%s) must match that of ' - 'sigmas (%s)' % (len(relative_radii), - len(sigmas))) + raise ValueError( + "relative_radii length (%s) must match that of " + "sigmas (%s)" % (len(relative_radii), len(sigmas)) + ) if len(sigmas) <= 1 and head_radius is not None: - raise ValueError('at least 2 sigmas must be supplied if ' - 'head_radius is not None, got %s' % (len(sigmas),)) - if (isinstance(r0, str) and r0 == 'auto') or \ - (isinstance(head_radius, str) and head_radius == 'auto'): + raise ValueError( + "at least 2 sigmas must be supplied if " + "head_radius is not None, got %s" % (len(sigmas),) + ) + if (isinstance(r0, str) and r0 == "auto") or ( + isinstance(head_radius, str) and head_radius == "auto" + ): if info is None: - raise ValueError('Info must not be None for auto mode') - head_radius_fit, r0_fit = fit_sphere_to_headshape(info, units='m')[:2] + raise ValueError("Info must not be None for auto mode") + head_radius_fit, r0_fit = fit_sphere_to_headshape(info, units="m")[:2] if isinstance(r0, str): r0 = r0_fit if isinstance(head_radius, str): head_radius = head_radius_fit - sphere = ConductorModel(is_sphere=True, r0=np.array(r0), - coord_frame=FIFF.FIFFV_COORD_HEAD) - sphere['layers'] = list() + sphere = ConductorModel( + is_sphere=True, r0=np.array(r0), coord_frame=FIFF.FIFFV_COORD_HEAD + ) + sphere["layers"] = list() if head_radius is not None: # Eventually these could be configurable... relative_radii = np.array(relative_radii, float) @@ -846,15 +927,15 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, for rel_rad, sig in zip(relative_radii, sigmas): # sort layers by (relative) radius, and scale radii layer = dict(rad=rel_rad, sigma=sig) - layer['rel_rad'] = layer['rad'] = rel_rad - sphere['layers'].append(layer) + layer["rel_rad"] = layer["rad"] = rel_rad + sphere["layers"].append(layer) # scale the radii - R = sphere['layers'][-1]['rad'] - rR = sphere['layers'][-1]['rel_rad'] - for layer in sphere['layers']: - layer['rad'] /= R - layer['rel_rad'] /= rR + R = sphere["layers"][-1]["rad"] + rR = sphere["layers"][-1]["rel_rad"] + for layer in sphere["layers"]: + layer["rad"] /= R + layer["rel_rad"] /= rR # # Setup the EEG sphere model calculations @@ -862,25 +943,32 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, # Scale the relative radii for k in range(len(relative_radii)): - sphere['layers'][k]['rad'] = (head_radius * - sphere['layers'][k]['rel_rad']) + sphere["layers"][k]["rad"] = head_radius * sphere["layers"][k]["rel_rad"] rv = _fwd_eeg_fit_berg_scherg(sphere, 200, 3) - logger.info('\nEquiv. model fitting -> RV = %g %%' % (100 * rv)) + logger.info("\nEquiv. model fitting -> RV = %g %%" % (100 * rv)) for k in range(3): - logger.info('mu%d = %g lambda%d = %g' - % (k + 1, sphere['mu'][k], k + 1, - sphere['layers'][-1]['sigma'] * - sphere['lambda'][k])) - logger.info('Set up EEG sphere model with scalp radius %7.1f mm\n' - % (1000 * head_radius,)) + logger.info( + "mu%d = %g lambda%d = %g" + % ( + k + 1, + sphere["mu"][k], + k + 1, + sphere["layers"][-1]["sigma"] * sphere["lambda"][k], + ) + ) + logger.info( + "Set up EEG sphere model with scalp radius %7.1f mm\n" + % (1000 * head_radius,) + ) return sphere # ############################################################################# # Sphere fitting + @verbose -def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): +def fit_sphere_to_headshape(info, dig_kinds="auto", units="m", verbose=None): """Fit a sphere to the headshape points to determine head center. Parameters @@ -907,11 +995,10 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): This function excludes any points that are low and frontal (``z < 0 and y > 0``) to improve the fit. """ - if not isinstance(units, str) or units not in ('m', 'mm'): + if not isinstance(units, str) or units not in ("m", "mm"): raise ValueError('units must be a "m" or "mm"') - radius, origin_head, origin_device = _fit_sphere_to_headshape( - info, dig_kinds) - if units == 'mm': + radius, origin_head, origin_device = _fit_sphere_to_headshape(info, dig_kinds) + if units == "mm": radius *= 1e3 origin_head *= 1e3 origin_device *= 1e3 @@ -919,8 +1006,7 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): @verbose -def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, - verbose=None): +def get_fitting_dig(info, dig_kinds="auto", exclude_frontal=True, verbose=None): """Get digitization points suitable for sphere fitting. Parameters @@ -946,17 +1032,18 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, .. versionadded:: 0.14 """ _validate_type(info, "info") - if info['dig'] is None: - raise RuntimeError('Cannot fit headshape without digitization ' - ', info["dig"] is None') + if info["dig"] is None: + raise RuntimeError( + "Cannot fit headshape without digitization " ', info["dig"] is None' + ) if isinstance(dig_kinds, str): - if dig_kinds == 'auto': + if dig_kinds == "auto": # try "extra" first try: - return get_fitting_dig(info, 'extra') + return get_fitting_dig(info, "extra") except ValueError: pass - return get_fitting_dig(info, ('extra', 'eeg')) + return get_fitting_dig(info, ("extra", "eeg")) else: dig_kinds = (dig_kinds,) # convert string args to ints (first make dig_kinds mutable in case tuple) @@ -964,19 +1051,21 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, for di, d in enumerate(dig_kinds): dig_kinds[di] = _dig_kind_dict.get(d, d) if dig_kinds[di] not in _dig_kind_ints: - raise ValueError('dig_kinds[#%d] (%s) must be one of %s' - % (di, d, sorted(list(_dig_kind_dict.keys())))) + raise ValueError( + "dig_kinds[#%d] (%s) must be one of %s" + % (di, d, sorted(list(_dig_kind_dict.keys()))) + ) # get head digization points of the specified kind(s) - dig = [p for p in info['dig'] if p['kind'] in dig_kinds] + dig = [p for p in info["dig"] if p["kind"] in dig_kinds] if len(dig) == 0: - raise ValueError( - f'No digitization points found for dig_kinds={dig_kinds}') - if any(p['coord_frame'] != FIFF.FIFFV_COORD_HEAD for p in dig): + raise ValueError(f"No digitization points found for dig_kinds={dig_kinds}") + if any(p["coord_frame"] != FIFF.FIFFV_COORD_HEAD for p in dig): raise RuntimeError( - f'Digitization points dig_kinds={dig_kinds} not in head ' - 'coordinates, contact mne-python developers') - hsp = [p['r'] for p in dig] + f"Digitization points dig_kinds={dig_kinds} not in head " + "coordinates, contact mne-python developers" + ) + hsp = [p["r"] for p in dig] del dig # exclude some frontal points (nose etc.) @@ -985,14 +1074,16 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, hsp = np.array(hsp) if len(hsp) <= 10: - kinds_str = ', '.join(['"%s"' % _dig_kind_rev[d] - for d in sorted(dig_kinds)]) - msg = ('Only %s head digitization points of the specified kind%s (%s,)' - % (len(hsp), _pl(dig_kinds), kinds_str)) + kinds_str = ", ".join(['"%s"' % _dig_kind_rev[d] for d in sorted(dig_kinds)]) + msg = "Only %s head digitization points of the specified kind%s (%s,)" % ( + len(hsp), + _pl(dig_kinds), + kinds_str, + ) if len(hsp) < 4: - raise ValueError(msg + ', at least 4 required') + raise ValueError(msg + ", at least 4 required") else: - warn(msg + ', fitting may be inaccurate') + warn(msg + ", fitting may be inaccurate") return hsp @@ -1002,33 +1093,39 @@ def _fit_sphere_to_headshape(info, dig_kinds, verbose=None): hsp = get_fitting_dig(info, dig_kinds) radius, origin_head = _fit_sphere(np.array(hsp), disp=False) # compute origin in device coordinates - dev_head_t = info['dev_head_t'] + dev_head_t = info["dev_head_t"] if dev_head_t is None: - dev_head_t = Transform('meg', 'head') - head_to_dev = _ensure_trans(dev_head_t, 'head', 'meg') + dev_head_t = Transform("meg", "head") + head_to_dev = _ensure_trans(dev_head_t, "head", "meg") origin_device = apply_trans(head_to_dev, origin_head) - logger.info('Fitted sphere radius:'.ljust(30) + '%0.1f mm' - % (radius * 1e3,)) + logger.info("Fitted sphere radius:".ljust(30) + "%0.1f mm" % (radius * 1e3,)) _check_head_radius(radius) # > 2 cm away from head center in X or Y is strange if np.linalg.norm(origin_head[:2]) > 0.02: - warn('(X, Y) fit (%0.1f, %0.1f) more than 20 mm from ' - 'head frame origin' % tuple(1e3 * origin_head[:2])) - logger.info('Origin head coordinates:'.ljust(30) + - '%0.1f %0.1f %0.1f mm' % tuple(1e3 * origin_head)) - logger.info('Origin device coordinates:'.ljust(30) + - '%0.1f %0.1f %0.1f mm' % tuple(1e3 * origin_device)) + warn( + "(X, Y) fit (%0.1f, %0.1f) more than 20 mm from " + "head frame origin" % tuple(1e3 * origin_head[:2]) + ) + logger.info( + "Origin head coordinates:".ljust(30) + + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_head) + ) + logger.info( + "Origin device coordinates:".ljust(30) + + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_device) + ) return radius, origin_head, origin_device -def _fit_sphere(points, disp='auto'): +def _fit_sphere(points, disp="auto"): """Fit a sphere to an arbitrary set of points.""" from scipy.optimize import fmin_cobyla - if isinstance(disp, str) and disp == 'auto': + + if isinstance(disp, str) and disp == "auto": disp = True if logger.level <= 20 else False # initial guess for center and radius - radii = (np.max(points, axis=1) - np.min(points, axis=1)) / 2. + radii = (np.max(points, axis=1) - np.min(points, axis=1)) / 2.0 radius_init = radii.mean() center_init = np.median(points, axis=0) @@ -1043,38 +1140,46 @@ def cost_fun(center_rad): def constraint(center_rad): return center_rad[3] # radius must be >= 0 - x_opt = fmin_cobyla(cost_fun, x0, constraint, rhobeg=radius_init, - rhoend=radius_init * 1e-6, disp=disp) + x_opt = fmin_cobyla( + cost_fun, + x0, + constraint, + rhobeg=radius_init, + rhoend=radius_init * 1e-6, + disp=disp, + ) origin, radius = x_opt[:3], x_opt[3] return radius, origin -def _check_origin(origin, info, coord_frame='head', disp=False): +def _check_origin(origin, info, coord_frame="head", disp=False): """Check or auto-determine the origin.""" if isinstance(origin, str): - if origin != 'auto': - raise ValueError('origin must be a numerical array, or "auto", ' - 'not %s' % (origin,)) - if coord_frame == 'head': + if origin != "auto": + raise ValueError( + 'origin must be a numerical array, or "auto", ' "not %s" % (origin,) + ) + if coord_frame == "head": R, origin = fit_sphere_to_headshape( - info, verbose=_verbose_safe_false(), units='m')[:2] - logger.info(' Automatic origin fit: head of radius %0.1f mm' - % (R * 1000.,)) + info, verbose=_verbose_safe_false(), units="m" + )[:2] + logger.info( + " Automatic origin fit: head of radius %0.1f mm" % (R * 1000.0,) + ) del R else: - origin = (0., 0., 0.) + origin = (0.0, 0.0, 0.0) origin = np.array(origin, float) if origin.shape != (3,): - raise ValueError('origin must be a 3-element array') + raise ValueError("origin must be a 3-element array") if disp: - origin_str = ', '.join(['%0.1f' % (o * 1000) for o in origin]) - msg = (' Using origin %s mm in the %s frame' - % (origin_str, coord_frame)) - if coord_frame == 'meg' and info['dev_head_t'] is not None: - o_dev = apply_trans(info['dev_head_t'], origin) - origin_str = ', '.join('%0.1f' % (o * 1000,) for o in o_dev) - msg += ' (%s mm in the head frame)' % (origin_str,) + origin_str = ", ".join(["%0.1f" % (o * 1000) for o in origin]) + msg = " Using origin %s mm in the %s frame" % (origin_str, coord_frame) + if coord_frame == "meg" and info["dev_head_t"] is not None: + o_dev = apply_trans(info["dev_head_t"], origin) + origin_str = ", ".join("%0.1f" % (o * 1000,) for o in o_dev) + msg += " (%s mm in the head frame)" % (origin_str,) logger.info(msg) return origin @@ -1082,11 +1187,22 @@ def _check_origin(origin, info, coord_frame='head', disp=False): # ############################################################################ # Create BEM surfaces + @verbose -def make_watershed_bem(subject, subjects_dir=None, overwrite=False, - volume='T1', atlas=False, gcaatlas=False, preflood=None, - show=False, copy=True, T1=None, brainmask='ws.mgz', - verbose=None): +def make_watershed_bem( + subject, + subjects_dir=None, + overwrite=False, + volume="T1", + atlas=False, + gcaatlas=False, + preflood=None, + show=False, + copy=True, + T1=None, + brainmask="ws.mgz", + verbose=None, +): """Create BEM surfaces using the FreeSurfer watershed algorithm. Parameters @@ -1141,78 +1257,97 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False, .. versionadded:: 0.10 """ from .viz.misc import plot_bem + env, mri_dir, bem_dir = _prepare_env(subject, subjects_dir) tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) - subjects_dir = env['SUBJECTS_DIR'] # Set by _prepare_env() above. + subjects_dir = env["SUBJECTS_DIR"] # Set by _prepare_env() above. subject_dir = op.join(subjects_dir, subject) - ws_dir = op.join(bem_dir, 'watershed') + ws_dir = op.join(bem_dir, "watershed") T1_dir = op.join(mri_dir, volume) T1_mgz = T1_dir - if not T1_dir.endswith('.mgz'): - T1_mgz += '.mgz' + if not T1_dir.endswith(".mgz"): + T1_mgz += ".mgz" if not op.isdir(bem_dir): os.makedirs(bem_dir) - _check_fname(T1_mgz, overwrite='read', must_exist=True, name='MRI data') + _check_fname(T1_mgz, overwrite="read", must_exist=True, name="MRI data") if op.isdir(ws_dir): if not overwrite: - raise RuntimeError('%s already exists. Use the --overwrite option' - ' to recreate it.' % ws_dir) + raise RuntimeError( + "%s already exists. Use the --overwrite option" + " to recreate it." % ws_dir + ) else: shutil.rmtree(ws_dir) # put together the command - cmd = ['mri_watershed'] + cmd = ["mri_watershed"] if preflood: cmd += ["-h", "%s" % int(preflood)] if T1 is None: T1 = gcaatlas if T1: - cmd += ['-T1'] + cmd += ["-T1"] if gcaatlas: - fname = op.join(env['FREESURFER_HOME'], 'average', - 'RB_all_withskull_*.gca') + fname = op.join(env["FREESURFER_HOME"], "average", "RB_all_withskull_*.gca") fname = sorted(glob.glob(fname))[::-1][0] - logger.info('Using GCA atlas: %s' % (fname,)) - cmd += ['-atlas', '-brain_atlas', fname, - subject_dir + '/mri/transforms/talairach_with_skull.lta'] + logger.info("Using GCA atlas: %s" % (fname,)) + cmd += [ + "-atlas", + "-brain_atlas", + fname, + subject_dir + "/mri/transforms/talairach_with_skull.lta", + ] elif atlas: - cmd += ['-atlas'] + cmd += ["-atlas"] if op.exists(T1_mgz): - cmd += ['-useSRAS', '-surf', op.join(ws_dir, subject), T1_mgz, - op.join(ws_dir, brainmask)] + cmd += [ + "-useSRAS", + "-surf", + op.join(ws_dir, subject), + T1_mgz, + op.join(ws_dir, brainmask), + ] else: - cmd += ['-useSRAS', '-surf', op.join(ws_dir, subject), T1_dir, - op.join(ws_dir, brainmask)] + cmd += [ + "-useSRAS", + "-surf", + op.join(ws_dir, subject), + T1_dir, + op.join(ws_dir, brainmask), + ] # report and run - logger.info('\nRunning mri_watershed for BEM segmentation with the ' - 'following parameters:\n\nResults dir = %s\nCommand = %s\n' - % (ws_dir, ' '.join(cmd))) + logger.info( + "\nRunning mri_watershed for BEM segmentation with the " + "following parameters:\n\nResults dir = %s\nCommand = %s\n" + % (ws_dir, " ".join(cmd)) + ) os.makedirs(op.join(ws_dir)) run_subprocess_env(cmd) del tempdir # clean up directory if op.isfile(T1_mgz): new_info = _extract_volume_info(T1_mgz) if not new_info: - warn('nibabel is not available or the volume info is invalid.' - 'Volume info not updated in the written surface.') - surfs = ['brain', 'inner_skull', 'outer_skull', 'outer_skin'] + warn( + "nibabel is not available or the volume info is invalid." + "Volume info not updated in the written surface." + ) + surfs = ["brain", "inner_skull", "outer_skull", "outer_skin"] for s in surfs: - surf_ws_out = op.join(ws_dir, '%s_%s_surface' % (subject, s)) + surf_ws_out = op.join(ws_dir, "%s_%s_surface" % (subject, s)) - rr, tris, volume_info = read_surface(surf_ws_out, - read_metadata=True) + rr, tris, volume_info = read_surface(surf_ws_out, read_metadata=True) # replace volume info, 'head' stays volume_info.update(new_info) - write_surface(surf_ws_out, rr, tris, volume_info=volume_info, - overwrite=True) + write_surface( + surf_ws_out, rr, tris, volume_info=volume_info, overwrite=True + ) # Create symbolic links - surf_out = op.join(bem_dir, '%s.surf' % s) + surf_out = op.join(bem_dir, "%s.surf" % s) if not overwrite and op.exists(surf_out): skip_symlink = True else: @@ -1222,48 +1357,60 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False, skip_symlink = False if skip_symlink: - logger.info("Unable to create all symbolic links to .surf files " - "in bem folder. Use --overwrite option to recreate " - "them.") - dest = op.join(bem_dir, 'watershed') + logger.info( + "Unable to create all symbolic links to .surf files " + "in bem folder. Use --overwrite option to recreate " + "them." + ) + dest = op.join(bem_dir, "watershed") else: logger.info("Symbolic links to .surf files created in bem folder") dest = bem_dir - logger.info("\nThank you for waiting.\nThe BEM triangulations for this " - "subject are now available at:\n%s." % dest) + logger.info( + "\nThank you for waiting.\nThe BEM triangulations for this " + "subject are now available at:\n%s." % dest + ) # Write a head file for coregistration - fname_head = op.join(bem_dir, subject + '-head.fif') + fname_head = op.join(bem_dir, subject + "-head.fif") if op.isfile(fname_head): os.remove(fname_head) - surf = _surfaces_to_bem([op.join(ws_dir, subject + '_outer_skin_surface')], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], sigmas=[1]) + surf = _surfaces_to_bem( + [op.join(ws_dir, subject + "_outer_skin_surface")], + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + sigmas=[1], + ) write_bem_surfaces(fname_head, surf) # Show computed BEM surfaces if show: - plot_bem(subject=subject, subjects_dir=subjects_dir, - orientation='coronal', slices=None, show=True) + plot_bem( + subject=subject, + subjects_dir=subjects_dir, + orientation="coronal", + slices=None, + show=True, + ) - logger.info('Created %s\n\nComplete.' % (fname_head,)) + logger.info("Created %s\n\nComplete." % (fname_head,)) def _extract_volume_info(mgz): """Extract volume info from a mgz file.""" nib = _import_nibabel() header = nib.load(mgz).header - version = header['version'] + version = header["version"] vol_info = dict() if version == 1: - version = '%s # volume info valid' % version - vol_info['valid'] = version - vol_info['filename'] = mgz - vol_info['volume'] = header['dims'][:3] - vol_info['voxelsize'] = header['delta'] - vol_info['xras'], vol_info['yras'], vol_info['zras'] = header['Mdc'] - vol_info['cras'] = header['Pxyz_c'] + version = "%s # volume info valid" % version + vol_info["valid"] = version + vol_info["filename"] = mgz + vol_info["volume"] = header["dims"][:3] + vol_info["voxelsize"] = header["delta"] + vol_info["xras"], vol_info["yras"], vol_info["zras"] = header["Mdc"] + vol_info["cras"] = header["Pxyz_c"] return vol_info @@ -1271,9 +1418,11 @@ def _extract_volume_info(mgz): # ############################################################################ # Read + @verbose -def read_bem_surfaces(fname, patch_stats=False, s_id=None, on_defects='raise', - verbose=None): +def read_bem_surfaces( + fname, patch_stats=False, s_id=None, on_defects="raise", verbose=None +): """Read the BEM surfaces from a FIF file. Parameters @@ -1302,16 +1451,16 @@ def read_bem_surfaces(fname, patch_stats=False, s_id=None, on_defects='raise', write_bem_surfaces, write_bem_solution, make_bem_model """ # Open the file, create directory - _validate_type(s_id, ('int-like', None), 's_id') - fname = _check_fname(fname, 'read', True, 'fname') + _validate_type(s_id, ("int-like", None), "s_id") + fname = _check_fname(fname, "read", True, "fname") if fname.suffix == ".h5": surf = _read_bem_surfaces_h5(fname, s_id) else: surf = _read_bem_surfaces_fif(fname, s_id) if s_id is not None and len(surf) != 1: - raise ValueError('surface with id %d not found' % s_id) + raise ValueError("surface with id %d not found" % s_id) for this in surf: - if patch_stats or this['nn'] is None: + if patch_stats or this["nn"] is None: _check_complete_surface(this, incomplete=on_defects) return surf[0] if s_id is not None else surf @@ -1320,12 +1469,12 @@ def _read_bem_surfaces_h5(fname, s_id): read_hdf5, _ = _import_h5io_funcs() bem = read_hdf5(fname) try: - [s['id'] for s in bem['surfs']] + [s["id"] for s in bem["surfs"]] except Exception: # not our format - raise ValueError('BEM data not found') - surf = bem['surfs'] + raise ValueError("BEM data not found") + surf = bem["surfs"] if s_id is not None: - surf = [s for s in surf if s['id'] == s_id] + surf = [s for s in surf if s["id"] == s_id] return surf @@ -1337,32 +1486,33 @@ def _read_bem_surfaces_fif(fname, s_id): # Find BEM bem = dir_tree_find(tree, FIFF.FIFFB_BEM) if bem is None or len(bem) == 0: - raise ValueError('BEM data not found') + raise ValueError("BEM data not found") bem = bem[0] # Locate all surfaces bemsurf = dir_tree_find(bem, FIFF.FIFFB_BEM_SURF) if bemsurf is None: - raise ValueError('BEM surface data not found') + raise ValueError("BEM surface data not found") - logger.info(' %d BEM surfaces found' % len(bemsurf)) + logger.info(" %d BEM surfaces found" % len(bemsurf)) # Coordinate frame possibly at the top level tag = find_tag(fid, bem, FIFF.FIFF_BEM_COORD_FRAME) if tag is not None: coord_frame = tag.data # Read all surfaces if s_id is not None: - surf = [_read_bem_surface(fid, bsurf, coord_frame, s_id) - for bsurf in bemsurf] + surf = [ + _read_bem_surface(fid, bsurf, coord_frame, s_id) for bsurf in bemsurf + ] surf = [s for s in surf if s is not None] else: surf = list() for bsurf in bemsurf: - logger.info(' Reading a surface...') + logger.info(" Reading a surface...") this = _read_bem_surface(fid, bsurf, coord_frame) surf.append(this) - logger.info('[done]') - logger.info(' %d BEM surfaces read' % len(surf)) + logger.info("[done]") + logger.info(" %d BEM surfaces read" % len(surf)) return surf @@ -1374,63 +1524,63 @@ def _read_bem_surface(fid, this, def_coord_frame, s_id=None): tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_ID) if tag is None: - res['id'] = FIFF.FIFFV_BEM_SURF_ID_UNKNOWN + res["id"] = FIFF.FIFFV_BEM_SURF_ID_UNKNOWN else: - res['id'] = int(tag.data.item()) + res["id"] = int(tag.data.item()) - if s_id is not None and res['id'] != s_id: + if s_id is not None and res["id"] != s_id: return None tag = find_tag(fid, this, FIFF.FIFF_BEM_SIGMA) - res['sigma'] = 1.0 if tag is None else float(tag.data.item()) + res["sigma"] = 1.0 if tag is None else float(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NNODE) if tag is None: - raise ValueError('Number of vertices not found') + raise ValueError("Number of vertices not found") - res['np'] = int(tag.data.item()) + res["np"] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NTRI) if tag is None: - raise ValueError('Number of triangles not found') - res['ntri'] = int(tag.data.item()) + raise ValueError("Number of triangles not found") + res["ntri"] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COORD_FRAME) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_BEM_COORD_FRAME) if tag is None: - res['coord_frame'] = def_coord_frame + res["coord_frame"] = def_coord_frame else: - res['coord_frame'] = int(tag.data.item()) + res["coord_frame"] = int(tag.data.item()) else: - res['coord_frame'] = int(tag.data.item()) + res["coord_frame"] = int(tag.data.item()) # Vertices, normals, and triangles tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NODES) if tag is None: - raise ValueError('Vertex data not found') + raise ValueError("Vertex data not found") - res['rr'] = tag.data.astype(np.float64) - if res['rr'].shape[0] != res['np']: - raise ValueError('Vertex information is incorrect') + res["rr"] = tag.data.astype(np.float64) + if res["rr"].shape[0] != res["np"]: + raise ValueError("Vertex information is incorrect") tag = find_tag(fid, this, FIFF.FIFF_MNE_SOURCE_SPACE_NORMALS) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NORMALS) if tag is None: - res['nn'] = None + res["nn"] = None else: - res['nn'] = tag.data.astype(np.float64) - if res['nn'].shape[0] != res['np']: - raise ValueError('Vertex normal information is incorrect') + res["nn"] = tag.data.astype(np.float64) + if res["nn"].shape[0] != res["np"]: + raise ValueError("Vertex normal information is incorrect") tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_TRIANGLES) if tag is None: - raise ValueError('Triangulation not found') + raise ValueError("Triangulation not found") - res['tris'] = tag.data - 1 # index start at 0 in Python - if res['tris'].shape[0] != res['ntri']: - raise ValueError('Triangulation information is incorrect') + res["tris"] = tag.data - 1 # index start at 0 in Python + if res["tris"].shape[0] != res["ntri"]: + raise ValueError("Triangulation information is incorrect") return res @@ -1457,97 +1607,104 @@ def read_bem_solution(fname, *, verbose=None): make_bem_solution write_bem_solution """ - fname = _check_fname(fname, 'read', True, 'fname') + fname = _check_fname(fname, "read", True, "fname") # mirrors fwd_bem_load_surfaces from fwd_bem_model.c if fname.suffix == ".h5": read_hdf5, _ = _import_h5io_funcs() - logger.info('Loading surfaces and solution...') + logger.info("Loading surfaces and solution...") bem = read_hdf5(fname) - if 'solver' not in bem: - bem['solver'] = 'mne' + if "solver" not in bem: + bem["solver"] = "mne" else: bem = _read_bem_solution_fif(fname) - if len(bem['surfs']) == 3: - logger.info('Three-layer model surfaces loaded.') - needed = np.array([FIFF.FIFFV_BEM_SURF_ID_HEAD, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_BRAIN]) - if not all(x['id'] in needed for x in bem['surfs']): - raise RuntimeError('Could not find necessary BEM surfaces') + if len(bem["surfs"]) == 3: + logger.info("Three-layer model surfaces loaded.") + needed = np.array( + [ + FIFF.FIFFV_BEM_SURF_ID_HEAD, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + ] + ) + if not all(x["id"] in needed for x in bem["surfs"]): + raise RuntimeError("Could not find necessary BEM surfaces") # reorder surfaces as necessary (shouldn't need to?) reorder = [None] * 3 - for x in bem['surfs']: - reorder[np.where(x['id'] == needed)[0][0]] = x - bem['surfs'] = reorder - elif len(bem['surfs']) == 1: - if not bem['surfs'][0]['id'] == FIFF.FIFFV_BEM_SURF_ID_BRAIN: - raise RuntimeError('BEM Surfaces not found') - logger.info('Homogeneous model surface loaded.') - - assert set(bem.keys()) == set( - ('surfs', 'solution', 'bem_method', 'solver')) + for x in bem["surfs"]: + reorder[np.where(x["id"] == needed)[0][0]] = x + bem["surfs"] = reorder + elif len(bem["surfs"]) == 1: + if not bem["surfs"][0]["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN: + raise RuntimeError("BEM Surfaces not found") + logger.info("Homogeneous model surface loaded.") + + assert set(bem.keys()) == set(("surfs", "solution", "bem_method", "solver")) bem = ConductorModel(bem) - bem['is_sphere'] = False + bem["is_sphere"] = False # sanity checks and conversions _check_option( - 'BEM approximation method', bem['bem_method'], - (FIFF.FIFFV_BEM_APPROX_LINEAR,)) # CONSTANT not supported + "BEM approximation method", bem["bem_method"], (FIFF.FIFFV_BEM_APPROX_LINEAR,) + ) # CONSTANT not supported dim = 0 - solver = bem.get('solver', 'mne') - _check_option('BEM solver', solver, ('mne', 'openmeeg')) - for si, surf in enumerate(bem['surfs']): - assert bem['bem_method'] == FIFF.FIFFV_BEM_APPROX_LINEAR - dim += surf['np'] - if solver == 'openmeeg' and si != 0: - dim += surf['ntri'] - dims = bem['solution'].shape + solver = bem.get("solver", "mne") + _check_option("BEM solver", solver, ("mne", "openmeeg")) + for si, surf in enumerate(bem["surfs"]): + assert bem["bem_method"] == FIFF.FIFFV_BEM_APPROX_LINEAR + dim += surf["np"] + if solver == "openmeeg" and si != 0: + dim += surf["ntri"] + dims = bem["solution"].shape if solver == "openmeeg": sz = (dim * (dim + 1)) // 2 if len(dims) != 1 or dims[0] != sz: raise RuntimeError( - 'For the given BEM surfaces, OpenMEEG should produce a ' - f'solution matrix of shape ({sz},) but got {dims}') - bem['nsol'] = dim + "For the given BEM surfaces, OpenMEEG should produce a " + f"solution matrix of shape ({sz},) but got {dims}" + ) + bem["nsol"] = dim else: if len(dims) != 2 and solver != "openmeeg": - raise RuntimeError('Expected a two-dimensional solution matrix ' - 'instead of a %d dimensional one' % dims[0]) + raise RuntimeError( + "Expected a two-dimensional solution matrix " + "instead of a %d dimensional one" % dims[0] + ) if dims[0] != dim or dims[1] != dim: - raise RuntimeError('Expected a %d x %d solution matrix instead of ' - 'a %d x %d one' % (dim, dim, dims[1], dims[0])) - bem['nsol'] = bem['solution'].shape[0] + raise RuntimeError( + "Expected a %d x %d solution matrix instead of " + "a %d x %d one" % (dim, dim, dims[1], dims[0]) + ) + bem["nsol"] = bem["solution"].shape[0] # Gamma factors and multipliers _add_gamma_multipliers(bem) - extra = f'made by {solver}' if solver != 'mne' else '' - logger.info(f'Loaded linear collocation BEM solution{extra} from {fname}') + extra = f"made by {solver}" if solver != "mne" else "" + logger.info(f"Loaded linear collocation BEM solution{extra} from {fname}") return bem def _read_bem_solution_fif(fname): - logger.info('Loading surfaces...') - surfs = read_bem_surfaces( - fname, patch_stats=True, verbose=_verbose_safe_false()) + logger.info("Loading surfaces...") + surfs = read_bem_surfaces(fname, patch_stats=True, verbose=_verbose_safe_false()) # convert from surfaces to solution - logger.info('\nLoading the solution matrix...\n') - solver = 'mne' + logger.info("\nLoading the solution matrix...\n") + solver = "mne" f, tree, _ = fiff_open(fname) with f as fid: # Find the BEM data nodes = dir_tree_find(tree, FIFF.FIFFB_BEM) if len(nodes) == 0: - raise RuntimeError('No BEM data in %s' % fname) + raise RuntimeError("No BEM data in %s" % fname) bem_node = nodes[0] # Approximation method tag = find_tag(f, bem_node, FIFF.FIFF_DESCRIPTION) if tag is not None: tag = json.loads(tag.data) - solver = tag['solver'] + solver = tag["solver"] tag = find_tag(f, bem_node, FIFF.FIFF_BEM_APPROX) if tag is None: - raise RuntimeError('No BEM solution found in %s' % fname) + raise RuntimeError("No BEM solution found in %s" % fname) method = tag.data[0] tag = find_tag(fid, bem_node, FIFF.FIFF_BEM_POT_SOLUTION) sol = tag.data @@ -1557,73 +1714,77 @@ def _read_bem_solution_fif(fname): def _add_gamma_multipliers(bem): """Add gamma and multipliers in-place.""" - bem['sigma'] = np.array([surf['sigma'] for surf in bem['surfs']]) + bem["sigma"] = np.array([surf["sigma"] for surf in bem["surfs"]]) # Dirty trick for the zero conductivity outside - sigma = np.r_[0.0, bem['sigma']] - bem['source_mult'] = 2.0 / (sigma[1:] + sigma[:-1]) - bem['field_mult'] = sigma[1:] - sigma[:-1] + sigma = np.r_[0.0, bem["sigma"]] + bem["source_mult"] = 2.0 / (sigma[1:] + sigma[:-1]) + bem["field_mult"] = sigma[1:] - sigma[:-1] # make sure subsequent "zip"s work correctly - assert len(bem['surfs']) == len(bem['field_mult']) - bem['gamma'] = ((sigma[1:] - sigma[:-1])[np.newaxis, :] / - (sigma[1:] + sigma[:-1])[:, np.newaxis]) + assert len(bem["surfs"]) == len(bem["field_mult"]) + bem["gamma"] = (sigma[1:] - sigma[:-1])[np.newaxis, :] / (sigma[1:] + sigma[:-1])[ + :, np.newaxis + ] # In our BEM code we do not model the CSF so we assign the innermost surface # the id BRAIN. Our 4-layer sphere we model CSF (at least by default), so when # searching for and referring to surfaces we need to keep track of this. -_sm_surf_dict = OrderedDict([ - ('brain', FIFF.FIFFV_BEM_SURF_ID_BRAIN), - ('inner_skull', FIFF.FIFFV_BEM_SURF_ID_CSF), - ('outer_skull', FIFF.FIFFV_BEM_SURF_ID_SKULL), - ('head', FIFF.FIFFV_BEM_SURF_ID_HEAD), -]) +_sm_surf_dict = OrderedDict( + [ + ("brain", FIFF.FIFFV_BEM_SURF_ID_BRAIN), + ("inner_skull", FIFF.FIFFV_BEM_SURF_ID_CSF), + ("outer_skull", FIFF.FIFFV_BEM_SURF_ID_SKULL), + ("head", FIFF.FIFFV_BEM_SURF_ID_HEAD), + ] +) _bem_surf_dict = { - 'inner_skull': FIFF.FIFFV_BEM_SURF_ID_BRAIN, - 'outer_skull': FIFF.FIFFV_BEM_SURF_ID_SKULL, - 'head': FIFF.FIFFV_BEM_SURF_ID_HEAD, + "inner_skull": FIFF.FIFFV_BEM_SURF_ID_BRAIN, + "outer_skull": FIFF.FIFFV_BEM_SURF_ID_SKULL, + "head": FIFF.FIFFV_BEM_SURF_ID_HEAD, } _bem_surf_name = { - FIFF.FIFFV_BEM_SURF_ID_BRAIN: 'inner skull', - FIFF.FIFFV_BEM_SURF_ID_SKULL: 'outer skull', - FIFF.FIFFV_BEM_SURF_ID_HEAD: 'outer skin ', - FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: 'unknown ', + FIFF.FIFFV_BEM_SURF_ID_BRAIN: "inner skull", + FIFF.FIFFV_BEM_SURF_ID_SKULL: "outer skull", + FIFF.FIFFV_BEM_SURF_ID_HEAD: "outer skin ", + FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: "unknown ", } _sm_surf_name = { - FIFF.FIFFV_BEM_SURF_ID_BRAIN: 'brain', - FIFF.FIFFV_BEM_SURF_ID_CSF: 'csf', - FIFF.FIFFV_BEM_SURF_ID_SKULL: 'outer skull', - FIFF.FIFFV_BEM_SURF_ID_HEAD: 'outer skin ', - FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: 'unknown ', + FIFF.FIFFV_BEM_SURF_ID_BRAIN: "brain", + FIFF.FIFFV_BEM_SURF_ID_CSF: "csf", + FIFF.FIFFV_BEM_SURF_ID_SKULL: "outer skull", + FIFF.FIFFV_BEM_SURF_ID_HEAD: "outer skin ", + FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: "unknown ", } def _bem_find_surface(bem, id_): """Find surface from already-loaded conductor model.""" - if bem['is_sphere']: + if bem["is_sphere"]: _surf_dict = _sm_surf_dict _name_dict = _sm_surf_name - kind = 'Sphere model' - tri = 'boundary' + kind = "Sphere model" + tri = "boundary" else: _surf_dict = _bem_surf_dict _name_dict = _bem_surf_name - kind = 'BEM' - tri = 'triangulation' + kind = "BEM" + tri = "triangulation" if isinstance(id_, str): name = id_ id_ = _surf_dict[id_] else: name = _name_dict[id_] - kind = 'Sphere model' if bem['is_sphere'] else 'BEM' - idx = np.where(np.array([s['id'] for s in bem['surfs']]) == id_)[0] + kind = "Sphere model" if bem["is_sphere"] else "BEM" + idx = np.where(np.array([s["id"] for s in bem["surfs"]]) == id_)[0] if len(idx) != 1: - raise RuntimeError(f'{kind} does not have the {name} {tri}') - return bem['surfs'][idx[0]] + raise RuntimeError(f"{kind} does not have the {name} {tri}") + return bem["surfs"][idx[0]] # ############################################################################ # Write + @verbose def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): """Write BEM surfaces to a FIF file. @@ -1639,7 +1800,7 @@ def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): """ if isinstance(surfs, dict): surfs = [surfs] - fname = _check_fname(fname, overwrite=overwrite, name='fname') + fname = _check_fname(fname, overwrite=overwrite, name="fname") if fname.suffix == ".h5": _, write_hdf5 = _import_h5io_funcs() @@ -1647,14 +1808,15 @@ def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): else: with start_and_end_file(fname) as fid: start_block(fid, FIFF.FIFFB_BEM) - write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, surfs[0]['coord_frame']) + write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, surfs[0]["coord_frame"]) _write_bem_surfaces_block(fid, surfs) end_block(fid, FIFF.FIFFB_BEM) @verbose -def write_head_bem(fname, rr, tris, on_defects='raise', overwrite=False, - *, verbose=None): +def write_head_bem( + fname, rr, tris, on_defects="raise", overwrite=False, *, verbose=None +): """Write a head surface to a FIF file. Parameters @@ -1670,9 +1832,13 @@ def write_head_bem(fname, rr, tris, on_defects='raise', overwrite=False, %(overwrite)s %(verbose)s """ - surf = _surfaces_to_bem([dict(rr=rr, tris=tris)], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], rescale=False, - incomplete=on_defects) + surf = _surfaces_to_bem( + [dict(rr=rr, tris=tris)], + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + [1], + rescale=False, + incomplete=on_defects, + ) write_bem_surfaces(fname, surf, overwrite=overwrite) @@ -1680,17 +1846,16 @@ def _write_bem_surfaces_block(fid, surfs): """Write bem surfaces to open file handle.""" for surf in surfs: start_block(fid, FIFF.FIFFB_BEM_SURF) - write_float(fid, FIFF.FIFF_BEM_SIGMA, surf['sigma']) - write_int(fid, FIFF.FIFF_BEM_SURF_ID, surf['id']) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, surf['coord_frame']) - write_int(fid, FIFF.FIFF_BEM_SURF_NNODE, surf['np']) - write_int(fid, FIFF.FIFF_BEM_SURF_NTRI, surf['ntri']) - write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NODES, surf['rr']) + write_float(fid, FIFF.FIFF_BEM_SIGMA, surf["sigma"]) + write_int(fid, FIFF.FIFF_BEM_SURF_ID, surf["id"]) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, surf["coord_frame"]) + write_int(fid, FIFF.FIFF_BEM_SURF_NNODE, surf["np"]) + write_int(fid, FIFF.FIFF_BEM_SURF_NTRI, surf["ntri"]) + write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NODES, surf["rr"]) # index start at 0 in Python - write_int_matrix(fid, FIFF.FIFF_BEM_SURF_TRIANGLES, - surf['tris'] + 1) - if 'nn' in surf and surf['nn'] is not None and len(surf['nn']) > 0: - write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NORMALS, surf['nn']) + write_int_matrix(fid, FIFF.FIFF_BEM_SURF_TRIANGLES, surf["tris"] + 1) + if "nn" in surf and surf["nn"] is not None and len(surf["nn"]) > 0: + write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NORMALS, surf["nn"]) end_block(fid, FIFF.FIFFB_BEM_SURF) @@ -1711,42 +1876,40 @@ def write_bem_solution(fname, bem, overwrite=False, *, verbose=None): -------- read_bem_solution """ - fname = _check_fname(fname, overwrite=overwrite, name='fname') + fname = _check_fname(fname, overwrite=overwrite, name="fname") if fname.suffix == ".h5": _, write_hdf5 = _import_h5io_funcs() - bem = {k: bem[k] for k in ('surfs', 'solution', 'bem_method')} + bem = {k: bem[k] for k in ("surfs", "solution", "bem_method")} write_hdf5(fname, bem, overwrite=True) else: _write_bem_solution_fif(fname, bem) def _write_bem_solution_fif(fname, bem): - _check_bem_size(bem['surfs']) + _check_bem_size(bem["surfs"]) with start_and_end_file(fname) as fid: start_block(fid, FIFF.FIFFB_BEM) # Coordinate frame (mainly for backward compatibility) - write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, - bem['surfs'][0]['coord_frame']) - solver = bem.get('solver', 'mne') - if solver != 'mne': - write_string( - fid, FIFF.FIFF_DESCRIPTION, json.dumps(dict(solver=solver))) + write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, bem["surfs"][0]["coord_frame"]) + solver = bem.get("solver", "mne") + if solver != "mne": + write_string(fid, FIFF.FIFF_DESCRIPTION, json.dumps(dict(solver=solver))) # Surfaces - _write_bem_surfaces_block(fid, bem['surfs']) + _write_bem_surfaces_block(fid, bem["surfs"]) # The potential solution - if 'solution' in bem: + if "solution" in bem: _check_option( - 'bem_method', bem['bem_method'], - (FIFF.FIFFV_BEM_APPROX_LINEAR,)) + "bem_method", bem["bem_method"], (FIFF.FIFFV_BEM_APPROX_LINEAR,) + ) write_int(fid, FIFF.FIFF_BEM_APPROX, FIFF.FIFFV_BEM_APPROX_LINEAR) - write_float_matrix(fid, FIFF.FIFF_BEM_POT_SOLUTION, - bem['solution']) + write_float_matrix(fid, FIFF.FIFF_BEM_POT_SOLUTION, bem["solution"]) end_block(fid, FIFF.FIFFB_BEM) # ############################################################################# # Create 3-Layers BEM model from Flash MRI images + def _prepare_env(subject, subjects_dir): """Prepare an env object for subprocess calls.""" env = os.environ.copy() @@ -1758,18 +1921,19 @@ def _prepare_env(subject, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) subject_dir = subjects_dir / subject if not subject_dir.is_dir(): - raise RuntimeError('Could not find the subject data directory "%s"' - % (subject_dir,)) - env.update(SUBJECT=subject, SUBJECTS_DIR=str(subjects_dir), - FREESURFER_HOME=fs_home) + raise RuntimeError( + 'Could not find the subject data directory "%s"' % (subject_dir,) + ) + env.update(SUBJECT=subject, SUBJECTS_DIR=str(subjects_dir), FREESURFER_HOME=fs_home) mri_dir = subject_dir / "mri" bem_dir = subject_dir / "bem" return env, mri_dir, bem_dir def _write_echos(mri_dir, flash_echos, angle): - nib = _import_nibabel('write echoes') + nib = _import_nibabel("write echoes") from nibabel.spatialimages import SpatialImage + if _path_like(flash_echos): flash_echos = nib.load(flash_echos) if isinstance(flash_echos, SpatialImage): @@ -1780,8 +1944,7 @@ def _write_echos(mri_dir, flash_echos, angle): data = data[..., np.newaxis] for echo_idx in range(data.shape[3]): this_echo_img = flash_echos.__class__( - data[..., echo_idx], affine=affine, - header=deepcopy(flash_echos.header) + data[..., echo_idx], affine=affine, header=deepcopy(flash_echos.header) ) flash_echo_imgs.append(this_echo_img) flash_echos = flash_echo_imgs @@ -1789,13 +1952,13 @@ def _write_echos(mri_dir, flash_echos, angle): for idx, flash_echo in enumerate(flash_echos, 1): if _path_like(flash_echo): flash_echo = nib.load(flash_echo) - nib.save(flash_echo, - op.join(mri_dir, 'flash', f'mef{angle}_{idx:03d}.mgz')) + nib.save(flash_echo, op.join(mri_dir, "flash", f"mef{angle}_{idx:03d}.mgz")) @verbose -def convert_flash_mris(subject, flash30=True, unwarp=False, - subjects_dir=None, flash5=True, verbose=None): +def convert_flash_mris( + subject, flash30=True, unwarp=False, subjects_dir=None, flash5=True, verbose=None +): """Synthesize the flash 5 files for use with make_flash_bem. This function aims to produce a synthesized flash 5 MRI from @@ -1843,32 +2006,30 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, """ # noqa: E501 env, mri_dir = _prepare_env(subject, subjects_dir)[:2] tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) mri_dir = Path(mri_dir) # Step 1a : Data conversion to mgz format flash_dir = mri_dir / "flash" - pm_dir = flash_dir / 'parameter_maps' + pm_dir = flash_dir / "parameter_maps" pm_dir.mkdir(parents=True, exist_ok=True) echos_done = 0 if not isinstance(flash5, bool): - _write_echos(mri_dir, flash5, angle='05') + _write_echos(mri_dir, flash5, angle="05") if not isinstance(flash30, bool): - _write_echos(mri_dir, flash30, angle='30') + _write_echos(mri_dir, flash30, angle="30") # Step 1b : Run grad_unwarp on converted files template = op.join(flash_dir, "mef*_*.mgz") files = sorted(glob.glob(template)) if len(files) == 0: - raise ValueError('No suitable source files found (%s)' % template) + raise ValueError("No suitable source files found (%s)" % template) if unwarp: logger.info("\n---- Unwarp mgz data sets ----") for infile in files: outfile = infile.replace(".mgz", "u.mgz") - cmd = ['grad_unwarp', '-i', infile, '-o', outfile, '-unwarp', - 'true'] + cmd = ["grad_unwarp", "-i", infile, "-o", outfile, "-unwarp", "true"] run_subprocess_env(cmd) # Clear parameter maps if some of the data were reconverted if echos_done > 0 and pm_dir.exists(): @@ -1882,20 +2043,24 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, if unwarp: files = sorted(glob.glob(op.join(flash_dir, "mef05_*u.mgz"))) if len(os.listdir(pm_dir)) == 0: - cmd = (['mri_ms_fitparms'] + files + [str(pm_dir)]) + cmd = ["mri_ms_fitparms"] + files + [str(pm_dir)] run_subprocess_env(cmd) else: logger.info("Parameter maps were already computed") # Step 3 : Synthesize the flash 5 images logger.info("\n---- Synthesizing flash 5 images ----") - if not (pm_dir / 'flash5.mgz').exists(): - cmd = ['mri_synthesize', '20', '5', '5', - (pm_dir / 'T1.mgz'), - (pm_dir / 'PD.mgz'), - (pm_dir / 'flash5.mgz') - ] + if not (pm_dir / "flash5.mgz").exists(): + cmd = [ + "mri_synthesize", + "20", + "5", + "5", + (pm_dir / "T1.mgz"), + (pm_dir / "PD.mgz"), + (pm_dir / "flash5.mgz"), + ] run_subprocess_env(cmd) - (pm_dir / 'flash5_reg.mgz').unlink() + (pm_dir / "flash5_reg.mgz").unlink() else: logger.info("Synthesized flash 5 volume is already there") else: @@ -1903,18 +2068,27 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, template = "mef05_*u.mgz" if unwarp else "mef05_*.mgz" files = sorted(flash_dir.glob(template)) if len(files) == 0: - raise ValueError('No suitable source files found (%s)' % template) - cmd = (['mri_average', '-noconform'] + files + [pm_dir / 'flash5.mgz']) + raise ValueError("No suitable source files found (%s)" % template) + cmd = ["mri_average", "-noconform"] + files + [pm_dir / "flash5.mgz"] run_subprocess_env(cmd) - (pm_dir / 'flash5_reg.mgz').unlink(missing_ok=True) + (pm_dir / "flash5_reg.mgz").unlink(missing_ok=True) del tempdir # finally done running subprocesses - assert (pm_dir / 'flash5.mgz').exists() - return pm_dir / 'flash5.mgz' + assert (pm_dir / "flash5.mgz").exists() + return pm_dir / "flash5.mgz" @verbose -def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, - copy=True, *, flash5_img=None, register=True, verbose=None): +def make_flash_bem( + subject, + overwrite=False, + show=True, + subjects_dir=None, + copy=True, + *, + flash5_img=None, + register=True, + verbose=None, +): """Create 3-Layer BEM model from prepared flash MRI images. Parameters @@ -1963,46 +2137,53 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, env, mri_dir, bem_dir = _prepare_env(subject, subjects_dir) tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) mri_dir = Path(mri_dir) bem_dir = Path(bem_dir) - subjects_dir = env['SUBJECTS_DIR'] - flash_path = (mri_dir / 'flash' / 'parameter_maps').resolve() + subjects_dir = env["SUBJECTS_DIR"] + flash_path = (mri_dir / "flash" / "parameter_maps").resolve() flash_path.mkdir(exist_ok=True, parents=True) - logger.info('\nProcessing the flash MRI data to produce BEM meshes with ' - 'the following parameters:\n' - 'SUBJECTS_DIR = %s\n' - 'SUBJECT = %s\n' - 'Result dir = %s\n' % (subjects_dir, subject, - bem_dir / 'flash')) + logger.info( + "\nProcessing the flash MRI data to produce BEM meshes with " + "the following parameters:\n" + "SUBJECTS_DIR = %s\n" + "SUBJECT = %s\n" + "Result dir = %s\n" % (subjects_dir, subject, bem_dir / "flash") + ) # Step 4 : Register with MPRAGE - flash5 = flash_path / 'flash5.mgz' + flash5 = flash_path / "flash5.mgz" if _path_like(flash5_img): logger.info(f"Copying flash 5 image {flash5_img} to {flash5}") - cmd = ['mri_convert', Path(flash5_img).resolve(), flash5] + cmd = ["mri_convert", Path(flash5_img).resolve(), flash5] run_subprocess_env(cmd) elif flash5_img is None: if not flash5.exists(): - raise ValueError(f'Flash 5 image cannot be found at {flash5}.') + raise ValueError(f"Flash 5 image cannot be found at {flash5}.") else: logger.info(f"Writing flash 5 image at {flash5}") - nib = _import_nibabel('write an MRI image') + nib = _import_nibabel("write an MRI image") nib.save(flash5_img, flash5) if register: logger.info("\n---- Registering flash 5 with T1 MPRAGE ----") - flash5_reg = flash_path / 'flash5_reg.mgz' + flash5_reg = flash_path / "flash5_reg.mgz" if not flash5_reg.exists(): - if (mri_dir / 'T1.mgz').exists(): - ref_volume = mri_dir / 'T1.mgz' + if (mri_dir / "T1.mgz").exists(): + ref_volume = mri_dir / "T1.mgz" else: - ref_volume = mri_dir / 'T1' - cmd = ['fsl_rigid_register', '-r', str(ref_volume), '-i', - str(flash5), '-o', str(flash5_reg)] + ref_volume = mri_dir / "T1" + cmd = [ + "fsl_rigid_register", + "-r", + str(ref_volume), + "-i", + str(flash5), + "-o", + str(flash5_reg), + ] run_subprocess_env(cmd) else: logger.info("Registered flash 5 image is already there") @@ -2011,62 +2192,61 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, # Step 5a : Convert flash5 into COR logger.info("\n---- Converting flash5 volume into COR format ----") - flash5_dir = mri_dir / 'flash5' + flash5_dir = mri_dir / "flash5" shutil.rmtree(flash5_dir, ignore_errors=True) flash5_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', flash5_reg, flash5_dir] + cmd = ["mri_convert", flash5_reg, flash5_dir] run_subprocess_env(cmd) # Step 5b and c : Convert the mgz volumes into COR convert_T1 = False - T1_dir = mri_dir / 'T1' - if not T1_dir.is_dir() or next(T1_dir.glob('COR*')) is None: + T1_dir = mri_dir / "T1" + if not T1_dir.is_dir() or next(T1_dir.glob("COR*")) is None: convert_T1 = True convert_brain = False - brain_dir = mri_dir / 'brain' - if not brain_dir.is_dir() or next(brain_dir.glob('COR*')) is None: + brain_dir = mri_dir / "brain" + if not brain_dir.is_dir() or next(brain_dir.glob("COR*")) is None: convert_brain = True logger.info("\n---- Converting T1 volume into COR format ----") if convert_T1: - T1_fname = mri_dir / 'T1.mgz' + T1_fname = mri_dir / "T1.mgz" if not T1_fname.is_file(): raise RuntimeError("Both T1 mgz and T1 COR volumes missing.") T1_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', T1_fname, T1_dir] + cmd = ["mri_convert", T1_fname, T1_dir] run_subprocess_env(cmd) else: logger.info("T1 volume is already in COR format") logger.info("\n---- Converting brain volume into COR format ----") if convert_brain: - brain_fname = mri_dir / 'brain.mgz' + brain_fname = mri_dir / "brain.mgz" if not brain_fname.is_file(): raise RuntimeError("Both brain mgz and brain COR volumes missing.") brain_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', brain_fname, brain_dir] + cmd = ["mri_convert", brain_fname, brain_dir] run_subprocess_env(cmd) else: logger.info("Brain volume is already in COR format") # Finally ready to go logger.info("\n---- Creating the BEM surfaces ----") - cmd = ['mri_make_bem_surfaces', subject] + cmd = ["mri_make_bem_surfaces", subject] run_subprocess_env(cmd) del tempdir # ran our last subprocess; clean up directory logger.info("\n---- Converting the tri files into surf files ----") - flash_bem_dir = bem_dir / 'flash' + flash_bem_dir = bem_dir / "flash" flash_bem_dir.mkdir(exist_ok=True, parents=True) - surfs = ['inner_skull', 'outer_skull', 'outer_skin'] + surfs = ["inner_skull", "outer_skull", "outer_skin"] for surf in surfs: - out_fname = flash_bem_dir / (surf + '.tri') - shutil.move(bem_dir / (surf + '.tri'), out_fname) + out_fname = flash_bem_dir / (surf + ".tri") + shutil.move(bem_dir / (surf + ".tri"), out_fname) nodes, tris = read_tri(out_fname, swap=True) # Do not write volume info here because the tris are already in # standard Freesurfer coords - write_surface(op.splitext(out_fname)[0] + '.surf', nodes, tris, - overwrite=True) + write_surface(op.splitext(out_fname)[0] + ".surf", nodes, tris, overwrite=True) # Cleanup section logger.info("\n---- Cleaning up ----") - (bem_dir / 'inner_skull_tmp.tri').unlink() + (bem_dir / "inner_skull_tmp.tri").unlink() if convert_T1: shutil.rmtree(T1_dir) logger.info("Deleted the T1 COR volume") @@ -2079,7 +2259,7 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, logger.info("\n---- Creating symbolic links ----") # os.chdir(bem_dir) for surf in surfs: - surf = bem_dir / (surf + '.surf') + surf = bem_dir / (surf + ".surf") if not overwrite and surf.exists(): skip_symlink = True else: @@ -2088,28 +2268,38 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, _symlink(flash_bem_dir / surf.name, surf, copy) skip_symlink = False if skip_symlink: - logger.info("Unable to create all symbolic links to .surf files " - "in bem folder. Use --overwrite option to recreate them.") - dest = bem_dir / 'flash' + logger.info( + "Unable to create all symbolic links to .surf files " + "in bem folder. Use --overwrite option to recreate them." + ) + dest = bem_dir / "flash" else: logger.info("Symbolic links to .surf files created in bem folder") dest = bem_dir - logger.info("\nThank you for waiting.\nThe BEM triangulations for this " - "subject are now available at:\n%s.\nWe hope the BEM meshes " - "created will facilitate your MEG and EEG data analyses." - % dest) + logger.info( + "\nThank you for waiting.\nThe BEM triangulations for this " + "subject are now available at:\n%s.\nWe hope the BEM meshes " + "created will facilitate your MEG and EEG data analyses." % dest + ) # Show computed BEM surfaces if show: - plot_bem(subject=subject, subjects_dir=subjects_dir, - orientation='coronal', slices=None, show=True) + plot_bem( + subject=subject, + subjects_dir=subjects_dir, + orientation="coronal", + slices=None, + show=True, + ) def _check_bem_size(surfs): """Check bem surface sizes.""" - if len(surfs) > 1 and surfs[0]['np'] > 10000: - warn('The bem surfaces have %s data points. 5120 (ico grade=4) ' - 'should be enough. Dense 3-layer bems may not save properly.' % - surfs[0]['np']) + if len(surfs) > 1 and surfs[0]["np"] > 10000: + warn( + "The bem surfaces have %s data points. 5120 (ico grade=4) " + "should be enough. Dense 3-layer bems may not save properly." + % surfs[0]["np"] + ) def _symlink(src, dest, copy=False): @@ -2119,40 +2309,41 @@ def _symlink(src, dest, copy=False): try: os.symlink(src_link, dest) except OSError: - warn('Could not create symbolic link %s. Check that your ' - 'partition handles symbolic links. The file will be copied ' - 'instead.' % dest) + warn( + "Could not create symbolic link %s. Check that your " + "partition handles symbolic links. The file will be copied " + "instead." % dest + ) copy = True if copy: shutil.copy(src, dest) -def _ensure_bem_surfaces(bem, extra_allow=(), name='bem'): +def _ensure_bem_surfaces(bem, extra_allow=(), name="bem"): # by default only allow path-like and list, but handle None and # ConductorModel properly if need be. Always return a ConductorModel # even though it's incomplete (and might have is_sphere=True). assert all(extra in (None, ConductorModel) for extra in extra_allow) - allowed = ('path-like', list) + extra_allow + allowed = ("path-like", list) + extra_allow _validate_type(bem, allowed, name) if isinstance(bem, path_like): # Load the surfaces - logger.info(f'Loading BEM surfaces from {str(bem)}...') + logger.info(f"Loading BEM surfaces from {str(bem)}...") bem = read_bem_surfaces(bem) bem = ConductorModel(is_sphere=False, surfs=bem) elif isinstance(bem, list): for ii, this_surf in enumerate(bem): - _validate_type(this_surf, dict, f'{name}[{ii}]') + _validate_type(this_surf, dict, f"{name}[{ii}]") if isinstance(bem, list): bem = ConductorModel(is_sphere=False, surfs=bem) # add surfaces in the spherical case - if isinstance(bem, ConductorModel) and bem['is_sphere']: + if isinstance(bem, ConductorModel) and bem["is_sphere"]: bem = bem.copy() - bem['surfs'] = [] - if len(bem['layers']) == 4: + bem["surfs"] = [] + if len(bem["layers"]) == 4: for idx, id_ in enumerate(_sm_surf_dict.values()): - bem['surfs'].append(_complete_sphere_surf( - bem, idx, 4, complete=False)) - bem['surfs'][-1]['id'] = id_ + bem["surfs"].append(_complete_sphere_surf(bem, idx, 4, complete=False)) + bem["surfs"][-1]["id"] = id_ return bem @@ -2160,7 +2351,7 @@ def _ensure_bem_surfaces(bem, extra_allow=(), name='bem'): def _check_file(fname, overwrite): """Prevent overwrites.""" if op.isfile(fname) and not overwrite: - raise OSError(f'File {fname} exists, use --overwrite to overwrite it') + raise OSError(f"File {fname} exists, use --overwrite to overwrite it") _tri_levels = dict( @@ -2170,9 +2361,17 @@ def _check_file(fname, overwrite): @verbose -def make_scalp_surfaces(subject, subjects_dir=None, force=True, - overwrite=False, no_decimate=False, *, - threshold=20, mri='T1.mgz', verbose=None): +def make_scalp_surfaces( + subject, + subjects_dir=None, + force=True, + overwrite=False, + no_decimate=False, + *, + threshold=20, + mri="T1.mgz", + verbose=None, +): """Create surfaces of the scalp and neck. The scalp surfaces are required for using the MNE coregistration GUI, and @@ -2204,22 +2403,24 @@ def make_scalp_surfaces(subject, subjects_dir=None, force=True, %(verbose)s """ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - incomplete = 'warn' if force else 'raise' + incomplete = "warn" if force else "raise" subj_path = subjects_dir / subject if not subj_path.exists(): - raise RuntimeError('%s does not exist. Please check your subject ' - 'directory path.' % subj_path) + raise RuntimeError( + "%s does not exist. Please check your subject " + "directory path." % subj_path + ) # Backward compat for old FreeSurfer (?) - _validate_type(mri, str, 'mri') - if mri == 'T1.mgz': + _validate_type(mri, str, "mri") + if mri == "T1.mgz": mri = mri if (subj_path / "mri" / mri).exists() else "T1" - logger.info('1. Creating a dense scalp tessellation with mkheadsurf...') + logger.info("1. Creating a dense scalp tessellation with mkheadsurf...") def check_seghead(surf_path=subj_path / "surf"): surf = None - for k in ['lh.seghead', 'lh.smseghead']: + for k in ["lh.seghead", "lh.smseghead"]: this_surf = surf_path / k if this_surf.exists(): surf = this_surf @@ -2227,61 +2428,79 @@ def check_seghead(surf_path=subj_path / "surf"): return surf my_seghead = check_seghead() - threshold = _ensure_int(threshold, 'threshold') + threshold = _ensure_int(threshold, "threshold") if my_seghead is None: this_env = deepcopy(os.environ) - this_env['SUBJECTS_DIR'] = str(subjects_dir) - this_env['SUBJECT'] = subject - this_env['subjdir'] = str(subj_path) - if 'FREESURFER_HOME' not in this_env: + this_env["SUBJECTS_DIR"] = str(subjects_dir) + this_env["SUBJECT"] = subject + this_env["subjdir"] = str(subj_path) + if "FREESURFER_HOME" not in this_env: raise RuntimeError( - 'The FreeSurfer environment needs to be set up to use ' - 'make_scalp_surfaces to create the outer skin surface ' - 'lh.seghead') - run_subprocess([ - 'mkheadsurf', '-subjid', subject, '-srcvol', mri, - '-thresh1', str(threshold), - '-thresh2', str(threshold)], env=this_env) + "The FreeSurfer environment needs to be set up to use " + "make_scalp_surfaces to create the outer skin surface " + "lh.seghead" + ) + run_subprocess( + [ + "mkheadsurf", + "-subjid", + subject, + "-srcvol", + mri, + "-thresh1", + str(threshold), + "-thresh2", + str(threshold), + ], + env=this_env, + ) surf = check_seghead() if surf is None: - raise RuntimeError('mkheadsurf did not produce the standard output ' - 'file.') + raise RuntimeError("mkheadsurf did not produce the standard output " "file.") bem_dir = subjects_dir / subject / "bem" if not bem_dir.is_dir(): os.mkdir(bem_dir) fname_template = bem_dir / ("%s-head-{}.fif" % subject) - dense_fname = str(fname_template).format('dense') - logger.info('2. Creating %s ...' % dense_fname) + dense_fname = str(fname_template).format("dense") + logger.info("2. Creating %s ..." % dense_fname) _check_file(dense_fname, overwrite) # Helpful message if we get a topology error - msg = ('\n\nConsider using pymeshfix directly to fix the mesh, or --force ' - 'to ignore the problem.') + msg = ( + "\n\nConsider using pymeshfix directly to fix the mesh, or --force " + "to ignore the problem." + ) surf = _surfaces_to_bem( - [surf], [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], - incomplete=incomplete, extra=msg)[0] + [surf], [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], incomplete=incomplete, extra=msg + )[0] write_bem_surfaces(dense_fname, surf, overwrite=overwrite) - if os.getenv('_MNE_TESTING_SCALP', 'false') == 'true': - tris = [len(surf['tris'])] # don't actually decimate + if os.getenv("_MNE_TESTING_SCALP", "false") == "true": + tris = [len(surf["tris"])] # don't actually decimate for ii, (level, n_tri) in enumerate(_tri_levels.items(), 3): if no_decimate: break - logger.info(f'{ii}. Creating {level} tessellation...') - logger.info(f'{ii}.1 Decimating the dense tessellation ' - f'({len(surf["tris"])} -> {n_tri} triangles)...') - points, tris = decimate_surface(points=surf['rr'], - triangles=surf['tris'], - n_triangles=n_tri) + logger.info(f"{ii}. Creating {level} tessellation...") + logger.info( + f"{ii}.1 Decimating the dense tessellation " + f'({len(surf["tris"])} -> {n_tri} triangles)...' + ) + points, tris = decimate_surface( + points=surf["rr"], triangles=surf["tris"], n_triangles=n_tri + ) dec_fname = str(fname_template).format(level) - logger.info('%i.2 Creating %s' % (ii, dec_fname)) + logger.info("%i.2 Creating %s" % (ii, dec_fname)) _check_file(dec_fname, overwrite) dec_surf = _surfaces_to_bem( [dict(rr=points, tris=tris)], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], rescale=False, - incomplete=incomplete, extra=msg) + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + [1], + rescale=False, + incomplete=incomplete, + extra=msg, + ) write_bem_surfaces(dec_fname, dec_surf, overwrite=overwrite) - logger.info('[done]') + logger.info("[done]") @verbose @@ -2318,28 +2537,23 @@ def distance_to_bem(pos, bem, trans=None, verbose=None): distance = np.zeros((n,)) logger.info( - 'Computing distance to inner skull surface for ' + - f'{n} position{_pl(n)}...' + "Computing distance to inner skull surface for " + f"{n} position{_pl(n)}..." ) - if bem['is_sphere']: - center = bem['r0'] + if bem["is_sphere"]: + center = bem["r0"] if trans: center = apply_trans(trans, center, move=True) - radius = bem['layers'][0]['rad'] + radius = bem["layers"][0]["rad"] - distance = np.abs(radius - np.linalg.norm( - pos - center, axis=1 - )) + distance = np.abs(radius - np.linalg.norm(pos - center, axis=1)) else: # is BEM - surface_points = bem['surfs'][0]['rr'] + surface_points = bem["surfs"][0]["rr"] if trans: - surface_points = apply_trans( - trans, surface_points, move=True - ) + surface_points = apply_trans(trans, surface_points, move=True) _, distance = _compute_nearest(surface_points, pos, return_dists=True) diff --git a/mne/channels/__init__.py b/mne/channels/__init__.py index c5701c7b2b1..cfd48fc2449 100644 --- a/mne/channels/__init__.py +++ b/mne/channels/__init__.py @@ -4,41 +4,81 @@ """ from ..defaults import HEAD_SIZE_DEFAULT -from .layout import (Layout, make_eeg_layout, make_grid_layout, read_layout, - find_layout, generate_2d_layout) -from .montage import (DigMontage, - get_builtin_montages, make_dig_montage, read_dig_dat, - read_dig_egi, read_dig_captrak, read_dig_fif, - read_dig_polhemus_isotrak, read_polhemus_fastscan, - compute_dev_head_t, make_standard_montage, - read_custom_montage, read_dig_hpts, read_dig_localite, - compute_native_head_t) -from .channels import (equalize_channels, rename_channels, fix_mag_coil_types, - read_ch_adjacency, _get_ch_type, find_ch_adjacency, - make_1020_channel_selections, combine_channels, - read_vectorview_selection, _SELECTIONS, _EEG_SELECTIONS, - _divide_to_regions, get_builtin_ch_adjacencies) +from .layout import ( + Layout, + make_eeg_layout, + make_grid_layout, + read_layout, + find_layout, + generate_2d_layout, +) +from .montage import ( + DigMontage, + get_builtin_montages, + make_dig_montage, + read_dig_dat, + read_dig_egi, + read_dig_captrak, + read_dig_fif, + read_dig_polhemus_isotrak, + read_polhemus_fastscan, + compute_dev_head_t, + make_standard_montage, + read_custom_montage, + read_dig_hpts, + read_dig_localite, + compute_native_head_t, +) +from .channels import ( + equalize_channels, + rename_channels, + fix_mag_coil_types, + read_ch_adjacency, + _get_ch_type, + find_ch_adjacency, + make_1020_channel_selections, + combine_channels, + read_vectorview_selection, + _SELECTIONS, + _EEG_SELECTIONS, + _divide_to_regions, + get_builtin_ch_adjacencies, +) __all__ = [ # Data Structures - 'DigMontage', 'Layout', - + "DigMontage", + "Layout", # Factory Methods - 'make_dig_montage', 'make_eeg_layout', 'make_grid_layout', - 'make_standard_montage', - + "make_dig_montage", + "make_eeg_layout", + "make_grid_layout", + "make_standard_montage", # Readers - 'read_ch_adjacency', 'read_dig_captrak', 'read_dig_dat', - 'read_dig_egi', 'read_dig_fif', 'read_dig_localite', - 'read_dig_polhemus_isotrak', 'read_layout', - 'read_polhemus_fastscan', 'read_custom_montage', 'read_dig_hpts', - + "read_ch_adjacency", + "read_dig_captrak", + "read_dig_dat", + "read_dig_egi", + "read_dig_fif", + "read_dig_localite", + "read_dig_polhemus_isotrak", + "read_layout", + "read_polhemus_fastscan", + "read_custom_montage", + "read_dig_hpts", # Helpers - 'rename_channels', 'make_1020_channel_selections', - '_get_ch_type', 'equalize_channels', 'find_ch_adjacency', 'find_layout', - 'fix_mag_coil_types', 'generate_2d_layout', 'get_builtin_montages', - 'combine_channels', 'read_vectorview_selection', - + "rename_channels", + "make_1020_channel_selections", + "_get_ch_type", + "equalize_channels", + "find_ch_adjacency", + "find_layout", + "fix_mag_coil_types", + "generate_2d_layout", + "get_builtin_montages", + "combine_channels", + "read_vectorview_selection", # Other - 'compute_dev_head_t', 'compute_native_head_t', + "compute_dev_head_t", + "compute_native_head_t", ] diff --git a/mne/channels/_dig_montage_utils.py b/mne/channels/_dig_montage_utils.py index a60418e84d4..d03cbc1fcbe 100644 --- a/mne/channels/_dig_montage_utils.py +++ b/mne/channels/_dig_montage_utils.py @@ -20,71 +20,76 @@ def _read_dig_montage_egi( - fname, - _scaling, - _all_data_kwargs_are_none, + fname, + _scaling, + _all_data_kwargs_are_none, ): - if not _all_data_kwargs_are_none: - raise ValueError('hsp, hpi, elp, point_names, fif must all be ' - 'None if egi is not None') - _check_fname(fname, overwrite='read', must_exist=True) + raise ValueError( + "hsp, hpi, elp, point_names, fif must all be " "None if egi is not None" + ) + _check_fname(fname, overwrite="read", must_exist=True) root = ElementTree.parse(fname).getroot() - ns = root.tag[root.tag.index('{'):root.tag.index('}') + 1] - sensors = root.find('%ssensorLayout/%ssensors' % (ns, ns)) + ns = root.tag[root.tag.index("{") : root.tag.index("}") + 1] + sensors = root.find("%ssensorLayout/%ssensors" % (ns, ns)) fids = dict() dig_ch_pos = dict() - fid_name_map = {'Nasion': 'nasion', - 'Right periauricular point': 'rpa', - 'Left periauricular point': 'lpa'} + fid_name_map = { + "Nasion": "nasion", + "Right periauricular point": "rpa", + "Left periauricular point": "lpa", + } for s in sensors: name, number, kind = s[0].text, int(s[1].text), int(s[2].text) - coordinates = np.array([float(s[3].text), float(s[4].text), - float(s[5].text)]) + coordinates = np.array([float(s[3].text), float(s[4].text), float(s[5].text)]) coordinates *= _scaling # EEG Channels if kind == 0: - dig_ch_pos['EEG %03d' % number] = coordinates + dig_ch_pos["EEG %03d" % number] = coordinates # Reference elif kind == 1: - dig_ch_pos['EEG %03d' % - (len(dig_ch_pos.keys()) + 1)] = coordinates + dig_ch_pos["EEG %03d" % (len(dig_ch_pos.keys()) + 1)] = coordinates # Fiducials elif kind == 2: fid_name = fid_name_map[name] fids[fid_name] = coordinates # Unknown else: - warn('Unknown sensor type %s detected. Skipping sensor...' - 'Proceed with caution!' % kind) + warn( + "Unknown sensor type %s detected. Skipping sensor..." + "Proceed with caution!" % kind + ) return Bunch( # EGI stuff - nasion=fids['nasion'], lpa=fids['lpa'], rpa=fids['rpa'], - ch_pos=dig_ch_pos, coord_frame='unknown', + nasion=fids["nasion"], + lpa=fids["lpa"], + rpa=fids["rpa"], + ch_pos=dig_ch_pos, + coord_frame="unknown", ) def _parse_brainvision_dig_montage(fname, scale): - FID_NAME_MAP = {'Nasion': 'nasion', 'RPA': 'rpa', 'LPA': 'lpa'} + FID_NAME_MAP = {"Nasion": "nasion", "RPA": "rpa", "LPA": "lpa"} root = ElementTree.parse(fname).getroot() - sensors = root.find('CapTrakElectrodeList') + sensors = root.find("CapTrakElectrodeList") fids, dig_ch_pos = dict(), dict() for s in sensors: - name = s.find('Name').text + name = s.find("Name").text is_fid = name in FID_NAME_MAP - coordinates = scale * np.array([float(s.find('X').text), - float(s.find('Y').text), - float(s.find('Z').text)]) + coordinates = scale * np.array( + [float(s.find("X").text), float(s.find("Y").text), float(s.find("Z").text)] + ) # Fiducials if is_fid: @@ -95,6 +100,9 @@ def _parse_brainvision_dig_montage(fname, scale): return dict( # BVCT stuff - nasion=fids['nasion'], lpa=fids['lpa'], rpa=fids['rpa'], - ch_pos=dig_ch_pos, coord_frame='unknown' + nasion=fids["nasion"], + lpa=fids["lpa"], + rpa=fids["rpa"], + ch_pos=dig_ch_pos, + coord_frame="unknown", ) diff --git a/mne/channels/_standard_montage_utils.py b/mne/channels/_standard_montage_utils.py index b83252c0dc1..c136b107924 100644 --- a/mne/channels/_standard_montage_utils.py +++ b/mne/channels/_standard_montage_utils.py @@ -17,28 +17,32 @@ from ..utils import warn, _pl from . import __file__ as _CHANNELS_INIT_FILE -MONTAGE_PATH = op.join(op.dirname(_CHANNELS_INIT_FILE), 'data', 'montages') +MONTAGE_PATH = op.join(op.dirname(_CHANNELS_INIT_FILE), "data", "montages") -_str = 'U100' +_str = "U100" # In standard_1020, T9=LPA, T10=RPA, Nasion is the same as Iz with a # sign-flipped Y value + def _egi_256(head_size): - fname = op.join(MONTAGE_PATH, 'EGI_256.csd') + fname = op.join(MONTAGE_PATH, "EGI_256.csd") montage = _read_csd(fname, head_size) ch_pos = montage._get_ch_pos() # For this cap, the Nasion is the frontmost electrode, # LPA/RPA we approximate by putting 75% of the way (toward the front) # between the two electrodes that are halfway down the ear holes - nasion = ch_pos['E31'] - lpa = 0.75 * ch_pos['E67'] + 0.25 * ch_pos['E94'] - rpa = 0.75 * ch_pos['E219'] + 0.25 * ch_pos['E190'] + nasion = ch_pos["E31"] + lpa = 0.75 * ch_pos["E67"] + 0.25 * ch_pos["E94"] + rpa = 0.75 * ch_pos["E219"] + 0.25 * ch_pos["E190"] fids_montage = make_dig_montage( - coord_frame='unknown', nasion=nasion, lpa=lpa, rpa=rpa, + coord_frame="unknown", + nasion=nasion, + lpa=lpa, + rpa=rpa, ) montage += fids_montage # add fiducials to montage @@ -63,119 +67,116 @@ def _str_names(ch_names): def _safe_np_loadtxt(fname, **kwargs): out = np.genfromtxt(fname, **kwargs) - ch_names = _str_names(out['f0']) - others = tuple(out['f%d' % ii] for ii in range(1, len(out.dtype.fields))) + ch_names = _str_names(out["f0"]) + others = tuple(out["f%d" % ii] for ii in range(1, len(out.dtype.fields))) return (ch_names,) + others def _biosemi(basename, head_size): fname = op.join(MONTAGE_PATH, basename) - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") return _read_theta_phi_in_degrees(fname, head_size, fid_names) -def _mgh_or_standard(basename, head_size, coord_frame='unknown'): - fid_names = ('Nz', 'LPA', 'RPA') +def _mgh_or_standard(basename, head_size, coord_frame="unknown"): + fid_names = ("Nz", "LPA", "RPA") fname = op.join(MONTAGE_PATH, basename) ch_names_, pos = [], [] with open(fname) as fid: # Ignore units as we will scale later using the norms anyway for line in fid: - if 'Positions\n' in line: + if "Positions\n" in line: break pos = [] for line in fid: - if 'Labels\n' in line: + if "Labels\n" in line: break pos.append(list(map(float, line.split()))) for line in fid: - if not line or not set(line) - {' '}: + if not line or not set(line) - {" "}: break - ch_names_.append(line.strip(' ').strip('\n')) + ch_names_.append(line.strip(" ").strip("\n")) - pos = np.array(pos) / 1000. + pos = np.array(pos) / 1000.0 ch_pos = _check_dupes_odict(ch_names_, pos) nasion, lpa, rpa = [ch_pos.pop(n) for n in fid_names] if head_size is None: - scale = 1. + scale = 1.0 else: scale = head_size / np.median(np.linalg.norm(pos, axis=1)) for value in ch_pos.values(): value *= scale # if we are in MRI/MNI coordinates, we need to replace nasion, LPA, and RPA # with those of fsaverage for ``trans='fsaverage'`` to work - if coord_frame == 'mri': - lpa, nasion, rpa = [ - x['r'].copy() for x in get_mni_fiducials('fsaverage')] + if coord_frame == "mri": + lpa, nasion, rpa = [x["r"].copy() for x in get_mni_fiducials("fsaverage")] nasion *= scale lpa *= scale rpa *= scale - return make_dig_montage(ch_pos=ch_pos, coord_frame=coord_frame, - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame=coord_frame, nasion=nasion, lpa=lpa, rpa=rpa + ) standard_montage_look_up_table = { - 'EGI_256': _egi_256, - - 'easycap-M1': partial(_easycap, basename='easycap-M1.txt'), - 'easycap-M10': partial(_easycap, basename='easycap-M10.txt'), - - 'GSN-HydroCel-128': partial(_hydrocel, basename='GSN-HydroCel-128.sfp'), - 'GSN-HydroCel-129': partial(_hydrocel, basename='GSN-HydroCel-129.sfp'), - 'GSN-HydroCel-256': partial(_hydrocel, basename='GSN-HydroCel-256.sfp'), - 'GSN-HydroCel-257': partial(_hydrocel, basename='GSN-HydroCel-257.sfp'), - 'GSN-HydroCel-32': partial(_hydrocel, basename='GSN-HydroCel-32.sfp'), - 'GSN-HydroCel-64_1.0': partial(_hydrocel, - basename='GSN-HydroCel-64_1.0.sfp'), - 'GSN-HydroCel-65_1.0': partial(_hydrocel, - basename='GSN-HydroCel-65_1.0.sfp'), - - 'biosemi128': partial(_biosemi, basename='biosemi128.txt'), - 'biosemi16': partial(_biosemi, basename='biosemi16.txt'), - 'biosemi160': partial(_biosemi, basename='biosemi160.txt'), - 'biosemi256': partial(_biosemi, basename='biosemi256.txt'), - 'biosemi32': partial(_biosemi, basename='biosemi32.txt'), - 'biosemi64': partial(_biosemi, basename='biosemi64.txt'), - - 'mgh60': partial(_mgh_or_standard, basename='mgh60.elc', - coord_frame='mri'), - 'mgh70': partial(_mgh_or_standard, basename='mgh70.elc', - coord_frame='mri'), - 'standard_1005': partial(_mgh_or_standard, - basename='standard_1005.elc', coord_frame='mri'), - 'standard_1020': partial(_mgh_or_standard, - basename='standard_1020.elc', coord_frame='mri'), - 'standard_alphabetic': partial(_mgh_or_standard, - basename='standard_alphabetic.elc', - coord_frame='mri'), - 'standard_postfixed': partial(_mgh_or_standard, - basename='standard_postfixed.elc', - coord_frame='mri'), - 'standard_prefixed': partial(_mgh_or_standard, - basename='standard_prefixed.elc', - coord_frame='mri'), - 'standard_primed': partial(_mgh_or_standard, - basename='standard_primed.elc', - coord_frame='mri'), - 'artinis-octamon': partial(_mgh_or_standard, coord_frame='mri', - basename='artinis-octamon.elc'), - 'artinis-brite23': partial(_mgh_or_standard, coord_frame='mri', - basename='artinis-brite23.elc'), - 'brainproducts-RNP-BA-128': partial( - _easycap, basename='brainproducts-RNP-BA-128.txt') + "EGI_256": _egi_256, + "easycap-M1": partial(_easycap, basename="easycap-M1.txt"), + "easycap-M10": partial(_easycap, basename="easycap-M10.txt"), + "GSN-HydroCel-128": partial(_hydrocel, basename="GSN-HydroCel-128.sfp"), + "GSN-HydroCel-129": partial(_hydrocel, basename="GSN-HydroCel-129.sfp"), + "GSN-HydroCel-256": partial(_hydrocel, basename="GSN-HydroCel-256.sfp"), + "GSN-HydroCel-257": partial(_hydrocel, basename="GSN-HydroCel-257.sfp"), + "GSN-HydroCel-32": partial(_hydrocel, basename="GSN-HydroCel-32.sfp"), + "GSN-HydroCel-64_1.0": partial(_hydrocel, basename="GSN-HydroCel-64_1.0.sfp"), + "GSN-HydroCel-65_1.0": partial(_hydrocel, basename="GSN-HydroCel-65_1.0.sfp"), + "biosemi128": partial(_biosemi, basename="biosemi128.txt"), + "biosemi16": partial(_biosemi, basename="biosemi16.txt"), + "biosemi160": partial(_biosemi, basename="biosemi160.txt"), + "biosemi256": partial(_biosemi, basename="biosemi256.txt"), + "biosemi32": partial(_biosemi, basename="biosemi32.txt"), + "biosemi64": partial(_biosemi, basename="biosemi64.txt"), + "mgh60": partial(_mgh_or_standard, basename="mgh60.elc", coord_frame="mri"), + "mgh70": partial(_mgh_or_standard, basename="mgh70.elc", coord_frame="mri"), + "standard_1005": partial( + _mgh_or_standard, basename="standard_1005.elc", coord_frame="mri" + ), + "standard_1020": partial( + _mgh_or_standard, basename="standard_1020.elc", coord_frame="mri" + ), + "standard_alphabetic": partial( + _mgh_or_standard, basename="standard_alphabetic.elc", coord_frame="mri" + ), + "standard_postfixed": partial( + _mgh_or_standard, basename="standard_postfixed.elc", coord_frame="mri" + ), + "standard_prefixed": partial( + _mgh_or_standard, basename="standard_prefixed.elc", coord_frame="mri" + ), + "standard_primed": partial( + _mgh_or_standard, basename="standard_primed.elc", coord_frame="mri" + ), + "artinis-octamon": partial( + _mgh_or_standard, coord_frame="mri", basename="artinis-octamon.elc" + ), + "artinis-brite23": partial( + _mgh_or_standard, coord_frame="mri", basename="artinis-brite23.elc" + ), + "brainproducts-RNP-BA-128": partial( + _easycap, basename="brainproducts-RNP-BA-128.txt" + ), } def _read_sfp(fname, head_size): """Read .sfp BESA/EGI files.""" # fname has been already checked - fid_names = ('FidNz', 'FidT9', 'FidT10') - options = dict(dtype=(_str, 'f4', 'f4', 'f4')) + fid_names = ("FidNz", "FidT9", "FidT10") + options = dict(dtype=(_str, "f4", "f4", "f4")) ch_names, xs, ys, zs = _safe_np_loadtxt(fname, **options) # deal with "headshape" - mask = np.array([ch_name == 'headshape' for ch_name in ch_names], bool) + mask = np.array([ch_name == "headshape" for ch_name in ch_names], bool) hsp = np.stack([xs[mask], ys[mask], zs[mask]], axis=-1) mask = ~mask pos = np.stack([xs[mask], ys[mask], zs[mask]], axis=-1) @@ -193,14 +194,16 @@ def _read_sfp(fname, head_size): lpa = lpa * scale if lpa is not None else None rpa = rpa * scale if rpa is not None else None - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, rpa=rpa, lpa=lpa, hsp=hsp) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, rpa=rpa, lpa=lpa, hsp=hsp + ) def _read_csd(fname, head_size): # Label, Theta, Phi, Radius, X, Y, Z, off sphere surface - options = dict(comments='//', - dtype=(_str, 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4')) + options = dict( + comments="//", dtype=(_str, "f4", "f4", "f4", "f4", "f4", "f4", "f4") + ) ch_names, _, _, _, xs, ys, zs, _ = _safe_np_loadtxt(fname, **options) pos = np.stack([xs, ys, zs], axis=-1) @@ -213,16 +216,15 @@ def _read_csd(fname, head_size): def _check_dupes_odict(ch_names, pos): """Warn if there are duplicates, then turn to ordered dict.""" ch_names = list(ch_names) - dups = OrderedDict((ch_name, ch_names.count(ch_name)) - for ch_name in ch_names) - dups = OrderedDict((ch_name, count) for ch_name, count in dups.items() - if count > 1) + dups = OrderedDict((ch_name, ch_names.count(ch_name)) for ch_name in ch_names) + dups = OrderedDict((ch_name, count) for ch_name, count in dups.items() if count > 1) n = len(dups) if n: - dups = ', '.join( - f'{ch_name} ({count})' for ch_name, count in dups.items()) - warn(f'Duplicate channel position{_pl(n)} found, the last will be ' - f'used for {dups}') + dups = ", ".join(f"{ch_name} ({count})" for ch_name, count in dups.items()) + warn( + f"Duplicate channel position{_pl(n)} found, the last will be " + f"used for {dups}" + ) return OrderedDict(zip(ch_names, pos)) @@ -242,30 +244,30 @@ def _read_elc(fname, head_size): montage : instance of DigMontage The montage in [m]. """ - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") ch_names_, pos = [], [] with open(fname) as fid: # _read_elc does require to detect the units. (see _mgh_or_standard) for line in fid: - if 'UnitPosition' in line: + if "UnitPosition" in line: units = line.split()[1] - scale = dict(m=1., mm=1e-3)[units] + scale = dict(m=1.0, mm=1e-3)[units] break else: - raise RuntimeError('Could not detect units in file %s' % fname) + raise RuntimeError("Could not detect units in file %s" % fname) for line in fid: - if 'Positions\n' in line: + if "Positions\n" in line: break pos = [] for line in fid: - if 'Labels\n' in line: + if "Labels\n" in line: break pos.append(list(map(float, line.split()))) for line in fid: - if not line or not set(line) - {' '}: + if not line or not set(line) - {" "}: break - ch_names_.append(line.strip(' ').strip('\n')) + ch_names_.append(line.strip(" ").strip("\n")) pos = np.array(pos) * scale if head_size is not None: @@ -274,14 +276,15 @@ def _read_elc(fname, head_size): ch_pos = _check_dupes_odict(ch_names_, pos) nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa + ) -def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, - add_fiducials=False): - ch_names, theta, phi = _safe_np_loadtxt(fname, skip_header=1, - dtype=(_str, 'i4', 'i4')) +def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, add_fiducials=False): + ch_names, theta, phi = _safe_np_loadtxt( + fname, skip_header=1, dtype=(_str, "i4", "i4") + ) if add_fiducials: # Add fiducials based on 10/20 spherical coordinate definitions # http://chgd.umich.edu/wp-content/uploads/2014/06/ @@ -290,7 +293,7 @@ def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, # https://www.easycap.de/wp-content/uploads/2018/02/ # Easycap-Equidistant-Layouts.pdf assert fid_names is None - fid_names = ['Nasion', 'LPA', 'RPA'] + fid_names = ["Nasion", "LPA", "RPA"] ch_names.extend(fid_names) theta = np.append(theta, [115, -115, 115]) phi = np.append(phi, [90, 0, 0]) @@ -303,23 +306,23 @@ def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, if fid_names is not None: nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa + ) def _read_elp_besa(fname, head_size): # This .elp is not the same as polhemus elp. see _read_isotrak_elp_points - dtype = np.dtype('S8, S8, f8, f8, f8') + dtype = np.dtype("S8, S8, f8, f8, f8") data = np.loadtxt(fname, dtype=dtype) - ch_names = data['f1'].astype(str).tolist() - az = data['f2'] - horiz = data['f3'] - radius = np.abs(az / 180.) - az = np.deg2rad(np.array([h if a >= 0. else 180 + h - for h, a in zip(horiz, az)])) + ch_names = data["f1"].astype(str).tolist() + az = data["f2"] + horiz = data["f3"] + radius = np.abs(az / 180.0) + az = np.deg2rad(np.array([h if a >= 0.0 else 180 + h for h, a in zip(horiz, az)])) pol = radius * np.pi - rad = data['f4'] / 100 + rad = data["f4"] / 100 pos = _sph_to_cart(np.array([rad, az, pol]).T) if head_size is not None: @@ -327,7 +330,7 @@ def _read_elp_besa(fname, head_size): ch_pos = _check_dupes_odict(ch_names, pos) - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") # No one grants that the fid names actually exist. nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 211b0275441..f599313304c 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -22,18 +22,41 @@ import numpy as np from ..defaults import HEAD_SIZE_DEFAULT, _handle_default -from ..utils import (verbose, logger, warn, - _check_preload, _validate_type, fill_doc, _check_option, - _get_stim_channel, _check_fname, _check_dict_keys, - _on_missing, legacy) +from ..utils import ( + verbose, + logger, + warn, + _check_preload, + _validate_type, + fill_doc, + _check_option, + _get_stim_channel, + _check_fname, + _check_dict_keys, + _on_missing, + legacy, +) from ..io.constants import FIFF -from ..io.meas_info import (anonymize_info, Info, MontageMixin, create_info, - _rename_comps) -from ..io.pick import (channel_type, pick_info, pick_types, _picks_by_type, - _check_excludes_includes, _contains_ch_type, - channel_indices_by_type, pick_channels, _picks_to_idx, - get_channel_type_constants, - _pick_data_channels) +from ..io.meas_info import ( + anonymize_info, + Info, + MontageMixin, + create_info, + _rename_comps, +) +from ..io.pick import ( + channel_type, + pick_info, + pick_types, + _picks_by_type, + _check_excludes_includes, + _contains_ch_type, + channel_indices_by_type, + pick_channels, + _picks_to_idx, + get_channel_type_constants, + _pick_data_channels, +) from ..io.tag import _rename_list from ..io.write import DATE_NONE from ..io.proj import setup_proj @@ -42,39 +65,40 @@ def _get_meg_system(info): """Educated guess for the helmet type based on channels.""" have_helmet = True - for ch in info['chs']: - if ch['kind'] == FIFF.FIFFV_MEG_CH: + for ch in info["chs"]: + if ch["kind"] == FIFF.FIFFV_MEG_CH: # Only take first 16 bits, as higher bits store CTF grad comp order - coil_type = ch['coil_type'] & 0xFFFF - nmag = np.sum( - [c['kind'] == FIFF.FIFFV_MEG_CH for c in info['chs']]) + coil_type = ch["coil_type"] & 0xFFFF + nmag = np.sum([c["kind"] == FIFF.FIFFV_MEG_CH for c in info["chs"]]) if coil_type == FIFF.FIFFV_COIL_NM_122: - system = '122m' + system = "122m" break elif coil_type // 1000 == 3: # All Vectorview coils are 30xx - system = '306m' + system = "306m" break - elif (coil_type == FIFF.FIFFV_COIL_MAGNES_MAG or - coil_type == FIFF.FIFFV_COIL_MAGNES_GRAD): - system = 'Magnes_3600wh' if nmag > 150 else 'Magnes_2500wh' + elif ( + coil_type == FIFF.FIFFV_COIL_MAGNES_MAG + or coil_type == FIFF.FIFFV_COIL_MAGNES_GRAD + ): + system = "Magnes_3600wh" if nmag > 150 else "Magnes_2500wh" break elif coil_type == FIFF.FIFFV_COIL_CTF_GRAD: - system = 'CTF_275' + system = "CTF_275" break elif coil_type == FIFF.FIFFV_COIL_KIT_GRAD: - system = 'KIT' + system = "KIT" # Our helmet does not match very well, so let's just create it have_helmet = False break elif coil_type == FIFF.FIFFV_COIL_BABY_GRAD: - system = 'BabySQUID' + system = "BabySQUID" break elif coil_type == FIFF.FIFFV_COIL_ARTEMIS123_GRAD: - system = 'ARTEMIS123' + system = "ARTEMIS123" have_helmet = False break else: - system = 'unknown' + system = "unknown" have_helmet = False return system, have_helmet @@ -86,11 +110,24 @@ def _get_ch_type(inst, ch_type, allow_ref_meg=False): then grads, then ... to plot. """ if ch_type is None: - allowed_types = ['mag', 'grad', 'planar1', 'planar2', 'eeg', 'csd', - 'fnirs_cw_amplitude', 'fnirs_fd_ac_amplitude', - 'fnirs_fd_phase', 'fnirs_od', 'hbo', 'hbr', - 'ecog', 'seeg', 'dbs'] - allowed_types += ['ref_meg'] if allow_ref_meg else [] + allowed_types = [ + "mag", + "grad", + "planar1", + "planar2", + "eeg", + "csd", + "fnirs_cw_amplitude", + "fnirs_fd_ac_amplitude", + "fnirs_fd_phase", + "fnirs_od", + "hbo", + "hbr", + "ecog", + "seeg", + "dbs", + ] + allowed_types += ["ref_meg"] if allow_ref_meg else [] for type_ in allowed_types: if isinstance(inst, Info): if _contains_ch_type(inst, type_): @@ -100,7 +137,7 @@ def _get_ch_type(inst, ch_type, allow_ref_meg=False): ch_type = type_ break else: - raise RuntimeError('No plottable channel types found') + raise RuntimeError("No plottable channel types found") return ch_type @@ -147,16 +184,26 @@ def equalize_channels(instances, copy=True, verbose=None): # Instances need to have a `ch_names` attribute and a `pick_channels` # method that supports `ordered=True`. - allowed_types = (BaseRaw, BaseEpochs, Evoked, _BaseTFR, Forward, - Covariance, CrossSpectralDensity, Info) - allowed_types_str = ("Raw, Epochs, Evoked, TFR, Forward, Covariance, " - "CrossSpectralDensity or Info") + allowed_types = ( + BaseRaw, + BaseEpochs, + Evoked, + _BaseTFR, + Forward, + Covariance, + CrossSpectralDensity, + Info, + ) + allowed_types_str = ( + "Raw, Epochs, Evoked, TFR, Forward, Covariance, " "CrossSpectralDensity or Info" + ) for inst in instances: - _validate_type(inst, allowed_types, "Instances to be modified", - allowed_types_str) + _validate_type( + inst, allowed_types, "Instances to be modified", allowed_types_str + ) chan_template = instances[0].ch_names - logger.info('Identifying common channels ...') + logger.info("Identifying common channels ...") channels = [set(inst.ch_names) for inst in instances] common_channels = set(chan_template).intersection(*channels) all_channels = set(chan_template).union(*channels) @@ -173,8 +220,9 @@ def equalize_channels(instances, copy=True, verbose=None): # Only perform picking when needed if inst.ch_names != common_channels: if isinstance(inst, Info): - sel = pick_channels(inst.ch_names, common_channels, exclude=[], - ordered=True) + sel = pick_channels( + inst.ch_names, common_channels, exclude=[], ordered=True + ) inst = pick_info(inst, sel, copy=copy, verbose=False) else: if copy: @@ -185,47 +233,59 @@ def equalize_channels(instances, copy=True, verbose=None): equalized_instances.append(inst) if dropped: - logger.info('Dropped the following channels:\n%s' % dropped) + logger.info("Dropped the following channels:\n%s" % dropped) elif reordered: - logger.info('Channels have been re-ordered.') + logger.info("Channels have been re-ordered.") return equalized_instances channel_type_constants = get_channel_type_constants(include_defaults=True) -_human2fiff = {k: v.get('kind', FIFF.FIFFV_COIL_NONE) for k, v in - channel_type_constants.items()} -_human2unit = {k: v.get('unit', FIFF.FIFF_UNIT_NONE) for k, v in - channel_type_constants.items()} -_unit2human = {FIFF.FIFF_UNIT_V: 'V', - FIFF.FIFF_UNIT_T: 'T', - FIFF.FIFF_UNIT_T_M: 'T/m', - FIFF.FIFF_UNIT_MOL: 'M', - FIFF.FIFF_UNIT_NONE: 'NA', - FIFF.FIFF_UNIT_CEL: 'C', - FIFF.FIFF_UNIT_S: 'S', - FIFF.FIFF_UNIT_PX: 'px'} +_human2fiff = { + k: v.get("kind", FIFF.FIFFV_COIL_NONE) for k, v in channel_type_constants.items() +} +_human2unit = { + k: v.get("unit", FIFF.FIFF_UNIT_NONE) for k, v in channel_type_constants.items() +} +_unit2human = { + FIFF.FIFF_UNIT_V: "V", + FIFF.FIFF_UNIT_T: "T", + FIFF.FIFF_UNIT_T_M: "T/m", + FIFF.FIFF_UNIT_MOL: "M", + FIFF.FIFF_UNIT_NONE: "NA", + FIFF.FIFF_UNIT_CEL: "C", + FIFF.FIFF_UNIT_S: "S", + FIFF.FIFF_UNIT_PX: "px", +} def _check_set(ch, projs, ch_type): """Ensure type change is compatible with projectors.""" new_kind = _human2fiff[ch_type] - if ch['kind'] != new_kind: + if ch["kind"] != new_kind: for proj in projs: - if ch['ch_name'] in proj['data']['col_names']: - raise RuntimeError('Cannot change channel type for channel %s ' - 'in projector "%s"' - % (ch['ch_name'], proj['desc'])) - ch['kind'] = new_kind + if ch["ch_name"] in proj["data"]["col_names"]: + raise RuntimeError( + "Cannot change channel type for channel %s " + 'in projector "%s"' % (ch["ch_name"], proj["desc"]) + ) + ch["kind"] = new_kind class SetChannelsMixin(MontageMixin): """Mixin class for Raw, Evoked, Epochs.""" @verbose - def set_eeg_reference(self, ref_channels='average', projection=False, - ch_type='auto', forward=None, *, joint=False, - verbose=None): + def set_eeg_reference( + self, + ref_channels="average", + projection=False, + ch_type="auto", + forward=None, + *, + joint=False, + verbose=None, + ): """Specify which reference to use for EEG data. Use this function to explicitly specify the desired reference for EEG. @@ -251,9 +311,16 @@ def set_eeg_reference(self, ref_channels='average', projection=False, %(set_eeg_reference_see_also_notes)s """ from ..io.reference import set_eeg_reference - return set_eeg_reference(self, ref_channels=ref_channels, copy=False, - projection=projection, ch_type=ch_type, - forward=forward, joint=joint)[0] + + return set_eeg_reference( + self, + ref_channels=ref_channels, + copy=False, + projection=projection, + ch_type=ch_type, + forward=forward, + joint=joint, + )[0] def _get_channel_positions(self, picks=None): """Get channel locations from info. @@ -268,12 +335,13 @@ def _get_channel_positions(self, picks=None): .. versionadded:: 0.9.0 """ picks = _picks_to_idx(self.info, picks) - chs = self.info['chs'] - pos = np.array([chs[k]['loc'][:3] for k in picks]) + chs = self.info["chs"] + pos = np.array([chs[k]["loc"][:3] for k in picks]) n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0) if n_zero > 1: # XXX some systems have origin (0, 0, 0) - raise ValueError('Could not extract channel positions for ' - '{} channels'.format(n_zero)) + raise ValueError( + "Could not extract channel positions for " "{} channels".format(n_zero) + ) return pos def _set_channel_positions(self, pos, names): @@ -291,24 +359,25 @@ def _set_channel_positions(self, pos, names): .. versionadded:: 0.9.0 """ if len(pos) != len(names): - raise ValueError('Number of channel positions not equal to ' - 'the number of names given.') + raise ValueError( + "Number of channel positions not equal to " "the number of names given." + ) pos = np.asarray(pos, dtype=np.float64) if pos.shape[-1] != 3 or pos.ndim != 2: - msg = ('Channel positions must have the shape (n_points, 3) ' - 'not %s.' % (pos.shape,)) + msg = "Channel positions must have the shape (n_points, 3) " "not %s." % ( + pos.shape, + ) raise ValueError(msg) for name, p in zip(names, pos): if name in self.ch_names: idx = self.ch_names.index(name) - self.info['chs'][idx]['loc'][:3] = p + self.info["chs"][idx]["loc"][:3] = p else: - msg = ('%s was not found in the info. Cannot be updated.' - % name) + msg = "%s was not found in the info. Cannot be updated." % name raise ValueError(msg) @verbose - def set_channel_types(self, mapping, *, on_unit_change='warn', verbose=None): + def set_channel_types(self, mapping, *, on_unit_change="warn", verbose=None): """Specify the sensor types of channels. Parameters @@ -342,64 +411,66 @@ def set_channel_types(self, mapping, *, on_unit_change='warn', verbose=None): .. versionadded:: 0.9.0 """ - ch_names = self.info['ch_names'] + ch_names = self.info["ch_names"] # first check and assemble clean mappings of index and name unit_changes = dict() for ch_name, ch_type in mapping.items(): if ch_name not in ch_names: - raise ValueError("This channel name (%s) doesn't exist in " - "info." % ch_name) + raise ValueError( + "This channel name (%s) doesn't exist in " "info." % ch_name + ) c_ind = ch_names.index(ch_name) if ch_type not in _human2fiff: - raise ValueError('This function cannot change to this ' - 'channel type: %s. Accepted channel types ' - 'are %s.' - % (ch_type, - ", ".join(sorted(_human2unit.keys())))) + raise ValueError( + "This function cannot change to this " + "channel type: %s. Accepted channel types " + "are %s." % (ch_type, ", ".join(sorted(_human2unit.keys()))) + ) # Set sensor type - _check_set(self.info['chs'][c_ind], self.info['projs'], ch_type) - unit_old = self.info['chs'][c_ind]['unit'] + _check_set(self.info["chs"][c_ind], self.info["projs"], ch_type) + unit_old = self.info["chs"][c_ind]["unit"] unit_new = _human2unit[ch_type] if unit_old not in _unit2human: - raise ValueError("Channel '%s' has unknown unit (%s). Please " - "fix the measurement info of your data." - % (ch_name, unit_old)) + raise ValueError( + "Channel '%s' has unknown unit (%s). Please " + "fix the measurement info of your data." % (ch_name, unit_old) + ) if unit_old != _human2unit[ch_type]: this_change = (_unit2human[unit_old], _unit2human[unit_new]) if this_change not in unit_changes: unit_changes[this_change] = list() unit_changes[this_change].append(ch_name) - self.info['chs'][c_ind]['unit'] = _human2unit[ch_type] - if ch_type in ['eeg', 'seeg', 'ecog', 'dbs']: + self.info["chs"][c_ind]["unit"] = _human2unit[ch_type] + if ch_type in ["eeg", "seeg", "ecog", "dbs"]: coil_type = FIFF.FIFFV_COIL_EEG - elif ch_type == 'hbo': + elif ch_type == "hbo": coil_type = FIFF.FIFFV_COIL_FNIRS_HBO - elif ch_type == 'hbr': + elif ch_type == "hbr": coil_type = FIFF.FIFFV_COIL_FNIRS_HBR - elif ch_type == 'fnirs_cw_amplitude': + elif ch_type == "fnirs_cw_amplitude": coil_type = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE - elif ch_type == 'fnirs_fd_ac_amplitude': + elif ch_type == "fnirs_fd_ac_amplitude": coil_type = FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE - elif ch_type == 'fnirs_fd_phase': + elif ch_type == "fnirs_fd_phase": coil_type = FIFF.FIFFV_COIL_FNIRS_FD_PHASE - elif ch_type == 'fnirs_od': + elif ch_type == "fnirs_od": coil_type = FIFF.FIFFV_COIL_FNIRS_OD - elif ch_type == 'eyetrack_pos': + elif ch_type == "eyetrack_pos": coil_type = FIFF.FIFFV_COIL_EYETRACK_POS - elif ch_type == 'eyetrack_pupil': + elif ch_type == "eyetrack_pupil": coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL else: coil_type = FIFF.FIFFV_COIL_NONE - self.info['chs'][c_ind]['coil_type'] = coil_type + self.info["chs"][c_ind]["coil_type"] = coil_type msg = "The unit for channel(s) {0} has changed from {1} to {2}." for this_change, names in unit_changes.items(): _on_missing( on_missing=on_unit_change, msg=msg.format(", ".join(sorted(names)), *this_change), - name='on_unit_change', + name="on_unit_change", ) return self @@ -427,13 +498,13 @@ def rename_channels(self, mapping, allow_duplicates=False, verbose=None): """ from ..io import BaseRaw - ch_names_orig = list(self.info['ch_names']) + ch_names_orig = list(self.info["ch_names"]) rename_channels(self.info, mapping, allow_duplicates) # Update self._orig_units for Raw if isinstance(self, BaseRaw): # whatever mapping was provided, now we can just use a dict - mapping = dict(zip(ch_names_orig, self.info['ch_names'])) + mapping = dict(zip(ch_names_orig, self.info["ch_names"])) for old_name, new_name in mapping.items(): if old_name in self._orig_units: self._orig_units[new_name] = self._orig_units.pop(old_name) @@ -444,10 +515,20 @@ def rename_channels(self, mapping, allow_duplicates=False, verbose=None): return self @verbose - def plot_sensors(self, kind='topomap', ch_type=None, title=None, - show_names=False, ch_groups=None, to_sphere=True, - axes=None, block=False, show=True, sphere=None, - verbose=None): + def plot_sensors( + self, + kind="topomap", + ch_type=None, + title=None, + show_names=False, + ch_groups=None, + to_sphere=True, + axes=None, + block=False, + show=True, + sphere=None, + verbose=None, + ): """Plot sensor positions. Parameters @@ -518,10 +599,21 @@ def plot_sensors(self, kind='topomap', ch_type=None, title=None, .. versionadded:: 0.12.0 """ from ..viz.utils import plot_sensors - return plot_sensors(self.info, kind=kind, ch_type=ch_type, title=title, - show_names=show_names, ch_groups=ch_groups, - to_sphere=to_sphere, axes=axes, block=block, - show=show, sphere=sphere, verbose=verbose) + + return plot_sensors( + self.info, + kind=kind, + ch_type=ch_type, + title=title, + show_names=show_names, + ch_groups=ch_groups, + to_sphere=to_sphere, + axes=axes, + block=block, + show=show, + sphere=sphere, + verbose=verbose, + ) @verbose def anonymize(self, daysback=None, keep_his=False, verbose=None): @@ -544,9 +636,8 @@ def anonymize(self, daysback=None, keep_his=False, verbose=None): .. versionadded:: 0.13.0 """ - anonymize_info(self.info, daysback=daysback, keep_his=keep_his, - verbose=verbose) - self.set_meas_date(self.info['meas_date']) # unify annot update + anonymize_info(self.info, daysback=daysback, keep_his=keep_his, verbose=verbose) + self.set_meas_date(self.info["meas_date"]) # unify annot update return self def set_meas_date(self, meas_date): @@ -580,25 +671,26 @@ def set_meas_date(self, meas_date): .. versionadded:: 0.20 """ from ..annotations import _handle_meas_date + meas_date = _handle_meas_date(meas_date) with self.info._unlock(): - self.info['meas_date'] = meas_date + self.info["meas_date"] = meas_date # clear file_id and meas_id if needed if meas_date is None: - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): value = self.info.get(key) if value is not None: - assert 'msecs' not in value - value['secs'] = DATE_NONE[0] - value['usecs'] = DATE_NONE[1] + assert "msecs" not in value + value["secs"] = DATE_NONE[0] + value["usecs"] = DATE_NONE[1] # The following copy is needed for a test CTF dataset # otherwise value['machid'][:] = 0 would suffice - _tmp = value['machid'].copy() + _tmp = value["machid"].copy() _tmp[:] = 0 - value['machid'] = _tmp + value["machid"] = _tmp - if hasattr(self, 'annotations'): + if hasattr(self, "annotations"): self.annotations._orig_time = meas_date return self @@ -607,14 +699,39 @@ class UpdateChannelsMixin: """Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR.""" @verbose - @legacy(alt='inst.pick(...)') - def pick_types(self, meg=False, eeg=False, stim=False, eog=False, - ecg=False, emg=False, ref_meg='auto', *, misc=False, - resp=False, chpi=False, exci=False, ias=False, syst=False, - seeg=False, dipole=False, gof=False, bio=False, - ecog=False, fnirs=False, csd=False, dbs=False, - temperature=False, gsr=False, eyetrack=False, - include=(), exclude='bads', selection=None, verbose=None): + @legacy(alt="inst.pick(...)") + def pick_types( + self, + meg=False, + eeg=False, + stim=False, + eog=False, + ecg=False, + emg=False, + ref_meg="auto", + *, + misc=False, + resp=False, + chpi=False, + exci=False, + ias=False, + syst=False, + seeg=False, + dipole=False, + gof=False, + bio=False, + ecog=False, + fnirs=False, + csd=False, + dbs=False, + temperature=False, + gsr=False, + eyetrack=False, + include=(), + exclude="bads", + selection=None, + verbose=None, + ): """Pick some channels by type and names. Parameters @@ -636,24 +753,47 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, .. versionadded:: 0.9.0 """ idx = pick_types( - self.info, meg=meg, eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, - ref_meg=ref_meg, misc=misc, resp=resp, chpi=chpi, exci=exci, - ias=ias, syst=syst, seeg=seeg, dipole=dipole, gof=gof, bio=bio, - ecog=ecog, fnirs=fnirs, csd=csd, dbs=dbs, temperature=temperature, - gsr=gsr, eyetrack=eyetrack, include=include, exclude=exclude, - selection=selection) + self.info, + meg=meg, + eeg=eeg, + stim=stim, + eog=eog, + ecg=ecg, + emg=emg, + ref_meg=ref_meg, + misc=misc, + resp=resp, + chpi=chpi, + exci=exci, + ias=ias, + syst=syst, + seeg=seeg, + dipole=dipole, + gof=gof, + bio=bio, + ecog=ecog, + fnirs=fnirs, + csd=csd, + dbs=dbs, + temperature=temperature, + gsr=gsr, + eyetrack=eyetrack, + include=include, + exclude=exclude, + selection=selection, + ) self._pick_drop_channels(idx) # remove dropped channel types from reject and flat - if getattr(self, 'reject', None) is not None: + if getattr(self, "reject", None) is not None: # use list(self.reject) to avoid RuntimeError for changing # dictionary size during iteration for ch_type in list(self.reject): if ch_type not in self: del self.reject[ch_type] - if getattr(self, 'flat', None) is not None: + if getattr(self, "flat", None) is not None: for ch_type in list(self.flat): if ch_type not in self: del self.flat[ch_type] @@ -661,7 +801,7 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, return self @verbose - @legacy(alt='inst.pick(...)') + @legacy(alt="inst.pick(...)") def pick_channels(self, ch_names, ordered=None, *, verbose=None): """Pick some channels. @@ -693,7 +833,7 @@ def pick_channels(self, ch_names, ordered=None, *, verbose=None): .. versionadded:: 0.9.0 """ - picks = pick_channels(self.info['ch_names'], ch_names, ordered=ordered) + picks = pick_channels(self.info["ch_names"], ch_names, ordered=ordered) return self._pick_drop_channels(picks) @verbose @@ -715,8 +855,7 @@ def pick(self, picks, exclude=(), *, verbose=None): inst : instance of Raw, Epochs, or Evoked The modified instance. """ - picks = _picks_to_idx(self.info, picks, 'all', exclude, - allow_empty=False) + picks = _picks_to_idx(self.info, picks, "all", exclude, allow_empty=False) return self._pick_drop_channels(picks) def reorder_channels(self, ch_names): @@ -750,12 +889,12 @@ def reorder_channels(self, ch_names): for ch_name in ch_names: ii = self.ch_names.index(ch_name) if ii in idx: - raise ValueError('Channel name repeated: %s' % (ch_name,)) + raise ValueError("Channel name repeated: %s" % (ch_name,)) idx.append(ii) return self._pick_drop_channels(idx) @fill_doc - def drop_channels(self, ch_names, on_missing='raise'): + def drop_channels(self, ch_names, on_missing="raise"): """Drop channel(s). Parameters @@ -785,20 +924,23 @@ def drop_channels(self, ch_names, on_missing='raise'): try: all_str = all([isinstance(ch, str) for ch in ch_names]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(ch_names), ch_names)) + raise ValueError( + "'ch_names' must be iterable, got " + "type {} ({}).".format(type(ch_names), ch_names) + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be str, got " - "{}.".format([type(ch) for ch in ch_names])) + raise ValueError( + "Each element in 'ch_names' must be str, got " + "{}.".format([type(ch) for ch in ch_names]) + ) missing = [ch for ch in ch_names if ch not in self.ch_names] if len(missing) > 0: msg = "Channel(s) {0} not found, nothing dropped." _on_missing(on_missing, msg.format(", ".join(missing))) - bad_idx = [self.ch_names.index(ch) for ch in ch_names - if ch in self.ch_names] + bad_idx = [self.ch_names.index(ch) for ch in ch_names if ch in self.ch_names] idx = np.setdiff1d(np.arange(len(self.ch_names)), bad_idx) return self._pick_drop_channels(idx) @@ -809,45 +951,45 @@ def _pick_drop_channels(self, idx, *, verbose=None): from ..time_frequency import AverageTFR, EpochsTFR from ..time_frequency.spectrum import BaseSpectrum - msg = 'adding, dropping, or reordering channels' + msg = "adding, dropping, or reordering channels" if isinstance(self, BaseRaw): if self._projector is not None: - _check_preload(self, f'{msg} after calling .apply_proj()') + _check_preload(self, f"{msg} after calling .apply_proj()") else: _check_preload(self, msg) - if getattr(self, 'picks', None) is not None: + if getattr(self, "picks", None) is not None: self.picks = self.picks[idx] - if getattr(self, '_read_picks', None) is not None: + if getattr(self, "_read_picks", None) is not None: self._read_picks = [r[idx] for r in self._read_picks] - if hasattr(self, '_cals'): + if hasattr(self, "_cals"): self._cals = self._cals[idx] pick_info(self.info, idx, copy=False) - for key in ('_comp', '_projector'): + for key in ("_comp", "_projector"): mat = getattr(self, key, None) if mat is not None: setattr(self, key, mat[idx][:, idx]) if isinstance(self, BaseSpectrum): - axis = self._dims.index('channel') + axis = self._dims.index("channel") elif isinstance(self, (AverageTFR, EpochsTFR)): axis = -3 else: # All others (Evoked, Epochs, Raw) have chs axis=-2 axis = -2 - if hasattr(self, '_data'): # skip non-preloaded Raw + if hasattr(self, "_data"): # skip non-preloaded Raw self._data = self._data.take(idx, axis=axis) else: assert isinstance(self, BaseRaw) and not self.preload if isinstance(self, BaseRaw): - self.annotations._prune_ch_names(self.info, on_missing='ignore') + self.annotations._prune_ch_names(self.info, on_missing="ignore") self._orig_units = { - k: v for k, v in self._orig_units.items() - if k in self.ch_names} + k: v for k, v in self._orig_units.items() if k in self.ch_names + } self._pick_projs() return self @@ -855,14 +997,14 @@ def _pick_drop_channels(self, idx, *, verbose=None): def _pick_projs(self): """Keep only projectors which apply to at least 1 data channel.""" drop_idx = [] - for idx, proj in enumerate(self.info['projs']): - if not set(self.info['ch_names']) & set(proj['data']['col_names']): + for idx, proj in enumerate(self.info["projs"]): + if not set(self.info["ch_names"]) & set(proj["data"]["col_names"]): drop_idx.append(idx) for idx in drop_idx: logger.info(f"Removing projector {self.info['projs'][idx]}") - if drop_idx and hasattr(self, 'del_proj'): + if drop_idx and hasattr(self, "del_proj"): self.del_proj(drop_idx) return self @@ -900,7 +1042,7 @@ def add_channels(self, add_list, force_update_info=False): from ..io import BaseRaw, _merge_info from ..epochs import BaseEpochs - _validate_type(add_list, (list, tuple), 'Input') + _validate_type(add_list, (list, tuple), "Input") # Object-specific checks for inst in add_list + [self]: @@ -915,7 +1057,7 @@ def add_channels(self, add_list, force_update_info=False): con_axis = 0 comp_class = type(self) for inst in add_list: - _validate_type(inst, comp_class, 'All input') + _validate_type(inst, comp_class, "All input") data = [inst._data for inst in [self] + add_list] # Make sure that all dimensions other than channel axis are the same @@ -924,8 +1066,9 @@ def add_channels(self, add_list, force_update_info=False): for shape in shapes: if not ((shapes[0] - shape) == 0).all(): raise ValueError( - 'All data dimensions except channels must match, got ' - f'{shapes[0]} != {shape}') + "All data dimensions except channels must match, got " + f"{shapes[0]} != {shape}" + ) del shapes # Create final data / info objects @@ -933,43 +1076,50 @@ def add_channels(self, add_list, force_update_info=False): new_info = _merge_info(infos, force_update_to_first=force_update_info) # Now update the attributes - if isinstance(self._data, np.memmap) and con_axis == 0 and \ - sys.platform != 'darwin': # resizing not available--no mremap + if ( + isinstance(self._data, np.memmap) + and con_axis == 0 + and sys.platform != "darwin" + ): # resizing not available--no mremap # Use a resize and fill in other ones out_shape = (sum(d.shape[0] for d in data),) + data[0].shape[1:] n_bytes = np.prod(out_shape) * self._data.dtype.itemsize self._data.flush() self._data.base.resize(n_bytes) - self._data = np.memmap(self._data.filename, mode='r+', - dtype=self._data.dtype, shape=out_shape) + self._data = np.memmap( + self._data.filename, mode="r+", dtype=self._data.dtype, shape=out_shape + ) assert self._data.shape == out_shape assert self._data.nbytes == n_bytes offset = len(data[0]) for d in data[1:]: this_len = len(d) - self._data[offset:offset + this_len] = d + self._data[offset : offset + this_len] = d offset += this_len else: self._data = np.concatenate(data, axis=con_axis) self.info = new_info if isinstance(self, BaseRaw): - self._cals = np.concatenate([getattr(inst, '_cals') - for inst in [self] + add_list]) + self._cals = np.concatenate( + [getattr(inst, "_cals") for inst in [self] + add_list] + ) # We should never use these since data are preloaded, let's just # set it to something large and likely to break (2 ** 31 - 1) - extra_idx = [2147483647] * sum(info['nchan'] for info in infos[1:]) - assert all(len(r) == infos[0]['nchan'] for r in self._read_picks) + extra_idx = [2147483647] * sum(info["nchan"] for info in infos[1:]) + assert all(len(r) == infos[0]["nchan"] for r in self._read_picks) self._read_picks = [ - np.concatenate([r, extra_idx]) for r in self._read_picks] - assert all(len(r) == self.info['nchan'] for r in self._read_picks) + np.concatenate([r, extra_idx]) for r in self._read_picks + ] + assert all(len(r) == self.info["nchan"] for r in self._read_picks) for other in add_list: self._orig_units.update(other._orig_units) elif isinstance(self, BaseEpochs): self.picks = np.arange(self._data.shape[1]) - if hasattr(self, '_projector'): + if hasattr(self, "_projector"): activate = False if self._do_delayed_proj else self.proj - self._projector, self.info = setup_proj(self.info, False, - activate=activate) + self._projector, self.info = setup_proj( + self.info, False, activate=activate + ) return self @@ -999,9 +1149,15 @@ class InterpolationMixin: """Mixin class for Raw, Evoked, Epochs.""" @verbose - def interpolate_bads(self, reset_bads=True, mode='accurate', - origin='auto', method=None, exclude=(), - verbose=None): + def interpolate_bads( + self, + reset_bads=True, + mode="accurate", + origin="auto", + method=None, + exclude=(), + verbose=None, + ): """Interpolate bad MEG and EEG channels. Operates in place. @@ -1052,34 +1208,37 @@ def interpolate_bads(self, reset_bads=True, mode='accurate', .. versionadded:: 0.9.0 """ from ..bem import _check_origin - from .interpolation import _interpolate_bads_eeg,\ - _interpolate_bads_meeg, _interpolate_bads_nirs + from .interpolation import ( + _interpolate_bads_eeg, + _interpolate_bads_meeg, + _interpolate_bads_nirs, + ) _check_preload(self, "interpolation") - method = _handle_default('interpolation_method', method) + method = _handle_default("interpolation_method", method) for key in method: - _check_option('method[key]', key, ('meg', 'eeg', 'fnirs')) - _check_option("method['eeg']", method['eeg'], ('spline', 'MNE')) - _check_option("method['meg']", method['meg'], ('MNE',)) - _check_option("method['fnirs']", method['fnirs'], ('nearest',)) + _check_option("method[key]", key, ("meg", "eeg", "fnirs")) + _check_option("method['eeg']", method["eeg"], ("spline", "MNE")) + _check_option("method['meg']", method["meg"], ("MNE",)) + _check_option("method['fnirs']", method["fnirs"], ("nearest",)) - if len(self.info['bads']) == 0: - warn('No bad channels to interpolate. Doing nothing...') + if len(self.info["bads"]) == 0: + warn("No bad channels to interpolate. Doing nothing...") return self - logger.info('Interpolating bad channels') + logger.info("Interpolating bad channels") origin = _check_origin(origin, self.info) - if method['eeg'] == 'spline': + if method["eeg"] == "spline": _interpolate_bads_eeg(self, origin=origin, exclude=exclude) eeg_mne = False else: eeg_mne = True - _interpolate_bads_meeg(self, mode=mode, origin=origin, eeg=eeg_mne, - exclude=exclude) + _interpolate_bads_meeg( + self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude + ) _interpolate_bads_nirs(self, exclude=exclude) if reset_bads is True: - self.info['bads'] = \ - [ch for ch in self.info['bads'] if ch in exclude] + self.info["bads"] = [ch for ch in self.info["bads"] if ch in exclude] return self @@ -1094,27 +1253,30 @@ def rename_channels(info, mapping, allow_duplicates=False, verbose=None): %(mapping_rename_channels_duplicates)s %(verbose)s """ - _validate_type(info, Info, 'info') + _validate_type(info, Info, "info") info._check_consistency() - bads = list(info['bads']) # make our own local copies - ch_names = list(info['ch_names']) + bads = list(info["bads"]) # make our own local copies + ch_names = list(info["ch_names"]) # first check and assemble clean mappings of index and name if isinstance(mapping, dict): - _check_dict_keys(mapping, ch_names, key_description="channel name(s)", - valid_key_source="info") - new_names = [(ch_names.index(ch_name), new_name) - for ch_name, new_name in mapping.items()] + _check_dict_keys( + mapping, + ch_names, + key_description="channel name(s)", + valid_key_source="info", + ) + new_names = [ + (ch_names.index(ch_name), new_name) for ch_name, new_name in mapping.items() + ] elif callable(mapping): - new_names = [(ci, mapping(ch_name)) - for ci, ch_name in enumerate(ch_names)] + new_names = [(ci, mapping(ch_name)) for ci, ch_name in enumerate(ch_names)] else: - raise ValueError('mapping must be callable or dict, not %s' - % (type(mapping),)) + raise ValueError("mapping must be callable or dict, not %s" % (type(mapping),)) # check we got all strings out of the mapping for new_name in new_names: - _validate_type(new_name[1], 'str', 'New channel mappings') + _validate_type(new_name[1], "str", "New channel mappings") # do the remapping locally for c_ind, new_name in new_names: @@ -1125,20 +1287,21 @@ def rename_channels(info, mapping, allow_duplicates=False, verbose=None): # check that all the channel names are unique if len(ch_names) != len(np.unique(ch_names)) and not allow_duplicates: - raise ValueError('New channel names are not unique, renaming failed') + raise ValueError("New channel names are not unique, renaming failed") # do the remapping in info - info['bads'] = bads + info["bads"] = bads ch_names_mapping = dict() - for ch, ch_name in zip(info['chs'], ch_names): - ch_names_mapping[ch['ch_name']] = ch_name - ch['ch_name'] = ch_name + for ch, ch_name in zip(info["chs"], ch_names): + ch_names_mapping[ch["ch_name"]] = ch_name + ch["ch_name"] = ch_name # .get b/c fwd info omits it - _rename_comps(info.get('comps', []), ch_names_mapping) - if 'projs' in info: # fwd might omit it - for proj in info['projs']: - proj['data']['col_names'][:] = \ - _rename_list(proj['data']['col_names'], ch_names_mapping) + _rename_comps(info.get("comps", []), ch_names_mapping) + if "projs" in info: # fwd might omit it + for proj in info["projs"]: + proj["data"]["col_names"][:] = _rename_list( + proj["data"]["col_names"], ch_names_mapping + ) info._update_redundant() info._check_consistency() @@ -1160,244 +1323,277 @@ class _BuiltinChannelAdjacency: _ft_neighbor_url_t = string.Template( - 'https://github.com/fieldtrip/fieldtrip/raw/master/' - 'template/neighbours/$fname' + "https://github.com/fieldtrip/fieldtrip/raw/master/" "template/neighbours/$fname" ) _BUILTIN_CHANNEL_ADJACENCIES = [ _BuiltinChannelAdjacency( - name='biosemi16', - description='Biosemi 16-electrode cap', - fname='biosemi16_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi16_neighb.mat'), + name="biosemi16", + description="Biosemi 16-electrode cap", + fname="biosemi16_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi16_neighb.mat"), ), _BuiltinChannelAdjacency( - name='biosemi32', - description='Biosemi 32-electrode cap', - fname='biosemi32_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi32_neighb.mat'), + name="biosemi32", + description="Biosemi 32-electrode cap", + fname="biosemi32_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi32_neighb.mat"), ), _BuiltinChannelAdjacency( - name='biosemi64', - description='Biosemi 64-electrode cap', - fname='biosemi64_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi64_neighb.mat'), + name="biosemi64", + description="Biosemi 64-electrode cap", + fname="biosemi64_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi64_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti148', - description='BTI 148-channel system', - fname='bti148_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti148_neighb.mat'), + name="bti148", + description="BTI 148-channel system", + fname="bti148_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="bti148_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti248', - description='BTI 248-channel system', - fname='bti248_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti248_neighb.mat'), + name="bti248", + description="BTI 248-channel system", + fname="bti248_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="bti248_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti248grad', - description='BTI 248 gradiometer system', - fname='bti248grad_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti248grad_neighb.mat'), # noqa: E501 + name="bti248grad", + description="BTI 248 gradiometer system", + fname="bti248grad_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="bti248grad_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ctf64', - description='CTF 64 axial gradiometer', - fname='ctf64_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf64_neighb.mat'), + name="ctf64", + description="CTF 64 axial gradiometer", + fname="ctf64_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf64_neighb.mat"), ), _BuiltinChannelAdjacency( - name='ctf151', - description='CTF 151 axial gradiometer', - fname='ctf151_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf151_neighb.mat'), + name="ctf151", + description="CTF 151 axial gradiometer", + fname="ctf151_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf151_neighb.mat"), ), _BuiltinChannelAdjacency( - name='ctf275', - description='CTF 275 axial gradiometer', - fname='ctf275_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf275_neighb.mat'), + name="ctf275", + description="CTF 275 axial gradiometer", + fname="ctf275_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf275_neighb.mat"), ), _BuiltinChannelAdjacency( - name='easycap32ch-avg', - description='', - fname='easycap32ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap32ch-avg_neighb.mat'), # noqa: E501 + name="easycap32ch-avg", + description="", + fname="easycap32ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap32ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycap64ch-avg', - description='', - fname='easycap64ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap64ch-avg_neighb.mat'), # noqa: E501 + name="easycap64ch-avg", + description="", + fname="easycap64ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap64ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycap128ch-avg', - description='', - fname='easycap128ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap128ch-avg_neighb.mat'), # noqa: E501 + name="easycap128ch-avg", + description="", + fname="easycap128ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap128ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM1', - description='Easycap M1', - fname='easycapM1_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM1_neighb.mat'), + name="easycapM1", + description="Easycap M1", + fname="easycapM1_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="easycapM1_neighb.mat"), ), _BuiltinChannelAdjacency( - name='easycapM11', - description='Easycap M11', - fname='easycapM11_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM11_neighb.mat'), # noqa: E501 + name="easycapM11", + description="Easycap M11", + fname="easycapM11_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM11_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM14', - description='Easycap M14', - fname='easycapM14_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM14_neighb.mat'), # noqa: E501 + name="easycapM14", + description="Easycap M14", + fname="easycapM14_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM14_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM15', - description='Easycap M15', - fname='easycapM15_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM15_neighb.mat'), # noqa: E501 + name="easycapM15", + description="Easycap M15", + fname="easycapM15_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM15_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='KIT-157', - description='', - fname='KIT-157_neighb.mat', + name="KIT-157", + description="", + fname="KIT-157_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-208', - description='', - fname='KIT-208_neighb.mat', + name="KIT-208", + description="", + fname="KIT-208_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-NYU-2019', - description='', - fname='KIT-NYU-2019_neighb.mat', + name="KIT-NYU-2019", + description="", + fname="KIT-NYU-2019_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-1', - description='', - fname='KIT-UMD-1_neighb.mat', + name="KIT-UMD-1", + description="", + fname="KIT-UMD-1_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-2', - description='', - fname='KIT-UMD-2_neighb.mat', + name="KIT-UMD-2", + description="", + fname="KIT-UMD-2_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-3', - description='', - fname='KIT-UMD-3_neighb.mat', + name="KIT-UMD-3", + description="", + fname="KIT-UMD-3_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-4', - description='', - fname='KIT-UMD-4_neighb.mat', + name="KIT-UMD-4", + description="", + fname="KIT-UMD-4_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='neuromag306mag', - description='Neuromag306, only magnetometers', - fname='neuromag306mag_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306mag_neighb.mat'), # noqa: E501 + name="neuromag306mag", + description="Neuromag306, only magnetometers", + fname="neuromag306mag_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306mag_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag306planar', - description='Neuromag306, only planar gradiometers', - fname='neuromag306planar_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306planar_neighb.mat'), # noqa: E501 + name="neuromag306planar", + description="Neuromag306, only planar gradiometers", + fname="neuromag306planar_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306planar_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag122cmb', - description='Neuromag122, only combined planar gradiometers', - fname='neuromag122cmb_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag122cmb_neighb.mat'), # noqa: E501 + name="neuromag122cmb", + description="Neuromag122, only combined planar gradiometers", + fname="neuromag122cmb_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag122cmb_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag306cmb', - description='Neuromag306, only combined planar gradiometers', - fname='neuromag306cmb_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306cmb_neighb.mat'), # noqa: E501 + name="neuromag306cmb", + description="Neuromag306, only combined planar gradiometers", + fname="neuromag306cmb_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306cmb_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ecog256', - description='ECOG 256channels, average referenced', - fname='ecog256_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ecog256_neighb.mat'), # noqa: E501 + name="ecog256", + description="ECOG 256channels, average referenced", + fname="ecog256_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="ecog256_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ecog256bipolar', - description='ECOG 256channels, bipolar referenced', - fname='ecog256bipolar_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ecog256bipolar_neighb.mat'), # noqa: E501 + name="ecog256bipolar", + description="ECOG 256channels, bipolar referenced", + fname="ecog256bipolar_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="ecog256bipolar_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='eeg1010_neighb', - description='', - fname='eeg1010_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='eeg1010_neighb.mat'), + name="eeg1010_neighb", + description="", + fname="eeg1010_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="eeg1010_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1005', - description='Standard 10-05 system', - fname='elec1005_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1005_neighb.mat'), + name="elec1005", + description="Standard 10-05 system", + fname="elec1005_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1005_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1010', - description='Standard 10-10 system', - fname='elec1010_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1010_neighb.mat'), + name="elec1010", + description="Standard 10-10 system", + fname="elec1010_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1010_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1020', - description='Standard 10-20 system', - fname='elec1020_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1020_neighb.mat'), + name="elec1020", + description="Standard 10-20 system", + fname="elec1020_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1020_neighb.mat"), ), _BuiltinChannelAdjacency( - name='itab28', - description='ITAB 28-channel system', - fname='itab28_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='itab28_neighb.mat'), + name="itab28", + description="ITAB 28-channel system", + fname="itab28_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="itab28_neighb.mat"), ), _BuiltinChannelAdjacency( - name='itab153', - description='ITAB 153-channel system', - fname='itab153_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='itab153_neighb.mat'), + name="itab153", + description="ITAB 153-channel system", + fname="itab153_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="itab153_neighb.mat"), ), _BuiltinChannelAdjacency( - name='language29ch-avg', - description='MPI for Psycholinguistic: Averaged 29-channel cap', - fname='language29ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='language29ch-avg_neighb.mat'), # noqa: E501 + name="language29ch-avg", + description="MPI for Psycholinguistic: Averaged 29-channel cap", + fname="language29ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="language29ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='mpi_59_channels', - description='MPI for Psycholinguistic: 59-channel cap', - fname='mpi_59_channels_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='mpi_59_channels_neighb.mat'), # noqa: E501 + name="mpi_59_channels", + description="MPI for Psycholinguistic: 59-channel cap", + fname="mpi_59_channels_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="mpi_59_channels_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='yokogawa160', - description='', - fname='yokogawa160_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='yokogawa160_neighb.mat'), # noqa: E501 + name="yokogawa160", + description="", + fname="yokogawa160_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="yokogawa160_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='yokogawa440', - description='', - fname='yokogawa440_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='yokogawa440_neighb.mat'), # noqa: E501 + name="yokogawa440", + description="", + fname="yokogawa440_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="yokogawa440_neighb.mat" + ), # noqa: E501 ), ] @@ -1433,13 +1629,10 @@ def get_builtin_ch_adjacencies(*, descriptions=False): if descriptions: return sorted( [(m.name, m.description) for m in _BUILTIN_CHANNEL_ADJACENCIES], - key=lambda x: x[0].casefold() # only sort based on name + key=lambda x: x[0].casefold(), # only sort based on name ) else: - return sorted( - [m.name for m in _BUILTIN_CHANNEL_ADJACENCIES], - key=str.casefold - ) + return sorted([m.name for m in _BUILTIN_CHANNEL_ADJACENCIES], key=str.casefold) @fill_doc @@ -1488,6 +1681,7 @@ def read_ch_adjacency(fname, picks=None): to pass to the eventual function. """ from scipy.io import loadmat + if op.isabs(fname): fname = str( _check_fname( @@ -1499,20 +1693,19 @@ def read_ch_adjacency(fname, picks=None): else: # built-in FieldTrip neighbors ch_adj_name = fname del fname - if ch_adj_name.endswith('_neighb.mat'): # backward-compat - ch_adj_name = ch_adj_name.replace('_neighb.mat', '') + if ch_adj_name.endswith("_neighb.mat"): # backward-compat + ch_adj_name = ch_adj_name.replace("_neighb.mat", "") if ch_adj_name not in get_builtin_ch_adjacencies(): raise ValueError( - f'No built-in channel adjacency matrix found with name: ' - f'{ch_adj_name}. Valid names are: ' + f"No built-in channel adjacency matrix found with name: " + f"{ch_adj_name}. Valid names are: " f'{", ".join(get_builtin_ch_adjacencies())}' ) - ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES - if a.name == ch_adj_name][0] + ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES if a.name == ch_adj_name][0] fname = ch_adj.fname - templates_dir = Path(__file__).resolve().parent / 'data' / 'neighbors' + templates_dir = Path(__file__).resolve().parent / "data" / "neighbors" fname = str( _check_fname( # only needed to convert to a string fname=templates_dir / fname, @@ -1521,11 +1714,10 @@ def read_ch_adjacency(fname, picks=None): ) ) - nb = loadmat(fname)['neighbours'] - ch_names = _recursive_flatten(nb['label'], str) + nb = loadmat(fname)["neighbours"] + ch_names = _recursive_flatten(nb["label"], str) picks = _picks_to_idx(len(ch_names), picks) - neighbors = [_recursive_flatten(c, str) for c in - nb['neighblabel'].flatten()] + neighbors = [_recursive_flatten(c, str) for c in nb["neighblabel"].flatten()] assert len(ch_names) == len(neighbors) adjacency = _ch_neighbor_adjacency(ch_names, neighbors) # picking before constructing matrix is buggy @@ -1534,8 +1726,8 @@ def read_ch_adjacency(fname, picks=None): # make sure MEG channel names contain space after "MEG" for idx, ch_name in enumerate(ch_names): - if ch_name.startswith('MEG') and not ch_name[3] == ' ': - ch_name = ch_name.replace('MEG', 'MEG ') + if ch_name.startswith("MEG") and not ch_name[3] == " ": + ch_name = ch_name.replace("MEG", "MEG ") ch_names[idx] = ch_name return adjacency, ch_names @@ -1559,19 +1751,19 @@ def _ch_neighbor_adjacency(ch_names, neighbors): The adjacency matrix. """ from scipy import sparse + if len(ch_names) != len(neighbors): - raise ValueError('`ch_names` and `neighbors` must ' - 'have the same length') + raise ValueError("`ch_names` and `neighbors` must " "have the same length") set_neighbors = {c for d in neighbors for c in d} rest = set_neighbors - set(ch_names) if len(rest) > 0: - raise ValueError('Some of your neighbors are not present in the ' - 'list of channel names') + raise ValueError( + "Some of your neighbors are not present in the " "list of channel names" + ) for neigh in neighbors: - if (not isinstance(neigh, list) and - not all(isinstance(c, str) for c in neigh)): - raise ValueError('`neighbors` must be a list of lists of str') + if not isinstance(neigh, list) and not all(isinstance(c, str) for c in neigh): + raise ValueError("`neighbors` must be a list of lists of str") ch_adjacency = np.eye(len(ch_names), dtype=bool) for ii, neigbs in enumerate(neighbors): @@ -1634,49 +1826,64 @@ def find_ch_adjacency(info, ch_type): if ch_type is None: picks = channel_indices_by_type(info) if sum([len(p) != 0 for p in picks.values()]) != 1: - raise ValueError('info must contain only one channel type if ' - 'ch_type is None.') + raise ValueError( + "info must contain only one channel type if " "ch_type is None." + ) ch_type = channel_type(info, 0) else: - _check_option('ch_type', ch_type, ['mag', 'grad', 'eeg']) - (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, - has_neuromag_122_grad, has_csd_coils) = _get_ch_info(info) + _check_option("ch_type", ch_type, ["mag", "grad", "eeg"]) + ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) = _get_ch_info(info) conn_name = None - if has_vv_mag and ch_type == 'mag': - conn_name = 'neuromag306mag' - elif has_vv_grad and ch_type == 'grad': - conn_name = 'neuromag306planar' + if has_vv_mag and ch_type == "mag": + conn_name = "neuromag306mag" + elif has_vv_grad and ch_type == "grad": + conn_name = "neuromag306planar" elif has_4D_mag: - if 'MEG 248' in info['ch_names']: - idx = info['ch_names'].index('MEG 248') - grad = info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_GRAD - mag = info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_MAG - if ch_type == 'grad' and grad: - conn_name = 'bti248grad' - elif ch_type == 'mag' and mag: - conn_name = 'bti248' - elif 'MEG 148' in info['ch_names'] and ch_type == 'mag': - idx = info['ch_names'].index('MEG 148') - if info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_MAG: - conn_name = 'bti148' - elif has_CTF_grad and ch_type == 'mag': - if info['nchan'] < 100: - conn_name = 'ctf64' - elif info['nchan'] > 200: - conn_name = 'ctf275' + if "MEG 248" in info["ch_names"]: + idx = info["ch_names"].index("MEG 248") + grad = info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_GRAD + mag = info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_MAG + if ch_type == "grad" and grad: + conn_name = "bti248grad" + elif ch_type == "mag" and mag: + conn_name = "bti248" + elif "MEG 148" in info["ch_names"] and ch_type == "mag": + idx = info["ch_names"].index("MEG 148") + if info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_MAG: + conn_name = "bti148" + elif has_CTF_grad and ch_type == "mag": + if info["nchan"] < 100: + conn_name = "ctf64" + elif info["nchan"] > 200: + conn_name = "ctf275" else: - conn_name = 'ctf151' + conn_name = "ctf151" elif n_kit_grads > 0: from ..io.kit.constants import KIT_NEIGHBORS - conn_name = KIT_NEIGHBORS.get(info['kit_system_id']) + + conn_name = KIT_NEIGHBORS.get(info["kit_system_id"]) if conn_name is not None: - logger.info(f'Reading adjacency matrix for {conn_name}.') + logger.info(f"Reading adjacency matrix for {conn_name}.") return read_ch_adjacency(conn_name) - logger.info('Could not find a adjacency matrix for the data. ' - 'Computing adjacency based on Delaunay triangulations.') + logger.info( + "Could not find a adjacency matrix for the data. " + "Computing adjacency based on Delaunay triangulations." + ) return _compute_ch_adjacency(info, ch_type) @@ -1702,21 +1909,24 @@ def _compute_ch_adjacency(info, ch_type): from scipy.spatial import Delaunay from .. import spatial_tris_adjacency from ..channels.layout import _find_topomap_coords, _pair_grad_sensors - combine_grads = (ch_type == 'grad' - and any([coil_type in [ch['coil_type'] - for ch in info['chs']] - for coil_type in - [FIFF.FIFFV_COIL_VV_PLANAR_T1, - FIFF.FIFFV_COIL_NM_122]])) + + combine_grads = ch_type == "grad" and any( + [ + coil_type in [ch["coil_type"] for ch in info["chs"]] + for coil_type in [FIFF.FIFFV_COIL_VV_PLANAR_T1, FIFF.FIFFV_COIL_NM_122] + ] + ) picks = dict(_picks_by_type(info, exclude=[]))[ch_type] - ch_names = [info['ch_names'][pick] for pick in picks] + ch_names = [info["ch_names"][pick] for pick in picks] if combine_grads: pairs = _pair_grad_sensors(info, topomap_coords=False, exclude=[]) if len(pairs) != len(picks): - raise RuntimeError('Cannot find a pair for some of the ' - 'gradiometers. Cannot compute adjacency ' - 'matrix.') + raise RuntimeError( + "Cannot find a pair for some of the " + "gradiometers. Cannot compute adjacency " + "matrix." + ) # only for one of the pair xy = _find_topomap_coords(info, picks[::2], sphere=HEAD_SIZE_DEFAULT) else: @@ -1774,26 +1984,26 @@ def fix_mag_coil_types(info, use_cal=False): old_mag_inds = _get_T1T2_mag_inds(info, use_cal) for ii in old_mag_inds: - info['chs'][ii]['coil_type'] = FIFF.FIFFV_COIL_VV_MAG_T3 - logger.info('%d of %d magnetometer types replaced with T3.' % - (len(old_mag_inds), - len(pick_types(info, meg='mag', exclude=[])))) + info["chs"][ii]["coil_type"] = FIFF.FIFFV_COIL_VV_MAG_T3 + logger.info( + "%d of %d magnetometer types replaced with T3." + % (len(old_mag_inds), len(pick_types(info, meg="mag", exclude=[]))) + ) info._check_consistency() def _get_T1T2_mag_inds(info, use_cal=False): """Find T1/T2 magnetometer coil types.""" - picks = pick_types(info, meg='mag', exclude=[]) + picks = pick_types(info, meg="mag", exclude=[]) old_mag_inds = [] # From email exchanges, systems with the larger T2 coil only use the cal # value of 2.09e-11. Newer T3 magnetometers use 4.13e-11 or 1.33e-10 # (Triux). So we can use a simple check for > 3e-11. for ii in picks: - ch = info['chs'][ii] - if ch['coil_type'] in (FIFF.FIFFV_COIL_VV_MAG_T1, - FIFF.FIFFV_COIL_VV_MAG_T2): + ch = info["chs"][ii] + if ch["coil_type"] in (FIFF.FIFFV_COIL_VV_MAG_T1, FIFF.FIFFV_COIL_VV_MAG_T2): if use_cal: - if ch['cal'] > 3e-11: + if ch["cal"] > 3e-11: old_mag_inds.append(ii) else: old_mag_inds.append(ii) @@ -1802,47 +2012,72 @@ def _get_T1T2_mag_inds(info, use_cal=False): def _get_ch_info(info): """Get channel info for inferring acquisition device.""" - chs = info['chs'] + chs = info["chs"] # Only take first 16 bits, as higher bits store CTF comp order - coil_types = {ch['coil_type'] & 0xFFFF for ch in chs} - channel_types = {ch['kind'] for ch in chs} - - has_vv_mag = any(k in coil_types for k in - [FIFF.FIFFV_COIL_VV_MAG_T1, FIFF.FIFFV_COIL_VV_MAG_T2, - FIFF.FIFFV_COIL_VV_MAG_T3]) - has_vv_grad = any(k in coil_types for k in [FIFF.FIFFV_COIL_VV_PLANAR_T1, - FIFF.FIFFV_COIL_VV_PLANAR_T2, - FIFF.FIFFV_COIL_VV_PLANAR_T3]) - has_neuromag_122_grad = any(k in coil_types - for k in [FIFF.FIFFV_COIL_NM_122]) - - is_old_vv = ' ' in chs[0]['ch_name'] + coil_types = {ch["coil_type"] & 0xFFFF for ch in chs} + channel_types = {ch["kind"] for ch in chs} + + has_vv_mag = any( + k in coil_types + for k in [ + FIFF.FIFFV_COIL_VV_MAG_T1, + FIFF.FIFFV_COIL_VV_MAG_T2, + FIFF.FIFFV_COIL_VV_MAG_T3, + ] + ) + has_vv_grad = any( + k in coil_types + for k in [ + FIFF.FIFFV_COIL_VV_PLANAR_T1, + FIFF.FIFFV_COIL_VV_PLANAR_T2, + FIFF.FIFFV_COIL_VV_PLANAR_T3, + ] + ) + has_neuromag_122_grad = any(k in coil_types for k in [FIFF.FIFFV_COIL_NM_122]) + + is_old_vv = " " in chs[0]["ch_name"] has_4D_mag = FIFF.FIFFV_COIL_MAGNES_MAG in coil_types - ctf_other_types = (FIFF.FIFFV_COIL_CTF_REF_MAG, - FIFF.FIFFV_COIL_CTF_REF_GRAD, - FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD) - has_CTF_grad = (FIFF.FIFFV_COIL_CTF_GRAD in coil_types or - (FIFF.FIFFV_MEG_CH in channel_types and - any(k in ctf_other_types for k in coil_types))) + ctf_other_types = ( + FIFF.FIFFV_COIL_CTF_REF_MAG, + FIFF.FIFFV_COIL_CTF_REF_GRAD, + FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD, + ) + has_CTF_grad = FIFF.FIFFV_COIL_CTF_GRAD in coil_types or ( + FIFF.FIFFV_MEG_CH in channel_types + and any(k in ctf_other_types for k in coil_types) + ) # hack due to MNE-C bug in IO of CTF # only take first 16 bits, as higher bits store CTF comp order - n_kit_grads = sum(ch['coil_type'] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD - for ch in chs) - - has_any_meg = any([has_vv_mag, has_vv_grad, has_4D_mag, has_CTF_grad, - n_kit_grads]) - has_eeg_coils = (FIFF.FIFFV_COIL_EEG in coil_types and - FIFF.FIFFV_EEG_CH in channel_types) + n_kit_grads = sum( + ch["coil_type"] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD for ch in chs + ) + + has_any_meg = any([has_vv_mag, has_vv_grad, has_4D_mag, has_CTF_grad, n_kit_grads]) + has_eeg_coils = ( + FIFF.FIFFV_COIL_EEG in coil_types and FIFF.FIFFV_EEG_CH in channel_types + ) has_eeg_coils_and_meg = has_eeg_coils and has_any_meg has_eeg_coils_only = has_eeg_coils and not has_any_meg - has_csd_coils = (FIFF.FIFFV_COIL_EEG_CSD in coil_types and - FIFF.FIFFV_EEG_CH in channel_types) - - return (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, has_neuromag_122_grad, - has_csd_coils) + has_csd_coils = ( + FIFF.FIFFV_COIL_EEG_CSD in coil_types and FIFF.FIFFV_EEG_CH in channel_types + ) + + return ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) @fill_doc @@ -1884,6 +2119,7 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): try: from .layout import find_layout + layout = find_layout(info) pos = layout.pos ch_names = layout.names @@ -1905,8 +2141,10 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): if pos is not None: # sort channels from front to center # (y-coordinate of the position info in the layout) - selections = {selection: np.array(picks)[pos[picks, 1].argsort()] - for selection, picks in selections.items()} + selections = { + selection: np.array(picks)[pos[picks, 1].argsort()] + for selection, picks in selections.items() + } # convert channel indices to names if requested if return_ch_names: @@ -1917,8 +2155,9 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): @verbose -def combine_channels(inst, groups, method='mean', keep_stim=False, - drop_bad=False, verbose=None): +def combine_channels( + inst, groups, method="mean", keep_stim=False, drop_bad=False, verbose=None +): """Combine channels based on specified channel grouping. Parameters @@ -1969,8 +2208,8 @@ def combine_channels(inst, groups, method='mean', keep_stim=False, from .. import BaseEpochs, EpochsArray, Evoked, EvokedArray ch_axis = 1 if isinstance(inst, BaseEpochs) else 0 - ch_idx = list(range(inst.info['nchan'])) - ch_names = inst.info['ch_names'] + ch_idx = list(range(inst.info["nchan"])) + ch_names = inst.info["ch_names"] ch_types = inst.get_channel_types() inst_data = inst.data if isinstance(inst, Evoked) else inst.get_data() groups = OrderedDict(deepcopy(groups)) @@ -1978,99 +2217,121 @@ def combine_channels(inst, groups, method='mean', keep_stim=False, # Convert string values of ``method`` into callables # XXX Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py if isinstance(method, str): - method_dict = {key: partial(getattr(np, key), axis=ch_axis) - for key in ('mean', 'median', 'std')} + method_dict = { + key: partial(getattr(np, key), axis=ch_axis) + for key in ("mean", "median", "std") + } try: method = method_dict[method] except KeyError: - raise ValueError('"method" must be a callable, or one of "mean", ' - f'"median", or "std"; got "{method}".') + raise ValueError( + '"method" must be a callable, or one of "mean", ' + f'"median", or "std"; got "{method}".' + ) # Instantiate channel info and data new_ch_names, new_ch_types, new_data = [], [], [] if not isinstance(keep_stim, bool): - raise TypeError('"keep_stim" must be of type bool, not ' - f'{type(keep_stim)}.') + raise TypeError('"keep_stim" must be of type bool, not ' f"{type(keep_stim)}.") if keep_stim: stim_ch_idx = list(pick_types(inst.info, meg=False, stim=True)) if stim_ch_idx: new_ch_names = [ch_names[idx] for idx in stim_ch_idx] new_ch_types = [ch_types[idx] for idx in stim_ch_idx] - new_data = [np.take(inst_data, idx, axis=ch_axis) - for idx in stim_ch_idx] + new_data = [np.take(inst_data, idx, axis=ch_axis) for idx in stim_ch_idx] else: - warn('Could not find stimulus channels.') + warn("Could not find stimulus channels.") # Get indices of bad channels ch_idx_bad = [] if not isinstance(drop_bad, bool): - raise TypeError('"drop_bad" must be of type bool, not ' - f'{type(drop_bad)}.') - if drop_bad and inst.info['bads']: - ch_idx_bad = pick_channels(ch_names, inst.info['bads']) + raise TypeError('"drop_bad" must be of type bool, not ' f"{type(drop_bad)}.") + if drop_bad and inst.info["bads"]: + ch_idx_bad = pick_channels(ch_names, inst.info["bads"]) # Check correctness of combinations for this_group, this_picks in groups.items(): # Check if channel indices are out of bounds if not all(idx in ch_idx for idx in this_picks): - raise ValueError('Some channel indices are out of bounds.') + raise ValueError("Some channel indices are out of bounds.") # Check if heterogeneous sensor type combinations this_ch_type = np.array(ch_types)[this_picks] if len(set(this_ch_type)) > 1: - types = ', '.join(set(this_ch_type)) - raise ValueError('Cannot combine sensors of different types; ' - f'"{this_group}" contains types {types}.') + types = ", ".join(set(this_ch_type)) + raise ValueError( + "Cannot combine sensors of different types; " + f'"{this_group}" contains types {types}.' + ) # Remove bad channels these_bads = [idx for idx in this_picks if idx in ch_idx_bad] this_picks = [idx for idx in this_picks if idx not in ch_idx_bad] if these_bads: - logger.info('Dropped the following channels in group ' - f'{this_group}: {these_bads}') + logger.info( + "Dropped the following channels in group " f"{this_group}: {these_bads}" + ) # Check if combining less than 2 channel if len(set(this_picks)) < 2: - warn(f'Less than 2 channels in group "{this_group}" when ' - f'combining by method "{method}".') + warn( + f'Less than 2 channels in group "{this_group}" when ' + f'combining by method "{method}".' + ) # If all good create more detailed dict without bad channels groups[this_group] = dict(picks=this_picks, ch_type=this_ch_type[0]) # Combine channels and add them to the new instance for this_group, this_group_dict in groups.items(): new_ch_names.append(this_group) - new_ch_types.append(this_group_dict['ch_type']) - this_picks = this_group_dict['picks'] + new_ch_types.append(this_group_dict["ch_type"]) + this_picks = this_group_dict["picks"] this_data = np.take(inst_data, this_picks, axis=ch_axis) new_data.append(method(this_data)) new_data = np.swapaxes(new_data, 0, ch_axis) - info = create_info(sfreq=inst.info['sfreq'], ch_names=new_ch_names, - ch_types=new_ch_types) + info = create_info( + sfreq=inst.info["sfreq"], ch_names=new_ch_names, ch_types=new_ch_types + ) # create new instances and make sure to copy important attributes if isinstance(inst, BaseRaw): combined_inst = RawArray(new_data, info, first_samp=inst.first_samp) elif isinstance(inst, BaseEpochs): - combined_inst = EpochsArray(new_data, info, events=inst.events, - tmin=inst.times[0], baseline=inst.baseline) + combined_inst = EpochsArray( + new_data, + info, + events=inst.events, + tmin=inst.times[0], + baseline=inst.baseline, + ) if inst.metadata is not None: combined_inst.metadata = inst.metadata.copy() elif isinstance(inst, Evoked): - combined_inst = EvokedArray(new_data, info, tmin=inst.times[0], - baseline=inst.baseline) + combined_inst = EvokedArray( + new_data, info, tmin=inst.times[0], baseline=inst.baseline + ) return combined_inst # NeuroMag channel groupings -_SELECTIONS = ['Vertex', 'Left-temporal', 'Right-temporal', 'Left-parietal', - 'Right-parietal', 'Left-occipital', 'Right-occipital', - 'Left-frontal', 'Right-frontal'] -_EEG_SELECTIONS = ['EEG 1-32', 'EEG 33-64', 'EEG 65-96', 'EEG 97-128'] +_SELECTIONS = [ + "Vertex", + "Left-temporal", + "Right-temporal", + "Left-parietal", + "Right-parietal", + "Left-occipital", + "Right-occipital", + "Left-frontal", + "Right-frontal", +] +_EEG_SELECTIONS = ["EEG 1-32", "EEG 33-64", "EEG 65-96", "EEG 97-128"] def _divide_to_regions(info, add_stim=True): """Divide channels to regions by positions.""" from scipy.stats import zscore + picks = _pick_data_channels(info, exclude=[]) chs_in_lobe = len(picks) // 4 - pos = np.array([ch['loc'][:3] for ch in info['chs']]) + pos = np.array([ch["loc"][:3] for ch in info["chs"]]) x, y, z = pos.T frontal = picks[np.argsort(y[picks])[-chs_in_lobe:]] @@ -2090,14 +2351,14 @@ def _divide_to_regions(info, add_stim=True): # Because of the way the sides are divided, there may be outliers in the # temporal lobes. Here we switch the sides for these outliers. For other # lobes it is not a big problem because of the vicinity of the lobes. - with np.errstate(invalid='ignore'): # invalid division, greater compare + with np.errstate(invalid="ignore"): # invalid division, greater compare zs = np.abs(zscore(x[rt])) - outliers = np.array(rt)[np.where(zs > 2.)[0]] + outliers = np.array(rt)[np.where(zs > 2.0)[0]] rt = list(np.setdiff1d(rt, outliers)) - with np.errstate(invalid='ignore'): # invalid division, greater compare + with np.errstate(invalid="ignore"): # invalid division, greater compare zs = np.abs(zscore(x[lt])) - outliers = np.append(outliers, (np.array(lt)[np.where(zs > 2.)[0]])) + outliers = np.append(outliers, (np.array(lt)[np.where(zs > 2.0)[0]])) lt = list(np.setdiff1d(lt, outliers)) l_mean = np.mean(x[lt]) @@ -2112,11 +2373,19 @@ def _divide_to_regions(info, add_stim=True): stim_ch = _get_stim_channel(None, info, raise_error=False) if len(stim_ch) > 0: for region in [lf, rf, lo, ro, lp, rp, lt, rt]: - region.append(info['ch_names'].index(stim_ch[0])) - return OrderedDict([('Left-frontal', lf), ('Right-frontal', rf), - ('Left-parietal', lp), ('Right-parietal', rp), - ('Left-occipital', lo), ('Right-occipital', ro), - ('Left-temporal', lt), ('Right-temporal', rt)]) + region.append(info["ch_names"].index(stim_ch[0])) + return OrderedDict( + [ + ("Left-frontal", lf), + ("Right-frontal", rf), + ("Left-parietal", lp), + ("Right-parietal", rp), + ("Left-occipital", lo), + ("Right-occipital", ro), + ("Left-temporal", lt), + ("Right-temporal", rt), + ] + ) def _divide_side(lobe, x): @@ -2165,42 +2434,44 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): name = [name] if isinstance(info, Info): picks = pick_types(info, meg=True, exclude=()) - if len(picks) > 0 and ' ' not in info['ch_names'][picks[0]]: - spacing = 'new' + if len(picks) > 0 and " " not in info["ch_names"][picks[0]]: + spacing = "new" else: - spacing = 'old' + spacing = "old" elif info is not None: - raise TypeError('info must be an instance of Info or None, not %s' - % (type(info),)) + raise TypeError( + "info must be an instance of Info or None, not %s" % (type(info),) + ) else: # info is None - spacing = 'old' + spacing = "old" # use built-in selections by default if fname is None: - fname = op.join(op.dirname(__file__), '..', 'data', 'mne_analyze.sel') + fname = op.join(op.dirname(__file__), "..", "data", "mne_analyze.sel") fname = str(_check_fname(fname, must_exist=True, overwrite="read")) # use this to make sure we find at least one match for each name name_found = {n: False for n in name} - with open(fname, 'r') as fid: + with open(fname, "r") as fid: sel = [] for line in fid: line = line.strip() # skip blank lines and comments - if len(line) == 0 or line[0] == '#': + if len(line) == 0 or line[0] == "#": continue # get the name of the selection in the file - pos = line.find(':') + pos = line.find(":") if pos < 0: - logger.info('":" delimiter not found in selections file, ' - 'skipping line') + logger.info( + '":" delimiter not found in selections file, ' "skipping line" + ) continue sel_name_file = line[:pos] # search for substring match with name provided for n in name: if sel_name_file.find(n) >= 0: - sel.extend(line[pos + 1:].split('|')) + sel.extend(line[pos + 1 :].split("|")) name_found[n] = True break @@ -2212,6 +2483,6 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): # make the selection a sorted list with unique elements sel = list(set(sel)) sel.sort() - if spacing == 'new': # "new" or "old" by now, "old" is default - sel = [s.replace('MEG ', 'MEG') for s in sel] + if spacing == "new": # "new" or "old" by now, "old" is default + sel = [s.replace("MEG ", "MEG") for s in sel] return sel diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index d8c0a2be78a..f9dc0319992 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -26,9 +26,10 @@ def _calc_h(cosang, stiffness=4, n_legendre_terms=50): n_legendre_terms : int number of Legendre terms to evaluate. """ - factors = [(2 * n + 1) / - (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi) - for n in range(1, n_legendre_terms + 1)] + factors = [ + (2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi) + for n in range(1, n_legendre_terms + 1) + ] return legval(cosang, [0] + factors) @@ -50,9 +51,10 @@ def _calc_g(cosang, stiffness=4, n_legendre_terms=50): G : np.ndrarray of float, shape(n_channels, n_channels) The G matrix. """ - factors = [(2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness * - 4 * np.pi) - for n in range(1, n_legendre_terms + 1)] + factors = [ + (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi) + for n in range(1, n_legendre_terms + 1) + ] return legval(cosang, [0] + factors) @@ -83,6 +85,7 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7. """ from scipy import linalg + pos_from = pos_from.copy() pos_to = pos_to.copy() n_from = pos_from.shape[0] @@ -101,10 +104,14 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): assert G_to_from.shape == (n_to, n_from) if alpha is not None: - G_from.flat[::len(G_from) + 1] += alpha - - C = np.vstack([np.hstack([G_from, np.ones((n_from, 1))]), - np.hstack([np.ones((1, n_from)), [[0]]])]) + G_from.flat[:: len(G_from) + 1] += alpha + + C = np.vstack( + [ + np.hstack([G_from, np.ones((n_from, 1))]), + np.hstack([np.ones((1, n_from)), [[0]]]), + ] + ) C_inv = linalg.pinv(C) interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1] @@ -117,9 +124,11 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): from ..io.base import BaseRaw from ..epochs import BaseEpochs from ..evoked import Evoked - _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), 'inst') + + _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst") inst._data[..., bads_idx, :] = np.matmul( - interpolation, inst._data[..., goods_idx, :]) + interpolation, inst._data[..., goods_idx, :] + ) @verbose @@ -131,7 +140,7 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude) inst.info._check_consistency() - bads_idx[picks] = [inst.ch_names[ch] in inst.info['bads'] for ch in picks] + bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] if len(picks) == 0 or bads_idx.sum() == 0: return @@ -148,30 +157,43 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): # test spherical fit distance = np.linalg.norm(pos - origin, axis=-1) distance = np.mean(distance / np.mean(distance)) - if np.abs(1. - distance) > 0.1: - warn('Your spherical fit is poor, interpolation results are ' - 'likely to be inaccurate.') + if np.abs(1.0 - distance) > 0.1: + warn( + "Your spherical fit is poor, interpolation results are " + "likely to be inaccurate." + ) pos_good = pos[goods_idx_pos] - origin pos_bad = pos[bads_idx_pos] - origin - logger.info('Computing interpolation matrix from {} sensor ' - 'positions'.format(len(pos_good))) + logger.info( + "Computing interpolation matrix from {} sensor " + "positions".format(len(pos_good)) + ) interpolation = _make_interpolation_matrix(pos_good, pos_bad) - logger.info('Interpolating {} sensors'.format(len(pos_bad))) + logger.info("Interpolating {} sensors".format(len(pos_bad))) _do_interp_dots(inst, interpolation, goods_idx, bads_idx) -def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04), - verbose=None, ref_meg=False): +def _interpolate_bads_meg( + inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False +): return _interpolate_bads_meeg( - inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose) + inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose + ) @verbose -def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), - meg=True, eeg=True, ref_meg=False, - exclude=(), verbose=None): +def _interpolate_bads_meeg( + inst, + mode="accurate", + origin=(0.0, 0.0, 0.04), + meg=True, + eeg=True, + ref_meg=False, + exclude=(), + verbose=None, +): bools = dict(meg=meg, eeg=eeg) info = _simplify_info(inst.info) for ch_type, do in bools.items(): @@ -180,15 +202,14 @@ def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), kw = dict(meg=False, eeg=False) kw[ch_type] = True picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw) - picks_good = pick_types(info, ref_meg=ref_meg, exclude='bads', **kw) - use_ch_names = [inst.info['ch_names'][p] for p in picks_type] - bads_type = [ch for ch in inst.info['bads'] if ch in use_ch_names] + picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw) + use_ch_names = [inst.info["ch_names"][p] for p in picks_type] + bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] if len(bads_type) == 0 or len(picks_type) == 0: continue # select the bad channels to be interpolated - picks_bad = pick_channels(inst.info['ch_names'], bads_type, - exclude=[]) - if ch_type == 'eeg': + picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) + if ch_type == "eeg": picks_to = picks_type bad_sel = np.in1d(picks_type, picks_bad) else: @@ -196,14 +217,13 @@ def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), bad_sel = slice(None) info_from = pick_info(inst.info, picks_good) info_to = pick_info(inst.info, picks_to) - mapping = _map_meg_or_eeg_channels( - info_from, info_to, mode=mode, origin=origin) + mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin) mapping = mapping[bad_sel] _do_interp_dots(inst, mapping, picks_good, picks_bad) @verbose -def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): +def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): from scipy.spatial.distance import pdist, squareform from mne.preprocessing.nirs import _validate_nirs_info @@ -212,21 +232,20 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): # Returns pick of all nirs and ensures channels are correctly ordered picks_nirs = _validate_nirs_info(inst.info) - nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs] + nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs] nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude] - bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names] + bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names] if len(bads_nirs) == 0: return - picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[]) + picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[]) bads_mask = [p in picks_bad for p in picks_nirs] - chs = [inst.info['chs'][i] for i in picks_nirs] - locs3d = np.array([ch['loc'][:3] for ch in chs]) - - _check_option('fnirs_method', method, ['nearest']) + chs = [inst.info["chs"][i] for i in picks_nirs] + locs3d = np.array([ch["loc"][:3] for ch in chs]) - if method == 'nearest': + _check_option("fnirs_method", method, ["nearest"]) + if method == "nearest": dist = pdist(locs3d) dist = squareform(dist) @@ -240,6 +259,6 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): closest_idx = np.argmin(dists_to_bad) + (bad % 2) inst._data[bad] = inst._data[closest_idx] - inst.info['bads'] = [ch for ch in inst.info['bads'] if ch in exclude] + inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] return inst diff --git a/mne/channels/layout.py b/mne/channels/layout.py index e59bb80a2a1..9be79f33102 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -20,8 +20,16 @@ from ..io.pick import pick_types, _picks_to_idx, _FNIRS_CH_TYPES_SPLIT from ..io.constants import FIFF from ..io.meas_info import Info -from ..utils import (_clean_names, warn, _check_ch_locs, fill_doc, - _check_fname, _check_option, _check_sphere, logger) +from ..utils import ( + _clean_names, + warn, + _check_ch_locs, + fill_doc, + _check_fname, + _check_option, + _check_sphere, + logger, +) from .channels import _get_ch_info @@ -74,26 +82,32 @@ def save(self, fname, overwrite=False): height = self.pos[:, 3] fname = _check_fname(fname, overwrite=overwrite, name=fname) if fname.suffix == ".lout": - out_str = '%8.2f %8.2f %8.2f %8.2f\n' % self.box + out_str = "%8.2f %8.2f %8.2f %8.2f\n" % self.box elif fname.suffix == ".lay": - out_str = '' + out_str = "" else: - raise ValueError('Unknown layout type. Should be of type ' - '.lout or .lay.') + raise ValueError("Unknown layout type. Should be of type " ".lout or .lay.") for ii in range(x.shape[0]): - out_str += ('%03d %8.2f %8.2f %8.2f %8.2f %s\n' - % (self.ids[ii], x[ii], y[ii], - width[ii], height[ii], self.names[ii])) + out_str += "%03d %8.2f %8.2f %8.2f %8.2f %s\n" % ( + self.ids[ii], + x[ii], + y[ii], + width[ii], + height[ii], + self.names[ii], + ) - f = open(fname, 'w') + f = open(fname, "w") f.write(out_str) f.close() def __repr__(self): """Return the string representation.""" - return '' % (self.kind, - ', '.join(self.names[:3])) + return "" % ( + self.kind, + ", ".join(self.names[:3]), + ) @fill_doc def plot(self, picks=None, show_axes=False, show=True): @@ -117,6 +131,7 @@ def plot(self, picks=None, show_axes=False, show=True): .. versionadded:: 0.12.0 """ from ..viz.topomap import plot_layout + return plot_layout(self, picks=picks, show_axes=show_axes, show=show) @@ -130,7 +145,7 @@ def _read_lout(fname): splits = line.split() if len(splits) == 7: cid, x, y, dx, dy, chkind, nb = splits - name = chkind + ' ' + nb + name = chkind + " " + nb else: cid, x, y, dx, dy, name = splits pos.append(np.array([x, y, dx, dy], dtype=np.float64)) @@ -151,7 +166,7 @@ def _read_lay(fname): splits = line.split() if len(splits) == 7: cid, x, y, dx, dy, chkind, nb = splits - name = chkind + ' ' + nb + name = chkind + " " + nb else: cid, x, y, dx, dy, name = splits pos.append(np.array([x, y, dx, dy], dtype=np.float64)) @@ -263,22 +278,14 @@ def read_layout(fname=None, path="", scale=True, *, kind=None): # kind should be the name as a string, but let's consider the case # where the path to the file is provided instead. kind = Path(kind) - if ( - len(kind.suffix) == 0 - and (path / kind.with_suffix(".lout")).exists() - ): + if len(kind.suffix) == 0 and (path / kind.with_suffix(".lout")).exists(): kind = kind.with_suffix(".lout") - elif ( - len(kind.suffix) == 0 - and (path / kind.with_suffix(".lay")).exists() - ): + elif len(kind.suffix) == 0 and (path / kind.with_suffix(".lay")).exists(): kind = kind.with_suffix(".lay") fname = kind if kind.exists() else path / kind.name if fname.suffix not in (".lout", ".lay"): - raise ValueError( - "Unknown layout type. Should be of type .lout or .lay." - ) + raise ValueError("Unknown layout type. Should be of type .lout or .lay.") kind = fname.stem else: # to be removed along the deprecated argument @@ -317,8 +324,9 @@ def read_layout(fname=None, path="", scale=True, *, kind=None): @fill_doc -def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', - csd=False): +def make_eeg_layout( + info, radius=0.5, width=None, height=None, exclude="bads", csd=False +): """Create .lout file from EEG electrode digitization. Parameters @@ -348,18 +356,18 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', make_grid_layout, generate_2d_layout """ if not (0 <= radius <= 0.5): - raise ValueError('The radius parameter should be between 0 and 0.5.') + raise ValueError("The radius parameter should be between 0 and 0.5.") if width is not None and not (0 <= width <= 1.0): - raise ValueError('The width parameter should be between 0 and 1.') + raise ValueError("The width parameter should be between 0 and 1.") if height is not None and not (0 <= height <= 1.0): - raise ValueError('The height parameter should be between 0 and 1.') + raise ValueError("The height parameter should be between 0 and 1.") pick_kwargs = dict(meg=False, eeg=True, ref_meg=False, exclude=exclude) if csd: pick_kwargs.update(csd=True, eeg=False) picks = pick_types(info, **pick_kwargs) loc2d = _find_topomap_coords(info, picks) - names = [info['chs'][i]['ch_name'] for i in picks] + names = [info["chs"][i]["ch_name"] for i in picks] # Scale [x, y] to be in the range [-0.5, 0.5] # Don't mess with the origin or aspect ratio @@ -376,7 +384,7 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', # Some subplot centers will be at the figure edge. Shrink everything so it # fits in the figure. - scaling = min(1 / (1. + width), 1 / (1. + height)) + scaling = min(1 / (1.0 + width), 1 / (1.0 + height)) loc2d *= scaling width *= scaling height *= scaling @@ -385,14 +393,16 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', loc2d += 0.5 n_channels = loc2d.shape[0] - pos = np.c_[loc2d[:, 0] - 0.5 * width, - loc2d[:, 1] - 0.5 * height, - width * np.ones(n_channels), - height * np.ones(n_channels)] + pos = np.c_[ + loc2d[:, 0] - 0.5 * width, + loc2d[:, 1] - 0.5 * height, + width * np.ones(n_channels), + height * np.ones(n_channels), + ] box = (0, 1, 0, 1) ids = 1 + np.arange(n_channels) - layout = Layout(box=box, pos=pos, names=names, kind='EEG', ids=ids) + layout = Layout(box=box, pos=pos, names=names, kind="EEG", ids=ids) return layout @@ -416,12 +426,12 @@ def make_grid_layout(info, picks=None, n_col=None): -------- make_eeg_layout, generate_2d_layout """ - picks = _picks_to_idx(info, picks, 'misc') + picks = _picks_to_idx(info, picks, "misc") - names = [info['chs'][k]['ch_name'] for k in picks] + names = [info["chs"][k]["ch_name"] for k in picks] if not names: - raise ValueError('No misc data channels found.') + raise ValueError("No misc data channels found.") ids = list(range(len(picks))) size = len(picks) @@ -439,16 +449,15 @@ def make_grid_layout(info, picks=None, n_col=None): n_row = int(np.ceil(size / float(n_col))) # setup position grid - x, y = np.meshgrid(np.linspace(-0.5, 0.5, n_col), - np.linspace(-0.5, 0.5, n_row)) + x, y = np.meshgrid(np.linspace(-0.5, 0.5, n_col), np.linspace(-0.5, 0.5, n_row)) x, y = x.ravel()[:size], y.ravel()[:size] width, height = _box_size(np.c_[x, y], padding=0.1) # Some axes will be at the figure edge. Shrink everything so it fits in the # figure. Add 0.01 border around everything border_x, border_y = (0.01, 0.01) - x_scaling = 1 / (1. + width + border_x) - y_scaling = 1 / (1. + height + border_y) + x_scaling = 1 / (1.0 + width + border_x) + y_scaling = 1 / (1.0 + height + border_y) x = x * x_scaling y = y * y_scaling width *= x_scaling @@ -459,16 +468,17 @@ def make_grid_layout(info, picks=None, n_col=None): y += 0.5 # calculate pos - pos = np.c_[x - 0.5 * width, y - 0.5 * height, - width * np.ones(size), height * np.ones(size)] + pos = np.c_[ + x - 0.5 * width, y - 0.5 * height, width * np.ones(size), height * np.ones(size) + ] box = (0, 1, 0, 1) - layout = Layout(box=box, pos=pos, names=names, kind='grid-misc', ids=ids) + layout = Layout(box=box, pos=pos, names=names, kind="grid-misc", ids=ids) return layout @fill_doc -def find_layout(info, ch_type=None, exclude='bads'): +def find_layout(info, ch_type=None, exclude="bads"): """Choose a layout based on the channels in the info 'chs' field. Parameters @@ -488,57 +498,70 @@ def find_layout(info, ch_type=None, exclude='bads'): layout : Layout instance | None None if layout not found. """ - _check_option('ch_type', ch_type, [None, 'mag', 'grad', 'meg', 'eeg', - 'csd']) - - (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, - has_neuromag_122_grad, has_csd_coils) = _get_ch_info(info) + _check_option("ch_type", ch_type, [None, "mag", "grad", "meg", "eeg", "csd"]) + + ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) = _get_ch_info(info) has_vv_meg = has_vv_mag and has_vv_grad has_vv_only_mag = has_vv_mag and not has_vv_grad has_vv_only_grad = has_vv_grad and not has_vv_mag if ch_type == "meg" and not has_any_meg: - raise RuntimeError('No MEG channels present. Cannot find MEG layout.') + raise RuntimeError("No MEG channels present. Cannot find MEG layout.") if ch_type == "eeg" and not has_eeg_coils: - raise RuntimeError('No EEG channels present. Cannot find EEG layout.') + raise RuntimeError("No EEG channels present. Cannot find EEG layout.") layout_name = None - if ((has_vv_meg and ch_type is None) or - (any([has_vv_mag, has_vv_grad]) and ch_type == 'meg')): - layout_name = 'Vectorview-all' - elif has_vv_only_mag or (has_vv_meg and ch_type == 'mag'): - layout_name = 'Vectorview-mag' - elif has_vv_only_grad or (has_vv_meg and ch_type == 'grad'): - if info['ch_names'][0].endswith('X'): - layout_name = 'Vectorview-grad_norm' + if (has_vv_meg and ch_type is None) or ( + any([has_vv_mag, has_vv_grad]) and ch_type == "meg" + ): + layout_name = "Vectorview-all" + elif has_vv_only_mag or (has_vv_meg and ch_type == "mag"): + layout_name = "Vectorview-mag" + elif has_vv_only_grad or (has_vv_meg and ch_type == "grad"): + if info["ch_names"][0].endswith("X"): + layout_name = "Vectorview-grad_norm" else: - layout_name = 'Vectorview-grad' + layout_name = "Vectorview-grad" elif has_neuromag_122_grad: - layout_name = 'Neuromag_122' - elif ((has_eeg_coils_only and ch_type in [None, 'eeg']) or - (has_eeg_coils_and_meg and ch_type == 'eeg')): + layout_name = "Neuromag_122" + elif (has_eeg_coils_only and ch_type in [None, "eeg"]) or ( + has_eeg_coils_and_meg and ch_type == "eeg" + ): if not isinstance(info, (dict, Info)): - raise RuntimeError('Cannot make EEG layout, no measurement info ' - 'was passed to `find_layout`') + raise RuntimeError( + "Cannot make EEG layout, no measurement info " + "was passed to `find_layout`" + ) return make_eeg_layout(info, exclude=exclude) - elif has_csd_coils and ch_type in [None, 'csd']: + elif has_csd_coils and ch_type in [None, "csd"]: return make_eeg_layout(info, exclude=exclude, csd=True) elif has_4D_mag: - layout_name = 'magnesWH3600' + layout_name = "magnesWH3600" elif has_CTF_grad: - layout_name = 'CTF-275' + layout_name = "CTF-275" elif n_kit_grads > 0: layout_name = _find_kit_layout(info, n_kit_grads) # If no known layout is found, fall back on automatic layout if layout_name is None: - picks = _picks_to_idx(info, 'data', exclude=(), with_ref_meg=False) - ch_names = [info['ch_names'][pick] for pick in picks] + picks = _picks_to_idx(info, "data", exclude=(), with_ref_meg=False) + ch_names = [info["ch_names"][pick] for pick in picks] xy = _find_topomap_coords(info, picks=picks, ignore_overlap=True) - return generate_2d_layout(xy, ch_names=ch_names, name='custom', - normalize=True) + return generate_2d_layout(xy, ch_names=ch_names, name="custom", normalize=True) layout = read_layout(fname=layout_name) if not is_old_vv: @@ -547,8 +570,8 @@ def find_layout(info, ch_type=None, exclude='bads'): layout.names = _clean_names(layout.names, before_dash=True) # Apply mask for excluded channels. - if exclude == 'bads': - exclude = info['bads'] + if exclude == "bads": + exclude = info["bads"] idx = [ii for ii, name in enumerate(layout.names) if name not in exclude] layout.names = [layout.names[ii] for ii in idx] layout.pos = layout.pos[idx] @@ -572,34 +595,69 @@ def _find_kit_layout(info, n_grads): kit_layout : str | None String naming the detected KIT layout or ``None`` if layout is missing. """ - if info['kit_system_id'] is not None: + if info["kit_system_id"] is not None: # avoid circular import from ..io.kit.constants import KIT_LAYOUT - return KIT_LAYOUT.get(info['kit_system_id']) + + return KIT_LAYOUT.get(info["kit_system_id"]) elif n_grads == 160: - return 'KIT-160' + return "KIT-160" elif n_grads == 125: - return 'KIT-125' + return "KIT-125" elif n_grads > 157: - return 'KIT-AD' + return "KIT-AD" # channels which are on the left hemisphere for NY and right for UMD - test_chs = ('MEG 13', 'MEG 14', 'MEG 15', 'MEG 16', 'MEG 25', - 'MEG 26', 'MEG 27', 'MEG 28', 'MEG 29', 'MEG 30', - 'MEG 31', 'MEG 32', 'MEG 57', 'MEG 60', 'MEG 61', - 'MEG 62', 'MEG 63', 'MEG 64', 'MEG 73', 'MEG 90', - 'MEG 93', 'MEG 95', 'MEG 96', 'MEG 105', 'MEG 112', - 'MEG 120', 'MEG 121', 'MEG 122', 'MEG 123', 'MEG 124', - 'MEG 125', 'MEG 126', 'MEG 142', 'MEG 144', 'MEG 153', - 'MEG 154', 'MEG 155', 'MEG 156') - x = [ch['loc'][0] < 0 for ch in info['chs'] if ch['ch_name'] in test_chs] + test_chs = ( + "MEG 13", + "MEG 14", + "MEG 15", + "MEG 16", + "MEG 25", + "MEG 26", + "MEG 27", + "MEG 28", + "MEG 29", + "MEG 30", + "MEG 31", + "MEG 32", + "MEG 57", + "MEG 60", + "MEG 61", + "MEG 62", + "MEG 63", + "MEG 64", + "MEG 73", + "MEG 90", + "MEG 93", + "MEG 95", + "MEG 96", + "MEG 105", + "MEG 112", + "MEG 120", + "MEG 121", + "MEG 122", + "MEG 123", + "MEG 124", + "MEG 125", + "MEG 126", + "MEG 142", + "MEG 144", + "MEG 153", + "MEG 154", + "MEG 155", + "MEG 156", + ) + x = [ch["loc"][0] < 0 for ch in info["chs"] if ch["ch_name"] in test_chs] if np.all(x): - return 'KIT-157' # KIT-NY + return "KIT-157" # KIT-NY elif np.all(np.invert(x)): - raise NotImplementedError("Guessing sensor layout for legacy UMD " - "files is not implemented. Please convert " - "your files using MNE-Python 0.13 or " - "higher.") + raise NotImplementedError( + "Guessing sensor layout for legacy UMD " + "files is not implemented. Please convert " + "your files using MNE-Python 0.13 or " + "higher." + ) else: raise RuntimeError("KIT system could not be determined for data") @@ -660,8 +718,7 @@ def ydiff(a, b): if height is None: # Find all axes that could potentially overlap horizontally. hdist = pdist(points, xdiff) - candidates = [all_combinations[i] for i, d in enumerate(hdist) - if d < width] + candidates = [all_combinations[i] for i, d in enumerate(hdist) if d < width] if len(candidates) == 0: # No axes overlap, take all the height you want. @@ -674,8 +731,7 @@ def ydiff(a, b): elif width is None: # Find all axes that could potentially overlap vertically. vdist = pdist(points, ydiff) - candidates = [all_combinations[i] for i, d in enumerate(vdist) - if d < height] + candidates = [all_combinations[i] for i, d in enumerate(vdist) if d < height] if len(candidates) == 0: # No axes overlap, take all the width you want. @@ -693,8 +749,9 @@ def ydiff(a, b): @fill_doc -def _find_topomap_coords(info, picks, layout=None, ignore_overlap=False, - to_sphere=True, sphere=None): +def _find_topomap_coords( + info, picks, layout=None, ignore_overlap=False, to_sphere=True, sphere=None +): """Guess the E/MEG layout and return appropriate topomap coordinates. Parameters @@ -714,16 +771,20 @@ def _find_topomap_coords(info, picks, layout=None, ignore_overlap=False, coords : array, shape = (n_chs, 2) 2 dimensional coordinates for each sensor for a topomap plot. """ - picks = _picks_to_idx(info, picks, 'all', exclude=(), allow_empty=False) + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) if layout is not None: - chs = [info['chs'][i] for i in picks] - pos = [layout.pos[layout.names.index(ch['ch_name'])] for ch in chs] + chs = [info["chs"][i] for i in picks] + pos = [layout.pos[layout.names.index(ch["ch_name"])] for ch in chs] pos = np.asarray(pos) else: pos = _auto_topomap_coords( - info, picks, ignore_overlap=ignore_overlap, to_sphere=to_sphere, - sphere=sphere) + info, + picks, + ignore_overlap=ignore_overlap, + to_sphere=to_sphere, + sphere=sphere, + ) return pos @@ -756,50 +817,64 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): An array of positions of the 2 dimensional map. """ from scipy.spatial.distance import pdist, squareform + sphere = _check_sphere(sphere, info) - logger.debug(f'Generating coords using: {sphere}') + logger.debug(f"Generating coords using: {sphere}") - picks = _picks_to_idx(info, picks, 'all', exclude=(), allow_empty=False) - chs = [info['chs'][i] for i in picks] + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) + chs = [info["chs"][i] for i in picks] # Use channel locations if available - locs3d = np.array([ch['loc'][:3] for ch in chs]) + locs3d = np.array([ch["loc"][:3] for ch in chs]) # If electrode locations are not available, use digization points if not _check_ch_locs(info=info, picks=picks): - logging.warning('Did not find any electrode locations (in the info ' - 'object), will attempt to use digitization points ' - 'instead. However, if digitization points do not ' - 'correspond to the EEG electrodes, this will lead to ' - 'bad results. Please verify that the sensor locations ' - 'in the plot are accurate.') + logging.warning( + "Did not find any electrode locations (in the info " + "object), will attempt to use digitization points " + "instead. However, if digitization points do not " + "correspond to the EEG electrodes, this will lead to " + "bad results. Please verify that the sensor locations " + "in the plot are accurate." + ) # MEG/EOG/ECG sensors don't have digitization points; all requested # channels must be EEG for ch in chs: - if ch['kind'] != FIFF.FIFFV_EEG_CH: - raise ValueError("Cannot determine location of MEG/EOG/ECG " - "channels using digitization points.") - - eeg_ch_names = [ch['ch_name'] for ch in info['chs'] - if ch['kind'] == FIFF.FIFFV_EEG_CH] + if ch["kind"] != FIFF.FIFFV_EEG_CH: + raise ValueError( + "Cannot determine location of MEG/EOG/ECG " + "channels using digitization points." + ) + + eeg_ch_names = [ + ch["ch_name"] for ch in info["chs"] if ch["kind"] == FIFF.FIFFV_EEG_CH + ] # Get EEG digitization points - if info['dig'] is None or len(info['dig']) == 0: - raise RuntimeError('No digitization points found.') - - locs3d = np.array([point['r'] for point in info['dig'] - if point['kind'] == FIFF.FIFFV_POINT_EEG]) + if info["dig"] is None or len(info["dig"]) == 0: + raise RuntimeError("No digitization points found.") + + locs3d = np.array( + [ + point["r"] + for point in info["dig"] + if point["kind"] == FIFF.FIFFV_POINT_EEG + ] + ) if len(locs3d) == 0: - raise RuntimeError('Did not find any digitization points of ' - 'kind FIFFV_POINT_EEG (%d) in the info.' - % FIFF.FIFFV_POINT_EEG) + raise RuntimeError( + "Did not find any digitization points of " + "kind FIFFV_POINT_EEG (%d) in the info." % FIFF.FIFFV_POINT_EEG + ) if len(locs3d) != len(eeg_ch_names): - raise ValueError("Number of EEG digitization points (%d) " - "doesn't match the number of EEG channels " - "(%d)" % (len(locs3d), len(eeg_ch_names))) + raise ValueError( + "Number of EEG digitization points (%d) " + "doesn't match the number of EEG channels " + "(%d)" % (len(locs3d), len(eeg_ch_names)) + ) # We no longer center digitization points on head origin, as we work # in head coordinates always @@ -807,22 +882,24 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): # Match the digitization points with the requested # channels. eeg_ch_locs = dict(zip(eeg_ch_names, locs3d)) - locs3d = np.array([eeg_ch_locs[ch['ch_name']] for ch in chs]) + locs3d = np.array([eeg_ch_locs[ch["ch_name"]] for ch in chs]) # Sometimes we can get nans - locs3d[~np.isfinite(locs3d)] = 0. + locs3d[~np.isfinite(locs3d)] = 0.0 # Duplicate points cause all kinds of trouble during visualization dist = pdist(locs3d) if len(locs3d) > 1 and np.min(dist) < 1e-10 and not ignore_overlap: problematic_electrodes = [ - chs[elec_i]['ch_name'] + chs[elec_i]["ch_name"] for elec_i in squareform(dist < 1e-10).any(axis=0).nonzero()[0] ] - raise ValueError('The following electrodes have overlapping positions,' - ' which causes problems during visualization:\n' + - ', '.join(problematic_electrodes)) + raise ValueError( + "The following electrodes have overlapping positions," + " which causes problems during visualization:\n" + + ", ".join(problematic_electrodes) + ) if to_sphere: # translate to sphere origin, transform/flatten Z, translate back @@ -831,7 +908,7 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): cart_coords = _cart_to_sph(locs3d) out = _pol_to_cart(cart_coords[:, 1:][:, ::-1]) # scale from radians to mm - out *= cart_coords[:, [0]] / (np.pi / 2.) + out *= cart_coords[:, [0]] / (np.pi / 2.0) out += sphere[:2] else: out = _pol_to_cart(_cart_to_sph(locs3d)) @@ -862,18 +939,19 @@ def _topo_to_sphere(pos, eegs): xs += 0.5 - np.mean(xs[eegs]) # Center the points ys += 0.5 - np.mean(ys[eegs]) - xs = xs * 2. - 1. # Values ranging from -1 to 1 - ys = ys * 2. - 1. + xs = xs * 2.0 - 1.0 # Values ranging from -1 to 1 + ys = ys * 2.0 - 1.0 - rs = np.clip(np.sqrt(xs ** 2 + ys ** 2), 0., 1.) + rs = np.clip(np.sqrt(xs**2 + ys**2), 0.0, 1.0) alphas = np.arccos(rs) zs = np.sin(alphas) return np.column_stack([xs, ys, zs]) @fill_doc -def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', - raise_error=True): +def _pair_grad_sensors( + info, layout=None, topomap_coords=True, exclude="bads", raise_error=True +): """Find the picks for pairing grad channels. Parameters @@ -901,18 +979,18 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', """ # find all complete pairs of grad channels pairs = defaultdict(list) - grad_picks = pick_types(info, meg='grad', ref_meg=False, exclude=exclude) + grad_picks = pick_types(info, meg="grad", ref_meg=False, exclude=exclude) _, has_vv_grad, *_, has_neuromag_122_grad, _ = _get_ch_info(info) for i in grad_picks: - ch = info['chs'][i] - name = ch['ch_name'] - if has_vv_grad and name.startswith('MEG'): - if name.endswith(('2', '3')): + ch = info["chs"][i] + name = ch["ch_name"] + if has_vv_grad and name.startswith("MEG"): + if name.endswith(("2", "3")): key = name[-4:-1] pairs[key].append(ch) - if has_neuromag_122_grad and name.startswith('MEG'): + if has_neuromag_122_grad and name.startswith("MEG"): key = (int(name[-3:]) - 1) // 2 pairs[key].append(ch) @@ -926,13 +1004,12 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', # find the picks corresponding to the grad channels grad_chs = sum(pairs, []) - ch_names = info['ch_names'] - picks = [ch_names.index(c['ch_name']) for c in grad_chs] + ch_names = info["ch_names"] + picks = [ch_names.index(c["ch_name"]) for c in grad_chs] if topomap_coords: shape = (len(pairs), 2, -1) - coords = (_find_topomap_coords(info, picks, layout) - .reshape(shape).mean(axis=1)) + coords = _find_topomap_coords(info, picks, layout).reshape(shape).mean(axis=1) return picks, coords else: return picks @@ -955,8 +1032,8 @@ def _pair_grad_sensors_ch_names_vectorview(ch_names): """ pairs = defaultdict(list) for i, name in enumerate(ch_names): - if name.startswith('MEG'): - if name.endswith(('2', '3')): + if name.startswith("MEG"): + if name.endswith(("2", "3")): key = name[-4:-1] pairs[key].append(i) @@ -983,7 +1060,7 @@ def _pair_grad_sensors_ch_names_neuromag122(ch_names): """ pairs = defaultdict(list) for i, name in enumerate(ch_names): - if name.startswith('MEG'): + if name.startswith("MEG"): key = (int(name[-3:]) - 1) // 2 pairs[key].append(i) @@ -993,7 +1070,7 @@ def _pair_grad_sensors_ch_names_neuromag122(ch_names): return grad_chs -def _merge_ch_data(data, ch_type, names, method='rms'): +def _merge_ch_data(data, ch_type, names, method="rms"): """Merge data from channel pairs. Parameters @@ -1014,7 +1091,7 @@ def _merge_ch_data(data, ch_type, names, method='rms'): names : list List of channel names. """ - if ch_type == 'grad': + if ch_type == "grad": data = _merge_grad_data(data, method) else: assert ch_type in _FNIRS_CH_TYPES_SPLIT @@ -1022,7 +1099,7 @@ def _merge_ch_data(data, ch_type, names, method='rms'): return data, names -def _merge_grad_data(data, method='rms'): +def _merge_grad_data(data, method="rms"): """Merge data from channel pairs using the RMS or mean. Parameters @@ -1038,10 +1115,10 @@ def _merge_grad_data(data, method='rms'): The root mean square or mean for each pair. """ data, orig_shape = data.reshape((len(data) // 2, 2, -1)), data.shape - if method == 'mean': + if method == "mean": data = np.mean(data, axis=1) - elif method == 'rms': - data = np.sqrt(np.sum(data ** 2, axis=1) / 2) + elif method == "rms": + data = np.sqrt(np.sum(data**2, axis=1) / 2) else: raise ValueError('method must be "rms" or "mean", got %s.' % method) return data.reshape(data.shape[:1] + orig_shape[1:]) @@ -1070,7 +1147,7 @@ def _merge_nirs_data(data, merged_names): """ to_remove = np.empty(0, dtype=np.int32) for idx, ch in enumerate(merged_names): - if 'x' in ch: + if "x" in ch: indices = np.empty(0, dtype=np.int32) channels = ch.split("x") for sub_ch in channels[1:]: @@ -1084,9 +1161,17 @@ def _merge_nirs_data(data, merged_names): return data, merged_names -def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, - ch_indices=None, name='ecog', bg_image=None, - normalize=True): +def generate_2d_layout( + xy, + w=0.07, + h=0.05, + pad=0.02, + ch_names=None, + ch_indices=None, + name="ecog", + bg_image=None, + normalize=True, +): """Generate a custom 2D layout from xy points. Generates a 2-D layout for plotting with plot_topo methods and @@ -1137,15 +1222,16 @@ def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, .. versionadded:: 0.9.0 """ import matplotlib.pyplot as plt + if ch_indices is None: ch_indices = np.arange(xy.shape[0]) if ch_names is None: - ch_names = ['{}'.format(i) for i in ch_indices] + ch_names = ["{}".format(i) for i in ch_indices] if len(ch_names) != len(ch_indices): - raise ValueError('# channel names and indices must be equal') + raise ValueError("# channel names and indices must be equal") if len(ch_names) != len(xy): - raise ValueError('# channel names and xy vals must be equal') + raise ValueError("# channel names and xy vals must be equal") x, y = xy.copy().astype(float).T @@ -1159,7 +1245,7 @@ def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, # Normalize x and y by their maxes for i_dim in [x, y]: i_dim -= i_dim.min(0) - i_dim /= (i_dim.max(0) - i_dim.min(0)) + i_dim /= i_dim.max(0) - i_dim.min(0) # Create box and pos variable box = _box_size(np.vstack([x, y]).T, padding=pad) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 178557bc520..11d48a099c8 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -22,22 +22,46 @@ from ..defaults import HEAD_SIZE_DEFAULT from .._freesurfer import get_mni_fiducials from ..viz import plot_montage -from ..transforms import (apply_trans, get_ras_to_neuromag_trans, _sph_to_cart, - _topo_to_sph, _frame_to_str, Transform, - _verbose_frames, _fit_matched_points, - _quat_to_affine, _ensure_trans) -from ..io._digitization import (_count_points_by_type, _ensure_fiducials_head, - _get_dig_eeg, _make_dig_points, write_dig, - _read_dig_fif, _format_dig_points, - _get_fid_coords, _coord_frame_const, - _get_data_as_dict_from_dig) +from ..transforms import ( + apply_trans, + get_ras_to_neuromag_trans, + _sph_to_cart, + _topo_to_sph, + _frame_to_str, + Transform, + _verbose_frames, + _fit_matched_points, + _quat_to_affine, + _ensure_trans, +) +from ..io._digitization import ( + _count_points_by_type, + _ensure_fiducials_head, + _get_dig_eeg, + _make_dig_points, + write_dig, + _read_dig_fif, + _format_dig_points, + _get_fid_coords, + _coord_frame_const, + _get_data_as_dict_from_dig, +) from ..io.meas_info import create_info from ..io.open import fiff_open from ..io.pick import pick_types, _picks_to_idx, channel_type from ..io.constants import FIFF, CHANNEL_LOC_ALIASES -from ..utils import (warn, copy_function_doc_to_method_doc, _pl, verbose, - _check_option, _validate_type, _check_fname, _on_missing, - fill_doc, _docdict) +from ..utils import ( + warn, + copy_function_doc_to_method_doc, + _pl, + verbose, + _check_option, + _validate_type, + _check_fname, + _on_missing, + fill_doc, + _docdict, +) from ._dig_montage_utils import _read_dig_montage_egi from ._dig_montage_utils import _parse_brainvision_dig_montage @@ -51,132 +75,133 @@ class _BuiltinStandardMontage: _BUILTIN_STANDARD_MONTAGES = [ _BuiltinStandardMontage( - name='standard_1005', - description='Electrodes are named and positioned according to the ' - 'international 10-05 system (343+3 locations)', + name="standard_1005", + description="Electrodes are named and positioned according to the " + "international 10-05 system (343+3 locations)", ), _BuiltinStandardMontage( - name='standard_1020', - description='Electrodes are named and positioned according to the ' - 'international 10-20 system (94+3 locations)', + name="standard_1020", + description="Electrodes are named and positioned according to the " + "international 10-20 system (94+3 locations)", ), _BuiltinStandardMontage( - name='standard_alphabetic', - description='Electrodes are named with LETTER-NUMBER combinations ' - '(A1, B2, F4, …) (65+3 locations)', + name="standard_alphabetic", + description="Electrodes are named with LETTER-NUMBER combinations " + "(A1, B2, F4, …) (65+3 locations)", ), _BuiltinStandardMontage( - name='standard_postfixed', - description='Electrodes are named according to the international ' - '10-20 system using postfixes for intermediate positions ' - '(100+3 locations)', + name="standard_postfixed", + description="Electrodes are named according to the international " + "10-20 system using postfixes for intermediate positions " + "(100+3 locations)", ), _BuiltinStandardMontage( - name='standard_prefixed', - description='Electrodes are named according to the international ' - '10-20 system using prefixes for intermediate positions ' - '(74+3 locations)', + name="standard_prefixed", + description="Electrodes are named according to the international " + "10-20 system using prefixes for intermediate positions " + "(74+3 locations)", ), _BuiltinStandardMontage( - name='standard_primed', + name="standard_primed", description="Electrodes are named according to the international " - "10-20 system using prime marks (' and '') for " - "intermediate positions (100+3 locations)", + "10-20 system using prime marks (' and '') for " + "intermediate positions (100+3 locations)", ), _BuiltinStandardMontage( - name='biosemi16', - description='BioSemi cap with 16 electrodes (16+3 locations)', + name="biosemi16", + description="BioSemi cap with 16 electrodes (16+3 locations)", ), _BuiltinStandardMontage( - name='biosemi32', - description='BioSemi cap with 32 electrodes (32+3 locations)', + name="biosemi32", + description="BioSemi cap with 32 electrodes (32+3 locations)", ), _BuiltinStandardMontage( - name='biosemi64', - description='BioSemi cap with 64 electrodes (64+3 locations)', + name="biosemi64", + description="BioSemi cap with 64 electrodes (64+3 locations)", ), _BuiltinStandardMontage( - name='biosemi128', - description='BioSemi cap with 128 electrodes (128+3 locations)', + name="biosemi128", + description="BioSemi cap with 128 electrodes (128+3 locations)", ), _BuiltinStandardMontage( - name='biosemi160', - description='BioSemi cap with 160 electrodes (160+3 locations)', + name="biosemi160", + description="BioSemi cap with 160 electrodes (160+3 locations)", ), _BuiltinStandardMontage( - name='biosemi256', - description='BioSemi cap with 256 electrodes (256+3 locations)', + name="biosemi256", + description="BioSemi cap with 256 electrodes (256+3 locations)", ), _BuiltinStandardMontage( - name='easycap-M1', - description='EasyCap with 10-05 electrode names (74 locations)', + name="easycap-M1", + description="EasyCap with 10-05 electrode names (74 locations)", ), _BuiltinStandardMontage( - name='easycap-M10', - description='EasyCap with numbered electrodes (61 locations)', + name="easycap-M10", + description="EasyCap with numbered electrodes (61 locations)", ), _BuiltinStandardMontage( - name='EGI_256', - description='Geodesic Sensor Net (256 locations)', + name="EGI_256", + description="Geodesic Sensor Net (256 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-32', - description='HydroCel Geodesic Sensor Net and Cz (33+3 locations)', + name="GSN-HydroCel-32", + description="HydroCel Geodesic Sensor Net and Cz (33+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-64_1.0', - description='HydroCel Geodesic Sensor Net (64+3 locations)', + name="GSN-HydroCel-64_1.0", + description="HydroCel Geodesic Sensor Net (64+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-65_1.0', - description='HydroCel Geodesic Sensor Net and Cz (65+3 locations)', + name="GSN-HydroCel-65_1.0", + description="HydroCel Geodesic Sensor Net and Cz (65+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-128', - description='HydroCel Geodesic Sensor Net (128+3 locations)', + name="GSN-HydroCel-128", + description="HydroCel Geodesic Sensor Net (128+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-129', - description='HydroCel Geodesic Sensor Net and Cz (129+3 locations)', + name="GSN-HydroCel-129", + description="HydroCel Geodesic Sensor Net and Cz (129+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-256', - description='HydroCel Geodesic Sensor Net (256+3 locations)', + name="GSN-HydroCel-256", + description="HydroCel Geodesic Sensor Net (256+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-257', - description='HydroCel Geodesic Sensor Net and Cz (257+3 locations)', + name="GSN-HydroCel-257", + description="HydroCel Geodesic Sensor Net and Cz (257+3 locations)", ), _BuiltinStandardMontage( - name='mgh60', - description='The (older) 60-channel cap used at MGH (60+3 locations)', + name="mgh60", + description="The (older) 60-channel cap used at MGH (60+3 locations)", ), _BuiltinStandardMontage( - name='mgh70', - description='The (newer) 70-channel BrainVision cap used at MGH ' - '(70+3 locations)', + name="mgh70", + description="The (newer) 70-channel BrainVision cap used at MGH " + "(70+3 locations)", ), _BuiltinStandardMontage( - name='artinis-octamon', - description='Artinis OctaMon fNIRS (8 sources, 2 detectors)', + name="artinis-octamon", + description="Artinis OctaMon fNIRS (8 sources, 2 detectors)", ), _BuiltinStandardMontage( - name='artinis-brite23', - description='Artinis Brite23 fNIRS (11 sources, 7 detectors)', + name="artinis-brite23", + description="Artinis Brite23 fNIRS (11 sources, 7 detectors)", ), _BuiltinStandardMontage( - name='brainproducts-RNP-BA-128', - description='Brain Products with 10-10 electrode names (128 channels)', - ) + name="brainproducts-RNP-BA-128", + description="Brain Products with 10-10 electrode names (128 channels)", + ), ] def _check_get_coord_frame(dig): - dig_coord_frames = sorted(set(d['coord_frame'] for d in dig)) + dig_coord_frames = sorted(set(d["coord_frame"] for d in dig)) if len(dig_coord_frames) != 1: raise RuntimeError( - 'Only a single coordinate frame in dig is supported, got ' - f'{dig_coord_frames}') + "Only a single coordinate frame in dig is supported, got " + f"{dig_coord_frames}" + ) return _frame_to_str[dig_coord_frames.pop()] if dig_coord_frames else None @@ -205,15 +230,20 @@ def get_builtin_montages(*, descriptions=False): If ``descriptions=True``, a list of tuples ``(name, description)``. """ if descriptions: - return [ - (m.name, m.description) for m in _BUILTIN_STANDARD_MONTAGES - ] + return [(m.name, m.description) for m in _BUILTIN_STANDARD_MONTAGES] else: return [m.name for m in _BUILTIN_STANDARD_MONTAGES] -def make_dig_montage(ch_pos=None, nasion=None, lpa=None, rpa=None, - hsp=None, hpi=None, coord_frame='unknown'): +def make_dig_montage( + ch_pos=None, + nasion=None, + lpa=None, + rpa=None, + hsp=None, + hpi=None, + coord_frame="unknown", +): r"""Make montage from arrays. Parameters @@ -263,14 +293,19 @@ def make_dig_montage(ch_pos=None, nasion=None, lpa=None, rpa=None, read_dig_localite read_dig_polhemus_isotrak """ - _validate_type(ch_pos, (dict, None), 'ch_pos') + _validate_type(ch_pos, (dict, None), "ch_pos") if ch_pos is None: ch_names = None else: ch_names = list(ch_pos) dig = _make_dig_points( - nasion=nasion, lpa=lpa, rpa=rpa, hpi=hpi, extra_points=hsp, - dig_ch_pos=ch_pos, coord_frame=coord_frame + nasion=nasion, + lpa=lpa, + rpa=rpa, + hpi=hpi, + extra_points=hsp, + dig_ch_pos=ch_pos, + coord_frame=coord_frame, ) return DigMontage(dig=dig, ch_names=ch_names) @@ -308,13 +343,13 @@ class DigMontage: def __init__(self, *, dig=None, ch_names=None): dig = list() if dig is None else dig - _validate_type(item=dig, types=list, item_name='dig') + _validate_type(item=dig, types=list, item_name="dig") ch_names = list() if ch_names is None else ch_names - n_eeg = sum([1 for d in dig if d['kind'] == FIFF.FIFFV_POINT_EEG]) + n_eeg = sum([1 for d in dig if d["kind"] == FIFF.FIFFV_POINT_EEG]) if n_eeg != len(ch_names): raise ValueError( - 'The number of EEG channels (%d) does not match the number' - ' of channel names provided (%d)' % (n_eeg, len(ch_names)) + "The number of EEG channels (%d) does not match the number" + " of channel names provided (%d)" % (n_eeg, len(ch_names)) ) self.dig = dig @@ -323,15 +358,32 @@ def __init__(self, *, dig=None, ch_names=None): def __repr__(self): """Return string representation.""" n_points = _count_points_by_type(self.dig) - return ('').format(**n_points) + return ( + "" + ).format(**n_points) @copy_function_doc_to_method_doc(plot_montage) - def plot(self, scale_factor=20, show_names=True, kind='topomap', show=True, - sphere=None, *, axes=None, verbose=None): - return plot_montage(self, scale_factor=scale_factor, - show_names=show_names, kind=kind, show=show, - sphere=sphere, axes=axes) + def plot( + self, + scale_factor=20, + show_names=True, + kind="topomap", + show=True, + sphere=None, + *, + axes=None, + verbose=None, + ): + return plot_montage( + self, + scale_factor=scale_factor, + show_names=show_names, + kind=kind, + show=show, + sphere=sphere, + axes=axes, + ) @fill_doc def rename_channels(self, mapping, allow_duplicates=False): @@ -347,9 +399,10 @@ def rename_channels(self, mapping, allow_duplicates=False): The instance. Operates in-place. """ from .channels import rename_channels - temp_info = create_info(list(self._get_ch_pos()), 1000., 'eeg') + + temp_info = create_info(list(self._get_ch_pos()), 1000.0, "eeg") rename_channels(temp_info, mapping, allow_duplicates) - self.ch_names = temp_info['ch_names'] + self.ch_names = temp_info["ch_names"] @verbose def save(self, fname, *, overwrite=False, verbose=None): @@ -374,20 +427,19 @@ def __iadd__(self, other): and if fiducials are present they should share the same coordinate system and location values. """ + def is_fid_defined(fid): - return not ( - fid.nasion is None and fid.lpa is None and fid.rpa is None - ) + return not (fid.nasion is None and fid.lpa is None and fid.rpa is None) # Check for none duplicated ch_names ch_names_intersection = set(self.ch_names).intersection(other.ch_names) if ch_names_intersection: - raise RuntimeError(( - "Cannot add two DigMontage objects if they contain duplicated" - " channel names. Duplicated channel(s) found: {}." - ).format( - ', '.join(['%r' % v for v in sorted(ch_names_intersection)]) - )) + raise RuntimeError( + ( + "Cannot add two DigMontage objects if they contain duplicated" + " channel names. Duplicated channel(s) found: {}." + ).format(", ".join(["%r" % v for v in sorted(ch_names_intersection)])) + ) # Check for unique matching fiducials self_fid, self_coord = _get_fid_coords(self.dig) @@ -395,20 +447,24 @@ def is_fid_defined(fid): if is_fid_defined(self_fid) and is_fid_defined(other_fid): if self_coord != other_coord: - raise RuntimeError('Cannot add two DigMontage objects if ' - 'fiducial locations are not in the same ' - 'coordinate system.') + raise RuntimeError( + "Cannot add two DigMontage objects if " + "fiducial locations are not in the same " + "coordinate system." + ) for kk in self_fid: if not np.array_equal(self_fid[kk], other_fid[kk]): - raise RuntimeError('Cannot add two DigMontage objects if ' - 'fiducial locations do not match ' - '(%s)' % kk) + raise RuntimeError( + "Cannot add two DigMontage objects if " + "fiducial locations do not match " + "(%s)" % kk + ) # keep self self.dig = _format_dig_points( - self.dig + [d for d in other.dig - if d['kind'] != FIFF.FIFFV_POINT_CARDINAL] + self.dig + + [d for d in other.dig if d["kind"] != FIFF.FIFFV_POINT_CARDINAL] ) else: self.dig = _format_dig_points(self.dig + other.dig) @@ -442,13 +498,13 @@ def __eq__(self, other): return self.dig == other.dig and self.ch_names == other.ch_names def _get_ch_pos(self): - pos = [d['r'] for d in _get_dig_eeg(self.dig)] + pos = [d["r"] for d in _get_dig_eeg(self.dig)] assert len(self.ch_names) == len(pos) return OrderedDict(zip(self.ch_names, pos)) def _get_dig_names(self): NAMED_KIND = (FIFF.FIFFV_POINT_EEG,) - is_eeg = np.array([d['kind'] in NAMED_KIND for d in self.dig]) + is_eeg = np.array([d["kind"] in NAMED_KIND for d in self.dig]) assert len(self.ch_names) == is_eeg.sum() dig_names = [None] * len(self.dig) for ch_name_idx, dig_idx in enumerate(np.where(is_eeg)[0]): @@ -509,16 +565,15 @@ def apply_trans(self, trans, verbose=None): The transformation matrix to be applied. %(verbose)s """ - _validate_type(trans, Transform, 'trans') - coord_frame = self.get_positions()['coord_frame'] - trans = _ensure_trans(trans, fro=coord_frame, to=trans['to']) + _validate_type(trans, Transform, "trans") + coord_frame = self.get_positions()["coord_frame"] + trans = _ensure_trans(trans, fro=coord_frame, to=trans["to"]) for d in self.dig: - d['r'] = apply_trans(trans, d['r']) - d['coord_frame'] = trans['to'] + d["r"] = apply_trans(trans, d["r"]) + d["coord_frame"] = trans["to"] @verbose - def add_estimated_fiducials(self, subject, subjects_dir=None, - verbose=None): + def add_estimated_fiducials(self, subject, subjects_dir=None, verbose=None): """Estimate fiducials based on FreeSurfer ``fsaverage`` subject. This takes a montage with the ``mri`` coordinate frame, @@ -558,8 +613,9 @@ def add_estimated_fiducials(self, subject, subjects_dir=None, if montage_bunch.coord_frame != FIFF.FIFFV_COORD_MRI: raise RuntimeError( f'Montage should be in the "mri" coordinate frame ' - f'to use `add_estimated_fiducials`. The current coordinate ' - f'frame is {montage_bunch.coord_frame}') + f"to use `add_estimated_fiducials`. The current coordinate " + f"frame is {montage_bunch.coord_frame}" + ) # estimate LPA, nasion, RPA from FreeSurfer fsaverage fids_mri = list(get_mni_fiducials(subject, subjects_dir)) @@ -598,14 +654,15 @@ def add_mni_fiducials(self, subjects_dir=None, verbose=None): if montage_bunch.coord_frame != FIFF.FIFFV_MNE_COORD_MNI_TAL: raise RuntimeError( f'Montage should be in the "mni_tal" coordinate frame ' - f'to use `add_estimated_fiducials`. The current coordinate ' - f'frame is {montage_bunch.coord_frame}') + f"to use `add_estimated_fiducials`. The current coordinate " + f"frame is {montage_bunch.coord_frame}" + ) - fids_mni = get_mni_fiducials('fsaverage', subjects_dir) + fids_mni = get_mni_fiducials("fsaverage", subjects_dir) for fid in fids_mni: # "mri" and "mni_tal" are equivalent for fsaverage - assert fid['coord_frame'] == FIFF.FIFFV_COORD_MRI - fid['coord_frame'] = FIFF.FIFFV_MNE_COORD_MNI_TAL + assert fid["coord_frame"] == FIFF.FIFFV_COORD_MRI + fid["coord_frame"] = FIFF.FIFFV_MNE_COORD_MNI_TAL self.dig = fids_mni + self.dig return self @@ -632,7 +689,7 @@ def remove_fiducials(self, verbose=None): should not be changed by removing fiducials. """ for d in self.dig.copy(): - if d['kind'] == FIFF.FIFFV_POINT_CARDINAL: + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL: self.dig.remove(d) return self @@ -641,7 +698,7 @@ def remove_fiducials(self, verbose=None): def _check_unit_and_get_scaling(unit): - _check_option('unit', unit, sorted(VALID_SCALES.keys())) + _check_option("unit", unit, sorted(VALID_SCALES.keys())) return VALID_SCALES[unit] @@ -677,11 +734,11 @@ def transform_to_head(montage): # Get fiducial points and their coord_frame native_head_t = compute_native_head_t(montage) montage = montage.copy() # to avoid inplace modification - if native_head_t['from'] != FIFF.FIFFV_COORD_HEAD: + if native_head_t["from"] != FIFF.FIFFV_COORD_HEAD: for d in montage.dig: - if d['coord_frame'] == native_head_t['from']: - d['r'] = apply_trans(native_head_t, d['r']) - d['coord_frame'] = FIFF.FIFFV_COORD_HEAD + if d["coord_frame"] == native_head_t["from"]: + d["r"] = apply_trans(native_head_t, d["r"]) + d["coord_frame"] = FIFF.FIFFV_COORD_HEAD _ensure_fiducials_head(montage.dig) return montage @@ -722,9 +779,10 @@ def read_dig_dat(fname): a plain text editor. """ from ._standard_montage_utils import _check_dupes_odict - fname = _check_fname(fname, overwrite='read', must_exist=True) - with open(fname, 'r') as fid: + fname = _check_fname(fname, overwrite="read", must_exist=True) + + with open(fname, "r") as fid: lines = fid.readlines() ch_names, poss = list(), list() @@ -736,16 +794,17 @@ def read_dig_dat(fname): elif len(items) != 5: raise ValueError( "Error reading %s, line %s has unexpected number of entries:\n" - "%s" % (fname, i, line.rstrip())) + "%s" % (fname, i, line.rstrip()) + ) num = items[1] - if num == '67': + if num == "67": continue # centroid pos = np.array([float(item) for item in items[2:]]) - if num == '78': + if num == "78": nasion = pos - elif num == '76': + elif num == "76": lpa = pos - elif num == '82': + elif num == "82": rpa = pos else: ch_names.append(items[0]) @@ -782,7 +841,7 @@ def read_dig_fif(fname): read_dig_localite make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) # Load the dig data f, tree = fiff_open(fname)[:2] with f as fid: @@ -790,14 +849,14 @@ def read_dig_fif(fname): ch_names = [] for d in dig: - if d['kind'] == FIFF.FIFFV_POINT_EEG: - ch_names.append('EEG%03d' % d['ident']) + if d["kind"] == FIFF.FIFFV_POINT_EEG: + ch_names.append("EEG%03d" % d["ident"]) montage = DigMontage(dig=dig, ch_names=ch_names) return montage -def read_dig_hpts(fname, unit='mm'): +def read_dig_hpts(fname, unit="mm"): """Read historical ``.hpts`` MNE-C files. Parameters @@ -867,26 +926,27 @@ def read_dig_hpts(fname, unit='mm'): """ from ._standard_montage_utils import _str_names, _str - fname = _check_fname(fname, overwrite='read', must_exist=True) + fname = _check_fname(fname, overwrite="read", must_exist=True) _scale = _check_unit_and_get_scaling(unit) - out = np.genfromtxt(fname, comments='#', - dtype=(_str, _str, 'f8', 'f8', 'f8')) - kind, label = _str_names(out['f0']), _str_names(out['f1']) + out = np.genfromtxt(fname, comments="#", dtype=(_str, _str, "f8", "f8", "f8")) + kind, label = _str_names(out["f0"]), _str_names(out["f1"]) kind = [k.lower() for k in kind] - xyz = np.array([out['f%d' % ii] for ii in range(2, 5)]).T + xyz = np.array([out["f%d" % ii] for ii in range(2, 5)]).T xyz *= _scale del _scale - fid_idx_to_label = {'1': 'lpa', '2': 'nasion', '3': 'rpa'} - fid = {fid_idx_to_label[label[ii]]: this_xyz - for ii, this_xyz in enumerate(xyz) if kind[ii] == 'cardinal'} - ch_pos = {label[ii]: this_xyz - for ii, this_xyz in enumerate(xyz) if kind[ii] == 'eeg'} - hpi = np.array([this_xyz for ii, this_xyz in enumerate(xyz) - if kind[ii] == 'hpi']) + fid_idx_to_label = {"1": "lpa", "2": "nasion", "3": "rpa"} + fid = { + fid_idx_to_label[label[ii]]: this_xyz + for ii, this_xyz in enumerate(xyz) + if kind[ii] == "cardinal" + } + ch_pos = { + label[ii]: this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "eeg" + } + hpi = np.array([this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "hpi"]) hpi.shape = (-1, 3) # in case it's empty - hsp = np.array([this_xyz for ii, this_xyz in enumerate(xyz) - if kind[ii] == 'extra']) + hsp = np.array([this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "extra"]) hsp.shape = (-1, 3) # in case it's empty return make_dig_montage(ch_pos=ch_pos, **fid, hpi=hpi, hsp=hsp) @@ -915,12 +975,10 @@ def read_dig_egi(fname): read_dig_polhemus_isotrak make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) data = _read_dig_montage_egi( - fname=fname, - _scaling=1., - _all_data_kwargs_are_none=True + fname=fname, _scaling=1.0, _all_data_kwargs_are_none=True ) return make_dig_montage(**data) @@ -950,7 +1008,7 @@ def read_dig_captrak(fname): read_dig_polhemus_isotrak make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) data = _parse_brainvision_dig_montage(fname, scale=1e-3) return make_dig_montage(**data) @@ -1004,7 +1062,7 @@ def read_dig_localite(fname, nasion=None, lpa=None, rpa=None): def _get_montage_in_head(montage): - coords = set([d['coord_frame'] for d in montage.dig]) + coords = set([d["coord_frame"] for d in montage.dig]) montage = montage.copy() if len(coords) == 1 and coords.pop() == FIFF.FIFFV_COORD_HEAD: _ensure_fiducials_head(montage.dig) @@ -1023,33 +1081,33 @@ def _set_montage_fnirs(info, montage): place. """ from ..preprocessing.nirs import _validate_nirs_info + # Validate that the fNIRS info is correctly formatted picks = _validate_nirs_info(info) # Modify info['chs'][#]['loc'] in place num_ficiduals = len(montage.dig) - len(montage.ch_names) for ch_idx in picks: - ch = info['chs'][ch_idx]['ch_name'] - source, detector = ch.split(' ')[0].split('_') - source_pos = montage.dig[montage.ch_names.index(source) - + num_ficiduals]['r'] - detector_pos = montage.dig[montage.ch_names.index(detector) - + num_ficiduals]['r'] - - info['chs'][ch_idx]['loc'][3:6] = source_pos - info['chs'][ch_idx]['loc'][6:9] = detector_pos + ch = info["chs"][ch_idx]["ch_name"] + source, detector = ch.split(" ")[0].split("_") + source_pos = montage.dig[montage.ch_names.index(source) + num_ficiduals]["r"] + detector_pos = montage.dig[montage.ch_names.index(detector) + num_ficiduals][ + "r" + ] + + info["chs"][ch_idx]["loc"][3:6] = source_pos + info["chs"][ch_idx]["loc"][6:9] = detector_pos midpoint = (source_pos + detector_pos) / 2 - info['chs'][ch_idx]['loc'][:3] = midpoint - info['chs'][ch_idx]['coord_frame'] = FIFF.FIFFV_COORD_HEAD + info["chs"][ch_idx]["loc"][:3] = midpoint + info["chs"][ch_idx]["coord_frame"] = FIFF.FIFFV_COORD_HEAD # Modify info['dig'] in place with info._unlock(): - info['dig'] = montage.dig + info["dig"] = montage.dig @fill_doc -def _set_montage(info, montage, match_case=True, match_alias=False, - on_missing='raise'): +def _set_montage(info, montage, match_case=True, match_alias=False, on_missing="raise"): """Apply montage to data. With a DigMontage, this function will replace the digitizer info with @@ -1070,19 +1128,20 @@ def _set_montage(info, montage, match_case=True, match_alias=False, ----- This function will change the info variable in place. """ - _validate_type(montage, (DigMontage, None, str), 'montage') + _validate_type(montage, (DigMontage, None, str), "montage") if montage is None: # Next line modifies info['dig'] in place with info._unlock(): - info['dig'] = None - for ch in info['chs']: + info["dig"] = None + for ch in info["chs"]: # Next line modifies info['chs'][#]['loc'] in place - ch['loc'] = np.full(12, np.nan) + ch["loc"] = np.full(12, np.nan) return if isinstance(montage, str): # load builtin montage _check_option( - parameter='montage', value=montage, - allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES] + parameter="montage", + value=montage, + allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES], ) montage = make_standard_montage(montage) @@ -1100,66 +1159,72 @@ def _backcompat_value(pos, ref_pos): # only get the eeg, seeg, dbs, ecog channels picks = pick_types( - info, meg=False, eeg=True, seeg=True, dbs=True, ecog=True, - exclude=()) - non_picks = np.setdiff1d(np.arange(info['nchan']), picks) + info, meg=False, eeg=True, seeg=True, dbs=True, ecog=True, exclude=() + ) + non_picks = np.setdiff1d(np.arange(info["nchan"]), picks) # get the reference position from the loc[3:6] - chs = [info['chs'][ii] for ii in picks] - non_names = [info['chs'][ii]['ch_name'] for ii in non_picks] + chs = [info["chs"][ii] for ii in picks] + non_names = [info["chs"][ii]["ch_name"] for ii in non_picks] del picks - ref_pos = [ch['loc'][3:6] for ch in chs] + ref_pos = [ch["loc"][3:6] for ch in chs] # keep reference location from EEG-like channels if they # already exist and are all the same. custom_eeg_ref_dig = False # Note: ref position is an empty list for fieldtrip data if ref_pos: - if all([np.equal(ref_pos[0], pos).all() for pos in ref_pos]) \ - and not np.equal(ref_pos[0], [0, 0, 0]).all(): + if ( + all([np.equal(ref_pos[0], pos).all() for pos in ref_pos]) + and not np.equal(ref_pos[0], [0, 0, 0]).all() + ): eeg_ref_pos = ref_pos[0] # since we have an EEG reference position, we have # to add it into the info['dig'] as EEG000 custom_eeg_ref_dig = True if not custom_eeg_ref_dig: - refs = set(ch_pos) & {'EEG000', 'REF'} + refs = set(ch_pos) & {"EEG000", "REF"} assert len(refs) <= 1 eeg_ref_pos = np.zeros(3) if not refs else ch_pos.pop(refs.pop()) # This raises based on info being subset/superset of montage - info_names = [ch['ch_name'] for ch in chs] + info_names = [ch["ch_name"] for ch in chs] dig_names = mnt_head._get_dig_names() - ref_names = [None, 'EEG000', 'REF'] + ref_names = [None, "EEG000", "REF"] if match_case: info_names_use = info_names dig_names_use = dig_names non_names_use = non_names else: - ch_pos_use = OrderedDict( - (name.lower(), pos) for name, pos in ch_pos.items()) + ch_pos_use = OrderedDict((name.lower(), pos) for name, pos in ch_pos.items()) info_names_use = [name.lower() for name in info_names] - dig_names_use = [name.lower() if name is not None else name - for name in dig_names] + dig_names_use = [ + name.lower() if name is not None else name for name in dig_names + ] non_names_use = [name.lower() for name in non_names] - ref_names = [name.lower() if name is not None else name - for name in ref_names] + ref_names = [name.lower() if name is not None else name for name in ref_names] n_dup = len(ch_pos) - len(ch_pos_use) if n_dup: - raise ValueError('Cannot use match_case=False as %s montage ' - 'name(s) require case sensitivity' % n_dup) + raise ValueError( + "Cannot use match_case=False as %s montage " + "name(s) require case sensitivity" % n_dup + ) n_dup = len(info_names_use) - len(set(info_names_use)) if n_dup: - raise ValueError('Cannot use match_case=False as %s channel ' - 'name(s) require case sensitivity' % n_dup) + raise ValueError( + "Cannot use match_case=False as %s channel " + "name(s) require case sensitivity" % n_dup + ) ch_pos = ch_pos_use del ch_pos_use del dig_names # use lookup table to match unrecognized channel names to known aliases if match_alias: - alias_dict = (match_alias if isinstance(match_alias, dict) else - CHANNEL_LOC_ALIASES) + alias_dict = ( + match_alias if isinstance(match_alias, dict) else CHANNEL_LOC_ALIASES + ) if not match_case: alias_dict = { ch_name.lower(): ch_alias.lower() @@ -1168,16 +1233,11 @@ def _backcompat_value(pos, ref_pos): # excluded ch_alias not in info, to prevent unnecessary mapping and # warning messages based on aliases. - alias_dict = { - ch_name: ch_alias - for ch_name, ch_alias in alias_dict.items() - } + alias_dict = {ch_name: ch_alias for ch_name, ch_alias in alias_dict.items()} info_names_use = [ alias_dict.get(ch_name, ch_name) for ch_name in info_names_use ] - non_names_use = [ - alias_dict.get(ch_name, ch_name) for ch_name in non_names_use - ] + non_names_use = [alias_dict.get(ch_name, ch_name) for ch_name in non_names_use] # warn user if there is not a full overlap of montage with info_chs missing = np.where([use not in ch_pos for use in info_names_use])[0] @@ -1208,42 +1268,47 @@ def _backcompat_value(pos, ref_pos): # will have entries "D1" and "S1". extra = np.where([non in ch_pos for non in non_names_use])[0] if len(extra): - types = '/'.join(sorted(set( - channel_type(info, non_picks[ii]) for ii in extra))) + types = "/".join(sorted(set(channel_type(info, non_picks[ii]) for ii in extra))) names = [non_names[ii] for ii in extra] - warn(f'Not setting position{_pl(extra)} of {len(extra)} {types} ' - f'channel{_pl(extra)} found in montage:\n{names}\n' - 'Consider setting the channel types to be of ' - f'{_docdict["montage_types"]} ' - 'using inst.set_channel_types before calling inst.set_montage, ' - 'or omit these channels when creating your montage.') + warn( + f"Not setting position{_pl(extra)} of {len(extra)} {types} " + f"channel{_pl(extra)} found in montage:\n{names}\n" + "Consider setting the channel types to be of " + f'{_docdict["montage_types"]} ' + "using inst.set_channel_types before calling inst.set_montage, " + "or omit these channels when creating your montage." + ) for ch, use in zip(chs, info_names_use): # Next line modifies info['chs'][#]['loc'] in place if use in ch_pos: - ch['loc'][:6] = _backcompat_value(ch_pos[use], eeg_ref_pos) - ch['coord_frame'] = FIFF.FIFFV_COORD_HEAD + ch["loc"][:6] = _backcompat_value(ch_pos[use], eeg_ref_pos) + ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD del ch_pos # XXX this is probably wrong as it uses the order from the montage # rather than the order of our info['ch_names'] ... digpoints = [ - mnt_head.dig[ii] for ii, name in enumerate(dig_names_use) - if name in (info_names_use + ref_names)] + mnt_head.dig[ii] + for ii, name in enumerate(dig_names_use) + if name in (info_names_use + ref_names) + ] # get a copy of the old dig - if info['dig'] is not None: - old_dig = info['dig'].copy() + if info["dig"] is not None: + old_dig = info["dig"].copy() else: old_dig = [] # determine if needed to add an extra EEG REF DigPoint if custom_eeg_ref_dig: # ref_name = 'EEG000' if match_case else 'eeg000' - ref_dig_dict = {'kind': FIFF.FIFFV_POINT_EEG, - 'r': eeg_ref_pos, - 'ident': 0, - 'coord_frame': info['dig'].pop()['coord_frame']} + ref_dig_dict = { + "kind": FIFF.FIFFV_POINT_EEG, + "r": eeg_ref_pos, + "ident": 0, + "coord_frame": info["dig"].pop()["coord_frame"], + } ref_dig_point = _format_dig_points([ref_dig_dict])[0] # only append the reference dig point if it was already # in the old dig @@ -1251,7 +1316,7 @@ def _backcompat_value(pos, ref_pos): digpoints.append(ref_dig_point) # Next line modifies info['dig'] in place with info._unlock(): - info['dig'] = _format_dig_points(digpoints, enforce_order=True) + info["dig"] = _format_dig_points(digpoints, enforce_order=True) del digpoints # TODO: Ideally we would have a check like this, but read_raw_bids for ECoG @@ -1267,7 +1332,7 @@ def _backcompat_value(pos, ref_pos): # 'not happen. Please contact MNE-Python developers.') # Handle fNIRS with source, detector and channel - fnirs_picks = _picks_to_idx(info, 'fnirs', allow_empty=True) + fnirs_picks = _picks_to_idx(info, "fnirs", allow_empty=True) if len(fnirs_picks) > 0: _set_montage_fnirs(info, mnt_head) @@ -1292,13 +1357,16 @@ def _read_isotrak_elp_points(fname): with open(fname) as fid: file_str = fid.read() - points_str = [m.groups() for m in re.finditer(coord_pattern, file_str, - re.MULTILINE)] + points_str = [ + m.groups() for m in re.finditer(coord_pattern, file_str, re.MULTILINE) + ] points = np.array(points_str, dtype=float) return { - 'nasion': points[0], 'lpa': points[1], 'rpa': points[2], - 'points': points[3:] + "nasion": points[0], + "lpa": points[1], + "rpa": points[2], + "points": points[3:], } @@ -1316,12 +1384,13 @@ def _read_isotrak_hsp_points(fname): The dictionary containing locations for 'nasion', 'lpa', 'rpa' and 'points'. """ + def get_hsp_fiducial(line): - return np.fromstring(line.replace('%F', ''), dtype=float, sep='\t') + return np.fromstring(line.replace("%F", ""), dtype=float, sep="\t") with open(fname) as ff: for line in ff: - if 'position of fiducials' in line.lower(): + if "position of fiducials" in line.lower(): break nasion = get_hsp_fiducial(ff.readline()) @@ -1331,20 +1400,20 @@ def get_hsp_fiducial(line): _ = ff.readline() line = ff.readline() if line: - n_points, n_cols = np.fromstring(line, dtype=int, sep='\t') + n_points, n_cols = np.fromstring(line, dtype=int, sep="\t") points = np.fromstring( - string=ff.read(), dtype=float, sep='\t', + string=ff.read(), + dtype=float, + sep="\t", ).reshape(-1, n_cols) assert points.shape[0] == n_points else: points = np.empty((0, 3)) - return { - 'nasion': nasion, 'lpa': lpa, 'rpa': rpa, 'points': points - } + return {"nasion": nasion, "lpa": lpa, "rpa": rpa, "points": points} -def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): +def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): """Read Polhemus digitizer data from a file. Parameters @@ -1377,14 +1446,14 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): read_dig_fif read_dig_localite """ - VALID_FILE_EXT = ('.hsp', '.elp', '.eeg') + VALID_FILE_EXT = (".hsp", ".elp", ".eeg") fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) _, ext = op.splitext(fname) - _check_option('fname', ext, VALID_FILE_EXT) + _check_option("fname", ext, VALID_FILE_EXT) - if ext == '.elp': + if ext == ".elp": data = _read_isotrak_elp_points(fname) else: # Default case we read points as hsp since is the most likely scenario @@ -1396,39 +1465,40 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): pass # noqa if ch_names is None: - keyword = 'hpi' if ext == '.elp' else 'hsp' - data[keyword] = data.pop('points') + keyword = "hpi" if ext == ".elp" else "hsp" + data[keyword] = data.pop("points") else: - points = data.pop('points') + points = data.pop("points") if points.shape[0] == len(ch_names): - data['ch_pos'] = OrderedDict(zip(ch_names, points)) + data["ch_pos"] = OrderedDict(zip(ch_names, points)) else: - raise ValueError(( - "Length of ``ch_names`` does not match the number of points" - " in {fname}. Expected ``ch_names`` length {n_points:d}," - " given {n_chnames:d}" - ).format( - fname=fname, n_points=points.shape[0], n_chnames=len(ch_names) - )) + raise ValueError( + ( + "Length of ``ch_names`` does not match the number of points" + " in {fname}. Expected ``ch_names`` length {n_points:d}," + " given {n_chnames:d}" + ).format(fname=fname, n_points=points.shape[0], n_chnames=len(ch_names)) + ) return make_dig_montage(**data) def _is_polhemus_fastscan(fname): - header = '' - with open(fname, 'r') as fid: + header = "" + with open(fname, "r") as fid: for line in fid: - if not line.startswith('%'): + if not line.startswith("%"): break header += line - return 'FastSCAN' in header + return "FastSCAN" in header @verbose -def read_polhemus_fastscan(fname, unit='mm', on_header_missing='raise', *, - verbose=None): +def read_polhemus_fastscan( + fname, unit="mm", on_header_missing="raise", *, verbose=None +): """Read Polhemus FastSCAN digitizer data from a ``.txt`` file. Parameters @@ -1451,18 +1521,18 @@ def read_polhemus_fastscan(fname, unit='mm', on_header_missing='raise', *, read_dig_polhemus_isotrak make_dig_montage """ - VALID_FILE_EXT = ['.txt'] + VALID_FILE_EXT = [".txt"] fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) _, ext = op.splitext(fname) - _check_option('fname', ext, VALID_FILE_EXT) + _check_option("fname", ext, VALID_FILE_EXT) if not _is_polhemus_fastscan(fname): msg = "%s does not contain a valid Polhemus FastSCAN header" % fname _on_missing(on_header_missing, msg) - points = _scale * np.loadtxt(fname, comments='%', ndmin=2) + points = _scale * np.loadtxt(fname, comments="%", ndmin=2) _check_dig_shape(points) return points @@ -1521,66 +1591,75 @@ def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): :func:`make_dig_montage` that takes arrays as input. """ from ._standard_montage_utils import ( - _read_theta_phi_in_degrees, _read_sfp, _read_csd, _read_elc, - _read_elp_besa, _read_brainvision, _read_xyz + _read_theta_phi_in_degrees, + _read_sfp, + _read_csd, + _read_elc, + _read_elp_besa, + _read_brainvision, + _read_xyz, ) + SUPPORTED_FILE_EXT = { - 'eeglab': ('.loc', '.locs', '.eloc', ), - 'hydrocel': ('.sfp', ), - 'matlab': ('.csd', ), - 'asa electrode': ('.elc', ), - 'generic (Theta-phi in degrees)': ('.txt', ), - 'standard BESA spherical': ('.elp', ), # NB: not same as polhemus elp - 'brainvision': ('.bvef', ), - 'xyz': ('.csv', '.tsv', '.xyz'), + "eeglab": ( + ".loc", + ".locs", + ".eloc", + ), + "hydrocel": (".sfp",), + "matlab": (".csd",), + "asa electrode": (".elc",), + "generic (Theta-phi in degrees)": (".txt",), + "standard BESA spherical": (".elp",), # NB: not same as polhemus elp + "brainvision": (".bvef",), + "xyz": (".csv", ".tsv", ".xyz"), } fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _, ext = op.splitext(fname) - _check_option('fname', ext, list(sum(SUPPORTED_FILE_EXT.values(), ()))) + _check_option("fname", ext, list(sum(SUPPORTED_FILE_EXT.values(), ()))) - if ext in SUPPORTED_FILE_EXT['eeglab']: + if ext in SUPPORTED_FILE_EXT["eeglab"]: if head_size is None: - raise ValueError( - "``head_size`` cannot be None for '{}'".format(ext)) + raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) ch_names, pos = _read_eeglab_locations(fname) scale = head_size / np.median(np.linalg.norm(pos, axis=-1)) pos *= scale montage = make_dig_montage( ch_pos=OrderedDict(zip(ch_names, pos)), - coord_frame='head', + coord_frame="head", ) - elif ext in SUPPORTED_FILE_EXT['hydrocel']: + elif ext in SUPPORTED_FILE_EXT["hydrocel"]: montage = _read_sfp(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['matlab']: + elif ext in SUPPORTED_FILE_EXT["matlab"]: montage = _read_csd(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['asa electrode']: + elif ext in SUPPORTED_FILE_EXT["asa electrode"]: montage = _read_elc(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['generic (Theta-phi in degrees)']: + elif ext in SUPPORTED_FILE_EXT["generic (Theta-phi in degrees)"]: if head_size is None: - raise ValueError( - "``head_size`` cannot be None for '{}'".format(ext)) - montage = _read_theta_phi_in_degrees(fname, head_size=head_size, - fid_names=('Nz', 'LPA', 'RPA')) + raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) + montage = _read_theta_phi_in_degrees( + fname, head_size=head_size, fid_names=("Nz", "LPA", "RPA") + ) - elif ext in SUPPORTED_FILE_EXT['standard BESA spherical']: + elif ext in SUPPORTED_FILE_EXT["standard BESA spherical"]: montage = _read_elp_besa(fname, head_size) - elif ext in SUPPORTED_FILE_EXT['brainvision']: + elif ext in SUPPORTED_FILE_EXT["brainvision"]: montage = _read_brainvision(fname, head_size) - elif ext in SUPPORTED_FILE_EXT['xyz']: + elif ext in SUPPORTED_FILE_EXT["xyz"]: montage = _read_xyz(fname) if coord_frame is not None: coord_frame = _coord_frame_const(coord_frame) for d in montage.dig: - d['coord_frame'] = coord_frame + d["coord_frame"] = coord_frame return montage @@ -1602,31 +1681,49 @@ def compute_dev_head_t(montage): """ _, coord_frame = _get_fid_coords(montage.dig) if coord_frame != FIFF.FIFFV_COORD_HEAD: - raise ValueError('montage should have been set to head coordinate ' - 'system with transform_to_head function.') + raise ValueError( + "montage should have been set to head coordinate " + "system with transform_to_head function." + ) hpi_head = np.array( - [d['r'] for d in montage.dig - if (d['kind'] == FIFF.FIFFV_POINT_HPI and - d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)], float) + [ + d["r"] + for d in montage.dig + if ( + d["kind"] == FIFF.FIFFV_POINT_HPI + and d["coord_frame"] == FIFF.FIFFV_COORD_HEAD + ) + ], + float, + ) hpi_dev = np.array( - [d['r'] for d in montage.dig - if (d['kind'] == FIFF.FIFFV_POINT_HPI and - d['coord_frame'] == FIFF.FIFFV_COORD_DEVICE)], float) + [ + d["r"] + for d in montage.dig + if ( + d["kind"] == FIFF.FIFFV_POINT_HPI + and d["coord_frame"] == FIFF.FIFFV_COORD_DEVICE + ) + ], + float, + ) if not (len(hpi_head) == len(hpi_dev) and len(hpi_dev) > 0): - raise ValueError(( - "To compute Device-to-Head transformation, the same number of HPI" - " points in device and head coordinates is required. (Got {dev}" - " points in device and {head} points in head coordinate systems)" - ).format(dev=len(hpi_dev), head=len(hpi_head))) + raise ValueError( + ( + "To compute Device-to-Head transformation, the same number of HPI" + " points in device and head coordinates is required. (Got {dev}" + " points in device and {head} points in head coordinate systems)" + ).format(dev=len(hpi_dev), head=len(hpi_head)) + ) trans = _quat_to_affine(_fit_matched_points(hpi_dev, hpi_head)[0]) - return Transform(fro='meg', to='head', trans=trans) + return Transform(fro="meg", to="head", trans=trans) @verbose -def compute_native_head_t(montage, *, on_missing='warn', verbose=None): +def compute_native_head_t(montage, *, on_missing="warn", verbose=None): """Compute the native-to-head transformation for a montage. This uses the fiducials in the native space to transform to compute the @@ -1653,23 +1750,25 @@ def compute_native_head_t(montage, *, on_missing='warn', verbose=None): if coord_frame == FIFF.FIFFV_COORD_HEAD: native_head_t = np.eye(4) else: - fid_keys = ('nasion', 'lpa', 'rpa') + fid_keys = ("nasion", "lpa", "rpa") for key in fid_keys: this_coord = fid_coords[key] if this_coord is None or np.any(np.isnan(this_coord)): msg = ( - f'Fiducial point {key} not found, assuming identity ' - f'{_verbose_frames[coord_frame]} to head transformation') + f"Fiducial point {key} not found, assuming identity " + f"{_verbose_frames[coord_frame]} to head transformation" + ) _on_missing(on_missing, msg, error_klass=RuntimeError) native_head_t = np.eye(4) break else: native_head_t = get_ras_to_neuromag_trans( - *[fid_coords[key] for key in fid_keys]) - return Transform(coord_frame, 'head', native_head_t) + *[fid_coords[key] for key in fid_keys] + ) + return Transform(coord_frame, "head", native_head_t) -def make_standard_montage(kind, head_size='auto'): +def make_standard_montage(kind, head_size="auto"): """Read a generic (built-in) standard montage that ships with MNE-Python. Parameters @@ -1708,15 +1807,17 @@ def make_standard_montage(kind, head_size='auto'): .. versionadded:: 0.19.0 """ from ._standard_montage_utils import standard_montage_look_up_table - _validate_type(kind, str, 'kind') + + _validate_type(kind, str, "kind") _check_option( - parameter='kind', value=kind, - allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES] + parameter="kind", + value=kind, + allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES], ) - _validate_type(head_size, ('numeric', str, None), 'head_size') + _validate_type(head_size, ("numeric", str, None), "head_size") if isinstance(head_size, str): - _check_option('head_size', head_size, ('auto',), extra='when str') - if kind.startswith(('standard', 'mgh', 'artinis')): + _check_option("head_size", head_size, ("auto",), extra="when str") + if kind.startswith(("standard", "mgh", "artinis")): head_size = None else: head_size = HEAD_SIZE_DEFAULT @@ -1724,7 +1825,6 @@ def make_standard_montage(kind, head_size='auto'): def _check_dig_shape(pts): - _validate_type(pts, np.ndarray, 'points') + _validate_type(pts, np.ndarray, "points") if pts.ndim != 2 or pts.shape[-1] != 3: - raise ValueError( - f'Points must be of shape (n, 3) instead of {pts.shape}') + raise ValueError(f"Points must be of shape (n, 3) instead of {pts.shape}") diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index 04f07d84ec3..2b719b7e3af 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -13,19 +13,42 @@ from scipy.io import savemat from numpy.testing import assert_array_equal, assert_equal, assert_allclose -from mne.channels import (rename_channels, read_ch_adjacency, combine_channels, - find_ch_adjacency, make_1020_channel_selections, - read_custom_montage, equalize_channels, - get_builtin_ch_adjacencies) +from mne.channels import ( + rename_channels, + read_ch_adjacency, + combine_channels, + find_ch_adjacency, + make_1020_channel_selections, + read_custom_montage, + equalize_channels, + get_builtin_ch_adjacencies, +) from mne.channels.channels import ( - _ch_neighbor_adjacency, _compute_ch_adjacency, - _BUILTIN_CHANNEL_ADJACENCIES, _BuiltinChannelAdjacency + _ch_neighbor_adjacency, + _compute_ch_adjacency, + _BUILTIN_CHANNEL_ADJACENCIES, + _BuiltinChannelAdjacency, +) +from mne.io import ( + read_info, + read_raw_fif, + read_raw_ctf, + read_raw_bti, + read_raw_eeglab, + read_raw_kit, + RawArray, ) -from mne.io import (read_info, read_raw_fif, read_raw_ctf, read_raw_bti, - read_raw_eeglab, read_raw_kit, RawArray) from mne.io.constants import FIFF -from mne import (pick_types, pick_channels, EpochsArray, EvokedArray, - make_ad_hoc_cov, create_info, read_events, Epochs) +from mne import ( + pick_types, + pick_channels, + EpochsArray, + EvokedArray, + make_ad_hoc_cov, + create_info, + read_events, + Epochs, +) from mne.datasets import testing from mne.utils import requires_pandas, requires_version from mne.parallel import parallel_func @@ -38,8 +61,8 @@ testing_path = testing.data_path(download=False) -@pytest.mark.parametrize('preload', (True, False)) -@pytest.mark.parametrize('proj', (True, False)) +@pytest.mark.parametrize("preload", (True, False)) +@pytest.mark.parametrize("proj", (True, False)) def test_reorder_channels(preload, proj): """Test reordering of channels.""" raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj() @@ -49,7 +72,7 @@ def test_reorder_channels(preload, proj): raw.load_data() # with .reorder_channels if proj and not preload: - with pytest.raises(RuntimeError, match='load data'): + with pytest.raises(RuntimeError, match="load data"): raw.copy().reorder_channels(raw.ch_names[::-1]) return raw_new = raw.copy().reorder_channels(raw.ch_names[::-1]) @@ -63,7 +86,7 @@ def test_reorder_channels(preload, proj): raw_new.reorder_channels(raw_new.ch_names[::-1][1:-1]) raw.drop_channels(raw.ch_names[:1] + raw.ch_names[-1:]) assert_array_equal(raw[:][0], raw_new[:][0]) - with pytest.raises(ValueError, match='repeated'): + with pytest.raises(ValueError, match="repeated"): raw.reorder_channels(raw.ch_names[:1] + raw.ch_names[:1]) # and with .pick reord = [1, 0] + list(range(2, len(raw.ch_names))) @@ -77,41 +100,41 @@ def test_rename_channels(): info = read_info(raw_fname) # Error Tests # Test channel name exists in ch_names - mapping = {'EEG 160': 'EEG060'} + mapping = {"EEG 160": "EEG060"} pytest.raises(ValueError, rename_channels, info, mapping) # Test improper mapping configuration - mapping = {'MEG 2641': 1.0} + mapping = {"MEG 2641": 1.0} pytest.raises(TypeError, rename_channels, info, mapping) # Test non-unique mapping configuration - mapping = {'MEG 2641': 'MEG 2642'} + mapping = {"MEG 2641": "MEG 2642"} pytest.raises(ValueError, rename_channels, info, mapping) # Test bad input - pytest.raises(ValueError, rename_channels, info, 1.) - pytest.raises(ValueError, rename_channels, info, 1.) + pytest.raises(ValueError, rename_channels, info, 1.0) + pytest.raises(ValueError, rename_channels, info, 1.0) # Test successful changes # Test ch_name and ch_names are changed info2 = deepcopy(info) # for consistency at the start of each test - info2['bads'] = ['EEG 060', 'EOG 061'] - mapping = {'EEG 060': 'EEG060', 'EOG 061': 'EOG061'} + info2["bads"] = ["EEG 060", "EOG 061"] + mapping = {"EEG 060": "EEG060", "EOG 061": "EOG061"} rename_channels(info2, mapping) - assert info2['chs'][374]['ch_name'] == 'EEG060' - assert info2['ch_names'][374] == 'EEG060' - assert info2['chs'][375]['ch_name'] == 'EOG061' - assert info2['ch_names'][375] == 'EOG061' - assert_array_equal(['EEG060', 'EOG061'], info2['bads']) + assert info2["chs"][374]["ch_name"] == "EEG060" + assert info2["ch_names"][374] == "EEG060" + assert info2["chs"][375]["ch_name"] == "EOG061" + assert info2["ch_names"][375] == "EOG061" + assert_array_equal(["EEG060", "EOG061"], info2["bads"]) info2 = deepcopy(info) - rename_channels(info2, lambda x: x.replace(' ', '')) - assert info2['chs'][373]['ch_name'] == 'EEG059' + rename_channels(info2, lambda x: x.replace(" ", "")) + assert info2["chs"][373]["ch_name"] == "EEG059" info2 = deepcopy(info) - info2['bads'] = ['EEG 060', 'EEG 060'] + info2["bads"] = ["EEG 060", "EEG 060"] rename_channels(info2, mapping) - assert_array_equal(['EEG060', 'EEG060'], info2['bads']) + assert_array_equal(["EEG060", "EEG060"], info2["bads"]) # test that keys in Raw._orig_units will be renamed, too raw = read_raw_fif(raw_fname).crop(0, 0.1) - old, new = 'EEG 060', 'New' - raw._orig_units = {old: 'V'} + old, new = "EEG 060", "New" + raw._orig_units = {old: "V"} raw.rename_channels({old: new}) assert old not in raw._orig_units @@ -123,74 +146,81 @@ def test_set_channel_types(): raw = read_raw_fif(raw_fname) # Error Tests # Test channel name exists in ch_names - mapping = {'EEG 160': 'EEG060'} + mapping = {"EEG 160": "EEG060"} with pytest.raises(ValueError, match=r"name \(EEG 160\) doesn't exist"): raw.set_channel_types(mapping) # Test change to illegal channel type - mapping = {'EOG 061': 'xxx'} - with pytest.raises(ValueError, match='cannot change to this channel type'): + mapping = {"EOG 061": "xxx"} + with pytest.raises(ValueError, match="cannot change to this channel type"): raw.set_channel_types(mapping) # Test changing type if in proj - mapping = {'EEG 057': 'dbs', 'EEG 058': 'ecog', 'EEG 059': 'ecg', - 'EEG 060': 'eog', 'EOG 061': 'seeg', 'MEG 2441': 'eeg', - 'MEG 2443': 'eeg', 'MEG 2442': 'hbo', 'EEG 001': 'resp'} + mapping = { + "EEG 057": "dbs", + "EEG 058": "ecog", + "EEG 059": "ecg", + "EEG 060": "eog", + "EOG 061": "seeg", + "MEG 2441": "eeg", + "MEG 2443": "eeg", + "MEG 2442": "hbo", + "EEG 001": "resp", + } raw2 = read_raw_fif(raw_fname) - raw2.info['bads'] = ['EEG 059', 'EEG 060', 'EOG 061'] + raw2.info["bads"] = ["EEG 059", "EEG 060", "EOG 061"] with pytest.raises(RuntimeError, match='type .* in projector "PCA-v1"'): raw2.set_channel_types(mapping) # has prj raw2.add_proj([], remove_existing=True) # Should raise - with pytest.raises(ValueError, match='unit for channel.* has changed'): - raw2.copy().set_channel_types(mapping, on_unit_change='raise') + with pytest.raises(ValueError, match="unit for channel.* has changed"): + raw2.copy().set_channel_types(mapping, on_unit_change="raise") # Should warn - with pytest.warns(RuntimeWarning, match='unit for channel.* has changed'): + with pytest.warns(RuntimeWarning, match="unit for channel.* has changed"): raw2.copy().set_channel_types(mapping) # Shouldn't warn - raw2.set_channel_types(mapping, on_unit_change='ignore') + raw2.set_channel_types(mapping, on_unit_change="ignore") info = raw2.info - assert info['chs'][371]['ch_name'] == 'EEG 057' - assert info['chs'][371]['kind'] == FIFF.FIFFV_DBS_CH - assert info['chs'][371]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][371]['coil_type'] == FIFF.FIFFV_COIL_EEG - assert info['chs'][372]['ch_name'] == 'EEG 058' - assert info['chs'][372]['kind'] == FIFF.FIFFV_ECOG_CH - assert info['chs'][372]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][372]['coil_type'] == FIFF.FIFFV_COIL_EEG - assert info['chs'][373]['ch_name'] == 'EEG 059' - assert info['chs'][373]['kind'] == FIFF.FIFFV_ECG_CH - assert info['chs'][373]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][373]['coil_type'] == FIFF.FIFFV_COIL_NONE - assert info['chs'][374]['ch_name'] == 'EEG 060' - assert info['chs'][374]['kind'] == FIFF.FIFFV_EOG_CH - assert info['chs'][374]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][374]['coil_type'] == FIFF.FIFFV_COIL_NONE - assert info['chs'][375]['ch_name'] == 'EOG 061' - assert info['chs'][375]['kind'] == FIFF.FIFFV_SEEG_CH - assert info['chs'][375]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][375]['coil_type'] == FIFF.FIFFV_COIL_EEG - for idx in pick_channels(raw.ch_names, ['MEG 2441', 'MEG 2443'], - ordered=False): - assert info['chs'][idx]['kind'] == FIFF.FIFFV_EEG_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_EEG - idx = pick_channels(raw.ch_names, ['MEG 2442'])[0] - assert info['chs'][idx]['kind'] == FIFF.FIFFV_FNIRS_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_MOL - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO + assert info["chs"][371]["ch_name"] == "EEG 057" + assert info["chs"][371]["kind"] == FIFF.FIFFV_DBS_CH + assert info["chs"][371]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][371]["coil_type"] == FIFF.FIFFV_COIL_EEG + assert info["chs"][372]["ch_name"] == "EEG 058" + assert info["chs"][372]["kind"] == FIFF.FIFFV_ECOG_CH + assert info["chs"][372]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][372]["coil_type"] == FIFF.FIFFV_COIL_EEG + assert info["chs"][373]["ch_name"] == "EEG 059" + assert info["chs"][373]["kind"] == FIFF.FIFFV_ECG_CH + assert info["chs"][373]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][373]["coil_type"] == FIFF.FIFFV_COIL_NONE + assert info["chs"][374]["ch_name"] == "EEG 060" + assert info["chs"][374]["kind"] == FIFF.FIFFV_EOG_CH + assert info["chs"][374]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][374]["coil_type"] == FIFF.FIFFV_COIL_NONE + assert info["chs"][375]["ch_name"] == "EOG 061" + assert info["chs"][375]["kind"] == FIFF.FIFFV_SEEG_CH + assert info["chs"][375]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][375]["coil_type"] == FIFF.FIFFV_COIL_EEG + for idx in pick_channels(raw.ch_names, ["MEG 2441", "MEG 2443"], ordered=False): + assert info["chs"][idx]["kind"] == FIFF.FIFFV_EEG_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_EEG + idx = pick_channels(raw.ch_names, ["MEG 2442"])[0] + assert info["chs"][idx]["kind"] == FIFF.FIFFV_FNIRS_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_MOL + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_FNIRS_HBO # resp channel type - idx = pick_channels(raw.ch_names, ['EEG 001'])[0] - assert info['chs'][idx]['kind'] == FIFF.FIFFV_RESP_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_NONE + idx = pick_channels(raw.ch_names, ["EEG 001"])[0] + assert info["chs"][idx]["kind"] == FIFF.FIFFV_RESP_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_NONE # Test meaningful error when setting channel type with unknown unit - raw.info['chs'][0]['unit'] = 0. - ch_types = {raw.ch_names[0]: 'misc'} + raw.info["chs"][0]["unit"] = 0.0 + ch_types = {raw.ch_names[0]: "misc"} pytest.raises(ValueError, raw.set_channel_types, ch_types) @@ -208,17 +238,21 @@ def test_get_builtin_ch_adjacencies(): def test_read_ch_adjacency(tmp_path): """Test reading channel adjacency templates.""" - a = partial(np.array, dtype=' ps # are channels in the correct selection? @@ -405,9 +443,9 @@ def test_1020_selection(): def test_find_ch_adjacency(): """Test computing the adjacency matrix.""" raw = read_raw_fif(raw_fname, preload=True) - sizes = {'mag': 828, 'grad': 1700, 'eeg': 384} - nchans = {'mag': 102, 'grad': 204, 'eeg': 60} - for ch_type in ['mag', 'grad', 'eeg']: + sizes = {"mag": 828, "grad": 1700, "eeg": 384} + nchans = {"mag": 102, "grad": 204, "eeg": 60} + for ch_type in ["mag", "grad", "eeg"]: conn, ch_names = find_ch_adjacency(raw.info, ch_type) # Silly test for checking the number of neighbors. assert_equal(conn.getnnz(), sizes[ch_type]) @@ -415,30 +453,30 @@ def test_find_ch_adjacency(): pytest.raises(ValueError, find_ch_adjacency, raw.info, None) # Test computing the conn matrix with gradiometers. - conn, ch_names = _compute_ch_adjacency(raw.info, 'grad') + conn, ch_names = _compute_ch_adjacency(raw.info, "grad") assert_equal(conn.getnnz(), 2680) # Test ch_type=None. - raw.pick_types(meg='mag') + raw.pick_types(meg="mag") find_ch_adjacency(raw.info, None) bti_fname = testing_path / "BTi" / "erm_HFH" / "c,rfDC" bti_config_name = testing_path / "BTi" / "erm_HFH" / "config" raw = read_raw_bti(bti_fname, bti_config_name, None) - _, ch_names = find_ch_adjacency(raw.info, 'mag') - assert 'A1' in ch_names + _, ch_names = find_ch_adjacency(raw.info, "mag") + assert "A1" in ch_names ctf_fname = testing_path / "CTF" / "testdata_ctf_short.ds" raw = read_raw_ctf(ctf_fname) - _, ch_names = find_ch_adjacency(raw.info, 'mag') - assert 'MLC11' in ch_names + _, ch_names = find_ch_adjacency(raw.info, "mag") + assert "MLC11" in ch_names - pytest.raises(ValueError, find_ch_adjacency, raw.info, 'eog') + pytest.raises(ValueError, find_ch_adjacency, raw.info, "eog") raw_kit = read_raw_kit(fname_kit_157) - neighb, ch_names = find_ch_adjacency(raw_kit.info, 'mag') + neighb, ch_names = find_ch_adjacency(raw_kit.info, "mag") assert neighb.data.size == 1329 - assert ch_names[0] == 'MEG 001' + assert ch_names[0] == "MEG 001" @testing.requires_testing_data @@ -446,7 +484,7 @@ def test_neuromag122_adjacency(): """Test computing the adjacency matrix of Neuromag122-Data.""" nm122_fname = testing_path / "misc" / "neuromag122_test_file-raw.fif" raw = read_raw_fif(nm122_fname, preload=True) - conn, ch_names = find_ch_adjacency(raw.info, 'grad') + conn, ch_names = find_ch_adjacency(raw.info, "grad") assert conn.getnnz() == 1564 assert len(ch_names) == 122 assert conn.shape == (122, 122) @@ -463,13 +501,13 @@ def test_drop_channels(): # by default, drop channels raises a ValueError if a channel can't be found m_chs = ["MEG 0111", "MEG blahblah"] - with pytest.raises(ValueError, match='not found, nothing dropped'): + with pytest.raises(ValueError, match="not found, nothing dropped"): raw.drop_channels(m_chs) # ...but this can be turned to a warning - with pytest.warns(RuntimeWarning, match='not found, nothing dropped'): - raw.drop_channels(m_chs, on_missing='warn') + with pytest.warns(RuntimeWarning, match="not found, nothing dropped"): + raw.drop_channels(m_chs, on_missing="warn") # ...or ignored altogether - raw.drop_channels(m_chs, on_missing='ignore') + raw.drop_channels(m_chs, on_missing="ignore") def test_pick_channels(): @@ -477,17 +515,17 @@ def test_pick_channels(): raw = read_raw_fif(raw_fname, preload=True).crop(0, 0.1) # selected correctly 3 channels - raw.pick(['MEG 0113', 'MEG 0112', 'MEG 0111']) + raw.pick(["MEG 0113", "MEG 0112", "MEG 0111"]) assert len(raw.ch_names) == 3 # selected correctly 3 channels and ignored 'meg', and emit warning - with pytest.raises(ValueError, match='not present in the info'): - raw.pick(['MEG 0113', "meg", 'MEG 0112', 'MEG 0111']) + with pytest.raises(ValueError, match="not present in the info"): + raw.pick(["MEG 0113", "meg", "MEG 0112", "MEG 0111"]) names_len = len(raw.ch_names) - raw.pick(['all']) # selected correctly all channels + raw.pick(["all"]) # selected correctly all channels assert len(raw.ch_names) == names_len - raw.pick('all') # selected correctly all channels + raw.pick("all") # selected correctly all channels assert len(raw.ch_names) == names_len @@ -502,16 +540,16 @@ def test_add_reference_channels(): n_evoked_original_channels = len(evoked.ch_names) # Raw object - raw.add_reference_channels(['REF 123']) + raw.add_reference_channels(["REF 123"]) assert len(raw.ch_names) == n_raw_original_channels + 1 assert np.all(raw.get_data()[-1] == 0) # Epochs object - epochs.add_reference_channels(['REF 123']) + epochs.add_reference_channels(["REF 123"]) assert epochs._data.shape[1] == epochs_original_shape + 1 # Evoked object - evoked.add_reference_channels(['REF 123']) + evoked.add_reference_channels(["REF 123"]) assert len(evoked.ch_names) == n_evoked_original_channels + 1 assert np.all(evoked._data[-1] == 0) @@ -521,30 +559,35 @@ def test_equalize_channels(): # This function only tests the generic functionality of equalize_channels. # Additional tests for each instance type are included in the accompanying # test suite for each type. - pytest.raises(TypeError, equalize_channels, ['foo', 'bar'], - match='Instances to be modified must be an instance of') + pytest.raises( + TypeError, + equalize_channels, + ["foo", "bar"], + match="Instances to be modified must be an instance of", + ) - raw = RawArray([[1.], [2.], [3.], [4.]], - create_info(['CH1', 'CH2', 'CH3', 'CH4'], sfreq=1.)) - epochs = EpochsArray([[[1.], [2.], [3.]]], - create_info(['CH5', 'CH2', 'CH1'], sfreq=1.)) - cov = make_ad_hoc_cov(create_info(['CH2', 'CH1', 'CH8'], sfreq=1., - ch_types='eeg')) - cov['bads'] = ['CH1'] - ave = EvokedArray([[1.], [2.]], create_info(['CH1', 'CH2'], sfreq=1.)) + raw = RawArray( + [[1.0], [2.0], [3.0], [4.0]], + create_info(["CH1", "CH2", "CH3", "CH4"], sfreq=1.0), + ) + epochs = EpochsArray( + [[[1.0], [2.0], [3.0]]], create_info(["CH5", "CH2", "CH1"], sfreq=1.0) + ) + cov = make_ad_hoc_cov(create_info(["CH2", "CH1", "CH8"], sfreq=1.0, ch_types="eeg")) + cov["bads"] = ["CH1"] + ave = EvokedArray([[1.0], [2.0]], create_info(["CH1", "CH2"], sfreq=1.0)) - raw2, epochs2, cov2, ave2 = equalize_channels([raw, epochs, cov, ave], - copy=True) + raw2, epochs2, cov2, ave2 = equalize_channels([raw, epochs, cov, ave], copy=True) # The Raw object was the first in the list, so should have been used as # template for the ordering of the channels. No bad channels should have # been dropped. - assert raw2.ch_names == ['CH1', 'CH2'] - assert_array_equal(raw2.get_data(), [[1.], [2.]]) - assert epochs2.ch_names == ['CH1', 'CH2'] - assert_array_equal(epochs2.get_data(), [[[3.], [2.]]]) - assert cov2.ch_names == ['CH1', 'CH2'] - assert cov2['bads'] == cov['bads'] + assert raw2.ch_names == ["CH1", "CH2"] + assert_array_equal(raw2.get_data(), [[1.0], [2.0]]) + assert epochs2.ch_names == ["CH1", "CH2"] + assert_array_equal(epochs2.get_data(), [[[3.0], [2.0]]]) + assert cov2.ch_names == ["CH1", "CH2"] + assert cov2["bads"] == cov["bads"] assert ave2.ch_names == ave.ch_names assert_array_equal(ave2.data, ave.data) @@ -565,7 +608,7 @@ def test_combine_channels(): """Test channel combination on Raw, Epochs, and Evoked.""" raw = read_raw_fif(raw_fname, preload=True) raw_ch_bad = read_raw_fif(raw_fname, preload=True) - raw_ch_bad.info['bads'] = ['MEG 0113', 'MEG 0112'] + raw_ch_bad.info["bads"] = ["MEG 0113", "MEG 0112"] epochs = Epochs(raw, read_events(eve_fname)) evoked = epochs.average() good = dict(foo=[0, 1, 3, 4], bar=[5, 2]) # good grad and mag @@ -583,35 +626,32 @@ def test_combine_channels(): # Test with stimulus channels combine_stim = combine_channels(raw, good, keep_stim=True) target_nchan = len(good) + len(pick_types(raw.info, meg=False, stim=True)) - assert combine_stim.info['nchan'] == target_nchan + assert combine_stim.info["nchan"] == target_nchan # Test results with one ROI good_single = dict(foo=[0, 1, 3, 4]) # good grad - combined_mean = combine_channels(raw, good_single, method='mean') - combined_median = combine_channels(raw, good_single, method='median') - combined_std = combine_channels(raw, good_single, method='std') - foo_mean = np.mean(raw.get_data()[good_single['foo']], axis=0) - foo_median = np.median(raw.get_data()[good_single['foo']], axis=0) - foo_std = np.std(raw.get_data()[good_single['foo']], axis=0) - assert_array_equal(combined_mean.get_data(), - np.expand_dims(foo_mean, axis=0)) - assert_array_equal(combined_median.get_data(), - np.expand_dims(foo_median, axis=0)) - assert_array_equal(combined_std.get_data(), - np.expand_dims(foo_std, axis=0)) + combined_mean = combine_channels(raw, good_single, method="mean") + combined_median = combine_channels(raw, good_single, method="median") + combined_std = combine_channels(raw, good_single, method="std") + foo_mean = np.mean(raw.get_data()[good_single["foo"]], axis=0) + foo_median = np.median(raw.get_data()[good_single["foo"]], axis=0) + foo_std = np.std(raw.get_data()[good_single["foo"]], axis=0) + assert_array_equal(combined_mean.get_data(), np.expand_dims(foo_mean, axis=0)) + assert_array_equal(combined_median.get_data(), np.expand_dims(foo_median, axis=0)) + assert_array_equal(combined_std.get_data(), np.expand_dims(foo_std, axis=0)) # Test bad cases bad1 = dict(foo=[0, 376], bar=[5, 2]) # out of bounds bad2 = dict(foo=[0, 2], bar=[5, 2]) # type mix in same group with pytest.raises(ValueError, match='"method" must be a callable, or'): - combine_channels(raw, good, method='bad_method') + combine_channels(raw, good, method="bad_method") with pytest.raises(TypeError, match='"keep_stim" must be of type bool'): - combine_channels(raw, good, keep_stim='bad_type') + combine_channels(raw, good, keep_stim="bad_type") with pytest.raises(TypeError, match='"drop_bad" must be of type bool'): - combine_channels(raw, good, drop_bad='bad_type') - with pytest.raises(ValueError, match='Some channel indices are out of'): + combine_channels(raw, good, drop_bad="bad_type") + with pytest.raises(ValueError, match="Some channel indices are out of"): combine_channels(raw, bad1) - with pytest.raises(ValueError, match='Cannot combine sensors of diff'): + with pytest.raises(ValueError, match="Cannot combine sensors of diff"): combine_channels(raw, bad2) # Test warnings @@ -620,9 +660,9 @@ def test_combine_channels(): warn1 = dict(foo=[375, 375], bar=[5, 2]) # same channel in same group warn2 = dict(foo=[375], bar=[5, 2]) # one channel (last channel) warn3 = dict(foo=[0, 4], bar=[5, 2]) # one good channel left - with pytest.warns(RuntimeWarning, match='Could not find stimulus'): + with pytest.warns(RuntimeWarning, match="Could not find stimulus"): combine_channels(raw_no_stim, good, keep_stim=True) - with pytest.warns(RuntimeWarning, match='Less than 2 channels') as record: + with pytest.warns(RuntimeWarning, match="Less than 2 channels") as record: combine_channels(raw, warn1) combine_channels(raw, warn2) combine_channels(raw_ch_bad, warn3, drop_bad=True) @@ -637,8 +677,7 @@ def test_combine_channels_metadata(): raw = read_raw_fif(raw_fname, preload=True) epochs = Epochs(raw, read_events(eve_fname), preload=True) - metadata = pd.DataFrame({"A": np.arange(len(epochs)), - "B": np.ones(len(epochs))}) + metadata = pd.DataFrame({"A": np.arange(len(epochs)), "B": np.ones(len(epochs))}) epochs.metadata = metadata good = dict(foo=[0, 1, 3, 4], bar=[5, 2]) # good grad and mag diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index f6c71d1ff00..2425db488eb 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -8,8 +8,11 @@ from mne import io, pick_types, pick_channels, read_events, Epochs from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing -from mne.preprocessing.nirs import (optical_density, scalp_coupling_index, - beer_lambert_law) +from mne.preprocessing.nirs import ( + optical_density, + scalp_coupling_index, + beer_lambert_law, +) from mne.io import read_raw_nirx from mne.io.proj import _has_eeg_average_ref_proj from mne.utils import _record_warnings, requires_version @@ -30,95 +33,118 @@ def _load_data(kind): raw = io.read_raw_fif(raw_fname) events = read_events(event_name) # subselect channels for speed - if kind == 'eeg': + if kind == "eeg": picks = pick_types(raw.info, meg=False, eeg=True, exclude=[])[:15] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, reject=dict(eeg=80e-6)) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + reject=dict(eeg=80e-6), + ) else: picks = pick_types(raw.info, meg=True, eeg=False, exclude=[])[1:200:2] - assert kind == 'meg' - with pytest.warns(RuntimeWarning, match='projection'): - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, - reject=dict(grad=1000e-12, mag=4e-12)) + assert kind == "meg" + with pytest.warns(RuntimeWarning, match="projection"): + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + reject=dict(grad=1000e-12, mag=4e-12), + ) return raw, epochs -@pytest.mark.parametrize('offset', (0., 0.1)) -@pytest.mark.parametrize('avg_proj, ctol', [ - (True, (0.86, 0.93)), - (False, (0.97, 0.99)), -]) -@pytest.mark.parametrize('method, atol', [ - pytest.param(None, 3e-6, marks=pytest.mark.slowtest), # slow on Azure - (dict(eeg='MNE'), 4e-6), -]) -@pytest.mark.filterwarnings('ignore:.*than 20 mm from head frame origin.*') +@pytest.mark.parametrize("offset", (0.0, 0.1)) +@pytest.mark.parametrize( + "avg_proj, ctol", + [ + (True, (0.86, 0.93)), + (False, (0.97, 0.99)), + ], +) +@pytest.mark.parametrize( + "method, atol", + [ + pytest.param(None, 3e-6, marks=pytest.mark.slowtest), # slow on Azure + (dict(eeg="MNE"), 4e-6), + ], +) +@pytest.mark.filterwarnings("ignore:.*than 20 mm from head frame origin.*") def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): """Test interpolation of EEG channels.""" - raw, epochs_eeg = _load_data('eeg') + raw, epochs_eeg = _load_data("eeg") epochs_eeg = epochs_eeg.copy() assert not _has_eeg_average_ref_proj(epochs_eeg.info) # Offsetting the coordinate frame should have no effect on the output for inst in (raw, epochs_eeg): - for ch in inst.info['chs']: - if ch['kind'] == io.constants.FIFF.FIFFV_EEG_CH: - ch['loc'][:3] += offset - ch['loc'][3:6] += offset - for d in inst.info['dig']: - d['r'] += offset + for ch in inst.info["chs"]: + if ch["kind"] == io.constants.FIFF.FIFFV_EEG_CH: + ch["loc"][:3] += offset + ch["loc"][3:6] += offset + for d in inst.info["dig"]: + d["r"] += offset # check that interpolation does nothing if no bads are marked - epochs_eeg.info['bads'] = [] + epochs_eeg.info["bads"] = [] evoked_eeg = epochs_eeg.average() kw = dict(method=method) - with pytest.warns(RuntimeWarning, match='Doing nothing'): + with pytest.warns(RuntimeWarning, match="Doing nothing"): evoked_eeg.interpolate_bads(**kw) # create good and bad channels for EEG - epochs_eeg.info['bads'] = [] + epochs_eeg.info["bads"] = [] goods_idx = np.ones(len(epochs_eeg.ch_names), dtype=bool) - goods_idx[epochs_eeg.ch_names.index('EEG 012')] = False + goods_idx[epochs_eeg.ch_names.index("EEG 012")] = False bads_idx = ~goods_idx pos = epochs_eeg._get_channel_positions() evoked_eeg = epochs_eeg.average() if avg_proj: evoked_eeg.set_eeg_reference(projection=True).apply_proj() - assert_allclose(evoked_eeg.data.mean(0), 0., atol=1e-20) + assert_allclose(evoked_eeg.data.mean(0), 0.0, atol=1e-20) ave_before = evoked_eeg.data[bads_idx] # interpolate bad channels for EEG - epochs_eeg.info['bads'] = ['EEG 012'] + epochs_eeg.info["bads"] = ["EEG 012"] evoked_eeg = epochs_eeg.average() if avg_proj: evoked_eeg.set_eeg_reference(projection=True).apply_proj() good_picks = pick_types(evoked_eeg.info, meg=False, eeg=True) - assert_allclose(evoked_eeg.data[good_picks].mean(0), 0., atol=1e-20) + assert_allclose(evoked_eeg.data[good_picks].mean(0), 0.0, atol=1e-20) evoked_eeg_bad = evoked_eeg.copy() bads_picks = pick_channels( - epochs_eeg.ch_names, include=epochs_eeg.info['bads'], ordered=True + epochs_eeg.ch_names, include=epochs_eeg.info["bads"], ordered=True ) evoked_eeg_bad.data[bads_picks, :] = 1e10 # Test first the exclude parameter evoked_eeg_2_bads = evoked_eeg_bad.copy() - evoked_eeg_2_bads.info['bads'] = ['EEG 004', 'EEG 012'] + evoked_eeg_2_bads.info["bads"] = ["EEG 004", "EEG 012"] evoked_eeg_2_bads.data[ - pick_channels(evoked_eeg_bad.ch_names, ['EEG 004', 'EEG 012']) + pick_channels(evoked_eeg_bad.ch_names, ["EEG 004", "EEG 012"]) ] = 1e10 evoked_eeg_interp = evoked_eeg_2_bads.interpolate_bads( - origin=(0., 0., 0.), exclude=['EEG 004'], **kw) - assert evoked_eeg_interp.info['bads'] == ['EEG 004'] - assert np.all(evoked_eeg_interp.get_data('EEG 004') == 1e10) - assert np.all(evoked_eeg_interp.get_data('EEG 012') != 1e10) + origin=(0.0, 0.0, 0.0), exclude=["EEG 004"], **kw + ) + assert evoked_eeg_interp.info["bads"] == ["EEG 004"] + assert np.all(evoked_eeg_interp.get_data("EEG 004") == 1e10) + assert np.all(evoked_eeg_interp.get_data("EEG 012") != 1e10) # Now test without exclude parameter - evoked_eeg_bad.info['bads'] = ['EEG 012'] + evoked_eeg_bad.info["bads"] = ["EEG 012"] evoked_eeg_interp = evoked_eeg_bad.copy().interpolate_bads( - origin=(0., 0., 0.), **kw) + origin=(0.0, 0.0, 0.0), **kw + ) if avg_proj: - assert_allclose(evoked_eeg_interp.data.mean(0), 0., atol=1e-6) + assert_allclose(evoked_eeg_interp.data.mean(0), 0.0, atol=1e-6) interp_zero = evoked_eeg_interp.data[bads_idx] if method is None: # using pos_good = pos[goods_idx] @@ -136,7 +162,7 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): # check that interpolation fails when preload is False epochs_eeg.preload = False - with pytest.raises(RuntimeError, match='requires epochs data to be loade'): + with pytest.raises(RuntimeError, match="requires epochs data to be loade"): epochs_eeg.interpolate_bads(**kw) epochs_eeg.preload = True @@ -148,10 +174,10 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): # check that interpolation fails when preload is False for inst in [raw, epochs_eeg]: - assert hasattr(inst, 'preload') + assert hasattr(inst, "preload") inst.preload = False - inst.info['bads'] = [inst.ch_names[1]] - with pytest.raises(RuntimeError, match='requires.*data to be loaded'): + inst.info["bads"] = [inst.ch_names[1]] + with pytest.raises(RuntimeError, match="requires.*data to be loaded"): inst.interpolate_bads(**kw) # check that interpolation works with few channels @@ -159,11 +185,11 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): raw_few.pick_channels(raw_few.ch_names[:1] + raw_few.ch_names[3:4]) assert len(raw_few.ch_names) == 2 raw_few.del_proj() - raw_few.info['bads'] = [raw_few.ch_names[-1]] + raw_few.info["bads"] = [raw_few.ch_names[-1]] orig_data = raw_few[1][0] with _record_warnings() as w: raw_few.interpolate_bads(reset_bads=False, **kw) - assert len([ww for ww in w if 'more than' not in str(ww.message)]) == 0 + assert len([ww for ww in w if "more than" not in str(ww.message)]) == 0 new_data = raw_few[1][0] assert (new_data == 0).mean() < 0.5 assert np.corrcoef(new_data, orig_data)[0, 1] > 0.2 @@ -176,82 +202,80 @@ def test_interpolation_meg(): # correlation drops thresh = 0.68 - raw, epochs_meg = _load_data('meg') + raw, epochs_meg = _load_data("meg") # check that interpolation works when non M/EEG channels are present # before MEG channels raw.crop(0, 0.1).load_data().pick_channels(epochs_meg.ch_names) raw.info.normalize_proj() - raw.set_channel_types({raw.ch_names[0]: 'stim'}, on_unit_change='ignore') - raw.info['bads'] = [raw.ch_names[1]] + raw.set_channel_types({raw.ch_names[0]: "stim"}, on_unit_change="ignore") + raw.info["bads"] = [raw.ch_names[1]] raw.load_data() - raw.interpolate_bads(mode='fast') + raw.interpolate_bads(mode="fast") del raw # check that interpolation works for MEG - epochs_meg.info['bads'] = ['MEG 0141'] + epochs_meg.info["bads"] = ["MEG 0141"] evoked = epochs_meg.average() - pick = pick_channels(epochs_meg.info['ch_names'], epochs_meg.info['bads']) + pick = pick_channels(epochs_meg.info["ch_names"], epochs_meg.info["bads"]) # MEG -- raw raw_meg = io.RawArray(data=epochs_meg._data[0], info=epochs_meg.info) - raw_meg.info['bads'] = ['MEG 0141'] + raw_meg.info["bads"] = ["MEG 0141"] data1 = raw_meg[pick, :][0][0] raw_meg.info.normalize_proj() - data2 = raw_meg.interpolate_bads(reset_bads=False, - mode='fast')[pick, :][0][0] + data2 = raw_meg.interpolate_bads(reset_bads=False, mode="fast")[pick, :][0][0] assert np.corrcoef(data1, data2)[0, 1] > thresh # the same number of bads as before - assert len(raw_meg.info['bads']) == len(raw_meg.info['bads']) + assert len(raw_meg.info["bads"]) == len(raw_meg.info["bads"]) # MEG -- epochs data1 = epochs_meg.get_data()[:, pick, :].ravel() epochs_meg.info.normalize_proj() - epochs_meg.interpolate_bads(mode='fast') + epochs_meg.interpolate_bads(mode="fast") data2 = epochs_meg.get_data()[:, pick, :].ravel() assert np.corrcoef(data1, data2)[0, 1] > thresh - assert len(epochs_meg.info['bads']) == 0 + assert len(epochs_meg.info["bads"]) == 0 # MEG -- evoked (plus auto origin) data1 = evoked.data[pick] evoked.info.normalize_proj() - data2 = evoked.interpolate_bads(origin='auto').data[pick] + data2 = evoked.interpolate_bads(origin="auto").data[pick] assert np.corrcoef(data1, data2)[0, 1] > thresh # MEG -- with exclude - evoked.info['bads'] = ['MEG 0141', 'MEG 0121'] - pick = pick_channels(evoked.ch_names, evoked.info['bads'], ordered=True) + evoked.info["bads"] = ["MEG 0141", "MEG 0121"] + pick = pick_channels(evoked.ch_names, evoked.info["bads"], ordered=True) evoked.data[pick[-1]] = 1e10 data1 = evoked.data[pick] evoked.info.normalize_proj() - data2 = evoked.interpolate_bads( - origin='auto', exclude=['MEG 0121'] - ).data[pick] + data2 = evoked.interpolate_bads(origin="auto", exclude=["MEG 0121"]).data[pick] assert np.corrcoef(data1[0], data2[0])[0, 1] > thresh assert np.all(data2[1] == 1e10) def _this_interpol(inst, ref_meg=False): from mne.channels.interpolation import _interpolate_bads_meg - _interpolate_bads_meg(inst, ref_meg=ref_meg, mode='fast') + + _interpolate_bads_meg(inst, ref_meg=ref_meg, mode="fast") return inst @pytest.mark.slowtest def test_interpolate_meg_ctf(): """Test interpolation of MEG channels from CTF system.""" - thresh = .85 - tol = .05 # assert the new interpol correlates at least .05 "better" - bad = 'MLC22-2622' # select a good channel to test the interpolation + thresh = 0.85 + tol = 0.05 # assert the new interpol correlates at least .05 "better" + bad = "MLC22-2622" # select a good channel to test the interpolation raw = io.read_raw_fif(raw_fname_ctf).crop(0, 1.0).load_data() # 3 secs raw.apply_gradient_compensation(3) # Show that we have to exclude ref_meg for interpolating CTF MEG-channels # (fixed in #5965): - raw.info['bads'] = [bad] - pick_bad = pick_channels(raw.info['ch_names'], raw.info['bads']) + raw.info["bads"] = [bad] + pick_bad = pick_channels(raw.info["ch_names"], raw.info["bads"]) data_orig = raw[pick_bad, :][0] # mimic old behavior (the ref_meg-arg in _interpolate_bads_meg only serves # this purpose): @@ -260,12 +284,12 @@ def test_interpolate_meg_ctf(): data_interp_no_refmeg = _this_interpol(raw, ref_meg=False)[pick_bad, :][0] R = dict() - R['no_refmeg'] = np.corrcoef(data_orig, data_interp_no_refmeg)[0, 1] - R['with_refmeg'] = np.corrcoef(data_orig, data_interp_refmeg)[0, 1] + R["no_refmeg"] = np.corrcoef(data_orig, data_interp_no_refmeg)[0, 1] + R["with_refmeg"] = np.corrcoef(data_orig, data_interp_refmeg)[0, 1] - print('Corrcoef of interpolated with original channel: ', R) - assert R['no_refmeg'] > R['with_refmeg'] + tol - assert R['no_refmeg'] > thresh + print("Corrcoef of interpolated with original channel: ", R) + assert R["no_refmeg"] > R["with_refmeg"] + tol + assert R["no_refmeg"] > thresh @testing.requires_testing_data @@ -273,33 +297,30 @@ def test_interpolation_ctf_comp(): """Test interpolation with compensated CTF data.""" raw_fname = testing_path / "CTF" / "somMDYO-18av.ds" raw = io.read_raw_ctf(raw_fname, preload=True) - raw.info['bads'] = [raw.ch_names[5], raw.ch_names[-5]] - raw.interpolate_bads(mode='fast', origin=(0., 0., 0.04)) - assert raw.info['bads'] == [] + raw.info["bads"] = [raw.ch_names[5], raw.ch_names[-5]] + raw.interpolate_bads(mode="fast", origin=(0.0, 0.0, 0.04)) + assert raw.info["bads"] == [] -@requires_version('pymatreader') +@requires_version("pymatreader") @testing.requires_testing_data def test_interpolation_nirs(): """Test interpolating bad nirs channels.""" - fname = ( - testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_overlap" - ) + fname = testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_overlap" raw_intensity = read_raw_nirx(fname, preload=False) raw_od = optical_density(raw_intensity) sci = scalp_coupling_index(raw_od) - raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5)) - bad_0 = np.where([name == raw_od.info['bads'][0] for - name in raw_od.ch_names])[0][0] + raw_od.info["bads"] = list(compress(raw_od.ch_names, sci < 0.5)) + bad_0 = np.where([name == raw_od.info["bads"][0] for name in raw_od.ch_names])[0][0] bad_0_std_pre_interp = np.std(raw_od._data[bad_0]) - bads_init = list(raw_od.info['bads']) + bads_init = list(raw_od.info["bads"]) raw_od.interpolate_bads(exclude=bads_init[:2]) - assert raw_od.info['bads'] == bads_init[:2] + assert raw_od.info["bads"] == bads_init[:2] raw_od.interpolate_bads() - assert raw_od.info['bads'] == [] + assert raw_od.info["bads"] == [] assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0]) raw_haemo = beer_lambert_law(raw_od, ppf=6) - raw_haemo.info['bads'] = raw_haemo.ch_names[2:4] - assert raw_haemo.info['bads'] == ['S1_D2 hbo', 'S1_D2 hbr'] + raw_haemo.info["bads"] = raw_haemo.ch_names[2:4] + assert raw_haemo.info["bads"] == ["S1_D2 hbo", "S1_D2 hbr"] raw_haemo.interpolate_bads() - assert raw_haemo.info['bads'] == [] + assert raw_haemo.info["bads"] == [] diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index e17f90cafaf..2362fb2e23b 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -9,15 +9,23 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_allclose, + assert_equal, +) import pytest import matplotlib.pyplot as plt -from mne.channels import (make_eeg_layout, make_grid_layout, read_layout, - find_layout, HEAD_SIZE_DEFAULT) -from mne.channels.layout import (_box_size, _find_topomap_coords, - generate_2d_layout) +from mne.channels import ( + make_eeg_layout, + make_grid_layout, + read_layout, + find_layout, + HEAD_SIZE_DEFAULT, +) +from mne.channels.layout import _box_size, _find_topomap_coords, generate_2d_layout from mne import pick_types, pick_info from mne.io import read_raw_kit, _empty_info, read_info from mne.io.constants import FIFF @@ -34,18 +42,50 @@ def _get_test_info(): """Make test info.""" test_info = _empty_info(1000) - loc = np.array([0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], - dtype=np.float32) - test_info['chs'] = [ - {'cal': 1, 'ch_name': 'ICA 001', 'coil_type': 0, 'coord_frame': 0, - 'kind': 502, 'loc': loc.copy(), 'logno': 1, 'range': 1.0, 'scanno': 1, - 'unit': -1, 'unit_mul': 0}, - {'cal': 1, 'ch_name': 'ICA 002', 'coil_type': 0, 'coord_frame': 0, - 'kind': 502, 'loc': loc.copy(), 'logno': 2, 'range': 1.0, 'scanno': 2, - 'unit': -1, 'unit_mul': 0}, - {'cal': 0.002142000012099743, 'ch_name': 'EOG 061', 'coil_type': 1, - 'coord_frame': 0, 'kind': 202, 'loc': loc.copy(), 'logno': 61, - 'range': 1.0, 'scanno': 376, 'unit': 107, 'unit_mul': 0}] + loc = np.array( + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], dtype=np.float32 + ) + test_info["chs"] = [ + { + "cal": 1, + "ch_name": "ICA 001", + "coil_type": 0, + "coord_frame": 0, + "kind": 502, + "loc": loc.copy(), + "logno": 1, + "range": 1.0, + "scanno": 1, + "unit": -1, + "unit_mul": 0, + }, + { + "cal": 1, + "ch_name": "ICA 002", + "coil_type": 0, + "coord_frame": 0, + "kind": 502, + "loc": loc.copy(), + "logno": 2, + "range": 1.0, + "scanno": 2, + "unit": -1, + "unit_mul": 0, + }, + { + "cal": 0.002142000012099743, + "ch_name": "EOG 061", + "coil_type": 1, + "coord_frame": 0, + "kind": 202, + "loc": loc.copy(), + "logno": 61, + "range": 1.0, + "scanno": 376, + "unit": 107, + "unit_mul": 0, + }, + ] test_info._unlocked = False test_info._update_redundant() test_info._check_consistency() @@ -57,7 +97,8 @@ def test_io_layout_lout(tmp_path): layout = read_layout(fname="Vectorview-all", scale=False) layout.save(tmp_path / "foobar.lout", overwrite=True) layout_read = read_layout( - fname=tmp_path / "foobar.lout", scale=False, + fname=tmp_path / "foobar.lout", + scale=False, ) assert_array_almost_equal(layout.pos, layout_read.pos, decimal=2) assert layout.names == layout_read.names @@ -66,15 +107,17 @@ def test_io_layout_lout(tmp_path): # deprecation with pytest.warns(DeprecationWarning, match="should not be provided"): layout_read = read_layout( - fname=tmp_path / "foobar.lout", kind="Vectorview-all", scale=False, + fname=tmp_path / "foobar.lout", + kind="Vectorview-all", + scale=False, ) with pytest.warns(DeprecationWarning, match="should not be provided"): layout_read = read_layout( - fname=tmp_path / "foobar.lout", path=None, scale=False, + fname=tmp_path / "foobar.lout", + path=None, + scale=False, ) - with pytest.warns( - DeprecationWarning, match="'kind' and 'path' are deprecated" - ): + with pytest.warns(DeprecationWarning, match="'kind' and 'path' are deprecated"): layout_read = read_layout(kind="Vectorview-all", scale=False) @@ -82,9 +125,7 @@ def test_io_layout_lay(tmp_path): """Test IO with .lay files.""" layout = read_layout(fname="CTF151", scale=False) layout.save(str(tmp_path / "foobar.lay")) - layout_read = read_layout( - fname=tmp_path / "foobar.lay", scale=False - ) + layout_read = read_layout(fname=tmp_path / "foobar.lay", scale=False) assert_array_almost_equal(layout.pos, layout_read.pos, decimal=2) assert layout.names == layout_read.names @@ -96,60 +137,57 @@ def test_find_topomap_coords(): # Remove extra digitization point, so EEG digitization points match up # with the EEG channels - del info['dig'][85] + del info["dig"][85] # Use channel locations - kwargs = dict(ignore_overlap=False, to_sphere=True, - sphere=HEAD_SIZE_DEFAULT) + kwargs = dict(ignore_overlap=False, to_sphere=True, sphere=HEAD_SIZE_DEFAULT) l0 = _find_topomap_coords(info, picks, **kwargs) # Remove electrode position information, use digitization points from now # on. - for ch in info['chs']: - ch['loc'].fill(np.nan) + for ch in info["chs"]: + ch["loc"].fill(np.nan) l1 = _find_topomap_coords(info, picks, **kwargs) assert_allclose(l1, l0, atol=1e-3) - for z_pt in ((HEAD_SIZE_DEFAULT, 0., 0.), - (0., HEAD_SIZE_DEFAULT, 0.)): - info['dig'][-1]['r'] = np.array(z_pt) + for z_pt in ((HEAD_SIZE_DEFAULT, 0.0, 0.0), (0.0, HEAD_SIZE_DEFAULT, 0.0)): + info["dig"][-1]["r"] = np.array(z_pt) l1 = _find_topomap_coords(info, picks, **kwargs) - assert_allclose(l1[-1], z_pt[:2], err_msg='Z=0 point moved', atol=1e-6) + assert_allclose(l1[-1], z_pt[:2], err_msg="Z=0 point moved", atol=1e-6) # Test plotting mag topomap without channel locations: it should fail - mag_picks = pick_types(info, meg='mag') - with pytest.raises(ValueError, match='Cannot determine location'): + mag_picks = pick_types(info, meg="mag") + with pytest.raises(ValueError, match="Cannot determine location"): _find_topomap_coords(info, mag_picks, **kwargs) # Test function with too many EEG digitization points: it should fail - info['dig'].append({'r': [1, 2, 3], 'kind': FIFF.FIFFV_POINT_EEG}) - with pytest.raises(ValueError, match='Number of EEG digitization points'): + info["dig"].append({"r": [1, 2, 3], "kind": FIFF.FIFFV_POINT_EEG}) + with pytest.raises(ValueError, match="Number of EEG digitization points"): _find_topomap_coords(info, picks, **kwargs) # Test function with too little EEG digitization points: it should fail info._unlocked = True - info['dig'] = info['dig'][:-2] - with pytest.raises(ValueError, match='Number of EEG digitization points'): + info["dig"] = info["dig"][:-2] + with pytest.raises(ValueError, match="Number of EEG digitization points"): _find_topomap_coords(info, picks, **kwargs) # Electrode positions must be unique - info['dig'].append(info['dig'][-1]) - with pytest.raises(ValueError, match='overlapping positions'): + info["dig"].append(info["dig"][-1]) + with pytest.raises(ValueError, match="overlapping positions"): _find_topomap_coords(info, picks, **kwargs) # Test function without EEG digitization points: it should fail - info['dig'] = [d for d in info['dig'] - if d['kind'] != FIFF.FIFFV_POINT_EEG] - with pytest.raises(RuntimeError, match='Did not find any digitization'): + info["dig"] = [d for d in info["dig"] if d["kind"] != FIFF.FIFFV_POINT_EEG] + with pytest.raises(RuntimeError, match="Did not find any digitization"): _find_topomap_coords(info, picks, **kwargs) # Test function without any digitization points, it should fail - info['dig'] = None - with pytest.raises(RuntimeError, match='No digitization points found'): + info["dig"] = None + with pytest.raises(RuntimeError, match="No digitization points found"): _find_topomap_coords(info, picks, **kwargs) - info['dig'] = [] - with pytest.raises(RuntimeError, match='No digitization points found'): + info["dig"] = [] + with pytest.raises(RuntimeError, match="No digitization points found"): _find_topomap_coords(info, picks, **kwargs) @@ -200,92 +238,92 @@ def test_make_grid_layout(tmp_path): def test_find_layout(): """Test finding layout.""" - pytest.raises(ValueError, find_layout, _get_test_info(), ch_type='meep') + pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep") sample_info = read_info(fif_fname) - grads = pick_types(sample_info, meg='grad') + grads = pick_types(sample_info, meg="grad") sample_info2 = pick_info(sample_info, grads) - mags = pick_types(sample_info, meg='mag') + mags = pick_types(sample_info, meg="mag") sample_info3 = pick_info(sample_info, mags) # mock new convention sample_info4 = copy.deepcopy(sample_info) - for ii, name in enumerate(sample_info4['ch_names']): - new = name.replace(' ', '') - sample_info4['chs'][ii]['ch_name'] = new + for ii, name in enumerate(sample_info4["ch_names"]): + new = name.replace(" ", "") + sample_info4["chs"][ii]["ch_name"] = new eegs = pick_types(sample_info, meg=False, eeg=True) sample_info5 = pick_info(sample_info, eegs) lout = find_layout(sample_info, ch_type=None) - assert lout.kind == 'Vectorview-all' - assert all(' ' in k for k in lout.names) + assert lout.kind == "Vectorview-all" + assert all(" " in k for k in lout.names) - lout = find_layout(sample_info2, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') + lout = find_layout(sample_info2, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") # test new vector-view lout = find_layout(sample_info4, ch_type=None) - assert_equal(lout.kind, 'Vectorview-all') - assert all(' ' not in k for k in lout.names) + assert_equal(lout.kind, "Vectorview-all") + assert all(" " not in k for k in lout.names) - lout = find_layout(sample_info, ch_type='grad') - assert_equal(lout.kind, 'Vectorview-grad') + lout = find_layout(sample_info, ch_type="grad") + assert_equal(lout.kind, "Vectorview-grad") lout = find_layout(sample_info2) - assert_equal(lout.kind, 'Vectorview-grad') - lout = find_layout(sample_info2, ch_type='grad') - assert_equal(lout.kind, 'Vectorview-grad') - lout = find_layout(sample_info2, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') - - lout = find_layout(sample_info, ch_type='mag') - assert_equal(lout.kind, 'Vectorview-mag') + assert_equal(lout.kind, "Vectorview-grad") + lout = find_layout(sample_info2, ch_type="grad") + assert_equal(lout.kind, "Vectorview-grad") + lout = find_layout(sample_info2, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") + + lout = find_layout(sample_info, ch_type="mag") + assert_equal(lout.kind, "Vectorview-mag") lout = find_layout(sample_info3) - assert_equal(lout.kind, 'Vectorview-mag') - lout = find_layout(sample_info3, ch_type='mag') - assert_equal(lout.kind, 'Vectorview-mag') - lout = find_layout(sample_info3, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') - - lout = find_layout(sample_info, ch_type='eeg') - assert_equal(lout.kind, 'EEG') + assert_equal(lout.kind, "Vectorview-mag") + lout = find_layout(sample_info3, ch_type="mag") + assert_equal(lout.kind, "Vectorview-mag") + lout = find_layout(sample_info3, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") + + lout = find_layout(sample_info, ch_type="eeg") + assert_equal(lout.kind, "EEG") lout = find_layout(sample_info5) - assert_equal(lout.kind, 'EEG') - lout = find_layout(sample_info5, ch_type='eeg') - assert_equal(lout.kind, 'EEG') + assert_equal(lout.kind, "EEG") + lout = find_layout(sample_info5, ch_type="eeg") + assert_equal(lout.kind, "EEG") # no common layout, 'meg' option not supported lout = find_layout(read_info(fname_ctf_raw)) - assert_equal(lout.kind, 'CTF-275') + assert_equal(lout.kind, "CTF-275") fname_bti_raw = bti_dir / "exported4D_linux_raw.fif" lout = find_layout(read_info(fname_bti_raw)) - assert_equal(lout.kind, 'magnesWH3600') + assert_equal(lout.kind, "magnesWH3600") raw_kit = read_raw_kit(fname_kit_157) lout = find_layout(raw_kit.info) - assert_equal(lout.kind, 'KIT-157') + assert_equal(lout.kind, "KIT-157") - raw_kit.info['bads'] = ['MEG 013', 'MEG 014', 'MEG 015', 'MEG 016'] + raw_kit.info["bads"] = ["MEG 013", "MEG 014", "MEG 015", "MEG 016"] raw_kit.info._check_consistency() lout = find_layout(raw_kit.info) - assert_equal(lout.kind, 'KIT-157') + assert_equal(lout.kind, "KIT-157") # fallback for missing IDs for val in (35, 52, 54, 1001): with raw_kit.info._unlock(): - raw_kit.info['kit_system_id'] = val + raw_kit.info["kit_system_id"] = val lout = find_layout(raw_kit.info) - assert lout.kind == 'custom' + assert lout.kind == "custom" raw_umd = read_raw_kit(fname_kit_umd) lout = find_layout(raw_umd.info) - assert_equal(lout.kind, 'KIT-UMD-3') + assert_equal(lout.kind, "KIT-UMD-3") # Test plotting lout.plot() lout.plot(picks=np.arange(10)) - plt.close('all') + plt.close("all") def test_box_size(): @@ -357,7 +395,7 @@ def test_generate_2d_layout(): sbg = 15 side = range(snobg) bg_image = np.random.RandomState(42).randn(sbg, sbg) - w, h = [.2, .5] + w, h = [0.2, 0.5] # Generate fake data xy = np.array([(i, j) for i in side for j in side]) @@ -367,9 +405,10 @@ def test_generate_2d_layout(): comp_1, comp_2 = [(5, 0), (7, 0)] assert lt.pos[:, :2].max() == 1 assert lt.pos[:, :2].min() == 0 - with np.errstate(invalid='ignore'): # divide by zero - assert_allclose(xy[comp_2] / float(xy[comp_1]), - lt.pos[comp_2] / float(lt.pos[comp_1])) + with np.errstate(invalid="ignore"): # divide by zero + assert_allclose( + xy[comp_2] / float(xy[comp_1]), lt.pos[comp_2] / float(lt.pos[comp_1]) + ) assert_allclose(lt.pos[0, [2, 3]], [w, h]) # Correct number elements diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index f78e6bb3f2d..ca7347b21ad 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -14,37 +14,60 @@ from functools import partial from string import ascii_lowercase -from numpy.testing import (assert_array_equal, assert_array_less, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_equal, + assert_array_less, + assert_allclose, + assert_equal, +) import matplotlib.pyplot as plt from mne import __file__ as _mne_file, create_info, read_evokeds, pick_types from mne.source_space import get_mni_fiducials from mne.utils._testing import assert_object_equal -from mne.channels import (get_builtin_montages, DigMontage, read_dig_dat, - read_dig_egi, read_dig_captrak, read_dig_fif, - make_standard_montage, read_custom_montage, - compute_dev_head_t, make_dig_montage, - read_dig_polhemus_isotrak, compute_native_head_t, - read_polhemus_fastscan, read_dig_localite, - read_dig_hpts) +from mne.channels import ( + get_builtin_montages, + DigMontage, + read_dig_dat, + read_dig_egi, + read_dig_captrak, + read_dig_fif, + make_standard_montage, + read_custom_montage, + compute_dev_head_t, + make_dig_montage, + read_dig_polhemus_isotrak, + compute_native_head_t, + read_polhemus_fastscan, + read_dig_localite, + read_dig_hpts, +) from mne.channels.montage import ( - transform_to_head, _check_get_coord_frame, _BUILTIN_STANDARD_MONTAGES + transform_to_head, + _check_get_coord_frame, + _BUILTIN_STANDARD_MONTAGES, ) from mne.preprocessing import compute_current_source_density from mne.utils import assert_dig_allclose, _record_warnings from mne.bem import _fit_sphere from mne.io.constants import FIFF -from mne.io._digitization import (_format_dig_points, - _get_fid_coords, _get_dig_eeg, - _count_points_by_type) -from mne.transforms import (_ensure_trans, apply_trans, invert_transform, - _get_trans) +from mne.io._digitization import ( + _format_dig_points, + _get_fid_coords, + _get_dig_eeg, + _count_points_by_type, +) +from mne.transforms import _ensure_trans, apply_trans, invert_transform, _get_trans from mne.viz._3d import _fiducial_coords from mne.io.kit import read_mrk -from mne.io import (read_raw_brainvision, read_raw_egi, read_raw_fif, - read_fiducials, read_raw_nirx) +from mne.io import ( + read_raw_brainvision, + read_raw_egi, + read_raw_fif, + read_fiducials, + read_raw_nirx, +) from mne.io import RawArray from mne.datasets import testing @@ -88,41 +111,43 @@ def _make_toy_raw(n_channels): return RawArray( data=np.empty([n_channels, 1]), info=create_info( - ch_names=list(ascii_lowercase[:n_channels]), - sfreq=1, ch_types='eeg' - ) + ch_names=list(ascii_lowercase[:n_channels]), sfreq=1, ch_types="eeg" + ), ) def _make_toy_dig_montage(n_channels, **kwargs): return make_dig_montage( - ch_pos=dict(zip( - list(ascii_lowercase[:n_channels]), - np.arange(n_channels * 3).reshape(n_channels, 3), - )), - **kwargs + ch_pos=dict( + zip( + list(ascii_lowercase[:n_channels]), + np.arange(n_channels * 3).reshape(n_channels, 3), + ) + ), + **kwargs, ) def _get_dig_montage_pos(montage): - return np.array([d['r'] for d in _get_dig_eeg(montage.dig)]) + return np.array([d["r"] for d in _get_dig_eeg(montage.dig)]) def test_dig_montage_trans(tmp_path): """Test getting a trans from and applying a trans to a montage.""" nasion, lpa, rpa, *ch_pos = np.random.RandomState(0).randn(10, 3) - ch_pos = {f'EEG{ii:3d}': pos for ii, pos in enumerate(ch_pos, 1)} - montage = make_dig_montage(ch_pos, nasion=nasion, lpa=lpa, rpa=rpa, - coord_frame='mri') + ch_pos = {f"EEG{ii:3d}": pos for ii, pos in enumerate(ch_pos, 1)} + montage = make_dig_montage( + ch_pos, nasion=nasion, lpa=lpa, rpa=rpa, coord_frame="mri" + ) trans = compute_native_head_t(montage) _ensure_trans(trans) # ensure that we can save and load it, too - fname = tmp_path / 'temp-mon.fif' - _check_roundtrip(montage, fname, 'mri') + fname = tmp_path / "temp-mon.fif" + _check_roundtrip(montage, fname, "mri") # test applying a trans position1 = montage.get_positions() montage.apply_trans(trans) - assert montage.get_positions()['coord_frame'] == 'head' + assert montage.get_positions()["coord_frame"] == "head" montage.apply_trans(invert_transform(trans)) position2 = montage.get_positions() assert str(position1) == str(position2) # exactly equal @@ -137,300 +162,356 @@ def test_fiducials(): points = _fiducial_coords(fids, coord_frame) assert points.shape == (3, 3) # Fids - assert_allclose(points[:, 2], 0., atol=1e-6) - assert_allclose(points[::2, 1], 0., atol=1e-6) + assert_allclose(points[:, 2], 0.0, atol=1e-6) + assert_allclose(points[::2, 1], 0.0, atol=1e-6) assert points[2, 0] > 0 # RPA assert points[0, 0] < 0 # LPA # Nasion - assert_allclose(points[1, 0], 0., atol=1e-6) + assert_allclose(points[1, 0], 0.0, atol=1e-6) assert points[1, 1] > 0 def test_documented(): """Test that standard montages are documented.""" - montage_dir = Path(_mne_file).parent / 'channels' / 'data' / 'montages' - montage_files = Path(montage_dir).glob('*') + montage_dir = Path(_mne_file).parent / "channels" / "data" / "montages" + montage_files = Path(montage_dir).glob("*") montage_names = [f.stem for f in montage_files] assert len(montage_names) == len(_BUILTIN_STANDARD_MONTAGES) - assert set(montage_names) == set( - [m.name for m in _BUILTIN_STANDARD_MONTAGES] - ) - - -@pytest.mark.parametrize('reader, file_content, expected_dig, ext, warning', [ - pytest.param( - partial(read_custom_montage, head_size=None), - ('FidNz 0 9.071585155 -2.359754454\n' - 'FidT9 -6.711765 0.040402876 -3.251600355\n' - 'very_very_very_long_name -5.831241498 -4.494821698 4.955347697\n' - 'Cz 0 0 1\n' - 'Cz 0 0 8.899186843'), - make_dig_montage( - ch_pos={ - 'very_very_very_long_name': [-5.8312416, -4.4948215, 4.9553475], # noqa - 'Cz': [0., 0., 8.899187], - }, - nasion=[0., 9.071585, -2.3597546], - lpa=[-6.711765, 0.04040287, -3.2516003], - rpa=None, + assert set(montage_names) == set([m.name for m in _BUILTIN_STANDARD_MONTAGES]) + + +@pytest.mark.parametrize( + "reader, file_content, expected_dig, ext, warning", + [ + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FidNz 0 9.071585155 -2.359754454\n" + "FidT9 -6.711765 0.040402876 -3.251600355\n" + "very_very_very_long_name -5.831241498 -4.494821698 4.955347697\n" + "Cz 0 0 1\n" + "Cz 0 0 8.899186843" + ), + make_dig_montage( + ch_pos={ + "very_very_very_long_name": [ + -5.8312416, + -4.4948215, + 4.9553475, + ], # noqa + "Cz": [0.0, 0.0, 8.899187], + }, + nasion=[0.0, 9.071585, -2.3597546], + lpa=[-6.711765, 0.04040287, -3.2516003], + rpa=None, + ), + "sfp", + (RuntimeWarning, r"Duplicate.*last will be used for Cz \(2\)"), + id="sfp_duplicate", ), - 'sfp', - (RuntimeWarning, r'Duplicate.*last will be used for Cz \(2\)'), - id='sfp_duplicate'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('FidNz 0 9.071585155 -2.359754454\n' - 'FidT9 -6.711765 0.040402876 -3.251600355\n' - 'headshape 1 2 3\n' - 'headshape 4 5 6\n' - 'Cz 0 0 8.899186843'), - make_dig_montage( - hsp=[ - [1, 2, 3], - [4, 5, 6], - ], - ch_pos={ - 'Cz': [0., 0., 8.899187], - }, - nasion=[0., 9.071585, -2.3597546], - lpa=[-6.711765, 0.04040287, -3.2516003], - rpa=None, + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FidNz 0 9.071585155 -2.359754454\n" + "FidT9 -6.711765 0.040402876 -3.251600355\n" + "headshape 1 2 3\n" + "headshape 4 5 6\n" + "Cz 0 0 8.899186843" + ), + make_dig_montage( + hsp=[ + [1, 2, 3], + [4, 5, 6], + ], + ch_pos={ + "Cz": [0.0, 0.0, 8.899187], + }, + nasion=[0.0, 9.071585, -2.3597546], + lpa=[-6.711765, 0.04040287, -3.2516003], + rpa=None, + ), + "sfp", + None, + id="sfp_headshape", ), - 'sfp', - None, - id='sfp_headshape'), - - pytest.param( - partial(read_custom_montage, head_size=1), - ('1 0 0.50669 FPz\n' - '2 23 0.71 EOG1\n' - '3 -39.947 0.34459 F3\n' - '4 0 0.25338 Fz\n'), - make_dig_montage( - ch_pos={ - 'EOG1': [0.30873816, 0.72734152, -0.61290705], - 'F3': [-0.56705965, 0.67706631, 0.46906776], - 'FPz': [0., 0.99977915, -0.02101571], - 'Fz': [0., 0.71457525, 0.69955859], - }, - nasion=None, lpa=None, rpa=None, coord_frame='head', + pytest.param( + partial(read_custom_montage, head_size=1), + ( + "1 0 0.50669 FPz\n" + "2 23 0.71 EOG1\n" + "3 -39.947 0.34459 F3\n" + "4 0 0.25338 Fz\n" + ), + make_dig_montage( + ch_pos={ + "EOG1": [0.30873816, 0.72734152, -0.61290705], + "F3": [-0.56705965, 0.67706631, 0.46906776], + "FPz": [0.0, 0.99977915, -0.02101571], + "Fz": [0.0, 0.71457525, 0.69955859], + }, + nasion=None, + lpa=None, + rpa=None, + coord_frame="head", + ), + "loc", + None, + id="EEGLAB", ), - 'loc', - None, - id='EEGLAB'), - - pytest.param( - partial(read_custom_montage, head_size=None, coord_frame='mri'), - "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 - "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 - "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 - "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 - "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 - "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 - make_dig_montage( - ch_pos={ - 'E1': [0.7677, 0.5934, -0.2419], - 'E3': [0.6084, 0.7704, 0.1908], - 'E31': [0., 0.9816, -0.1908], - 'E61': [-0.8857, 0.3579, -0.2957], - }, - nasion=None, lpa=None, rpa=None, coord_frame='mri', + pytest.param( + partial(read_custom_montage, head_size=None, coord_frame="mri"), + "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 + "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 + "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 + "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 + "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 + "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 + make_dig_montage( + ch_pos={ + "E1": [0.7677, 0.5934, -0.2419], + "E3": [0.6084, 0.7704, 0.1908], + "E31": [0.0, 0.9816, -0.1908], + "E61": [-0.8857, 0.3579, -0.2957], + }, + nasion=None, + lpa=None, + rpa=None, + coord_frame="mri", + ), + "csd", + None, + id="matlab", ), - 'csd', - None, - id='matlab'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('# ASA electrode file\nReferenceLabel avg\nUnitPosition mm\n' - 'NumberPositions= 68\n' - 'Positions\n' - '-86.0761 -19.9897 -47.9860\n' - '85.7939 -20.0093 -48.0310\n' - '0.0083 86.8110 -39.9830\n' - '-86.0761 -24.9897 -67.9860\n' - 'Labels\nLPA\nRPA\nNz\nDummy\n'), - make_dig_montage( - ch_pos={ - 'Dummy': [-0.0860761, -0.0249897, -0.067986], - }, - nasion=[8.3000e-06, 8.6811e-02, -3.9983e-02], - lpa=[-0.0860761, -0.0199897, -0.047986], - rpa=[0.0857939, -0.0200093, -0.048031], + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "# ASA electrode file\nReferenceLabel avg\nUnitPosition mm\n" + "NumberPositions= 68\n" + "Positions\n" + "-86.0761 -19.9897 -47.9860\n" + "85.7939 -20.0093 -48.0310\n" + "0.0083 86.8110 -39.9830\n" + "-86.0761 -24.9897 -67.9860\n" + "Labels\nLPA\nRPA\nNz\nDummy\n" + ), + make_dig_montage( + ch_pos={ + "Dummy": [-0.0860761, -0.0249897, -0.067986], + }, + nasion=[8.3000e-06, 8.6811e-02, -3.9983e-02], + lpa=[-0.0860761, -0.0199897, -0.047986], + rpa=[0.0857939, -0.0200093, -0.048031], + ), + "elc", + None, + id="ASA electrode", ), - 'elc', - None, - id='ASA electrode'), - - pytest.param( - partial(read_custom_montage, head_size=1), - ('Site Theta Phi\n' - 'Fp1 -92 -72\n' - 'Fp2 92 72\n' - 'very_very_very_long_name -92 72\n' - 'O2 92 -90\n'), - make_dig_montage( - ch_pos={ - 'Fp1': [-0.30882875, 0.95047716, -0.0348995], - 'Fp2': [0.30882875, 0.95047716, -0.0348995], - 'very_very_very_long_name': [-0.30882875, -0.95047716, -0.0348995], # noqa - 'O2': [6.11950389e-17, -9.99390827e-01, -3.48994967e-02] - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_custom_montage, head_size=1), + ( + "Site Theta Phi\n" + "Fp1 -92 -72\n" + "Fp2 92 72\n" + "very_very_very_long_name -92 72\n" + "O2 92 -90\n" + ), + make_dig_montage( + ch_pos={ + "Fp1": [-0.30882875, 0.95047716, -0.0348995], + "Fp2": [0.30882875, 0.95047716, -0.0348995], + "very_very_very_long_name": [ + -0.30882875, + -0.95047716, + -0.0348995, + ], # noqa + "O2": [6.11950389e-17, -9.99390827e-01, -3.48994967e-02], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "txt", + None, + id="generic theta-phi (txt)", ), - 'txt', - None, - id='generic theta-phi (txt)'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('FID\t LPA\t -120.03\t 0\t 85\n' - 'FID\t RPA\t 120.03\t 0\t 85\n' - 'FID\t Nz\t 114.03\t 90\t 85\n' - 'EEG\t F3\t -62.027\t -50.053\t 85\n' - 'EEG\t Fz\t 45.608\t 90\t 85\n' - 'EEG\t F4\t 62.01\t 50.103\t 85\n' - 'EEG\t FCz\t 68.01\t 58.103\t 85\n'), - make_dig_montage( - ch_pos={ - 'F3': [-0.48200427, 0.57551063, 0.39869712], - 'Fz': [3.71915931e-17, 6.07384809e-01, 5.94629038e-01], - 'F4': [0.48142596, 0.57584026, 0.39891983], - 'FCz': [0.41645989, 0.66914889, 0.31827805], - }, - nasion=[4.75366562e-17, 7.76332511e-01, -3.46132681e-01], - lpa=[-7.35898963e-01, 9.01216309e-17, -4.25385374e-01], - rpa=[0.73589896, 0., -0.42538537], + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FID\t LPA\t -120.03\t 0\t 85\n" + "FID\t RPA\t 120.03\t 0\t 85\n" + "FID\t Nz\t 114.03\t 90\t 85\n" + "EEG\t F3\t -62.027\t -50.053\t 85\n" + "EEG\t Fz\t 45.608\t 90\t 85\n" + "EEG\t F4\t 62.01\t 50.103\t 85\n" + "EEG\t FCz\t 68.01\t 58.103\t 85\n" + ), + make_dig_montage( + ch_pos={ + "F3": [-0.48200427, 0.57551063, 0.39869712], + "Fz": [3.71915931e-17, 6.07384809e-01, 5.94629038e-01], + "F4": [0.48142596, 0.57584026, 0.39891983], + "FCz": [0.41645989, 0.66914889, 0.31827805], + }, + nasion=[4.75366562e-17, 7.76332511e-01, -3.46132681e-01], + lpa=[-7.35898963e-01, 9.01216309e-17, -4.25385374e-01], + rpa=[0.73589896, 0.0, -0.42538537], + ), + "elp", + None, + id="BESA spherical model", ), - 'elp', - None, - id='BESA spherical model'), - - pytest.param( - partial(read_dig_hpts, unit='m'), - ('eeg Fp1 -95.0 -3. -3.\n' - 'eeg AF7 -1 -1 -3\n' - 'eeg A3 -2 -2 2\n' - 'eeg A 0 0 0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_dig_hpts, unit="m"), + ( + "eeg Fp1 -95.0 -3. -3.\n" + "eeg AF7 -1 -1 -3\n" + "eeg A3 -2 -2 2\n" + "eeg A 0 0 0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "hpts", + None, + id="legacy mne-c", ), - 'hpts', - None, - id='legacy mne-c'), - - pytest.param( - read_custom_montage, - ('ch_name, x, y, z\n' - 'Fp1, -95.0, -3., -3.\n' - 'AF7, -1, -1, -3\n' - 'A3, -2, -2, 2\n' - 'A, 0, 0, 0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "ch_name, x, y, z\n" + "Fp1, -95.0, -3., -3.\n" + "AF7, -1, -1, -3\n" + "A3, -2, -2, 2\n" + "A, 0, 0, 0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "csv", + None, + id="CSV file", ), - 'csv', - None, - id='CSV file'), - - pytest.param( - read_custom_montage, - ('1\t-95.0\t-3.\t-3.\tFp1\n' - '2\t-1\t-1\t-3\tAF7\n' - '3\t-2\t-2\t2\tA3\n' - '4\t0\t0\t0\tA'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "1\t-95.0\t-3.\t-3.\tFp1\n" + "2\t-1\t-1\t-3\tAF7\n" + "3\t-2\t-2\t2\tA3\n" + "4\t0\t0\t0\tA" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "xyz", + None, + id="XYZ file", ), - 'xyz', - None, - id='XYZ file'), - - pytest.param( - read_custom_montage, - ('ch_name\tx\ty\tz\n' - 'Fp1\t-95.0\t-3.\t-3.\n' - 'AF7\t-1\t-1\t-3\n' - 'A3\t-2\t-2\t2\n' - 'A\t0\t0\t0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "ch_name\tx\ty\tz\n" + "Fp1\t-95.0\t-3.\t-3.\n" + "AF7\t-1\t-1\t-3\n" + "A3\t-2\t-2\t2\n" + "A\t0\t0\t0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "tsv", + None, + id="TSV file", ), - 'tsv', - None, - id='TSV file'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('\n' - '\n' - '\n' - ' \n' - ' Fp1\n' - ' -90\n' - ' -72\n' - ' 1\n' - ' 1\n' - ' \n' - ' \n' - ' Fz\n' - ' 45\n' - ' 90\n' - ' 1\n' - ' 2\n' - ' \n' - ' \n' - ' F3\n' - ' -60\n' - ' -51\n' - ' 1\n' - ' 3\n' - ' \n' - ' \n' - ' F7\n' - ' -90\n' - ' -36\n' - ' 1\n' - ' 4\n' - ' \n' - ''), - make_dig_montage( - ch_pos={ - 'Fp1': [-3.09016994e-01, 9.51056516e-01, 6.12323400e-17], - 'Fz': [4.32978028e-17, 7.07106781e-01, 7.07106781e-01], - 'F3': [-0.54500745, 0.67302815, 0.5], - 'F7': [-8.09016994e-01, 5.87785252e-01, 6.12323400e-17], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_custom_montage, head_size=None), + ( + '\n' + "\n" + '\n' + " \n" + " Fp1\n" + " -90\n" + " -72\n" + " 1\n" + " 1\n" + " \n" + " \n" + " Fz\n" + " 45\n" + " 90\n" + " 1\n" + " 2\n" + " \n" + " \n" + " F3\n" + " -60\n" + " -51\n" + " 1\n" + " 3\n" + " \n" + " \n" + " F7\n" + " -90\n" + " -36\n" + " 1\n" + " 4\n" + " \n" + "" + ), + make_dig_montage( + ch_pos={ + "Fp1": [-3.09016994e-01, 9.51056516e-01, 6.12323400e-17], + "Fz": [4.32978028e-17, 7.07106781e-01, 7.07106781e-01], + "F3": [-0.54500745, 0.67302815, 0.5], + "F7": [-8.09016994e-01, 5.87785252e-01, 6.12323400e-17], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "bvef", + None, + id="brainvision", ), - 'bvef', - None, - id='brainvision'), -]) -def test_montage_readers( - reader, file_content, expected_dig, ext, warning, tmp_path -): + ], +) +def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_path): """Test that we have an equivalent of read_montage for all file formats.""" fname = tmp_path / f"test.{ext}" - with open(fname, 'w') as fid: + with open(fname, "w") as fid: fid.write(file_content) if warning is None: @@ -447,15 +528,15 @@ def test_montage_readers( assert_allclose(actual_ch_pos[kk], expected_ch_pos[kk], atol=1e-5) assert len(dig_montage.dig) == len(expected_dig.dig) for d1, d2 in zip(dig_montage.dig, expected_dig.dig): - assert d1['coord_frame'] == d2['coord_frame'] - for key in ('coord_frame', 'ident', 'kind'): + assert d1["coord_frame"] == d2["coord_frame"] + for key in ("coord_frame", "ident", "kind"): assert isinstance(d1[key], int) assert isinstance(d2[key], int) with _record_warnings() as w: xform = compute_native_head_t(dig_montage) - assert xform['to'] == FIFF.FIFFV_COORD_HEAD - assert xform['from'] == FIFF.FIFFV_COORD_UNKNOWN - n = int(np.allclose(xform['trans'], np.eye(4))) + assert xform["to"] == FIFF.FIFFV_COORD_HEAD + assert xform["from"] == FIFF.FIFFV_COORD_UNKNOWN + n = int(np.allclose(xform["trans"], np.eye(4))) assert len(w) == n @@ -465,32 +546,34 @@ def test_read_locs(): data = read_custom_montage(locs_montage_fname)._get_ch_pos() assert_allclose( actual=np.stack( - [data[kk] for kk in ('FPz', 'EOG1', 'F3', 'Fz')] # 4 random chs + [data[kk] for kk in ("FPz", "EOG1", "F3", "Fz")] # 4 random chs ), - desired=[[0., 0.094979, -0.001996], - [0.02933, 0.069097, -0.058226], - [-0.053871, 0.064321, 0.044561], - [0., 0.067885, 0.066458]], - atol=1e-6 + desired=[ + [0.0, 0.094979, -0.001996], + [0.02933, 0.069097, -0.058226], + [-0.053871, 0.064321, 0.044561], + [0.0, 0.067885, 0.066458], + ], + atol=1e-6, ) def test_read_dig_dat(tmp_path): """Test reading *.dat electrode locations.""" rows = [ - ['Nasion', 78, 0.00, 1.00, 0.00], - ['Left', 76, -1.00, 0.00, 0.00], - ['Right', 82, 1.00, -0.00, 0.00], - ['O2', 69, -0.50, -0.90, 0.05], - ['O2', 68, 0.00, 0.01, 0.02], - ['Centroid', 67, 0.00, 0.00, 0.00], + ["Nasion", 78, 0.00, 1.00, 0.00], + ["Left", 76, -1.00, 0.00, 0.00], + ["Right", 82, 1.00, -0.00, 0.00], + ["O2", 69, -0.50, -0.90, 0.05], + ["O2", 68, 0.00, 0.01, 0.02], + ["Centroid", 67, 0.00, 0.00, 0.00], ] # write mock test.dat file fname_temp = tmp_path / "test.dat" - with open(fname_temp, 'w') as fid: + with open(fname_temp, "w") as fid: for row in rows: name = row[0].rjust(10) - data = '\t'.join(map(str, row[1:])) + data = "\t".join(map(str, row[1:])) fid.write("%s\t%s\n" % (name, data)) # construct expected value idents = { @@ -507,15 +590,21 @@ def test_read_dig_dat(tmp_path): 69: FIFF.FIFFV_POINT_EEG, 68: FIFF.FIFFV_POINT_EEG, } - target = {row[0]: {'r': row[2:], 'ident': idents[row[1]], - 'kind': kinds[row[1]], 'coord_frame': 0} - for row in rows[:-1]} - assert_allclose(target['O2']['r'], [0, 0.01, 0.02]) + target = { + row[0]: { + "r": row[2:], + "ident": idents[row[1]], + "kind": kinds[row[1]], + "coord_frame": 0, + } + for row in rows[:-1] + } + assert_allclose(target["O2"]["r"], [0, 0.01, 0.02]) # read it - with pytest.warns(RuntimeWarning, match=r'Duplic.*for O2 \(2\)'): + with pytest.warns(RuntimeWarning, match=r"Duplic.*for O2 \(2\)"): dig = read_dig_dat(fname_temp) - assert set(dig.ch_names) == {'O2'} - keys = chain(['Left', 'Nasion', 'Right'], dig.ch_names) + assert set(dig.ch_names) == {"O2"} + keys = chain(["Left", "Nasion", "Right"], dig.ch_names) target = [target[k] for k in keys] assert dig.dig == target @@ -526,32 +615,29 @@ def test_read_dig_montage_using_polhemus_fastscan(): my_electrode_positions = read_polhemus_fastscan(kit_dir / "test_elp.txt") montage = make_dig_montage( # EEG_CH - ch_pos=dict(zip(ascii_lowercase[:N_EEG_CH], - np.random.RandomState(0).rand(N_EEG_CH, 3))), + ch_pos=dict( + zip(ascii_lowercase[:N_EEG_CH], np.random.RandomState(0).rand(N_EEG_CH, 3)) + ), # NO NAMED points nasion=my_electrode_positions[0], lpa=my_electrode_positions[1], rpa=my_electrode_positions[2], hpi=my_electrode_positions[3:], hsp=read_polhemus_fastscan(kit_dir / "test_hsp.txt"), - # Other defaults - coord_frame='unknown' + coord_frame="unknown", ) assert repr(montage) == ( - '' + "" ) - assert set([d['coord_frame'] for d in montage.dig]) == { - FIFF.FIFFV_COORD_UNKNOWN - } + assert set([d["coord_frame"] for d in montage.dig]) == {FIFF.FIFFV_COORD_UNKNOWN} EXPECTED_FID_IN_POLHEMUS = { - 'nasion': [0.001393, 0.0131613, -0.0046967], - 'lpa': [-0.0624997, -0.0737271, 0.07996], - 'rpa': [-0.0748957, 0.0873785, 0.0811943], + "nasion": [0.001393, 0.0131613, -0.0046967], + "lpa": [-0.0624997, -0.0737271, 0.07996], + "rpa": [-0.0748957, 0.0873785, 0.0811943], } fiducials, fid_coordframe = _get_fid_coords(montage.dig) assert fid_coordframe == FIFF.FIFFV_COORD_UNKNOWN @@ -562,17 +648,17 @@ def test_read_dig_montage_using_polhemus_fastscan(): def test_read_dig_montage_using_polhemus_fastscan_error_handling(tmp_path): """Test reading Polhemus FastSCAN errors.""" with open(kit_dir / "test_elp.txt") as fid: - content = fid.read().replace('FastSCAN', 'XxxxXXXX') + content = fid.read().replace("FastSCAN", "XxxxXXXX") - fname = tmp_path / 'faulty_FastSCAN.txt' - with open(fname, 'w') as fid: + fname = tmp_path / "faulty_FastSCAN.txt" + with open(fname, "w") as fid: fid.write(content) - with pytest.raises(ValueError, match='not contain.*Polhemus FastSCAN'): + with pytest.raises(ValueError, match="not contain.*Polhemus FastSCAN"): _ = read_polhemus_fastscan(fname) - fname = tmp_path / 'faulty_FastSCAN.bar' - with open(fname, 'w') as fid: + fname = tmp_path / "faulty_FastSCAN.bar" + with open(fname, "w") as fid: fid.write(content) EXPECTED_ERR_MSG = "allowed value is '.txt', but got '.bar' instead" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): @@ -582,16 +668,13 @@ def test_read_dig_montage_using_polhemus_fastscan_error_handling(tmp_path): def test_read_dig_polhemus_isotrak_hsp(): """Test reading Polhemus IsoTrak HSP file.""" EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - montage = read_dig_polhemus_isotrak( - fname=kit_dir / "test.hsp", ch_names=None - ) + montage = read_dig_polhemus_isotrak(fname=kit_dir / "test.hsp", ch_names=None) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -604,16 +687,13 @@ def test_read_dig_polhemus_isotrak_hsp(): def test_read_dig_polhemus_isotrak_elp(): """Test reading Polhemus IsoTrak ELP file.""" EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - montage = read_dig_polhemus_isotrak( - fname=kit_dir / "test.elp", ch_names=None - ) + montage = read_dig_polhemus_isotrak(fname=kit_dir / "test.elp", ch_names=None) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -622,35 +702,39 @@ def test_read_dig_polhemus_isotrak_elp(): assert_array_equal(val, EXPECTED_FID_IN_POLHEMUS[kk]) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def isotrak_eeg(tmp_path_factory): """Mock isotrak file with EEG positions.""" _SEED = 42 N_ROWS, N_COLS = 5, 3 content = np.random.RandomState(_SEED).randn(N_ROWS, N_COLS) - fname = tmp_path_factory.mktemp('data') / 'test.eeg' - with open(str(fname), 'w') as fid: - fid.write(( - '3 200\n' - '//Shape file\n' - '//Minor revision number\n' - '2\n' - '//Subject Name\n' - '%N Name \n' - '////Shape code, number of digitized points\n' - )) - fid.write('0 {rows:d}\n'.format(rows=N_ROWS)) - fid.write(( - '//Position of fiducials X+, Y+, Y- on the subject\n' - '%F 0.11056 -5.421e-19 0 \n' - '%F -0.00021075 0.080793 -7.5894e-19 \n' - '%F 0.00021075 -0.080793 -2.8731e-18 \n' - '//No of rows, no of columns; position of digitized points\n' - )) - fid.write('{rows:d} {cols:d}\n'.format(rows=N_ROWS, cols=N_COLS)) + fname = tmp_path_factory.mktemp("data") / "test.eeg" + with open(str(fname), "w") as fid: + fid.write( + ( + "3 200\n" + "//Shape file\n" + "//Minor revision number\n" + "2\n" + "//Subject Name\n" + "%N Name \n" + "////Shape code, number of digitized points\n" + ) + ) + fid.write("0 {rows:d}\n".format(rows=N_ROWS)) + fid.write( + ( + "//Position of fiducials X+, Y+, Y- on the subject\n" + "%F 0.11056 -5.421e-19 0 \n" + "%F -0.00021075 0.080793 -7.5894e-19 \n" + "%F 0.00021075 -0.080793 -2.8731e-18 \n" + "//No of rows, no of columns; position of digitized points\n" + ) + ) + fid.write("{rows:d} {cols:d}\n".format(rows=N_ROWS, cols=N_COLS)) for row in content: - fid.write('\t'.join('%0.18e' % cell for cell in row) + '\n') + fid.write("\t".join("%0.18e" % cell for cell in row) + "\n") return str(fname) @@ -660,18 +744,18 @@ def test_read_dig_polhemus_isotrak_eeg(isotrak_eeg): N_CHANNELS = 5 _SEED = 42 EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - ch_names = ['eeg {:01d}'.format(ii) for ii in range(N_CHANNELS)] - EXPECTED_CH_POS = dict(zip( - ch_names, np.random.RandomState(_SEED).randn(N_CHANNELS, 3))) + ch_names = ["eeg {:01d}".format(ii) for ii in range(N_CHANNELS)] + EXPECTED_CH_POS = dict( + zip(ch_names, np.random.RandomState(_SEED).randn(N_CHANNELS, 3)) + ) montage = read_dig_polhemus_isotrak(fname=isotrak_eeg, ch_names=ch_names) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -681,8 +765,8 @@ def test_read_dig_polhemus_isotrak_eeg(isotrak_eeg): assert_array_equal(val, EXPECTED_FID_IN_POLHEMUS[kk]) for kk, dig_point in zip(montage.ch_names, _get_dig_eeg(montage.dig)): - assert_array_equal(dig_point['r'], EXPECTED_CH_POS[kk]) - assert dig_point['coord_frame'] == FIFF.FIFFV_COORD_UNKNOWN + assert_array_equal(dig_point["r"], EXPECTED_CH_POS[kk]) + assert dig_point["coord_frame"] == FIFF.FIFFV_COORD_UNKNOWN def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): @@ -697,7 +781,7 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): _ = read_dig_polhemus_isotrak( fname=isotrak_eeg, - ch_names=['eeg {:01d}'.format(ii) for ii in range(N_CHANNELS + 42)] + ch_names=["eeg {:01d}".format(ii) for ii in range(N_CHANNELS + 42)], ) # Check fname extensions @@ -706,7 +790,7 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): with pytest.raises( ValueError, - match="Allowed val.*'.hsp', '.elp', and '.eeg', but got '.bar' instead" + match="Allowed val.*'.hsp', '.elp', and '.eeg', but got '.bar' instead", ): _ = read_dig_polhemus_isotrak(fname=fname, ch_names=None) @@ -714,52 +798,64 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): def test_combining_digmontage_objects(): """Test combining different DigMontage objects.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) # hsp positions are [1X, 1X, 1X] - hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.)) - hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) - hsp3 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 13.)) + hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.0)) + hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) + hsp3 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 13.0)) # hpi positions are [2X, 2X, 2X] - hpi1 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 21.)) - hpi2 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 22.)) - hpi3 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 23.)) + hpi1 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 21.0)) + hpi2 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 22.0)) + hpi3 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 23.0)) # channels have positions at 40s, 50s, and 60s. ch_pos1 = make_dig_montage( - **fiducials, - ch_pos={'h': [41, 41, 41], 'b': [42, 42, 42], 'g': [43, 43, 43]} + **fiducials, ch_pos={"h": [41, 41, 41], "b": [42, 42, 42], "g": [43, 43, 43]} ) ch_pos2 = make_dig_montage( - **fiducials, - ch_pos={'n': [51, 51, 51], 'y': [52, 52, 52], 'p': [53, 53, 53]} + **fiducials, ch_pos={"n": [51, 51, 51], "y": [52, 52, 52], "p": [53, 53, 53]} ) ch_pos3 = make_dig_montage( - **fiducials, - ch_pos={'v': [61, 61, 61], 'a': [62, 62, 62], 'l': [63, 63, 63]} + **fiducials, ch_pos={"v": [61, 61, 61], "a": [62, 62, 62], "l": [63, 63, 63]} ) montage = ( - DigMontage() + hsp1 + hsp2 + hsp3 + hpi1 + hpi2 + hpi3 + ch_pos1 + - ch_pos2 + ch_pos3 + DigMontage() + + hsp1 + + hsp2 + + hsp3 + + hpi1 + + hpi2 + + hpi3 + + ch_pos1 + + ch_pos2 + + ch_pos3 ) assert repr(montage) == ( - '' + "" ) EXPECTED_MONTAGE = make_dig_montage( **fiducials, - hsp=np.concatenate([np.full((2, 3), 11.), np.full((2, 3), 12.), - np.full((2, 3), 13.)]), - hpi=np.concatenate([np.full((2, 3), 21.), np.full((2, 3), 22.), - np.full((2, 3), 23.)]), + hsp=np.concatenate( + [np.full((2, 3), 11.0), np.full((2, 3), 12.0), np.full((2, 3), 13.0)] + ), + hpi=np.concatenate( + [np.full((2, 3), 21.0), np.full((2, 3), 22.0), np.full((2, 3), 23.0)] + ), ch_pos={ - 'h': [41, 41, 41], 'b': [42, 42, 42], 'g': [43, 43, 43], - 'n': [51, 51, 51], 'y': [52, 52, 52], 'p': [53, 53, 53], - 'v': [61, 61, 61], 'a': [62, 62, 62], 'l': [63, 63, 63], - } + "h": [41, 41, 41], + "b": [42, 42, 42], + "g": [43, 43, 43], + "n": [51, 51, 51], + "y": [52, 52, 52], + "p": [53, 53, 53], + "v": [61, 61, 61], + "a": [62, 62, 62], + "l": [63, 63, 63], + }, ) # Do some checks to ensure they are the same DigMontage @@ -773,33 +869,33 @@ def test_combining_digmontage_objects(): def test_combining_digmontage_forbiden_behaviors(): """Test combining different DigMontage objects with repeated names.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) dig1 = make_dig_montage( **fiducials, - ch_pos=dict(zip(list('abc'), rng.rand(3, 3))), + ch_pos=dict(zip(list("abc"), rng.rand(3, 3))), ) dig2 = make_dig_montage( **fiducials, - ch_pos=dict(zip(list('bcd'), rng.rand(3, 3))), + ch_pos=dict(zip(list("bcd"), rng.rand(3, 3))), ) dig2_wrong_fid = make_dig_montage( - nasion=rng.rand(3), lpa=rng.rand(3), rpa=rng.rand(3), - ch_pos=dict(zip(list('ghi'), rng.rand(3, 3))), + nasion=rng.rand(3), + lpa=rng.rand(3), + rpa=rng.rand(3), + ch_pos=dict(zip(list("ghi"), rng.rand(3, 3))), ) dig2_wrong_coordframe = make_dig_montage( - **fiducials, - ch_pos=dict(zip(list('ghi'), rng.rand(3, 3))), - coord_frame='meg' + **fiducials, ch_pos=dict(zip(list("ghi"), rng.rand(3, 3))), coord_frame="meg" ) - EXPECTED_ERR_MSG = "Cannot.*duplicated channel.*found: \'b\', \'c\'." + EXPECTED_ERR_MSG = "Cannot.*duplicated channel.*found: 'b', 'c'." with pytest.raises(RuntimeError, match=EXPECTED_ERR_MSG): _ = dig1 + dig2 - with pytest.raises(RuntimeError, match='fiducial locations do not match'): + with pytest.raises(RuntimeError, match="fiducial locations do not match"): _ = dig1 + dig2_wrong_fid - with pytest.raises(RuntimeError, match='not in the same coordinate '): + with pytest.raises(RuntimeError, match="not in the same coordinate "): _ = dig1 + dig2_wrong_coordframe @@ -807,45 +903,57 @@ def test_set_dig_montage(): """Test setting DigMontage with toy understandable points.""" N_CHANNELS, N_HSP, N_HPI = 3, 2, 1 ch_names = list(ascii_lowercase[:N_CHANNELS]) - ch_pos = dict(zip( - ch_names, - np.arange(N_CHANNELS * 3).reshape(N_CHANNELS, 3), - )) + ch_pos = dict( + zip( + ch_names, + np.arange(N_CHANNELS * 3).reshape(N_CHANNELS, 3), + ) + ) - montage_ch_only = make_dig_montage(ch_pos=ch_pos, coord_frame='head') + montage_ch_only = make_dig_montage(ch_pos=ch_pos, coord_frame="head") assert repr(montage_ch_only) == ( - '' + "" ) - info = create_info(ch_names, sfreq=1, ch_types='eeg') + info = create_info(ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage_ch_only) - assert len(info['dig']) == len(montage_ch_only.dig) + 3 # added fiducials + assert len(info["dig"]) == len(montage_ch_only.dig) + 3 # added fiducials - assert_allclose(actual=np.array([ch['loc'][:6] for ch in info['chs']]), - desired=[[0., 1., 2., 0., 0., 0.], - [3., 4., 5., 0., 0., 0.], - [6., 7., 8., 0., 0., 0.]]) + assert_allclose( + actual=np.array([ch["loc"][:6] for ch in info["chs"]]), + desired=[ + [0.0, 1.0, 2.0, 0.0, 0.0, 0.0], + [3.0, 4.0, 5.0, 0.0, 0.0, 0.0], + [6.0, 7.0, 8.0, 0.0, 0.0, 0.0], + ], + ) montage_full = make_dig_montage( ch_pos=dict(**ch_pos, EEG000=np.full(3, 42)), # 4 = 3 egg + 1 eeg_ref - nasion=[1, 1, 1], lpa=[2, 2, 2], rpa=[3, 3, 3], + nasion=[1, 1, 1], + lpa=[2, 2, 2], + rpa=[3, 3, 3], hsp=np.full((N_HSP, 3), 4), hpi=np.full((N_HPI, 3), 4), - coord_frame='head' + coord_frame="head", ) assert repr(montage_full) == ( - '' + "" ) - info = create_info(ch_names, sfreq=1, ch_types='eeg') + info = create_info(ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage_full) - EXPECTED_LEN = sum({'hsp': 2, 'hpi': 1, 'fid': 3, 'eeg': 4}.values()) - assert len(info['dig']) == EXPECTED_LEN - assert_allclose(actual=np.array([ch['loc'][:6] for ch in info['chs']]), - desired=[[0., 1., 2., 42., 42., 42.], - [3., 4., 5., 42., 42., 42.], - [6., 7., 8., 42., 42., 42.]]) + EXPECTED_LEN = sum({"hsp": 2, "hpi": 1, "fid": 3, "eeg": 4}.values()) + assert len(info["dig"]) == EXPECTED_LEN + assert_allclose( + actual=np.array([ch["loc"][:6] for ch in info["chs"]]), + desired=[ + [0.0, 1.0, 2.0, 42.0, 42.0, 42.0], + [3.0, 4.0, 5.0, 42.0, 42.0, 42.0], + [6.0, 7.0, 8.0, 42.0, 42.0, 42.0], + ], + ) def test_set_dig_montage_with_nan_positions(): @@ -854,10 +962,11 @@ def test_set_dig_montage_with_nan_positions(): Test that setting a montage with some NaN positions does not produce NaN fiducials. """ + def _ensure_fid_not_nan(info, ch_pos): - montage_kwargs = dict(ch_pos=dict(), coord_frame='head') + montage_kwargs = dict(ch_pos=dict(), coord_frame="head") for ch_idx, ch in enumerate(info.ch_names): - montage_kwargs['ch_pos'][ch] = ch_pos[ch_idx] + montage_kwargs["ch_pos"][ch] = ch_pos[ch_idx] new_montage = make_dig_montage(**montage_kwargs) info = info.copy() @@ -865,7 +974,8 @@ def _ensure_fid_not_nan(info, ch_pos): recovered_montage = info.get_montage() fid_coords, coord_frame = _get_fid_coords( - recovered_montage.dig, raise_error=False) + recovered_montage.dig, raise_error=False + ) for fid_coord in fid_coords.values(): if fid_coord is not None: @@ -873,21 +983,20 @@ def _ensure_fid_not_nan(info, ch_pos): return fid_coords, coord_frame - channels = list('ABCDEF') - info = create_info(channels, 1000, ch_types='seeg') + channels = list("ABCDEF") + info = create_info(channels, 1000, ch_types="seeg") # if all positions are NaN, the fiducials should not be NaN, but None - ch_pos = [info['chs'][ch_idx]['loc'][:3] - for ch_idx in range(len(channels))] + ch_pos = [info["chs"][ch_idx]["loc"][:3] for ch_idx in range(len(channels))] fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) for fid_coord in fid_coords.values(): assert fid_coord is None assert coord_frame is None # if some positions are not NaN, the fiducials should be a non-NaN array - ch_pos[0] = np.array([1., 1.5, 1.]) - ch_pos[1] = np.array([2., 1.5, 1.5]) - ch_pos[2] = np.array([1.25, 1., 1.25]) + ch_pos[0] = np.array([1.0, 1.5, 1.0]) + ch_pos[1] = np.array([2.0, 1.5, 1.5]) + ch_pos[2] = np.array([1.25, 1.0, 1.25]) fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) for fid_coord in fid_coords.values(): assert isinstance(fid_coord, np.ndarray) @@ -908,14 +1017,14 @@ def test_fif_dig_montage(tmp_path): raw_bv_2 = raw_bv.copy() mapping = dict() for ii, ch_name in enumerate(raw_bv.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 1,) + mapping[ch_name] = "EEG%03d" % (ii + 1,) raw_bv.rename_channels(mapping) for ii, ch_name in enumerate(raw_bv_2.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 33,) + mapping[ch_name] = "EEG%03d" % (ii + 33,) raw_bv_2.rename_channels(mapping) raw_bv.add_channels([raw_bv_2]) - for ch in raw_bv.info['chs']: - ch['kind'] = FIFF.FIFFV_EEG_CH + for ch in raw_bv.info["chs"]: + ch["kind"] = FIFF.FIFFV_EEG_CH # Set the montage raw_bv.set_montage(dig_montage) @@ -925,33 +1034,30 @@ def test_fif_dig_montage(tmp_path): # check info[chs] matches assert_equal(len(raw_bv.ch_names), len(evoked.ch_names) - 1) - for ch_py, ch_c in zip(raw_bv.info['chs'], evoked.info['chs'][:-1]): - assert_equal(ch_py['ch_name'], - ch_c['ch_name'].replace('EEG ', 'EEG')) + for ch_py, ch_c in zip(raw_bv.info["chs"], evoked.info["chs"][:-1]): + assert_equal(ch_py["ch_name"], ch_c["ch_name"].replace("EEG ", "EEG")) # C actually says it's unknown, but it's not (?): # assert_equal(ch_py['coord_frame'], ch_c['coord_frame']) - assert_equal(ch_py['coord_frame'], FIFF.FIFFV_COORD_HEAD) - c_loc = ch_c['loc'].copy() + assert_equal(ch_py["coord_frame"], FIFF.FIFFV_COORD_HEAD) + c_loc = ch_c["loc"].copy() c_loc[c_loc == 0] = np.nan - assert_allclose(ch_py['loc'], c_loc, atol=1e-7) + assert_allclose(ch_py["loc"], c_loc, atol=1e-7) # check info[dig] assert_dig_allclose(raw_bv.info, evoked.info) # Roundtrip of non-FIF start - montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), - hpi=read_mrk(hpi)) + montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), hpi=read_mrk(hpi)) elp_points = read_polhemus_fastscan(elp) ch_pos = {"EEG%03d" % (k + 1): pos for k, pos in enumerate(elp_points[8:])} - montage += make_dig_montage(nasion=elp_points[0], - lpa=elp_points[1], - rpa=elp_points[2], - ch_pos=ch_pos) - _check_roundtrip(montage, fname_temp, 'unknown') + montage += make_dig_montage( + nasion=elp_points[0], lpa=elp_points[1], rpa=elp_points[2], ch_pos=ch_pos + ) + _check_roundtrip(montage, fname_temp, "unknown") montage = transform_to_head(montage) _check_roundtrip(montage, fname_temp) - montage.dig[0]['coord_frame'] = FIFF.FIFFV_COORD_UNKNOWN - with pytest.raises(RuntimeError, match='Only a single coordinate'): + montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN + with pytest.raises(RuntimeError, match="Only a single coordinate"): montage.save(fname_temp) @@ -963,24 +1069,25 @@ def test_egi_dig_montage(tmp_path): assert coord == FIFF.FIFFV_COORD_UNKNOWN assert_allclose( - actual=np.array([fid[key] for key in ['nasion', 'lpa', 'rpa']]), - desired=[[ 0. , 10.564, -2.051], # noqa - [-8.592, 0.498, -4.128], # noqa - [ 8.592, 0.498, -4.128]], # noqa + actual=np.array([fid[key] for key in ["nasion", "lpa", "rpa"]]), + desired=[ + [0.0, 10.564, -2.051], # noqa + [-8.592, 0.498, -4.128], # noqa + [8.592, 0.498, -4.128], + ], # noqa ) # Test accuracy and embedding within raw object - raw_egi = read_raw_egi(egi_raw_fname, channel_naming='EEG %03d') + raw_egi = read_raw_egi(egi_raw_fname, channel_naming="EEG %03d") raw_egi.set_montage(dig_montage) test_raw_egi = read_raw_fif(egi_fif_fname) assert_equal(len(raw_egi.ch_names), len(test_raw_egi.ch_names)) - for ch_raw, ch_test_raw in zip(raw_egi.info['chs'], - test_raw_egi.info['chs']): - assert_equal(ch_raw['ch_name'], ch_test_raw['ch_name']) - assert_equal(ch_raw['coord_frame'], FIFF.FIFFV_COORD_HEAD) - assert_allclose(ch_raw['loc'], ch_test_raw['loc'], atol=1e-7) + for ch_raw, ch_test_raw in zip(raw_egi.info["chs"], test_raw_egi.info["chs"]): + assert_equal(ch_raw["ch_name"], ch_test_raw["ch_name"]) + assert_equal(ch_raw["coord_frame"], FIFF.FIFFV_COORD_HEAD) + assert_allclose(ch_raw["loc"], ch_test_raw["loc"], atol=1e-7) assert_dig_allclose(raw_egi.info, test_raw_egi.info) @@ -988,14 +1095,14 @@ def test_egi_dig_montage(tmp_path): fid, coord = _get_fid_coords(dig_montage_in_head.dig) assert coord == FIFF.FIFFV_COORD_HEAD assert_allclose( - actual=np.array([fid[key] for key in ['nasion', 'lpa', 'rpa']]), - desired=[[0., 10.278, 0.], [-8.592, 0., 0.], [8.592, 0., 0.]], + actual=np.array([fid[key] for key in ["nasion", "lpa", "rpa"]]), + desired=[[0.0, 10.278, 0.0], [-8.592, 0.0, 0.0], [8.592, 0.0, 0.0]], atol=1e-4, ) # test round-trip IO - fname_temp = tmp_path / 'egi_test.fif' - _check_roundtrip(dig_montage, fname_temp, 'unknown') + fname_temp = tmp_path / "egi_test.fif" + _check_roundtrip(dig_montage, fname_temp, "unknown") _check_roundtrip(dig_montage_in_head, fname_temp) @@ -1007,44 +1114,158 @@ def _pop_montage(dig_montage, ch_name): del dig_montage.dig[dig_idx] del dig_montage.ch_names[name_idx] for k in range(dig_idx, len(dig_montage.dig)): - dig_montage.dig[k]['ident'] -= 1 + dig_montage.dig[k]["ident"] -= 1 @testing.requires_testing_data def test_read_dig_captrak(tmp_path): """Test reading a captrak montage file.""" EXPECTED_CH_NAMES_OLD = [ - 'AF3', 'AF4', 'AF7', 'AF8', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', - 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPz', 'Cz', 'F1', 'F2', 'F3', 'F4', - 'F5', 'F6', 'F7', 'F8', 'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', - 'FT10', 'FT7', 'FT8', 'FT9', 'Fp1', 'Fp2', 'Fz', 'GND', 'O1', 'O2', - 'Oz', 'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'PO10', 'PO3', - 'PO4', 'PO7', 'PO8', 'PO9', 'POz', 'Pz', 'REF', 'T7', 'T8', 'TP10', - 'TP7', 'TP8', 'TP9' + "AF3", + "AF4", + "AF7", + "AF8", + "C1", + "C2", + "C3", + "C4", + "C5", + "C6", + "CP1", + "CP2", + "CP3", + "CP4", + "CP5", + "CP6", + "CPz", + "Cz", + "F1", + "F2", + "F3", + "F4", + "F5", + "F6", + "F7", + "F8", + "FC1", + "FC2", + "FC3", + "FC4", + "FC5", + "FC6", + "FT10", + "FT7", + "FT8", + "FT9", + "Fp1", + "Fp2", + "Fz", + "GND", + "O1", + "O2", + "Oz", + "P1", + "P2", + "P3", + "P4", + "P5", + "P6", + "P7", + "P8", + "PO10", + "PO3", + "PO4", + "PO7", + "PO8", + "PO9", + "POz", + "Pz", + "REF", + "T7", + "T8", + "TP10", + "TP7", + "TP8", + "TP9", ] EXPECTED_CH_NAMES = [ - 'T7', 'FC5', 'F7', 'C5', 'FT7', 'FT9', 'TP7', 'TP9', 'P7', 'CP5', - 'PO7', 'C3', 'CP3', 'P5', 'P3', 'PO3', 'PO9', 'O1', 'Oz', 'POz', 'O2', - 'PO4', 'P1', 'Pz', 'P2', 'CP2', 'CP1', 'CPz', 'Cz', 'C1', 'FC1', 'FC3', - 'REF', 'F3', 'F1', 'Fz', 'F5', 'AF7', 'AF3', 'Fp1', 'GND', 'F2', 'AF4', - 'Fp2', 'F4', 'F8', 'F6', 'AF8', 'FC2', 'FC6', 'FC4', 'C2', 'C4', 'P4', - 'CP4', 'PO8', 'P8', 'P6', 'CP6', 'PO10', 'TP10', 'TP8', 'FT10', 'T8', - 'C6', 'FT8' + "T7", + "FC5", + "F7", + "C5", + "FT7", + "FT9", + "TP7", + "TP9", + "P7", + "CP5", + "PO7", + "C3", + "CP3", + "P5", + "P3", + "PO3", + "PO9", + "O1", + "Oz", + "POz", + "O2", + "PO4", + "P1", + "Pz", + "P2", + "CP2", + "CP1", + "CPz", + "Cz", + "C1", + "FC1", + "FC3", + "REF", + "F3", + "F1", + "Fz", + "F5", + "AF7", + "AF3", + "Fp1", + "GND", + "F2", + "AF4", + "Fp2", + "F4", + "F8", + "F6", + "AF8", + "FC2", + "FC6", + "FC4", + "C2", + "C4", + "P4", + "CP4", + "PO8", + "P8", + "P6", + "CP6", + "PO10", + "TP10", + "TP8", + "FT10", + "T8", + "C6", + "FT8", ] assert set(EXPECTED_CH_NAMES) == set(EXPECTED_CH_NAMES_OLD) - montage = read_dig_captrak( - fname=data_path / "montage" / "captrak_coords.bvct" - ) + montage = read_dig_captrak(fname=data_path / "montage" / "captrak_coords.bvct") assert montage.ch_names == EXPECTED_CH_NAMES assert repr(montage) == ( - '' + "" ) montage = transform_to_head(montage) # transform_to_head has to be tested - _check_roundtrip(montage=montage, - fname=str(tmp_path / 'bvct_test.fif')) + _check_roundtrip(montage=montage, fname=str(tmp_path / "bvct_test.fif")) fid, _ = _get_fid_coords(montage.dig) assert_allclose( @@ -1054,64 +1275,65 @@ def test_read_dig_captrak(tmp_path): ) raw_bv = read_raw_brainvision(bv_raw_fname) - raw_bv.set_channel_types({"HEOG": 'eog', "VEOG": 'eog', "ECG": 'ecg'}) + raw_bv.set_channel_types({"HEOG": "eog", "VEOG": "eog", "ECG": "ecg"}) raw_bv.set_montage(montage) test_raw_bv = read_raw_fif(bv_fif_fname) # compare after set_montage using chs loc. - for actual, expected in zip(raw_bv.info['chs'], test_raw_bv.info['chs']): - assert_allclose(actual['loc'][:3], expected['loc'][:3]) - if actual['kind'] == FIFF.FIFFV_EEG_CH: - assert_allclose(actual['loc'][3:6], - [-0.005103, 0.05395, 0.144622], rtol=1e-04) + for actual, expected in zip(raw_bv.info["chs"], test_raw_bv.info["chs"]): + assert_allclose(actual["loc"][:3], expected["loc"][:3]) + if actual["kind"] == FIFF.FIFFV_EEG_CH: + assert_allclose( + actual["loc"][3:6], [-0.005103, 0.05395, 0.144622], rtol=1e-04 + ) # https://gist.github.com/larsoner/2264fb5895070d29a8c9aa7c0dc0e8a6 _MGH60 = ( - 'Fp1 Fpz Fp2 ' - 'AF7 AF3 AF4 AF8 ' - 'F7 F5 F3 F1 Fz F2 F4 F6 F8 ' - 'FT9 FT7 FC5 FC1 FC2 FC6 FT8 FT10 ' - 'T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 ' - 'TP9 TP7 CP3 CP1 CP2 CP4 TP8 TP10 ' - 'P7 P5 P3 P1 Pz P2 P4 P6 P8 ' - 'PO7 PO3 PO4 PO8 ' - 'O1 Oz O2 ' - 'Iz' + "Fp1 Fpz Fp2 " + "AF7 AF3 AF4 AF8 " + "F7 F5 F3 F1 Fz F2 F4 F6 F8 " + "FT9 FT7 FC5 FC1 FC2 FC6 FT8 FT10 " + "T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 " + "TP9 TP7 CP3 CP1 CP2 CP4 TP8 TP10 " + "P7 P5 P3 P1 Pz P2 P4 P6 P8 " + "PO7 PO3 PO4 PO8 " + "O1 Oz O2 " + "Iz" ).split() -@pytest.mark.parametrize('rename', ('raw', 'montage', 'custom')) +@pytest.mark.parametrize("rename", ("raw", "montage", "custom")) def test_set_montage_mgh(rename): """Test setting 'mgh60' montage to old fif.""" raw = read_raw_fif(fif_fname) eeg_picks = pick_types(raw.info, meg=False, eeg=True, exclude=()) - assert list(eeg_picks) == [ii for ii, name in enumerate(raw.ch_names) - if name.startswith('EEG')] - orig_pos = np.array([raw.info['chs'][pick]['loc'][:3] - for pick in eeg_picks]) + assert list(eeg_picks) == [ + ii for ii, name in enumerate(raw.ch_names) if name.startswith("EEG") + ] + orig_pos = np.array([raw.info["chs"][pick]["loc"][:3] for pick in eeg_picks]) atol = 1e-6 mon = None - if rename == 'raw': - raw.rename_channels(lambda x: x.replace('EEG ', 'EEG')) - raw.set_montage('mgh60') # test loading with string argument - elif rename == 'montage': - mon = make_standard_montage('mgh60') - mon.rename_channels(lambda x: x.replace('EEG', 'EEG ')) + if rename == "raw": + raw.rename_channels(lambda x: x.replace("EEG ", "EEG")) + raw.set_montage("mgh60") # test loading with string argument + elif rename == "montage": + mon = make_standard_montage("mgh60") + mon.rename_channels(lambda x: x.replace("EEG", "EEG ")) assert [raw.ch_names[pick] for pick in eeg_picks] == mon.ch_names raw.set_montage(mon) else: atol = 3e-3 # different subsets of channel locations - assert rename == 'custom' + assert rename == "custom" assert len(_MGH60) == 60 - mon = make_standard_montage('standard_1020') + mon = make_standard_montage("standard_1020") assert len(mon._get_ch_pos()) == 94 def renamer(x): try: - return 'EEG %03d' % (_MGH60.index(x) + 1,) + return "EEG %03d" % (_MGH60.index(x) + 1,) except ValueError: return x @@ -1122,47 +1344,56 @@ def renamer(x): # first two are 'Fp1' and 'Fz', take them from standard_1020.elc -- # they should not be changed on load! want_pos = [[-29.4367, 83.9171, -6.9900], [0.1123, 88.2470, -1.7130]] - got_pos = [mon.get_positions()['ch_pos'][f'EEG {x:03d}'] * 1000 - for x in range(1, 3)] + got_pos = [ + mon.get_positions()["ch_pos"][f"EEG {x:03d}"] * 1000 for x in range(1, 3) + ] assert_allclose(want_pos, got_pos) - assert mon.dig[0]['coord_frame'] == FIFF.FIFFV_COORD_MRI + assert mon.dig[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI trans = compute_native_head_t(mon) - trans_2 = _get_trans('fsaverage', 'mri', 'head')[0] - assert trans['to'] == trans_2['to'] - assert trans['from'] == trans_2['from'] - assert_allclose(trans['trans'], trans_2['trans'], atol=1e-6) + trans_2 = _get_trans("fsaverage", "mri", "head")[0] + assert trans["to"] == trans_2["to"] + assert trans["from"] == trans_2["from"] + assert_allclose(trans["trans"], trans_2["trans"], atol=1e-6) - new_pos = np.array([ch['loc'][:3] for ch in raw.info['chs'] - if ch['ch_name'].startswith('EEG')]) - assert ((orig_pos != new_pos).all()) + new_pos = np.array( + [ch["loc"][:3] for ch in raw.info["chs"] if ch["ch_name"].startswith("EEG")] + ) + assert (orig_pos != new_pos).all() r0 = _fit_sphere(new_pos)[1] assert_allclose(r0, [-0.001021, 0.014554, 0.041404], atol=1e-4) # spot check: Fp1 and Fpz - assert_allclose(new_pos[:2], [[-0.030903, 0.114585, 0.027867], - [-0.001337, 0.119102, 0.03289]], atol=atol) + assert_allclose( + new_pos[:2], + [[-0.030903, 0.114585, 0.027867], [-0.001337, 0.119102, 0.03289]], + atol=atol, + ) -@pytest.mark.parametrize('fname, montage, n_eeg, n_good, bads', [ - (fif_fname, 'mgh60', 60, 59, ['EEG 053']), - pytest.param(mgh70_fname, 'mgh70', 70, 64, None, - marks=[testing._pytest_mark()]), -]) +@pytest.mark.parametrize( + "fname, montage, n_eeg, n_good, bads", + [ + (fif_fname, "mgh60", 60, 59, ["EEG 053"]), + pytest.param( + mgh70_fname, "mgh70", 70, 64, None, marks=[testing._pytest_mark()] + ), + ], +) def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): """Test that montages give spatially similar positions.""" # 1. Prepare data: load, set bads (if missing), and filter raw = read_raw_fif(fname).pick_types(eeg=True, exclude=()) if bads is not None: - assert raw.info['bads'] == [] - raw.info['bads'] = bads + assert raw.info["bads"] == [] + raw.info["bads"] = bads assert len(raw.ch_names) == n_eeg - raw.pick_types(eeg=True, exclude='bads').load_data() + raw.pick_types(eeg=True, exclude="bads").load_data() raw.apply_function(lambda x: x - x.mean()) # remove DC raw.filter(None, 40) # remove line noise assert len(raw.ch_names) == n_good - if montage == 'mgh60': + if montage == "mgh60": montage = make_standard_montage(montage) - montage.rename_channels(lambda n: f'EEG {n[-3:]}') + montage.rename_channels(lambda n: f"EEG {n[-3:]}") raw_mon = raw.copy().set_montage(montage) # 2. First test: CSDs should be similar (CSD uses 3D positions) csd = compute_current_source_density(raw).get_data() @@ -1174,8 +1405,8 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): bads = [raw.ch_names[idx] for idx in bad_picks] orig_data = raw.get_data(bad_picks) assert_allclose(orig_data, raw_mon.get_data(bad_picks)) - raw.info['bads'] = bads - raw_mon.info['bads'] = bads + raw.info["bads"] = bads + raw_mon.info["bads"] = bads raw.interpolate_bads() raw_mon.interpolate_bads() orig_data = orig_data.ravel() @@ -1185,21 +1416,22 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): assert 0.95 < corr < 0.99, corr # 4. Third test: project each to a sphere, check cosine angles are small poss = dict() - for kind, this_raw in (('orig', raw), ('mon', raw_mon)): + for kind, this_raw in (("orig", raw), ("mon", raw_mon)): pos = np.array( - list(this_raw.get_montage().get_positions()['ch_pos'].values()), - float) + list(this_raw.get_montage().get_positions()["ch_pos"].values()), float + ) pos -= np.mean(pos, axis=0) pos /= np.linalg.norm(pos, axis=1, keepdims=True) poss[kind] = pos ang = np.rad2deg( # arccos is in [0, pi] - np.arccos(np.minimum(np.sum(poss['orig'] * poss['mon'], axis=1), 1))) + np.arccos(np.minimum(np.sum(poss["orig"] * poss["mon"], axis=1), 1)) + ) assert_array_less(ang, 20) # less than 20 deg assert_array_less(0, ang) # but not equal # XXX: this does not check ch_names + it cannot work because of write_dig -def _check_roundtrip(montage, fname, coord_frame='head'): +def _check_roundtrip(montage, fname, coord_frame="head"): """Check roundtrip writing.""" montage.save(fname, overwrite=True) montage_read = read_dig_fif(fname=fname) @@ -1211,68 +1443,74 @@ def _check_roundtrip(montage, fname, coord_frame='head'): def _fake_montage(ch_names): pos = np.random.RandomState(42).randn(len(ch_names), 3) - return make_dig_montage(ch_pos=dict(zip(ch_names, pos)), - coord_frame='head') + return make_dig_montage(ch_pos=dict(zip(ch_names, pos)), coord_frame="head") cnt_ignore_warns = [ pytest.mark.filterwarnings( - 'ignore:.*Could not parse meas date from the header. Setting to None.' + "ignore:.*Could not parse meas date from the header. Setting to None." ), - pytest.mark.filterwarnings(( - 'ignore:.*Could not define the number of bytes automatically.' - ' Defaulting to 2.') + pytest.mark.filterwarnings( + ( + "ignore:.*Could not define the number of bytes automatically." + " Defaulting to 2." + ) ), ] def test_digmontage_constructor_errors(): """Test proper error messaging.""" - with pytest.raises(ValueError, match='does not match the number'): - _ = DigMontage(ch_names=['foo', 'bar'], dig=list()) + with pytest.raises(ValueError, match="does not match the number"): + _ = DigMontage(ch_names=["foo", "bar"], dig=list()) def test_transform_to_head_and_compute_dev_head_t(): """Test transform_to_head and compute_dev_head_t.""" - EXPECTED_DEV_HEAD_T = \ - [[-3.72201691e-02, -9.98212167e-01, -4.67667497e-02, -7.31583414e-04], - [8.98064989e-01, -5.39382685e-02, 4.36543170e-01, 1.60134431e-02], - [-4.38285221e-01, -2.57513699e-02, 8.98466990e-01, 6.13035748e-02], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]] + EXPECTED_DEV_HEAD_T = [ + [-3.72201691e-02, -9.98212167e-01, -4.67667497e-02, -7.31583414e-04], + [8.98064989e-01, -5.39382685e-02, 4.36543170e-01, 1.60134431e-02], + [-4.38285221e-01, -2.57513699e-02, 8.98466990e-01, 6.13035748e-02], + [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00], + ] EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([0.001393, 0.0131613, -0.0046967]), - 'lpa': np.array([-0.0624997, -0.0737271, 0.07996]), - 'rpa': np.array([-0.0748957, 0.0873785, 0.0811943]), + "nasion": np.array([0.001393, 0.0131613, -0.0046967]), + "lpa": np.array([-0.0624997, -0.0737271, 0.07996]), + "rpa": np.array([-0.0748957, 0.0873785, 0.0811943]), } EXPECTED_FID_IN_HEAD = { - 'nasion': np.array([-8.94466792e-18, 1.10559624e-01, -3.85185989e-34]), - 'lpa': np.array([-8.10816716e-02, 6.56321671e-18, 0]), - 'rpa': np.array([8.05048781e-02, -6.47441364e-18, 0]), + "nasion": np.array([-8.94466792e-18, 1.10559624e-01, -3.85185989e-34]), + "lpa": np.array([-8.10816716e-02, 6.56321671e-18, 0]), + "rpa": np.array([8.05048781e-02, -6.47441364e-18, 0]), } hpi_dev = np.array( - [[ 2.13951493e-02, 8.47444056e-02, -5.65431188e-02], # noqa - [ 2.10299433e-02, -8.03141101e-02, -6.34420259e-02], # noqa - [ 1.05916829e-01, 8.18485672e-05, 1.19928083e-02], # noqa - [ 9.26595105e-02, 4.64804385e-02, 8.45141253e-03], # noqa - [ 9.42554419e-02, -4.35206589e-02, 8.78999363e-03]] # noqa + [ + [2.13951493e-02, 8.47444056e-02, -5.65431188e-02], # noqa + [2.10299433e-02, -8.03141101e-02, -6.34420259e-02], # noqa + [1.05916829e-01, 8.18485672e-05, 1.19928083e-02], # noqa + [9.26595105e-02, 4.64804385e-02, 8.45141253e-03], # noqa + [9.42554419e-02, -4.35206589e-02, 8.78999363e-03], + ] # noqa ) hpi_polhemus = np.array( - [[-0.0595004, -0.0704836, 0.075893 ], # noqa - [-0.0646373, 0.0838228, 0.0762123], # noqa - [-0.0135035, 0.0072522, -0.0268405], # noqa - [-0.0202967, -0.0351498, -0.0129305], # noqa - [-0.0277519, 0.0452628, -0.0222407]] # noqa + [ + [-0.0595004, -0.0704836, 0.075893], # noqa + [-0.0646373, 0.0838228, 0.0762123], # noqa + [-0.0135035, 0.0072522, -0.0268405], # noqa + [-0.0202967, -0.0351498, -0.0129305], # noqa + [-0.0277519, 0.0452628, -0.0222407], + ] # noqa ) montage_polhemus = make_dig_montage( - **EXPECTED_FID_IN_POLHEMUS, hpi=hpi_polhemus, coord_frame='unknown' + **EXPECTED_FID_IN_POLHEMUS, hpi=hpi_polhemus, coord_frame="unknown" ) - montage_meg = make_dig_montage(hpi=hpi_dev, coord_frame='meg') + montage_meg = make_dig_montage(hpi=hpi_dev, coord_frame="meg") # Test regular workflow to get dev_head_t montage = montage_polhemus + montage_meg @@ -1280,7 +1518,7 @@ def test_transform_to_head_and_compute_dev_head_t(): for kk in fids: assert_allclose(fids[kk], EXPECTED_FID_IN_POLHEMUS[kk], atol=1e-5) - with pytest.raises(ValueError, match='set to head coordinate system'): + with pytest.raises(ValueError, match="set to head coordinate system"): _ = compute_dev_head_t(montage) montage = transform_to_head(montage) @@ -1290,39 +1528,43 @@ def test_transform_to_head_and_compute_dev_head_t(): assert_allclose(fids[kk], EXPECTED_FID_IN_HEAD[kk], atol=1e-5) dev_head_t = compute_dev_head_t(montage) - assert_allclose(dev_head_t['trans'], EXPECTED_DEV_HEAD_T, atol=5e-7) + assert_allclose(dev_head_t["trans"], EXPECTED_DEV_HEAD_T, atol=5e-7) # Test errors when number of HPI points do not match - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 0 .*device and 5 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 0 .*device and 5 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): _ = compute_dev_head_t(transform_to_head(montage_polhemus)) - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 5 .*device and 0 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 5 .*device and 0 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): - _ = compute_dev_head_t(transform_to_head( - montage_meg + make_dig_montage(**EXPECTED_FID_IN_POLHEMUS) - )) + _ = compute_dev_head_t( + transform_to_head( + montage_meg + make_dig_montage(**EXPECTED_FID_IN_POLHEMUS) + ) + ) - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 3 .*device and 5 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 3 .*device and 5 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): - _ = compute_dev_head_t(transform_to_head( - DigMontage(dig=_format_dig_points(montage_meg.dig[:3])) + - montage_polhemus - )) + _ = compute_dev_head_t( + transform_to_head( + DigMontage(dig=_format_dig_points(montage_meg.dig[:3])) + + montage_polhemus + ) + ) def test_set_montage_with_mismatching_ch_names(): """Test setting a DigMontage with mismatching ch_names.""" raw = read_raw_fif(fif_fname) - montage = make_standard_montage('mgh60') + montage = make_standard_montage("mgh60") # 'EEG 001' and 'EEG001' won't match - missing_err = '60 channel positions not present' + missing_err = "60 channel positions not present" with pytest.raises(ValueError, match=missing_err): raw.set_montage(montage) montage.ch_names = [ # modify the names in place - name.replace('EEG', 'EEG ') for name in montage.ch_names + name.replace("EEG", "EEG ") for name in montage.ch_names ] raw.set_montage(montage) # does not raise @@ -1333,90 +1575,97 @@ def test_set_montage_with_mismatching_ch_names(): # should work raw.set_montage(montage, match_case=False) raw.rename_channels(lambda x: x.upper()) # restore - assert 'EEG 001' in raw.ch_names and 'eeg 001' not in raw.ch_names - raw.rename_channels({'EEG 002': 'eeg 001'}) - assert 'EEG 001' in raw.ch_names and 'eeg 001' in raw.ch_names - with pytest.warns(RuntimeWarning, match='changed from V to NA'): - raw.set_channel_types({'eeg 001': 'misc'}) + assert "EEG 001" in raw.ch_names and "eeg 001" not in raw.ch_names + raw.rename_channels({"EEG 002": "eeg 001"}) + assert "EEG 001" in raw.ch_names and "eeg 001" in raw.ch_names + with pytest.warns(RuntimeWarning, match="changed from V to NA"): + raw.set_channel_types({"eeg 001": "misc"}) raw.set_montage(montage) - with pytest.warns(RuntimeWarning, match='changed from NA to V'): - raw.set_channel_types({'eeg 001': 'eeg'}) - with pytest.raises(ValueError, match='1 channel position not present'): + with pytest.warns(RuntimeWarning, match="changed from NA to V"): + raw.set_channel_types({"eeg 001": "eeg"}) + with pytest.raises(ValueError, match="1 channel position not present"): raw.set_montage(montage) - with pytest.raises(ValueError, match='match_case=False as 1 channel name'): + with pytest.raises(ValueError, match="match_case=False as 1 channel name"): raw.set_montage(montage, match_case=False) - info = create_info(['EEG 001'], 1000., 'eeg') - mon = make_dig_montage({'EEG 001': np.zeros(3), 'eeg 001': np.zeros(3)}, - nasion=[0, 1., 0], rpa=[1., 0, 0], lpa=[-1., 0, 0]) + info = create_info(["EEG 001"], 1000.0, "eeg") + mon = make_dig_montage( + {"EEG 001": np.zeros(3), "eeg 001": np.zeros(3)}, + nasion=[0, 1.0, 0], + rpa=[1.0, 0, 0], + lpa=[-1.0, 0, 0], + ) info.set_montage(mon) - with pytest.raises(ValueError, match='match_case=False as 1 montage name'): + with pytest.raises(ValueError, match="match_case=False as 1 montage name"): info.set_montage(mon, match_case=False) def test_set_montage_with_sub_super_set_of_ch_names(): """Test info and montage ch_names matching criteria.""" - N_CHANNELS = len('abcdef') - montage = _make_toy_dig_montage(N_CHANNELS, coord_frame='head') + N_CHANNELS = len("abcdef") + montage = _make_toy_dig_montage(N_CHANNELS, coord_frame="head") # montage and info match - info = create_info(ch_names=list('abcdef'), sfreq=1, ch_types='eeg') + info = create_info(ch_names=list("abcdef"), sfreq=1, ch_types="eeg") info.set_montage(montage) # montage is a SUPERset of info - info = create_info(list('abc'), sfreq=1, ch_types='eeg') + info = create_info(list("abc"), sfreq=1, ch_types="eeg") info.set_montage(montage) - assert len(info['dig']) == len(list('abc')) + 3 # 3 fiducials + assert len(info["dig"]) == len(list("abc")) + 3 # 3 fiducials # montage is a SUBset of info - _MSG = 'subset of info. There are 2 .* not present in the DigMontage' - info = create_info(ch_names=list('abcdfgh'), sfreq=1, ch_types='eeg') + _MSG = "subset of info. There are 2 .* not present in the DigMontage" + info = create_info(ch_names=list("abcdfgh"), sfreq=1, ch_types="eeg") with pytest.raises(ValueError, match=_MSG) as exc: info.set_montage(montage) # plus suggestions - assert exc.match('set_channel_types') - assert exc.match('on_missing') + assert exc.match("set_channel_types") + assert exc.match("on_missing") def test_set_montage_with_known_aliases(): """Test matching unrecognized channel locations to known aliases.""" # montage and info match - mock_montage_ch_names = ['POO7', 'POO8'] + mock_montage_ch_names = ["POO7", "POO8"] n_channels = len(mock_montage_ch_names) - montage = make_dig_montage(ch_pos=dict( - zip( - mock_montage_ch_names, - np.arange(n_channels * 3).reshape(n_channels, 3), - )), - coord_frame='head') + montage = make_dig_montage( + ch_pos=dict( + zip( + mock_montage_ch_names, + np.arange(n_channels * 3).reshape(n_channels, 3), + ) + ), + coord_frame="head", + ) - mock_info_ch_names = ['Cb1', 'Cb2'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') + mock_info_ch_names = ["Cb1", "Cb2"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage, match_alias=True) # work with match_case - mock_info_ch_names = ['cb1', 'cb2'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') + mock_info_ch_names = ["cb1", "cb2"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage, match_case=False, match_alias=True) # should warn user T1 instead of its alias T9 - mock_info_ch_names = ['Cb1', 'T1'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') - with pytest.raises(ValueError, match='T1'): + mock_info_ch_names = ["Cb1", "T1"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") + with pytest.raises(ValueError, match="T1"): info.set_montage(montage, match_case=False, match_alias=True) def test_heterogeneous_ch_type(): """Test ch_names matching criteria with heterogeneous ch_type.""" - VALID_MONTAGE_NAMED_CHS = ('eeg', 'ecog', 'seeg', 'dbs') + VALID_MONTAGE_NAMED_CHS = ("eeg", "ecog", "seeg", "dbs") montage = _make_toy_dig_montage( n_channels=len(VALID_MONTAGE_NAMED_CHS), - coord_frame='head', + coord_frame="head", ) # Montage and info match - info = create_info(montage.ch_names, 1., list(VALID_MONTAGE_NAMED_CHS)) + info = create_info(montage.ch_names, 1.0, list(VALID_MONTAGE_NAMED_CHS)) RawArray(np.zeros((4, 1)), info, copy=None).set_montage(montage) @@ -1425,45 +1674,46 @@ def test_set_montage_coord_frame_in_head_vs_unknown(): N_CHANNELS, NaN = 3, np.nan raw = _make_toy_raw(N_CHANNELS) - montage_in_head = _make_toy_dig_montage(N_CHANNELS, coord_frame='head') - montage_in_unknown = _make_toy_dig_montage( - N_CHANNELS, coord_frame='unknown' - ) + montage_in_head = _make_toy_dig_montage(N_CHANNELS, coord_frame="head") + montage_in_unknown = _make_toy_dig_montage(N_CHANNELS, coord_frame="unknown") montage_in_unknown_with_fid = _make_toy_dig_montage( - N_CHANNELS, coord_frame='unknown', - nasion=[0, 1, 0], lpa=[1, 0, 0], rpa=[-1, 0, 0], + N_CHANNELS, + coord_frame="unknown", + nasion=[0, 1, 0], + lpa=[1, 0, 0], + rpa=[-1, 0, 0], ) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), - desired=np.full((N_CHANNELS, 12), np.nan) + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), + desired=np.full((N_CHANNELS, 12), np.nan), ) raw.set_montage(montage_in_head) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ - [0., 1., 2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [3., 4., 5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [6., 7., 8., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [0.0, 1.0, 2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [3.0, 4.0, 5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [6.0, 7.0, 8.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) - with pytest.warns(RuntimeWarning, match='assuming identity'): + with pytest.warns(RuntimeWarning, match="assuming identity"): raw.set_montage(montage_in_unknown) raw.set_montage(montage_in_unknown_with_fid) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ - [-0., 1., -2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-3., 4., -5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-6., 7., -8., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [-0.0, 1.0, -2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-3.0, 4.0, -5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-6.0, 7.0, -8.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) # check no collateral effects from transforming montage - assert _check_get_coord_frame(montage_in_unknown_with_fid.dig) == 'unknown' + assert _check_get_coord_frame(montage_in_unknown_with_fid.dig) == "unknown" assert_array_equal( _get_dig_montage_pos(montage_in_unknown_with_fid), [[0, 1, 2], [3, 4, 5], [6, 7, 8]], @@ -1471,41 +1721,42 @@ def test_set_montage_coord_frame_in_head_vs_unknown(): @testing.requires_testing_data -@pytest.mark.parametrize('ch_type', ('eeg', 'ecog', 'seeg', 'dbs')) +@pytest.mark.parametrize("ch_type", ("eeg", "ecog", "seeg", "dbs")) def test_montage_head_frame(ch_type): """Test that head frame is set properly.""" # gh-9446 data = np.random.randn(2, 100) - info = create_info(['a', 'b'], 512, ch_type) - for ch in info['chs']: - assert ch['coord_frame'] == FIFF.FIFFV_COORD_HEAD + info = create_info(["a", "b"], 512, ch_type) + for ch in info["chs"]: + assert ch["coord_frame"] == FIFF.FIFFV_COORD_HEAD raw = RawArray(data, info) - ch_pos = dict(a=[-0.00250136, 0.04913788, 0.05047056], - b=[-0.00528394, 0.05066484, 0.05061559]) - lpa, nasion, rpa = get_mni_fiducials( - 'fsaverage', subjects_dir=subjects_dir) - lpa, nasion, rpa = lpa['r'], nasion['r'], rpa['r'] + ch_pos = dict( + a=[-0.00250136, 0.04913788, 0.05047056], b=[-0.00528394, 0.05066484, 0.05061559] + ) + lpa, nasion, rpa = get_mni_fiducials("fsaverage", subjects_dir=subjects_dir) + lpa, nasion, rpa = lpa["r"], nasion["r"], rpa["r"] montage = make_dig_montage( - ch_pos, coord_frame='mri', nasion=nasion, lpa=lpa, rpa=rpa) + ch_pos, coord_frame="mri", nasion=nasion, lpa=lpa, rpa=rpa + ) mri_head_t = compute_native_head_t(montage) raw.set_montage(montage) pos = apply_trans(mri_head_t, np.array(list(ch_pos.values()))) - for p, ch in zip(pos, raw.info['chs']): - assert ch['coord_frame'] == FIFF.FIFFV_COORD_HEAD - assert_allclose(p, ch['loc'][:3]) + for p, ch in zip(pos, raw.info["chs"]): + assert ch["coord_frame"] == FIFF.FIFFV_COORD_HEAD + assert_allclose(p, ch["loc"][:3]) # Also test that including channels in the montage that will not have their # positions set will emit a warning - with pytest.warns(RuntimeWarning, match='changed from V to NA'): - raw.set_channel_types(dict(a='misc')) - with pytest.warns(RuntimeWarning, match='Not setting .*of 1 misc channel'): + with pytest.warns(RuntimeWarning, match="changed from V to NA"): + raw.set_channel_types(dict(a="misc")) + with pytest.warns(RuntimeWarning, match="Not setting .*of 1 misc channel"): raw.set_montage(montage) # and with a bunch of bad types raw = read_raw_fif(fif_fname) ch_pos = {ch_name: np.zeros(3) for ch_name in raw.ch_names} - mon = make_dig_montage(ch_pos, coord_frame='head') - with pytest.warns(RuntimeWarning, match='316 eog/grad/mag/stim channels'): + mon = make_dig_montage(ch_pos, coord_frame="head") + with pytest.warns(RuntimeWarning, match="316 eog/grad/mag/stim channels"): raw.set_montage(mon) @@ -1514,39 +1765,44 @@ def test_set_montage_with_missing_coordinates(): N_CHANNELS, NaN = 3, np.nan raw = _make_toy_raw(N_CHANNELS) - raw.set_channel_types({ch: 'ecog' for ch in raw.ch_names}) + raw.set_channel_types({ch: "ecog" for ch in raw.ch_names}) # don't include all the channels ch_names = raw.ch_names[1:] n_channels = len(ch_names) ch_coords = np.arange(n_channels * 3).reshape(n_channels, 3) montage_in_mri = make_dig_montage( - ch_pos=dict(zip(ch_names, ch_coords,)), - coord_frame='unknown', - nasion=[0, 1, 0], lpa=[1, 0, 0], rpa=[-1, 0, 0], + ch_pos=dict( + zip( + ch_names, + ch_coords, + ) + ), + coord_frame="unknown", + nasion=[0, 1, 0], + lpa=[1, 0, 0], + rpa=[-1, 0, 0], ) - with pytest.raises(ValueError, match='DigMontage is ' - 'only a subset of info'): + with pytest.raises(ValueError, match="DigMontage is " "only a subset of info"): raw.set_montage(montage_in_mri) - with pytest.raises(ValueError, match='Invalid value'): - raw.set_montage(montage_in_mri, on_missing='foo') + with pytest.raises(ValueError, match="Invalid value"): + raw.set_montage(montage_in_mri, on_missing="foo") - with pytest.raises(TypeError, match='must be an instance'): + with pytest.raises(TypeError, match="must be an instance"): raw.set_montage(montage_in_mri, on_missing=True) - with pytest.warns(RuntimeWarning, match='DigMontage is ' - 'only a subset of info'): - raw.set_montage(montage_in_mri, on_missing='warn') + with pytest.warns(RuntimeWarning, match="DigMontage is " "only a subset of info"): + raw.set_montage(montage_in_mri, on_missing="warn") - raw.set_montage(montage_in_mri, on_missing='ignore') + raw.set_montage(montage_in_mri, on_missing="ignore") assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], - [0., 1., -2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-3., 4., -5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [0.0, 1.0, -2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-3.0, 4.0, -5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) @@ -1559,16 +1815,16 @@ def test_get_montage(): # 1. read in testing data and assert montage roundtrip # for testing dataset: 'test_raw.fif' raw = read_raw_fif(fif_fname) - raw = raw.rename_channels(lambda name: name.replace('EEG ', 'EEG')) + raw = raw.rename_channels(lambda name: name.replace("EEG ", "EEG")) raw2 = raw.copy() # get montage and then set montage and # it should be the same montage = raw.get_montage() - raw.set_montage(montage, on_missing='raise') + raw.set_montage(montage, on_missing="raise") test_montage = raw.get_montage() - assert_object_equal(raw.info['chs'], raw2.info['chs']) + assert_object_equal(raw.info["chs"], raw2.info["chs"]) assert_dig_allclose(raw2.info, raw.info) - assert_object_equal(raw2.info['dig'], raw.info['dig']) + assert_object_equal(raw2.info["dig"], raw.info["dig"]) # the montage does not change assert_object_equal(montage.dig, test_montage.dig) @@ -1578,7 +1834,7 @@ def test_get_montage(): assert_object_equal(test2_montage.dig, test_montage.dig) # 2. now do a standard montage - montage = make_standard_montage('mgh60') + montage = make_standard_montage("mgh60") # set the montage; note renaming to make standard montage map raw.set_montage(montage) @@ -1586,20 +1842,20 @@ def test_get_montage(): # the channel locations should be the same raw2 = raw.copy() test_montage = raw.get_montage() - raw.set_montage(test_montage, on_missing='ignore') + raw.set_montage(test_montage, on_missing="ignore") # the montage should fulfill a roundtrip with make_dig_montage test2_montage = make_dig_montage(**test_montage.get_positions()) assert_object_equal(test2_montage.dig, test_montage.dig) # chs should not change - assert_object_equal(raw2.info['chs'], raw.info['chs']) + assert_object_equal(raw2.info["chs"], raw.info["chs"]) # dig order might be different after set_montage assert montage.ch_names == test_montage.ch_names # note that test_montage will have different coordinate frame # compared to standard montage assert_dig_allclose(raw2.info, raw.info) - assert_object_equal(raw2.info['dig'], raw.info['dig']) + assert_object_equal(raw2.info["dig"], raw.info["dig"]) # 3. if montage gets set to None raw.set_montage(None) @@ -1618,14 +1874,14 @@ def test_get_montage(): # of channels mapping = dict() for ii, ch_name in enumerate(raw_bv.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 1,) + mapping[ch_name] = "EEG%03d" % (ii + 1,) raw_bv.rename_channels(mapping) for ii, ch_name in enumerate(raw_bv_2.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 33,) + mapping[ch_name] = "EEG%03d" % (ii + 33,) raw_bv_2.rename_channels(mapping) raw_bv.add_channels([raw_bv_2]) - for ch in raw_bv.info['chs']: - ch['kind'] = FIFF.FIFFV_EEG_CH + for ch in raw_bv.info["chs"]: + ch["kind"] = FIFF.FIFFV_EEG_CH # Set the montage and roundtrip raw_bv.set_montage(dig_montage) @@ -1633,14 +1889,14 @@ def test_get_montage(): # reset the montage test_montage = raw_bv.get_montage() - raw_bv.set_montage(test_montage, on_missing='ignore') + raw_bv.set_montage(test_montage, on_missing="ignore") # dig order might be different after set_montage - assert_object_equal(raw_bv2.info['dig'], raw_bv.info['dig']) + assert_object_equal(raw_bv2.info["dig"], raw_bv.info["dig"]) assert_dig_allclose(raw_bv2.info, raw_bv.info) # if dig is not set in the info, then montage returns None with raw.info._unlock(): - raw.info['dig'] = None + raw.info["dig"] = None assert raw.get_montage() is None # the montage should fulfill a roundtrip with make_dig_montage @@ -1653,8 +1909,7 @@ def test_read_dig_hpts(): fname = io_dir / "brainvision" / "tests" / "data" / "test.hpts" montage = read_dig_hpts(fname) assert repr(montage) == ( - '' + "" ) @@ -1677,7 +1932,7 @@ def test_plot_montage(): # gh-8025 montage = read_dig_captrak(bvct_dig_montage_fname) montage.plot() - plt.close('all') + plt.close("all") f, ax = plt.subplots(1, 1) montage.plot(axes=ax) @@ -1694,12 +1949,12 @@ def test_plot_montage(): def test_montage_equality(): """Test montage equality.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) # hsp positions are [1X, 1X, 1X] - hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.)) - hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) - hsp2_identical = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) + hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.0)) + hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) + hsp2_identical = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) assert hsp1 != hsp2 assert hsp2 == hsp2_identical @@ -1710,45 +1965,46 @@ def test_montage_add_fiducials(): """Test montage can add estimated fiducials for rpa, lpa, nas.""" # get the fiducials from test file subjects_dir = data_path / "subjects" - subject = 'sample' + subject = "sample" fid_fname = subjects_dir / subject / "bem" / "sample-fiducials.fif" test_fids, test_coord_frame = read_fiducials(fid_fname) - test_fids = np.array([f['r'] for f in test_fids]) + test_fids = np.array([f["r"] for f in test_fids]) # create test montage and add estimated fiducials - test_ch_pos = {'A1': [0, 0, 0]} - montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame='mri') + test_ch_pos = {"A1": [0, 0, 0]} + montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame="mri") montage.add_estimated_fiducials(subject=subject, subjects_dir=subjects_dir) # check that adding MNI fiducials fails because we're in MRI - with pytest.raises(RuntimeError, match='Montage should be in the ' - '"mni_tal" coordinate frame'): + with pytest.raises( + RuntimeError, match="Montage should be in the " '"mni_tal" coordinate frame' + ): montage.add_mni_fiducials(subjects_dir=subjects_dir) # check that these fiducials are close to the estimated fiducials ch_pos = montage.get_positions() - fids_est = [ch_pos['lpa'], ch_pos['nasion'], ch_pos['rpa']] + fids_est = [ch_pos["lpa"], ch_pos["nasion"], ch_pos["rpa"]] - dists = np.linalg.norm(test_fids - fids_est, axis=-1) * 1000. # -> mm + dists = np.linalg.norm(test_fids - fids_est, axis=-1) * 1000.0 # -> mm assert (dists < 8).all(), dists # an error should be raised if the montage is not in `mri` coord_frame # which is the FreeSurfer RAS - montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame='mni_tal') - with pytest.raises(RuntimeError, match='Montage should be in the ' - '"mri" coordinate frame'): - montage.add_estimated_fiducials(subject=subject, - subjects_dir=subjects_dir) + montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame="mni_tal") + with pytest.raises( + RuntimeError, match="Montage should be in the " '"mri" coordinate frame' + ): + montage.add_estimated_fiducials(subject=subject, subjects_dir=subjects_dir) # test that adding MNI fiducials works montage.add_mni_fiducials(subjects_dir=subjects_dir) - test_fids = get_mni_fiducials('fsaverage', subjects_dir=subjects_dir) + test_fids = get_mni_fiducials("fsaverage", subjects_dir=subjects_dir) for fid, test_fid in zip(montage.dig[:3], test_fids): - assert_array_equal(fid['r'], test_fid['r']) + assert_array_equal(fid["r"], test_fid["r"]) # test remove fiducials montage.remove_fiducials() - assert all([d['kind'] != FIFF.FIFFV_POINT_CARDINAL for d in montage.dig]) + assert all([d["kind"] != FIFF.FIFFV_POINT_CARDINAL for d in montage.dig]) def test_read_dig_localite(tmp_path): @@ -1773,23 +2029,23 @@ def test_read_dig_localite(tmp_path): 17,ch14,-61.16539571,-61.86866187,26.23986153 18,ch15,-55.82855386,-34.77319103,25.8083942""" - fname = tmp_path / 'localite.csv' - with open(fname, 'w') as f: - for row in contents.split('\n'): - f.write(f'{row.lstrip()}\n') + fname = tmp_path / "localite.csv" + with open(fname, "w") as f: + for row in contents.split("\n"): + f.write(f"{row.lstrip()}\n") montage = read_dig_localite(fname, nasion="Nasion", lpa="LPA", rpa="RPA") - s = '' + s = "" assert repr(montage) == s - assert montage.ch_names == [f'ch{i:02}' for i in range(1, 16)] + assert montage.ch_names == [f"ch{i:02}" for i in range(1, 16)] def test_make_wrong_dig_montage(): """Test that a montage with non numeric is not possible.""" - make_dig_montage(ch_pos={'A1': ['0', '0', '0']}) # converted to floats + make_dig_montage(ch_pos={"A1": ["0", "0", "0"]}) # converted to floats with pytest.raises(ValueError, match="could not convert string to float"): - make_dig_montage(ch_pos={'A1': ['a', 'b', 'c']}) + make_dig_montage(ch_pos={"A1": ["a", "b", "c"]}) with pytest.raises(TypeError, match="instance of ndarray, list, or tuple"): - make_dig_montage(ch_pos={'A1': 5}) + make_dig_montage(ch_pos={"A1": 5}) @testing.requires_testing_data @@ -1805,15 +2061,14 @@ def test_fnirs_montage(): assert num_detectors == 13 # Make a change to the montage before setting - raw.info['chs'][2]['loc'][:3] = [1., 2, 3] + raw.info["chs"][2]["loc"][:3] = [1.0, 2, 3] # Set montage back to original raw.set_montage(mtg) for ch in range(len(raw.ch_names)): - assert_array_equal(info_orig['chs'][ch]['loc'], - raw.info['chs'][ch]['loc']) + assert_array_equal(info_orig["chs"][ch]["loc"], raw.info["chs"][ch]["loc"]) # Mixed channel types not supported yet - raw.set_channel_types({ch_name: 'eeg' for ch_name in raw.ch_names[-2:]}) - with pytest.raises(ValueError, match='mix of fNIRS'): + raw.set_channel_types({ch_name: "eeg" for ch_name in raw.ch_names[-2:]}) + with pytest.raises(ValueError, match="mix of fNIRS"): raw.get_montage() diff --git a/mne/channels/tests/test_standard_montage.py b/mne/channels/tests/test_standard_montage.py index 49fffaa4ab3..a9cf8f2cf0a 100644 --- a/mne/channels/tests/test_standard_montage.py +++ b/mne/channels/tests/test_standard_montage.py @@ -8,8 +8,7 @@ import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_raises) +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_raises from mne import create_info from mne.channels import make_standard_montage, compute_native_head_t @@ -21,7 +20,7 @@ from mne.transforms import _get_trans, _angle_between_quats, rot_to_quat -@pytest.mark.parametrize('kind', get_builtin_montages()) +@pytest.mark.parametrize("kind", get_builtin_montages()) def test_standard_montages_have_fids(kind): """Test standard montage are all in unknown coord (have fids).""" montage = make_standard_montage(kind) @@ -29,44 +28,47 @@ def test_standard_montages_have_fids(kind): for k, v in fids.items(): assert v is not None, k for d in montage.dig: - if kind.startswith(('artinis', 'standard', 'mgh')): + if kind.startswith(("artinis", "standard", "mgh")): want = FIFF.FIFFV_COORD_MRI else: want = FIFF.FIFFV_COORD_UNKNOWN - assert d['coord_frame'] == want + assert d["coord_frame"] == want def test_standard_montage_errors(): """Test error handling for wrong keys.""" _msg = "Invalid value for the 'kind' parameter..*but got.*not-here" with pytest.raises(ValueError, match=_msg): - _ = make_standard_montage('not-here') - - -@pytest.mark.parametrize('head_size', (HEAD_SIZE_DEFAULT, 0.05)) -@pytest.mark.parametrize('kind, tol', [ - ['EGI_256', 1e-5], - ['easycap-M1', 1e-8], - ['easycap-M10', 1e-8], - ['biosemi128', 1e-8], - ['biosemi16', 1e-8], - ['biosemi160', 1e-8], - ['biosemi256', 1e-8], - ['biosemi32', 1e-8], - ['biosemi64', 1e-8], - ['brainproducts-RNP-BA-128', 1e-8] -]) + _ = make_standard_montage("not-here") + + +@pytest.mark.parametrize("head_size", (HEAD_SIZE_DEFAULT, 0.05)) +@pytest.mark.parametrize( + "kind, tol", + [ + ["EGI_256", 1e-5], + ["easycap-M1", 1e-8], + ["easycap-M10", 1e-8], + ["biosemi128", 1e-8], + ["biosemi16", 1e-8], + ["biosemi160", 1e-8], + ["biosemi256", 1e-8], + ["biosemi32", 1e-8], + ["biosemi64", 1e-8], + ["brainproducts-RNP-BA-128", 1e-8], + ], +) def test_standard_montages_on_sphere(kind, tol, head_size): """Test some standard montage are on sphere.""" kwargs = dict() if head_size != HEAD_SIZE_DEFAULT: - kwargs['head_size'] = head_size + kwargs["head_size"] = head_size montage = make_standard_montage(kind, **kwargs) - eeg_loc = np.array([ch['r'] for ch in _get_dig_eeg(montage.dig)]) + eeg_loc = np.array([ch["r"] for ch in _get_dig_eeg(montage.dig)]) assert_allclose( actual=np.linalg.norm(eeg_loc, axis=1), - desired=np.full((eeg_loc.shape[0], ), head_size), + desired=np.full((eeg_loc.shape[0],), head_size), atol=tol, ) @@ -74,14 +76,14 @@ def test_standard_montages_on_sphere(kind, tol, head_size): def test_standard_superset(): """Test some properties that should hold for superset montages.""" # new montages, tweaked to end up at the same size as the others - m_1005 = make_standard_montage('standard_1005', 0.0970) - m_1020 = make_standard_montage('standard_1020', 0.0991) + m_1005 = make_standard_montage("standard_1005", 0.0970) + m_1020 = make_standard_montage("standard_1020", 0.0991) assert len(set(m_1005.ch_names) - set(m_1020.ch_names)) > 0 # XXX weird that this is not a proper superset... - assert set(m_1020.ch_names) - set(m_1005.ch_names) == {'O10', 'O9'} + assert set(m_1020.ch_names) - set(m_1005.ch_names) == {"O10", "O9"} c_1005 = m_1005._get_ch_pos() for key, value in m_1020._get_ch_pos().items(): - if key not in ('O10', 'O9'): + if key not in ("O10", "O9"): assert_allclose(c_1005[key], value, atol=1e-4, err_msg=key) @@ -93,15 +95,29 @@ def _simulate_artinis_octamon(): """ np.random.seed(42) data = np.absolute(np.random.normal(size=(16, 100))) - ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', - 'S3_D1 760', 'S3_D1 850', 'S4_D1 760', 'S4_D1 850', - 'S5_D2 760', 'S5_D2 850', 'S6_D2 760', 'S6_D2 850', - 'S7_D2 760', 'S7_D2 850', 'S8_D2 760', 'S8_D2 850'] - ch_types = ['fnirs_cw_amplitude' for _ in ch_names] - sfreq = 10. # Hz + ch_names = [ + "S1_D1 760", + "S1_D1 850", + "S2_D1 760", + "S2_D1 850", + "S3_D1 760", + "S3_D1 850", + "S4_D1 760", + "S4_D1 850", + "S5_D2 760", + "S5_D2 850", + "S6_D2 760", + "S6_D2 850", + "S7_D2 760", + "S7_D2 850", + "S8_D2 760", + "S8_D2 850", + ] + ch_types = ["fnirs_cw_amplitude" for _ in ch_names] + sfreq = 10.0 # Hz info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) for i, ch_name in enumerate(ch_names): - info['chs'][i]['loc'][9] = int(ch_name.split(' ')[1]) + info["chs"][i]["loc"][9] = int(ch_name.split(" ")[1]) raw = RawArray(data, info) return raw @@ -115,47 +131,71 @@ def _simulate_artinis_brite23(): """ np.random.seed(0) data = np.random.normal(size=(46, 100)) - sd_names = ['S1_D1', 'S2_D1', 'S3_D1', 'S4_D1', 'S3_D2', 'S4_D2', 'S5_D2', - 'S4_D3', 'S5_D3', 'S6_D3', 'S5_D4', 'S6_D4', 'S7_D4', 'S6_D5', - 'S7_D5', 'S8_D5', 'S7_D6', 'S8_D6', 'S9_D6', 'S8_D7', 'S9_D7', - 'S10_D7', 'S11_D7'] + sd_names = [ + "S1_D1", + "S2_D1", + "S3_D1", + "S4_D1", + "S3_D2", + "S4_D2", + "S5_D2", + "S4_D3", + "S5_D3", + "S6_D3", + "S5_D4", + "S6_D4", + "S7_D4", + "S6_D5", + "S7_D5", + "S8_D5", + "S7_D6", + "S8_D6", + "S9_D6", + "S8_D7", + "S9_D7", + "S10_D7", + "S11_D7", + ] ch_names = [] ch_types = [] for name in sd_names: - ch_names.append(name + ' hbo') - ch_types.append('hbo') - ch_names.append(name + ' hbr') - ch_types.append('hbr') - sfreq = 10. # Hz + ch_names.append(name + " hbo") + ch_types.append("hbo") + ch_names.append(name + " hbr") + ch_types.append("hbr") + sfreq = 10.0 # Hz info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data, info) return raw -@pytest.mark.parametrize('kind', ('octamon', 'brite23')) +@pytest.mark.parametrize("kind", ("octamon", "brite23")) def test_set_montage_artinis_fsaverage(kind): """Test that artinis montages match fsaverage's head<->MRI transform.""" # Compare OctaMon and Brite23 to fsaverage - trans_fs, _ = _get_trans('fsaverage') - montage = make_standard_montage(f'artinis-{kind}') + trans_fs, _ = _get_trans("fsaverage") + montage = make_standard_montage(f"artinis-{kind}") trans = compute_native_head_t(montage) - assert trans['to'] == trans_fs['to'] - assert trans['from'] == trans_fs['from'] - translation = 1000 * np.linalg.norm(trans['trans'][:3, 3] - - trans_fs['trans'][:3, 3]) + assert trans["to"] == trans_fs["to"] + assert trans["from"] == trans_fs["from"] + translation = 1000 * np.linalg.norm( + trans["trans"][:3, 3] - trans_fs["trans"][:3, 3] + ) assert 0 < translation < 1 # mm rotation = np.rad2deg( - _angle_between_quats(rot_to_quat(trans['trans'][:3, :3]), - rot_to_quat(trans_fs['trans'][:3, :3]))) + _angle_between_quats( + rot_to_quat(trans["trans"][:3, :3]), rot_to_quat(trans_fs["trans"][:3, :3]) + ) + ) assert 0 < rotation < 1 # degrees def test_set_montage_artinis_basic(): """Test that OctaMon and Brite23 montages are set properly.""" # Test OctaMon montage - montage_octamon = make_standard_montage('artinis-octamon') - montage_brite23 = make_standard_montage('artinis-brite23') + montage_octamon = make_standard_montage("artinis-octamon") + montage_brite23 = make_standard_montage("artinis-brite23") raw = _simulate_artinis_octamon() raw_od = optical_density(raw) old_info = raw.info.copy() @@ -164,82 +204,106 @@ def test_set_montage_artinis_basic(): raw_od.set_montage(montage_octamon) raw_hb = beer_lambert_law(raw_od, ppf=6) # montage needed for BLL # Check that the montage was actually modified - assert_raises(AssertionError, assert_array_almost_equal, - old_info['chs'][0]['loc'][:9], - raw.info['chs'][0]['loc'][:9]) - assert_raises(AssertionError, assert_array_almost_equal, - old_info_od['chs'][0]['loc'][:9], - raw_od.info['chs'][0]['loc'][:9]) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info["chs"][0]["loc"][:9], + raw.info["chs"][0]["loc"][:9], + ) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info_od["chs"][0]["loc"][:9], + raw_od.info["chs"][0]["loc"][:9], + ) # Check a known location - assert_array_almost_equal(raw.info['chs'][0]['loc'][:3], - [0.054243, 0.081884, 0.054544]) - assert_array_almost_equal(raw.info['chs'][8]['loc'][:3], - [-0.03013, 0.105097, 0.055894]) - assert_array_almost_equal(raw.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) - assert_array_almost_equal(raw_od.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) - assert_array_almost_equal(raw_hb.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:3], [0.054243, 0.081884, 0.054544] + ) + assert_array_almost_equal( + raw.info["chs"][8]["loc"][:3], [-0.03013, 0.105097, 0.055894] + ) + assert_array_almost_equal( + raw.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) + assert_array_almost_equal( + raw_od.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) + assert_array_almost_equal( + raw_hb.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) # Check that locations are identical for a pair of channels (all elements # except the 10th which is the wavelength if not hbo and hbr type) - assert_array_almost_equal(raw.info['chs'][0]['loc'][:9], - raw.info['chs'][1]['loc'][:9]) - assert_array_almost_equal(raw_od.info['chs'][0]['loc'][:9], - raw_od.info['chs'][1]['loc'][:9]) - assert_array_almost_equal(raw_hb.info['chs'][0]['loc'][:9], - raw_hb.info['chs'][1]['loc'][:9]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:9], raw.info["chs"][1]["loc"][:9] + ) + assert_array_almost_equal( + raw_od.info["chs"][0]["loc"][:9], raw_od.info["chs"][1]["loc"][:9] + ) + assert_array_almost_equal( + raw_hb.info["chs"][0]["loc"][:9], raw_hb.info["chs"][1]["loc"][:9] + ) # Test Brite23 montage raw = _simulate_artinis_brite23() old_info = raw.info.copy() raw.set_montage(montage_brite23) # Check that the montage was actually modified - assert_raises(AssertionError, assert_array_almost_equal, - old_info['chs'][0]['loc'][:9], - raw.info['chs'][0]['loc'][:9]) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info["chs"][0]["loc"][:9], + raw.info["chs"][0]["loc"][:9], + ) # Check a known location - assert_array_almost_equal(raw.info['chs'][0]['loc'][:3], - [0.068931, 0.046201, 0.072055]) - assert_array_almost_equal(raw.info['chs'][8]['loc'][:3], - [0.055196, 0.082757, 0.052165]) - assert_array_almost_equal(raw.info['chs'][12]['loc'][:3], - [0.033592, 0.102607, 0.047423]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:3], [0.068931, 0.046201, 0.072055] + ) + assert_array_almost_equal( + raw.info["chs"][8]["loc"][:3], [0.055196, 0.082757, 0.052165] + ) + assert_array_almost_equal( + raw.info["chs"][12]["loc"][:3], [0.033592, 0.102607, 0.047423] + ) # Check that locations are identical for a pair of channels (all elements # except the 10th which is the wavelength if not hbo and hbr type) - assert_array_almost_equal(raw.info['chs'][0]['loc'][:9], - raw.info['chs'][1]['loc'][:9]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:9], raw.info["chs"][1]["loc"][:9] + ) # Test channel variations raw_old = _simulate_artinis_brite23() # Raw missing some channels that are in the montage: pass raw = raw_old.copy() - raw.pick(['S1_D1 hbo', 'S1_D1 hbr']) - raw.set_montage('artinis-brite23') + raw.pick(["S1_D1 hbo", "S1_D1 hbr"]) + raw.set_montage("artinis-brite23") # Unconventional channel pair: pass raw = raw_old.copy() - info_new = create_info(['S11_D1 hbo', 'S11_D1 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S11_D1 hbo", "S11_D1 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - raw.set_montage('artinis-brite23') + raw.set_montage("artinis-brite23") # Source not in montage: fail raw = raw_old.copy() - info_new = create_info(['S12_D7 hbo', 'S12_D7 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S12_D7 hbo", "S12_D7 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - with pytest.raises(ValueError, match='is not in list'): - raw.set_montage('artinis-brite23') + with pytest.raises(ValueError, match="is not in list"): + raw.set_montage("artinis-brite23") # Detector not in montage: fail raw = raw_old.copy() - info_new = create_info(['S11_D8 hbo', 'S11_D8 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S11_D8 hbo", "S11_D8 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - with pytest.raises(ValueError, match='is not in list'): - raw.set_montage('artinis-brite23') + with pytest.raises(ValueError, match="is not in list"): + raw.set_montage("artinis-brite23") diff --git a/mne/chpi.py b/mne/chpi.py index 9d80fa6efde..3bbddb1647b 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -29,24 +29,49 @@ from .io.kit.constants import KIT from .io.kit.kit import RawKIT as _RawKIT from .io.meas_info import _simplify_info, Info -from .io.pick import (pick_types, pick_channels, pick_channels_regexp, - pick_info, _picks_to_idx) +from .io.pick import ( + pick_types, + pick_channels, + pick_channels_regexp, + pick_info, + _picks_to_idx, +) from .io.proj import Projection, setup_proj from .io.constants import FIFF from .io.ctf.trans import _make_ctf_coord_trans_set -from .forward import (_magnetic_dipole_field_vec, _create_meg_coils, - _concatenate_coils) +from .forward import _magnetic_dipole_field_vec, _create_meg_coils, _concatenate_coils from .cov import make_ad_hoc_cov, compute_whitener from .dipole import _make_guesses from .fixes import jit -from .preprocessing.maxwell import (_sss_basis, _prep_mf_coils, - _regularize_out, _get_mf_picks_fix_mags) -from .transforms import (apply_trans, invert_transform, _angle_between_quats, - quat_to_rot, rot_to_quat, _fit_matched_points, - _quat_to_affine, als_ras_trans) -from .utils import (verbose, logger, use_log_level, _check_fname, warn, - _validate_type, ProgressBar, _check_option, _pl, - _on_missing, _verbose_safe_false) +from .preprocessing.maxwell import ( + _sss_basis, + _prep_mf_coils, + _regularize_out, + _get_mf_picks_fix_mags, +) +from .transforms import ( + apply_trans, + invert_transform, + _angle_between_quats, + quat_to_rot, + rot_to_quat, + _fit_matched_points, + _quat_to_affine, + als_ras_trans, +) +from .utils import ( + verbose, + logger, + use_log_level, + _check_fname, + warn, + _validate_type, + ProgressBar, + _check_option, + _pl, + _on_missing, + _verbose_safe_false, +) # Eventually we should add: # hpicons @@ -57,6 +82,7 @@ # ############################################################################ # Reading from text or FIF file + def read_head_pos(fname): """Read MaxFilter-formatted head position parameters. @@ -80,12 +106,11 @@ def read_head_pos(fname): ----- .. versionadded:: 0.12 """ - _check_fname(fname, must_exist=True, overwrite='read') + _check_fname(fname, must_exist=True, overwrite="read") data = np.loadtxt(fname, skiprows=1) # first line is header, skip it data.shape = (-1, 10) # ensure it's the right size even if empty if np.isnan(data).any(): # make sure we didn't do something dumb - raise RuntimeError('positions could not be read properly from %s' - % fname) + raise RuntimeError("positions could not be read properly from %s" % fname) return data @@ -111,14 +136,15 @@ def write_head_pos(fname, pos): _check_fname(fname, overwrite=True) pos = np.array(pos, np.float64) if pos.ndim != 2 or pos.shape[1] != 10: - raise ValueError('pos must be a 2D array of shape (N, 10)') - with open(fname, 'wb') as fid: - fid.write(' Time q1 q2 q3 q4 q5 ' - 'q6 g-value error velocity\n'.encode('ASCII')) + raise ValueError("pos must be a 2D array of shape (N, 10)") + with open(fname, "wb") as fid: + fid.write( + " Time q1 q2 q3 q4 q5 " + "q6 g-value error velocity\n".encode("ASCII") + ) for p in pos: - fmts = ['% 9.3f'] + ['% 8.5f'] * 9 - fid.write(((' ' + ' '.join(fmts) + '\n') - % tuple(p)).encode('ASCII')) + fmts = ["% 9.3f"] + ["% 8.5f"] * 9 + fid.write(((" " + " ".join(fmts) + "\n") % tuple(p)).encode("ASCII")) def head_pos_to_trans_rot_t(quats): @@ -178,15 +204,14 @@ def extract_chpi_locs_ctf(raw, verbose=None): .. versionadded:: 0.20 """ # Pick channels corresponding to the cHPI positions - hpi_picks = pick_channels_regexp(raw.info['ch_names'], 'HLC00[123][123].*') + hpi_picks = pick_channels_regexp(raw.info["ch_names"], "HLC00[123][123].*") # make sure we get 9 channels if len(hpi_picks) != 9: - raise RuntimeError('Could not find all 9 cHPI channels') + raise RuntimeError("Could not find all 9 cHPI channels") # get indices in alphabetical order - sorted_picks = np.array(sorted(hpi_picks, - key=lambda k: raw.info['ch_names'][k])) + sorted_picks = np.array(sorted(hpi_picks, key=lambda k: raw.info["ch_names"][k])) # make picks to match order of dig cardinial ident codes. # LPA (HPIC002[123]-*), NAS(HPIC001[123]-*), RPA(HPIC003[123]-*) @@ -199,7 +224,7 @@ def extract_chpi_locs_ctf(raw, verbose=None): # transforms tmp_trans = _make_ctf_coord_trans_set(None, None) - ctf_dev_dev_t = tmp_trans['t_ctf_dev_dev'] + ctf_dev_dev_t = tmp_trans["t_ctf_dev_dev"] del tmp_trans # find indices where chpi locations change @@ -216,7 +241,7 @@ def extract_chpi_locs_ctf(raw, verbose=None): @verbose -def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): +def extract_chpi_locs_kit(raw, stim_channel="MISC 064", *, verbose=None): """Extract cHPI locations from KIT data. Parameters @@ -235,34 +260,35 @@ def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): ----- .. versionadded:: 0.23 """ - _validate_type(raw, (_RawKIT,), 'raw') + _validate_type(raw, (_RawKIT,), "raw") stim_chs = [ - raw.info['ch_names'][pick] for pick in pick_types( - raw.info, stim=True, misc=True, ref_meg=False)] - _validate_type(stim_channel, str, 'stim_channel') - _check_option('stim_channel', stim_channel, stim_chs) + raw.info["ch_names"][pick] + for pick in pick_types(raw.info, stim=True, misc=True, ref_meg=False) + ] + _validate_type(stim_channel, str, "stim_channel") + _check_option("stim_channel", stim_channel, stim_chs) idx = raw.ch_names.index(stim_channel) safe_false = _verbose_safe_false() events_on = find_events( - raw, stim_channel=raw.ch_names[idx], output='onset', - verbose=safe_false)[:, 0] + raw, stim_channel=raw.ch_names[idx], output="onset", verbose=safe_false + )[:, 0] events_off = find_events( - raw, stim_channel=raw.ch_names[idx], output='offset', - verbose=safe_false)[:, 0] + raw, stim_channel=raw.ch_names[idx], output="offset", verbose=safe_false + )[:, 0] bad = False if len(events_on) == 0 or len(events_off) == 0: bad = True else: if events_on[-1] > events_off[-1]: events_on = events_on[:-1] - if events_on.size != events_off.size or not \ - (events_on < events_off).all(): + if events_on.size != events_off.size or not (events_on < events_off).all(): bad = True if bad: raise RuntimeError( - f'Could not find appropriate cHPI intervals from {stim_channel}') + f"Could not find appropriate cHPI intervals from {stim_channel}" + ) # use the midpoint for times - times = (events_on + events_off) / (2 * raw.info['sfreq']) + times = (events_on + events_off) / (2 * raw.info["sfreq"]) del events_on, events_off # XXX remove first two rows. It is unknown currently if there is a way to # determine from the con file the number of initial pulses that @@ -271,24 +297,25 @@ def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): # may just always be 2... times = times[2:] n_coils = 5 # KIT always has 5 (hard-coded in reader) - header = raw._raw_extras[0]['dirs'][KIT.DIR_INDEX_CHPI_DATA] - dtype = np.dtype([('good', ' 0 else None # grab codes indicating a coil is active - hpi_on = [coil['event_bits'][0] for coil in hpi_sub['hpi_coils']] + hpi_on = [coil["event_bits"][0] for coil in hpi_sub["hpi_coils"]] # not all HPI coils will actually be used - hpi_on = np.array([hpi_on[hc['number'] - 1] for hc in hpi_coils]) + hpi_on = np.array([hpi_on[hc["number"] - 1] for hc in hpi_coils]) # mask for coils that may be active hpi_mask = np.array([event_bit != 0 for event_bit in hpi_on]) hpi_on = hpi_on[hpi_mask] @@ -366,63 +404,71 @@ def get_chpi_info(info, on_missing='raise', verbose=None): @verbose def _get_hpi_initial_fit(info, adjust=False, verbose=None): """Get HPI fit locations from raw.""" - if info['hpi_results'] is None or len(info['hpi_results']) == 0: - raise RuntimeError('no initial cHPI head localization performed') - - hpi_result = info['hpi_results'][-1] - hpi_dig = sorted([d for d in info['dig'] - if d['kind'] == FIFF.FIFFV_POINT_HPI], - key=lambda x: x['ident']) # ascending (dig) order + if info["hpi_results"] is None or len(info["hpi_results"]) == 0: + raise RuntimeError("no initial cHPI head localization performed") + + hpi_result = info["hpi_results"][-1] + hpi_dig = sorted( + [d for d in info["dig"] if d["kind"] == FIFF.FIFFV_POINT_HPI], + key=lambda x: x["ident"], + ) # ascending (dig) order if len(hpi_dig) == 0: # CTF data, probably - hpi_dig = sorted(hpi_result['dig_points'], key=lambda x: x['ident']) - if all(d['coord_frame'] in (FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_UNKNOWN) - for d in hpi_dig): + hpi_dig = sorted(hpi_result["dig_points"], key=lambda x: x["ident"]) + if all( + d["coord_frame"] in (FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_UNKNOWN) + for d in hpi_dig + ): for dig in hpi_dig: - dig.update(r=apply_trans(info['dev_head_t'], dig['r']), - coord_frame=FIFF.FIFFV_COORD_HEAD) + dig.update( + r=apply_trans(info["dev_head_t"], dig["r"]), + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) # zero-based indexing, dig->info # CTF does not populate some entries so we use .get here - pos_order = hpi_result.get('order', np.arange(1, len(hpi_dig) + 1)) - 1 - used = hpi_result.get('used', np.arange(len(hpi_dig))) - dist_limit = hpi_result.get('dist_limit', 0.005) - good_limit = hpi_result.get('good_limit', 0.98) - goodness = hpi_result.get('goodness', np.ones(len(hpi_dig))) + pos_order = hpi_result.get("order", np.arange(1, len(hpi_dig) + 1)) - 1 + used = hpi_result.get("used", np.arange(len(hpi_dig))) + dist_limit = hpi_result.get("dist_limit", 0.005) + good_limit = hpi_result.get("good_limit", 0.98) + goodness = hpi_result.get("goodness", np.ones(len(hpi_dig))) # this shouldn't happen, eventually we could add the transforms # necessary to put it in head coords - if not all(d['coord_frame'] == FIFF.FIFFV_COORD_HEAD for d in hpi_dig): - raise RuntimeError('cHPI coordinate frame incorrect') + if not all(d["coord_frame"] == FIFF.FIFFV_COORD_HEAD for d in hpi_dig): + raise RuntimeError("cHPI coordinate frame incorrect") # Give the user some info - logger.info('HPIFIT: %s coils digitized in order %s' - % (len(pos_order), ' '.join(str(o + 1) for o in pos_order))) - logger.debug('HPIFIT: %s coils accepted: %s' - % (len(used), ' '.join(str(h) for h in used))) - hpi_rrs = np.array([d['r'] for d in hpi_dig])[pos_order] + logger.info( + "HPIFIT: %s coils digitized in order %s" + % (len(pos_order), " ".join(str(o + 1) for o in pos_order)) + ) + logger.debug( + "HPIFIT: %s coils accepted: %s" % (len(used), " ".join(str(h) for h in used)) + ) + hpi_rrs = np.array([d["r"] for d in hpi_dig])[pos_order] assert len(hpi_rrs) >= 3 # Fitting errors - hpi_rrs_fit = sorted([d for d in info['hpi_results'][-1]['dig_points']], - key=lambda x: x['ident']) - hpi_rrs_fit = np.array([d['r'] for d in hpi_rrs_fit]) + hpi_rrs_fit = sorted( + [d for d in info["hpi_results"][-1]["dig_points"]], key=lambda x: x["ident"] + ) + hpi_rrs_fit = np.array([d["r"] for d in hpi_rrs_fit]) # hpi_result['dig_points'] are in FIFFV_COORD_UNKNOWN coords, but this # is probably a misnomer because it should be FIFFV_COORD_DEVICE for this # to work - assert hpi_result['coord_trans']['to'] == FIFF.FIFFV_COORD_HEAD - hpi_rrs_fit = apply_trans(hpi_result['coord_trans']['trans'], hpi_rrs_fit) - if 'moments' in hpi_result: - logger.debug('Hpi coil moments (%d %d):' - % hpi_result['moments'].shape[::-1]) - for moment in hpi_result['moments']: + assert hpi_result["coord_trans"]["to"] == FIFF.FIFFV_COORD_HEAD + hpi_rrs_fit = apply_trans(hpi_result["coord_trans"]["trans"], hpi_rrs_fit) + if "moments" in hpi_result: + logger.debug("Hpi coil moments (%d %d):" % hpi_result["moments"].shape[::-1]) + for moment in hpi_result["moments"]: logger.debug("%g %g %g" % tuple(moment)) errors = np.linalg.norm(hpi_rrs - hpi_rrs_fit, axis=1) - logger.debug('HPIFIT errors: %s mm.' - % ', '.join('%0.1f' % (1000. * e) for e in errors)) + logger.debug( + "HPIFIT errors: %s mm." % ", ".join("%0.1f" % (1000.0 * e) for e in errors) + ) if errors.sum() < len(errors) * dist_limit: - logger.info('HPI consistency of isotrak and hpifit is OK.') + logger.info("HPI consistency of isotrak and hpifit is OK.") elif not adjust and (len(used) == len(hpi_dig)): - warn('HPI consistency of isotrak and hpifit is poor.') + warn("HPI consistency of isotrak and hpifit is poor.") else: # adjust HPI coil locations using the hpifit transformation for hi, (err, r_fit) in enumerate(zip(errors, hpi_rrs_fit)): @@ -430,24 +476,33 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): d = 1000 * err if not adjust: if err >= dist_limit: - warn('Discrepancy of HPI coil %d isotrak and hpifit is ' - '%.1f mm!' % (hi + 1, d)) + warn( + "Discrepancy of HPI coil %d isotrak and hpifit is " + "%.1f mm!" % (hi + 1, d) + ) elif hi + 1 not in used: if goodness[hi] >= good_limit: - logger.info('Note: HPI coil %d isotrak is adjusted by ' - '%.1f mm!' % (hi + 1, d)) + logger.info( + "Note: HPI coil %d isotrak is adjusted by " + "%.1f mm!" % (hi + 1, d) + ) hpi_rrs[hi] = r_fit else: - warn('Discrepancy of HPI coil %d isotrak and hpifit of ' - '%.1f mm was not adjusted!' % (hi + 1, d)) - logger.debug('HP fitting limits: err = %.1f mm, gval = %.3f.' - % (1000 * dist_limit, good_limit)) + warn( + "Discrepancy of HPI coil %d isotrak and hpifit of " + "%.1f mm was not adjusted!" % (hi + 1, d) + ) + logger.debug( + "HP fitting limits: err = %.1f mm, gval = %.3f." + % (1000 * dist_limit, good_limit) + ) return hpi_rrs.astype(float) -def _magnetic_dipole_objective(x, B, B2, coils, whitener, too_close, - return_moment=False): +def _magnetic_dipole_objective( + x, B, B2, coils, whitener, too_close, return_moment=False +): """Project data onto right eigenvectors of whitened forward.""" fwd = _magnetic_dipole_field_vec(x[np.newaxis], coils, too_close) out, u, s, one = _magnetic_dipole_delta(fwd, whitener, B, B2) @@ -478,22 +533,27 @@ def _magnetic_dipole_delta_multi(whitened_fwd_svd, B, B2): def _fit_magnetic_dipole(B_orig, x0, too_close, whitener, coils, guesses): """Fit a single bit of data (x0 = pos).""" from scipy.optimize import fmin_cobyla + B = np.dot(whitener, B_orig) B2 = np.dot(B, B) - objective = partial(_magnetic_dipole_objective, B=B, B2=B2, - coils=coils, whitener=whitener, - too_close=too_close) + objective = partial( + _magnetic_dipole_objective, + B=B, + B2=B2, + coils=coils, + whitener=whitener, + too_close=too_close, + ) if guesses is not None: res0 = objective(x0) - res = _magnetic_dipole_delta_multi( - guesses['whitened_fwd_svd'], B, B2) - assert res.shape == (guesses['rr'].shape[0],) + res = _magnetic_dipole_delta_multi(guesses["whitened_fwd_svd"], B, B2) + assert res.shape == (guesses["rr"].shape[0],) idx = np.argmin(res) if res[idx] < res0: - x0 = guesses['rr'][idx] + x0 = guesses["rr"][idx] x = fmin_cobyla(objective, x0, (), rhobeg=1e-3, rhoend=1e-5, disp=False) gof, moment = objective(x, return_moment=True) - gof = 1. - gof / B2 + gof = 1.0 - gof / B2 return x, gof, moment @@ -515,7 +575,7 @@ def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs): # XXX someday we could choose to weight these points by their goodness # of fit somehow. quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] - gof = 1. - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom + gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom return quat, gof @@ -534,7 +594,7 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): # equivalent g values. To avoid this, heavily penalize # large rotations. rotation = _angle_between_quats(this_quat[:3], np.zeros(3)) - check_g = g * max(1. - rotation / np.pi, 0) ** 0.25 + check_g = g * max(1.0 - rotation / np.pi, 0) ** 0.25 else: check_g = g if check_g > best_g: @@ -549,61 +609,77 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): @verbose -def _setup_hpi_amplitude_fitting(info, t_window, remove_aliased=False, - ext_order=1, allow_empty=False, verbose=None): +def _setup_hpi_amplitude_fitting( + info, t_window, remove_aliased=False, ext_order=1, allow_empty=False, verbose=None +): """Generate HPI structure for HPI localization.""" # grab basic info. - on_missing = 'raise' if not allow_empty else 'ignore' + on_missing = "raise" if not allow_empty else "ignore" hpi_freqs, hpi_pick, hpi_ons = get_chpi_info(info, on_missing=on_missing) - _validate_type(t_window, (str, 'numeric'), 't_window') - if info['line_freq'] is not None: - line_freqs = np.arange(info['line_freq'], info['sfreq'] / 3., - info['line_freq']) + _validate_type(t_window, (str, "numeric"), "t_window") + if info["line_freq"] is not None: + line_freqs = np.arange( + info["line_freq"], info["sfreq"] / 3.0, info["line_freq"] + ) else: line_freqs = np.zeros([0]) - logger.info('Line interference frequencies: %s Hz' - % ' '.join(['%d' % lf for lf in line_freqs])) + logger.info( + "Line interference frequencies: %s Hz" + % " ".join(["%d" % lf for lf in line_freqs]) + ) # worry about resampled/filtered data. # What to do e.g. if Raw has been resampled and some of our # HPI freqs would now be aliased - highest = info.get('lowpass') - highest = info['sfreq'] / 2. if highest is None else highest + highest = info.get("lowpass") + highest = info["sfreq"] / 2.0 if highest is None else highest keepers = hpi_freqs <= highest if remove_aliased: hpi_freqs = hpi_freqs[keepers] hpi_ons = hpi_ons[keepers] elif not keepers.all(): - raise RuntimeError('Found HPI frequencies %s above the lowpass ' - '(or Nyquist) frequency %0.1f' - % (hpi_freqs[~keepers].tolist(), highest)) + raise RuntimeError( + "Found HPI frequencies %s above the lowpass " + "(or Nyquist) frequency %0.1f" % (hpi_freqs[~keepers].tolist(), highest) + ) # calculate optimal window length. if isinstance(t_window, str): - _check_option('t_window', t_window, ('auto',), extra='if a string') + _check_option("t_window", t_window, ("auto",), extra="if a string") if len(hpi_freqs): all_freqs = np.concatenate((hpi_freqs, line_freqs)) delta_freqs = np.diff(np.unique(all_freqs)) - t_window = max(5. / all_freqs.min(), 1. / delta_freqs.min()) + t_window = max(5.0 / all_freqs.min(), 1.0 / delta_freqs.min()) else: t_window = 0.2 t_window = float(t_window) if t_window <= 0: - raise ValueError('t_window (%s) must be > 0' % (t_window,)) - logger.info('Using time window: %0.1f ms' % (1000 * t_window,)) - window_nsamp = np.rint(t_window * info['sfreq']).astype(int) - model = _setup_hpi_glm(hpi_freqs, line_freqs, info['sfreq'], window_nsamp) + raise ValueError("t_window (%s) must be > 0" % (t_window,)) + logger.info("Using time window: %0.1f ms" % (1000 * t_window,)) + window_nsamp = np.rint(t_window * info["sfreq"]).astype(int) + model = _setup_hpi_glm(hpi_freqs, line_freqs, info["sfreq"], window_nsamp) inv_model = np.linalg.pinv(model) inv_model_reord = _reorder_inv_model(inv_model, len(hpi_freqs)) proj, proj_op, meg_picks = _setup_ext_proj(info, ext_order) # include mag and grad picks separately, for SNR computations - mag_picks = _picks_to_idx(info, 'mag', allow_empty=True) - grad_picks = _picks_to_idx(info, 'grad', allow_empty=True) + mag_picks = _picks_to_idx(info, "mag", allow_empty=True) + grad_picks = _picks_to_idx(info, "grad", allow_empty=True) # Set up magnetic dipole fits hpi = dict( - meg_picks=meg_picks, mag_picks=mag_picks, grad_picks=grad_picks, - hpi_pick=hpi_pick, model=model, inv_model=inv_model, t_window=t_window, - inv_model_reord=inv_model_reord, on=hpi_ons, n_window=window_nsamp, - proj=proj, proj_op=proj_op, freqs=hpi_freqs, line_freqs=line_freqs) + meg_picks=meg_picks, + mag_picks=mag_picks, + grad_picks=grad_picks, + hpi_pick=hpi_pick, + model=model, + inv_model=inv_model, + t_window=t_window, + inv_model_reord=inv_model_reord, + on=hpi_ons, + n_window=window_nsamp, + proj=proj, + proj_op=proj_op, + freqs=hpi_freqs, + line_freqs=line_freqs, + ) return hpi @@ -613,9 +689,14 @@ def _setup_hpi_glm(hpi_freqs, line_freqs, sfreq, window_nsamp): radians_per_sec = 2 * np.pi * np.arange(window_nsamp, dtype=float) / sfreq f_t = hpi_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis] l_t = line_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis] - model = [np.sin(f_t), np.cos(f_t), # hpi freqs - np.sin(l_t), np.cos(l_t), # line freqs - slope, np.ones_like(slope)] # drift, DC + model = [ + np.sin(f_t), + np.cos(f_t), # hpi freqs + np.sin(l_t), + np.cos(l_t), # line freqs + slope, + np.ones_like(slope), + ] # drift, DC return np.hstack(model) @@ -628,34 +709,40 @@ def _reorder_inv_model(inv_model, n_freqs): def _setup_ext_proj(info, ext_order): from scipy import linalg - meg_picks = pick_types(info, meg=True, eeg=False, exclude='bads') + + meg_picks = pick_types(info, meg=True, eeg=False, exclude="bads") info = pick_info(_simplify_info(info), meg_picks) # makes a copy _, _, _, _, mag_or_fine = _get_mf_picks_fix_mags( - info, int_order=0, ext_order=ext_order, ignore_ref=True, - verbose='error') - mf_coils = _prep_mf_coils(info, verbose='error') + info, int_order=0, ext_order=ext_order, ignore_ref=True, verbose="error" + ) + mf_coils = _prep_mf_coils(info, verbose="error") ext = _sss_basis( - dict(origin=(0., 0., 0.), int_order=0, ext_order=ext_order), - mf_coils).T + dict(origin=(0.0, 0.0, 0.0), int_order=0, ext_order=ext_order), mf_coils + ).T out_removes = _regularize_out(0, 1, mag_or_fine, []) ext = ext[~np.in1d(np.arange(len(ext)), out_removes)] ext = linalg.orth(ext.T).T assert ext.shape[1] == len(meg_picks) proj = Projection( - kind=FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD, desc='SSS', active=False, - data=dict(data=ext, ncol=info['nchan'], col_names=info['ch_names'], - nrow=len(ext))) + kind=FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD, + desc="SSS", + active=False, + data=dict( + data=ext, ncol=info["nchan"], col_names=info["ch_names"], nrow=len(ext) + ), + ) with info._unlock(): - info['projs'] = [proj] + info["projs"] = [proj] proj_op, _ = setup_proj( - info, add_eeg_ref=False, activate=False, verbose=_verbose_safe_false()) + info, add_eeg_ref=False, activate=False, verbose=_verbose_safe_false() + ) assert proj_op.shape == (len(meg_picks),) * 2 return proj, proj_op, meg_picks def _time_prefix(fit_time): """Format log messages.""" - return (' t=%0.3f:' % fit_time).ljust(17) + return (" t=%0.3f:" % fit_time).ljust(17) def _fit_chpi_amplitudes(raw, time_sl, hpi, snr=False): @@ -672,32 +759,43 @@ def _fit_chpi_amplitudes(raw, time_sl, hpi, snr=False): # No need to detrend the data because our model has a DC term with use_log_level(False): # loads good channels - this_data = raw[hpi['meg_picks'], time_sl][0] + this_data = raw[hpi["meg_picks"], time_sl][0] # which HPI coils to use - if hpi['hpi_pick'] is not None: + if hpi["hpi_pick"] is not None: with use_log_level(False): # loads hpi_stim channel - chpi_data = raw[hpi['hpi_pick'], time_sl][0] + chpi_data = raw[hpi["hpi_pick"], time_sl][0] - ons = (np.round(chpi_data).astype(np.int64) & - hpi['on'][:, np.newaxis]).astype(bool) + ons = (np.round(chpi_data).astype(np.int64) & hpi["on"][:, np.newaxis]).astype( + bool + ) n_on = ons.all(axis=-1).sum(axis=0) if not (n_on >= 3).all(): return None if snr: return _fast_fit_snr( - this_data, len(hpi['freqs']), hpi['model'], hpi['inv_model'], - hpi['mag_picks'], hpi['grad_picks']) - return _fast_fit(this_data, hpi['proj_op'], len(hpi['freqs']), - hpi['model'], hpi['inv_model_reord']) + this_data, + len(hpi["freqs"]), + hpi["model"], + hpi["inv_model"], + hpi["mag_picks"], + hpi["grad_picks"], + ) + return _fast_fit( + this_data, + hpi["proj_op"], + len(hpi["freqs"]), + hpi["model"], + hpi["inv_model_reord"], + ) @jit() def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): # first or last window if this_data.shape[1] != model.shape[0]: - model = model[:this_data.shape[1]] + model = model[: this_data.shape[1]] inv_model_reord = _reorder_inv_model(np.linalg.pinv(model), n_freqs) proj_data = proj @ this_data X = inv_model_reord @ proj_data.T @@ -705,7 +803,7 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): sin_fit = np.zeros((n_freqs, X.shape[1])) for fi in range(n_freqs): # use SVD across all sensors to estimate the sinusoid phase - u, s, vt = np.linalg.svd(X[2 * fi:2 * fi + 2], full_matrices=False) + u, s, vt = np.linalg.svd(X[2 * fi : 2 * fi + 2], full_matrices=False) # the first component holds the predominant phase direction # (so ignore the second, effectively doing s[1] = 0): sin_fit[fi] = vt[0] * s[0] @@ -716,11 +814,11 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): def _fast_fit_snr(this_data, n_freqs, model, inv_model, mag_picks, grad_picks): # first or last window if this_data.shape[1] != model.shape[0]: - model = model[:this_data.shape[1]] + model = model[: this_data.shape[1]] inv_model = np.linalg.pinv(model) coefs = np.ascontiguousarray(inv_model) @ np.ascontiguousarray(this_data.T) # average sin & cos terms (special property of sinusoids: power=A²/2) - hpi_power = (coefs[:n_freqs] ** 2 + coefs[n_freqs:(2 * n_freqs)] ** 2) / 2 + hpi_power = (coefs[:n_freqs] ** 2 + coefs[n_freqs : (2 * n_freqs)] ** 2) / 2 resid = this_data - np.ascontiguousarray((model @ coefs).T) # can't use np.var(..., axis=1) with Numba, so do it manually: resid_mean = np.atleast_2d(resid.sum(axis=1) / resid.shape[1]).T @@ -741,59 +839,70 @@ def _fast_fit_snr(this_data, n_freqs, model, inv_model, mag_picks, grad_picks): def _check_chpi_param(chpi_, name): - if name == 'chpi_locs': + if name == "chpi_locs": want_ndims = dict(times=1, rrs=3, moments=3, gofs=2) extra_keys = list() else: - assert name == 'chpi_amplitudes' + assert name == "chpi_amplitudes" want_ndims = dict(times=1, slopes=3) - extra_keys = ['proj'] + extra_keys = ["proj"] _validate_type(chpi_, dict, name) want_keys = list(want_ndims.keys()) + extra_keys if set(want_keys).symmetric_difference(chpi_): - raise ValueError('%s must be a dict with entries %s, got %s' - % (name, want_keys, sorted(chpi_.keys()))) + raise ValueError( + "%s must be a dict with entries %s, got %s" + % (name, want_keys, sorted(chpi_.keys())) + ) n_times = None for key, want_ndim in want_ndims.items(): - key_str = '%s[%s]' % (name, key) + key_str = "%s[%s]" % (name, key) val = chpi_[key] _validate_type(val, np.ndarray, key_str) shape = val.shape if val.ndim != want_ndim: - raise ValueError('%s must have ndim=%d, got %d' - % (key_str, want_ndim, val.ndim)) - if n_times is None and key != 'proj': + raise ValueError( + "%s must have ndim=%d, got %d" % (key_str, want_ndim, val.ndim) + ) + if n_times is None and key != "proj": n_times = shape[0] - if n_times != shape[0] and key != 'proj': - raise ValueError('%s have inconsistent number of time ' - 'points in %s' % (name, want_keys)) - if name == 'chpi_locs': - n_coils = chpi_['rrs'].shape[1] - for key in ('gofs', 'moments'): + if n_times != shape[0] and key != "proj": + raise ValueError( + "%s have inconsistent number of time " + "points in %s" % (name, want_keys) + ) + if name == "chpi_locs": + n_coils = chpi_["rrs"].shape[1] + for key in ("gofs", "moments"): val = chpi_[key] if val.shape[1] != n_coils: - raise ValueError('chpi_locs["rrs"] had values for %d coils but' - ' chpi_locs["%s"] had values for %d coils' - % (n_coils, key, val.shape[1])) - for key in ('rrs', 'moments'): + raise ValueError( + 'chpi_locs["rrs"] had values for %d coils but' + ' chpi_locs["%s"] had values for %d coils' + % (n_coils, key, val.shape[1]) + ) + for key in ("rrs", "moments"): val = chpi_[key] if val.shape[2] != 3: - raise ValueError('chpi_locs["%s"].shape[2] must be 3, got ' - 'shape %s' % (key, shape)) + raise ValueError( + 'chpi_locs["%s"].shape[2] must be 3, got ' "shape %s" % (key, shape) + ) else: - assert name == 'chpi_amplitudes' - slopes, proj = chpi_['slopes'], chpi_['proj'] + assert name == "chpi_amplitudes" + slopes, proj = chpi_["slopes"], chpi_["proj"] _validate_type(proj, Projection, 'chpi_amplitudes["proj"]') - n_ch = len(proj['data']['col_names']) + n_ch = len(proj["data"]["col_names"]) if slopes.shape[0] != n_times or slopes.shape[2] != n_ch: - raise ValueError('slopes must have shape[0]==%d and shape[2]==%d,' - ' got shape %s' % (n_times, n_ch, slopes.shape)) + raise ValueError( + "slopes must have shape[0]==%d and shape[2]==%d," + " got shape %s" % (n_times, n_ch, slopes.shape) + ) @verbose -def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, - adjust_dig=False, verbose=None): +def compute_head_pos( + info, chpi_locs, dist_limit=0.005, gof_limit=0.98, adjust_dig=False, verbose=None +): """Compute time-varying head positions. Parameters @@ -825,29 +934,30 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, ----- .. versionadded:: 0.20 """ - _check_chpi_param(chpi_locs, 'chpi_locs') - _validate_type(info, Info, 'info') - hpi_dig_head_rrs = _get_hpi_initial_fit(info, adjust=adjust_dig, - verbose='error') + _check_chpi_param(chpi_locs, "chpi_locs") + _validate_type(info, Info, "info") + hpi_dig_head_rrs = _get_hpi_initial_fit(info, adjust=adjust_dig, verbose="error") n_coils = len(hpi_dig_head_rrs) - coil_dev_rrs = apply_trans(invert_transform(info['dev_head_t']), - hpi_dig_head_rrs) - dev_head_t = info['dev_head_t']['trans'] + coil_dev_rrs = apply_trans(invert_transform(info["dev_head_t"]), hpi_dig_head_rrs) + dev_head_t = info["dev_head_t"]["trans"] pos_0 = dev_head_t[:3, 3] - last = dict(quat_fit_time=-0.1, coil_dev_rrs=coil_dev_rrs, - quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]), - dev_head_t[:3, 3]])) + last = dict( + quat_fit_time=-0.1, + coil_dev_rrs=coil_dev_rrs, + quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]), dev_head_t[:3, 3]]), + ) del coil_dev_rrs quats = [] for fit_time, this_coil_dev_rrs, g_coils in zip( - *(chpi_locs[key] for key in ('times', 'rrs', 'gofs'))): + *(chpi_locs[key] for key in ("times", "rrs", "gofs")) + ): use_idx = np.where(g_coils >= gof_limit)[0] # # 1. Check number of good ones # if len(use_idx) < 3: - gofs = ', '.join(f"{g:0.2f}" for g in g_coils) + gofs = ", ".join(f"{g:0.2f}" for g in g_coils) warn( f"{_time_prefix(fit_time)}{len(use_idx)}/{n_coils} " "good HPI fits, cannot determine the transformation " @@ -861,7 +971,8 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, # positions) iteratively using different sets of coils. # this_quat, g, use_idx = _fit_chpi_quat_subset( - this_coil_dev_rrs, hpi_dig_head_rrs, use_idx) + this_coil_dev_rrs, hpi_dig_head_rrs, use_idx + ) # # 3. Stop if < 3 good @@ -873,64 +984,87 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, errs = np.linalg.norm(hpi_dig_head_rrs - est_coil_head_rrs, axis=1) n_good = ((g_coils >= gof_limit) & (errs < dist_limit)).sum() if n_good < 3: - warn(_time_prefix(fit_time) + '%s/%s good HPI fits, cannot ' - 'determine the transformation (%s mm/GOF)!' - % (n_good, n_coils, - ', '.join(f'{1000 * e:0.1f}::{g:0.2f}' - for e, g in zip(errs, g_coils)))) + warn( + _time_prefix(fit_time) + "%s/%s good HPI fits, cannot " + "determine the transformation (%s mm/GOF)!" + % ( + n_good, + n_coils, + ", ".join( + f"{1000 * e:0.1f}::{g:0.2f}" for e, g in zip(errs, g_coils) + ), + ) + ) continue # velocities, in device coords, of HPI coils - dt = fit_time - last['quat_fit_time'] - vs = tuple(1000. * np.linalg.norm(last['coil_dev_rrs'] - - this_coil_dev_rrs, axis=1) / dt) - logger.info(_time_prefix(fit_time) + - ('%s/%s good HPI fits, movements [mm/s] = ' + - ' / '.join(['% 8.1f'] * n_coils)) - % ((n_good, n_coils) + vs)) + dt = fit_time - last["quat_fit_time"] + vs = tuple( + 1000.0 + * np.linalg.norm(last["coil_dev_rrs"] - this_coil_dev_rrs, axis=1) + / dt + ) + logger.info( + _time_prefix(fit_time) + + ( + "%s/%s good HPI fits, movements [mm/s] = " + + " / ".join(["% 8.1f"] * n_coils) + ) + % ((n_good, n_coils) + vs) + ) # Log results # MaxFilter averages over a 200 ms window for display, but we don't for ii in range(n_coils): if ii in use_idx: - start, end = ' ', '/' + start, end = " ", "/" else: - start, end = '(', ')' - log_str = (' ' + start + - '{0:6.1f} {1:6.1f} {2:6.1f} / ' + - '{3:6.1f} {4:6.1f} {5:6.1f} / ' + - 'g = {6:0.3f} err = {7:4.1f} ' + - end) - vals = np.concatenate((1000 * hpi_dig_head_rrs[ii], - 1000 * est_coil_head_rrs[ii], - [g_coils[ii], 1000 * errs[ii]])) + start, end = "(", ")" + log_str = ( + " " + + start + + "{0:6.1f} {1:6.1f} {2:6.1f} / " + + "{3:6.1f} {4:6.1f} {5:6.1f} / " + + "g = {6:0.3f} err = {7:4.1f} " + + end + ) + vals = np.concatenate( + ( + 1000 * hpi_dig_head_rrs[ii], + 1000 * est_coil_head_rrs[ii], + [g_coils[ii], 1000 * errs[ii]], + ) + ) if len(use_idx) >= 3: if ii <= 2: - log_str += '{8:6.3f} {9:6.3f} {10:6.3f}' - vals = np.concatenate( - (vals, this_dev_head_t[ii, :3])) + log_str += "{8:6.3f} {9:6.3f} {10:6.3f}" + vals = np.concatenate((vals, this_dev_head_t[ii, :3])) elif ii == 3: - log_str += '{8:6.1f} {9:6.1f} {10:6.1f}' - vals = np.concatenate( - (vals, this_dev_head_t[:3, 3] * 1000.)) + log_str += "{8:6.1f} {9:6.1f} {10:6.1f}" + vals = np.concatenate((vals, this_dev_head_t[:3, 3] * 1000.0)) logger.debug(log_str.format(*vals)) # resulting errors in head coil positions - d = np.linalg.norm(last['quat'][3:] - this_quat[3:]) # m - r = _angle_between_quats(last['quat'][:3], this_quat[:3]) / dt + d = np.linalg.norm(last["quat"][3:] - this_quat[3:]) # m + r = _angle_between_quats(last["quat"][:3], this_quat[:3]) / dt v = d / dt # m/s d = 100 * np.linalg.norm(this_quat[3:] - pos_0) # dis from 1st - logger.debug(' #t = %0.3f, #e = %0.2f cm, #g = %0.3f, ' - '#v = %0.2f cm/s, #r = %0.2f rad/s, #d = %0.2f cm' - % (fit_time, 100 * errs.mean(), g, 100 * v, r, d)) - logger.debug(' #t = %0.3f, #q = %s ' - % (fit_time, ' '.join(map('{:8.5f}'.format, this_quat)))) - - quats.append(np.concatenate(([fit_time], this_quat, [g], - [errs[use_idx].mean()], [v]))) - last['quat_fit_time'] = fit_time - last['quat'] = this_quat - last['coil_dev_rrs'] = this_coil_dev_rrs + logger.debug( + " #t = %0.3f, #e = %0.2f cm, #g = %0.3f, " + "#v = %0.2f cm/s, #r = %0.2f rad/s, #d = %0.2f cm" + % (fit_time, 100 * errs.mean(), g, 100 * v, r, d) + ) + logger.debug( + " #t = %0.3f, #q = %s " + % (fit_time, " ".join(map("{:8.5f}".format, this_quat))) + ) + + quats.append( + np.concatenate(([fit_time], this_quat, [g], [errs[use_idx].mean()], [v])) + ) + last["quat_fit_time"] = fit_time + last["quat"] = this_quat + last["coil_dev_rrs"] = this_coil_dev_rrs quats = np.array(quats, np.float64) quats = np.zeros((0, 10)) if quats.size == 0 else quats return quats @@ -941,9 +1075,10 @@ def _fit_chpi_quat_subset(coil_dev_rrs, coil_head_rrs, use_idx): out_idx = use_idx.copy() if len(use_idx) > 3: # try dropping one (recursively) for di in range(len(use_idx)): - this_use_idx = list(use_idx[:di]) + list(use_idx[di + 1:]) + this_use_idx = list(use_idx[:di]) + list(use_idx[di + 1 :]) this_quat, this_g, this_use_idx = _fit_chpi_quat_subset( - coil_dev_rrs, coil_head_rrs, this_use_idx) + coil_dev_rrs, coil_head_rrs, this_use_idx + ) if this_g > g: quat, g, out_idx = this_quat, this_g, this_use_idx return quat, g, np.array(out_idx, int) @@ -956,8 +1091,9 @@ def _unit_quat_constraint(x): @verbose -def compute_chpi_snr(raw, t_step_min=0.01, t_window='auto', ext_order=1, - tmin=0, tmax=None, verbose=None): +def compute_chpi_snr( + raw, t_step_min=0.01, t_window="auto", ext_order=1, tmin=0, tmax=None, verbose=None +): """Compute time-varying estimates of cHPI SNR. Parameters @@ -988,13 +1124,15 @@ def compute_chpi_snr(raw, t_step_min=0.01, t_window='auto', ext_order=1, ----- .. versionadded:: 0.24 """ - return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order, - tmin, tmax, verbose, snr=True) + return _compute_chpi_amp_or_snr( + raw, t_step_min, t_window, ext_order, tmin, tmax, verbose, snr=True + ) @verbose -def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto', - ext_order=1, tmin=0, tmax=None, verbose=None): +def compute_chpi_amplitudes( + raw, t_step_min=0.01, t_window="auto", ext_order=1, tmin=0, tmax=None, verbose=None +): """Compute time-varying cHPI amplitudes. Parameters @@ -1040,13 +1178,21 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto', .. versionadded:: 0.20 """ - return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order, - tmin, tmax, verbose) + return _compute_chpi_amp_or_snr( + raw, t_step_min, t_window, ext_order, tmin, tmax, verbose + ) -def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', - ext_order=1, tmin=0, tmax=None, verbose=None, - snr=False): +def _compute_chpi_amp_or_snr( + raw, + t_step_min=0.01, + t_window="auto", + ext_order=1, + tmin=0, + tmax=None, + verbose=None, + snr=False, +): """Compute cHPI amplitude or SNR. See compute_chpi_amplitudes for parameter descriptions. One additional @@ -1055,42 +1201,44 @@ def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', """ hpi = _setup_hpi_amplitude_fitting(raw.info, t_window, ext_order=ext_order) tmin, tmax = raw._tmin_tmax_to_start_stop(tmin, tmax) - tmin = tmin / raw.info['sfreq'] - tmax = tmax / raw.info['sfreq'] - need_win = hpi['t_window'] / 2. - fit_idxs = raw.time_as_index(np.arange( - tmin + need_win, tmax, t_step_min), use_rounding=True) - logger.info('Fitting %d HPI coil locations at up to %s time points ' - '(%0.1f s duration)' - % (len(hpi['freqs']), len(fit_idxs), tmax - tmin)) + tmin = tmin / raw.info["sfreq"] + tmax = tmax / raw.info["sfreq"] + need_win = hpi["t_window"] / 2.0 + fit_idxs = raw.time_as_index( + np.arange(tmin + need_win, tmax, t_step_min), use_rounding=True + ) + logger.info( + "Fitting %d HPI coil locations at up to %s time points " + "(%0.1f s duration)" % (len(hpi["freqs"]), len(fit_idxs), tmax - tmin) + ) del tmin, tmax sin_fits = dict() - sin_fits['proj'] = hpi['proj'] - sin_fits['times'] = np.round(fit_idxs + raw.first_samp - - hpi['n_window'] / 2.) / raw.info['sfreq'] - n_times = len(sin_fits['times']) - n_freqs = len(hpi['freqs']) - n_chans = len(sin_fits['proj']['data']['col_names']) + sin_fits["proj"] = hpi["proj"] + sin_fits["times"] = ( + np.round(fit_idxs + raw.first_samp - hpi["n_window"] / 2.0) / raw.info["sfreq"] + ) + n_times = len(sin_fits["times"]) + n_freqs = len(hpi["freqs"]) + n_chans = len(sin_fits["proj"]["data"]["col_names"]) if snr: - del sin_fits['proj'] - sin_fits['freqs'] = hpi['freqs'] + del sin_fits["proj"] + sin_fits["freqs"] = hpi["freqs"] ch_types = raw.get_channel_types() - grad_offset = 3 if 'mag' in ch_types else 0 - for ch_type in ('mag', 'grad'): + grad_offset = 3 if "mag" in ch_types else 0 + for ch_type in ("mag", "grad"): if ch_type in ch_types: - for key in ('snr', 'power', 'resid'): - cols = 1 if key == 'resid' else n_freqs - sin_fits[f'{ch_type}_{key}'] = np.empty((n_times, cols)) + for key in ("snr", "power", "resid"): + cols = 1 if key == "resid" else n_freqs + sin_fits[f"{ch_type}_{key}"] = np.empty((n_times, cols)) else: - sin_fits['slopes'] = np.empty((n_times, n_freqs, n_chans)) + sin_fits["slopes"] = np.empty((n_times, n_freqs, n_chans)) message = f"cHPI {'SNRs' if snr else 'amplitudes'}" for mi, midpt in enumerate(ProgressBar(fit_idxs, mesg=message)): # # 0. determine samples to fit. # - time_sl = midpt - hpi['n_window'] // 2 - time_sl = slice(max(time_sl, 0), - min(time_sl + hpi['n_window'], len(raw.times))) + time_sl = midpt - hpi["n_window"] // 2 + time_sl = slice(max(time_sl, 0), min(time_sl + hpi["n_window"], len(raw.times))) # # 1. Fit amplitudes for each channel from each of the N sinusoids @@ -1103,22 +1251,28 @@ def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', # is returned as a (tiled) vector (again, because Numba) so that's # why below we take amps_or_snrs[0, 2] instead of [:, 2] ch_types = raw.get_channel_types() - if 'mag' in ch_types: - sin_fits['mag_snr'][mi] = amps_or_snrs[:, 0] # SNR - sin_fits['mag_power'][mi] = amps_or_snrs[:, 1] # mean power - sin_fits['mag_resid'][mi] = amps_or_snrs[0, 2] # mean resid - if 'grad' in ch_types: - sin_fits['grad_snr'][mi] = amps_or_snrs[:, grad_offset] - sin_fits['grad_power'][mi] = amps_or_snrs[:, grad_offset + 1] - sin_fits['grad_resid'][mi] = amps_or_snrs[0, grad_offset + 2] + if "mag" in ch_types: + sin_fits["mag_snr"][mi] = amps_or_snrs[:, 0] # SNR + sin_fits["mag_power"][mi] = amps_or_snrs[:, 1] # mean power + sin_fits["mag_resid"][mi] = amps_or_snrs[0, 2] # mean resid + if "grad" in ch_types: + sin_fits["grad_snr"][mi] = amps_or_snrs[:, grad_offset] + sin_fits["grad_power"][mi] = amps_or_snrs[:, grad_offset + 1] + sin_fits["grad_resid"][mi] = amps_or_snrs[0, grad_offset + 2] else: - sin_fits['slopes'][mi] = amps_or_snrs + sin_fits["slopes"][mi] = amps_or_snrs return sin_fits @verbose -def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', - adjust_dig=False, verbose=None): +def compute_chpi_locs( + info, + chpi_amplitudes, + t_step_max=1.0, + too_close="raise", + adjust_dig=False, + verbose=None, +): """Compute locations of each cHPI coils over time. Parameters @@ -1163,19 +1317,18 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', .. versionadded:: 0.20 """ # Set up magnetic dipole fits - _check_option('too_close', too_close, ['raise', 'warning', 'info']) - _check_chpi_param(chpi_amplitudes, 'chpi_amplitudes') - _validate_type(info, Info, 'info') + _check_option("too_close", too_close, ["raise", "warning", "info"]) + _check_chpi_param(chpi_amplitudes, "chpi_amplitudes") + _validate_type(info, Info, "info") sin_fits = chpi_amplitudes # use the old name below del chpi_amplitudes - proj = sin_fits['proj'] - meg_picks = pick_channels( - info['ch_names'], proj['data']['col_names'], ordered=True) + proj = sin_fits["proj"] + meg_picks = pick_channels(info["ch_names"], proj["data"]["col_names"], ordered=True) info = pick_info(info, meg_picks) # makes a copy with info._unlock(): - info['projs'] = [proj] + info["projs"] = [proj] del meg_picks, proj - meg_coils = _concatenate_coils(_create_meg_coils(info['chs'], 'accurate')) + meg_coils = _concatenate_coils(_create_meg_coils(info["chs"], "accurate")) # Set up external model for interference suppression safe_false = _verbose_safe_false() @@ -1184,10 +1337,13 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', # Make some location guesses (1 cm grid) R = np.linalg.norm(meg_coils[0], axis=1).min() - guesses = _make_guesses(dict(R=R, r0=np.zeros(3)), 0.01, 0., 0.005, - verbose=safe_false)[0]['rr'] - logger.info('Computing %d HPI location guesses (1 cm grid in a %0.1f cm ' - 'sphere)' % (len(guesses), R * 100)) + guesses = _make_guesses( + dict(R=R, r0=np.zeros(3)), 0.01, 0.0, 0.005, verbose=safe_false + )[0]["rr"] + logger.info( + "Computing %d HPI location guesses (1 cm grid in a %0.1f cm " + "sphere)" % (len(guesses), R * 100) + ) fwd = _magnetic_dipole_field_vec(guesses, meg_coils, too_close) fwd = np.dot(fwd, whitener.T) fwd.shape = (guesses.shape[0], 3, -1) @@ -1195,51 +1351,58 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', guesses = dict(rr=guesses, whitened_fwd_svd=fwd) del fwd, R - iter_ = list(zip(sin_fits['times'], sin_fits['slopes'])) + iter_ = list(zip(sin_fits["times"], sin_fits["slopes"])) chpi_locs = dict(times=[], rrs=[], gofs=[], moments=[]) # setup last iteration structure hpi_dig_dev_rrs = apply_trans( - invert_transform(info['dev_head_t'])['trans'], - _get_hpi_initial_fit(info, adjust=adjust_dig)) - last = dict(sin_fit=None, coil_fit_time=sin_fits['times'][0] - 1, - coil_dev_rrs=hpi_dig_dev_rrs) + invert_transform(info["dev_head_t"])["trans"], + _get_hpi_initial_fit(info, adjust=adjust_dig), + ) + last = dict( + sin_fit=None, + coil_fit_time=sin_fits["times"][0] - 1, + coil_dev_rrs=hpi_dig_dev_rrs, + ) n_hpi = len(hpi_dig_dev_rrs) del hpi_dig_dev_rrs - for fit_time, sin_fit in ProgressBar(iter_, mesg='cHPI locations '): + for fit_time, sin_fit in ProgressBar(iter_, mesg="cHPI locations "): # skip this window if bad if not np.isfinite(sin_fit).all(): continue # check if data has sufficiently changed - if last['sin_fit'] is not None: # first iteration + if last["sin_fit"] is not None: # first iteration corrs = np.array( - [np.corrcoef(s, lst)[0, 1] - for s, lst in zip(sin_fit, last['sin_fit'])]) + [np.corrcoef(s, lst)[0, 1] for s, lst in zip(sin_fit, last["sin_fit"])] + ) corrs *= corrs # check to see if we need to continue - if fit_time - last['coil_fit_time'] <= t_step_max - 1e-7 and \ - (corrs > 0.98).sum() >= 3: + if ( + fit_time - last["coil_fit_time"] <= t_step_max - 1e-7 + and (corrs > 0.98).sum() >= 3 + ): # don't need to refit data continue # update 'last' sin_fit *before* inplace sign mult - last['sin_fit'] = sin_fit.copy() + last["sin_fit"] = sin_fit.copy() # # 2. Fit magnetic dipole for each coil to obtain coil positions # in device coordinates # - coil_fits = [_fit_magnetic_dipole(f, x0, too_close, whitener, - meg_coils, guesses) - for f, x0 in zip(sin_fit, last['coil_dev_rrs'])] + coil_fits = [ + _fit_magnetic_dipole(f, x0, too_close, whitener, meg_coils, guesses) + for f, x0 in zip(sin_fit, last["coil_dev_rrs"]) + ] rrs, gofs, moments = zip(*coil_fits) - chpi_locs['times'].append(fit_time) - chpi_locs['rrs'].append(rrs) - chpi_locs['gofs'].append(gofs) - chpi_locs['moments'].append(moments) - last['coil_fit_time'] = fit_time - last['coil_dev_rrs'] = rrs - n_times = len(chpi_locs['times']) + chpi_locs["times"].append(fit_time) + chpi_locs["rrs"].append(rrs) + chpi_locs["gofs"].append(gofs) + chpi_locs["moments"].append(moments) + last["coil_fit_time"] = fit_time + last["coil_dev_rrs"] = rrs + n_times = len(chpi_locs["times"]) shapes = dict( times=(n_times,), rrs=(n_times, n_hpi, 3), @@ -1254,17 +1417,32 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', def _chpi_locs_to_times_dig(chpi_locs): """Reformat chpi_locs as list of dig (dict).""" dig = list() - for rrs, gofs in zip(*(chpi_locs[key] for key in ('rrs', 'gofs'))): - dig.append([{'r': rr, 'ident': idx, 'gof': gof, - 'kind': FIFF.FIFFV_POINT_HPI, - 'coord_frame': FIFF.FIFFV_COORD_DEVICE} - for idx, (rr, gof) in enumerate(zip(rrs, gofs), 1)]) - return chpi_locs['times'], dig + for rrs, gofs in zip(*(chpi_locs[key] for key in ("rrs", "gofs"))): + dig.append( + [ + { + "r": rr, + "ident": idx, + "gof": gof, + "kind": FIFF.FIFFV_POINT_HPI, + "coord_frame": FIFF.FIFFV_COORD_DEVICE, + } + for idx, (rr, gof) in enumerate(zip(rrs, gofs), 1) + ] + ) + return chpi_locs["times"], dig @verbose -def filter_chpi(raw, include_line=True, t_step=0.01, t_window='auto', - ext_order=1, allow_line_only=False, verbose=None): +def filter_chpi( + raw, + include_line=True, + t_step=0.01, + t_window="auto", + ext_order=1, + allow_line_only=False, + verbose=None, +): """Remove cHPI and line noise from data. .. note:: This function will only work properly if cHPI was on @@ -1301,73 +1479,80 @@ def filter_chpi(raw, include_line=True, t_step=0.01, t_window='auto', .. versionadded:: 0.12 """ - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") if not raw.preload: - raise RuntimeError('raw data must be preloaded') + raise RuntimeError("raw data must be preloaded") t_step = float(t_step) if t_step <= 0: - raise ValueError('t_step (%s) must be > 0' % (t_step,)) - n_step = int(np.ceil(t_step * raw.info['sfreq'])) - if include_line and raw.info['line_freq'] is None: - raise RuntimeError('include_line=True but raw.info["line_freq"] is ' - 'None, consider setting it to the line frequency') + raise ValueError("t_step (%s) must be > 0" % (t_step,)) + n_step = int(np.ceil(t_step * raw.info["sfreq"])) + if include_line and raw.info["line_freq"] is None: + raise RuntimeError( + 'include_line=True but raw.info["line_freq"] is ' + "None, consider setting it to the line frequency" + ) hpi = _setup_hpi_amplitude_fitting( - raw.info, t_window, remove_aliased=True, ext_order=ext_order, - allow_empty=allow_line_only, verbose=_verbose_safe_false()) + raw.info, + t_window, + remove_aliased=True, + ext_order=ext_order, + allow_empty=allow_line_only, + verbose=_verbose_safe_false(), + ) - fit_idxs = np.arange(0, len(raw.times) + hpi['n_window'] // 2, n_step) - n_freqs = len(hpi['freqs']) + fit_idxs = np.arange(0, len(raw.times) + hpi["n_window"] // 2, n_step) + n_freqs = len(hpi["freqs"]) n_remove = 2 * n_freqs meg_picks = pick_types(raw.info, meg=True, exclude=()) # filter all chs n_times = len(raw.times) - msg = 'Removing %s cHPI' % n_freqs + msg = "Removing %s cHPI" % n_freqs if include_line: - n_remove += 2 * len(hpi['line_freqs']) - msg += ' and %s line harmonic' % len(hpi['line_freqs']) - msg += ' frequencies from %s MEG channels' % len(meg_picks) + n_remove += 2 * len(hpi["line_freqs"]) + msg += " and %s line harmonic" % len(hpi["line_freqs"]) + msg += " frequencies from %s MEG channels" % len(meg_picks) - recon = np.dot(hpi['model'][:, :n_remove], hpi['inv_model'][:n_remove]).T + recon = np.dot(hpi["model"][:, :n_remove], hpi["inv_model"][:n_remove]).T logger.info(msg) chunks = list() # the chunks to subtract last_endpt = 0 - pb = ProgressBar(fit_idxs, mesg='Filtering') + pb = ProgressBar(fit_idxs, mesg="Filtering") for ii, midpt in enumerate(pb): - left_edge = midpt - hpi['n_window'] // 2 - time_sl = slice(max(left_edge, 0), - min(left_edge + hpi['n_window'], len(raw.times))) + left_edge = midpt - hpi["n_window"] // 2 + time_sl = slice( + max(left_edge, 0), min(left_edge + hpi["n_window"], len(raw.times)) + ) this_len = time_sl.stop - time_sl.start - if this_len == hpi['n_window']: + if this_len == hpi["n_window"]: this_recon = recon else: # first or last window - model = hpi['model'][:this_len] + model = hpi["model"][:this_len] inv_model = np.linalg.pinv(model) this_recon = np.dot(model[:, :n_remove], inv_model[:n_remove]).T this_data = raw._data[meg_picks, time_sl] subt_pt = min(midpt + n_step, n_times) if last_endpt != subt_pt: - fit_left_edge = left_edge - time_sl.start + hpi['n_window'] // 2 - fit_sl = slice(fit_left_edge, - fit_left_edge + (subt_pt - last_endpt)) + fit_left_edge = left_edge - time_sl.start + hpi["n_window"] // 2 + fit_sl = slice(fit_left_edge, fit_left_edge + (subt_pt - last_endpt)) chunks.append((subt_pt, np.dot(this_data, this_recon[:, fit_sl]))) last_endpt = subt_pt # Consume (trailing) chunks that are now safe to remove because # our windows will no longer touch them if ii < len(fit_idxs) - 1: - next_left_edge = fit_idxs[ii + 1] - hpi['n_window'] // 2 + next_left_edge = fit_idxs[ii + 1] - hpi["n_window"] // 2 else: next_left_edge = np.inf while len(chunks) > 0 and chunks[0][0] <= next_left_edge: right_edge, chunk = chunks.pop(0) - raw._data[meg_picks, - right_edge - chunk.shape[1]:right_edge] -= chunk + raw._data[meg_picks, right_edge - chunk.shape[1] : right_edge] -= chunk return raw def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): """Compute good coils based on distances.""" from scipy.spatial.distance import cdist + these_dists = cdist(new_pos, new_pos) these_dists = np.abs(hpi_coil_dists - these_dists) # there is probably a better algorithm for finding the bad ones... @@ -1375,7 +1560,7 @@ def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): use_mask = np.ones(len(hpi_coil_dists), bool) while not good: d = these_dists[use_mask][:, use_mask] - d_bad = (d > dist_limit) + d_bad = d > dist_limit good = not d_bad.any() if not good: if use_mask.sum() == 2: @@ -1389,7 +1574,7 @@ def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): @verbose -def get_active_chpi(raw, *, on_missing='raise', verbose=None): +def get_active_chpi(raw, *, on_missing="raise", verbose=None): """Determine how many HPI coils were active for a time point. Parameters @@ -1412,10 +1597,14 @@ def get_active_chpi(raw, *, on_missing='raise', verbose=None): system, _ = _get_meg_system(raw.info) # check whether we have a neuromag system - if system not in ['122m', '306m']: - raise NotImplementedError(('Identifying active HPI channels' - ' is not implemented for other systems' - ' than neuromag.')) + if system not in ["122m", "306m"]: + raise NotImplementedError( + ( + "Identifying active HPI channels" + " is not implemented for other systems" + " than neuromag." + ) + ) # extract hpi info chpi_info = get_chpi_info(raw.info, on_missing=on_missing) if len(chpi_info[2]) == 0: diff --git a/mne/commands/mne_anonymize.py b/mne/commands/mne_anonymize.py index 7c858319265..d4b54000b78 100644 --- a/mne/commands/mne_anonymize.py +++ b/mne/commands/mne_anonymize.py @@ -19,7 +19,7 @@ import mne import os.path as op -ANONYMIZE_FILE_PREFIX = 'anon' +ANONYMIZE_FILE_PREFIX = "anon" def mne_anonymize(fif_fname, out_fname, keep_his, daysback, overwrite): @@ -49,8 +49,7 @@ def mne_anonymize(fif_fname, out_fname, keep_his, daysback, overwrite): dir_name = op.split(fif_fname)[0] if out_fname is None: fif_bname = op.basename(fif_fname) - out_fname = op.join(dir_name, - "{}-{}".format(ANONYMIZE_FILE_PREFIX, fif_bname)) + out_fname = op.join(dir_name, "{}-{}".format(ANONYMIZE_FILE_PREFIX, fif_bname)) elif not op.isabs(out_fname): out_fname = op.join(dir_name, out_fname) @@ -63,20 +62,48 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-f", "--file", type="string", dest="file", - help="Name of file to modify.", metavar="FILE", - default=None) - parser.add_option("-o", "--output", type="string", dest="output", - help="Name of anonymized output file." - "`anon-` prefix is added to FILE if not given", - metavar="OUTFILE", default=None) - parser.add_option("--keep_his", dest="keep_his", action="http://github.com/store_true", - help="Keep the HIS tag (not advised)", default=False) - parser.add_option("-d", "--daysback", type="int", dest="daysback", - help="Move dates in file backwards by this many days.", - metavar="N_DAYS", default=None) - parser.add_option("--overwrite", dest="overwrite", action="http://github.com/store_true", - help="Overwrite input file.", default=False) + parser.add_option( + "-f", + "--file", + type="string", + dest="file", + help="Name of file to modify.", + metavar="FILE", + default=None, + ) + parser.add_option( + "-o", + "--output", + type="string", + dest="output", + help="Name of anonymized output file." + "`anon-` prefix is added to FILE if not given", + metavar="OUTFILE", + default=None, + ) + parser.add_option( + "--keep_his", + dest="keep_his", + action="http://github.com/store_true", + help="Keep the HIS tag (not advised)", + default=False, + ) + parser.add_option( + "-d", + "--daysback", + type="int", + dest="daysback", + help="Move dates in file backwards by this many days.", + metavar="N_DAYS", + default=None, + ) + parser.add_option( + "--overwrite", + dest="overwrite", + action="http://github.com/store_true", + help="Overwrite input file.", + default=False, + ) options, args = parser.parse_args() if options.file is None: @@ -88,12 +115,12 @@ def run(): keep_his = options.keep_his daysback = options.daysback overwrite = options.overwrite - if not fname.endswith('.fif'): - raise ValueError('%s does not seem to be a .fif file.' % fname) + if not fname.endswith(".fif"): + raise ValueError("%s does not seem to be a .fif file." % fname) mne_anonymize(fname, out_fname, keep_his, daysback, overwrite) -is_main = (__name__ == '__main__') +is_main = __name__ == "__main__" if is_main: run() diff --git a/mne/commands/mne_browse_raw.py b/mne/commands/mne_browse_raw.py index 95b4381cc7b..9c338518e85 100644 --- a/mne/commands/mne_browse_raw.py +++ b/mne/commands/mne_browse_raw.py @@ -24,57 +24,114 @@ def run(): from mne.commands.utils import get_optparser, _add_verbose_flag from mne.viz import _RAW_CLIP_DEF - parser = get_optparser(__file__, usage='usage: %prog raw [options]') - - parser.add_option("--raw", dest="raw_in", - help="Input raw FIF file (can also be specified " - "directly as an argument without the --raw prefix)", - metavar="FILE") - parser.add_option("--proj", dest="proj_in", - help="Projector file", metavar="FILE", - default='') - parser.add_option("--projoff", dest="proj_off", - help="Disable all projectors", - default=False, action="http://github.com/store_true") - parser.add_option("--eve", dest="eve_in", - help="Events file", metavar="FILE", - default='') - parser.add_option("-d", "--duration", dest="duration", type="float", - help="Time window for plotting (s)", - default=10.0) - parser.add_option("-t", "--start", dest="start", type="float", - help="Initial start time for plotting", - default=0.0) - parser.add_option("-n", "--n_channels", dest="n_channels", type="int", - help="Number of channels to plot at a time", - default=20) - parser.add_option("-o", "--order", dest="group_by", - help="Order to use for grouping during plotting " - "('type' or 'original')", default='type') - parser.add_option("-p", "--preload", dest="preload", - help="Preload raw data (for faster navigaton)", - default=False, action="http://github.com/store_true") - parser.add_option("-s", "--show_options", dest="show_options", - help="Show projection options dialog", - default=False) - parser.add_option("--allowmaxshield", dest="maxshield", - help="Allow loading MaxShield processed data", - action="http://github.com/store_true") - parser.add_option("--highpass", dest="highpass", type="float", - help="Display high-pass filter corner frequency", - default=-1) - parser.add_option("--lowpass", dest="lowpass", type="float", - help="Display low-pass filter corner frequency", - default=-1) - parser.add_option("--filtorder", dest="filtorder", type="int", - help="Display filtering IIR order (or 0 to use FIR)", - default=4) - parser.add_option("--clipping", dest="clipping", - help="Enable trace clipping mode, either 'clamp' or " - "'transparent'", default=_RAW_CLIP_DEF) - parser.add_option("--filterchpi", dest="filterchpi", - help="Enable filtering cHPI signals.", default=None, - action="http://github.com/store_true") + parser = get_optparser(__file__, usage="usage: %prog raw [options]") + + parser.add_option( + "--raw", + dest="raw_in", + help="Input raw FIF file (can also be specified " + "directly as an argument without the --raw prefix)", + metavar="FILE", + ) + parser.add_option( + "--proj", dest="proj_in", help="Projector file", metavar="FILE", default="" + ) + parser.add_option( + "--projoff", + dest="proj_off", + help="Disable all projectors", + default=False, + action="http://github.com/store_true", + ) + parser.add_option( + "--eve", dest="eve_in", help="Events file", metavar="FILE", default="" + ) + parser.add_option( + "-d", + "--duration", + dest="duration", + type="float", + help="Time window for plotting (s)", + default=10.0, + ) + parser.add_option( + "-t", + "--start", + dest="start", + type="float", + help="Initial start time for plotting", + default=0.0, + ) + parser.add_option( + "-n", + "--n_channels", + dest="n_channels", + type="int", + help="Number of channels to plot at a time", + default=20, + ) + parser.add_option( + "-o", + "--order", + dest="group_by", + help="Order to use for grouping during plotting " "('type' or 'original')", + default="type", + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Preload raw data (for faster navigaton)", + default=False, + action="http://github.com/store_true", + ) + parser.add_option( + "-s", + "--show_options", + dest="show_options", + help="Show projection options dialog", + default=False, + ) + parser.add_option( + "--allowmaxshield", + dest="maxshield", + help="Allow loading MaxShield processed data", + action="http://github.com/store_true", + ) + parser.add_option( + "--highpass", + dest="highpass", + type="float", + help="Display high-pass filter corner frequency", + default=-1, + ) + parser.add_option( + "--lowpass", + dest="lowpass", + type="float", + help="Display low-pass filter corner frequency", + default=-1, + ) + parser.add_option( + "--filtorder", + dest="filtorder", + type="int", + help="Display filtering IIR order (or 0 to use FIR)", + default=4, + ) + parser.add_option( + "--clipping", + dest="clipping", + help="Enable trace clipping mode, either 'clamp' or " "'transparent'", + default=_RAW_CLIP_DEF, + ) + parser.add_option( + "--filterchpi", + dest="filterchpi", + help="Enable filtering cHPI signals.", + default=None, + action="http://github.com/store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -97,7 +154,7 @@ def run(): filtorder = options.filtorder clipping = options.clipping if isinstance(clipping, str): - if clipping.lower() == 'none': + if clipping.lower() == "none": clipping = None else: try: @@ -113,11 +170,11 @@ def run(): kwargs = dict(preload=preload) if maxshield: - kwargs.update(allow_maxshield='yes') + kwargs.update(allow_maxshield="yes") raw = mne.io.read_raw(raw_in, **kwargs) if len(proj_in) > 0: projs = mne.read_proj(proj_in) - raw.info['projs'] = projs + raw.info["projs"] = projs if len(eve_in) > 0: events = mne.read_events(eve_in) else: @@ -125,17 +182,27 @@ def run(): if filterchpi: if not preload: - raise RuntimeError( - 'Raw data must be preloaded for chpi, use --preload') + raise RuntimeError("Raw data must be preloaded for chpi, use --preload") raw = mne.chpi.filter_chpi(raw) highpass = None if highpass < 0 or filtorder < 0 else highpass lowpass = None if lowpass < 0 or filtorder < 0 else lowpass - raw.plot(duration=duration, start=start, n_channels=n_channels, - group_by=group_by, show_options=show_options, events=events, - highpass=highpass, lowpass=lowpass, filtorder=filtorder, - clipping=clipping, proj=not proj_off, verbose=verbose, - show=True, block=True) + raw.plot( + duration=duration, + start=start, + n_channels=n_channels, + group_by=group_by, + show_options=show_options, + events=events, + highpass=highpass, + lowpass=lowpass, + filtorder=filtorder, + clipping=clipping, + proj=not proj_off, + verbose=verbose, + show=True, + block=True, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_bti2fiff.py b/mne/commands/mne_bti2fiff.py index db3c37fcd8c..88510626822 100644 --- a/mne/commands/mne_bti2fiff.py +++ b/mne/commands/mne_bti2fiff.py @@ -41,29 +41,57 @@ def run(): parser = get_optparser(__file__) - parser.add_option('-p', '--pdf', dest='pdf_fname', - help='Input data file name', metavar='FILE') - parser.add_option('-c', '--config', dest='config_fname', - help='Input config file name', metavar='FILE', - default='config') - parser.add_option('--head_shape', dest='head_shape_fname', - help='Headshape file name', metavar='FILE', - default='hs_file') - parser.add_option('-o', '--out_fname', dest='out_fname', - help='Name of the resulting fiff file', - default='as_data_fname') - parser.add_option('-r', '--rotation_x', dest='rotation_x', type='float', - help='Compensatory rotation about Neuromag x axis, deg', - default=2.0) - parser.add_option('-T', '--translation', dest='translation', type='str', - help='Default translation, meter', - default=(0.00, 0.02, 0.11)) - parser.add_option('--ecg_ch', dest='ecg_ch', type='str', - help='4D ECG channel name', - default='E31') - parser.add_option('--eog_ch', dest='eog_ch', type='str', - help='4D EOG channel names', - default='E63,E64') + parser.add_option( + "-p", "--pdf", dest="pdf_fname", help="Input data file name", metavar="FILE" + ) + parser.add_option( + "-c", + "--config", + dest="config_fname", + help="Input config file name", + metavar="FILE", + default="config", + ) + parser.add_option( + "--head_shape", + dest="head_shape_fname", + help="Headshape file name", + metavar="FILE", + default="hs_file", + ) + parser.add_option( + "-o", + "--out_fname", + dest="out_fname", + help="Name of the resulting fiff file", + default="as_data_fname", + ) + parser.add_option( + "-r", + "--rotation_x", + dest="rotation_x", + type="float", + help="Compensatory rotation about Neuromag x axis, deg", + default=2.0, + ) + parser.add_option( + "-T", + "--translation", + dest="translation", + type="str", + help="Default translation, meter", + default=(0.00, 0.02, 0.11), + ) + parser.add_option( + "--ecg_ch", dest="ecg_ch", type="str", help="4D ECG channel name", default="E31" + ) + parser.add_option( + "--eog_ch", + dest="eog_ch", + type="str", + help="4D EOG channel names", + default="E63,E64", + ) options, args = parser.parse_args() @@ -78,15 +106,20 @@ def run(): rotation_x = options.rotation_x translation = options.translation ecg_ch = options.ecg_ch - eog_ch = options.ecg_ch.split(',') - - if out_fname == 'as_data_fname': - out_fname = pdf_fname + '_raw.fif' - - raw = read_raw_bti(pdf_fname=pdf_fname, config_fname=config_fname, - head_shape_fname=head_shape_fname, - rotation_x=rotation_x, translation=translation, - ecg_ch=ecg_ch, eog_ch=eog_ch) + eog_ch = options.ecg_ch.split(",") + + if out_fname == "as_data_fname": + out_fname = pdf_fname + "_raw.fif" + + raw = read_raw_bti( + pdf_fname=pdf_fname, + config_fname=config_fname, + head_shape_fname=head_shape_fname, + rotation_x=rotation_x, + translation=translation, + ecg_ch=ecg_ch, + eog_ch=eog_ch, + ) raw.save(out_fname) raw.close() diff --git a/mne/commands/mne_clean_eog_ecg.py b/mne/commands/mne_clean_eog_ecg.py index f722a9fea52..b1ffaa74edd 100644 --- a/mne/commands/mne_clean_eog_ecg.py +++ b/mne/commands/mne_clean_eog_ecg.py @@ -18,10 +18,18 @@ import mne -def clean_ecg_eog(in_fif_fname, out_fif_fname=None, eog=True, ecg=True, - ecg_proj_fname=None, eog_proj_fname=None, - ecg_event_fname=None, eog_event_fname=None, in_path='.', - quiet=False): +def clean_ecg_eog( + in_fif_fname, + out_fif_fname=None, + eog=True, + ecg=True, + ecg_proj_fname=None, + eog_proj_fname=None, + ecg_event_fname=None, + eog_event_fname=None, + in_path=".", + quiet=False, +): """Clean ECG from raw fif file. Parameters @@ -45,65 +53,124 @@ def clean_ecg_eog(in_fif_fname, out_fif_fname=None, eog=True, ecg=True, # Reading fif File raw_in = mne.io.read_raw_fif(in_fif_fname) - if in_fif_fname.endswith('_raw.fif') or in_fif_fname.endswith('-raw.fif'): + if in_fif_fname.endswith("_raw.fif") or in_fif_fname.endswith("-raw.fif"): prefix = in_fif_fname[:-8] else: prefix = in_fif_fname[:-4] if out_fif_fname is None: - out_fif_fname = prefix + '_clean_ecg_eog_raw.fif' + out_fif_fname = prefix + "_clean_ecg_eog_raw.fif" if ecg_proj_fname is None: - ecg_proj_fname = prefix + '_ecg-proj.fif' + ecg_proj_fname = prefix + "_ecg-proj.fif" if eog_proj_fname is None: - eog_proj_fname = prefix + '_eog-proj.fif' + eog_proj_fname = prefix + "_eog-proj.fif" if ecg_event_fname is None: - ecg_event_fname = prefix + '_ecg-eve.fif' + ecg_event_fname = prefix + "_ecg-eve.fif" if eog_event_fname is None: - eog_event_fname = prefix + '_eog-eve.fif' + eog_event_fname = prefix + "_eog-eve.fif" - print('Implementing ECG and EOG artifact rejection on data') + print("Implementing ECG and EOG artifact rejection on data") kwargs = dict() if quiet else dict(stdout=None, stderr=None) if ecg: ecg_events, _, _ = mne.preprocessing.find_ecg_events( - raw_in, reject_by_annotation=True) + raw_in, reject_by_annotation=True + ) print("Writing ECG events in %s" % ecg_event_fname) mne.write_events(ecg_event_fname, ecg_events) - print('Computing ECG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--events', ecg_event_fname, '--makeproj', - '--projtmin', '-0.08', '--projtmax', '0.08', - '--saveprojtag', '_ecg-proj', '--projnmag', '2', - '--projngrad', '1', '--projevent', '999', '--highpass', '5', - '--lowpass', '35', '--projmagrej', '4000', - '--projgradrej', '3000') + print("Computing ECG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--events", + ecg_event_fname, + "--makeproj", + "--projtmin", + "-0.08", + "--projtmax", + "0.08", + "--saveprojtag", + "_ecg-proj", + "--projnmag", + "2", + "--projngrad", + "1", + "--projevent", + "999", + "--highpass", + "5", + "--lowpass", + "35", + "--projmagrej", + "4000", + "--projgradrej", + "3000", + ) mne.utils.run_subprocess(command, **kwargs) if eog: eog_events = mne.preprocessing.find_eog_events(raw_in) print("Writing EOG events in %s" % eog_event_fname) mne.write_events(eog_event_fname, eog_events) - print('Computing EOG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--events', eog_event_fname, '--makeproj', - '--projtmin', '-0.15', '--projtmax', '0.15', - '--saveprojtag', '_eog-proj', '--projnmag', '2', - '--projngrad', '2', '--projevent', '998', '--lowpass', '35', - '--projmagrej', '4000', '--projgradrej', '3000') + print("Computing EOG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--events", + eog_event_fname, + "--makeproj", + "--projtmin", + "-0.15", + "--projtmax", + "0.15", + "--saveprojtag", + "_eog-proj", + "--projnmag", + "2", + "--projngrad", + "2", + "--projevent", + "998", + "--lowpass", + "35", + "--projmagrej", + "4000", + "--projgradrej", + "3000", + ) mne.utils.run_subprocess(command, **kwargs) if out_fif_fname is not None: # Applying the ECG EOG projector - print('Applying ECG EOG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--proj', in_fif_fname, '--projoff', '--save', - out_fif_fname, '--filteroff', - '--proj', ecg_proj_fname, '--proj', eog_proj_fname) + print("Applying ECG EOG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--proj", + in_fif_fname, + "--projoff", + "--save", + out_fif_fname, + "--filteroff", + "--proj", + ecg_proj_fname, + "--proj", + eog_proj_fname, + ) mne.utils.run_subprocess(command, **kwargs) - print('Done removing artifacts.') + print("Done removing artifacts.") print("Cleaned raw data saved in: %s" % out_fif_fname) - print('IMPORTANT : Please eye-ball the data !!') + print("IMPORTANT : Please eye-ball the data !!") else: - print('Projection not applied to raw data.') + print("Projection not applied to raw data.") def run(): @@ -112,17 +179,41 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("-o", "--out", dest="raw_out", - help="Output raw FIF file", metavar="FILE", - default=None) - parser.add_option("-e", "--no-eog", dest="eog", action="http://github.com/store_false", - help="Remove EOG", default=True) - parser.add_option("-c", "--no-ecg", dest="ecg", action="http://github.com/store_false", - help="Remove ECG", default=True) - parser.add_option("-q", "--quiet", dest="quiet", action="http://github.com/store_true", - help="Suppress mne_process_raw output", default=False) + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "-o", + "--out", + dest="raw_out", + help="Output raw FIF file", + metavar="FILE", + default=None, + ) + parser.add_option( + "-e", + "--no-eog", + dest="eog", + action="http://github.com/store_false", + help="Remove EOG", + default=True, + ) + parser.add_option( + "-c", + "--no-ecg", + dest="ecg", + action="http://github.com/store_false", + help="Remove ECG", + default=True, + ) + parser.add_option( + "-q", + "--quiet", + dest="quiet", + action="http://github.com/store_true", + help="Suppress mne_process_raw output", + default=False, + ) options, args = parser.parse_args() diff --git a/mne/commands/mne_compare_fiff.py b/mne/commands/mne_compare_fiff.py index b616a3e4072..fe05d636592 100644 --- a/mne/commands/mne_compare_fiff.py +++ b/mne/commands/mne_compare_fiff.py @@ -18,7 +18,8 @@ def run(): """Run command.""" parser = mne.commands.utils.get_optparser( - __file__, usage='mne compare_fiff ') + __file__, usage="mne compare_fiff " + ) options, args = parser.parse_args() if len(args) != 2: parser.print_help() diff --git a/mne/commands/mne_compute_proj_ecg.py b/mne/commands/mne_compute_proj_ecg.py index c42798be3be..bb366f9d3e2 100644 --- a/mne/commands/mne_compute_proj_ecg.py +++ b/mne/commands/mne_compute_proj_ecg.py @@ -24,97 +24,191 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("--tmin", dest="tmin", type="float", - help="Time before event in seconds", - default=-0.2) - parser.add_option("--tmax", dest="tmax", type="float", - help="Time after event in seconds", - default=0.4) - parser.add_option("-g", "--n-grad", dest="n_grad", type="int", - help="Number of SSP vectors for gradiometers", - default=2) - parser.add_option("-m", "--n-mag", dest="n_mag", type="int", - help="Number of SSP vectors for magnetometers", - default=2) - parser.add_option("-e", "--n-eeg", dest="n_eeg", type="int", - help="Number of SSP vectors for EEG", - default=2) - parser.add_option("--l-freq", dest="l_freq", type="float", - help="Filter low cut-off frequency in Hz", - default=1) - parser.add_option("--h-freq", dest="h_freq", type="float", - help="Filter high cut-off frequency in Hz", - default=100) - parser.add_option("--ecg-l-freq", dest="ecg_l_freq", type="float", - help="Filter low cut-off frequency in Hz used " - "for ECG event detection", - default=5) - parser.add_option("--ecg-h-freq", dest="ecg_h_freq", type="float", - help="Filter high cut-off frequency in Hz used " - "for ECG event detection", - default=35) - parser.add_option("-p", "--preload", dest="preload", - help="Temporary file used during computation " - "(to save memory)", - default=True) - parser.add_option("-a", "--average", dest="average", action="http://github.com/store_true", - help="Compute SSP after averaging", - default=False) - parser.add_option("--proj", dest="proj", - help="Use SSP projections from a fif file.", - default=None) - parser.add_option("--filtersize", dest="filter_length", type="int", - help="Number of taps to use for filtering", - default=2048) - parser.add_option("-j", "--n-jobs", dest="n_jobs", type="int", - help="Number of jobs to run in parallel", - default=1) - parser.add_option("-c", "--channel", dest="ch_name", - help="Channel to use for ECG detection " - "(Required if no ECG found)", - default=None) - parser.add_option("--rej-grad", dest="rej_grad", type="float", - help="Gradiometers rejection parameter " - "in fT/cm (peak to peak amplitude)", - default=2000) - parser.add_option("--rej-mag", dest="rej_mag", type="float", - help="Magnetometers rejection parameter " - "in fT (peak to peak amplitude)", - default=3000) - parser.add_option("--rej-eeg", dest="rej_eeg", type="float", - help="EEG rejection parameter in µV " - "(peak to peak amplitude)", - default=50) - parser.add_option("--rej-eog", dest="rej_eog", type="float", - help="EOG rejection parameter in µV " - "(peak to peak amplitude)", - default=250) - parser.add_option("--avg-ref", dest="avg_ref", action="http://github.com/store_true", - help="Add EEG average reference proj", - default=False) - parser.add_option("--no-proj", dest="no_proj", action="http://github.com/store_true", - help="Exclude the SSP projectors currently " - "in the fiff file", - default=False) - parser.add_option("--bad", dest="bad_fname", - help="Text file containing bad channels list " - "(one per line)", - default=None) - parser.add_option("--event-id", dest="event_id", type="int", - help="ID to use for events", - default=999) - parser.add_option("--event-raw", dest="raw_event_fname", - help="raw file to use for event detection", - default=None) - parser.add_option("--tstart", dest="tstart", type="float", - help="Start artifact detection after tstart seconds", - default=0.) - parser.add_option("--qrsthr", dest="qrs_threshold", type="string", - help="QRS detection threshold. Between 0 and 1. Can " - "also be 'auto' for automatic selection", - default='auto') + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "--tmin", + dest="tmin", + type="float", + help="Time before event in seconds", + default=-0.2, + ) + parser.add_option( + "--tmax", + dest="tmax", + type="float", + help="Time after event in seconds", + default=0.4, + ) + parser.add_option( + "-g", + "--n-grad", + dest="n_grad", + type="int", + help="Number of SSP vectors for gradiometers", + default=2, + ) + parser.add_option( + "-m", + "--n-mag", + dest="n_mag", + type="int", + help="Number of SSP vectors for magnetometers", + default=2, + ) + parser.add_option( + "-e", + "--n-eeg", + dest="n_eeg", + type="int", + help="Number of SSP vectors for EEG", + default=2, + ) + parser.add_option( + "--l-freq", + dest="l_freq", + type="float", + help="Filter low cut-off frequency in Hz", + default=1, + ) + parser.add_option( + "--h-freq", + dest="h_freq", + type="float", + help="Filter high cut-off frequency in Hz", + default=100, + ) + parser.add_option( + "--ecg-l-freq", + dest="ecg_l_freq", + type="float", + help="Filter low cut-off frequency in Hz used " "for ECG event detection", + default=5, + ) + parser.add_option( + "--ecg-h-freq", + dest="ecg_h_freq", + type="float", + help="Filter high cut-off frequency in Hz used " "for ECG event detection", + default=35, + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Temporary file used during computation " "(to save memory)", + default=True, + ) + parser.add_option( + "-a", + "--average", + dest="average", + action="http://github.com/store_true", + help="Compute SSP after averaging", + default=False, + ) + parser.add_option( + "--proj", dest="proj", help="Use SSP projections from a fif file.", default=None + ) + parser.add_option( + "--filtersize", + dest="filter_length", + type="int", + help="Number of taps to use for filtering", + default=2048, + ) + parser.add_option( + "-j", + "--n-jobs", + dest="n_jobs", + type="int", + help="Number of jobs to run in parallel", + default=1, + ) + parser.add_option( + "-c", + "--channel", + dest="ch_name", + help="Channel to use for ECG detection " "(Required if no ECG found)", + default=None, + ) + parser.add_option( + "--rej-grad", + dest="rej_grad", + type="float", + help="Gradiometers rejection parameter " "in fT/cm (peak to peak amplitude)", + default=2000, + ) + parser.add_option( + "--rej-mag", + dest="rej_mag", + type="float", + help="Magnetometers rejection parameter " "in fT (peak to peak amplitude)", + default=3000, + ) + parser.add_option( + "--rej-eeg", + dest="rej_eeg", + type="float", + help="EEG rejection parameter in µV " "(peak to peak amplitude)", + default=50, + ) + parser.add_option( + "--rej-eog", + dest="rej_eog", + type="float", + help="EOG rejection parameter in µV " "(peak to peak amplitude)", + default=250, + ) + parser.add_option( + "--avg-ref", + dest="avg_ref", + action="http://github.com/store_true", + help="Add EEG average reference proj", + default=False, + ) + parser.add_option( + "--no-proj", + dest="no_proj", + action="http://github.com/store_true", + help="Exclude the SSP projectors currently " "in the fiff file", + default=False, + ) + parser.add_option( + "--bad", + dest="bad_fname", + help="Text file containing bad channels list " "(one per line)", + default=None, + ) + parser.add_option( + "--event-id", + dest="event_id", + type="int", + help="ID to use for events", + default=999, + ) + parser.add_option( + "--event-raw", + dest="raw_event_fname", + help="raw file to use for event detection", + default=None, + ) + parser.add_option( + "--tstart", + dest="tstart", + type="float", + help="Start artifact detection after tstart seconds", + default=0.0, + ) + parser.add_option( + "--qrsthr", + dest="qrs_threshold", + type="string", + help="QRS detection threshold. Between 0 and 1. Can " + "also be 'auto' for automatic selection", + default="auto", + ) options, args = parser.parse_args() @@ -138,10 +232,12 @@ def run(): filter_length = options.filter_length n_jobs = options.n_jobs ch_name = options.ch_name - reject = dict(grad=1e-13 * float(options.rej_grad), - mag=1e-15 * float(options.rej_mag), - eeg=1e-6 * float(options.rej_eeg), - eog=1e-6 * float(options.rej_eog)) + reject = dict( + grad=1e-13 * float(options.rej_grad), + mag=1e-15 * float(options.rej_mag), + eeg=1e-6 * float(options.rej_eeg), + eog=1e-6 * float(options.rej_eog), + ) avg_ref = options.avg_ref no_proj = options.no_proj bad_fname = options.bad_fname @@ -150,30 +246,30 @@ def run(): raw_event_fname = options.raw_event_fname tstart = options.tstart qrs_threshold = options.qrs_threshold - if qrs_threshold != 'auto': + if qrs_threshold != "auto": try: qrs_threshold = float(qrs_threshold) except ValueError: raise ValueError('qrsthr must be "auto" or a float') if bad_fname is not None: - with open(bad_fname, 'r') as fid: + with open(bad_fname, "r") as fid: bads = [w.rstrip() for w in fid.readlines()] - print('Bad channels read : %s' % bads) + print("Bad channels read : %s" % bads) else: bads = [] - if raw_in.endswith('_raw.fif') or raw_in.endswith('-raw.fif'): + if raw_in.endswith("_raw.fif") or raw_in.endswith("-raw.fif"): prefix = raw_in[:-8] else: prefix = raw_in[:-4] - ecg_event_fname = prefix + '_ecg-eve.fif' + ecg_event_fname = prefix + "_ecg-eve.fif" if average: - ecg_proj_fname = prefix + '_ecg_avg-proj.fif' + ecg_proj_fname = prefix + "_ecg_avg-proj.fif" else: - ecg_proj_fname = prefix + '_ecg-proj.fif' + ecg_proj_fname = prefix + "_ecg-proj.fif" raw = mne.io.read_raw_fif(raw_in, preload=preload) @@ -184,10 +280,31 @@ def run(): flat = None projs, events = mne.preprocessing.compute_proj_ecg( - raw, raw_event, tmin, tmax, n_grad, n_mag, n_eeg, l_freq, h_freq, - average, filter_length, n_jobs, ch_name, reject, flat, bads, avg_ref, - no_proj, event_id, ecg_l_freq, ecg_h_freq, tstart, qrs_threshold, - copy=False) + raw, + raw_event, + tmin, + tmax, + n_grad, + n_mag, + n_eeg, + l_freq, + h_freq, + average, + filter_length, + n_jobs, + ch_name, + reject, + flat, + bads, + avg_ref, + no_proj, + event_id, + ecg_l_freq, + ecg_h_freq, + tstart, + qrs_threshold, + copy=False, + ) raw.close() @@ -195,7 +312,7 @@ def run(): raw_event.close() if proj_fname is not None: - print('Including SSP projections from : %s' % proj_fname) + print("Including SSP projections from : %s" % proj_fname) # append the ecg projs, so they are last in the list projs = mne.read_proj(proj_fname) + projs diff --git a/mne/commands/mne_compute_proj_eog.py b/mne/commands/mne_compute_proj_eog.py index 3494ffa47af..42c93513122 100644 --- a/mne/commands/mne_compute_proj_eog.py +++ b/mne/commands/mne_compute_proj_eog.py @@ -34,77 +34,184 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("--tmin", dest="tmin", type="float", - help="Time before event in seconds", default=-0.2) - parser.add_option("--tmax", dest="tmax", type="float", - help="Time after event in seconds", default=0.2) - parser.add_option("-g", "--n-grad", dest="n_grad", type="int", - help="Number of SSP vectors for gradiometers", - default=2) - parser.add_option("-m", "--n-mag", dest="n_mag", type="int", - help="Number of SSP vectors for magnetometers", - default=2) - parser.add_option("-e", "--n-eeg", dest="n_eeg", type="int", - help="Number of SSP vectors for EEG", default=2) - parser.add_option("--l-freq", dest="l_freq", type="float", - help="Filter low cut-off frequency in Hz", - default=1) - parser.add_option("--h-freq", dest="h_freq", type="float", - help="Filter high cut-off frequency in Hz", - default=35) - parser.add_option("--eog-l-freq", dest="eog_l_freq", type="float", - help="Filter low cut-off frequency in Hz used for " - "EOG event detection", default=1) - parser.add_option("--eog-h-freq", dest="eog_h_freq", type="float", - help="Filter high cut-off frequency in Hz used for " - "EOG event detection", default=10) - parser.add_option("-p", "--preload", dest="preload", - help="Temporary file used during computation (to " - "save memory)", default=True) - parser.add_option("-a", "--average", dest="average", action="http://github.com/store_true", - help="Compute SSP after averaging", - default=False) - parser.add_option("--proj", dest="proj", - help="Use SSP projections from a fif file.", - default=None) - parser.add_option("--filtersize", dest="filter_length", type="int", - help="Number of taps to use for filtering", - default=2048) - parser.add_option("-j", "--n-jobs", dest="n_jobs", type="int", - help="Number of jobs to run in parallel", default=1) - parser.add_option("--rej-grad", dest="rej_grad", type="float", - help="Gradiometers rejection parameter in fT/cm (peak " - "to peak amplitude)", default=2000) - parser.add_option("--rej-mag", dest="rej_mag", type="float", - help="Magnetometers rejection parameter in fT (peak to " - "peak amplitude)", default=3000) - parser.add_option("--rej-eeg", dest="rej_eeg", type="float", - help="EEG rejection parameter in µV (peak to peak " - "amplitude)", default=50) - parser.add_option("--rej-eog", dest="rej_eog", type="float", - help="EOG rejection parameter in µV (peak to peak " - "amplitude)", default=1e9) - parser.add_option("--avg-ref", dest="avg_ref", action="http://github.com/store_true", - help="Add EEG average reference proj", - default=False) - parser.add_option("--no-proj", dest="no_proj", action="http://github.com/store_true", - help="Exclude the SSP projectors currently in the " - "fiff file", default=False) - parser.add_option("--bad", dest="bad_fname", - help="Text file containing bad channels list " - "(one per line)", default=None) - parser.add_option("--event-id", dest="event_id", type="int", - help="ID to use for events", default=998) - parser.add_option("--event-raw", dest="raw_event_fname", - help="raw file to use for event detection", default=None) - parser.add_option("--tstart", dest="tstart", type="float", - help="Start artifact detection after tstart seconds", - default=0.) - parser.add_option("-c", "--channel", dest="ch_name", type="string", - help="Custom EOG channel(s), comma separated", - default=None) + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "--tmin", + dest="tmin", + type="float", + help="Time before event in seconds", + default=-0.2, + ) + parser.add_option( + "--tmax", + dest="tmax", + type="float", + help="Time after event in seconds", + default=0.2, + ) + parser.add_option( + "-g", + "--n-grad", + dest="n_grad", + type="int", + help="Number of SSP vectors for gradiometers", + default=2, + ) + parser.add_option( + "-m", + "--n-mag", + dest="n_mag", + type="int", + help="Number of SSP vectors for magnetometers", + default=2, + ) + parser.add_option( + "-e", + "--n-eeg", + dest="n_eeg", + type="int", + help="Number of SSP vectors for EEG", + default=2, + ) + parser.add_option( + "--l-freq", + dest="l_freq", + type="float", + help="Filter low cut-off frequency in Hz", + default=1, + ) + parser.add_option( + "--h-freq", + dest="h_freq", + type="float", + help="Filter high cut-off frequency in Hz", + default=35, + ) + parser.add_option( + "--eog-l-freq", + dest="eog_l_freq", + type="float", + help="Filter low cut-off frequency in Hz used for " "EOG event detection", + default=1, + ) + parser.add_option( + "--eog-h-freq", + dest="eog_h_freq", + type="float", + help="Filter high cut-off frequency in Hz used for " "EOG event detection", + default=10, + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Temporary file used during computation (to " "save memory)", + default=True, + ) + parser.add_option( + "-a", + "--average", + dest="average", + action="http://github.com/store_true", + help="Compute SSP after averaging", + default=False, + ) + parser.add_option( + "--proj", dest="proj", help="Use SSP projections from a fif file.", default=None + ) + parser.add_option( + "--filtersize", + dest="filter_length", + type="int", + help="Number of taps to use for filtering", + default=2048, + ) + parser.add_option( + "-j", + "--n-jobs", + dest="n_jobs", + type="int", + help="Number of jobs to run in parallel", + default=1, + ) + parser.add_option( + "--rej-grad", + dest="rej_grad", + type="float", + help="Gradiometers rejection parameter in fT/cm (peak " "to peak amplitude)", + default=2000, + ) + parser.add_option( + "--rej-mag", + dest="rej_mag", + type="float", + help="Magnetometers rejection parameter in fT (peak to " "peak amplitude)", + default=3000, + ) + parser.add_option( + "--rej-eeg", + dest="rej_eeg", + type="float", + help="EEG rejection parameter in µV (peak to peak " "amplitude)", + default=50, + ) + parser.add_option( + "--rej-eog", + dest="rej_eog", + type="float", + help="EOG rejection parameter in µV (peak to peak " "amplitude)", + default=1e9, + ) + parser.add_option( + "--avg-ref", + dest="avg_ref", + action="http://github.com/store_true", + help="Add EEG average reference proj", + default=False, + ) + parser.add_option( + "--no-proj", + dest="no_proj", + action="http://github.com/store_true", + help="Exclude the SSP projectors currently in the " "fiff file", + default=False, + ) + parser.add_option( + "--bad", + dest="bad_fname", + help="Text file containing bad channels list " "(one per line)", + default=None, + ) + parser.add_option( + "--event-id", + dest="event_id", + type="int", + help="ID to use for events", + default=998, + ) + parser.add_option( + "--event-raw", + dest="raw_event_fname", + help="raw file to use for event detection", + default=None, + ) + parser.add_option( + "--tstart", + dest="tstart", + type="float", + help="Start artifact detection after tstart seconds", + default=0.0, + ) + parser.add_option( + "-c", + "--channel", + dest="ch_name", + type="string", + help="Custom EOG channel(s), comma separated", + default=None, + ) options, args = parser.parse_args() @@ -127,10 +234,12 @@ def run(): preload = options.preload filter_length = options.filter_length n_jobs = options.n_jobs - reject = dict(grad=1e-13 * float(options.rej_grad), - mag=1e-15 * float(options.rej_mag), - eeg=1e-6 * float(options.rej_eeg), - eog=1e-6 * float(options.rej_eog)) + reject = dict( + grad=1e-13 * float(options.rej_grad), + mag=1e-15 * float(options.rej_mag), + eeg=1e-6 * float(options.rej_eeg), + eog=1e-6 * float(options.rej_eog), + ) avg_ref = options.avg_ref no_proj = options.no_proj bad_fname = options.bad_fname @@ -141,23 +250,23 @@ def run(): ch_name = options.ch_name if bad_fname is not None: - with open(bad_fname, 'r') as fid: + with open(bad_fname, "r") as fid: bads = [w.rstrip() for w in fid.readlines()] - print('Bad channels read : %s' % bads) + print("Bad channels read : %s" % bads) else: bads = [] - if raw_in.endswith('_raw.fif') or raw_in.endswith('-raw.fif'): + if raw_in.endswith("_raw.fif") or raw_in.endswith("-raw.fif"): prefix = raw_in[:-8] else: prefix = raw_in[:-4] - eog_event_fname = prefix + '_eog-eve.fif' + eog_event_fname = prefix + "_eog-eve.fif" if average: - eog_proj_fname = prefix + '_eog_avg-proj.fif' + eog_proj_fname = prefix + "_eog_avg-proj.fif" else: - eog_proj_fname = prefix + '_eog-proj.fif' + eog_proj_fname = prefix + "_eog-proj.fif" raw = mne.io.read_raw_fif(raw_in, preload=preload) @@ -168,13 +277,30 @@ def run(): flat = None projs, events = mne.preprocessing.compute_proj_eog( - raw=raw, raw_event=raw_event, tmin=tmin, tmax=tmax, n_grad=n_grad, - n_mag=n_mag, n_eeg=n_eeg, l_freq=l_freq, h_freq=h_freq, - average=average, filter_length=filter_length, - n_jobs=n_jobs, reject=reject, flat=flat, bads=bads, - avg_ref=avg_ref, no_proj=no_proj, event_id=event_id, - eog_l_freq=eog_l_freq, eog_h_freq=eog_h_freq, - tstart=tstart, ch_name=ch_name, copy=False) + raw=raw, + raw_event=raw_event, + tmin=tmin, + tmax=tmax, + n_grad=n_grad, + n_mag=n_mag, + n_eeg=n_eeg, + l_freq=l_freq, + h_freq=h_freq, + average=average, + filter_length=filter_length, + n_jobs=n_jobs, + reject=reject, + flat=flat, + bads=bads, + avg_ref=avg_ref, + no_proj=no_proj, + event_id=event_id, + eog_l_freq=eog_l_freq, + eog_h_freq=eog_h_freq, + tstart=tstart, + ch_name=ch_name, + copy=False, + ) raw.close() @@ -182,7 +308,7 @@ def run(): raw_event.close() if proj_fname is not None: - print('Including SSP projections from : %s' % proj_fname) + print("Including SSP projections from : %s" % proj_fname) # append the eog projs, so they are last in the list projs = mne.read_proj(proj_fname) + projs @@ -196,6 +322,6 @@ def run(): mne.write_events(eog_event_fname, events) -is_main = (__name__ == '__main__') +is_main = __name__ == "__main__" if is_main: run() diff --git a/mne/commands/mne_coreg.py b/mne/commands/mne_coreg.py index 0e25c1f44de..dad18d278aa 100644 --- a/mne/commands/mne_coreg.py +++ b/mne/commands/mne_coreg.py @@ -22,51 +22,98 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - default=None, help="Subjects directory") - parser.add_option("-s", "--subject", dest="subject", default=None, - help="Subject name") - parser.add_option("-f", "--fiff", dest="inst", default=None, - help="FIFF file with digitizer data for coregistration") - parser.add_option("-t", "--tabbed", dest="tabbed", action="http://github.com/store_true", - default=False, help="Option for small screens: Combine " - "the data source panel and the coregistration panel " - "into a single panel with tabs.") - parser.add_option("--no-guess-mri", dest="guess_mri_subject", - action='http://github.com/store_false', default=None, - help="Prevent the GUI from automatically guessing and " - "changing the MRI subject when a new head shape source " - "file is selected.") - parser.add_option("--head-opacity", type=float, default=None, - dest="head_opacity", - help="The opacity of the head surface, in the range " - "[0, 1].") - parser.add_option("--high-res-head", - action='http://github.com/store_true', default=False, dest="high_res_head", - help="Use a high-resolution head surface.") - parser.add_option("--low-res-head", - action='http://github.com/store_true', default=False, dest="low_res_head", - help="Use a low-resolution head surface.") - parser.add_option('--trans', dest='trans', default=None, - help='Head<->MRI transform FIF file ("-trans.fif")') - parser.add_option('--interaction', - type=str, default=None, dest='interaction', - help='Interaction style to use, can be "trackball" or ' - '"terrain".') - parser.add_option('--scale', - type=float, default=None, dest='scale', - help='Scale factor for the scene.') - parser.add_option('--simple-rendering', action='http://github.com/store_false', - dest='advanced_rendering', - help='Use simplified OpenGL rendering') + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + default=None, + help="Subjects directory", + ) + parser.add_option( + "-s", "--subject", dest="subject", default=None, help="Subject name" + ) + parser.add_option( + "-f", + "--fiff", + dest="inst", + default=None, + help="FIFF file with digitizer data for coregistration", + ) + parser.add_option( + "-t", + "--tabbed", + dest="tabbed", + action="http://github.com/store_true", + default=False, + help="Option for small screens: Combine " + "the data source panel and the coregistration panel " + "into a single panel with tabs.", + ) + parser.add_option( + "--no-guess-mri", + dest="guess_mri_subject", + action="http://github.com/store_false", + default=None, + help="Prevent the GUI from automatically guessing and " + "changing the MRI subject when a new head shape source " + "file is selected.", + ) + parser.add_option( + "--head-opacity", + type=float, + default=None, + dest="head_opacity", + help="The opacity of the head surface, in the range " "[0, 1].", + ) + parser.add_option( + "--high-res-head", + action="http://github.com/store_true", + default=False, + dest="high_res_head", + help="Use a high-resolution head surface.", + ) + parser.add_option( + "--low-res-head", + action="http://github.com/store_true", + default=False, + dest="low_res_head", + help="Use a low-resolution head surface.", + ) + parser.add_option( + "--trans", + dest="trans", + default=None, + help='Head<->MRI transform FIF file ("-trans.fif")', + ) + parser.add_option( + "--interaction", + type=str, + default=None, + dest="interaction", + help='Interaction style to use, can be "trackball" or ' '"terrain".', + ) + parser.add_option( + "--scale", + type=float, + default=None, + dest="scale", + help="Scale factor for the scene.", + ) + parser.add_option( + "--simple-rendering", + action="http://github.com/store_false", + dest="advanced_rendering", + help="Use simplified OpenGL rendering", + ) _add_verbose_flag(parser) options, args = parser.parse_args() if options.low_res_head: if options.high_res_head: - raise ValueError("Can't specify --high-res-head and " - "--low-res-head at the same time.") + raise ValueError( + "Can't specify --high-res-head and " "--low-res-head at the same time." + ) head_high_res = False elif options.high_res_head: head_high_res = True @@ -81,18 +128,25 @@ def run(): if trans is not None: trans = op.expanduser(trans) import faulthandler + faulthandler.enable() mne.gui.coregistration( - options.tabbed, inst=options.inst, subject=options.subject, + options.tabbed, + inst=options.inst, + subject=options.subject, subjects_dir=subjects_dir, guess_mri_subject=options.guess_mri_subject, - head_opacity=options.head_opacity, head_high_res=head_high_res, - trans=trans, scrollable=True, + head_opacity=options.head_opacity, + head_high_res=head_high_res, + trans=trans, + scrollable=True, interaction=options.interaction, scale=options.scale, advanced_rendering=options.advanced_rendering, - show=True, block=True, - verbose=options.verbose) + show=True, + block=True, + verbose=options.verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_flash_bem.py b/mne/commands/mne_flash_bem.py index 3556b58a78d..8ffaf57b816 100644 --- a/mne/commands/mne_flash_bem.py +++ b/mne/commands/mne_flash_bem.py @@ -41,7 +41,7 @@ def _vararg_callback(option, opt_str, value, parser): break value.append(arg) - del parser.rargs[:len(value)] + del parser.rargs[: len(value)] setattr(parser.values, option.dest, value) @@ -51,45 +51,103 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=None) - parser.add_option("-3", "--flash30", "--noflash30", dest="flash30", - action="http://github.com/callback", callback=_vararg_callback, - help=("The 30-degree flip angle data. If no argument do " - "not use flash30. If arguments are given, them as " - "file names.")) - parser.add_option("-5", "--flash5", dest="flash5", - action="http://github.com/callback", callback=_vararg_callback, - help=("Path to the multiecho flash 5 images. " - "Can be one file or one per echo."),) - parser.add_option("-r", "--registered", dest="registered", - action="http://github.com/store_true", default=False, - help=("Set if the Flash MRI images have already " - "been registered with the T1.mgz file.")) - parser.add_option("-n", "--noconvert", dest="noconvert", - action="http://github.com/store_true", default=False, - help=("[DEPRECATED] Assume that the Flash MRI images " - "have already been converted to mgz files")) - parser.add_option("-u", "--unwarp", dest="unwarp", - action="http://github.com/store_true", default=False, - help=("Run grad_unwarp with -unwarp " - "option on each of the converted data sets")) - parser.add_option("-o", "--overwrite", dest="overwrite", - action="http://github.com/store_true", default=False, - help="Write over existing .surf files in bem folder") - parser.add_option("-v", "--view", dest="show", action="http://github.com/store_true", - help="Show BEM model in 3D for visual inspection", - default=False) - parser.add_option("--copy", dest="copy", - help="Use copies instead of symlinks for surfaces", - action="http://github.com/store_true") - parser.add_option("-p", "--flash-path", dest="flash_path", - default=None, - help="[DEPRECATED] The directory containing flash5.mgz " - "files (defaults to " - "$SUBJECTS_DIR/$SUBJECT/mri/flash/parameter_maps") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name", default=None + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-3", + "--flash30", + "--noflash30", + dest="flash30", + action="http://github.com/callback", + callback=_vararg_callback, + help=( + "The 30-degree flip angle data. If no argument do " + "not use flash30. If arguments are given, them as " + "file names." + ), + ) + parser.add_option( + "-5", + "--flash5", + dest="flash5", + action="http://github.com/callback", + callback=_vararg_callback, + help=( + "Path to the multiecho flash 5 images. " "Can be one file or one per echo." + ), + ) + parser.add_option( + "-r", + "--registered", + dest="registered", + action="http://github.com/store_true", + default=False, + help=( + "Set if the Flash MRI images have already " + "been registered with the T1.mgz file." + ), + ) + parser.add_option( + "-n", + "--noconvert", + dest="noconvert", + action="http://github.com/store_true", + default=False, + help=( + "[DEPRECATED] Assume that the Flash MRI images " + "have already been converted to mgz files" + ), + ) + parser.add_option( + "-u", + "--unwarp", + dest="unwarp", + action="http://github.com/store_true", + default=False, + help=( + "Run grad_unwarp with -unwarp " + "option on each of the converted data sets" + ), + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + action="http://github.com/store_true", + default=False, + help="Write over existing .surf files in bem folder", + ) + parser.add_option( + "-v", + "--view", + dest="show", + action="http://github.com/store_true", + help="Show BEM model in 3D for visual inspection", + default=False, + ) + parser.add_option( + "--copy", + dest="copy", + help="Use copies instead of symlinks for surfaces", + action="http://github.com/store_true", + ) + parser.add_option( + "-p", + "--flash-path", + dest="flash_path", + default=None, + help="[DEPRECATED] The directory containing flash5.mgz " + "files (defaults to " + "$SUBJECTS_DIR/$SUBJECT/mri/flash/parameter_maps", + ) options, _ = parser.parse_args() @@ -111,15 +169,26 @@ def run(): if options.subject is None: parser.print_help() - raise RuntimeError('The subject argument must be set') + raise RuntimeError("The subject argument must be set") flash5_img = convert_flash_mris( - subject=subject, subjects_dir=subjects_dir, flash5=flash5, - flash30=flash30, unwarp=unwarp, verbose=True + subject=subject, + subjects_dir=subjects_dir, + flash5=flash5, + flash30=flash30, + unwarp=unwarp, + verbose=True, + ) + make_flash_bem( + subject=subject, + subjects_dir=subjects_dir, + overwrite=overwrite, + show=show, + copy=copy, + register=register, + flash5_img=flash5_img, + verbose=True, ) - make_flash_bem(subject=subject, subjects_dir=subjects_dir, - overwrite=overwrite, show=show, copy=copy, - register=register, flash5_img=flash5_img, verbose=True) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_freeview_bem_surfaces.py b/mne/commands/mne_freeview_bem_surfaces.py index f5a65d9fb79..646049b6616 100644 --- a/mne/commands/mne_freeview_bem_surfaces.py +++ b/mne/commands/mne_freeview_bem_surfaces.py @@ -38,39 +38,40 @@ def freeview_bem_surfaces(subject, subjects_dir, method): subject_dir = op.join(subjects_dir, subject) if not op.isdir(subject_dir): - raise ValueError("Wrong path: '{}'. Check subjects-dir or" - "subject argument.".format(subject_dir)) + raise ValueError( + "Wrong path: '{}'. Check subjects-dir or" + "subject argument.".format(subject_dir) + ) env = os.environ.copy() - env['SUBJECT'] = subject - env['SUBJECTS_DIR'] = subjects_dir + env["SUBJECT"] = subject + env["SUBJECTS_DIR"] = subjects_dir - if 'FREESURFER_HOME' not in env: - raise RuntimeError('The FreeSurfer environment needs to be set up.') + if "FREESURFER_HOME" not in env: + raise RuntimeError("The FreeSurfer environment needs to be set up.") - mri_dir = op.join(subject_dir, 'mri') - bem_dir = op.join(subject_dir, 'bem') - mri = op.join(mri_dir, 'T1.mgz') + mri_dir = op.join(subject_dir, "mri") + bem_dir = op.join(subject_dir, "bem") + mri = op.join(mri_dir, "T1.mgz") - if method == 'watershed': - bem_dir = op.join(bem_dir, 'watershed') - outer_skin = op.join(bem_dir, '%s_outer_skin_surface' % subject) - outer_skull = op.join(bem_dir, '%s_outer_skull_surface' % subject) - inner_skull = op.join(bem_dir, '%s_inner_skull_surface' % subject) + if method == "watershed": + bem_dir = op.join(bem_dir, "watershed") + outer_skin = op.join(bem_dir, "%s_outer_skin_surface" % subject) + outer_skull = op.join(bem_dir, "%s_outer_skull_surface" % subject) + inner_skull = op.join(bem_dir, "%s_inner_skull_surface" % subject) else: - if method == 'flash': - bem_dir = op.join(bem_dir, 'flash') - outer_skin = op.join(bem_dir, 'outer_skin.surf') - outer_skull = op.join(bem_dir, 'outer_skull.surf') - inner_skull = op.join(bem_dir, 'inner_skull.surf') + if method == "flash": + bem_dir = op.join(bem_dir, "flash") + outer_skin = op.join(bem_dir, "outer_skin.surf") + outer_skull = op.join(bem_dir, "outer_skull.surf") + inner_skull = op.join(bem_dir, "inner_skull.surf") # put together the command - cmd = ['freeview'] + cmd = ["freeview"] cmd += ["--volume", mri] cmd += ["--surface", "%s:color=red:edgecolor=red" % inner_skull] cmd += ["--surface", "%s:color=yellow:edgecolor=yellow" % outer_skull] - cmd += ["--surface", - "%s:color=255,170,127:edgecolor=255,170,127" % outer_skin] + cmd += ["--surface", "%s:color=255,170,127:edgecolor=255,170,127" % outer_skin] run_subprocess(cmd, env=env, stdout=sys.stdout) print("[done]") @@ -82,18 +83,27 @@ def run(): parser = get_optparser(__file__) - subject = os.environ.get('SUBJECT') + subject = os.environ.get("SUBJECT") subjects_dir = get_subjects_dir() if subjects_dir is not None: subjects_dir = str(subjects_dir) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name", default=subject) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=subjects_dir) - parser.add_option("-m", "--method", dest="method", - help=("Method used to generate the BEM model. " - "Can be flash or watershed.")) + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name", default=subject + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=subjects_dir, + ) + parser.add_option( + "-m", + "--method", + dest="method", + help=("Method used to generate the BEM model. " "Can be flash or watershed."), + ) options, args = parser.parse_args() diff --git a/mne/commands/mne_kit2fiff.py b/mne/commands/mne_kit2fiff.py index 1317a154c8c..0c6b4545203 100644 --- a/mne/commands/mne_kit2fiff.py +++ b/mne/commands/mne_kit2fiff.py @@ -29,33 +29,50 @@ def run(): parser = get_optparser(__file__) - parser.add_option('--input', dest='input_fname', - help='Input data file name', metavar='filename') - parser.add_option('--mrk', dest='mrk_fname', - help='MEG Marker file name', metavar='filename') - parser.add_option('--elp', dest='elp_fname', - help='Headshape points file name', metavar='filename') - parser.add_option('--hsp', dest='hsp_fname', - help='Headshape file name', metavar='filename') - parser.add_option('--stim', dest='stim', - help='Colon Separated Stimulus Trigger Channels', - metavar='chs') - parser.add_option('--slope', dest='slope', help='Slope direction', - metavar='slope') - parser.add_option('--stimthresh', dest='stimthresh', default=1, - help='Threshold value for trigger channels', - metavar='value') - parser.add_option('--output', dest='out_fname', - help='Name of the resulting fiff file', - metavar='filename') - parser.add_option('--debug', dest='debug', action='http://github.com/store_true', - default=False, - help='Set logging level for terminal output to debug') + parser.add_option( + "--input", dest="input_fname", help="Input data file name", metavar="filename" + ) + parser.add_option( + "--mrk", dest="mrk_fname", help="MEG Marker file name", metavar="filename" + ) + parser.add_option( + "--elp", dest="elp_fname", help="Headshape points file name", metavar="filename" + ) + parser.add_option( + "--hsp", dest="hsp_fname", help="Headshape file name", metavar="filename" + ) + parser.add_option( + "--stim", + dest="stim", + help="Colon Separated Stimulus Trigger Channels", + metavar="chs", + ) + parser.add_option("--slope", dest="slope", help="Slope direction", metavar="slope") + parser.add_option( + "--stimthresh", + dest="stimthresh", + default=1, + help="Threshold value for trigger channels", + metavar="value", + ) + parser.add_option( + "--output", + dest="out_fname", + help="Name of the resulting fiff file", + metavar="filename", + ) + parser.add_option( + "--debug", + dest="debug", + action="http://github.com/store_true", + default=False, + help="Set logging level for terminal output to debug", + ) options, args = parser.parse_args() if options.debug: - mne.set_log_level('debug') + mne.set_log_level("debug") input_fname = options.input_fname if input_fname is None: @@ -63,8 +80,8 @@ def run(): from mne_kit_gui import kit2fiff # noqa except ImportError: raise ImportError( - 'The mne-kit-gui package is required, install it using ' - 'conda or pip') from None + "The mne-kit-gui package is required, install it using " "conda or pip" + ) from None kit2fiff() sys.exit(0) @@ -77,11 +94,17 @@ def run(): out_fname = options.out_fname if isinstance(stim, str): - stim = map(int, stim.split(':')) - - raw = read_raw_kit(input_fname=input_fname, mrk=mrk_fname, elp=elp_fname, - hsp=hsp_fname, stim=stim, slope=slope, - stimthresh=stimthresh) + stim = map(int, stim.split(":")) + + raw = read_raw_kit( + input_fname=input_fname, + mrk=mrk_fname, + elp=elp_fname, + hsp=hsp_fname, + stim=stim, + slope=slope, + stimthresh=stimthresh, + ) raw.save(out_fname) raw.close() diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 9da7941384c..c5bf03e06a0 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -27,29 +27,60 @@ def run(): from mne.commands.utils import get_optparser, _add_verbose_flag parser = get_optparser(__file__) - subjects_dir = mne.get_config('SUBJECTS_DIR') + subjects_dir = mne.get_config("SUBJECTS_DIR") - parser.add_option('-o', '--overwrite', dest='overwrite', - action='http://github.com/store_true', - help='Overwrite previously computed surface') - parser.add_option('-s', '--subject', dest='subject', - help='The name of the subject', type='str') - parser.add_option('-m', '--mri', dest='mri', type='str', default='T1.mgz', - help='The MRI file to process using mkheadsurf.') - parser.add_option('-f', '--force', dest='force', action='http://github.com/store_true', - help='Force creation of the surface even if it has ' - 'some topological defects.') - parser.add_option('-t', '--threshold', dest='threshold', type='int', - default=20, help='Threshold value to use with the MRI.') - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=subjects_dir) - parser.add_option("-n", "--no-decimate", dest="no_decimate", - help="Disable medium and sparse decimations " - "(dense only)", action='http://github.com/store_true') + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + action="http://github.com/store_true", + help="Overwrite previously computed surface", + ) + parser.add_option( + "-s", "--subject", dest="subject", help="The name of the subject", type="str" + ) + parser.add_option( + "-m", + "--mri", + dest="mri", + type="str", + default="T1.mgz", + help="The MRI file to process using mkheadsurf.", + ) + parser.add_option( + "-f", + "--force", + dest="force", + action="http://github.com/store_true", + help="Force creation of the surface even if it has " + "some topological defects.", + ) + parser.add_option( + "-t", + "--threshold", + dest="threshold", + type="int", + default=20, + help="Threshold value to use with the MRI.", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=subjects_dir, + ) + parser.add_option( + "-n", + "--no-decimate", + dest="no_decimate", + help="Disable medium and sparse decimations " "(dense only)", + action="http://github.com/store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() - subject = vars(options).get('subject', os.getenv('SUBJECT')) + subject = vars(options).get("subject", os.getenv("SUBJECT")) subjects_dir = options.subjects_dir if subject is None or subjects_dir is None: parser.print_help() @@ -62,7 +93,8 @@ def run(): no_decimate=options.no_decimate, threshold=options.threshold, mri=options.mri, - verbose=options.verbose) + verbose=options.verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_maxfilter.py b/mne/commands/mne_maxfilter.py index 4825b4d5553..182a2c6254b 100644 --- a/mne/commands/mne_maxfilter.py +++ b/mne/commands/mne_maxfilter.py @@ -25,71 +25,157 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="in_fname", - help="Input raw FIF file", metavar="FILE") - parser.add_option("-o", dest="out_fname", - help="Output FIF file (if not set, suffix '_sss' will " - "be used)", metavar="FILE", default=None) - parser.add_option("--origin", dest="origin", - help="Head origin in mm, or a filename to read the " - "origin from. If not set it will be estimated from " - "headshape points", default=None) - parser.add_option("--origin-out", dest="origin_out", - help="Filename to use for computed origin", default=None) - parser.add_option("--frame", dest="frame", type="string", - help="Coordinate frame for head center ('device' or " - "'head')", default="device") - parser.add_option("--bad", dest="bad", type="string", - help="List of static bad channels", - default=None) - parser.add_option("--autobad", dest="autobad", type="string", - help="Set automated bad channel detection ('on', 'off', " - "'n')", default="off") - parser.add_option("--skip", dest="skip", - help="Skips raw data sequences, time intervals pairs in " - "s, e.g.: 0 30 120 150", default=None) - parser.add_option("--force", dest="force", action="http://github.com/store_true", - help="Ignore program warnings", - default=False) - parser.add_option("--st", dest="st", action="http://github.com/store_true", - help="Apply the time-domain MaxST extension", - default=False) - parser.add_option("--buflen", dest="st_buflen", type="float", - help="MaxSt buffer length in s", - default=16.0) - parser.add_option("--corr", dest="st_corr", type="float", - help="MaxSt subspace correlation", - default=0.96) - parser.add_option("--trans", dest="mv_trans", - help="Transforms the data into the coil definitions of " - "in_fname, or into the default frame", default=None) - parser.add_option("--movecomp", dest="mv_comp", action="http://github.com/store_true", - help="Estimates and compensates head movements in " - "continuous raw data", default=False) - parser.add_option("--headpos", dest="mv_headpos", action="http://github.com/store_true", - help="Estimates and stores head position parameters, " - "but does not compensate movements", default=False) - parser.add_option("--hp", dest="mv_hp", type="string", - help="Stores head position data in an ascii file", - default=None) - parser.add_option("--hpistep", dest="mv_hpistep", type="float", - help="Sets head position update interval in ms", - default=None) - parser.add_option("--hpisubt", dest="mv_hpisubt", type="string", - help="Subtracts hpi signals: sine amplitudes, amp + " - "baseline, or switch off", default=None) - parser.add_option("--nohpicons", dest="mv_hpicons", action="http://github.com/store_false", - help="Do not check initial consistency isotrak vs " - "hpifit", default=True) - parser.add_option("--linefreq", dest="linefreq", type="float", - help="Sets the basic line interference frequency (50 or " - "60 Hz)", default=None) - parser.add_option("--nooverwrite", dest="overwrite", action="http://github.com/store_false", - help="Do not overwrite output file if it already exists", - default=True) - parser.add_option("--args", dest="mx_args", type="string", - help="Additional command line arguments to pass to " - "MaxFilter", default="") + parser.add_option( + "-i", "--in", dest="in_fname", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "-o", + dest="out_fname", + help="Output FIF file (if not set, suffix '_sss' will " "be used)", + metavar="FILE", + default=None, + ) + parser.add_option( + "--origin", + dest="origin", + help="Head origin in mm, or a filename to read the " + "origin from. If not set it will be estimated from " + "headshape points", + default=None, + ) + parser.add_option( + "--origin-out", + dest="origin_out", + help="Filename to use for computed origin", + default=None, + ) + parser.add_option( + "--frame", + dest="frame", + type="string", + help="Coordinate frame for head center ('device' or " "'head')", + default="device", + ) + parser.add_option( + "--bad", + dest="bad", + type="string", + help="List of static bad channels", + default=None, + ) + parser.add_option( + "--autobad", + dest="autobad", + type="string", + help="Set automated bad channel detection ('on', 'off', " "'n')", + default="off", + ) + parser.add_option( + "--skip", + dest="skip", + help="Skips raw data sequences, time intervals pairs in " + "s, e.g.: 0 30 120 150", + default=None, + ) + parser.add_option( + "--force", + dest="force", + action="http://github.com/store_true", + help="Ignore program warnings", + default=False, + ) + parser.add_option( + "--st", + dest="st", + action="http://github.com/store_true", + help="Apply the time-domain MaxST extension", + default=False, + ) + parser.add_option( + "--buflen", + dest="st_buflen", + type="float", + help="MaxSt buffer length in s", + default=16.0, + ) + parser.add_option( + "--corr", + dest="st_corr", + type="float", + help="MaxSt subspace correlation", + default=0.96, + ) + parser.add_option( + "--trans", + dest="mv_trans", + help="Transforms the data into the coil definitions of " + "in_fname, or into the default frame", + default=None, + ) + parser.add_option( + "--movecomp", + dest="mv_comp", + action="http://github.com/store_true", + help="Estimates and compensates head movements in " "continuous raw data", + default=False, + ) + parser.add_option( + "--headpos", + dest="mv_headpos", + action="http://github.com/store_true", + help="Estimates and stores head position parameters, " + "but does not compensate movements", + default=False, + ) + parser.add_option( + "--hp", + dest="mv_hp", + type="string", + help="Stores head position data in an ascii file", + default=None, + ) + parser.add_option( + "--hpistep", + dest="mv_hpistep", + type="float", + help="Sets head position update interval in ms", + default=None, + ) + parser.add_option( + "--hpisubt", + dest="mv_hpisubt", + type="string", + help="Subtracts hpi signals: sine amplitudes, amp + " "baseline, or switch off", + default=None, + ) + parser.add_option( + "--nohpicons", + dest="mv_hpicons", + action="http://github.com/store_false", + help="Do not check initial consistency isotrak vs " "hpifit", + default=True, + ) + parser.add_option( + "--linefreq", + dest="linefreq", + type="float", + help="Sets the basic line interference frequency (50 or " "60 Hz)", + default=None, + ) + parser.add_option( + "--nooverwrite", + dest="overwrite", + action="http://github.com/store_false", + help="Do not overwrite output file if it already exists", + default=True, + ) + parser.add_option( + "--args", + dest="mx_args", + type="string", + help="Additional command line arguments to pass to " "MaxFilter", + default="", + ) options, args = parser.parse_args() @@ -121,30 +207,48 @@ def run(): overwrite = options.overwrite mx_args = options.mx_args - if in_fname.endswith('_raw.fif') or in_fname.endswith('-raw.fif'): + if in_fname.endswith("_raw.fif") or in_fname.endswith("-raw.fif"): prefix = in_fname[:-8] else: prefix = in_fname[:-4] if out_fname is None: if st: - out_fname = prefix + '_tsss.fif' + out_fname = prefix + "_tsss.fif" else: - out_fname = prefix + '_sss.fif' + out_fname = prefix + "_sss.fif" if origin is not None and os.path.exists(origin): - with open(origin, 'r') as fid: + with open(origin, "r") as fid: origin = fid.readlines()[0].strip() origin = mne.preprocessing.apply_maxfilter( - in_fname, out_fname, origin, frame, - bad, autobad, skip, force, st, st_buflen, st_corr, mv_trans, - mv_comp, mv_headpos, mv_hp, mv_hpistep, mv_hpisubt, mv_hpicons, - linefreq, mx_args, overwrite) + in_fname, + out_fname, + origin, + frame, + bad, + autobad, + skip, + force, + st, + st_buflen, + st_corr, + mv_trans, + mv_comp, + mv_headpos, + mv_hp, + mv_hpistep, + mv_hpisubt, + mv_hpicons, + linefreq, + mx_args, + overwrite, + ) if origin_out is not None: - with open(origin_out, 'w') as fid: - fid.write(origin + '\n') + with open(origin_out, "w") as fid: + fid.write(origin + "\n") mne.utils.run_command_if_main() diff --git a/mne/commands/mne_prepare_bem_model.py b/mne/commands/mne_prepare_bem_model.py index da308bb737e..ae43ae9533a 100644 --- a/mne/commands/mne_prepare_bem_model.py +++ b/mne/commands/mne_prepare_bem_model.py @@ -20,18 +20,25 @@ def run(): parser = get_optparser(__file__) - parser.add_option('--bem', dest='bem_fname', - help='The name of the file containing the ' - 'triangulations of the BEM surfaces and the ' - 'conductivities of the compartments. The standard ' - 'ending for this file is -bem.fif.', - metavar="FILE") - parser.add_option('--sol', dest='bem_sol_fname', - help='The name of the resulting file containing BEM ' - 'solution (geometry matrix). It uses the linear ' - 'collocation approach. The file should end with ' - '-bem-sof.fif.', - metavar='FILE', default=None) + parser.add_option( + "--bem", + dest="bem_fname", + help="The name of the file containing the " + "triangulations of the BEM surfaces and the " + "conductivities of the compartments. The standard " + "ending for this file is -bem.fif.", + metavar="FILE", + ) + parser.add_option( + "--sol", + dest="bem_sol_fname", + help="The name of the resulting file containing BEM " + "solution (geometry matrix). It uses the linear " + "collocation approach. The file should end with " + "-bem-sof.fif.", + metavar="FILE", + default=None, + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -45,10 +52,9 @@ def run(): if bem_sol_fname is None: base, _ = os.path.splitext(bem_fname) - bem_sol_fname = base + '-sol.fif' + bem_sol_fname = base + "-sol.fif" - bem_model = mne.read_bem_surfaces(bem_fname, patch_stats=False, - verbose=verbose) + bem_model = mne.read_bem_surfaces(bem_fname, patch_stats=False, verbose=verbose) bem_solution = mne.make_bem_solution(bem_model, verbose=verbose) mne.write_bem_solution(bem_sol_fname, bem_solution) diff --git a/mne/commands/mne_report.py b/mne/commands/mne_report.py index 2d96570f26f..79818d52bab 100644 --- a/mne/commands/mne_report.py +++ b/mne/commands/mne_report.py @@ -78,7 +78,7 @@ @verbose def log_elapsed(t, verbose=None): """Log elapsed time.""" - logger.info('Report complete in %s seconds' % round(t, 1)) + logger.info("Report complete in %s seconds" % round(t, 1)) def run(): @@ -87,36 +87,72 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-p", "--path", dest="path", - help="Path to folder who MNE-Report must be created") - parser.add_option("-i", "--info", dest="info_fname", - help="File from which info dictionary is to be read", - metavar="FILE") - parser.add_option("-c", "--cov", dest="cov_fname", - help="File from which noise covariance is to be read", - metavar="FILE") - parser.add_option("--bmin", dest="bmin", - help="Time at which baseline correction starts for " - "evokeds", default=None) - parser.add_option("--bmax", dest="bmax", - help="Time at which baseline correction stops for " - "evokeds", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="The subjects directory") - parser.add_option("-s", "--subject", dest="subject", - help="The subject name") - parser.add_option("--no-browser", dest="no_browser", action='http://github.com/store_false', - help="Do not open MNE-Report in browser") - parser.add_option("--overwrite", dest="overwrite", action='http://github.com/store_false', - help="Overwrite html report if it already exists") - parser.add_option("-j", "--jobs", dest="n_jobs", help="Number of jobs to" - " run in parallel") - parser.add_option("-m", "--mri-decim", type="int", dest="mri_decim", - default=2, help="Integer factor used to decimate " - "BEM plots") - parser.add_option("--image-format", type="str", dest="image_format", - default='png', help="Image format to use " - "(can be 'png' or 'svg')") + parser.add_option( + "-p", + "--path", + dest="path", + help="Path to folder who MNE-Report must be created", + ) + parser.add_option( + "-i", + "--info", + dest="info_fname", + help="File from which info dictionary is to be read", + metavar="FILE", + ) + parser.add_option( + "-c", + "--cov", + dest="cov_fname", + help="File from which noise covariance is to be read", + metavar="FILE", + ) + parser.add_option( + "--bmin", + dest="bmin", + help="Time at which baseline correction starts for " "evokeds", + default=None, + ) + parser.add_option( + "--bmax", + dest="bmax", + help="Time at which baseline correction stops for " "evokeds", + default=None, + ) + parser.add_option( + "-d", "--subjects-dir", dest="subjects_dir", help="The subjects directory" + ) + parser.add_option("-s", "--subject", dest="subject", help="The subject name") + parser.add_option( + "--no-browser", + dest="no_browser", + action="http://github.com/store_false", + help="Do not open MNE-Report in browser", + ) + parser.add_option( + "--overwrite", + dest="overwrite", + action="http://github.com/store_false", + help="Overwrite html report if it already exists", + ) + parser.add_option( + "-j", "--jobs", dest="n_jobs", help="Number of jobs to" " run in parallel" + ) + parser.add_option( + "-m", + "--mri-decim", + type="int", + dest="mri_decim", + default=2, + help="Integer factor used to decimate " "BEM plots", + ) + parser.add_option( + "--image-format", + type="str", + dest="image_format", + default="png", + help="Image format to use " "(can be 'png' or 'svg')", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -144,12 +180,16 @@ def run(): baseline = (bmin, bmax) t0 = time.time() - report = Report(info_fname, subjects_dir=subjects_dir, - subject=subject, baseline=baseline, - cov_fname=cov_fname, verbose=verbose, - image_format=image_format) - report.parse_folder(path, verbose=verbose, n_jobs=n_jobs, - mri_decim=mri_decim) + report = Report( + info_fname, + subjects_dir=subjects_dir, + subject=subject, + baseline=baseline, + cov_fname=cov_fname, + verbose=verbose, + image_format=image_format, + ) + report.parse_folder(path, verbose=verbose, n_jobs=n_jobs, mri_decim=mri_decim) log_elapsed(time.time() - t0, verbose=verbose) report.save(open_browser=open_browser, overwrite=overwrite) diff --git a/mne/commands/mne_setup_forward_model.py b/mne/commands/mne_setup_forward_model.py index 239decefbfe..df7fc5fff4b 100644 --- a/mne/commands/mne_setup_forward_model.py +++ b/mne/commands/mne_setup_forward_model.py @@ -21,51 +21,66 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", - dest="subject", - help="Subject name (required)", - default=None) - parser.add_option("--model", - dest="model", - help="Output file name. Use a name /-bem.fif", - default=None, - type='string') - parser.add_option('--ico', - dest='ico', - help='The surface ico downsampling to use, e.g. ' - ' 5=20484, 4=5120, 3=1280. If None, no subsampling' - ' is applied.', - default=None, - type='int') - parser.add_option('--brainc', - dest='brainc', - help='Defines the brain compartment conductivity. ' - 'The default value is 0.3 S/m.', - default=0.3, - type='float') - parser.add_option('--skullc', - dest='skullc', - help='Defines the skull compartment conductivity. ' - 'The default value is 0.006 S/m.', - default=None, - type='float') - parser.add_option('--scalpc', - dest='scalpc', - help='Defines the scalp compartment conductivity. ' - 'The default value is 0.3 S/m.', - default=None, - type='float') - parser.add_option('--homog', - dest='homog', - help='Use a single compartment model (brain only) ' - 'instead a three layer one (scalp, skull, and ' - ' brain). If this flag is specified, the options ' - '--skullc and --scalpc are irrelevant.', - default=None, action="http://github.com/store_true") - parser.add_option('-d', '--subjects-dir', - dest='subjects_dir', - help='Subjects directory', - default=None) + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "--model", + dest="model", + help="Output file name. Use a name /-bem.fif", + default=None, + type="string", + ) + parser.add_option( + "--ico", + dest="ico", + help="The surface ico downsampling to use, e.g. " + " 5=20484, 4=5120, 3=1280. If None, no subsampling" + " is applied.", + default=None, + type="int", + ) + parser.add_option( + "--brainc", + dest="brainc", + help="Defines the brain compartment conductivity. " + "The default value is 0.3 S/m.", + default=0.3, + type="float", + ) + parser.add_option( + "--skullc", + dest="skullc", + help="Defines the skull compartment conductivity. " + "The default value is 0.006 S/m.", + default=None, + type="float", + ) + parser.add_option( + "--scalpc", + dest="scalpc", + help="Defines the scalp compartment conductivity. " + "The default value is 0.3 S/m.", + default=None, + type="float", + ) + parser.add_option( + "--homog", + dest="homog", + help="Use a single compartment model (brain only) " + "instead a three layer one (scalp, skull, and " + " brain). If this flag is specified, the options " + "--skullc and --scalpc are irrelevant.", + default=None, + action="http://github.com/store_true", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -85,11 +100,15 @@ def run(): # Parse conductivity option if homog is True: if skullc is not None: - warn('Trying to set the skull conductivity for a single layer ' - 'model. To use a 3 layer model, do not set the --homog flag.') + warn( + "Trying to set the skull conductivity for a single layer " + "model. To use a 3 layer model, do not set the --homog flag." + ) if scalpc is not None: - warn('Trying to set the scalp conductivity for a single layer ' - 'model. To use a 3 layer model, do not set the --homog flag.') + warn( + "Trying to set the scalp conductivity for a single layer " + "model. To use a 3 layer model, do not set the --homog flag." + ) # Single layer conductivity = [brainc] else: @@ -99,17 +118,19 @@ def run(): scalpc = 0.3 conductivity = [brainc, skullc, scalpc] # Create source space - bem_model = mne.make_bem_model(subject, - ico=ico, - conductivity=conductivity, - subjects_dir=subjects_dir, - verbose=verbose) + bem_model = mne.make_bem_model( + subject, + ico=ico, + conductivity=conductivity, + subjects_dir=subjects_dir, + verbose=verbose, + ) # Generate filename if fname is None: - n_faces = list(str(len(surface['tris'])) for surface in bem_model) - fname = subject + '-' + '-'.join(n_faces) + '-bem.fif' + n_faces = list(str(len(surface["tris"])) for surface in bem_model) + fname = subject + "-" + "-".join(n_faces) + "-bem.fif" else: - if not (fname.endswith('-bem.fif') or fname.endswith('_bem.fif')): + if not (fname.endswith("-bem.fif") or fname.endswith("_bem.fif")): fname = fname + "-bem.fif" # Save to subject's directory subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index e8b14b78db3..49bf0b9ed06 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -22,64 +22,89 @@ def run(): """Run command.""" from mne.commands.utils import get_optparser, _add_verbose_flag + parser = get_optparser(__file__) - parser.add_option('-s', '--subject', - dest='subject', - help='Subject name (required)', - default=None) - parser.add_option('--src', dest='fname', - help='Output file name. Use a name /-src.fif', - metavar='FILE', default=None) - parser.add_option('--morph', - dest='subject_to', - help='morph the source space to this subject', - default=None) - parser.add_option('--surf', - dest='surface', - help='The surface to use. (default to white)', - default='white', - type='string') - parser.add_option('--spacing', - dest='spacing', - help='Specifies the approximate grid spacing of the ' - 'source space in mm. (default to 7mm)', - default=None, - type='int') - parser.add_option('--ico', - dest='ico', - help='use the recursively subdivided icosahedron ' - 'to create the source space.', - default=None, - type='int') - parser.add_option('--oct', - dest='oct', - help='use the recursively subdivided octahedron ' - 'to create the source space.', - default=None, - type='int') - parser.add_option('-d', '--subjects-dir', - dest='subjects_dir', - help='Subjects directory', - default=None) - parser.add_option('-n', '--n-jobs', - dest='n_jobs', - help='The number of jobs to run in parallel ' - '(default 1). Requires the joblib package. ' - 'Will use at most 2 jobs' - ' (one for each hemisphere).', - default=1, - type='int') - parser.add_option('--add-dist', - dest='add_dist', - help='Add distances. Can be "True", "False", or "patch" ' - 'to only compute cortical patch statistics (like the ' - '--cps option in MNE-C; requires SciPy >= 1.3)', - default='True') - parser.add_option('-o', '--overwrite', - dest='overwrite', - help='to write over existing files', - default=None, action="http://github.com/store_true") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "--src", + dest="fname", + help="Output file name. Use a name /-src.fif", + metavar="FILE", + default=None, + ) + parser.add_option( + "--morph", + dest="subject_to", + help="morph the source space to this subject", + default=None, + ) + parser.add_option( + "--surf", + dest="surface", + help="The surface to use. (default to white)", + default="white", + type="string", + ) + parser.add_option( + "--spacing", + dest="spacing", + help="Specifies the approximate grid spacing of the " + "source space in mm. (default to 7mm)", + default=None, + type="int", + ) + parser.add_option( + "--ico", + dest="ico", + help="use the recursively subdivided icosahedron " + "to create the source space.", + default=None, + type="int", + ) + parser.add_option( + "--oct", + dest="oct", + help="use the recursively subdivided octahedron " "to create the source space.", + default=None, + type="int", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-n", + "--n-jobs", + dest="n_jobs", + help="The number of jobs to run in parallel " + "(default 1). Requires the joblib package. " + "Will use at most 2 jobs" + " (one for each hemisphere).", + default=1, + type="int", + ) + parser.add_option( + "--add-dist", + dest="add_dist", + help='Add distances. Can be "True", "False", or "patch" ' + "to only compute cortical patch statistics (like the " + "--cps option in MNE-C; requires SciPy >= 1.3)", + default="True", + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + help="to write over existing files", + default=None, + action="http://github.com/store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -98,8 +123,8 @@ def run(): surface = options.surface n_jobs = options.n_jobs add_dist = options.add_dist - _check_option('add_dist', add_dist, ('True', 'False', 'patch')) - add_dist = {'True': True, 'False': False, 'patch': 'patch'}[add_dist] + _check_option("add_dist", add_dist, ("True", "False", "patch")) + add_dist = {"True": True, "False": False, "patch": "patch"}[add_dist] verbose = True if options.verbose is not None else False overwrite = True if options.overwrite is not None else False @@ -107,10 +132,10 @@ def run(): spacing_options = [ico, oct, spacing] n_options = len([x for x in spacing_options if x is not None]) if n_options > 1: - raise ValueError('Only one spacing option can be set at the same time') + raise ValueError("Only one spacing option can be set at the same time") elif n_options == 0: # Default to oct6 - use_spacing = 'oct6' + use_spacing = "oct6" elif n_options == 1: if ico is not None: use_spacing = "ico" + str(ico) @@ -121,23 +146,31 @@ def run(): # Generate filename if fname is None: if subject_to is None: - fname = subject + '-' + str(use_spacing) + '-src.fif' + fname = subject + "-" + str(use_spacing) + "-src.fif" else: - fname = (subject_to + '-' + subject + '-' + - str(use_spacing) + '-src.fif') + fname = subject_to + "-" + subject + "-" + str(use_spacing) + "-src.fif" else: - if not (fname.endswith('_src.fif') or fname.endswith('-src.fif')): + if not (fname.endswith("_src.fif") or fname.endswith("-src.fif")): fname = fname + "-src.fif" # Create source space - src = mne.setup_source_space(subject=subject, spacing=use_spacing, - surface=surface, subjects_dir=subjects_dir, - n_jobs=n_jobs, add_dist=add_dist, - verbose=verbose) + src = mne.setup_source_space( + subject=subject, + spacing=use_spacing, + surface=surface, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + add_dist=add_dist, + verbose=verbose, + ) # Morph source space if --morph is set if subject_to is not None: - src = mne.morph_source_spaces(src, subject_to=subject_to, - subjects_dir=subjects_dir, - surf=surface, verbose=verbose) + src = mne.morph_source_spaces( + src, + subject_to=subject_to, + subjects_dir=subjects_dir, + surf=surface, + verbose=verbose, + ) # Save source space to file src.save(fname=fname, overwrite=overwrite) diff --git a/mne/commands/mne_show_fiff.py b/mne/commands/mne_show_fiff.py index be31cde2ad8..ed6fccdf89e 100644 --- a/mne/commands/mne_show_fiff.py +++ b/mne/commands/mne_show_fiff.py @@ -24,10 +24,14 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser( - __file__, usage='mne show_fiff ') - parser.add_option("-t", "--tag", dest="tag", - help="provide information about this tag", metavar="TAG") + parser = mne.commands.utils.get_optparser(__file__, usage="mne show_fiff ") + parser.add_option( + "-t", + "--tag", + dest="tag", + help="provide information about this tag", + metavar="TAG", + ) options, args = parser.parse_args() if len(args) != 1: parser.print_help() diff --git a/mne/commands/mne_show_info.py b/mne/commands/mne_show_info.py index 44e1fa79141..dc39491fb6c 100644 --- a/mne/commands/mne_show_info.py +++ b/mne/commands/mne_show_info.py @@ -17,8 +17,7 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser( - __file__, usage='mne show_info ') + parser = mne.commands.utils.get_optparser(__file__, usage="mne show_info ") options, args = parser.parse_args() if len(args) != 1: parser.print_help() @@ -26,8 +25,8 @@ def run(): fname = args[0] - if not fname.endswith('.fif'): - raise ValueError('%s does not seem to be a .fif file.' % fname) + if not fname.endswith(".fif"): + raise ValueError("%s does not seem to be a .fif file." % fname) info = mne.io.read_info(fname) print("File : %s" % fname) diff --git a/mne/commands/mne_surf2bem.py b/mne/commands/mne_surf2bem.py index 4cb5ade9662..93a154b2477 100644 --- a/mne/commands/mne_surf2bem.py +++ b/mne/commands/mne_surf2bem.py @@ -25,12 +25,19 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--surf", dest="surf", - help="Surface in Freesurfer format", metavar="FILE") - parser.add_option("-f", "--fif", dest="fif", - help="FIF file produced", metavar="FILE") - parser.add_option("-i", "--id", dest="id", default=4, - help=("Surface Id (e.g. 4 for head surface)")) + parser.add_option( + "-s", "--surf", dest="surf", help="Surface in Freesurfer format", metavar="FILE" + ) + parser.add_option( + "-f", "--fif", dest="fif", help="FIF file produced", metavar="FILE" + ) + parser.add_option( + "-i", + "--id", + dest="id", + default=4, + help=("Surface Id (e.g. 4 for head surface)"), + ) options, args = parser.parse_args() @@ -39,8 +46,7 @@ def run(): sys.exit(1) print("Converting %s to BEM FIF file." % options.surf) - surf = mne.bem._surfaces_to_bem([options.surf], [int(options.id)], - sigmas=[1]) + surf = mne.bem._surfaces_to_bem([options.surf], [int(options.id)], sigmas=[1]) mne.write_bem_surfaces(options.fif, surf) diff --git a/mne/commands/mne_sys_info.py b/mne/commands/mne_sys_info.py index a09994de8f9..075ff446681 100644 --- a/mne/commands/mne_sys_info.py +++ b/mne/commands/mne_sys_info.py @@ -17,17 +17,31 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser(__file__, usage='mne sys_info') - parser.add_option('-p', '--show-paths', dest='show_paths', - help='Show module paths', action='http://github.com/store_true') - parser.add_option('-d', '--developer', dest='developer', - help='Show additional developer module information', - action='http://github.com/store_true') - parser.add_option('-a', '--ascii', dest='unicode', - help='Use ASCII instead of unicode symbols', - action='http://github.com/store_false', default=True) + parser = mne.commands.utils.get_optparser(__file__, usage="mne sys_info") + parser.add_option( + "-p", + "--show-paths", + dest="show_paths", + help="Show module paths", + action="http://github.com/store_true", + ) + parser.add_option( + "-d", + "--developer", + dest="developer", + help="Show additional developer module information", + action="http://github.com/store_true", + ) + parser.add_option( + "-a", + "--ascii", + dest="unicode", + help="Use ASCII instead of unicode symbols", + action="http://github.com/store_false", + default=True, + ) options, args = parser.parse_args() - dependencies = 'developer' if options.developer else 'user' + dependencies = "developer" if options.developer else "user" if len(args) != 0: parser.print_help() sys.exit(1) @@ -35,7 +49,7 @@ def run(): mne.sys_info( show_paths=options.show_paths, dependencies=dependencies, - unicode=options.unicode + unicode=options.unicode, ) diff --git a/mne/commands/mne_watershed_bem.py b/mne/commands/mne_watershed_bem.py index b69a2801fd6..c182c7a0ded 100644 --- a/mne/commands/mne_watershed_bem.py +++ b/mne/commands/mne_watershed_bem.py @@ -23,35 +23,73 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name (required)", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=None) - parser.add_option("-o", "--overwrite", dest="overwrite", - help="Write over existing files", action="http://github.com/store_true") - parser.add_option("-v", "--volume", dest="volume", - help="Defaults to T1", default='T1') - parser.add_option("-a", "--atlas", dest="atlas", - help="Specify the --atlas option for mri_watershed", - default=False, action="http://github.com/store_true") - parser.add_option("-g", "--gcaatlas", dest="gcaatlas", - help="Specify the --brain_atlas option for " - "mri_watershed", default=False, action="http://github.com/store_true") - parser.add_option("-p", "--preflood", dest="preflood", - help="Change the preflood height", default=None) - parser.add_option("--copy", dest="copy", - help="Use copies instead of symlinks for surfaces", - action="http://github.com/store_true") - parser.add_option("-t", "--T1", dest="T1", - help="Whether or not to pass the -T1 flag " - "(can be true, false, 0, or 1). " - "By default it takes the same value as gcaatlas.", - default=None) - parser.add_option("-b", "--brainmask", dest="brainmask", - help="The filename for the brainmask output file " - "relative to the " - "$SUBJECTS_DIR/$SUBJECT/bem/watershed/ directory.", - default="ws") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + help="Write over existing files", + action="http://github.com/store_true", + ) + parser.add_option( + "-v", "--volume", dest="volume", help="Defaults to T1", default="T1" + ) + parser.add_option( + "-a", + "--atlas", + dest="atlas", + help="Specify the --atlas option for mri_watershed", + default=False, + action="http://github.com/store_true", + ) + parser.add_option( + "-g", + "--gcaatlas", + dest="gcaatlas", + help="Specify the --brain_atlas option for " "mri_watershed", + default=False, + action="http://github.com/store_true", + ) + parser.add_option( + "-p", + "--preflood", + dest="preflood", + help="Change the preflood height", + default=None, + ) + parser.add_option( + "--copy", + dest="copy", + help="Use copies instead of symlinks for surfaces", + action="http://github.com/store_true", + ) + parser.add_option( + "-t", + "--T1", + dest="T1", + help="Whether or not to pass the -T1 flag " + "(can be true, false, 0, or 1). " + "By default it takes the same value as gcaatlas.", + default=None, + ) + parser.add_option( + "-b", + "--brainmask", + dest="brainmask", + help="The filename for the brainmask output file " + "relative to the " + "$SUBJECTS_DIR/$SUBJECT/bem/watershed/ directory.", + default="ws", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -72,14 +110,23 @@ def run(): T1 = options.T1 if T1 is not None: T1 = T1.lower() - _check_option("--T1", T1, ('true', 'false', '0', '1')) - T1 = T1 in ('true', '1') + _check_option("--T1", T1, ("true", "false", "0", "1")) + T1 = T1 in ("true", "1") verbose = options.verbose - make_watershed_bem(subject=subject, subjects_dir=subjects_dir, - overwrite=overwrite, volume=volume, atlas=atlas, - gcaatlas=gcaatlas, preflood=preflood, copy=copy, - T1=T1, brainmask=brainmask, verbose=verbose) + make_watershed_bem( + subject=subject, + subjects_dir=subjects_dir, + overwrite=overwrite, + volume=volume, + atlas=atlas, + gcaatlas=gcaatlas, + preflood=preflood, + copy=copy, + T1=T1, + brainmask=brainmask, + verbose=verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_what.py b/mne/commands/mne_what.py index 5d281facd0c..ab4a9d5ea8f 100644 --- a/mne/commands/mne_what.py +++ b/mne/commands/mne_what.py @@ -17,7 +17,8 @@ def run(): """Run command.""" from mne.commands.utils import get_optparser - parser = get_optparser(__file__, usage='usage: %prog fname [fname2 ...]') + + parser = get_optparser(__file__, usage="usage: %prog fname [fname2 ...]") options, args = parser.parse_args() for arg in args: print(mne.what(arg)) diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index 995edae59b9..c3bac034339 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -9,48 +9,74 @@ from numpy.testing import assert_equal, assert_allclose import mne -from mne import (concatenate_raws, read_bem_surfaces, read_surface, - read_source_spaces, read_bem_solution) +from mne import ( + concatenate_raws, + read_bem_surfaces, + read_surface, + read_source_spaces, + read_bem_solution, +) from mne.bem import ConductorModel, convert_flash_mris -from mne.commands import (mne_browse_raw, mne_bti2fiff, mne_clean_eog_ecg, - mne_compute_proj_ecg, mne_compute_proj_eog, - mne_coreg, mne_kit2fiff, - mne_make_scalp_surfaces, mne_maxfilter, - mne_report, mne_surf2bem, mne_watershed_bem, - mne_compare_fiff, mne_flash_bem, mne_show_fiff, - mne_show_info, mne_what, mne_setup_source_space, - mne_setup_forward_model, mne_anonymize, - mne_prepare_bem_model, mne_sys_info) +from mne.commands import ( + mne_browse_raw, + mne_bti2fiff, + mne_clean_eog_ecg, + mne_compute_proj_ecg, + mne_compute_proj_eog, + mne_coreg, + mne_kit2fiff, + mne_make_scalp_surfaces, + mne_maxfilter, + mne_report, + mne_surf2bem, + mne_watershed_bem, + mne_compare_fiff, + mne_flash_bem, + mne_show_fiff, + mne_show_info, + mne_what, + mne_setup_source_space, + mne_setup_forward_model, + mne_anonymize, + mne_prepare_bem_model, + mne_sys_info, +) from mne.datasets import testing from mne.io import read_raw_fif, read_info -from mne.utils import (requires_mne, requires_freesurfer, ArgvSetter, - _stamp_to_dt, _record_warnings) +from mne.utils import ( + requires_mne, + requires_freesurfer, + ArgvSetter, + _stamp_to_dt, + _record_warnings, +) -base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') -raw_fname = op.join(base_dir, 'test_raw.fif') +base_dir = op.join(op.dirname(__file__), "..", "..", "io", "tests", "data") +raw_fname = op.join(base_dir, "test_raw.fif") testing_path = testing.data_path(download=False) -subjects_dir = op.join(testing_path, 'subjects') -bem_model_fname = op.join(testing_path, 'subjects', - 'sample', 'bem', 'sample-320-320-320-bem.fif') +subjects_dir = op.join(testing_path, "subjects") +bem_model_fname = op.join( + testing_path, "subjects", "sample", "bem", "sample-320-320-320-bem.fif" +) def check_usage(module, force_help=False): """Ensure we print usage.""" - args = ('--help',) if force_help else () + args = ("--help",) if force_help else () with ArgvSetter(args) as out: try: module.run() except SystemExit: pass - assert 'Usage: ' in out.stdout.getvalue() + assert "Usage: " in out.stdout.getvalue() @pytest.mark.slowtest def test_browse_raw(): """Test mne browse_raw.""" check_usage(mne_browse_raw) - with ArgvSetter(('--raw', raw_fname)): + with ArgvSetter(("--raw", raw_fname)): with _record_warnings(): # mpl show warning mne_browse_raw.run() @@ -60,7 +86,7 @@ def test_what(): check_usage(mne_browse_raw) with ArgvSetter((raw_fname,)) as out: mne_what.run() - assert 'raw' == out.stdout.getvalue().strip() + assert "raw" == out.stdout.getvalue().strip() def test_bti2fiff(): @@ -78,7 +104,7 @@ def test_show_fiff(): check_usage(mne_show_fiff) with ArgvSetter((raw_fname,)): mne_show_fiff.run() - with ArgvSetter((raw_fname, '--tag=102')): + with ArgvSetter((raw_fname, "--tag=102")): mne_show_fiff.run() @@ -87,42 +113,40 @@ def test_clean_eog_ecg(tmp_path): """Test mne clean_eog_ecg.""" check_usage(mne_clean_eog_ecg) tempdir = str(tmp_path) - raw = concatenate_raws([read_raw_fif(f) - for f in [raw_fname, raw_fname, raw_fname]]) - raw.info['bads'] = ['MEG 2443'] + raw = concatenate_raws([read_raw_fif(f) for f in [raw_fname, raw_fname, raw_fname]]) + raw.info["bads"] = ["MEG 2443"] use_fname = op.join(tempdir, op.basename(raw_fname)) raw.save(use_fname) - with ArgvSetter(('-i', use_fname, '--quiet')): + with ArgvSetter(("-i", use_fname, "--quiet")): mne_clean_eog_ecg.run() - for key, count in (('proj', 2), ('-eve', 3)): - fnames = glob.glob(op.join(tempdir, '*%s.fif' % key)) + for key, count in (("proj", 2), ("-eve", 3)): + fnames = glob.glob(op.join(tempdir, "*%s.fif" % key)) assert len(fnames) == count @pytest.mark.slowtest -@pytest.mark.parametrize('fun', (mne_compute_proj_ecg, mne_compute_proj_eog)) +@pytest.mark.parametrize("fun", (mne_compute_proj_ecg, mne_compute_proj_eog)) def test_compute_proj_exg(tmp_path, fun): """Test mne compute_proj_ecg/eog.""" check_usage(fun) tempdir = str(tmp_path) use_fname = op.join(tempdir, op.basename(raw_fname)) - bad_fname = op.join(tempdir, 'bads.txt') - with open(bad_fname, 'w') as fid: - fid.write('MEG 2443\n') + bad_fname = op.join(tempdir, "bads.txt") + with open(bad_fname, "w") as fid: + fid.write("MEG 2443\n") shutil.copyfile(raw_fname, use_fname) - with ArgvSetter(('-i', use_fname, '--bad=' + bad_fname, - '--rej-eeg', '150')): + with ArgvSetter(("-i", use_fname, "--bad=" + bad_fname, "--rej-eeg", "150")): with _record_warnings(): # samples, sometimes fun.run() - fnames = glob.glob(op.join(tempdir, '*proj.fif')) + fnames = glob.glob(op.join(tempdir, "*proj.fif")) assert len(fnames) == 1 - fnames = glob.glob(op.join(tempdir, '*-eve.fif')) + fnames = glob.glob(op.join(tempdir, "*-eve.fif")) assert len(fnames) == 1 def test_coreg(): """Test mne coreg.""" - assert hasattr(mne_coreg, 'run') + assert hasattr(mne_coreg, "run") def test_kit2fiff(): @@ -136,60 +160,73 @@ def test_kit2fiff(): @testing.requires_testing_data def test_make_scalp_surfaces(tmp_path, monkeypatch): """Test mne make_scalp_surfaces.""" - pytest.importorskip('nibabel') - pytest.importorskip('pyvista') + pytest.importorskip("nibabel") + pytest.importorskip("pyvista") check_usage(mne_make_scalp_surfaces) - has = 'SUBJECTS_DIR' in os.environ + has = "SUBJECTS_DIR" in os.environ # Copy necessary files to avoid FreeSurfer call tempdir = str(tmp_path) - surf_path = op.join(subjects_dir, 'sample', 'surf') - surf_path_new = op.join(tempdir, 'sample', 'surf') - os.mkdir(op.join(tempdir, 'sample')) + surf_path = op.join(subjects_dir, "sample", "surf") + surf_path_new = op.join(tempdir, "sample", "surf") + os.mkdir(op.join(tempdir, "sample")) os.mkdir(surf_path_new) - subj_dir = op.join(tempdir, 'sample', 'bem') + subj_dir = op.join(tempdir, "sample", "bem") os.mkdir(subj_dir) - cmd = ('-s', 'sample', '--subjects-dir', tempdir) + cmd = ("-s", "sample", "--subjects-dir", tempdir) monkeypatch.setattr( - mne.bem, 'decimate_surface', - lambda points, triangles, n_triangles: (points, triangles)) - dense_fname = op.join(subj_dir, 'sample-head-dense.fif') - medium_fname = op.join(subj_dir, 'sample-head-medium.fif') + mne.bem, + "decimate_surface", + lambda points, triangles, n_triangles: (points, triangles), + ) + dense_fname = op.join(subj_dir, "sample-head-dense.fif") + medium_fname = op.join(subj_dir, "sample-head-medium.fif") with ArgvSetter(cmd, disable_stdout=False, disable_stderr=False): - monkeypatch.delenv('FREESURFER_HOME') - with pytest.raises(RuntimeError, match='The FreeSurfer environ'): + monkeypatch.delenv("FREESURFER_HOME") + with pytest.raises(RuntimeError, match="The FreeSurfer environ"): mne_make_scalp_surfaces.run() - shutil.copy(op.join(surf_path, 'lh.seghead'), surf_path_new) - monkeypatch.setenv('FREESURFER_HOME', tempdir) + shutil.copy(op.join(surf_path, "lh.seghead"), surf_path_new) + monkeypatch.setenv("FREESURFER_HOME", tempdir) mne_make_scalp_surfaces.run() assert op.isfile(dense_fname) assert op.isfile(medium_fname) - with pytest.raises(OSError, match='overwrite'): + with pytest.raises(OSError, match="overwrite"): mne_make_scalp_surfaces.run() # actually check the outputs head_py = read_bem_surfaces(dense_fname) assert_equal(len(head_py), 1) head_py = head_py[0] - head_c = read_bem_surfaces(op.join(subjects_dir, 'sample', 'bem', - 'sample-head-dense.fif'))[0] - assert_allclose(head_py['rr'], head_c['rr']) + head_c = read_bem_surfaces( + op.join(subjects_dir, "sample", "bem", "sample-head-dense.fif") + )[0] + assert_allclose(head_py["rr"], head_c["rr"]) if not has: - assert 'SUBJECTS_DIR' not in os.environ + assert "SUBJECTS_DIR" not in os.environ def test_maxfilter(): """Test mne maxfilter.""" check_usage(mne_maxfilter) - with ArgvSetter(('-i', raw_fname, '--st', '--movecomp', '--linefreq', '60', - '--trans', raw_fname)) as out: + with ArgvSetter( + ( + "-i", + raw_fname, + "--st", + "--movecomp", + "--linefreq", + "60", + "--trans", + raw_fname, + ) + ) as out: with pytest.warns(RuntimeWarning, match="Don't use"): - os.environ['_MNE_MAXFILTER_TEST'] = 'true' + os.environ["_MNE_MAXFILTER_TEST"] = "true" try: mne_maxfilter.run() finally: - del os.environ['_MNE_MAXFILTER_TEST'] + del os.environ["_MNE_MAXFILTER_TEST"] out = out.stdout.getvalue() - for check in ('maxfilter', '-trans', '-movecomp'): + for check in ("maxfilter", "-trans", "-movecomp"): assert check in out, check @@ -197,16 +234,29 @@ def test_maxfilter(): @testing.requires_testing_data def test_report(tmp_path): """Test mne report.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_report) tempdir = str(tmp_path) use_fname = op.join(tempdir, op.basename(raw_fname)) shutil.copyfile(raw_fname, use_fname) - with ArgvSetter(('-p', tempdir, '-i', use_fname, '-d', subjects_dir, - '-s', 'sample', '--no-browser', '-m', '30')): + with ArgvSetter( + ( + "-p", + tempdir, + "-i", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--no-browser", + "-m", + "30", + ) + ): with _record_warnings(): # contour levels mne_report.run() - fnames = glob.glob(op.join(tempdir, '*.html')) + fnames = glob.glob(op.join(tempdir, "*.html")) assert len(fnames) == 1 @@ -218,48 +268,48 @@ def test_surf2bem(): @pytest.mark.timeout(900) # took ~400 s on a local test @pytest.mark.slowtest @pytest.mark.ultraslowtest -@requires_freesurfer('mri_watershed') +@requires_freesurfer("mri_watershed") @testing.requires_testing_data def test_watershed_bem(tmp_path): """Test mne watershed bem.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_watershed_bem) # from T1.mgz Mdc = np.array([[-1, 0, 0], [0, 0, -1], [0, 1, 0]]) Pxyz_c = np.array([-5.273613, 9.039085, -27.287964]) # Copy necessary files to tempdir tempdir = str(tmp_path) - mridata_path = op.join(subjects_dir, 'sample', 'mri') - subject_path_new = op.join(tempdir, 'sample') - mridata_path_new = op.join(subject_path_new, 'mri') + mridata_path = op.join(subjects_dir, "sample", "mri") + subject_path_new = op.join(tempdir, "sample") + mridata_path_new = op.join(subject_path_new, "mri") os.makedirs(mridata_path_new) - new_fname = op.join(mridata_path_new, 'T1.mgz') - shutil.copyfile(op.join(mridata_path, 'T1.mgz'), new_fname) + new_fname = op.join(mridata_path_new, "T1.mgz") + shutil.copyfile(op.join(mridata_path, "T1.mgz"), new_fname) old_mode = os.stat(new_fname).st_mode os.chmod(new_fname, 0) - args = ('-d', tempdir, '-s', 'sample', '-o') - with pytest.raises(PermissionError, match=r'read permissions.*T1\.mgz'): + args = ("-d", tempdir, "-s", "sample", "-o") + with pytest.raises(PermissionError, match=r"read permissions.*T1\.mgz"): with ArgvSetter(args): mne_watershed_bem.run() os.chmod(new_fname, old_mode) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - assert not op.isfile(op.join(subject_path_new, 'bem', '%s.surf' % s)) + for s in ("outer_skin", "outer_skull", "inner_skull"): + assert not op.isfile(op.join(subject_path_new, "bem", "%s.surf" % s)) with ArgvSetter(args): mne_watershed_bem.run() kwargs = dict(rtol=1e-5, atol=1e-5) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - rr, tris, vol_info = read_surface(op.join(subject_path_new, 'bem', - '%s.surf' % s), - read_metadata=True) + for s in ("outer_skin", "outer_skull", "inner_skull"): + rr, tris, vol_info = read_surface( + op.join(subject_path_new, "bem", "%s.surf" % s), read_metadata=True + ) assert_equal(len(tris), 20480) assert_equal(tris.min(), 0) assert_equal(rr.shape[0], tris.max() + 1) # compare the volume info to the mgz header - assert_allclose(vol_info['xras'], Mdc[0], **kwargs) - assert_allclose(vol_info['yras'], Mdc[1], **kwargs) - assert_allclose(vol_info['zras'], Mdc[2], **kwargs) - assert_allclose(vol_info['cras'], Pxyz_c, **kwargs) + assert_allclose(vol_info["xras"], Mdc[0], **kwargs) + assert_allclose(vol_info["yras"], Mdc[1], **kwargs) + assert_allclose(vol_info["zras"], Mdc[2], **kwargs) + assert_allclose(vol_info["cras"], Pxyz_c, **kwargs) @pytest.mark.timeout(180) # took ~70 s locally @@ -272,33 +322,38 @@ def test_flash_bem(tmp_path): check_usage(mne_flash_bem, force_help=True) # Copy necessary files to tempdir tempdir = Path(str(tmp_path)) - mridata_path = Path(subjects_dir) / 'sample' / 'mri' - subject_path_new = tempdir / 'sample' - mridata_path_new = subject_path_new / 'mri' - flash_path = mridata_path_new / 'flash' + mridata_path = Path(subjects_dir) / "sample" / "mri" + subject_path_new = tempdir / "sample" + mridata_path_new = subject_path_new / "mri" + flash_path = mridata_path_new / "flash" flash_path.mkdir(parents=True, exist_ok=True) - bem_path = mridata_path_new / 'bem' + bem_path = mridata_path_new / "bem" bem_path.mkdir(parents=True, exist_ok=True) - shutil.copyfile(op.join(mridata_path, 'T1.mgz'), - op.join(mridata_path_new, 'T1.mgz')) - shutil.copyfile(op.join(mridata_path, 'brain.mgz'), - op.join(mridata_path_new, 'brain.mgz')) + shutil.copyfile( + op.join(mridata_path, "T1.mgz"), op.join(mridata_path_new, "T1.mgz") + ) + shutil.copyfile( + op.join(mridata_path, "brain.mgz"), op.join(mridata_path_new, "brain.mgz") + ) # Copy the available mri/flash/mef*.mgz files from the dataset for kind in (5, 30): - in_fname = mridata_path / "flash" / f'mef{kind:02d}.mgz' - in_fname_echo = flash_path / f'mef{kind:02d}_001.mgz' + in_fname = mridata_path / "flash" / f"mef{kind:02d}.mgz" + in_fname_echo = flash_path / f"mef{kind:02d}_001.mgz" shutil.copyfile(in_fname, flash_path / in_fname_echo.name) # Test mne flash_bem with --noconvert option # (since there are no DICOM Flash images in dataset) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - assert not op.isfile(subject_path_new / 'bem' / f'{s}.surf') + for s in ("outer_skin", "outer_skull", "inner_skull"): + assert not op.isfile(subject_path_new / "bem" / f"{s}.surf") # First test without flash30 - with ArgvSetter(('-d', tempdir, '-s', 'sample', '-n', '-r', '-3'), - disable_stdout=False, disable_stderr=False): + with ArgvSetter( + ("-d", tempdir, "-s", "sample", "-n", "-r", "-3"), + disable_stdout=False, + disable_stderr=False, + ): mne_flash_bem.run() - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - surf_path = subject_path_new / 'bem' / f'{s}.surf' + for s in ("outer_skin", "outer_skull", "inner_skull"): + surf_path = subject_path_new / "bem" / f"{s}.surf" assert surf_path.exists() surf_path.unlink() # cleanup shutil.rmtree(flash_path / "parameter_maps") # remove old files @@ -313,22 +368,33 @@ def test_flash_bem(tmp_path): # Test with flash5 and flash30 shutil.rmtree(flash_path) # first remove old files - with ArgvSetter(('-d', tempdir, '-s', 'sample', '-n', - '-3', str(mridata_path / "flash" / 'mef30.mgz'), - '-5', str(mridata_path / "flash" / 'mef05.mgz')), - disable_stdout=False, disable_stderr=False): + with ArgvSetter( + ( + "-d", + tempdir, + "-s", + "sample", + "-n", + "-3", + str(mridata_path / "flash" / "mef30.mgz"), + "-5", + str(mridata_path / "flash" / "mef05.mgz"), + ), + disable_stdout=False, + disable_stderr=False, + ): mne_flash_bem.run() kwargs = dict(rtol=1e-5, atol=1e-5) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - rr, tris = read_surface(op.join(subject_path_new, 'bem', - '%s.surf' % s)) + for s in ("outer_skin", "outer_skull", "inner_skull"): + rr, tris = read_surface(op.join(subject_path_new, "bem", "%s.surf" % s)) assert_equal(len(tris), 5120) assert_equal(tris.min(), 0) assert_equal(rr.shape[0], tris.max() + 1) # compare to the testing flash surfaces - rr_c, tris_c = read_surface(op.join(subjects_dir, 'sample', 'bem', - '%s.surf' % s)) + rr_c, tris_c = read_surface( + op.join(subjects_dir, "sample", "bem", "%s.surf" % s) + ) assert_allclose(rr, rr_c, **kwargs) assert_allclose(tris, tris_c, **kwargs) @@ -336,29 +402,80 @@ def test_flash_bem(tmp_path): @testing.requires_testing_data def test_setup_source_space(tmp_path): """Test mne setup_source_space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_setup_source_space, force_help=True) # Using the sample dataset use_fname = op.join(tmp_path, "sources-src.fif") # Test command - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--morph', 'sample', - '--add-dist', 'False', '--ico', '3', '--verbose')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--morph", + "sample", + "--add-dist", + "False", + "--ico", + "3", + "--verbose", + ) + ): mne_setup_source_space.run() src = read_source_spaces(use_fname) assert len(src) == 2 with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--oct', '3')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--oct", + "3", + ) + ): assert mne_setup_source_space.run() with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--spacing', '10')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--spacing", + "10", + ) + ): assert mne_setup_source_space.run() with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--spacing', '10', - '--oct', '3')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--spacing", + "10", + "--oct", + "3", + ) + ): assert mne_setup_source_space.run() @@ -366,17 +483,29 @@ def test_setup_source_space(tmp_path): @testing.requires_testing_data def test_setup_forward_model(tmp_path): """Test mne setup_forward_model.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_setup_forward_model, force_help=True) # Using the sample dataset use_fname = op.join(tmp_path, "model-bem.fif") # Test command - with ArgvSetter(('--model', use_fname, '-d', subjects_dir, '--homog', - '-s', 'sample', '--ico', '3', '--verbose')): + with ArgvSetter( + ( + "--model", + use_fname, + "-d", + subjects_dir, + "--homog", + "-s", + "sample", + "--ico", + "3", + "--verbose", + ) + ): mne_setup_forward_model.run() model = read_bem_surfaces(use_fname) assert len(model) == 1 - sol_fname = op.splitext(use_fname)[0] + '-sol.fif' + sol_fname = op.splitext(use_fname)[0] + "-sol.fif" read_bem_solution(sol_fname) @@ -388,8 +517,9 @@ def test_mne_prepare_bem_model(tmp_path): # Using the sample dataset bem_solution_fname = op.join(tmp_path, "bem_solution-bem-sol.fif") # Test command - with ArgvSetter(('--bem', bem_model_fname, '--sol', bem_solution_fname, - '--verbose')): + with ArgvSetter( + ("--bem", bem_model_fname, "--sol", bem_solution_fname, "--verbose") + ): mne_prepare_bem_model.run() bem_solution = read_bem_solution(bem_solution_fname) assert isinstance(bem_solution, ConductorModel) @@ -406,19 +536,19 @@ def test_sys_info(): """Test mne show_info.""" check_usage(mne_sys_info, force_help=True) with ArgvSetter((raw_fname,)): - with pytest.raises(SystemExit, match='1'): + with pytest.raises(SystemExit, match="1"): mne_sys_info.run() with ArgvSetter() as out: mne_sys_info.run() - assert 'numpy' in out.stdout.getvalue() + assert "numpy" in out.stdout.getvalue() def test_anonymize(tmp_path): """Test mne anonymize.""" check_usage(mne_anonymize) - out_fname = op.join(tmp_path, 'anon_test_raw.fif') - with ArgvSetter(('-f', raw_fname, '-o', out_fname)): + out_fname = op.join(tmp_path, "anon_test_raw.fif") + with ArgvSetter(("-f", raw_fname, "-o", out_fname)): mne_anonymize.run() info = read_info(out_fname) assert op.exists(out_fname) - assert info['meas_date'] == _stamp_to_dt((946684800, 0)) + assert info["meas_date"] == _stamp_to_dt((946684800, 0)) diff --git a/mne/commands/utils.py b/mne/commands/utils.py index 415f513cad1..80d04ab1729 100644 --- a/mne/commands/utils.py +++ b/mne/commands/utils.py @@ -16,9 +16,13 @@ def _add_verbose_flag(parser): - parser.add_option("--verbose", dest='verbose', - help="Enable verbose mode (printing of log messages).", - default=None, action="http://github.com/store_true") + parser.add_option( + "--verbose", + dest="verbose", + help="Enable verbose mode (printing of log messages).", + default=None, + action="http://github.com/store_true", + ) def load_module(name, path): @@ -38,31 +42,32 @@ def load_module(name, path): """ from importlib.util import spec_from_file_location, module_from_spec + spec = spec_from_file_location(name, path) mod = module_from_spec(spec) spec.loader.exec_module(mod) return mod -def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): +def get_optparser(cmdpath, usage=None, prog_prefix="mne", version=None): """Create OptionParser with cmd specific settings (e.g., prog value).""" # Fetch description - mod = load_module('__temp', cmdpath) + mod = load_module("__temp", cmdpath) if mod.__doc__: doc, description, epilog = mod.__doc__, None, None - doc_lines = doc.split('\n') + doc_lines = doc.split("\n") description = doc_lines[0] if len(doc_lines) > 1: - epilog = '\n'.join(doc_lines[1:]) + epilog = "\n".join(doc_lines[1:]) # Get the name of the command command = os.path.basename(cmdpath) command, _ = os.path.splitext(command) - command = command[len(prog_prefix) + 1:] # +1 is for `_` character + command = command[len(prog_prefix) + 1 :] # +1 is for `_` character # Set prog - prog = prog_prefix + ' {}'.format(command) + prog = prog_prefix + " {}".format(command) # Set version if version is None: @@ -70,10 +75,9 @@ def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): # monkey patch OptionParser to not wrap epilog OptionParser.format_epilog = lambda self, formatter: self.epilog - parser = OptionParser(prog=prog, - version=version, - description=description, - epilog=epilog, usage=usage) + parser = OptionParser( + prog=prog, version=version, description=description, epilog=epilog, usage=usage + ) return parser @@ -81,8 +85,7 @@ def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): def main(): """Entrypoint for mne usage.""" mne_bin_dir = op.dirname(op.dirname(__file__)) - valid_commands = sorted(glob.glob(op.join(mne_bin_dir, - 'commands', 'mne_*.py'))) + valid_commands = sorted(glob.glob(op.join(mne_bin_dir, "commands", "mne_*.py"))) valid_commands = [c.split(op.sep)[-1][4:-3] for c in valid_commands] def print_help(): # noqa @@ -102,6 +105,6 @@ def print_help(): # noqa print_help() else: cmd = sys.argv[1] - cmd = importlib.import_module('.mne_%s' % (cmd,), 'mne.commands') + cmd = importlib.import_module(".mne_%s" % (cmd,), "mne.commands") sys.argv = sys.argv[1:] cmd.run() diff --git a/mne/conftest.py b/mne/conftest.py index 72e95b6e788..c99d2eeebe1 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -26,81 +26,93 @@ from mne.fixes import has_numba, _compare_version from mne.io import read_raw_fif, read_raw_ctf, read_raw_nirx, read_raw_snirf from mne.stats import cluster_level -from mne.utils import (_pl, _assert_no_instances, numerics, Bunch, - _check_qt_version, _TempDir, check_version) +from mne.utils import ( + _pl, + _assert_no_instances, + numerics, + Bunch, + _check_qt_version, + _TempDir, + check_version, +) # data from sample dataset from mne.viz._figure import use_browser_backend from mne.viz.backends._utils import _init_mne_qtapp test_path = testing.data_path(download=False) -s_path = op.join(test_path, 'MEG', 'sample') -fname_evoked = op.join(s_path, 'sample_audvis_trunc-ave.fif') -fname_cov = op.join(s_path, 'sample_audvis_trunc-cov.fif') -fname_fwd = op.join(s_path, 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif') -fname_fwd_full = op.join(s_path, 'sample_audvis_trunc-meg-eeg-oct-6-fwd.fif') -bem_path = op.join(test_path, 'subjects', 'sample', 'bem') -fname_bem = op.join(bem_path, 'sample-1280-bem.fif') -fname_aseg = op.join(test_path, 'subjects', 'sample', 'mri', 'aseg.mgz') -subjects_dir = op.join(test_path, 'subjects') -fname_src = op.join(bem_path, 'sample-oct-4-src.fif') -fname_trans = op.join(s_path, 'sample_audvis_trunc-trans.fif') - -ctf_dir = op.join(test_path, 'CTF') -fname_ctf_continuous = op.join(ctf_dir, 'testdata_ctf.ds') - -nirx_path = test_path / 'NIRx' -snirf_path = test_path / 'SNIRF' -nirsport2 = nirx_path / 'nirsport_v2' / 'aurora_recording _w_short_and_acc' -nirsport2_snirf = ( - snirf_path / 'NIRx' / 'NIRSport2' / '1.0.3' / - '2021-05-05_001.snirf') -nirsport2_2021_9 = nirx_path / 'nirsport_v2' / 'aurora_2021_9' +s_path = op.join(test_path, "MEG", "sample") +fname_evoked = op.join(s_path, "sample_audvis_trunc-ave.fif") +fname_cov = op.join(s_path, "sample_audvis_trunc-cov.fif") +fname_fwd = op.join(s_path, "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif") +fname_fwd_full = op.join(s_path, "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif") +bem_path = op.join(test_path, "subjects", "sample", "bem") +fname_bem = op.join(bem_path, "sample-1280-bem.fif") +fname_aseg = op.join(test_path, "subjects", "sample", "mri", "aseg.mgz") +subjects_dir = op.join(test_path, "subjects") +fname_src = op.join(bem_path, "sample-oct-4-src.fif") +fname_trans = op.join(s_path, "sample_audvis_trunc-trans.fif") + +ctf_dir = op.join(test_path, "CTF") +fname_ctf_continuous = op.join(ctf_dir, "testdata_ctf.ds") + +nirx_path = test_path / "NIRx" +snirf_path = test_path / "SNIRF" +nirsport2 = nirx_path / "nirsport_v2" / "aurora_recording _w_short_and_acc" +nirsport2_snirf = snirf_path / "NIRx" / "NIRSport2" / "1.0.3" / "2021-05-05_001.snirf" +nirsport2_2021_9 = nirx_path / "nirsport_v2" / "aurora_2021_9" nirsport2_20219_snirf = ( - snirf_path / 'NIRx' / 'NIRSport2' / '2021.9' / - '2021-10-01_002.snirf') + snirf_path / "NIRx" / "NIRSport2" / "2021.9" / "2021-10-01_002.snirf" +) # data from mne.io.tests.data -base_dir = op.join(op.dirname(__file__), 'io', 'tests', 'data') -fname_raw_io = op.join(base_dir, 'test_raw.fif') -fname_event_io = op.join(base_dir, 'test-eve.fif') -fname_cov_io = op.join(base_dir, 'test-cov.fif') -fname_evoked_io = op.join(base_dir, 'test-ave.fif') +base_dir = op.join(op.dirname(__file__), "io", "tests", "data") +fname_raw_io = op.join(base_dir, "test_raw.fif") +fname_event_io = op.join(base_dir, "test-eve.fif") +fname_cov_io = op.join(base_dir, "test-cov.fif") +fname_evoked_io = op.join(base_dir, "test-ave.fif") event_id, tmin, tmax = 1, -0.1, 1.0 -vv_layout = read_layout('Vectorview-all') +vv_layout = read_layout("Vectorview-all") -collect_ignore = [ - 'export/_brainvision.py', - 'export/_eeglab.py', - 'export/_edf.py'] +collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf.py"] def pytest_configure(config): """Configure pytest options.""" # Markers - for marker in ('slowtest', 'ultraslowtest', 'pgtest', 'allow_unclosed', - 'allow_unclosed_pyside2'): - config.addinivalue_line('markers', marker) + for marker in ( + "slowtest", + "ultraslowtest", + "pgtest", + "allow_unclosed", + "allow_unclosed_pyside2", + ): + config.addinivalue_line("markers", marker) # Fixtures - for fixture in ('matplotlib_config', 'close_all', 'check_verbose', - 'qt_config', 'protect_config'): - config.addinivalue_line('usefixtures', fixture) + for fixture in ( + "matplotlib_config", + "close_all", + "check_verbose", + "qt_config", + "protect_config", + ): + config.addinivalue_line("usefixtures", fixture) # pytest-qt uses PYTEST_QT_API, but let's make it respect qtpy's QT_API # if present - if os.getenv('PYTEST_QT_API') is None and os.getenv('QT_API') is not None: - os.environ['PYTEST_QT_API'] = os.environ['QT_API'] + if os.getenv("PYTEST_QT_API") is None and os.getenv("QT_API") is not None: + os.environ["PYTEST_QT_API"] = os.environ["QT_API"] # Warnings # - Once SciPy updates not to have non-integer and non-tuple errors (1.2.0) # we should remove them from here. # - This list should also be considered alongside reset_warnings in # doc/conf.py. - if os.getenv('MNE_IGNORE_WARNINGS_IN_TESTS', '') != 'true': - first_kind = 'error' + if os.getenv("MNE_IGNORE_WARNINGS_IN_TESTS", "") != "true": + first_kind = "error" else: - first_kind = 'always' + first_kind = "always" warning_lines = f" {first_kind}::" warning_lines += r""" # matplotlib->traitlets (notebook) @@ -143,10 +155,10 @@ def pytest_configure(config): # h5py ignore:`product` is deprecated as of NumPy.*:DeprecationWarning """ # noqa: E501 - for warning_line in warning_lines.split('\n'): + for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() - if warning_line and not warning_line.startswith('#'): - config.addinivalue_line('filterwarnings', warning_line) + if warning_line and not warning_line.startswith("#"): + config.addinivalue_line("filterwarnings", warning_line) # Have to be careful with autouse=True, but this is just an int comparison @@ -160,9 +172,10 @@ def check_verbose(request): try: assert mne.utils.logger.level == starting_level except AssertionError: - pytest.fail('.'.join([request.module.__name__, - request.function.__name__]) + - ' modifies logger.level') + pytest.fail( + ".".join([request.module.__name__, request.function.__name__]) + + " modifies logger.level" + ) @pytest.fixture(autouse=True) @@ -170,8 +183,9 @@ def close_all(): """Close all matplotlib plots, regardless of test status.""" # This adds < 1 µS in local testing, and we have ~2500 tests, so ~2 ms max import matplotlib.pyplot as plt + yield - plt.close('all') + plt.close("all") @pytest.fixture(autouse=True) @@ -180,44 +194,46 @@ def add_mne(doctest_namespace): doctest_namespace["mne"] = mne -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def verbose_debug(): """Run a test with debug verbosity.""" - with mne.utils.use_log_level('debug'): + with mne.utils.use_log_level("debug"): yield -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def qt_config(): """Configure the Qt backend for viz tests.""" - os.environ['_MNE_BROWSER_NO_BLOCK'] = 'true' + os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def matplotlib_config(): """Configure matplotlib for viz tests.""" import matplotlib from matplotlib import cbook + # Allow for easy interactive debugging with a call like: # # $ MNE_MPL_TESTING_BACKEND=Qt5Agg pytest mne/viz/tests/test_raw.py -k annotation -x --pdb # noqa: E501 # try: - want = os.environ['MNE_MPL_TESTING_BACKEND'] + want = os.environ["MNE_MPL_TESTING_BACKEND"] except KeyError: - want = 'agg' # don't pop up windows + want = "agg" # don't pop up windows with warnings.catch_warnings(record=True): # ignore warning - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") matplotlib.use(want, force=True) import matplotlib.pyplot as plt + assert plt.get_backend() == want # overwrite some params that can horribly slow down tests that # users might have changed locally (but should not otherwise affect # functionality) plt.ioff() - plt.rcParams['figure.dpi'] = 100 + plt.rcParams["figure.dpi"] = 100 try: - plt.rcParams['figure.raise_window'] = False + plt.rcParams["figure.raise_window"] = False except KeyError: # MPL < 3.3 pass @@ -231,21 +247,22 @@ def __init__(self, exception_handler=None, signals=None): cbook.CallbackRegistry = CallbackRegistryReraise -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def azure_windows(): """Determine if running on Azure Windows.""" - return (os.getenv('AZURE_CI_WINDOWS', 'false').lower() == 'true' and - sys.platform.startswith('win')) + return os.getenv( + "AZURE_CI_WINDOWS", "false" + ).lower() == "true" and sys.platform.startswith("win") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw_orig(): """Get raw data without any change to it from mne.io.tests.data.""" raw = read_raw_fif(fname_raw_io, preload=True) return raw -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw(): """ Get raw data and pick channels to reduce load for testing. @@ -254,21 +271,21 @@ def raw(): """ raw = read_raw_fif(fname_raw_io, preload=True) # Throws a warning about a changed unit. - with pytest.warns(RuntimeWarning, match='unit'): - raw.set_channel_types({raw.ch_names[0]: 'ias'}) + with pytest.warns(RuntimeWarning, match="unit"): + raw.set_channel_types({raw.ch_names[0]: "ias"}) raw.pick_channels(raw.ch_names[:9]) raw.info.normalize_proj() # Fix projectors after subselection return raw -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw_ctf(): """Get ctf raw data from mne.io.tests.data.""" raw_ctf = read_raw_ctf(fname_ctf_continuous, preload=True) return raw_ctf -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def events(): """Get events from mne.io.tests.data.""" return read_events(fname_event_io) @@ -278,13 +295,22 @@ def _get_epochs(stop=5, meg=True, eeg=False, n_chan=20): """Get epochs.""" raw = read_raw_fif(fname_raw_io) events = read_events(fname_event_io) - picks = pick_types(raw.info, meg=meg, eeg=eeg, stim=False, - ecg=False, eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=meg, eeg=eeg, stim=False, ecg=False, eog=False, exclude="bads" + ) # Use a subset of channels for plotting speed picks = np.round(np.linspace(0, len(picks) + 1, n_chan)).astype(int) - with pytest.warns(RuntimeWarning, match='projection'): - epochs = Epochs(raw, events[:stop], event_id, tmin, tmax, picks=picks, - proj=False, preload=False) + with pytest.warns(RuntimeWarning, match="projection"): + epochs = Epochs( + raw, + events[:stop], + event_id, + tmin, + tmax, + picks=picks, + proj=False, + preload=False, + ) epochs.info.normalize_proj() # avoid warnings return epochs @@ -311,12 +337,13 @@ def epochs_full(): return _get_epochs(None).load_data() -@pytest.fixture(scope='session', params=[testing._pytest_param()]) +@pytest.fixture(scope="session", params=[testing._pytest_param()]) def _evoked(): # This one is session scoped, so be sure not to modify it (use evoked # instead) - evoked = mne.read_evokeds(fname_evoked, condition='Left Auditory', - baseline=(None, 0)) + evoked = mne.read_evokeds( + fname_evoked, condition="Left Auditory", baseline=(None, 0) + ) evoked.crop(0, 0.2) return evoked @@ -327,7 +354,7 @@ def evoked(_evoked): return _evoked.copy() -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def noise_cov(): """Get a noise cov from the testing dataset.""" return mne.read_cov(fname_cov) @@ -339,45 +366,44 @@ def noise_cov_io(): return mne.read_cov(fname_cov_io) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def bias_params_free(evoked, noise_cov): """Provide inputs for free bias functions.""" fwd = mne.read_forward_solution(fname_fwd) return _bias_params(evoked, noise_cov, fwd) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def bias_params_fixed(evoked, noise_cov): """Provide inputs for fixed bias functions.""" fwd = mne.read_forward_solution(fname_fwd) - mne.convert_forward_solution( - fwd, force_fixed=True, surf_ori=True, copy=False) + mne.convert_forward_solution(fwd, force_fixed=True, surf_ori=True, copy=False) return _bias_params(evoked, noise_cov, fwd) def _bias_params(evoked, noise_cov, fwd): evoked.pick_types(meg=True, eeg=True, exclude=()) # restrict to limited set of verts (small src here) and one hemi for speed - vertices = [fwd['src'][0]['vertno'].copy(), []] + vertices = [fwd["src"][0]["vertno"].copy(), []] stc = mne.SourceEstimate( - np.zeros((sum(len(v) for v in vertices), 1)), vertices, 0, 1) + np.zeros((sum(len(v) for v in vertices), 1)), vertices, 0, 1 + ) fwd = mne.forward.restrict_forward_to_stc(fwd, stc) - assert fwd['sol']['row_names'] == noise_cov['names'] - assert noise_cov['names'] == evoked.ch_names - evoked = mne.EvokedArray(fwd['sol']['data'].copy(), evoked.info) + assert fwd["sol"]["row_names"] == noise_cov["names"] + assert noise_cov["names"] == evoked.ch_names + evoked = mne.EvokedArray(fwd["sol"]["data"].copy(), evoked.info) data_cov = noise_cov.copy() - data = fwd['sol']['data'] @ fwd['sol']['data'].T + data = fwd["sol"]["data"] @ fwd["sol"]["data"].T data *= 1e-14 # 100 nAm at each source, effectively (1e-18 would be 1 nAm) # This is rank-deficient, so let's make it actually positive semidefinite # by regularizing a tiny bit - data.flat[::data.shape[0] + 1] += mne.make_ad_hoc_cov(evoked.info)['data'] + data.flat[:: data.shape[0] + 1] += mne.make_ad_hoc_cov(evoked.info)["data"] # Do our projection - proj, _, _ = mne.io.proj.make_projector( - data_cov['projs'], data_cov['names']) + proj, _, _ = mne.io.proj.make_projector(data_cov["projs"], data_cov["names"]) data = proj @ data @ proj.T - data_cov['data'][:] = data - assert data_cov['data'].shape[0] == len(noise_cov['names']) - want = np.arange(fwd['sol']['data'].shape[1]) + data_cov["data"][:] = data + assert data_cov["data"].shape[0] == len(noise_cov["names"]) + want = np.arange(fwd["sol"]["data"].shape[1]) if not mne.forward.is_fixed_orient(fwd): want //= 3 return evoked, fwd, noise_cov, data_cov, want @@ -393,42 +419,42 @@ def garbage_collect(): @pytest.fixture def mpl_backend(garbage_collect): """Use for epochs/ica when not implemented with pyqtgraph yet.""" - with use_browser_backend('matplotlib') as backend: + with use_browser_backend("matplotlib") as backend: yield backend backend._close_all() # Skip functions or modules for mne-qt-browser < 0.2.0 -pre_2_0_skip_modules = ['mne.viz.tests.test_epochs', - 'mne.viz.tests.test_ica'] -pre_2_0_skip_funcs = ['test_plot_raw_white', - 'test_plot_raw_selection'] +pre_2_0_skip_modules = ["mne.viz.tests.test_epochs", "mne.viz.tests.test_ica"] +pre_2_0_skip_funcs = ["test_plot_raw_white", "test_plot_raw_selection"] def _check_pyqtgraph(request): # Check Qt qt_version, api = _check_qt_version(return_api=True) - if (not qt_version) or _compare_version(qt_version, '<', '5.12'): - pytest.skip(f'Qt API {api} has version {qt_version} ' - f'but pyqtgraph needs >= 5.12!') + if (not qt_version) or _compare_version(qt_version, "<", "5.12"): + pytest.skip( + f"Qt API {api} has version {qt_version} " f"but pyqtgraph needs >= 5.12!" + ) try: import mne_qt_browser # noqa: F401 + # Check mne-qt-browser version - lower_2_0 = _compare_version(mne_qt_browser.__version__, '<', '0.2.0') + lower_2_0 = _compare_version(mne_qt_browser.__version__, "<", "0.2.0") m_name = request.function.__module__ f_name = request.function.__name__ if lower_2_0 and m_name in pre_2_0_skip_modules: - pytest.skip(f'Test-Module "{m_name}" was skipped for' - f' mne-qt-browser < 0.2.0') + pytest.skip( + f'Test-Module "{m_name}" was skipped for' f" mne-qt-browser < 0.2.0" + ) elif lower_2_0 and f_name in pre_2_0_skip_funcs: - pytest.skip(f'Test "{f_name}" was skipped for ' - f'mne-qt-browser < 0.2.0') + pytest.skip(f'Test "{f_name}" was skipped for ' f"mne-qt-browser < 0.2.0") except Exception: - pytest.skip('Requires mne_qt_browser') + pytest.skip("Requires mne_qt_browser") else: ver = mne_qt_browser.__version__ - if api != 'PyQt5' and _compare_version(ver, '<=', '0.2.6'): - pytest.skip(f'mne_qt_browser {ver} requires PyQt5, API is {api}') + if api != "PyQt5" and _compare_version(ver, "<=", "0.2.6"): + pytest.skip(f"mne_qt_browser {ver} requires PyQt5, API is {api}") @pytest.fixture @@ -436,35 +462,39 @@ def pg_backend(request, garbage_collect): """Use for pyqtgraph-specific test-functions.""" _check_pyqtgraph(request) from mne_qt_browser._pg_figure import MNEQtBrowser - with use_browser_backend('qt') as backend: + + with use_browser_backend("qt") as backend: backend._close_all() yield backend backend._close_all() # This shouldn't be necessary, but let's make sure nothing is stale import mne_qt_browser + mne_qt_browser._browser_instances.clear() - if check_version('mne_qt_browser', min_version='0.4'): - _assert_no_instances( - MNEQtBrowser, f'Closure of {request.node.name}') + if check_version("mne_qt_browser", min_version="0.4"): + _assert_no_instances(MNEQtBrowser, f"Closure of {request.node.name}") -@pytest.fixture(params=[ - 'matplotlib', - pytest.param('qt', marks=pytest.mark.pgtest), -]) +@pytest.fixture( + params=[ + "matplotlib", + pytest.param("qt", marks=pytest.mark.pgtest), + ] +) def browser_backend(request, garbage_collect, monkeypatch): """Parametrizes the name of the browser backend.""" backend_name = request.param - if backend_name == 'qt': + if backend_name == "qt": _check_pyqtgraph(request) with use_browser_backend(backend_name) as backend: backend._close_all() - monkeypatch.setenv('MNE_BROWSE_RAW_SIZE', '10,10') + monkeypatch.setenv("MNE_BROWSE_RAW_SIZE", "10,10") yield backend backend._close_all() - if backend_name == 'qt': + if backend_name == "qt": # This shouldn't be necessary, but let's make sure nothing is stale import mne_qt_browser + mne_qt_browser._browser_instances.clear() @@ -506,9 +536,11 @@ def renderer_interactive(request, options_3d): @contextmanager def _use_backend(backend_name, interactive): from mne.viz.backends.renderer import _use_test_3d_backend + _check_skip_backend(backend_name) with _use_test_3d_backend(backend_name, interactive=interactive): from mne.viz.backends import renderer + try: yield renderer finally: @@ -516,34 +548,39 @@ def _use_backend(backend_name, interactive): def _check_skip_backend(name): - from mne.viz.backends.tests._utils import (has_pyvista, - has_imageio_ffmpeg, - has_pyvistaqt) + from mne.viz.backends.tests._utils import ( + has_pyvista, + has_imageio_ffmpeg, + has_pyvistaqt, + ) from mne.viz.backends._utils import _notebook_vtk_works + if not has_pyvista(): pytest.skip("Test skipped, requires pyvista.") if not has_imageio_ffmpeg(): pytest.skip("Test skipped, requires imageio-ffmpeg") - if name == 'pyvistaqt': + if name == "pyvistaqt": if not _check_qt_version(): pytest.skip("Test skipped, requires Qt.") if not has_pyvistaqt(): pytest.skip("Test skipped, requires pyvistaqt") else: - assert name == 'notebook', name + assert name == "notebook", name if not _notebook_vtk_works(): pytest.skip("Test skipped, requires working notebook vtk") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def pixel_ratio(): """Get the pixel ratio.""" from mne.viz.backends.tests._utils import has_pyvista + # _check_qt_version will init an app for us, so no need for us to do it if not has_pyvista() or not _check_qt_version(): - return 1. + return 1.0 from qtpy.QtWidgets import QMainWindow from qtpy.QtCore import Qt + app = _init_mne_qtapp() app.processEvents() window = QMainWindow() @@ -553,10 +590,10 @@ def pixel_ratio(): return ratio -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def subjects_dir_tmp(tmp_path): """Copy MNE-testing-data subjects_dir to a temp dir for manipulation.""" - for key in ('sample', 'fsaverage'): + for key in ("sample", "fsaverage"): shutil.copytree(op.join(subjects_dir, key), str(tmp_path / key)) return str(tmp_path) @@ -564,59 +601,64 @@ def subjects_dir_tmp(tmp_path): @pytest.fixture(params=[testing._pytest_param()]) def subjects_dir_tmp_few(tmp_path): """Copy fewer files to a tmp_path.""" - subjects_path = tmp_path / 'subjects' + subjects_path = tmp_path / "subjects" os.mkdir(subjects_path) # add fsaverage - create_default_subject(subjects_dir=subjects_path, fs_home=test_path, - verbose=True) + create_default_subject(subjects_dir=subjects_path, fs_home=test_path, verbose=True) # add sample (with few files) - sample_path = subjects_path / 'sample' - os.makedirs(sample_path / 'bem') - for dirname in ('mri', 'surf'): + sample_path = subjects_path / "sample" + os.makedirs(sample_path / "bem") + for dirname in ("mri", "surf"): shutil.copytree( - test_path / 'subjects' / 'sample' / dirname, sample_path / dirname) + test_path / "subjects" / "sample" / dirname, sample_path / dirname + ) return subjects_path # Scoping these as session will make things faster, but need to make sure # not to modify them in-place in the tests, so keep them private -@pytest.fixture(scope='session', params=[testing._pytest_param()]) +@pytest.fixture(scope="session", params=[testing._pytest_param()]) def _evoked_cov_sphere(_evoked): """Compute a small evoked/cov/sphere combo for use with forwards.""" evoked = _evoked.copy().pick_types(meg=True) evoked.pick_channels(evoked.ch_names[::4]) assert len(evoked.ch_names) == 77 cov = mne.read_cov(fname_cov) - sphere = mne.make_sphere_model('auto', 'auto', evoked.info) + sphere = mne.make_sphere_model("auto", "auto", evoked.info) return evoked, cov, sphere -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _fwd_surf(_evoked_cov_sphere): """Compute a forward for a surface source space.""" evoked, cov, sphere = _evoked_cov_sphere src_surf = mne.read_source_spaces(fname_src) return mne.make_forward_solution( - evoked.info, fname_trans, src_surf, sphere, mindist=5.0) + evoked.info, fname_trans, src_surf, sphere, mindist=5.0 + ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _fwd_subvolume(_evoked_cov_sphere): """Compute a forward for a surface source space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") evoked, cov, sphere = _evoked_cov_sphere - volume_labels = ['Left-Cerebellum-Cortex', 'right-Cerebellum-Cortex'] - with pytest.raises(ValueError, - match=r"Did you mean one of \['Right-Cere"): + volume_labels = ["Left-Cerebellum-Cortex", "right-Cerebellum-Cortex"] + with pytest.raises(ValueError, match=r"Did you mean one of \['Right-Cere"): mne.setup_volume_source_space( - 'sample', pos=20., volume_label=volume_labels, - subjects_dir=subjects_dir) - volume_labels[1] = 'R' + volume_labels[1][1:] + "sample", pos=20.0, volume_label=volume_labels, subjects_dir=subjects_dir + ) + volume_labels[1] = "R" + volume_labels[1][1:] src_vol = mne.setup_volume_source_space( - 'sample', pos=20., volume_label=volume_labels, - subjects_dir=subjects_dir, add_interpolator=False) + "sample", + pos=20.0, + volume_label=volume_labels, + subjects_dir=subjects_dir, + add_interpolator=False, + ) return mne.make_forward_solution( - evoked.info, fname_trans, src_vol, sphere, mindist=5.0) + evoked.info, fname_trans, src_vol, sphere, mindist=5.0 + ) @pytest.fixture @@ -625,52 +667,50 @@ def fwd_volume_small(_fwd_subvolume): return _fwd_subvolume.copy() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _all_src_types_fwd(_fwd_surf, _fwd_subvolume): """Create all three forward types (surf, vol, mixed).""" - fwds = dict( - surface=_fwd_surf.copy(), - volume=_fwd_subvolume.copy()) - with pytest.raises(RuntimeError, - match='Invalid source space with kinds'): - fwds['volume']['src'] + fwds['surface']['src'] + fwds = dict(surface=_fwd_surf.copy(), volume=_fwd_subvolume.copy()) + with pytest.raises(RuntimeError, match="Invalid source space with kinds"): + fwds["volume"]["src"] + fwds["surface"]["src"] # mixed (4) - fwd = fwds['surface'].copy() - f2 = fwds['volume'].copy() + fwd = fwds["surface"].copy() + f2 = fwds["volume"].copy() del _fwd_surf, _fwd_subvolume - for keys, axis in [(('source_rr',), 0), - (('source_nn',), 0), - (('sol', 'data'), 1), - (('_orig_sol',), 1)]: + for keys, axis in [ + (("source_rr",), 0), + (("source_nn",), 0), + (("sol", "data"), 1), + (("_orig_sol",), 1), + ]: a, b = fwd, f2 key = keys[0] if len(keys) > 1: a, b = a[key], b[key] key = keys[1] a[key] = np.concatenate([a[key], b[key]], axis=axis) - fwd['sol']['ncol'] = fwd['sol']['data'].shape[1] - fwd['nsource'] = fwd['sol']['ncol'] // 3 - fwd['src'] = fwd['src'] + f2['src'] - fwds['mixed'] = fwd + fwd["sol"]["ncol"] = fwd["sol"]["data"].shape[1] + fwd["nsource"] = fwd["sol"]["ncol"] // 3 + fwd["src"] = fwd["src"] + f2["src"] + fwds["mixed"] = fwd return fwds -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _all_src_types_inv_evoked(_evoked_cov_sphere, _all_src_types_fwd): """Compute inverses for all source types.""" evoked, cov, _ = _evoked_cov_sphere invs = dict() for kind, fwd in _all_src_types_fwd.items(): - assert fwd['src'].kind == kind - with pytest.warns(RuntimeWarning, match='has been reduced'): - invs[kind] = mne.minimum_norm.make_inverse_operator( - evoked.info, fwd, cov) + assert fwd["src"].kind == kind + with pytest.warns(RuntimeWarning, match="has been reduced"): + invs[kind] = mne.minimum_norm.make_inverse_operator(evoked.info, fwd, cov) return invs, evoked -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def all_src_types_inv_evoked(_all_src_types_inv_evoked): """All source types of inverses, allowing for possible modification.""" invs, evoked = _all_src_types_inv_evoked @@ -679,42 +719,48 @@ def all_src_types_inv_evoked(_all_src_types_inv_evoked): return invs, evoked -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def mixed_fwd_cov_evoked(_evoked_cov_sphere, _all_src_types_fwd): """Compute inverses for all source types.""" evoked, cov, _ = _evoked_cov_sphere - return _all_src_types_fwd['mixed'].copy(), cov.copy(), evoked.copy() + return _all_src_types_fwd["mixed"].copy(), cov.copy(), evoked.copy() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") @pytest.mark.slowtest @pytest.mark.parametrize(params=[testing._pytest_param()]) def src_volume_labels(): """Create a 7mm source space with labels.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") volume_labels = mne.get_volume_labels_from_aseg(fname_aseg) - with pytest.warns(RuntimeWarning, match='Found no usable.*Left-vessel.*'): + with pytest.warns(RuntimeWarning, match="Found no usable.*Left-vessel.*"): src = mne.setup_volume_source_space( - 'sample', 7., mri='aseg.mgz', volume_label=volume_labels, - add_interpolator=False, bem=fname_bem, - subjects_dir=subjects_dir) + "sample", + 7.0, + mri="aseg.mgz", + volume_label=volume_labels, + add_interpolator=False, + bem=fname_bem, + subjects_dir=subjects_dir, + ) lut, _ = mne.read_freesurfer_lut() assert len(volume_labels) == 46 - assert volume_labels[0] == 'Unknown' - assert lut['Unknown'] == 0 # it will be excluded during label gen + assert volume_labels[0] == "Unknown" + assert lut["Unknown"] == 0 # it will be excluded during label gen return src, tuple(volume_labels), lut def _fail(*args, **kwargs): __tracebackhide__ = True - raise AssertionError('Test should not download') + raise AssertionError("Test should not download") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def download_is_error(monkeypatch): """Prevent downloading by raising an error when it's attempted.""" import pooch - monkeypatch.setattr(pooch, 'retrieve', _fail) + + monkeypatch.setattr(pooch, "retrieve", _fail) yield @@ -722,14 +768,14 @@ def download_is_error(monkeypatch): def fake_retrieve(monkeypatch, download_is_error): """Monkeypatch pooch.retrieve to avoid downloading (just touch files).""" import pooch + my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - monkeypatch.setattr(pooch, 'create', my_func) + monkeypatch.setattr(pooch, "retrieve", my_func) + monkeypatch.setattr(pooch, "create", my_func) yield my_func class _FakeFetch: - def __init__(self): self.call_args_list = list() @@ -739,15 +785,15 @@ def call_count(self): # Wrapper for pooch.retrieve(...) and pooch.create(...) def __call__(self, *args, **kwargs): - assert 'path' in kwargs - if 'fname' in kwargs: # pooch.retrieve(...) + assert "path" in kwargs + if "fname" in kwargs: # pooch.retrieve(...) self.call_args_list.append((args, kwargs)) - path = Path(kwargs['path'], kwargs['fname']) + path = Path(kwargs["path"], kwargs["fname"]) path.parent.mkdir(parents=True, exist_ok=True) - path.write_text('test') + path.write_text("test") return path else: # pooch.create(...) has been called - self.path = kwargs['path'] + self.path = kwargs["path"] return self # Wrappers for Pooch instances (e.g., in eegbci we pooch.create) @@ -761,20 +807,21 @@ def load_registry(self, registry): # We can't use monkeypatch because its scope (function-level) conflicts with # the requests fixture (module-level), so we live with a module-scoped version # that uses mock -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def options_3d(): """Disable advanced 3d rendering.""" with mock.patch.dict( - os.environ, { + os.environ, + { "MNE_3D_OPTION_ANTIALIAS": "false", "MNE_3D_OPTION_DEPTH_PEELING": "false", "MNE_3D_OPTION_SMOOTH_SHADING": "false", - } + }, ): yield -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def protect_config(): """Protect ~/.mne.""" temp = _TempDir() @@ -786,23 +833,23 @@ def protect_config(): def brain_gc(request): """Ensure that brain can be properly garbage collected.""" keys = ( - 'renderer_interactive', - 'renderer_interactive_pyvistaqt', - 'renderer', - 'renderer_pyvistaqt', - 'renderer_notebook', + "renderer_interactive", + "renderer_interactive_pyvistaqt", + "renderer", + "renderer_pyvistaqt", + "renderer_notebook", ) assert set(request.fixturenames) & set(keys) != set() for key in keys: if key in request.fixturenames: - is_pv = \ - request.getfixturevalue(key)._get_3d_backend() == 'pyvistaqt' + is_pv = request.getfixturevalue(key)._get_3d_backend() == "pyvistaqt" close_func = request.getfixturevalue(key).backend._close_all break if not is_pv: yield return from mne.viz import Brain + ignore = set(id(o) for o in gc.get_objects()) yield close_func() @@ -810,10 +857,10 @@ def brain_gc(request): try: outcome = request.node.harvest_rep_call except Exception: - outcome = 'failed' - if outcome != 'passed': + outcome = "failed" + if outcome != "passed": return - _assert_no_instances(Brain, 'after') + _assert_no_instances(Brain, "after") # Check VTK objs = gc.get_objects() bad = list() @@ -823,11 +870,11 @@ def brain_gc(request): except Exception: # old Python, probably pass else: - if name.startswith('vtk') and id(o) not in ignore: + if name.startswith("vtk") and id(o) not in ignore: bad.append(name) del o del objs, ignore, Brain - assert len(bad) == 0, 'VTK objects linger:\n' + '\n'.join(bad) + assert len(bad) == 0, "VTK objects linger:\n" + "\n".join(bad) _files = list() @@ -838,26 +885,26 @@ def pytest_sessionfinish(session, exitstatus): n = session.config.option.durations if n is None: return - print('\n') + print("\n") try: import pytest_harvest except ImportError: - print('Module-level timings require pytest-harvest') + print("Module-level timings require pytest-harvest") return # get the number to print res = pytest_harvest.get_session_synthesis_dct(session) files = dict() for key, val in res.items(): - parts = Path(key.split(':')[0]).parts + parts = Path(key.split(":")[0]).parts # split mne/tests/test_whatever.py into separate categories since these # are essentially submodule-level tests. Keeping just [:3] works, # except for mne/viz where we want level-4 granulatity - split_submodules = (('mne', 'viz'), ('mne', 'preprocessing')) - parts = parts[:4 if parts[:2] in split_submodules else 3] - if not parts[-1].endswith('.py'): - parts = parts + ('',) - file_key = '/'.join(parts) - files[file_key] = files.get(file_key, 0) + val['pytest_duration_s'] + split_submodules = (("mne", "viz"), ("mne", "preprocessing")) + parts = parts[: 4 if parts[:2] in split_submodules else 3] + if not parts[-1].endswith(".py"): + parts = parts + ("",) + file_key = "/".join(parts) + files[file_key] = files.get(file_key, 0) + val["pytest_duration_s"] files = sorted(list(files.items()), key=lambda x: x[1])[::-1] # print _files[:] = files[:n] @@ -868,36 +915,38 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): writer = terminalreporter n = len(_files) if n: - writer.line('') # newline - writer.write_sep('=', f'slowest {n} test module{_pl(n)}') + writer.line("") # newline + writer.write_sep("=", f"slowest {n} test module{_pl(n)}") names, timings = zip(*_files) - timings = [f'{timing:0.2f}s total' for timing in timings] + timings = [f"{timing:0.2f}s total" for timing in timings] rjust = max(len(timing) for timing in timings) timings = [timing.rjust(rjust) for timing in timings] for name, timing in zip(names, timings): - writer.line(f'{timing.ljust(15)}{name}') + writer.line(f"{timing.ljust(15)}{name}") -@pytest.fixture(scope="function", params=('Numba', 'NumPy')) +@pytest.fixture(scope="function", params=("Numba", "NumPy")) def numba_conditional(monkeypatch, request): """Test both code paths on machines that have Numba.""" - assert request.param in ('Numba', 'NumPy') - if request.param == 'NumPy' and has_numba: + assert request.param in ("Numba", "NumPy") + if request.param == "NumPy" and has_numba: monkeypatch.setattr( - cluster_level, '_get_buddies', cluster_level._get_buddies_fallback) + cluster_level, "_get_buddies", cluster_level._get_buddies_fallback + ) monkeypatch.setattr( - cluster_level, '_get_selves', cluster_level._get_selves_fallback) + cluster_level, "_get_selves", cluster_level._get_selves_fallback + ) monkeypatch.setattr( - cluster_level, '_where_first', cluster_level._where_first_fallback) - monkeypatch.setattr( - numerics, '_arange_div', numerics._arange_div_fallback) - if request.param == 'Numba' and not has_numba: - pytest.skip('Numba not installed') + cluster_level, "_where_first", cluster_level._where_first_fallback + ) + monkeypatch.setattr(numerics, "_arange_div", numerics._arange_div_fallback) + if request.param == "Numba" and not has_numba: + pytest.skip("Numba not installed") yield request.param # Create one nbclient and reuse it -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _nbclient(): try: import nbformat @@ -906,9 +955,10 @@ def _nbclient(): from ipywidgets import Button # noqa import ipyvtklink # noqa except Exception as exc: - return pytest.skip(f'Skipping Notebook test: {exc}') + return pytest.skip(f"Skipping Notebook test: {exc}") km = AsyncKernelManager(config=None) - nb = nbformat.reads(""" + nb = nbformat.reads( + """ { "cells": [ { @@ -934,7 +984,9 @@ def _nbclient(): }, "nbformat": 4, "nbformat_minor": 4 -}""", as_version=4) +}""", + as_version=4, + ) client = NotebookClient(nb, km=km) yield client try: @@ -943,7 +995,7 @@ def _nbclient(): pass -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def nbexec(_nbclient): """Execute Python code in a notebook.""" # Adapted/simplified from nbclient/client.py (BSD-3-Clause) @@ -953,7 +1005,7 @@ def execute(code, reset=False): _nbclient.reset_execution_trackers() with _nbclient.setup_kernel(): assert _nbclient.kc is not None - cell = Bunch(cell_type='code', metadata={}, source=dedent(code)) + cell = Bunch(cell_type="code", metadata={}, source=dedent(code)) _nbclient.execute_cell(cell, 0, execution_count=0) _nbclient.set_widgets_metadata() @@ -962,15 +1014,15 @@ def execute(code, reset=False): def pytest_runtest_call(item): """Run notebook code written in Python.""" - if 'nbexec' in getattr(item, 'fixturenames', ()): - nbexec = item.funcargs['nbexec'] - code = inspect.getsource(getattr(item.module, item.name.split('[')[0])) + if "nbexec" in getattr(item, "fixturenames", ()): + nbexec = item.funcargs["nbexec"] + code = inspect.getsource(getattr(item.module, item.name.split("[")[0])) code = code.splitlines() ci = 0 for ci, c in enumerate(code): - if c.startswith(' '): # actual content + if c.startswith(" "): # actual content break - code = '\n'.join(code[ci:]) + code = "\n".join(code[ci:]) def run(nbexec=nbexec, code=code): nbexec(code) @@ -979,27 +1031,32 @@ def run(nbexec=nbexec, code=code): return -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') -@pytest.fixture(params=( - [nirsport2, nirsport2_snirf, testing._pytest_param()], - [nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()], -)) +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") +@pytest.fixture( + params=( + [nirsport2, nirsport2_snirf, testing._pytest_param()], + [nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()], + ) +) def nirx_snirf(request): """Return a (raw_nirx, raw_snirf) matched pair.""" - pytest.importorskip('h5py') + pytest.importorskip("h5py") skipper = request.param[2].marks[0].mark if skipper.args[0]: # will skip - pytest.skip(skipper.kwargs['reason']) - return (read_raw_nirx(request.param[0], preload=True), - read_raw_snirf(request.param[1], preload=True)) + pytest.skip(skipper.kwargs["reason"]) + return ( + read_raw_nirx(request.param[0], preload=True), + read_raw_snirf(request.param[1], preload=True), + ) @pytest.fixture def qt_windows_closed(request): """Ensure that no new Qt windows are open after a test.""" - _check_skip_backend('pyvistaqt') + _check_skip_backend("pyvistaqt") app = _init_mne_qtapp() from qtpy import API_NAME + app.processEvents() gc.collect() n_before = len(app.topLevelWidgets()) @@ -1007,9 +1064,9 @@ def qt_windows_closed(request): yield app.processEvents() gc.collect() - if 'allow_unclosed' in marks: + if "allow_unclosed" in marks: return - if 'allow_unclosed_pyside2' in marks and API_NAME.lower() == 'pyside2': + if "allow_unclosed_pyside2" in marks and API_NAME.lower() == "pyside2": return # Don't check when the test fails report = request.node.stash[_phase_report_key] diff --git a/mne/coreg.py b/mne/coreg.py index 3e21f3ff917..fc9a7a9753f 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -21,6 +21,7 @@ from .io.constants import FIFF from .io.meas_info import Info from .io._digitization import _get_data_as_dict_from_dig + # keep get_mni_fiducials for backward compat (no burden to keep in this # namespace, too) from ._freesurfer import ( @@ -34,52 +35,82 @@ read_source_spaces, # noqa: F401 write_source_spaces, ) -from .surface import (read_surface, write_surface, _normalize_vectors, - complete_surface_info, decimate_surface, - _DistanceQuery) +from .surface import ( + read_surface, + write_surface, + _normalize_vectors, + complete_surface_info, + decimate_surface, + _DistanceQuery, +) from .bem import read_bem_surfaces, write_bem_surfaces -from .transforms import (rotation, rotation3d, scaling, translation, Transform, - _read_fs_xfm, _write_fs_xfm, invert_transform, - combine_transforms, _quat_to_euler, - _fit_matched_points, apply_trans, - rot_to_quat, _angle_between_quats) +from .transforms import ( + rotation, + rotation3d, + scaling, + translation, + Transform, + _read_fs_xfm, + _write_fs_xfm, + invert_transform, + combine_transforms, + _quat_to_euler, + _fit_matched_points, + apply_trans, + rot_to_quat, + _angle_between_quats, +) from .channels import make_dig_montage -from .utils import (get_config, get_subjects_dir, logger, pformat, verbose, - warn, fill_doc, _validate_type, - _check_subject, _check_option, _import_nibabel) +from .utils import ( + get_config, + get_subjects_dir, + logger, + pformat, + verbose, + warn, + fill_doc, + _validate_type, + _check_subject, + _check_option, + _import_nibabel, +) from .viz._3d import _fiducial_coords # some path templates -trans_fname = os.path.join('{raw_dir}', '{subject}-trans.fif') -subject_dirname = os.path.join('{subjects_dir}', '{subject}') -bem_dirname = os.path.join(subject_dirname, 'bem') -mri_dirname = os.path.join(subject_dirname, 'mri') -mri_transforms_dirname = os.path.join(subject_dirname, 'mri', 'transforms') -surf_dirname = os.path.join(subject_dirname, 'surf') +trans_fname = os.path.join("{raw_dir}", "{subject}-trans.fif") +subject_dirname = os.path.join("{subjects_dir}", "{subject}") +bem_dirname = os.path.join(subject_dirname, "bem") +mri_dirname = os.path.join(subject_dirname, "mri") +mri_transforms_dirname = os.path.join(subject_dirname, "mri", "transforms") +surf_dirname = os.path.join(subject_dirname, "surf") bem_fname = os.path.join(bem_dirname, "{subject}-{name}.fif") -head_bem_fname = pformat(bem_fname, name='head') -head_sparse_fname = pformat(bem_fname, name='head-sparse') -fid_fname = pformat(bem_fname, name='fiducials') +head_bem_fname = pformat(bem_fname, name="head") +head_sparse_fname = pformat(bem_fname, name="head-sparse") +fid_fname = pformat(bem_fname, name="fiducials") fid_fname_general = os.path.join(bem_dirname, "{head}-fiducials.fif") -src_fname = os.path.join(bem_dirname, '{subject}-{spacing}-src.fif') -_head_fnames = (os.path.join(bem_dirname, 'outer_skin.surf'), - head_sparse_fname, - head_bem_fname) -_high_res_head_fnames = (os.path.join(bem_dirname, '{subject}-head-dense.fif'), - os.path.join(surf_dirname, 'lh.seghead'), - os.path.join(surf_dirname, 'lh.smseghead')) +src_fname = os.path.join(bem_dirname, "{subject}-{spacing}-src.fif") +_head_fnames = ( + os.path.join(bem_dirname, "outer_skin.surf"), + head_sparse_fname, + head_bem_fname, +) +_high_res_head_fnames = ( + os.path.join(bem_dirname, "{subject}-head-dense.fif"), + os.path.join(surf_dirname, "lh.seghead"), + os.path.join(surf_dirname, "lh.smseghead"), +) def _map_fid_name_to_idx(name: str) -> int: """Map a fiducial name to its index in the DigMontage.""" name = name.lower() - if name == 'lpa': + if name == "lpa": return 0 - elif name == 'nasion': + elif name == "nasion": return 1 else: - assert name == 'rpa' + assert name == "rpa" return 2 @@ -90,7 +121,7 @@ def _make_writable(fname): def _make_writable_recursive(path): """Recursively set writable.""" - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): return # can't safely set perms for root, dirs, files in os.walk(path, topdown=False): for f in dirs + files: @@ -132,21 +163,19 @@ def coregister_fiducials(info, fiducials, tol=0.01): fiducials, coord_frame_to = read_fiducials(fiducials) else: coord_frame_to = FIFF.FIFFV_COORD_MRI - frames_from = {d['coord_frame'] for d in info['dig']} + frames_from = {d["coord_frame"] for d in info["dig"]} if len(frames_from) > 1: - raise ValueError("info contains fiducials from different coordinate " - "frames") + raise ValueError("info contains fiducials from different coordinate " "frames") else: coord_frame_from = frames_from.pop() - coords_from = _fiducial_coords(info['dig']) + coords_from = _fiducial_coords(info["dig"]) coords_to = _fiducial_coords(fiducials, coord_frame_to) trans = fit_matched_points(coords_from, coords_to, tol=tol) return Transform(coord_frame_from, coord_frame_to, trans) @verbose -def create_default_subject(fs_home=None, update=False, subjects_dir=None, - verbose=None): +def create_default_subject(fs_home=None, update=False, subjects_dir=None, verbose=None): """Create an average brain subject for subjects without structural MRI. Create a copy of fsaverage from the Freesurfer directory in subjects_dir @@ -177,37 +206,43 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) if fs_home is None: - fs_home = get_config('FREESURFER_HOME', fs_home) + fs_home = get_config("FREESURFER_HOME", fs_home) if fs_home is None: raise ValueError( "FREESURFER_HOME environment variable not found. Please " "specify the fs_home parameter in your call to " - "create_default_subject().") + "create_default_subject()." + ) # make sure freesurfer files exist - fs_src = os.path.join(fs_home, 'subjects', 'fsaverage') + fs_src = os.path.join(fs_home, "subjects", "fsaverage") if not os.path.exists(fs_src): - raise OSError('fsaverage not found at %r. Is fs_home specified ' - 'correctly?' % fs_src) - for name in ('label', 'mri', 'surf'): + raise OSError( + "fsaverage not found at %r. Is fs_home specified " "correctly?" % fs_src + ) + for name in ("label", "mri", "surf"): dirname = os.path.join(fs_src, name) if not os.path.isdir(dirname): - raise OSError("Freesurfer fsaverage seems to be incomplete: No " - "directory named %s found in %s" % (name, fs_src)) + raise OSError( + "Freesurfer fsaverage seems to be incomplete: No " + "directory named %s found in %s" % (name, fs_src) + ) # make sure destination does not already exist - dest = os.path.join(subjects_dir, 'fsaverage') + dest = os.path.join(subjects_dir, "fsaverage") if dest == fs_src: raise OSError( "Your subjects_dir points to the freesurfer subjects_dir (%r). " "The default subject can not be created in the freesurfer " "installation directory; please specify a different " - "subjects_dir." % subjects_dir) + "subjects_dir." % subjects_dir + ) elif (not update) and os.path.exists(dest): raise OSError( "Can not create fsaverage because %r already exists in " "subjects_dir %r. Delete or rename the existing fsaverage " - "subject folder." % ('fsaverage', subjects_dir)) + "subject folder." % ("fsaverage", subjects_dir) + ) # copy fsaverage from freesurfer logger.info("Copying fsaverage subject from freesurfer directory...") @@ -216,15 +251,16 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, _make_writable_recursive(dest) # copy files from mne - source_fname = os.path.join(os.path.dirname(__file__), 'data', 'fsaverage', - 'fsaverage-%s.fif') - dest_bem = os.path.join(dest, 'bem') + source_fname = os.path.join( + os.path.dirname(__file__), "data", "fsaverage", "fsaverage-%s.fif" + ) + dest_bem = os.path.join(dest, "bem") if not os.path.exists(dest_bem): os.mkdir(dest_bem) logger.info("Copying auxiliary fsaverage files from mne...") - dest_fname = os.path.join(dest_bem, 'fsaverage-%s.fif') + dest_fname = os.path.join(dest_bem, "fsaverage-%s.fif") _make_writable_recursive(dest_bem) - for name in ('fiducials', 'head', 'inner_skull-bem', 'trans'): + for name in ("fiducials", "head", "inner_skull-bem", "trans"): if not os.path.exists(dest_fname % name): shutil.copy(source_fname % name, dest_bem) @@ -249,10 +285,11 @@ def _decimate_points(pts, res=10): The decimated points. """ from scipy.spatial.distance import cdist + pts = np.asarray(pts) # find the bin edges for the voxel space - xmin, ymin, zmin = pts.min(0) - res / 2. + xmin, ymin, zmin = pts.min(0) - res / 2.0 xmax, ymax, zmax = pts.max(0) + res xax = np.arange(xmin, xmax, res) yax = np.arange(ymin, ymax, res) @@ -264,19 +301,18 @@ def _decimate_points(pts, res=10): x = xax[xbins] y = yax[ybins] z = zax[zbins] - mids = np.c_[x, y, z] + res / 2. + mids = np.c_[x, y, z] + res / 2.0 # each point belongs to at most one voxel center, so figure those out # (cKDTree faster than BallTree for these small problems) - tree = _DistanceQuery(mids, method='cKDTree') + tree = _DistanceQuery(mids, method="cKDTree") _, mid_idx = tree.query(pts) # then figure out which to actually use based on proximity # (take advantage of sorting the mid_idx to get our mapping of # pts to nearest voxel midpoint) sort_idx = np.argsort(mid_idx) - bounds = np.cumsum( - np.concatenate([[0], np.bincount(mid_idx, minlength=len(mids))])) + bounds = np.cumsum(np.concatenate([[0], np.bincount(mid_idx, minlength=len(mids))])) assert len(bounds) == len(mids) + 1 out = list() for mi, mid in enumerate(mids): @@ -287,14 +323,13 @@ def _decimate_points(pts, res=10): # But it's faster for many points than making a big boolean indexer # over and over (esp. since each point can only belong to a single # voxel). - use_pts = pts[sort_idx[bounds[mi]:bounds[mi + 1]]] + use_pts = pts[sort_idx[bounds[mi] : bounds[mi + 1]]] if not len(use_pts): out.append([np.inf] * 3) else: - out.append( - use_pts[np.argmin(cdist(use_pts, mid[np.newaxis])[:, 0])]) + out.append(use_pts[np.argmin(cdist(use_pts, mid[np.newaxis])[:, 0])]) out = np.array(out, float).reshape(-1, 3) - out = out[np.abs(out - mids).max(axis=1) < res / 2.] + out = out[np.abs(out - mids).max(axis=1) < res / 2.0] # """ return out @@ -312,7 +347,7 @@ def _trans_from_params(param_info, params): i += 3 if do_translate: - x, y, z = params[i:i + 3] + x, y, z = params[i : i + 3] trans.insert(0, translation(x, y, z)) i += 3 @@ -320,7 +355,7 @@ def _trans_from_params(param_info, params): s = params[i] trans.append(scaling(s, s, s)) elif do_scale == 3: - x, y, z = params[i:i + 3] + x, y, z = params[i : i + 3] trans.append(scaling(x, y, z)) trans = reduce(np.dot, trans) @@ -331,9 +366,17 @@ def _trans_from_params(param_info, params): # XXX this function should be moved out of coreg as used elsewhere -def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, - scale=False, tol=None, x0=None, out='trans', - weights=None): +def fit_matched_points( + src_pts, + tgt_pts, + rotate=True, + translate=True, + scale=False, + tol=None, + x0=None, + out="trans", + weights=None, +): """Find a transform between matched sets of points. This minimizes the squared distance between two matching sets of points. @@ -378,13 +421,21 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, src_pts = np.atleast_2d(src_pts) tgt_pts = np.atleast_2d(tgt_pts) if src_pts.shape != tgt_pts.shape: - raise ValueError("src_pts and tgt_pts must have same shape (got " - "{}, {})".format(src_pts.shape, tgt_pts.shape)) + raise ValueError( + "src_pts and tgt_pts must have same shape (got " + "{}, {})".format(src_pts.shape, tgt_pts.shape) + ) if weights is not None: weights = np.asarray(weights, src_pts.dtype) if weights.ndim != 1 or weights.size not in (src_pts.shape[0], 1): - raise ValueError("weights (shape=%s) must be None or have shape " - "(%s,)" % (weights.shape, src_pts.shape[0],)) + raise ValueError( + "weights (shape=%s) must be None or have shape " + "(%s,)" + % ( + weights.shape, + src_pts.shape[0], + ) + ) weights = weights[:, np.newaxis] param_info = (bool(rotate), bool(translate), int(scale)) @@ -397,15 +448,14 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, tgt_pts = np.asarray(tgt_pts, float) if weights is not None: weights = np.asarray(weights, float) - x, s = _fit_matched_points( - src_pts, tgt_pts, weights, bool(param_info[2])) + x, s = _fit_matched_points(src_pts, tgt_pts, weights, bool(param_info[2])) x[:3] = _quat_to_euler(x[:3]) x = np.concatenate((x, [s])) if param_info[2] else x else: x = _generic_fit(src_pts, tgt_pts, param_info, weights, x0) # re-create the final transformation matrix - if (tol is not None) or (out == 'trans'): + if (tol is not None) or (out == "trans"): trans = _trans_from_params(param_info, x) # assess the error of the solution @@ -416,21 +466,24 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, if np.any(err > tol): raise RuntimeError("Error exceeds tolerance. Error = %r" % err) - if out == 'params': + if out == "params": return x - elif out == 'trans': + elif out == "trans": return trans else: - raise ValueError("Invalid out parameter: %r. Needs to be 'params' or " - "'trans'." % out) + raise ValueError( + "Invalid out parameter: %r. Needs to be 'params' or " "'trans'." % out + ) def _generic_fit(src_pts, tgt_pts, param_info, weights, x0): from scipy.optimize import leastsq + if param_info[1]: # translate src_pts = np.hstack((src_pts, np.ones((len(src_pts), 1)))) if param_info == (True, False, 0): + def error(x): rx, ry, rz = x trans = rotation3d(rx, ry, rz) @@ -439,9 +492,11 @@ def error(x): if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0) elif param_info == (True, True, 0): + def error(x): rx, ry, rz, tx, ty, tz = x trans = np.dot(translation(tx, ty, tz), rotation(rx, ry, rz)) @@ -450,44 +505,52 @@ def error(x): if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0) elif param_info == (True, True, 1): + def error(x): rx, ry, rz, tx, ty, tz, s = x - trans = reduce(np.dot, (translation(tx, ty, tz), - rotation(rx, ry, rz), - scaling(s, s, s))) + trans = reduce( + np.dot, + (translation(tx, ty, tz), rotation(rx, ry, rz), scaling(s, s, s)), + ) est = np.dot(src_pts, trans.T)[:, :3] d = tgt_pts - est if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0, 1) elif param_info == (True, True, 3): + def error(x): rx, ry, rz, tx, ty, tz, sx, sy, sz = x - trans = reduce(np.dot, (translation(tx, ty, tz), - rotation(rx, ry, rz), - scaling(sx, sy, sz))) + trans = reduce( + np.dot, + (translation(tx, ty, tz), rotation(rx, ry, rz), scaling(sx, sy, sz)), + ) est = np.dot(src_pts, trans.T)[:, :3] d = tgt_pts - est if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0, 1, 1, 1) else: raise NotImplementedError( "The specified parameter combination is not implemented: " - "rotate=%r, translate=%r, scale=%r" % param_info) + "rotate=%r, translate=%r, scale=%r" % param_info + ) x, _, _, _, _ = leastsq(error, x0, full_output=True) return x -def _find_label_paths(subject='fsaverage', pattern=None, subjects_dir=None): +def _find_label_paths(subject="fsaverage", pattern=None, subjects_dir=None): """Find paths to label files in a subject's label directory. Parameters @@ -515,7 +578,7 @@ def _find_label_paths(subject='fsaverage', pattern=None, subjects_dir=None): paths = [] for dirpath, _, filenames in os.walk(lbl_dir): rel_dir = os.path.relpath(dirpath, lbl_dir) - for filename in fnmatch.filter(filenames, '*.label'): + for filename in fnmatch.filter(filenames, "*.label"): path = os.path.join(rel_dir, filename) paths.append(path) else: @@ -548,41 +611,56 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): paths = {} # directories to create - paths['dirs'] = [bem_dirname, surf_dirname] + paths["dirs"] = [bem_dirname, surf_dirname] # surf/ files - paths['surf'] = [] - surf_fname = os.path.join(surf_dirname, '{name}') - surf_names = ('inflated', 'white', 'orig', 'orig_avg', 'inflated_avg', - 'inflated_pre', 'pial', 'pial_avg', 'smoothwm', 'white_avg', - 'seghead', 'smseghead') - if os.getenv('_MNE_FEW_SURFACES', '') == 'true': # for testing + paths["surf"] = [] + surf_fname = os.path.join(surf_dirname, "{name}") + surf_names = ( + "inflated", + "white", + "orig", + "orig_avg", + "inflated_avg", + "inflated_pre", + "pial", + "pial_avg", + "smoothwm", + "white_avg", + "seghead", + "smseghead", + ) + if os.getenv("_MNE_FEW_SURFACES", "") == "true": # for testing surf_names = surf_names[:4] for surf_name in surf_names: - for hemi in ('lh.', 'rh.'): + for hemi in ("lh.", "rh."): name = hemi + surf_name - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=name + ) if os.path.exists(path): - paths['surf'].append(pformat(surf_fname, name=name)) - surf_fname = os.path.join(bem_dirname, '{name}') - surf_names = ('inner_skull.surf', 'outer_skull.surf', 'outer_skin.surf') + paths["surf"].append(pformat(surf_fname, name=name)) + surf_fname = os.path.join(bem_dirname, "{name}") + surf_names = ("inner_skull.surf", "outer_skull.surf", "outer_skin.surf") for surf_name in surf_names: - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=surf_name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=surf_name + ) if os.path.exists(path): - paths['surf'].append(pformat(surf_fname, name=surf_name)) + paths["surf"].append(pformat(surf_fname, name=surf_name)) del surf_names, surf_name, path, hemi # BEM files - paths['bem'] = bem = [] + paths["bem"] = bem = [] path = head_bem_fname.format(subjects_dir=subjects_dir, subject=subject) if os.path.exists(path): - bem.append('head') - bem_pattern = pformat(bem_fname, subjects_dir=subjects_dir, - subject=subject, name='*-bem') - re_pattern = pformat(bem_fname, subjects_dir=subjects_dir, subject=subject, - name='(.+)').replace('\\', '\\\\') + bem.append("head") + bem_pattern = pformat( + bem_fname, subjects_dir=subjects_dir, subject=subject, name="*-bem" + ) + re_pattern = pformat( + bem_fname, subjects_dir=subjects_dir, subject=subject, name="(.+)" + ).replace("\\", "\\\\") for path in iglob(bem_pattern): match = re.match(re_pattern, path) name = match.group(1) @@ -591,54 +669,57 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): # fiducials if skip_fiducials: - paths['fid'] = [] + paths["fid"] = [] else: - paths['fid'] = _find_fiducials_files(subject, subjects_dir) + paths["fid"] = _find_fiducials_files(subject, subjects_dir) # check that we found at least one - if len(paths['fid']) == 0: - raise OSError("No fiducials file found for %s. The fiducials " - "file should be named " - "{subject}/bem/{subject}-fiducials.fif. In " - "order to scale an MRI without fiducials set " - "skip_fiducials=True." % subject) + if len(paths["fid"]) == 0: + raise OSError( + "No fiducials file found for %s. The fiducials " + "file should be named " + "{subject}/bem/{subject}-fiducials.fif. In " + "order to scale an MRI without fiducials set " + "skip_fiducials=True." % subject + ) # duplicate files (curvature and some surfaces) - paths['duplicate'] = [] - path = os.path.join(surf_dirname, '{name}') - surf_fname = os.path.join(surf_dirname, '{name}') - surf_dup_names = ('curv', 'sphere', 'sphere.reg', 'sphere.reg.avg') + paths["duplicate"] = [] + path = os.path.join(surf_dirname, "{name}") + surf_fname = os.path.join(surf_dirname, "{name}") + surf_dup_names = ("curv", "sphere", "sphere.reg", "sphere.reg.avg") for surf_dup_name in surf_dup_names: - for hemi in ('lh.', 'rh.'): + for hemi in ("lh.", "rh."): name = hemi + surf_dup_name - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=name + ) if os.path.exists(path): - paths['duplicate'].append(pformat(surf_fname, name=name)) + paths["duplicate"].append(pformat(surf_fname, name=name)) del surf_dup_name, name, path, hemi # transform files (talairach) - paths['transforms'] = [] - transform_fname = os.path.join(mri_transforms_dirname, 'talairach.xfm') + paths["transforms"] = [] + transform_fname = os.path.join(mri_transforms_dirname, "talairach.xfm") path = transform_fname.format(subjects_dir=subjects_dir, subject=subject) if os.path.exists(path): - paths['transforms'].append(transform_fname) + paths["transforms"].append(transform_fname) del transform_fname, path # find source space files - paths['src'] = src = [] + paths["src"] = src = [] bem_dir = bem_dirname.format(subjects_dir=subjects_dir, subject=subject) - fnames = fnmatch.filter(os.listdir(bem_dir), '*-src.fif') - prefix = subject + '-' + fnames = fnmatch.filter(os.listdir(bem_dir), "*-src.fif") + prefix = subject + "-" for fname in fnames: if fname.startswith(prefix): - fname = "{subject}-%s" % fname[len(prefix):] + fname = "{subject}-%s" % fname[len(prefix) :] path = os.path.join(bem_dirname, fname) src.append(path) # find MRIs mri_dir = mri_dirname.format(subjects_dir=subjects_dir, subject=subject) - fnames = fnmatch.filter(os.listdir(mri_dir), '*.mgz') - paths['mri'] = [os.path.join(mri_dir, f) for f in fnames] + fnames = fnmatch.filter(os.listdir(mri_dir), "*.mgz") + paths["mri"] = [os.path.join(mri_dir, f) for f in fnames] return paths @@ -647,17 +728,18 @@ def _find_fiducials_files(subject, subjects_dir): """Find fiducial files.""" fid = [] # standard fiducials - if os.path.exists(fid_fname.format(subjects_dir=subjects_dir, - subject=subject)): + if os.path.exists(fid_fname.format(subjects_dir=subjects_dir, subject=subject)): fid.append(fid_fname) # fiducials with subject name - pattern = pformat(fid_fname_general, subjects_dir=subjects_dir, - subject=subject, head='*') - regex = pformat(fid_fname_general, subjects_dir=subjects_dir, - subject=subject, head='(.+)').replace('\\', '\\\\') + pattern = pformat( + fid_fname_general, subjects_dir=subjects_dir, subject=subject, head="*" + ) + regex = pformat( + fid_fname_general, subjects_dir=subjects_dir, subject=subject, head="(.+)" + ).replace("\\", "\\\\") for path in iglob(pattern): match = re.match(regex, path) - head = match.group(1).replace(subject, '{subject}') + head = match.group(1).replace(subject, "{subject}") fid.append(pformat(fid_fname_general, head=head)) return fid @@ -678,8 +760,10 @@ def _is_mri_subject(subject, subjects_dir=None): Whether ``subject`` is an mri subject. """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - return bool(_find_head_bem(subject, subjects_dir) or - _find_head_bem(subject, subjects_dir, high_res=True)) + return bool( + _find_head_bem(subject, subjects_dir) + or _find_head_bem(subject, subjects_dir, high_res=True) + ) def _is_scaled_mri_subject(subject, subjects_dir=None): @@ -720,8 +804,7 @@ def _mri_subject_has_bem(subject, subjects_dir=None): Whether ``subject`` has a bem file. """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - pattern = bem_fname.format(subjects_dir=subjects_dir, subject=subject, - name='*-bem') + pattern = bem_fname.format(subjects_dir=subjects_dir, subject=subject, name="*-bem") fnames = glob(pattern) return bool(len(fnames)) @@ -745,23 +828,28 @@ def read_mri_cfg(subject, subjects_dir=None): fname = subjects_dir / subject / "MRI scaling parameters.cfg" if not fname.exists(): - raise OSError("%r does not seem to be a scaled mri subject: %r does " - "not exist." % (subject, fname)) + raise OSError( + "%r does not seem to be a scaled mri subject: %r does " + "not exist." % (subject, fname) + ) logger.info("Reading MRI cfg file %s" % fname) config = configparser.RawConfigParser() config.read(fname) - n_params = config.getint("MRI Scaling", 'n_params') + n_params = config.getint("MRI Scaling", "n_params") if n_params == 1: - scale = config.getfloat("MRI Scaling", 'scale') + scale = config.getfloat("MRI Scaling", "scale") elif n_params == 3: - scale_str = config.get("MRI Scaling", 'scale') + scale_str = config.get("MRI Scaling", "scale") scale = np.array([float(s) for s in scale_str.split()]) else: raise ValueError("Invalid n_params value in MRI cfg: %i" % n_params) - out = {'subject_from': config.get("MRI Scaling", 'subject_from'), - 'n_params': n_params, 'scale': scale} + out = { + "subject_from": config.get("MRI Scaling", "subject_from"), + "n_params": n_params, + "scale": scale, + } return out @@ -787,15 +875,15 @@ def _write_mri_config(fname, subject_from, subject_to, scale): config = configparser.RawConfigParser() config.add_section("MRI Scaling") - config.set("MRI Scaling", 'subject_from', subject_from) - config.set("MRI Scaling", 'subject_to', subject_to) - config.set("MRI Scaling", 'n_params', str(n_params)) + config.set("MRI Scaling", "subject_from", subject_from) + config.set("MRI Scaling", "subject_to", subject_to) + config.set("MRI Scaling", "n_params", str(n_params)) if n_params == 1: - config.set("MRI Scaling", 'scale', str(scale)) + config.set("MRI Scaling", "scale", str(scale)) else: - config.set("MRI Scaling", 'scale', ' '.join([str(s) for s in scale])) - config.set("MRI Scaling", 'version', '1') - with open(fname, 'w') as fid: + config.set("MRI Scaling", "scale", " ".join([str(s) for s in scale])) + config.set("MRI Scaling", "version", "1") + with open(fname, "w") as fid: config.write(fid) @@ -816,27 +904,38 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): """ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) if (subject_from is None) != (scale is None): - raise TypeError("Need to provide either both subject_from and scale " - "parameters, or neither.") + raise TypeError( + "Need to provide either both subject_from and scale " + "parameters, or neither." + ) if subject_from is None: cfg = read_mri_cfg(subject_to, subjects_dir) - subject_from = cfg['subject_from'] - n_params = cfg['n_params'] + subject_from = cfg["subject_from"] + n_params = cfg["n_params"] assert n_params in (1, 3) - scale = cfg['scale'] + scale = cfg["scale"] scale = np.atleast_1d(scale) if scale.ndim != 1 or scale.shape[0] not in (1, 3): - raise ValueError("Invalid shape for scale parameter. Need scalar " - "or array of length 3. Got shape %s." - % (scale.shape,)) + raise ValueError( + "Invalid shape for scale parameter. Need scalar " + "or array of length 3. Got shape %s." % (scale.shape,) + ) n_params = len(scale) return str(subjects_dir), subject_from, scale, n_params == 1 @verbose -def scale_bem(subject_to, bem_name, subject_from=None, scale=None, - subjects_dir=None, *, on_defects='raise', verbose=None): +def scale_bem( + subject_to, + bem_name, + subject_from=None, + scale=None, + subjects_dir=None, + *, + on_defects="raise", + verbose=None, +): """Scale a bem file. Parameters @@ -860,29 +959,36 @@ def scale_bem(subject_to, bem_name, subject_from=None, scale=None, .. versionadded:: 1.0 %(verbose)s """ - subjects_dir, subject_from, scale, uniform = \ - _scale_params(subject_to, subject_from, scale, subjects_dir) + subjects_dir, subject_from, scale, uniform = _scale_params( + subject_to, subject_from, scale, subjects_dir + ) - src = bem_fname.format(subjects_dir=subjects_dir, subject=subject_from, - name=bem_name) - dst = bem_fname.format(subjects_dir=subjects_dir, subject=subject_to, - name=bem_name) + src = bem_fname.format( + subjects_dir=subjects_dir, subject=subject_from, name=bem_name + ) + dst = bem_fname.format(subjects_dir=subjects_dir, subject=subject_to, name=bem_name) if os.path.exists(dst): raise OSError("File already exists: %s" % dst) surfs = read_bem_surfaces(src, on_defects=on_defects) for surf in surfs: - surf['rr'] *= scale + surf["rr"] *= scale if not uniform: - assert len(surf['nn']) > 0 - surf['nn'] /= scale - _normalize_vectors(surf['nn']) + assert len(surf["nn"]) > 0 + surf["nn"] /= scale + _normalize_vectors(surf["nn"]) write_bem_surfaces(dst, surfs) -def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, - scale=None, subjects_dir=None): +def scale_labels( + subject_to, + pattern=None, + overwrite=False, + subject_from=None, + scale=None, + subjects_dir=None, +): r"""Scale labels to match a brain that was previously created by scaling. Parameters @@ -907,7 +1013,8 @@ def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, Override the ``SUBJECTS_DIR`` environment variable. """ subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) + subject_to, subject_from, scale, subjects_dir + ) # find labels paths = _find_label_paths(subject_from, pattern, subjects_dir) @@ -930,15 +1037,31 @@ def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, src = src_root / fname l_old = read_label(src) pos = l_old.pos * scale - l_new = Label(l_old.vertices, pos, l_old.values, l_old.hemi, - l_old.comment, subject=subject_to) + l_new = Label( + l_old.vertices, + pos, + l_old.values, + l_old.hemi, + l_old.comment, + subject=subject_to, + ) l_new.save(dst) @verbose -def scale_mri(subject_from, subject_to, scale, overwrite=False, - subjects_dir=None, skip_fiducials=False, labels=True, - annot=False, *, on_defects='raise', verbose=None): +def scale_mri( + subject_from, + subject_to, + scale, + overwrite=False, + subjects_dir=None, + skip_fiducials=False, + labels=True, + annot=False, + *, + on_defects="raise", + verbose=None, +): """Create a scaled copy of an MRI subject. Parameters @@ -984,99 +1107,119 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False, if np.isclose(scale[1], scale[0]) and np.isclose(scale[2], scale[0]): scale = scale[0] # speed up scaling conditionals using a singleton elif scale.shape != (1,): - raise ValueError('scale must have shape (3,) or (1,), got %s' - % (scale.shape,)) + raise ValueError("scale must have shape (3,) or (1,), got %s" % (scale.shape,)) # make sure we have an empty target directory - dest = subject_dirname.format(subject=subject_to, - subjects_dir=subjects_dir) + dest = subject_dirname.format(subject=subject_to, subjects_dir=subjects_dir) if os.path.exists(dest): if not overwrite: - raise OSError("Subject directory for %s already exists: %r" - % (subject_to, dest)) + raise OSError( + "Subject directory for %s already exists: %r" % (subject_to, dest) + ) shutil.rmtree(dest) - logger.debug('create empty directory structure') - for dirname in paths['dirs']: + logger.debug("create empty directory structure") + for dirname in paths["dirs"]: dir_ = dirname.format(subject=subject_to, subjects_dir=subjects_dir) os.makedirs(dir_) - logger.debug('save MRI scaling parameters') - fname = os.path.join(dest, 'MRI scaling parameters.cfg') + logger.debug("save MRI scaling parameters") + fname = os.path.join(dest, "MRI scaling parameters.cfg") _write_mri_config(fname, subject_from, subject_to, scale) - logger.debug('surf files [in mm]') - for fname in paths['surf']: + logger.debug("surf files [in mm]") + for fname in paths["surf"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) src = os.path.realpath(src) dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) pts, tri = read_surface(src) write_surface(dest, pts * scale, tri) - logger.debug('BEM files [in m]') - for bem_name in paths['bem']: - scale_bem(subject_to, bem_name, subject_from, scale, subjects_dir, - on_defects=on_defects, verbose=False) + logger.debug("BEM files [in m]") + for bem_name in paths["bem"]: + scale_bem( + subject_to, + bem_name, + subject_from, + scale, + subjects_dir, + on_defects=on_defects, + verbose=False, + ) - logger.debug('fiducials [in m]') - for fname in paths['fid']: + logger.debug("fiducials [in m]") + for fname in paths["fid"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) src = os.path.realpath(src) pts, cframe = read_fiducials(src, verbose=False) for pt in pts: - pt['r'] = pt['r'] * scale + pt["r"] = pt["r"] * scale dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) write_fiducials(dest, pts, cframe, overwrite=True, verbose=False) - logger.debug('MRIs [nibabel]') - os.mkdir(mri_dirname.format(subjects_dir=subjects_dir, - subject=subject_to)) - for fname in paths['mri']: + logger.debug("MRIs [nibabel]") + os.mkdir(mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to)) + for fname in paths["mri"]: mri_name = os.path.basename(fname) _scale_mri(subject_to, mri_name, subject_from, scale, subjects_dir) - logger.debug('Transforms') - for mri_name in paths['mri']: - if mri_name.endswith('T1.mgz'): - os.mkdir(mri_transforms_dirname.format(subjects_dir=subjects_dir, - subject=subject_to)) - for fname in paths['transforms']: + logger.debug("Transforms") + for mri_name in paths["mri"]: + if mri_name.endswith("T1.mgz"): + os.mkdir( + mri_transforms_dirname.format( + subjects_dir=subjects_dir, subject=subject_to + ) + ) + for fname in paths["transforms"]: xfm_name = os.path.basename(fname) - _scale_xfm(subject_to, xfm_name, mri_name, - subject_from, scale, subjects_dir) + _scale_xfm( + subject_to, xfm_name, mri_name, subject_from, scale, subjects_dir + ) break - logger.debug('duplicate files') - for fname in paths['duplicate']: + logger.debug("duplicate files") + for fname in paths["duplicate"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) shutil.copyfile(src, dest) - logger.debug('source spaces') - for fname in paths['src']: + logger.debug("source spaces") + for fname in paths["src"]: src_name = os.path.basename(fname) - scale_source_space(subject_to, src_name, subject_from, scale, - subjects_dir, verbose=False) + scale_source_space( + subject_to, src_name, subject_from, scale, subjects_dir, verbose=False + ) - logger.debug('labels [in m]') - os.mkdir(os.path.join(subjects_dir, subject_to, 'label')) + logger.debug("labels [in m]") + os.mkdir(os.path.join(subjects_dir, subject_to, "label")) if labels: - scale_labels(subject_to, subject_from=subject_from, scale=scale, - subjects_dir=subjects_dir) + scale_labels( + subject_to, + subject_from=subject_from, + scale=scale, + subjects_dir=subjects_dir, + ) - logger.debug('copy *.annot files') + logger.debug("copy *.annot files") # they don't contain scale-dependent information if annot: - src_pattern = os.path.join(subjects_dir, subject_from, 'label', - '*.annot') - dst_dir = os.path.join(subjects_dir, subject_to, 'label') + src_pattern = os.path.join(subjects_dir, subject_from, "label", "*.annot") + dst_dir = os.path.join(subjects_dir, subject_to, "label") for src_file in iglob(src_pattern): shutil.copy(src_file, dst_dir) @verbose -def scale_source_space(subject_to, src_name, subject_from=None, scale=None, - subjects_dir=None, n_jobs=None, verbose=None): +def scale_source_space( + subject_to, + src_name, + subject_from=None, + scale=None, + subjects_dir=None, + n_jobs=None, + verbose=None, +): """Scale a source space for an mri created with scale_mri(). Parameters @@ -1110,8 +1253,9 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, are updated so that source estimates can be plotted on the original MRI volume. """ - subjects_dir, subject_from, scale, uniform = \ - _scale_params(subject_to, subject_from, scale, subjects_dir) + subjects_dir, subject_from, scale, uniform = _scale_params( + subject_to, subject_from, scale, subjects_dir + ) # if n_params==1 scale is a scalar; if n_params==3 scale is a (3,) array # find the source space file names @@ -1121,45 +1265,46 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, else: match = re.match(r"(oct|ico|vol)-?(\d+)$", src_name) if match: - spacing = '-'.join(match.groups()) + spacing = "-".join(match.groups()) src_pattern = src_fname else: spacing = None src_pattern = os.path.join(bem_dirname, src_name) - src = src_pattern.format(subjects_dir=subjects_dir, subject=subject_from, - spacing=spacing) - dst = src_pattern.format(subjects_dir=subjects_dir, subject=subject_to, - spacing=spacing) + src = src_pattern.format( + subjects_dir=subjects_dir, subject=subject_from, spacing=spacing + ) + dst = src_pattern.format( + subjects_dir=subjects_dir, subject=subject_to, spacing=spacing + ) # read and scale the source space [in m] sss = read_source_spaces(src) - logger.info("scaling source space %s: %s -> %s", spacing, subject_from, - subject_to) + logger.info("scaling source space %s: %s -> %s", spacing, subject_from, subject_to) logger.info("Scale factor: %s", scale) add_dist = False for ss in sss: - ss['subject_his_id'] = subject_to - ss['rr'] *= scale + ss["subject_his_id"] = subject_to + ss["rr"] *= scale # additional tags for volume source spaces - for key in ('vox_mri_t', 'src_mri_t'): + for key in ("vox_mri_t", "src_mri_t"): # maintain transform to original MRI volume ss['mri_volume_name'] if key in ss: - ss[key]['trans'][:3] *= scale[:, np.newaxis] + ss[key]["trans"][:3] *= scale[:, np.newaxis] # distances and patch info if uniform: - if ss['dist'] is not None: - ss['dist'] *= scale[0] + if ss["dist"] is not None: + ss["dist"] *= scale[0] # Sometimes this is read-only due to how it's read - ss['nearest_dist'] = ss['nearest_dist'] * scale - ss['dist_limit'] = ss['dist_limit'] * scale + ss["nearest_dist"] = ss["nearest_dist"] * scale + ss["dist_limit"] = ss["dist_limit"] * scale else: # non-uniform scaling - ss['nn'] /= scale - _normalize_vectors(ss['nn']) - if ss['dist'] is not None: + ss["nn"] /= scale + _normalize_vectors(ss["nn"]) + if ss["dist"] is not None: add_dist = True - dist_limit = float(np.abs(sss[0]['dist_limit'])) - elif ss['nearest'] is not None: + dist_limit = float(np.abs(sss[0]["dist_limit"])) + elif ss["nearest"] is not None: add_dist = True dist_limit = 0 @@ -1173,12 +1318,15 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, def _scale_mri(subject_to, mri_fname, subject_from, scale, subjects_dir): """Scale an MRI by setting its affine.""" subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) - nibabel = _import_nibabel('scale an MRI') - fname_from = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_from), mri_fname) - fname_to = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), mri_fname) + subject_to, subject_from, scale, subjects_dir + ) + nibabel = _import_nibabel("scale an MRI") + fname_from = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_from), mri_fname + ) + fname_to = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to), mri_fname + ) img = nibabel.load(fname_from) zooms = np.array(img.header.get_zooms()) zooms[[0, 2, 1]] *= scale @@ -1189,21 +1337,23 @@ def _scale_mri(subject_to, mri_fname, subject_from, scale, subjects_dir): nibabel.save(img, fname_to) -def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, - subjects_dir): +def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, subjects_dir): """Scale a transform.""" subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) + subject_to, subject_from, scale, subjects_dir + ) # The nibabel warning should already be there in MRI step, if applicable, # as we only get here if T1.mgz is present (and thus a scaling was # attempted) so we can silently return here. fname_from = os.path.join( - mri_transforms_dirname.format( - subjects_dir=subjects_dir, subject=subject_from), xfm_fname) + mri_transforms_dirname.format(subjects_dir=subjects_dir, subject=subject_from), + xfm_fname, + ) fname_to = op.join( - mri_transforms_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), xfm_fname) + mri_transforms_dirname.format(subjects_dir=subjects_dir, subject=subject_to), + xfm_fname, + ) assert op.isfile(fname_from), fname_from assert op.isdir(op.dirname(fname_to)), op.dirname(fname_to) # The "talairach.xfm" file stores the ras_mni transform. @@ -1228,23 +1378,25 @@ def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, # prepare the scale (S) transform scale = np.atleast_1d(scale) scale = np.tile(scale, 3) if len(scale) == 1 else scale - S = Transform('mri', 'mri', scaling(*scale)) # F_mri->T_mri + S = Transform("mri", "mri", scaling(*scale)) # F_mri->T_mri # # Get the necessary transforms of the "from" subject # xfm, kind = _read_fs_xfm(fname_from) - assert kind == 'MNI Transform File', kind - _, _, F_mri_ras, _, _ = _read_mri_info(mri_name, units='mm') - F_ras_mni = Transform('ras', 'mni_tal', xfm) + assert kind == "MNI Transform File", kind + _, _, F_mri_ras, _, _ = _read_mri_info(mri_name, units="mm") + F_ras_mni = Transform("ras", "mni_tal", xfm) del xfm # # Get the necessary transforms of the "to" subject # - mri_name = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), op.basename(mri_name)) - _, _, T_mri_ras, _, _ = _read_mri_info(mri_name, units='mm') + mri_name = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to), + op.basename(mri_name), + ) + _, _, T_mri_ras, _, _ = _read_mri_info(mri_name, units="mm") T_ras_mri = invert_transform(T_mri_ras) del mri_name, T_mri_ras @@ -1253,32 +1405,35 @@ def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, # T_ras_mni = F_ras_mni @ F_mri_ras @ S⁻¹ @ T_ras_mri # # By moving right to left through the equation. - T_ras_mni = \ + T_ras_mni = combine_transforms( combine_transforms( - combine_transforms( - combine_transforms( - T_ras_mri, invert_transform(S), 'ras', 'mri'), - F_mri_ras, 'ras', 'ras'), - F_ras_mni, 'ras', 'mni_tal') - _write_fs_xfm(fname_to, T_ras_mni['trans'], kind) + combine_transforms(T_ras_mri, invert_transform(S), "ras", "mri"), + F_mri_ras, + "ras", + "ras", + ), + F_ras_mni, + "ras", + "mni_tal", + ) + _write_fs_xfm(fname_to, T_ras_mni["trans"], kind) def _read_surface(filename, *, on_defects): bem = dict() if filename is not None and op.exists(filename): - if filename.endswith('.fif'): - bem = read_bem_surfaces( - filename, on_defects=on_defects, verbose=False - )[0] + if filename.endswith(".fif"): + bem = read_bem_surfaces(filename, on_defects=on_defects, verbose=False)[0] else: try: bem = read_surface(filename, return_dict=True)[2] - bem['rr'] *= 1e-3 + bem["rr"] *= 1e-3 complete_surface_info(bem, copy=False) except Exception: raise ValueError( "Error loading surface from %s (see " - "Terminal for details)." % filename) + "Terminal for details)." % filename + ) return bem @@ -1320,20 +1475,20 @@ class Coregistration: to create a surrogate MRI subject with the proper scale factors. """ - def __init__(self, info, subject, subjects_dir=None, fiducials='auto', *, - on_defects='raise'): - _validate_type(info, (Info, None), 'info') + def __init__( + self, info, subject, subjects_dir=None, fiducials="auto", *, on_defects="raise" + ): + _validate_type(info, (Info, None), "info") self._info = info self._subject = _check_subject(subject, subject) - self._subjects_dir = str( - get_subjects_dir(subjects_dir, raise_error=True) - ) + self._subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) self._scale_mode = None self._on_defects = on_defects self._rot_trans = None - self._default_parameters = \ - np.array([0., 0., 0., 0., 0., 0., 1., 1., 1.]) + self._default_parameters = np.array( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + ) self._rotation = self._default_parameters[:3] self._translation = self._default_parameters[3:6] @@ -1342,14 +1497,14 @@ def __init__(self, info, subject, subjects_dir=None, fiducials='auto', *, self._icp_angle = 0.2 self._icp_distance = 0.2 self._icp_scale = 0.2 - self._icp_fid_matches = ('nearest', 'matched') + self._icp_fid_matches = ("nearest", "matched") self._icp_fid_match = self._icp_fid_matches[0] - self._lpa_weight = 1. - self._nasion_weight = 10. - self._rpa_weight = 1. - self._hsp_weight = 1. - self._eeg_weight = 1. - self._hpi_weight = 1. + self._lpa_weight = 1.0 + self._nasion_weight = 10.0 + self._rpa_weight = 1.0 + self._hsp_weight = 1.0 + self._eeg_weight = 1.0 + self._hpi_weight = 1.0 self._extra_points_filter = None self._setup_digs() @@ -1371,77 +1526,86 @@ def _setup_digs(self): ) else: self._dig_dict = _get_data_as_dict_from_dig( - dig=self._info['dig'], - exclude_ref_channel=False + dig=self._info["dig"], exclude_ref_channel=False ) # adjustments: # set weights to 0 for None input # convert fids to float arrays - for k, w_atr in zip(['nasion', 'lpa', 'rpa', 'hsp', 'hpi'], - ['_nasion_weight', '_lpa_weight', - '_rpa_weight', '_hsp_weight', '_hpi_weight']): + for k, w_atr in zip( + ["nasion", "lpa", "rpa", "hsp", "hpi"], + [ + "_nasion_weight", + "_lpa_weight", + "_rpa_weight", + "_hsp_weight", + "_hpi_weight", + ], + ): if self._dig_dict[k] is None: self._dig_dict[k] = np.zeros((0, 3)) setattr(self, w_atr, 0) - elif k in ['rpa', 'nasion', 'lpa']: + elif k in ["rpa", "nasion", "lpa"]: self._dig_dict[k] = np.array([self._dig_dict[k]], float) def _setup_bem(self): # find high-res head model (if possible) - high_res_path = _find_head_bem(self._subject, self._subjects_dir, - high_res=True) - low_res_path = _find_head_bem(self._subject, self._subjects_dir, - high_res=False) + high_res_path = _find_head_bem(self._subject, self._subjects_dir, high_res=True) + low_res_path = _find_head_bem(self._subject, self._subjects_dir, high_res=False) if high_res_path is None and low_res_path is None: - raise RuntimeError("No standard head model was " - f"found for subject {self._subject}") + raise RuntimeError( + "No standard head model was " f"found for subject {self._subject}" + ) if high_res_path is not None: self._bem_high_res = _read_surface( high_res_path, on_defects=self._on_defects ) - logger.info(f'Using high resolution head model in {high_res_path}') + logger.info(f"Using high resolution head model in {high_res_path}") else: self._bem_high_res = _read_surface( low_res_path, on_defects=self._on_defects ) - logger.info(f'Using low resolution head model in {low_res_path}') + logger.info(f"Using low resolution head model in {low_res_path}") if low_res_path is None: # This should be very rare! - warn('No low-resolution head found, decimating high resolution ' - 'mesh (%d vertices): %s' % (len(self._bem_high_res['rr']), - high_res_path,)) + warn( + "No low-resolution head found, decimating high resolution " + "mesh (%d vertices): %s" + % ( + len(self._bem_high_res["rr"]), + high_res_path, + ) + ) # Create one from the high res one, which we know we have - rr, tris = decimate_surface(self._bem_high_res['rr'], - self._bem_high_res['tris'], - n_triangles=5120) + rr, tris = decimate_surface( + self._bem_high_res["rr"], self._bem_high_res["tris"], n_triangles=5120 + ) # directly set the attributes of bem_low_res self._bem_low_res = complete_surface_info( - dict(rr=rr, tris=tris), copy=False, verbose=False) - else: - self._bem_low_res = _read_surface( - low_res_path, on_defects=self._on_defects + dict(rr=rr, tris=tris), copy=False, verbose=False ) + else: + self._bem_low_res = _read_surface(low_res_path, on_defects=self._on_defects) def _setup_fiducials(self, fids): _validate_type(fids, (str, dict, list)) # find fiducials file fid_accurate = None - if fids == 'auto': - fid_files = _find_fiducials_files(self._subject, - self._subjects_dir) + if fids == "auto": + fid_files = _find_fiducials_files(self._subject, self._subjects_dir) if len(fid_files) > 0: # Read fiducials from disk fid_filename = fid_files[0].format( - subjects_dir=self._subjects_dir, subject=self._subject) - logger.info(f'Using fiducials from: {fid_filename}.') + subjects_dir=self._subjects_dir, subject=self._subject + ) + logger.info(f"Using fiducials from: {fid_filename}.") fids, _ = read_fiducials(fid_filename) fid_accurate = True self._fid_filename = fid_filename else: - fids = 'estimated' + fids = "estimated" - if fids == 'estimated': - logger.info('Estimating fiducials from fsaverage.') + if fids == "estimated": + logger.info("Estimating fiducials from fsaverage.") fid_accurate = False fids = get_mni_fiducials(self._subject, self._subjects_dir) @@ -1450,8 +1614,9 @@ def _setup_fiducials(self, fids): fid_coords = _fiducial_coords(fids) else: assert isinstance(fids, dict) - fid_coords = np.array([fids['lpa'], fids['nasion'], fids['rpa']], - dtype=float) + fid_coords = np.array( + [fids["lpa"], fids["nasion"], fids["rpa"]], dtype=float + ) self._fid_points = fid_coords self._fid_accurate = fid_accurate @@ -1464,12 +1629,11 @@ def _reset_fiducials(self): lpa=self._fid_points[0], nasion=self._fid_points[1], rpa=self._fid_points[2], - coord_frame='mri' + coord_frame="mri", ) self.fiducials = dig_montage - def _update_params(self, rot=None, tra=None, sca=None, - force_update=False): + def _update_params(self, rot=None, tra=None, sca=None, force_update=False): if force_update and tra is None: tra = self._translation rot_changed = False @@ -1485,18 +1649,19 @@ def _update_params(self, rot=None, tra=None, sca=None, self._last_translation = self._translation.copy() self._translation = tra self._head_mri_t = rotation(*self._rotation).T - self._head_mri_t[:3, 3] = \ - -np.dot(self._head_mri_t[:3, :3], tra) - self._transformed_dig_hpi = \ - apply_trans(self._head_mri_t, self._dig_dict['hpi']) - self._transformed_dig_eeg = \ - apply_trans( - self._head_mri_t, self._dig_dict['dig_ch_pos_location']) - self._transformed_dig_extra = \ - apply_trans(self._head_mri_t, - self._filtered_extra_points) - self._transformed_orig_dig_extra = \ - apply_trans(self._head_mri_t, self._dig_dict['hsp']) + self._head_mri_t[:3, 3] = -np.dot(self._head_mri_t[:3, :3], tra) + self._transformed_dig_hpi = apply_trans( + self._head_mri_t, self._dig_dict["hpi"] + ) + self._transformed_dig_eeg = apply_trans( + self._head_mri_t, self._dig_dict["dig_ch_pos_location"] + ) + self._transformed_dig_extra = apply_trans( + self._head_mri_t, self._filtered_extra_points + ) + self._transformed_orig_dig_extra = apply_trans( + self._head_mri_t, self._dig_dict["hsp"] + ) self._mri_head_t = rotation(*self._rotation) self._mri_head_t[:3, 3] = np.array(tra) if tra_changed or sca is not None: @@ -1506,27 +1671,32 @@ def _update_params(self, rot=None, tra=None, sca=None, self._scale = sca self._mri_trans = np.eye(4) self._mri_trans[:, :3] *= sca - self._transformed_high_res_mri_points = \ - apply_trans(self._mri_trans, - self._processed_high_res_mri_points) + self._transformed_high_res_mri_points = apply_trans( + self._mri_trans, self._processed_high_res_mri_points + ) self._update_nearest_calc() if tra_changed: - self._nearest_transformed_high_res_mri_idx_orig_hsp = \ + self._nearest_transformed_high_res_mri_idx_orig_hsp = ( self._nearest_calc.query(self._transformed_orig_dig_extra)[1] - self._nearest_transformed_high_res_mri_idx_hpi = \ - self._nearest_calc.query(self._transformed_dig_hpi)[1] - self._nearest_transformed_high_res_mri_idx_eeg = \ - self._nearest_calc.query(self._transformed_dig_eeg)[1] - self._nearest_transformed_high_res_mri_idx_rpa = \ - self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['rpa']))[1] - self._nearest_transformed_high_res_mri_idx_nasion = \ - self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['nasion']))[1] - self._nearest_transformed_high_res_mri_idx_lpa = \ + ) + self._nearest_transformed_high_res_mri_idx_hpi = self._nearest_calc.query( + self._transformed_dig_hpi + )[1] + self._nearest_transformed_high_res_mri_idx_eeg = self._nearest_calc.query( + self._transformed_dig_eeg + )[1] + self._nearest_transformed_high_res_mri_idx_rpa = self._nearest_calc.query( + apply_trans(self._head_mri_t, self._dig_dict["rpa"]) + )[1] + self._nearest_transformed_high_res_mri_idx_nasion = ( self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['lpa']))[1] + apply_trans(self._head_mri_t, self._dig_dict["nasion"]) + )[1] + ) + self._nearest_transformed_high_res_mri_idx_lpa = self._nearest_calc.query( + apply_trans(self._head_mri_t, self._dig_dict["lpa"]) + )[1] def set_scale_mode(self, scale_mode): """Select how to fit the scale parameters. @@ -1616,14 +1786,15 @@ def set_scale(self, sca): def _update_nearest_calc(self): self._nearest_calc = _DistanceQuery( - self._processed_high_res_mri_points * self._scale) + self._processed_high_res_mri_points * self._scale + ) @property def _filtered_extra_points(self): if self._extra_points_filter is None: - return self._dig_dict['hsp'] + return self._dig_dict["hsp"] else: - return self._dig_dict['hsp'][self._extra_points_filter] + return self._dig_dict["hsp"][self._extra_points_filter] @property def _parameters(self): @@ -1631,79 +1802,89 @@ def _parameters(self): @property def _last_parameters(self): - return np.concatenate((self._last_rotation, - self._last_translation, self._last_scale)) + return np.concatenate( + (self._last_rotation, self._last_translation, self._last_scale) + ) @property def _changes(self): move = np.linalg.norm(self._last_translation - self._translation) * 1e3 - angle = np.rad2deg(_angle_between_quats( - rot_to_quat(rotation(*self._rotation)[:3, :3]), - rot_to_quat(rotation(*self._last_rotation)[:3, :3]))) + angle = np.rad2deg( + _angle_between_quats( + rot_to_quat(rotation(*self._rotation)[:3, :3]), + rot_to_quat(rotation(*self._last_rotation)[:3, :3]), + ) + ) percs = 100 * (self._scale - self._last_scale) / self._last_scale return move, angle, percs @property def _nearest_transformed_high_res_mri_idx_hsp(self): return self._nearest_calc.query( - apply_trans(self._head_mri_t, self._filtered_extra_points))[1] + apply_trans(self._head_mri_t, self._filtered_extra_points) + )[1] @property def _has_hsp_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hsp) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hsp) > 0 + ) @property def _has_hpi_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hpi) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hpi) > 0 + ) @property def _has_eeg_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_eeg) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_eeg) > 0 + ) @property def _has_lpa_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('lpa')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_LPA - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['lpa']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("lpa")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_LPA + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["lpa"]) return has_mri_data and has_head_data @property def _has_nasion_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('nasion')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_NASION - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['nasion']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("nasion")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_NASION + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["nasion"]) return has_mri_data and has_head_data @property def _has_rpa_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('rpa')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_RPA - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['rpa']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("rpa")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_RPA + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["rpa"]) return has_mri_data and has_head_data @property def _processed_high_res_mri_points(self): - return self._get_processed_mri_points('high') + return self._get_processed_mri_points("high") @property def _processed_low_res_mri_points(self): - return self._get_processed_mri_points('low') + return self._get_processed_mri_points("low") def _get_processed_mri_points(self, res): - bem = self._bem_low_res if res == 'low' else self._bem_high_res - points = bem['rr'].copy() + bem = self._bem_low_res if res == "low" else self._bem_high_res + points = bem["rr"].copy() if self._grow_hair: - assert len(bem['nn']) # should be guaranteed by _read_surface - scaled_hair_dist = (1e-3 * self._grow_hair / - np.array(self._scale)) + assert len(bem["nn"]) # should be guaranteed by _read_surface + scaled_hair_dist = 1e-3 * self._grow_hair / np.array(self._scale) hair = points[:, 2] > points[:, 1] - points[hair] += bem['nn'][hair] * scaled_hair_dist + points[hair] += bem["nn"][hair] * scaled_hair_dist return points @property @@ -1712,20 +1893,24 @@ def _has_mri_data(self): @property def _has_dig_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hsp) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hsp) > 0 + ) @property def _orig_hsp_point_distance(self): mri_points = self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_orig_hsp] + self._nearest_transformed_high_res_mri_idx_orig_hsp + ] hsp_points = self._transformed_orig_dig_extra return np.linalg.norm(mri_points - hsp_points, axis=-1) def _log_dig_mri_distance(self, prefix): errs_nearest = self.compute_dig_mri_distances() - logger.info(f'{prefix} median distance: ' - f'{np.median(errs_nearest * 1000):6.2f} mm') + logger.info( + f"{prefix} median distance: " f"{np.median(errs_nearest * 1000):6.2f} mm" + ) @property def scale(self): @@ -1739,8 +1924,9 @@ def scale(self): return self._scale.copy() @verbose - def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., - verbose=None): + def fit_fiducials( + self, lpa_weight=1.0, nasion_weight=10.0, rpa_weight=1.0, verbose=None + ): """Find rotation and translation to fit all 3 fiducials. Parameters @@ -1758,34 +1944,41 @@ def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., self : Coregistration The modified Coregistration object. """ - logger.info('Aligning using fiducials') - self._log_dig_mri_distance('Start') + logger.info("Aligning using fiducials") + self._log_dig_mri_distance("Start") n_scale_params = self._n_scale_params if n_scale_params == 3: # enforce 1 even for 3-axis here (3 points is not enough) - logger.info("Enforcing 1 scaling parameter for fit " - "with fiducials.") + logger.info("Enforcing 1 scaling parameter for fit " "with fiducials.") n_scale_params = 1 self._lpa_weight = lpa_weight self._nasion_weight = nasion_weight self._rpa_weight = rpa_weight - head_pts = np.vstack((self._dig_dict['lpa'], - self._dig_dict['nasion'], - self._dig_dict['rpa'])) + head_pts = np.vstack( + (self._dig_dict["lpa"], self._dig_dict["nasion"], self._dig_dict["rpa"]) + ) mri_pts = np.vstack( - (self.fiducials.dig[0]['r'], # LPA - self.fiducials.dig[1]['r'], # Nasion - self.fiducials.dig[2]['r']) # RPA + ( + self.fiducials.dig[0]["r"], # LPA + self.fiducials.dig[1]["r"], # Nasion + self.fiducials.dig[2]["r"], + ) # RPA ) weights = [lpa_weight, nasion_weight, rpa_weight] if n_scale_params == 0: mri_pts *= self._scale # not done in fit_matched_points x0 = self._parameters - x0 = x0[:6 + n_scale_params] - est = fit_matched_points(mri_pts, head_pts, x0=x0, out='params', - scale=n_scale_params, weights=weights) + x0 = x0[: 6 + n_scale_params] + est = fit_matched_points( + mri_pts, + head_pts, + x0=x0, + out="params", + scale=n_scale_params, + weights=weights, + ) if n_scale_params == 0: self._update_params(rot=est[:3], tra=est[3:6]) else: @@ -1793,7 +1986,7 @@ def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., est = np.concatenate([est, [est[-1]] * 2]) assert est.size == 9 self._update_params(rot=est[:3], tra=est[3:6], sca=est[6:9]) - self._log_dig_mri_distance('End ') + self._log_dig_mri_distance("End ") return self def _setup_icp(self, n_scale_params): @@ -1802,34 +1995,47 @@ def _setup_icp(self, n_scale_params): weights = list() if self._has_dig_data and self._hsp_weight > 0: # should be true head_pts.append(self._filtered_extra_points) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hsp]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hsp + ] + ) weights.append(np.full(len(head_pts[-1]), self._hsp_weight)) - for key in ('lpa', 'nasion', 'rpa'): - if getattr(self, f'_has_{key}_data'): + for key in ("lpa", "nasion", "rpa"): + if getattr(self, f"_has_{key}_data"): head_pts.append(self._dig_dict[key]) - if self._icp_fid_match == 'matched': + if self._icp_fid_match == "matched": idx = _map_fid_name_to_idx(name=key) - p = self.fiducials.dig[idx]['r'].reshape(1, -1) + p = self.fiducials.dig[idx]["r"].reshape(1, -1) mri_pts.append(p) else: - assert self._icp_fid_match == 'nearest' - mri_pts.append(self._processed_high_res_mri_points[ - getattr( - self, - '_nearest_transformed_high_res_mri_idx_%s' - % (key,))]) - weights.append(np.full(len(mri_pts[-1]), - getattr(self, '_%s_weight' % key))) + assert self._icp_fid_match == "nearest" + mri_pts.append( + self._processed_high_res_mri_points[ + getattr( + self, + "_nearest_transformed_high_res_mri_idx_%s" % (key,), + ) + ] + ) + weights.append( + np.full(len(mri_pts[-1]), getattr(self, "_%s_weight" % key)) + ) if self._has_eeg_data and self._eeg_weight > 0: - head_pts.append(self._dig_dict['dig_ch_pos_location']) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_eeg]) + head_pts.append(self._dig_dict["dig_ch_pos_location"]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_eeg + ] + ) weights.append(np.full(len(mri_pts[-1]), self._eeg_weight)) if self._has_hpi_data and self._hpi_weight > 0: - head_pts.append(self._dig_dict['hpi']) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hpi]) + head_pts.append(self._dig_dict["hpi"]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hpi + ] + ) weights.append(np.full(len(mri_pts[-1]), self._hpi_weight)) head_pts = np.concatenate(head_pts) mri_pts = np.concatenate(mri_pts) @@ -1853,14 +2059,23 @@ def set_fid_match(self, match): self : Coregistration The modified Coregistration object. """ - _check_option('match', match, self._icp_fid_matches) + _check_option("match", match, self._icp_fid_matches) self._icp_fid_match = match return self @verbose - def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., - rpa_weight=1., hsp_weight=1., eeg_weight=1., hpi_weight=1., - callback=None, verbose=None): + def fit_icp( + self, + n_iterations=20, + lpa_weight=1.0, + nasion_weight=10.0, + rpa_weight=1.0, + hsp_weight=1.0, + eeg_weight=1.0, + hpi_weight=1.0, + callback=None, + verbose=None, + ): """Find MRI scaling, translation, and rotation to match HSP. Parameters @@ -1890,8 +2105,8 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., self : Coregistration The modified Coregistration object. """ - logger.info('Aligning using ICP') - self._log_dig_mri_distance('Start ') + logger.info("Aligning using ICP") + self._log_dig_mri_distance("Start ") n_scale_params = self._n_scale_params self._lpa_weight = lpa_weight self._nasion_weight = nasion_weight @@ -1902,13 +2117,19 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., # Initial guess (current state) est = self._parameters - est = est[:[6, 7, None, 9][n_scale_params]] + est = est[: [6, 7, None, 9][n_scale_params]] # Do the fits, assigning and evaluating at each step for iteration in range(n_iterations): head_pts, mri_pts, weights = self._setup_icp(n_scale_params) - est = fit_matched_points(mri_pts, head_pts, scale=n_scale_params, - x0=est, out='params', weights=weights) + est = fit_matched_points( + mri_pts, + head_pts, + scale=n_scale_params, + x0=est, + out="params", + weights=weights, + ) if n_scale_params == 0: self._update_params(rot=est[:3], tra=est[3:6]) elif n_scale_params == 1: @@ -1917,20 +2138,23 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., else: self._update_params(rot=est[:3], tra=est[3:6], sca=est[6:9]) angle, move, scale = self._changes - self._log_dig_mri_distance(f' ICP {iteration + 1:2d} ') + self._log_dig_mri_distance(f" ICP {iteration + 1:2d} ") if callback is not None: callback(iteration, n_iterations) - if angle <= self._icp_angle and move <= self._icp_distance and \ - all(scale <= self._icp_scale): + if ( + angle <= self._icp_angle + and move <= self._icp_distance + and all(scale <= self._icp_scale) + ): break - self._log_dig_mri_distance('End ') + self._log_dig_mri_distance("End ") return self @property def _n_scale_params(self): if self._scale_mode is None: n_scale_params = 0 - elif self._scale_mode == 'uniform': + elif self._scale_mode == "uniform": n_scale_params = 1 else: n_scale_params = 3 @@ -1957,8 +2181,12 @@ def omit_head_shape_points(self, distance): # find the new filter mask = self._orig_hsp_point_distance <= distance n_excluded = np.sum(~mask) - logger.info("Coregistration: Excluding %i head shape points with " - "distance >= %.3f m.", n_excluded, distance) + logger.info( + "Coregistration: Excluding %i head shape points with " + "distance >= %.3f m.", + n_excluded, + distance, + ) # set the filter self._extra_points_filter = mask self._update_params(force_update=True) @@ -1985,7 +2213,7 @@ def compute_dig_mri_distances(self): @property def trans(self): """The head->mri :class:`~mne.transforms.Transform`.""" - return Transform('head', 'mri', self._head_mri_t) + return Transform("head", "mri", self._head_mri_t) def reset(self): """Reset all the parameters affecting the coregistration. @@ -1995,7 +2223,7 @@ def reset(self): self : Coregistration The modified Coregistration object. """ - self._grow_hair = 0. + self._grow_hair = 0.0 self.set_rotation(self._default_parameters[:3]) self.set_translation(self._default_parameters[3:6]) self.set_scale(self._default_parameters[6:9]) @@ -2005,15 +2233,13 @@ def reset(self): def _get_fiducials_distance(self): distance = dict() - for key in ('lpa', 'nasion', 'rpa'): + for key in ("lpa", "nasion", "rpa"): idx = _map_fid_name_to_idx(name=key) - fid = self.fiducials.dig[idx]['r'].reshape(1, -1) + fid = self.fiducials.dig[idx]["r"].reshape(1, -1) transformed_mri = apply_trans(self._mri_trans, fid) - transformed_hsp = apply_trans( - self._head_mri_t, self._dig_dict[key]) - distance[key] = np.linalg.norm( - np.ravel(transformed_mri - transformed_hsp)) + transformed_hsp = apply_trans(self._head_mri_t, self._dig_dict[key]) + distance[key] = np.linalg.norm(np.ravel(transformed_mri - transformed_hsp)) return np.array(list(distance.values())) * 1e3 def _get_fiducials_distance_str(self): @@ -2024,18 +2250,27 @@ def _get_point_distance(self): mri_points = list() hsp_points = list() if self._hsp_weight > 0 and self._has_hsp_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hsp]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hsp + ] + ) hsp_points.append(self._transformed_dig_extra) assert len(mri_points[-1]) == len(hsp_points[-1]) if self._eeg_weight > 0 and self._has_eeg_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_eeg]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_eeg + ] + ) hsp_points.append(self._transformed_dig_eeg) assert len(mri_points[-1]) == len(hsp_points[-1]) if self._hpi_weight > 0 and self._has_hpi_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hpi]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hpi + ] + ) hsp_points.append(self._transformed_dig_hpi) assert len(mri_points[-1]) == len(hsp_points[-1]) if all(len(h) == 0 for h in hsp_points): @@ -2051,10 +2286,14 @@ def _get_point_distance_str(self): dists = 1e3 * point_distance av_dist = np.mean(dists) std_dist = np.std(dists) - kinds = [kind for kind, check in - (('HSP', self._hsp_weight > 0 and self._has_hsp_data), - ('EEG', self._eeg_weight > 0 and self._has_eeg_data), - ('HPI', self._hpi_weight > 0 and self._has_hpi_data)) - if check] - kinds = '+'.join(kinds) + kinds = [ + kind + for kind, check in ( + ("HSP", self._hsp_weight > 0 and self._has_hsp_data), + ("EEG", self._eeg_weight > 0 and self._has_eeg_data), + ("HPI", self._hpi_weight > 0 and self._has_hpi_data), + ) + if check + ] + kinds = "+".join(kinds) return f"{len(dists)} {kinds}: {av_dist:.1f} ± {std_dist:.1f} mm" diff --git a/mne/cov.py b/mne/cov.py index 43c993c6c91..15fd043d022 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -10,55 +10,100 @@ import numpy as np -from .defaults import (_INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, - _BORDER_DEFAULT, DEFAULTS) +from .defaults import ( + _INTERPOLATION_DEFAULT, + _EXTRAPOLATE_DEFAULT, + _BORDER_DEFAULT, + DEFAULTS, +) from .io.write import start_and_end_file -from .io.proj import (make_projector, _proj_equal, activate_proj, - _check_projs, _needs_eeg_average_ref_proj, - _has_eeg_average_ref_proj, _read_proj, _write_proj) +from .io.proj import ( + make_projector, + _proj_equal, + activate_proj, + _check_projs, + _needs_eeg_average_ref_proj, + _has_eeg_average_ref_proj, + _read_proj, + _write_proj, +) from .io import fiff_open, RawArray -from .io.pick import (pick_types, pick_channels_cov, pick_channels, pick_info, - _picks_by_type, _pick_data_channels, _picks_to_idx, - _DATA_CH_TYPES_SPLIT) +from .io.pick import ( + pick_types, + pick_channels_cov, + pick_channels, + pick_info, + _picks_by_type, + _pick_data_channels, + _picks_to_idx, + _DATA_CH_TYPES_SPLIT, +) from .io.constants import FIFF from .io.meas_info import _read_bad_channels, create_info, _write_bad_channels from .io.tag import find_tag from .io.tree import dir_tree_find -from .io.write import (start_block, end_block, write_int, write_double, - write_float_matrix, write_string, _safe_name_list, - write_name_list_sanitized) +from .io.write import ( + start_block, + end_block, + write_int, + write_double, + write_float_matrix, + write_string, + _safe_name_list, + write_name_list_sanitized, +) from .defaults import _handle_default from .epochs import Epochs from .event import make_fixed_length_events from .evoked import EvokedArray from .rank import compute_rank -from .utils import (check_fname, logger, verbose, check_version, _time_mask, - warn, copy_function_doc_to_method_doc, _pl, - _undo_scaling_cov, _scaled_array, _validate_type, - _check_option, eigh, fill_doc, _on_missing, - _check_on_missing, _check_fname, _verbose_safe_false) +from .utils import ( + check_fname, + logger, + verbose, + check_version, + _time_mask, + warn, + copy_function_doc_to_method_doc, + _pl, + _undo_scaling_cov, + _scaled_array, + _validate_type, + _check_option, + eigh, + fill_doc, + _on_missing, + _check_on_missing, + _check_fname, + _verbose_safe_false, +) from . import viz -from .fixes import (BaseEstimator, EmpiricalCovariance, _logdet, - empirical_covariance, log_likelihood) +from .fixes import ( + BaseEstimator, + EmpiricalCovariance, + _logdet, + empirical_covariance, + log_likelihood, +) def _check_covs_algebra(cov1, cov2): if cov1.ch_names != cov2.ch_names: - raise ValueError('Both Covariance do not have the same list of ' - 'channels.') - projs1 = [str(c) for c in cov1['projs']] - projs2 = [str(c) for c in cov1['projs']] + raise ValueError("Both Covariance do not have the same list of " "channels.") + projs1 = [str(c) for c in cov1["projs"]] + projs2 = [str(c) for c in cov1["projs"]] if projs1 != projs2: - raise ValueError('Both Covariance do not have the same list of ' - 'SSP projections.') + raise ValueError( + "Both Covariance do not have the same list of " "SSP projections." + ) def _get_tslice(epochs, tmin, tmax): """Get the slice.""" - mask = _time_mask(epochs.times, tmin, tmax, sfreq=epochs.info['sfreq']) + mask = _time_mask(epochs.times, tmin, tmax, sfreq=epochs.info["sfreq"]) tstart = np.where(mask)[0][0] if tmin is not None else None tend = np.where(mask)[0][-1] + 1 if tmax is not None else None tslice = slice(tstart, tend, None) @@ -116,33 +161,54 @@ class Covariance(dict): """ @verbose - def __init__(self, data, names, bads, projs, nfree, eig=None, eigvec=None, - method=None, loglik=None, *, verbose=None): + def __init__( + self, + data, + names, + bads, + projs, + nfree, + eig=None, + eigvec=None, + method=None, + loglik=None, + *, + verbose=None, + ): """Init of covariance.""" - diag = (data.ndim == 1) + diag = data.ndim == 1 projs = _check_projs(projs) - self.update(data=data, dim=len(data), names=names, bads=bads, - nfree=nfree, eig=eig, eigvec=eigvec, diag=diag, - projs=projs, kind=FIFF.FIFFV_MNE_NOISE_COV) + self.update( + data=data, + dim=len(data), + names=names, + bads=bads, + nfree=nfree, + eig=eig, + eigvec=eigvec, + diag=diag, + projs=projs, + kind=FIFF.FIFFV_MNE_NOISE_COV, + ) if method is not None: - self['method'] = method + self["method"] = method if loglik is not None: - self['loglik'] = loglik + self["loglik"] = loglik @property def data(self): """Numpy array of Noise covariance matrix.""" - return self['data'] + return self["data"] @property def ch_names(self): """Channel names.""" - return self['names'] + return self["names"] @property def nfree(self): """Number of degrees of freedom.""" - return self['nfree'] + return self["nfree"] @verbose def save(self, fname, *, overwrite=False, verbose=None): @@ -157,8 +223,9 @@ def save(self, fname, *, overwrite=False, verbose=None): .. versionadded:: 1.0 %(verbose)s """ - check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', - '_cov.fif', '_cov.fif.gz')) + check_fname( + fname, "covariance", ("-cov.fif", "-cov.fif.gz", "_cov.fif", "_cov.fif.gz") + ) fname = _check_fname(fname=fname, overwrite=overwrite) with start_and_end_file(fname) as fid: _write_cov(fid, self) @@ -188,35 +255,35 @@ def as_diag(self): This function operates in place. """ - if self['diag']: + if self["diag"]: return self - self['diag'] = True - self['data'] = np.diag(self['data']) - self['eig'] = None - self['eigvec'] = None + self["diag"] = True + self["data"] = np.diag(self["data"]) + self["eig"] = None + self["eigvec"] = None return self def _as_square(self): # This is a hack but it works because np.diag() behaves nicely - if self['diag']: - self['diag'] = False + if self["diag"]: + self["diag"] = False self.as_diag() - self['diag'] = False + self["diag"] = False return self def _get_square(self): - if self['diag'] != (self.data.ndim == 1): + if self["diag"] != (self.data.ndim == 1): raise RuntimeError( - 'Covariance attributes inconsistent, got data with ' - 'dimensionality %d but diag=%s' - % (self.data.ndim, self['diag'])) - return np.diag(self.data) if self['diag'] else self.data.copy() + "Covariance attributes inconsistent, got data with " + "dimensionality %d but diag=%s" % (self.data.ndim, self["diag"]) + ) + return np.diag(self.data) if self["diag"] else self.data.copy() def __repr__(self): # noqa: D105 if self.data.ndim == 2: - s = 'size : %s x %s' % self.data.shape + s = "size : %s x %s" % self.data.shape else: # ndim == 1 - s = 'diagonal : %s' % self.data.size + s = "diagonal : %s" % self.data.size s += ", n_samples : %s" % self.nfree s += ", data : %s" % self.data return "" % s @@ -225,43 +292,74 @@ def __add__(self, cov): """Add Covariance taking into account number of degrees of freedom.""" _check_covs_algebra(self, cov) this_cov = cov.copy() - this_cov['data'] = (((this_cov['data'] * this_cov['nfree']) + - (self['data'] * self['nfree'])) / - (self['nfree'] + this_cov['nfree'])) - this_cov['nfree'] += self['nfree'] + this_cov["data"] = ( + (this_cov["data"] * this_cov["nfree"]) + (self["data"] * self["nfree"]) + ) / (self["nfree"] + this_cov["nfree"]) + this_cov["nfree"] += self["nfree"] - this_cov['bads'] = list(set(this_cov['bads']).union(self['bads'])) + this_cov["bads"] = list(set(this_cov["bads"]).union(self["bads"])) return this_cov def __iadd__(self, cov): """Add Covariance taking into account number of degrees of freedom.""" _check_covs_algebra(self, cov) - self['data'][:] = (((self['data'] * self['nfree']) + - (cov['data'] * cov['nfree'])) / - (self['nfree'] + cov['nfree'])) - self['nfree'] += cov['nfree'] + self["data"][:] = ( + (self["data"] * self["nfree"]) + (cov["data"] * cov["nfree"]) + ) / (self["nfree"] + cov["nfree"]) + self["nfree"] += cov["nfree"] - self['bads'] = list(set(self['bads']).union(cov['bads'])) + self["bads"] = list(set(self["bads"]).union(cov["bads"])) return self @verbose @copy_function_doc_to_method_doc(viz.misc.plot_cov) - def plot(self, info, exclude=[], colorbar=True, proj=False, show_svd=True, - show=True, verbose=None): - return viz.misc.plot_cov(self, info, exclude, colorbar, proj, show_svd, - show, verbose) + def plot( + self, + info, + exclude=[], + colorbar=True, + proj=False, + show_svd=True, + show=True, + verbose=None, + ): + return viz.misc.plot_cov( + self, info, exclude, colorbar, proj, show_svd, show, verbose + ) @verbose def plot_topomap( - self, info, ch_type=None, *, scalings=None, proj=False, - noise_cov=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, - cbar_fmt='%3.1f', units=None, axes=None, show=True, verbose=None): + self, + info, + ch_type=None, + *, + scalings=None, + proj=False, + noise_cov=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + show=True, + verbose=None, + ): """Plot a topomap of the covariance diagonal. Parameters @@ -319,22 +417,40 @@ def plot_topomap( # entries is the same as multiplying twice evoked = whiten_evoked(whiten_evoked(evoked, noise_cov), noise_cov) if units is None: - units = 'AU' + units = "AU" if scalings is None: - scalings = 1. + scalings = 1.0 if units is None: - units = {k: f'({v})²' for k, v in DEFAULTS['units'].items()} + units = {k: f"({v})²" for k, v in DEFAULTS["units"].items()} if scalings is None: - scalings = {k: v * v for k, v in DEFAULTS['scalings'].items()} + scalings = {k: v * v for k, v in DEFAULTS["scalings"].items()} return evoked.plot_topomap( - times=[0], ch_type=ch_type, vlim=vlim, cmap=cmap, - sensors=sensors, cnorm=cnorm, colorbar=colorbar, scalings=scalings, - units=units, res=res, size=size, cbar_fmt=cbar_fmt, - proj=proj, show=show, show_names=show_names, - mask=mask, mask_params=mask_params, outlines=outlines, - contours=contours, image_interp=image_interp, axes=axes, - extrapolate=extrapolate, sphere=sphere, border=border, - time_format='') + times=[0], + ch_type=ch_type, + vlim=vlim, + cmap=cmap, + sensors=sensors, + cnorm=cnorm, + colorbar=colorbar, + scalings=scalings, + units=units, + res=res, + size=size, + cbar_fmt=cbar_fmt, + proj=proj, + show=show, + show_names=show_names, + mask=mask, + mask_params=mask_params, + outlines=outlines, + contours=contours, + image_interp=image_interp, + axes=axes, + extrapolate=extrapolate, + sphere=sphere, + border=border, + time_format="", + ) @verbose def pick_channels(self, ch_names, ordered=None, *, verbose=None): @@ -358,13 +474,15 @@ def pick_channels(self, ch_names, ordered=None, *, verbose=None): .. versionadded:: 0.20.0 """ - return pick_channels_cov(self, ch_names, exclude=[], ordered=ordered, - copy=False) + return pick_channels_cov( + self, ch_names, exclude=[], ordered=ordered, copy=False + ) ############################################################################### # IO + @verbose def read_cov(fname, verbose=None): """Read a noise covariance from a FIF file. @@ -385,18 +503,21 @@ def read_cov(fname, verbose=None): -------- write_cov, compute_covariance, compute_raw_covariance """ - check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', - '_cov.fif', '_cov.fif.gz')) - fname = _check_fname(fname=fname, must_exist=True, overwrite='read') + check_fname( + fname, "covariance", ("-cov.fif", "-cov.fif.gz", "_cov.fif", "_cov.fif.gz") + ) + fname = _check_fname(fname=fname, must_exist=True, overwrite="read") f, tree, _ = fiff_open(fname) with f as fid: - return Covariance(**_read_cov(fid, tree, FIFF.FIFFV_MNE_NOISE_COV, - limited=True)) + return Covariance( + **_read_cov(fid, tree, FIFF.FIFFV_MNE_NOISE_COV, limited=True) + ) ############################################################################### # Estimate from data + @verbose def make_ad_hoc_cov(info, std=None, *, verbose=None): """Create an ad hoc noise covariance. @@ -423,33 +544,51 @@ def make_ad_hoc_cov(info, std=None, *, verbose=None): .. versionadded:: 0.9.0 """ picks = pick_types(info, meg=True, eeg=True, exclude=()) - std = _handle_default('noise_std', std) + std = _handle_default("noise_std", std) data = np.zeros(len(picks)) - for meg, eeg, val in zip(('grad', 'mag', False), (False, False, True), - (std['grad'], std['mag'], std['eeg'])): + for meg, eeg, val in zip( + ("grad", "mag", False), + (False, False, True), + (std["grad"], std["mag"], std["eeg"]), + ): these_picks = pick_types(info, meg=meg, eeg=eeg) data[np.searchsorted(picks, these_picks)] = val * val - ch_names = [info['ch_names'][pick] for pick in picks] - return Covariance(data, ch_names, info['bads'], info['projs'], nfree=0) + ch_names = [info["ch_names"][pick] for pick in picks] + return Covariance(data, ch_names, info["bads"], info["projs"], nfree=0) def _check_n_samples(n_samples, n_chan): """Check to see if there are enough samples for reliable cov calc.""" n_samples_min = 10 * (n_chan + 1) // 2 if n_samples <= 0: - raise ValueError('No samples found to compute the covariance matrix') + raise ValueError("No samples found to compute the covariance matrix") if n_samples < n_samples_min: - warn('Too few samples (required : %d got : %d), covariance ' - 'estimate may be unreliable' % (n_samples_min, n_samples)) + warn( + "Too few samples (required : %d got : %d), covariance " + "estimate may be unreliable" % (n_samples_min, n_samples) + ) @verbose -def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, - flat=None, picks=None, method='empirical', - method_params=None, cv=3, scalings=None, - n_jobs=None, return_estimators=False, - reject_by_annotation=True, rank=None, verbose=None): +def compute_raw_covariance( + raw, + tmin=0, + tmax=None, + tstep=0.2, + reject=None, + flat=None, + picks=None, + method="empirical", + method_params=None, + cv=3, + scalings=None, + n_jobs=None, + return_estimators=False, + reject_by_annotation=True, + rank=None, + verbose=None, +): """Estimate noise covariance matrix from a continuous segment of raw data. It is typically useful to estimate a noise covariance from empty room @@ -557,31 +696,40 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, baseline correction) subtract the mean across time *for each epoch* (instead of across epochs) for each channel. """ - tmin = 0. if tmin is None else float(tmin) - dt = 1. / raw.info['sfreq'] + tmin = 0.0 if tmin is None else float(tmin) + dt = 1.0 / raw.info["sfreq"] tmax = raw.times[-1] + dt if tmax is None else float(tmax) tstep = tmax - tmin if tstep is None else float(tstep) tstep_m1 = tstep - dt # inclusive! events = make_fixed_length_events(raw, 1, tmin, tmax, tstep) - logger.info('Using up to %s segment%s' % (len(events), _pl(events))) + logger.info("Using up to %s segment%s" % (len(events), _pl(events))) # don't exclude any bad channels, inverses expect all channels present if picks is None: # Need to include all channels e.g. if eog rejection is to be used - picks = np.arange(raw.info['nchan']) - pick_mask = np.in1d( - picks, _pick_data_channels(raw.info, with_ref_meg=False)) + picks = np.arange(raw.info["nchan"]) + pick_mask = np.in1d(picks, _pick_data_channels(raw.info, with_ref_meg=False)) else: pick_mask = slice(None) picks = _picks_to_idx(raw.info, picks) - epochs = Epochs(raw, events, 1, 0, tstep_m1, baseline=None, - picks=picks, reject=reject, flat=flat, - verbose=_verbose_safe_false(), - preload=False, proj=False, - reject_by_annotation=reject_by_annotation) + epochs = Epochs( + raw, + events, + 1, + 0, + tstep_m1, + baseline=None, + picks=picks, + reject=reject, + flat=flat, + verbose=_verbose_safe_false(), + preload=False, + proj=False, + reject_by_annotation=reject_by_annotation, + ) if method is None: - method = 'empirical' - if isinstance(method, str) and method == 'empirical': + method = "empirical" + if isinstance(method, str) and method == "empirical": # potentially *much* more memory efficient to do it the iterative way picks = picks[pick_mask] data = 0 @@ -595,13 +743,12 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, n_samples += raw_segment.shape[1] _check_n_samples(n_samples, len(picks)) data -= mu[:, None] * (mu[None, :] / n_samples) - data /= (n_samples - 1.0) + data /= n_samples - 1.0 logger.info("Number of samples used : %d" % n_samples) - logger.info('[done]') - ch_names = [raw.info['ch_names'][k] for k in picks] - bads = [b for b in raw.info['bads'] if b in ch_names] - return Covariance(data, ch_names, bads, raw.info['projs'], - nfree=n_samples - 1) + logger.info("[done]") + ch_names = [raw.info["ch_names"][k] for k in picks] + bads = [b for b in raw.info["bads"] if b in ch_names] + return Covariance(data, ch_names, bads, raw.info["projs"], nfree=n_samples - 1) del picks, pick_mask # This makes it equivalent to what we used to do (and do above for @@ -611,85 +758,130 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, epochs._data -= ch_means[np.newaxis, :, np.newaxis] # fake this value so there are no complaints from compute_covariance epochs.baseline = (None, None) - return compute_covariance(epochs, keep_sample_mean=True, method=method, - method_params=method_params, cv=cv, - scalings=scalings, n_jobs=n_jobs, - return_estimators=return_estimators, - rank=rank) - - -def _check_method_params(method, method_params, keep_sample_mean=True, - name='method', allow_auto=True, rank=None): + return compute_covariance( + epochs, + keep_sample_mean=True, + method=method, + method_params=method_params, + cv=cv, + scalings=scalings, + n_jobs=n_jobs, + return_estimators=return_estimators, + rank=rank, + ) + + +def _check_method_params( + method, + method_params, + keep_sample_mean=True, + name="method", + allow_auto=True, + rank=None, +): """Check that method and method_params are usable.""" - accepted_methods = ('auto', 'empirical', 'diagonal_fixed', 'ledoit_wolf', - 'oas', 'shrunk', 'pca', 'factor_analysis', 'shrinkage') + accepted_methods = ( + "auto", + "empirical", + "diagonal_fixed", + "ledoit_wolf", + "oas", + "shrunk", + "pca", + "factor_analysis", + "shrinkage", + ) _method_params = { - 'empirical': {'store_precision': False, 'assume_centered': True}, - 'diagonal_fixed': {'store_precision': False, 'assume_centered': True}, - 'ledoit_wolf': {'store_precision': False, 'assume_centered': True}, - 'oas': {'store_precision': False, 'assume_centered': True}, - 'shrinkage': {'shrinkage': 0.1, 'store_precision': False, - 'assume_centered': True}, - 'shrunk': {'shrinkage': np.logspace(-4, 0, 30), - 'store_precision': False, 'assume_centered': True}, - 'pca': {'iter_n_components': None}, - 'factor_analysis': {'iter_n_components': None} + "empirical": {"store_precision": False, "assume_centered": True}, + "diagonal_fixed": {"store_precision": False, "assume_centered": True}, + "ledoit_wolf": {"store_precision": False, "assume_centered": True}, + "oas": {"store_precision": False, "assume_centered": True}, + "shrinkage": { + "shrinkage": 0.1, + "store_precision": False, + "assume_centered": True, + }, + "shrunk": { + "shrinkage": np.logspace(-4, 0, 30), + "store_precision": False, + "assume_centered": True, + }, + "pca": {"iter_n_components": None}, + "factor_analysis": {"iter_n_components": None}, } for ch_type in _DATA_CH_TYPES_SPLIT: - _method_params['diagonal_fixed'][ch_type] = 0.1 + _method_params["diagonal_fixed"][ch_type] = 0.1 if isinstance(method_params, dict): for key, values in method_params.items(): if key not in _method_params: - raise ValueError('key (%s) must be "%s"' % - (key, '" or "'.join(_method_params))) + raise ValueError( + 'key (%s) must be "%s"' % (key, '" or "'.join(_method_params)) + ) _method_params[key].update(method_params[key]) - shrinkage = method_params.get('shrinkage', {}).get('shrinkage', 0.1) + shrinkage = method_params.get("shrinkage", {}).get("shrinkage", 0.1) if not 0 <= shrinkage <= 1: - raise ValueError('shrinkage must be between 0 and 1, got %s' - % (shrinkage,)) + raise ValueError("shrinkage must be between 0 and 1, got %s" % (shrinkage,)) was_auto = False if method is None: - method = ['empirical'] - elif method == 'auto' and allow_auto: + method = ["empirical"] + elif method == "auto" and allow_auto: was_auto = True - method = ['shrunk', 'diagonal_fixed', 'empirical', 'factor_analysis'] + method = ["shrunk", "diagonal_fixed", "empirical", "factor_analysis"] if not isinstance(method, (list, tuple)): method = [method] if not all(k in accepted_methods for k in method): raise ValueError( - 'Invalid {name} ({method}). Accepted values (individually or ' + "Invalid {name} ({method}). Accepted values (individually or " 'in a list) are any of "{accepted_methods}" or None.'.format( - name=name, method=method, accepted_methods=accepted_methods)) - if not (isinstance(rank, str) and rank == 'full'): + name=name, method=method, accepted_methods=accepted_methods + ) + ) + if not (isinstance(rank, str) and rank == "full"): if was_auto: - method.pop(method.index('factor_analysis')) + method.pop(method.index("factor_analysis")) for method_ in method: - if method_ in ('pca', 'factor_analysis'): - raise ValueError('%s can so far only be used with rank="full",' - ' got rank=%r' % (method_, rank)) + if method_ in ("pca", "factor_analysis"): + raise ValueError( + '%s can so far only be used with rank="full",' + " got rank=%r" % (method_, rank) + ) if not keep_sample_mean: - if len(method) != 1 or 'empirical' not in method: - raise ValueError('`keep_sample_mean=False` is only supported' - 'with %s="empirical"' % (name,)) + if len(method) != 1 or "empirical" not in method: + raise ValueError( + "`keep_sample_mean=False` is only supported" + 'with %s="empirical"' % (name,) + ) for p, v in _method_params.items(): - if v.get('assume_centered', None) is False: - raise ValueError('`assume_centered` must be True' - ' if `keep_sample_mean` is False') + if v.get("assume_centered", None) is False: + raise ValueError( + "`assume_centered` must be True" " if `keep_sample_mean` is False" + ) return method, _method_params @verbose -def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None, - projs=None, method='empirical', method_params=None, - cv=3, scalings=None, n_jobs=None, - return_estimators=False, on_mismatch='raise', - rank=None, verbose=None): +def compute_covariance( + epochs, + keep_sample_mean=True, + tmin=None, + tmax=None, + projs=None, + method="empirical", + method_params=None, + cv=3, + scalings=None, + n_jobs=None, + return_estimators=False, + on_mismatch="raise", + rank=None, + verbose=None, +): """Estimate noise covariance matrix from epochs. The noise covariance is typically estimated on pre-stimulus periods @@ -859,7 +1051,8 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None, # scale to natural unit for best stability with MEG/EEG scalings = _check_scalings_user(scalings) method, _method_params = _check_method_params( - method, method_params, keep_sample_mean, rank=rank) + method, method_params, keep_sample_mean, rank=rank + ) del method_params # for multi condition support epochs is required to refer to a list of @@ -878,43 +1071,49 @@ def _unpack_epochs(epochs): epochs = sum([_unpack_epochs(epoch) for epoch in epochs], []) # check for baseline correction - if any(epochs_t.baseline is None and epochs_t.info['highpass'] < 0.5 and - keep_sample_mean for epochs_t in epochs): - warn('Epochs are not baseline corrected, covariance ' - 'matrix may be inaccurate') - - orig = epochs[0].info['dev_head_t'] - _check_on_missing(on_mismatch, 'on_mismatch') + if any( + epochs_t.baseline is None + and epochs_t.info["highpass"] < 0.5 + and keep_sample_mean + for epochs_t in epochs + ): + warn( + "Epochs are not baseline corrected, covariance " "matrix may be inaccurate" + ) + + orig = epochs[0].info["dev_head_t"] + _check_on_missing(on_mismatch, "on_mismatch") for ei, epoch in enumerate(epochs): epoch.info._check_consistency() - if (orig is None) != (epoch.info['dev_head_t'] is None) or \ - (orig is not None and not - np.allclose(orig['trans'], - epoch.info['dev_head_t']['trans'])): - msg = ('MEG<->Head transform mismatch between epochs[0]:\n%s\n\n' - 'and epochs[%s]:\n%s' - % (orig, ei, epoch.info['dev_head_t'])) - _on_missing(on_mismatch, msg, 'on_mismatch') - - bads = epochs[0].info['bads'] + if (orig is None) != (epoch.info["dev_head_t"] is None) or ( + orig is not None + and not np.allclose(orig["trans"], epoch.info["dev_head_t"]["trans"]) + ): + msg = ( + "MEG<->Head transform mismatch between epochs[0]:\n%s\n\n" + "and epochs[%s]:\n%s" % (orig, ei, epoch.info["dev_head_t"]) + ) + _on_missing(on_mismatch, msg, "on_mismatch") + + bads = epochs[0].info["bads"] if projs is None: - projs = epochs[0].info['projs'] + projs = epochs[0].info["projs"] # make sure Epochs are compatible for epochs_t in epochs[1:]: if epochs_t.proj != epochs[0].proj: - raise ValueError('Epochs must agree on the use of projections') - for proj_a, proj_b in zip(epochs_t.info['projs'], projs): + raise ValueError("Epochs must agree on the use of projections") + for proj_a, proj_b in zip(epochs_t.info["projs"], projs): if not _proj_equal(proj_a, proj_b): - raise ValueError('Epochs must have same projectors') + raise ValueError("Epochs must have same projectors") projs = _check_projs(projs) ch_names = epochs[0].ch_names # make sure Epochs are compatible for epochs_t in epochs[1:]: - if epochs_t.info['bads'] != bads: - raise ValueError('Epochs must have same bad channels') + if epochs_t.info["bads"] != bads: + raise ValueError("Epochs must have same bad channels") if epochs_t.ch_names != ch_names: - raise ValueError('Epochs must have same channel names') + raise ValueError("Epochs must have same channel names") picks_list = _picks_by_type(epochs[0].info) picks_meeg = np.concatenate([b for _, b in picks_list]) picks_meeg = np.sort(picks_meeg) @@ -929,7 +1128,6 @@ def _unpack_epochs(epochs): n_epochs = np.zeros(n_epoch_types, dtype=np.int64) for ii, epochs_t in enumerate(epochs): - tslice = _get_tslice(epochs_t, tmin, tmax) for e in epochs_t: e = e[picks_meeg, tslice] @@ -940,8 +1138,10 @@ def _unpack_epochs(epochs): n_samples_epoch = n_samples // n_epochs norm_const = np.sum(n_samples_epoch * (n_epochs - 1)) - data_mean = [1.0 / n_epoch * np.dot(mean, mean.T) for n_epoch, mean - in zip(n_epochs, data_mean)] + data_mean = [ + 1.0 / n_epoch * np.dot(mean, mean.T) + for n_epoch, mean in zip(n_epochs, data_mean) + ] info = pick_info(info, picks_meeg) tslice = _get_tslice(epochs[0], tmin, tmax) @@ -960,14 +1160,22 @@ def _unpack_epochs(epochs): epochs = epochs.T # sklearn | C-order cov_data = _compute_covariance_auto( - epochs, method=method, method_params=_method_params, info=info, - cv=cv, n_jobs=n_jobs, stop_early=True, picks_list=picks_list, - scalings=scalings, rank=rank) + epochs, + method=method, + method_params=_method_params, + info=info, + cv=cv, + n_jobs=n_jobs, + stop_early=True, + picks_list=picks_list, + scalings=scalings, + rank=rank, + ) if keep_sample_mean is False: - cov = cov_data['empirical']['data'] + cov = cov_data["empirical"]["data"] # undo scaling - cov *= (n_samples_tot - 1) + cov *= n_samples_tot - 1 # ... apply pre-computed class-wise normalization for mean_cov in data_mean: cov -= mean_cov @@ -975,28 +1183,29 @@ def _unpack_epochs(epochs): covs = list() for this_method, data in cov_data.items(): - cov = Covariance(data.pop('data'), ch_names, info['bads'], projs, - nfree=n_samples_tot - 1) + cov = Covariance( + data.pop("data"), ch_names, info["bads"], projs, nfree=n_samples_tot - 1 + ) # add extra info cov.update(method=this_method, **data) covs.append(cov) - logger.info('Number of samples used : %d' % n_samples_tot) - covs.sort(key=lambda c: c['loglik'], reverse=True) + logger.info("Number of samples used : %d" % n_samples_tot) + covs.sort(key=lambda c: c["loglik"], reverse=True) if len(covs) > 1: - msg = ['log-likelihood on unseen data (descending order):'] + msg = ["log-likelihood on unseen data (descending order):"] for c in covs: - msg.append('%s: %0.3f' % (c['method'], c['loglik'])) - logger.info('\n '.join(msg)) + msg.append("%s: %0.3f" % (c["method"], c["loglik"])) + logger.info("\n ".join(msg)) if return_estimators: out = covs else: out = covs[0] - logger.info('selecting best estimator: {}'.format(out['method'])) + logger.info("selecting best estimator: {}".format(out["method"])) else: out = covs[0] - logger.info('[done]') + logger.info("[done]") return out @@ -1004,11 +1213,12 @@ def _unpack_epochs(epochs): def _check_scalings_user(scalings): if isinstance(scalings, dict): for k, v in scalings.items(): - _check_option('the keys in `scalings`', k, ['mag', 'grad', 'eeg']) + _check_option("the keys in `scalings`", k, ["mag", "grad", "eeg"]) elif scalings is not None and not isinstance(scalings, np.ndarray): - raise TypeError('scalings must be a dict, ndarray, or None, got %s' - % type(scalings)) - scalings = _handle_default('scalings', scalings) + raise TypeError( + "scalings must be a dict, ndarray, or None, got %s" % type(scalings) + ) + scalings = _handle_default("scalings", scalings) return scalings @@ -1021,33 +1231,49 @@ def _eigvec_subspace(eig, eigvec, mask): return eig, eigvec -def _compute_covariance_auto(data, method, info, method_params, cv, - scalings, n_jobs, stop_early, picks_list, rank): +def _compute_covariance_auto( + data, + method, + info, + method_params, + cv, + scalings, + n_jobs, + stop_early, + picks_list, + rank, +): """Compute covariance auto mode.""" # rescale to improve numerical stability orig_rank = rank rank = compute_rank( RawArray(data.T, info, copy=None, verbose=_verbose_safe_false()), - rank, scalings, info) + rank, + scalings, + info, + ) with _scaled_array(data.T, picks_list, scalings): C = np.dot(data.T, data) - _, eigvec, mask = _smart_eigh(C, info, rank, proj_subspace=True, - do_compute_rank=False) + _, eigvec, mask = _smart_eigh( + C, info, rank, proj_subspace=True, do_compute_rank=False + ) eigvec = eigvec[mask] data = np.dot(data, eigvec.T) used = np.where(mask)[0] - sub_picks_list = [(key, np.searchsorted(used, picks)) - for key, picks in picks_list] + sub_picks_list = [ + (key, np.searchsorted(used, picks)) for key, picks in picks_list + ] sub_info = pick_info(info, used) if len(used) != len(mask) else info - logger.info('Reducing data rank from %s -> %s' - % (len(mask), eigvec.shape[0])) + logger.info("Reducing data rank from %s -> %s" % (len(mask), eigvec.shape[0])) estimator_cov_info = list() - msg = 'Estimating covariance using %s' + msg = "Estimating covariance using %s" - ok_sklearn = check_version('sklearn') - if not ok_sklearn and (len(method) != 1 or method[0] != 'empirical'): - raise ValueError('scikit-learn is not installed, `method` must be ' - '`empirical`, got %s' % (method,)) + ok_sklearn = check_version("sklearn") + if not ok_sklearn and (len(method) != 1 or method[0] != "empirical"): + raise ValueError( + "scikit-learn is not installed, `method` must be " + "`empirical`, got %s" % (method,) + ) for method_ in method: data_ = data.copy() @@ -1056,20 +1282,21 @@ def _compute_covariance_auto(data, method, info, method_params, cv, mp = method_params[method_] _info = {} - if method_ == 'empirical': + if method_ == "empirical": est = EmpiricalCovariance(**mp) est.fit(data_) estimator_cov_info.append((est, est.covariance_, _info)) del est - elif method_ == 'diagonal_fixed': + elif method_ == "diagonal_fixed": est = _RegCovariance(info=sub_info, **mp) est.fit(data_) estimator_cov_info.append((est, est.covariance_, _info)) del est - elif method_ == 'ledoit_wolf': + elif method_ == "ledoit_wolf": from sklearn.covariance import LedoitWolf + shrinkages = [] lw = LedoitWolf(**mp) @@ -1081,8 +1308,9 @@ def _compute_covariance_auto(data, method, info, method_params, cv, estimator_cov_info.append((sc, sc.covariance_, _info)) del lw, sc - elif method_ == 'oas': + elif method_ == "oas": from sklearn.covariance import OAS + shrinkages = [] oas = OAS(**mp) @@ -1094,58 +1322,65 @@ def _compute_covariance_auto(data, method, info, method_params, cv, estimator_cov_info.append((sc, sc.covariance_, _info)) del oas, sc - elif method_ == 'shrinkage': + elif method_ == "shrinkage": sc = _ShrunkCovariance(**mp) sc.fit(data_) estimator_cov_info.append((sc, sc.covariance_, _info)) del sc - elif method_ == 'shrunk': + elif method_ == "shrunk": from sklearn.model_selection import GridSearchCV from sklearn.covariance import ShrunkCovariance - shrinkage = mp.pop('shrinkage') - tuned_parameters = [{'shrinkage': shrinkage}] + + shrinkage = mp.pop("shrinkage") + tuned_parameters = [{"shrinkage": shrinkage}] shrinkages = [] - gs = GridSearchCV(ShrunkCovariance(**mp), - tuned_parameters, cv=cv) + gs = GridSearchCV(ShrunkCovariance(**mp), tuned_parameters, cv=cv) for ch_type, picks in sub_picks_list: gs.fit(data_[:, picks]) - shrinkages.append((ch_type, gs.best_estimator_.shrinkage, - picks)) + shrinkages.append((ch_type, gs.best_estimator_.shrinkage, picks)) shrinkages = [c[0] for c in zip(shrinkages)] sc = _ShrunkCovariance(shrinkage=shrinkages, **mp) sc.fit(data_) estimator_cov_info.append((sc, sc.covariance_, _info)) del shrinkage, sc - elif method_ == 'pca': - assert orig_rank == 'full' + elif method_ == "pca": + assert orig_rank == "full" pca, _info = _auto_low_rank_model( - data_, method_, n_jobs=n_jobs, method_params=mp, cv=cv, - stop_early=stop_early) + data_, + method_, + n_jobs=n_jobs, + method_params=mp, + cv=cv, + stop_early=stop_early, + ) pca.fit(data_) estimator_cov_info.append((pca, pca.get_covariance(), _info)) del pca - elif method_ == 'factor_analysis': - assert orig_rank == 'full' + elif method_ == "factor_analysis": + assert orig_rank == "full" fa, _info = _auto_low_rank_model( - data_, method_, n_jobs=n_jobs, method_params=mp, cv=cv, - stop_early=stop_early) + data_, + method_, + n_jobs=n_jobs, + method_params=mp, + cv=cv, + stop_early=stop_early, + ) fa.fit(data_) estimator_cov_info.append((fa, fa.get_covariance(), _info)) del fa else: - raise ValueError('Oh no! Your estimator does not have' - ' a .fit method') - logger.info('Done.') + raise ValueError("Oh no! Your estimator does not have" " a .fit method") + logger.info("Done.") if len(method) > 1: - logger.info('Using cross-validation to select the best estimator.') + logger.info("Using cross-validation to select the best estimator.") out = dict() - for ei, (estimator, cov, runtime_info) in \ - enumerate(estimator_cov_info): + for ei, (estimator, cov, runtime_info) in enumerate(estimator_cov_info): if len(method) > 1: loglik = _cross_val(data, estimator, cv, n_jobs) else: @@ -1169,8 +1404,8 @@ def _gaussian_loglik_scorer(est, X, y=None): # compute empirical covariance of the test set precision = est.get_precision() n_samples, n_features = X.shape - log_like = -.5 * (X * (np.dot(X, precision))).sum(axis=1) - log_like -= .5 * (n_features * log(2. * np.pi) - _logdet(precision)) + log_like = -0.5 * (X * (np.dot(X, precision))).sum(axis=1) + log_like -= 0.5 * (n_features * log(2.0 * np.pi) - _logdet(precision)) out = np.mean(log_like) return out @@ -1178,22 +1413,28 @@ def _gaussian_loglik_scorer(est, X, y=None): def _cross_val(data, est, cv, n_jobs): """Compute cross validation.""" from sklearn.model_selection import cross_val_score - return np.mean(cross_val_score(est, data, cv=cv, n_jobs=n_jobs, - scoring=_gaussian_loglik_scorer)) + return np.mean( + cross_val_score( + est, data, cv=cv, n_jobs=n_jobs, scoring=_gaussian_loglik_scorer + ) + ) -def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, - stop_early=True, verbose=None): + +def _auto_low_rank_model( + data, mode, n_jobs, method_params, cv, stop_early=True, verbose=None +): """Compute latent variable models.""" method_params = deepcopy(method_params) - iter_n_components = method_params.pop('iter_n_components') + iter_n_components = method_params.pop("iter_n_components") if iter_n_components is None: iter_n_components = np.arange(5, data.shape[1], 5) from sklearn.decomposition import PCA, FactorAnalysis - if mode == 'factor_analysis': + + if mode == "factor_analysis": est = FactorAnalysis else: - assert mode == 'pca' + assert mode == "pca" est = PCA est = est(**method_params) est.n_components = 1 @@ -1203,8 +1444,10 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, # make sure we don't empty the thing if it's a generator max_n = max(list(deepcopy(iter_n_components))) if max_n > data.shape[1]: - warn('You are trying to estimate %i components on matrix ' - 'with %i features.' % (max_n, data.shape[1])) + warn( + "You are trying to estimate %i components on matrix " + "with %i features." % (max_n, data.shape[1]) + ) for ii, n in enumerate(iter_n_components): est.n_components = n @@ -1213,30 +1456,34 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, except ValueError: score = np.inf if np.isinf(score) or score > 0: - logger.info('... infinite values encountered. stopping estimation') + logger.info("... infinite values encountered. stopping estimation") break - logger.info('... rank: %i - loglik: %0.3f' % (n, score)) + logger.info("... rank: %i - loglik: %0.3f" % (n, score)) if score != -np.inf: scores[ii] = score - if (ii >= 3 and np.all(np.diff(scores[ii - 3:ii]) < 0) and stop_early): + if ii >= 3 and np.all(np.diff(scores[ii - 3 : ii]) < 0) and stop_early: # early stop search when loglik has been going down 3 times - logger.info('early stopping parameter search.') + logger.info("early stopping parameter search.") break # happens if rank is too low right form the beginning if np.isnan(scores).all(): - raise RuntimeError('Oh no! Could not estimate covariance because all ' - 'scores were NaN. Please contact the MNE-Python ' - 'developers.') + raise RuntimeError( + "Oh no! Could not estimate covariance because all " + "scores were NaN. Please contact the MNE-Python " + "developers." + ) i_score = np.nanargmax(scores) best = est.n_components = iter_n_components[i_score] - logger.info('... best model at rank = %i' % best) - runtime_info = {'ranks': np.array(iter_n_components), - 'scores': scores, - 'best': best, - 'cv': cv} + logger.info("... best model at rank = %i" % best) + runtime_info = { + "ranks": np.array(iter_n_components), + "scores": scores, + "best": best, + "cv": cv, + } return est, runtime_info @@ -1247,11 +1494,25 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, class _RegCovariance(BaseEstimator): """Aux class.""" - def __init__(self, info, grad=0.1, mag=0.1, eeg=0.1, seeg=0.1, - ecog=0.1, hbo=0.1, hbr=0.1, fnirs_cw_amplitude=0.1, - fnirs_fd_ac_amplitude=0.1, fnirs_fd_phase=0.1, fnirs_od=0.1, - csd=0.1, dbs=0.1, store_precision=False, - assume_centered=False): + def __init__( + self, + info, + grad=0.1, + mag=0.1, + eeg=0.1, + seeg=0.1, + ecog=0.1, + hbo=0.1, + hbr=0.1, + fnirs_cw_amplitude=0.1, + fnirs_fd_ac_amplitude=0.1, + fnirs_fd_phase=0.1, + fnirs_od=0.1, + csd=0.1, + dbs=0.1, + store_precision=False, + assume_centered=False, + ): self.info = info # For sklearn compat, these cannot (easily?) be combined into # a single dictionary @@ -1274,20 +1535,33 @@ def __init__(self, info, grad=0.1, mag=0.1, eeg=0.1, seeg=0.1, def fit(self, X): """Fit covariance model with classical diagonal regularization.""" self.estimator_ = EmpiricalCovariance( - store_precision=self.store_precision, - assume_centered=self.assume_centered) + store_precision=self.store_precision, assume_centered=self.assume_centered + ) self.covariance_ = self.estimator_.fit(X).covariance_ self.covariance_ = 0.5 * (self.covariance_ + self.covariance_.T) cov_ = Covariance( - data=self.covariance_, names=self.info['ch_names'], - bads=self.info['bads'], projs=self.info['projs'], - nfree=len(self.covariance_)) + data=self.covariance_, + names=self.info["ch_names"], + bads=self.info["bads"], + projs=self.info["projs"], + nfree=len(self.covariance_), + ) cov_ = regularize( - cov_, self.info, proj=False, exclude='bads', - grad=self.grad, mag=self.mag, eeg=self.eeg, - ecog=self.ecog, seeg=self.seeg, dbs=self.dbs, - hbo=self.hbo, hbr=self.hbr, rank='full') + cov_, + self.info, + proj=False, + exclude="bads", + grad=self.grad, + mag=self.mag, + eeg=self.eeg, + ecog=self.ecog, + seeg=self.seeg, + dbs=self.dbs, + hbo=self.hbo, + hbr=self.hbr, + rank="full", + ) self.estimator_.covariance_ = self.covariance_ = cov_.data return self @@ -1303,9 +1577,7 @@ def get_precision(self): class _ShrunkCovariance(BaseEstimator): """Aux class.""" - def __init__(self, store_precision, assume_centered, - shrinkage=0.1): - + def __init__(self, store_precision, assume_centered, shrinkage=0.1): self.store_precision = store_precision self.assume_centered = assume_centered self.shrinkage = shrinkage @@ -1313,14 +1585,15 @@ def __init__(self, store_precision, assume_centered, def fit(self, X): """Fit covariance model with oracle shrinkage regularization.""" from sklearn.covariance import shrunk_covariance + self.estimator_ = EmpiricalCovariance( - store_precision=self.store_precision, - assume_centered=self.assume_centered) + store_precision=self.store_precision, assume_centered=self.assume_centered + ) cov = self.estimator_.fit(X).covariance_ if not isinstance(self.shrinkage, (list, tuple)): - shrinkage = [('all', self.shrinkage, np.arange(len(cov)))] + shrinkage = [("all", self.shrinkage, np.arange(len(cov)))] else: shrinkage = self.shrinkage @@ -1328,7 +1601,7 @@ def fit(self, X): for a, b in itt.combinations(shrinkage, 2): picks_i, picks_j = a[2], b[2] ch_ = a[0], b[0] - if 'eeg' in ch_: + if "eeg" in ch_: zero_cross_cov[np.ix_(picks_i, picks_j)] = True zero_cross_cov[np.ix_(picks_j, picks_i)] = True @@ -1337,14 +1610,13 @@ def fit(self, X): # Apply shrinkage to blocks for ch_type, c, picks in shrinkage: sub_cov = cov[np.ix_(picks, picks)] - cov[np.ix_(picks, picks)] = shrunk_covariance(sub_cov, - shrinkage=c) + cov[np.ix_(picks, picks)] = shrunk_covariance(sub_cov, shrinkage=c) # Apply shrinkage to cross-cov for a, b in itt.combinations(shrinkage, 2): shrinkage_i, shrinkage_j = a[1], b[1] picks_i, picks_j = a[2], b[2] - c_ij = np.sqrt((1. - shrinkage_i) * (1. - shrinkage_j)) + c_ij = np.sqrt((1.0 - shrinkage_i) * (1.0 - shrinkage_j)) cov[np.ix_(picks_i, picks_j)] *= c_ij cov[np.ix_(picks_j, picks_i)] *= c_ij @@ -1358,10 +1630,11 @@ def fit(self, X): def score(self, X_test, y=None): """Delegate to modified EmpiricalCovariance instance.""" # compute empirical covariance of the test set - test_cov = empirical_covariance(X_test - self.estimator_.location_, - assume_centered=True) + test_cov = empirical_covariance( + X_test - self.estimator_.location_, assume_centered=True + ) if np.any(self.zero_cross_cov_): - test_cov[self.zero_cross_cov_] = 0. + test_cov[self.zero_cross_cov_] = 0.0 res = log_likelihood(test_cov, self.estimator_.get_precision()) return res @@ -1373,6 +1646,7 @@ def get_precision(self): ############################################################################### # Writing + @verbose def write_cov(fname, cov, *, overwrite=False, verbose=None): """Write a noise covariance matrix. @@ -1399,6 +1673,7 @@ def write_cov(fname, cov, *, overwrite=False, verbose=None): ############################################################################### # Prepare for inverse modeling + def _unpack_epochs(epochs): """Aux Function.""" if len(epochs.event_id) > 1: @@ -1418,8 +1693,10 @@ def _get_ch_whitener(A, pca, ch_type, rank): eig[:-rank] = 0.0 mask[:-rank] = False - logger.info(' Setting small %s eigenvalues to zero (%s)' - % (ch_type, 'using PCA' if pca else 'without PCA')) + logger.info( + " Setting small %s eigenvalues to zero (%s)" + % (ch_type, "using PCA" if pca else "without PCA") + ) if pca: # No PCA case. # This line will reduce the actual number of variables in data # and leadfield to the true rank. @@ -1428,8 +1705,15 @@ def _get_ch_whitener(A, pca, ch_type, rank): @verbose -def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, - scalings=None, on_rank_mismatch='ignore', verbose=None): +def prepare_noise_cov( + noise_cov, + info, + ch_names=None, + rank=None, + scalings=None, + on_rank_mismatch="ignore", + verbose=None, +): """Prepare noise covariance matrix. Parameters @@ -1461,7 +1745,7 @@ def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, # reorder C and info to match ch_names order noise_cov_idx = list() missing = list() - ch_names = info['ch_names'] if ch_names is None else ch_names + ch_names = info["ch_names"] if ch_names is None else ch_names for c in ch_names: # this could be try/except ValueError, but it is not the preferred way if c in noise_cov.ch_names: @@ -1469,51 +1753,71 @@ def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, else: missing.append(c) if len(missing): - raise RuntimeError('Not all channels present in noise covariance:\n%s' - % missing) + raise RuntimeError( + "Not all channels present in noise covariance:\n%s" % missing + ) C = noise_cov._get_square()[np.ix_(noise_cov_idx, noise_cov_idx)] - info = pick_info( - info, pick_channels(info['ch_names'], ch_names, ordered=False)) - projs = info['projs'] + noise_cov['projs'] + info = pick_info(info, pick_channels(info["ch_names"], ch_names, ordered=False)) + projs = info["projs"] + noise_cov["projs"] noise_cov = Covariance( - data=C, names=ch_names, bads=list(noise_cov['bads']), - projs=deepcopy(noise_cov['projs']), nfree=noise_cov['nfree'], - method=noise_cov.get('method', None), - loglik=noise_cov.get('loglik', None)) - - eig, eigvec, _ = _smart_eigh(noise_cov, info, rank, scalings, projs, - ch_names, on_rank_mismatch=on_rank_mismatch) + data=C, + names=ch_names, + bads=list(noise_cov["bads"]), + projs=deepcopy(noise_cov["projs"]), + nfree=noise_cov["nfree"], + method=noise_cov.get("method", None), + loglik=noise_cov.get("loglik", None), + ) + + eig, eigvec, _ = _smart_eigh( + noise_cov, + info, + rank, + scalings, + projs, + ch_names, + on_rank_mismatch=on_rank_mismatch, + ) noise_cov.update(eig=eig, eigvec=eigvec) return noise_cov @verbose -def _smart_eigh(C, info, rank, scalings=None, projs=None, - ch_names=None, proj_subspace=False, do_compute_rank=True, - on_rank_mismatch='ignore', verbose=None): +def _smart_eigh( + C, + info, + rank, + scalings=None, + projs=None, + ch_names=None, + proj_subspace=False, + do_compute_rank=True, + on_rank_mismatch="ignore", + verbose=None, +): """Compute eigh of C taking into account rank and ch_type scalings.""" - scalings = _handle_default('scalings_cov_rank', scalings) - projs = info['projs'] if projs is None else projs - ch_names = info['ch_names'] if ch_names is None else ch_names - if info['ch_names'] != ch_names: - info = pick_info(info, [info['ch_names'].index(c) for c in ch_names]) - assert info['ch_names'] == ch_names + scalings = _handle_default("scalings_cov_rank", scalings) + projs = info["projs"] if projs is None else projs + ch_names = info["ch_names"] if ch_names is None else ch_names + if info["ch_names"] != ch_names: + info = pick_info(info, [info["ch_names"].index(c) for c in ch_names]) + assert info["ch_names"] == ch_names n_chan = len(ch_names) # Create the projection operator proj, ncomp, _ = make_projector(projs, ch_names) if isinstance(C, Covariance): - C = C['data'] + C = C["data"] if ncomp > 0: - logger.info(' Created an SSP operator (subspace dimension = %d)' - % ncomp) + logger.info(" Created an SSP operator (subspace dimension = %d)" % ncomp) C = np.dot(proj, np.dot(C, proj.T)) noise_cov = Covariance(C, ch_names, [], projs, 0) if do_compute_rank: # if necessary rank = compute_rank( - noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch) + noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch + ) assert C.ndim == 2 and C.shape[0] == C.shape[1] # time saving short-circuit @@ -1524,14 +1828,15 @@ def _smart_eigh(C, info, rank, scalings=None, projs=None, eig = np.zeros(n_chan, dtype) eigvec = np.zeros((n_chan, n_chan), dtype) mask = np.zeros(n_chan, bool) - for ch_type, picks in _picks_by_type(info, meg_combined=True, - ref_meg=False, exclude=[]): + for ch_type, picks in _picks_by_type( + info, meg_combined=True, ref_meg=False, exclude=[] + ): if len(picks) == 0: continue this_C = C[np.ix_(picks, picks)] - if ch_type not in rank and ch_type in ('mag', 'grad'): - this_rank = rank['meg'] # if there is only one or the other + if ch_type not in rank and ch_type in ("mag", "grad"): + this_rank = rank["meg"] # if there is only one or the other else: this_rank = rank[ch_type] @@ -1541,21 +1846,43 @@ def _smart_eigh(C, info, rank, scalings=None, projs=None, e, ev = _eigvec_subspace(e, ev, m) eig[picks], eigvec[np.ix_(picks, picks)], mask[picks] = e, ev, m # XXX : also handle ref for sEEG and ECoG - if ch_type == 'eeg' and _needs_eeg_average_ref_proj(info) and not \ - _has_eeg_average_ref_proj(info, projs=projs): - warn('No average EEG reference present in info["projs"], ' - 'covariance may be adversely affected. Consider recomputing ' - 'covariance using with an average eeg reference projector ' - 'added.') + if ( + ch_type == "eeg" + and _needs_eeg_average_ref_proj(info) + and not _has_eeg_average_ref_proj(info, projs=projs) + ): + warn( + 'No average EEG reference present in info["projs"], ' + "covariance may be adversely affected. Consider recomputing " + "covariance using with an average eeg reference projector " + "added." + ) return eig, eigvec, mask @verbose -def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', - proj=True, seeg=0.1, ecog=0.1, hbo=0.1, hbr=0.1, - fnirs_cw_amplitude=0.1, fnirs_fd_ac_amplitude=0.1, - fnirs_fd_phase=0.1, fnirs_od=0.1, csd=0.1, dbs=0.1, - rank=None, scalings=None, verbose=None): +def regularize( + cov, + info, + mag=0.1, + grad=0.1, + eeg=0.1, + exclude="bads", + proj=True, + seeg=0.1, + ecog=0.1, + hbo=0.1, + hbr=0.1, + fnirs_cw_amplitude=0.1, + fnirs_fd_ac_amplitude=0.1, + fnirs_fd_phase=0.1, + fnirs_od=0.1, + csd=0.1, + dbs=0.1, + rank=None, + scalings=None, + verbose=None, +): """Regularize noise covariance matrix. This method works by adding a constant to the diagonal for each @@ -1629,37 +1956,54 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', mne.compute_covariance """ # noqa: E501 from scipy import linalg + cov = cov.copy() info._check_consistency() - scalings = _handle_default('scalings_cov_rank', scalings) - regs = dict(eeg=eeg, seeg=seeg, dbs=dbs, ecog=ecog, hbo=hbo, hbr=hbr, - fnirs_cw_amplitude=fnirs_cw_amplitude, - fnirs_fd_ac_amplitude=fnirs_fd_ac_amplitude, - fnirs_fd_phase=fnirs_fd_phase, fnirs_od=fnirs_od, csd=csd) + scalings = _handle_default("scalings_cov_rank", scalings) + regs = dict( + eeg=eeg, + seeg=seeg, + dbs=dbs, + ecog=ecog, + hbo=hbo, + hbr=hbr, + fnirs_cw_amplitude=fnirs_cw_amplitude, + fnirs_fd_ac_amplitude=fnirs_fd_ac_amplitude, + fnirs_fd_phase=fnirs_fd_phase, + fnirs_od=fnirs_od, + csd=csd, + ) if exclude is None: raise ValueError('exclude must be a list of strings or "bads"') - if exclude == 'bads': - exclude = info['bads'] + cov['bads'] + if exclude == "bads": + exclude = info["bads"] + cov["bads"] picks_dict = {ch_type: [] for ch_type in _DATA_CH_TYPES_SPLIT} - meg_combined = 'auto' if rank != 'full' else False - picks_dict.update(dict(_picks_by_type( - info, meg_combined=meg_combined, exclude=exclude, ref_meg=False))) - if len(picks_dict.get('meg', [])) > 0 and rank != 'full': # combined + meg_combined = "auto" if rank != "full" else False + picks_dict.update( + dict( + _picks_by_type( + info, meg_combined=meg_combined, exclude=exclude, ref_meg=False + ) + ) + ) + if len(picks_dict.get("meg", [])) > 0 and rank != "full": # combined if mag != grad: - raise ValueError('On data where magnetometers and gradiometers ' - 'are dependent (e.g., SSSed data), mag (%s) must ' - 'equal grad (%s)' % (mag, grad)) - logger.info('Regularizing MEG channels jointly') - regs['meg'] = mag + raise ValueError( + "On data where magnetometers and gradiometers " + "are dependent (e.g., SSSed data), mag (%s) must " + "equal grad (%s)" % (mag, grad) + ) + logger.info("Regularizing MEG channels jointly") + regs["meg"] = mag else: regs.update(mag=mag, grad=grad) - if rank != 'full': + if rank != "full": rank = compute_rank(cov, rank, scalings, info) - info_ch_names = info['ch_names'] + info_ch_names = info["ch_names"] ch_names_by_type = dict() for ch_type, picks_type in picks_dict.items(): ch_names_by_type[ch_type] = [info_ch_names[i] for i in picks_type] @@ -1667,7 +2011,8 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', # This actually removes bad channels from the cov, which is not backward # compatible, so let's leave all channels in cov_good = pick_channels_cov( - cov, include=info_ch_names, exclude=exclude, ordered=False) + cov, include=info_ch_names, exclude=exclude, ordered=False + ) ch_names = cov_good.ch_names # Now get the indices for each channel type in the cov @@ -1678,14 +2023,14 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', idx_cov[ch_type].append(i) break else: - raise Exception('channel %s is unknown type' % ch) + raise Exception("channel %s is unknown type" % ch) - C = cov_good['data'] + C = cov_good["data"] assert len(C) == sum(map(len, idx_cov.values())) if proj: - projs = info['projs'] + cov_good['projs'] + projs = info["projs"] + cov_good["projs"] projs = activate_proj(projs) for ch_type in idx_cov: @@ -1702,16 +2047,18 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', this_C = C[np.ix_(idx, idx)] U = np.eye(this_C.shape[0]) this_ch_names = [ch_names[k] for k in idx] - if rank == 'full': + if rank == "full": if proj: P, ncomp, _ = make_projector(projs, this_ch_names) if ncomp > 0: # This adjustment ends up being redundant if rank is None: U = linalg.svd(P)[0][:, :-ncomp] - logger.info(' Created an SSP operator for %s ' - '(dimension = %d)' % (desc, ncomp)) + logger.info( + " Created an SSP operator for %s " + "(dimension = %d)" % (desc, ncomp) + ) else: - this_picks = pick_channels(info['ch_names'], this_ch_names) + this_picks = pick_channels(info["ch_names"], this_ch_names) this_info = pick_info(info, this_picks) # Here we could use proj_subspace=True, but this should not matter # since this is already in a loop over channel types @@ -1720,20 +2067,18 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', this_C = np.dot(U.T, np.dot(this_C, U)) sigma = np.mean(np.diag(this_C)) - this_C.flat[::len(this_C) + 1] += reg * sigma # modify diag inplace + this_C.flat[:: len(this_C) + 1] += reg * sigma # modify diag inplace this_C = np.dot(U, np.dot(this_C, U.T)) C[np.ix_(idx, idx)] = this_C # Put data back in correct locations - idx = pick_channels( - cov.ch_names, info_ch_names, exclude=exclude, ordered=False) - cov['data'][np.ix_(idx, idx)] = C + idx = pick_channels(cov.ch_names, info_ch_names, exclude=exclude, ordered=False) + cov["data"][np.ix_(idx, idx)] = C return cov -def _regularized_covariance(data, reg=None, method_params=None, info=None, - rank=None): +def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=None): """Compute a regularized covariance from data using sklearn. This is a convenience wrapper for mne.decoding functions, which @@ -1744,36 +2089,55 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, cov : ndarray, shape (n_channels, n_channels) The covariance matrix. """ - _validate_type(reg, (str, 'numeric', None)) + _validate_type(reg, (str, "numeric", None)) if reg is None: - reg = 'empirical' + reg = "empirical" elif not isinstance(reg, str): reg = float(reg) if method_params is not None: - raise ValueError('If reg is a float, method_params must be None ' - '(got %s)' % (type(method_params),)) - method_params = dict(shrinkage=dict( - shrinkage=reg, assume_centered=True, store_precision=False)) - reg = 'shrinkage' + raise ValueError( + "If reg is a float, method_params must be None " + "(got %s)" % (type(method_params),) + ) + method_params = dict( + shrinkage=dict(shrinkage=reg, assume_centered=True, store_precision=False) + ) + reg = "shrinkage" method, method_params = _check_method_params( - reg, method_params, name='reg', allow_auto=False, rank=rank) + reg, method_params, name="reg", allow_auto=False, rank=rank + ) # use mag instead of eeg here to avoid the cov EEG projection warning - info = create_info(data.shape[-2], 1000., 'mag') if info is None else info + info = create_info(data.shape[-2], 1000.0, "mag") if info is None else info picks_list = _picks_by_type(info) - scalings = _handle_default('scalings_cov_rank', None) + scalings = _handle_default("scalings_cov_rank", None) cov = _compute_covariance_auto( - data.T, method=method, method_params=method_params, - info=info, cv=None, n_jobs=None, stop_early=True, - picks_list=picks_list, scalings=scalings, - rank=rank)[reg]['data'] + data.T, + method=method, + method_params=method_params, + info=info, + cv=None, + n_jobs=None, + stop_early=True, + picks_list=picks_list, + scalings=scalings, + rank=rank, + )[reg]["data"] return cov @verbose -def compute_whitener(noise_cov, info=None, picks=None, rank=None, - scalings=None, return_rank=False, pca=False, - return_colorer=False, on_rank_mismatch='warn', - verbose=None): +def compute_whitener( + noise_cov, + info=None, + picks=None, + rank=None, + scalings=None, + return_rank=False, + pca=False, + return_colorer=False, + on_rank_mismatch="warn", + verbose=None, +): """Compute whitening matrix. Parameters @@ -1824,53 +2188,56 @@ def compute_whitener(noise_cov, info=None, picks=None, rank=None, colorer : ndarray, shape (n_channels, n_channels) or (n_channels, n_nonzero) The coloring matrix. """ # noqa: E501 - _validate_type(pca, (str, bool), 'space') - _valid_pcas = (True, 'white', False) + _validate_type(pca, (str, bool), "space") + _valid_pcas = (True, "white", False) if pca not in _valid_pcas: - raise ValueError('space must be one of %s, got %s' - % (_valid_pcas, pca)) + raise ValueError("space must be one of %s, got %s" % (_valid_pcas, pca)) if info is None: - if 'eig' not in noise_cov: - raise ValueError('info can only be None if the noise cov has ' - 'already been prepared with prepare_noise_cov') - ch_names = deepcopy(noise_cov['names']) + if "eig" not in noise_cov: + raise ValueError( + "info can only be None if the noise cov has " + "already been prepared with prepare_noise_cov" + ) + ch_names = deepcopy(noise_cov["names"]) else: picks = _picks_to_idx(info, picks, with_ref_meg=False) - ch_names = [info['ch_names'][k] for k in picks] + ch_names = [info["ch_names"][k] for k in picks] del picks noise_cov = prepare_noise_cov( - noise_cov, info, ch_names, rank, scalings, - on_rank_mismatch=on_rank_mismatch) + noise_cov, info, ch_names, rank, scalings, on_rank_mismatch=on_rank_mismatch + ) n_chan = len(ch_names) - assert n_chan == len(noise_cov['eig']) + assert n_chan == len(noise_cov["eig"]) # Omit the zeroes due to projection - eig = noise_cov['eig'].copy() - nzero = (eig > 0) - eig[~nzero] = 0. # get rid of numerical noise (negative) ones + eig = noise_cov["eig"].copy() + nzero = eig > 0 + eig[~nzero] = 0.0 # get rid of numerical noise (negative) ones - if noise_cov['eigvec'].dtype.kind == 'c': + if noise_cov["eigvec"].dtype.kind == "c": dtype = np.complex128 else: dtype = np.float64 W = np.zeros((n_chan, 1), dtype) W[nzero, 0] = 1.0 / np.sqrt(eig[nzero]) # Rows of eigvec are the eigenvectors - W = W * noise_cov['eigvec'] # C ** -0.5 - C = np.sqrt(eig) * noise_cov['eigvec'].conj().T # C ** 0.5 + W = W * noise_cov["eigvec"] # C ** -0.5 + C = np.sqrt(eig) * noise_cov["eigvec"].conj().T # C ** 0.5 n_nzero = nzero.sum() - logger.info(' Created the whitener using a noise covariance matrix ' - 'with rank %d (%d small eigenvalues omitted)' - % (n_nzero, noise_cov['dim'] - n_nzero)) + logger.info( + " Created the whitener using a noise covariance matrix " + "with rank %d (%d small eigenvalues omitted)" + % (n_nzero, noise_cov["dim"] - n_nzero) + ) # Do the requested projection if pca is True: W = W[nzero] C = C[:, nzero] elif pca is False: - W = np.dot(noise_cov['eigvec'].conj().T, W) - C = np.dot(C, noise_cov['eigvec']) + W = np.dot(noise_cov["eigvec"].conj().T, W) + C = np.dot(C, noise_cov["eigvec"]) # Triage return out = W, ch_names @@ -1882,8 +2249,9 @@ def compute_whitener(noise_cov, info=None, picks=None, rank=None, @verbose -def whiten_evoked(evoked, noise_cov, picks=None, diag=None, rank=None, - scalings=None, verbose=None): +def whiten_evoked( + evoked, noise_cov, picks=None, diag=None, rank=None, scalings=None, verbose=None +): """Whiten evoked data using given noise covariance. Parameters @@ -1919,8 +2287,9 @@ def whiten_evoked(evoked, noise_cov, picks=None, diag=None, rank=None, if diag: noise_cov = noise_cov.as_diag() - W, _ = compute_whitener(noise_cov, evoked.info, picks=picks, - rank=rank, scalings=scalings) + W, _ = compute_whitener( + noise_cov, evoked.info, picks=picks, rank=rank, scalings=scalings + ) evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks]) return evoked @@ -1931,9 +2300,10 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): """Read a noise covariance matrix.""" # Find all covariance matrices from scipy import sparse + covs = dir_tree_find(node, FIFF.FIFFB_MNE_COV) if len(covs) == 0: - raise ValueError('No covariance matrices found') + raise ValueError("No covariance matrices found") # Is any of the covariance matrices a noise covariance for p in range(len(covs)): @@ -1945,7 +2315,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): # Find all the necessary data tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_DIM) if tag is None: - raise ValueError('Covariance matrix dimension not found') + raise ValueError("Covariance matrix dimension not found") dim = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_NFREE) @@ -1970,22 +2340,25 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): if tag is None: names = [] else: - names = _safe_name_list(tag.data, 'read', 'names') + names = _safe_name_list(tag.data, "read", "names") if len(names) != dim: - raise ValueError('Number of names does not match ' - 'covariance matrix dimension') + raise ValueError( + "Number of names does not match " "covariance matrix dimension" + ) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_DIAG) if tag is None: - raise ValueError('No covariance matrix data found') + raise ValueError("No covariance matrix data found") else: # Diagonal is stored data = tag.data diag = True - logger.info(' %d x %d diagonal covariance (kind = ' - '%d) found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d diagonal covariance (kind = " + "%d) found." % (dim, dim, cov_kind) + ) else: if not sparse.issparse(tag.data): @@ -1994,15 +2367,19 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data = np.zeros((dim, dim)) data[np.tril(np.ones((dim, dim))) > 0] = vals data = data + data.T - data.flat[::dim + 1] /= 2.0 + data.flat[:: dim + 1] /= 2.0 diag = False - logger.info(' %d x %d full covariance (kind = %d) ' - 'found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d full covariance (kind = %d) " + "found." % (dim, dim, cov_kind) + ) else: diag = False data = tag.data - logger.info(' %d x %d sparse covariance (kind = %d)' - ' found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d sparse covariance (kind = %d)" + " found." % (dim, dim, cov_kind) + ) # Read the possibly precomputed decomposition tag1 = find_tag(fid, this, FIFF.FIFF_MNE_COV_EIGENVALUES) @@ -2023,20 +2400,28 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): # Put it together assert dim == len(data) assert data.ndim == (1 if diag else 2) - cov = dict(kind=cov_kind, diag=diag, dim=dim, names=names, - data=data, projs=projs, bads=bads, nfree=nfree, eig=eig, - eigvec=eigvec) + cov = dict( + kind=cov_kind, + diag=diag, + dim=dim, + names=names, + data=data, + projs=projs, + bads=bads, + nfree=nfree, + eig=eig, + eigvec=eigvec, + ) if score is not None: - cov['loglik'] = score + cov["loglik"] = score if method is not None: - cov['method'] = method + cov["method"] = method if limited: - del cov['kind'], cov['dim'], cov['diag'] + del cov["kind"], cov["dim"], cov["diag"] return cov - logger.info(' Did not find the desired covariance matrix (kind = %d)' - % cov_kind) + logger.info(" Did not find the desired covariance matrix (kind = %d)" % cov_kind) return None @@ -2046,55 +2431,55 @@ def _write_cov(fid, cov): start_block(fid, FIFF.FIFFB_MNE_COV) # Dimensions etc. - write_int(fid, FIFF.FIFF_MNE_COV_KIND, cov['kind']) - write_int(fid, FIFF.FIFF_MNE_COV_DIM, cov['dim']) - if cov['nfree'] > 0: - write_int(fid, FIFF.FIFF_MNE_COV_NFREE, cov['nfree']) + write_int(fid, FIFF.FIFF_MNE_COV_KIND, cov["kind"]) + write_int(fid, FIFF.FIFF_MNE_COV_DIM, cov["dim"]) + if cov["nfree"] > 0: + write_int(fid, FIFF.FIFF_MNE_COV_NFREE, cov["nfree"]) # Channel names - if cov['names'] is not None and len(cov['names']) > 0: + if cov["names"] is not None and len(cov["names"]) > 0: write_name_list_sanitized( - fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names'], 'cov["names"]') + fid, FIFF.FIFF_MNE_ROW_NAMES, cov["names"], 'cov["names"]' + ) # Data - if cov['diag']: - write_double(fid, FIFF.FIFF_MNE_COV_DIAG, cov['data']) + if cov["diag"]: + write_double(fid, FIFF.FIFF_MNE_COV_DIAG, cov["data"]) else: # Store only lower part of covariance matrix - dim = cov['dim'] + dim = cov["dim"] mask = np.tril(np.ones((dim, dim), dtype=bool)) > 0 - vals = cov['data'][mask].ravel() + vals = cov["data"][mask].ravel() write_double(fid, FIFF.FIFF_MNE_COV, vals) # Eigenvalues and vectors if present - if cov['eig'] is not None and cov['eigvec'] is not None: - write_float_matrix(fid, FIFF.FIFF_MNE_COV_EIGENVECTORS, cov['eigvec']) - write_double(fid, FIFF.FIFF_MNE_COV_EIGENVALUES, cov['eig']) + if cov["eig"] is not None and cov["eigvec"] is not None: + write_float_matrix(fid, FIFF.FIFF_MNE_COV_EIGENVECTORS, cov["eigvec"]) + write_double(fid, FIFF.FIFF_MNE_COV_EIGENVALUES, cov["eig"]) # Projection operator - if cov['projs'] is not None and len(cov['projs']) > 0: - _write_proj(fid, cov['projs']) + if cov["projs"] is not None and len(cov["projs"]) > 0: + _write_proj(fid, cov["projs"]) # Bad channels - _write_bad_channels(fid, cov['bads'], None) + _write_bad_channels(fid, cov["bads"], None) # estimator method - if 'method' in cov: - write_string(fid, FIFF.FIFF_MNE_COV_METHOD, cov['method']) + if "method" in cov: + write_string(fid, FIFF.FIFF_MNE_COV_METHOD, cov["method"]) # negative log-likelihood score - if 'loglik' in cov: - write_double( - fid, FIFF.FIFF_MNE_COV_SCORE, np.array(cov['loglik'])) + if "loglik" in cov: + write_double(fid, FIFF.FIFF_MNE_COV_SCORE, np.array(cov["loglik"])) # Done! end_block(fid, FIFF.FIFFB_MNE_COV) @verbose -def _ensure_cov(cov, name='cov', *, verbose=None): - _validate_type(cov, ('path-like', Covariance), name) - logger.info('Noise covariance : %s' % (cov,)) +def _ensure_cov(cov, name="cov", *, verbose=None): + _validate_type(cov, ("path-like", Covariance), name) + logger.info("Noise covariance : %s" % (cov,)) if not isinstance(cov, Covariance): cov = read_cov(cov, verbose=_verbose_safe_false()) return cov diff --git a/mne/cuda.py b/mne/cuda.py index 15a2be2bab7..2b2dab64836 100644 --- a/mne/cuda.py +++ b/mne/cuda.py @@ -4,14 +4,22 @@ import numpy as np -from .utils import (sizeof_fmt, logger, get_config, warn, _explain_exception, - verbose, fill_doc, _check_option) +from .utils import ( + sizeof_fmt, + logger, + get_config, + warn, + _explain_exception, + verbose, + fill_doc, + _check_option, +) _cuda_capable = False -def get_cuda_memory(kind='available'): +def get_cuda_memory(kind="available"): """Get the amount of free memory for CUDA operations. Parameters @@ -25,10 +33,11 @@ def get_cuda_memory(kind='available'): The amount of available or total memory as a human-readable string. """ if not _cuda_capable: - warn('CUDA not enabled, returning zero for memory') + warn("CUDA not enabled, returning zero for memory") mem = 0 else: import cupy + mem = cupy.cuda.runtime.memGetInfo()[dict(available=0, total=1)[kind]] return sizeof_fmt(mem) @@ -55,29 +64,30 @@ def init_cuda(ignore_config=False, verbose=None): global _cuda_capable if _cuda_capable: return - if not ignore_config and (get_config('MNE_USE_CUDA', 'false').lower() != - 'true'): - logger.info('CUDA not enabled in config, skipping initialization') + if not ignore_config and (get_config("MNE_USE_CUDA", "false").lower() != "true"): + logger.info("CUDA not enabled in config, skipping initialization") return # Triage possible errors for informative messaging _cuda_capable = False try: import cupy # noqa except ImportError: - warn('module cupy not found, CUDA not enabled') + warn("module cupy not found, CUDA not enabled") return - device_id = int(get_config('MNE_CUDA_DEVICE', '0')) + device_id = int(get_config("MNE_CUDA_DEVICE", "0")) try: # Initialize CUDA _set_cuda_device(device_id, verbose) except Exception: - warn('so CUDA device could be initialized, likely a hardware error, ' - 'CUDA not enabled%s' % _explain_exception()) + warn( + "so CUDA device could be initialized, likely a hardware error, " + "CUDA not enabled%s" % _explain_exception() + ) return _cuda_capable = True # Figure out limit for CUDA FFT calculations - logger.info('Enabling CUDA with %s available memory' % get_cuda_memory()) + logger.info("Enabling CUDA with %s available memory" % get_cuda_memory()) @verbose @@ -92,28 +102,31 @@ def set_cuda_device(device_id, verbose=None): """ if _cuda_capable: _set_cuda_device(device_id, verbose) - elif get_config('MNE_USE_CUDA', 'false').lower() == 'true': + elif get_config("MNE_USE_CUDA", "false").lower() == "true": init_cuda() _set_cuda_device(device_id, verbose) else: - warn('Could not set CUDA device because CUDA is not enabled; either ' - 'run mne.cuda.init_cuda() first, or set the MNE_USE_CUDA config ' - 'variable to "true".') + warn( + "Could not set CUDA device because CUDA is not enabled; either " + "run mne.cuda.init_cuda() first, or set the MNE_USE_CUDA config " + 'variable to "true".' + ) @verbose def _set_cuda_device(device_id, verbose=None): """Set the CUDA device.""" import cupy + cupy.cuda.Device(device_id).use() - logger.info('Now using CUDA device {}'.format(device_id)) + logger.info("Now using CUDA device {}".format(device_id)) ############################################################################### # Repeated FFT multiplication -def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, - kind='FFT FIR filtering'): + +def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, kind="FFT FIR filtering"): """Set up repeated CUDA FFT multiplication with a given filter. Parameters @@ -154,28 +167,31 @@ def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, This function is designed to be used with fft_multiply_repeated(). """ from scipy.fft import rfft, irfft - cuda_dict = dict(n_fft=n_fft, rfft=rfft, irfft=irfft, - h_fft=rfft(h, n=n_fft)) + + cuda_dict = dict(n_fft=n_fft, rfft=rfft, irfft=irfft, h_fft=rfft(h, n=n_fft)) if isinstance(n_jobs, str): - _check_option('n_jobs', n_jobs, ('cuda',)) + _check_option("n_jobs", n_jobs, ("cuda",)) n_jobs = 1 init_cuda() if _cuda_capable: import cupy + try: # do the IFFT normalization now so we don't have to later - h_fft = cupy.array(cuda_dict['h_fft']) - logger.info('Using CUDA for %s' % kind) + h_fft = cupy.array(cuda_dict["h_fft"]) + logger.info("Using CUDA for %s" % kind) except Exception as exp: - logger.info('CUDA not used, could not instantiate memory ' - '(arrays may be too large: "%s"), falling back to ' - 'n_jobs=None' % str(exp)) - cuda_dict.update(h_fft=h_fft, - rfft=_cuda_upload_rfft, - irfft=_cuda_irfft_get) + logger.info( + "CUDA not used, could not instantiate memory " + '(arrays may be too large: "%s"), falling back to ' + "n_jobs=None" % str(exp) + ) + cuda_dict.update(h_fft=h_fft, rfft=_cuda_upload_rfft, irfft=_cuda_irfft_get) else: - logger.info('CUDA not used, CUDA could not be initialized, ' - 'falling back to n_jobs=None') + logger.info( + "CUDA not used, CUDA could not be initialized, " + "falling back to n_jobs=None" + ) return n_jobs, cuda_dict @@ -199,15 +215,16 @@ def _fft_multiply_repeated(x, cuda_dict): Filtered version of x. """ # do the fourier-domain operations - x_fft = cuda_dict['rfft'](x, cuda_dict['n_fft']) - x_fft *= cuda_dict['h_fft'] - x = cuda_dict['irfft'](x_fft, cuda_dict['n_fft']) + x_fft = cuda_dict["rfft"](x, cuda_dict["n_fft"]) + x_fft *= cuda_dict["h_fft"] + x = cuda_dict["irfft"](x_fft, cuda_dict["n_fft"]) return x ############################################################################### # FFT Resampling + def _setup_cuda_fft_resample(n_jobs, W, new_len): """Set up CUDA FFT resampling. @@ -248,52 +265,59 @@ def _setup_cuda_fft_resample(n_jobs, W, new_len): This function is designed to be used with fft_resample(). """ from scipy.fft import rfft, irfft + cuda_dict = dict(use_cuda=False, rfft=rfft, irfft=irfft) rfft_len_x = len(W) // 2 + 1 # fold the window onto inself (should be symmetric) and truncate W = W.copy() - W[1:rfft_len_x] = (W[1:rfft_len_x] + W[::-1][:rfft_len_x - 1]) / 2. + W[1:rfft_len_x] = (W[1:rfft_len_x] + W[::-1][: rfft_len_x - 1]) / 2.0 W = W[:rfft_len_x] if isinstance(n_jobs, str): - _check_option('n_jobs', n_jobs, ('cuda',)) + _check_option("n_jobs", n_jobs, ("cuda",)) n_jobs = 1 init_cuda() if _cuda_capable: try: import cupy + # do the IFFT normalization now so we don't have to later W = cupy.array(W) - logger.info('Using CUDA for FFT resampling') + logger.info("Using CUDA for FFT resampling") except Exception: - logger.info('CUDA not used, could not instantiate memory ' - '(arrays may be too large), falling back to ' - 'n_jobs=None') + logger.info( + "CUDA not used, could not instantiate memory " + "(arrays may be too large), falling back to " + "n_jobs=None" + ) else: - cuda_dict.update(use_cuda=True, - rfft=_cuda_upload_rfft, - irfft=_cuda_irfft_get) + cuda_dict.update( + use_cuda=True, rfft=_cuda_upload_rfft, irfft=_cuda_irfft_get + ) else: - logger.info('CUDA not used, CUDA could not be initialized, ' - 'falling back to n_jobs=None') - cuda_dict['W'] = W + logger.info( + "CUDA not used, CUDA could not be initialized, " + "falling back to n_jobs=None" + ) + cuda_dict["W"] = W return n_jobs, cuda_dict def _cuda_upload_rfft(x, n, axis=-1): """Upload and compute rfft.""" import cupy + return cupy.fft.rfft(cupy.array(x), n=n, axis=axis) def _cuda_irfft_get(x, n, axis=-1): """Compute irfft and get.""" import cupy + return cupy.fft.irfft(x, n=n, axis=axis).get() @fill_doc -def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, - pad='reflect_limited'): +def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, pad="reflect_limited"): """Do FFT resampling with a filter function (possibly using CUDA). Parameters @@ -327,16 +351,16 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, old_len = len(x) shorter = new_len < old_len use_len = new_len if shorter else old_len - x_fft = cuda_dict['rfft'](x, None) + x_fft = cuda_dict["rfft"](x, None) if use_len % 2 == 0: nyq = use_len // 2 - x_fft[nyq:nyq + 1] *= 2 if shorter else 0.5 - x_fft *= cuda_dict['W'] - y = cuda_dict['irfft'](x_fft, new_len) + x_fft[nyq : nyq + 1] *= 2 if shorter else 0.5 + x_fft *= cuda_dict["W"] + y = cuda_dict["irfft"](x_fft, new_len) # now let's trim it back to the correct size (if there was padding) if (to_removes > 0).any(): - y = y[to_removes[0]:y.shape[0] - to_removes[1]] + y = y[to_removes[0] : y.shape[0] - to_removes[1]] return y @@ -344,20 +368,28 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, ############################################################################### # Misc + # this has to go in mne.cuda instead of mne.filter to avoid import errors -def _smart_pad(x, n_pad, pad='reflect_limited'): +def _smart_pad(x, n_pad, pad="reflect_limited"): """Pad vector x.""" n_pad = np.asarray(n_pad) assert n_pad.shape == (2,) if (n_pad == 0).all(): return x elif (n_pad < 0).any(): - raise RuntimeError('n_pad must be non-negative') - if pad == 'reflect_limited': + raise RuntimeError("n_pad must be non-negative") + if pad == "reflect_limited": # need to pad with zeros if len(x) <= npad l_z_pad = np.zeros(max(n_pad[0] - len(x) + 1, 0), dtype=x.dtype) r_z_pad = np.zeros(max(n_pad[1] - len(x) + 1, 0), dtype=x.dtype) - return np.concatenate([l_z_pad, 2 * x[0] - x[n_pad[0]:0:-1], x, - 2 * x[-1] - x[-2:-n_pad[1] - 2:-1], r_z_pad]) + return np.concatenate( + [ + l_z_pad, + 2 * x[0] - x[n_pad[0] : 0 : -1], + x, + 2 * x[-1] - x[-2 : -n_pad[1] - 2 : -1], + r_z_pad, + ] + ) else: return np.pad(x, (tuple(n_pad),), pad) diff --git a/mne/datasets/__init__.py b/mne/datasets/__init__.py index ec24f450fd0..1549fa21f8f 100644 --- a/mne/datasets/__init__.py +++ b/mne/datasets/__init__.py @@ -29,19 +29,47 @@ from . import eyelink from . import ucl_opm_auditory from ._fetch import fetch_dataset -from .utils import (_download_all_example_data, fetch_hcp_mmp_parcellation, - fetch_aparc_sub_parcellation, has_dataset) +from .utils import ( + _download_all_example_data, + fetch_hcp_mmp_parcellation, + fetch_aparc_sub_parcellation, + has_dataset, +) from ._fsaverage.base import fetch_fsaverage from ._infant.base import fetch_infant_template from ._phantom.base import fetch_phantom __all__ = [ - '_download_all_example_data', '_fake', 'brainstorm', 'eegbci', - 'fetch_aparc_sub_parcellation', 'fetch_fsaverage', 'fetch_infant_template', - 'fetch_hcp_mmp_parcellation', 'fieldtrip_cmc', 'hf_sef', 'kiloword', - 'misc', 'mtrf', 'multimodal', 'opm', 'phantom_4dbti', 'sample', - 'sleep_physionet', 'somato', 'spm_face', 'ssvep', 'testing', - 'visual_92_categories', 'limo', 'erp_core', 'epilepsy_ecog', - 'fetch_dataset', 'fetch_phantom', 'has_dataset', 'refmeg_noise', - 'fnirs_motor', 'eyelink' + "_download_all_example_data", + "_fake", + "brainstorm", + "eegbci", + "fetch_aparc_sub_parcellation", + "fetch_fsaverage", + "fetch_infant_template", + "fetch_hcp_mmp_parcellation", + "fieldtrip_cmc", + "hf_sef", + "kiloword", + "misc", + "mtrf", + "multimodal", + "opm", + "phantom_4dbti", + "sample", + "sleep_physionet", + "somato", + "spm_face", + "ssvep", + "testing", + "visual_92_categories", + "limo", + "erp_core", + "epilepsy_ecog", + "fetch_dataset", + "fetch_phantom", + "has_dataset", + "refmeg_noise", + "fnirs_motor", + "eyelink", ] diff --git a/mne/datasets/_fake/_fake.py b/mne/datasets/_fake/_fake.py index 61ef7678862..475b7aeb640 100644 --- a/mne/datasets/_fake/_fake.py +++ b/mne/datasets/_fake/_fake.py @@ -4,25 +4,28 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _download_mne_dataset, - _get_version, _version_doc) +from ..utils import _data_path_doc, _download_mne_dataset, _get_version, _version_doc @verbose -def data_path(path=None, force_update=False, update_path=False, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=False, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fake', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fake", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='fake', - conf='MNE_DATASETS_FAKE_PATH') +data_path.__doc__ = _data_path_doc.format(name="fake", conf="MNE_DATASETS_FAKE_PATH") def get_version(): # noqa: D103 - return _get_version('fake') + return _get_version("fake") -get_version.__doc__ = _version_doc.format(name='fake') +get_version.__doc__ = _version_doc.format(name="fake") diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 578c1cf82ed..37802f817b8 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -17,8 +17,13 @@ TESTING_VERSIONED, MISC_VERSIONED, ) -from .utils import (_dataset_version, _do_path_update, _get_path, - _log_time_size, _downloader_params) +from .utils import ( + _dataset_version, + _do_path_update, + _get_path, + _log_time_size, + _downloader_params, +) from ..fixes import _compare_version @@ -131,6 +136,7 @@ def fetch_dataset( pass a list of dicts. """ # noqa E501 import pooch + t0 = time.time() if auth is not None: @@ -153,7 +159,7 @@ def fetch_dataset( names = [params["dataset_name"] for params in dataset_params] name = names[0] dataset_dict = dataset_params[0] - config_key = dataset_dict.get('config_key', None) + config_key = dataset_dict.get("config_key", None) folder_name = dataset_dict["folder_name"] # get download path for specific dataset @@ -175,8 +181,9 @@ def fetch_dataset( # get the version of the dataset and then check if the version is outdated data_version = _dataset_version(final_path, name) - outdated = (want_version is not None and - _compare_version(want_version, '>', data_version)) + outdated = want_version is not None and _compare_version( + want_version, ">", data_version + ) if outdated: logger.info( @@ -188,16 +195,13 @@ def fetch_dataset( # return empty string if outdated dataset and we don't want to download if (not force_update) and outdated and not download: logger.info( - 'Dataset out of date but force_update=False and download=False, ' - 'returning empty data_path') + "Dataset out of date but force_update=False and download=False, " + "returning empty data_path" + ) return (empty, data_version) if return_version else empty # reasons to bail early (hf_sef has separate code for this): - if ( - (not force_update) - and (not outdated) - and (not name.startswith("hf_sef_")) - ): + if (not force_update) and (not outdated) and (not name.startswith("hf_sef_")): # ...if target folder exists (otherwise pooch downloads every # time because we don't save the archive files after unpacking, so # pooch can't check its checksum) @@ -215,8 +219,7 @@ def fetch_dataset( else: # If they don't have stdin, just accept the license # https://github.com/mne-tools/mne-python/issues/8513#issuecomment-726823724 # noqa: E501 - answer = _safe_input( - "%sAgree (y/[n])? " % _bst_license_text, use="y") + answer = _safe_input("%sAgree (y/[n])? " % _bst_license_text, use="y") if answer.lower() != "y": raise RuntimeError( "You must agree to the license to use this " "dataset" @@ -262,10 +265,11 @@ def fetch_dataset( ) except ValueError as err: err = str(err) - if 'hash of downloaded file' in str(err): + if "hash of downloaded file" in str(err): raise ValueError( - f'{err} Consider using force_update=True to force ' - 'the dataset to be downloaded again.') from None + f"{err} Consider using force_update=True to force " + "the dataset to be downloaded again." + ) from None else: raise fname = use_path / archive_name @@ -291,7 +295,7 @@ def fetch_dataset( data_version = _dataset_version(path, name) # 0.7 < 0.7.git should be False, therefore strip if check_version and ( - _compare_version(data_version, '<', mne_version.strip(".git")) + _compare_version(data_version, "<", mne_version.strip(".git")) ): warn( "The {name} dataset (version {current}) is older than " diff --git a/mne/datasets/_fsaverage/base.py b/mne/datasets/_fsaverage/base.py index d4a8f3d82c0..daa01dc64c2 100644 --- a/mne/datasets/_fsaverage/base.py +++ b/mne/datasets/_fsaverage/base.py @@ -65,19 +65,19 @@ def fetch_fsaverage(subjects_dir=None, *, verbose=None): # subjects_dir = _set_montage_coreg_path(subjects_dir) subjects_dir = op.abspath(op.expanduser(subjects_dir)) - fs_dir = op.join(subjects_dir, 'fsaverage') + fs_dir = op.join(subjects_dir, "fsaverage") os.makedirs(fs_dir, exist_ok=True) _manifest_check_download( - manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, 'root.txt'), + manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, "root.txt"), destination=op.join(subjects_dir), - url='https://osf.io/3bxqt/download?version=2', - hash_='5133fe92b7b8f03ae19219d5f46e4177', + url="https://osf.io/3bxqt/download?version=2", + hash_="5133fe92b7b8f03ae19219d5f46e4177", ) _manifest_check_download( - manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, 'bem.txt'), - destination=op.join(subjects_dir, 'fsaverage'), - url='https://osf.io/7ve8g/download?version=4', - hash_='b31509cdcf7908af6a83dc5ee8f49fb1', + manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, "bem.txt"), + destination=op.join(subjects_dir, "fsaverage"), + url="https://osf.io/7ve8g/download?version=4", + hash_="b31509cdcf7908af6a83dc5ee8f49fb1", ) return fs_dir @@ -85,8 +85,8 @@ def fetch_fsaverage(subjects_dir=None, *, verbose=None): def _get_create_subjects_dir(subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=False) if subjects_dir is None: - subjects_dir = _get_path(None, 'MNE_DATA', 'montage coregistration') - subjects_dir = op.join(subjects_dir, 'MNE-fsaverage-data') + subjects_dir = _get_path(None, "MNE_DATA", "montage coregistration") + subjects_dir = op.join(subjects_dir, "MNE-fsaverage-data") os.makedirs(subjects_dir, exist_ok=True) else: subjects_dir = str(subjects_dir) @@ -128,5 +128,5 @@ def _set_montage_coreg_path(subjects_dir=None): subjects_dir = _get_create_subjects_dir(subjects_dir) old_subjects_dir = get_subjects_dir(None, raise_error=False) if old_subjects_dir is None: - set_config('SUBJECTS_DIR', subjects_dir) + set_config("SUBJECTS_DIR", subjects_dir) return subjects_dir diff --git a/mne/datasets/_infant/base.py b/mne/datasets/_infant/base.py index c327c4835e0..196faa7bfc2 100644 --- a/mne/datasets/_infant/base.py +++ b/mne/datasets/_infant/base.py @@ -7,9 +7,9 @@ from ..utils import _manifest_check_download from ...utils import verbose, get_subjects_dir, _check_option, _validate_type -_AGES = '2wk 1mo 2mo 3mo 4.5mo 6mo 7.5mo 9mo 10.5mo 12mo 15mo 18mo 2yr' +_AGES = "2wk 1mo 2mo 3mo 4.5mo 6mo 7.5mo 9mo 10.5mo 12mo 15mo 18mo 2yr" # https://github.com/christian-oreilly/infant_template_paper/releases -_ORIGINAL_URL = 'https://github.com/christian-oreilly/infant_template_paper/releases/download/v0.1-alpha/%7Bsubject%7D.zip' # noqa: E501 +_ORIGINAL_URL = "https://github.com/christian-oreilly/infant_template_paper/releases/download/v0.1-alpha/%7Bsubject%7D.zip" # noqa: E501 # Formatted the same way as md5sum *.zip on Ubuntu: _ORIGINAL_HASHES = """ 851737d5f8f246883f2aef9819c6ec29 ANTS10-5Months3T.zip @@ -71,23 +71,24 @@ def fetch_infant_template(age, subjects_dir=None, *, verbose=None): # ... names = sorted(name for name in zip.namelist() if not zipfile.Path(zip, name).is_dir()) # noqa: E501 # ... with open(f'{name}.txt', 'w') as fid: # ... fid.write('\n'.join(names)) - _validate_type(age, str, 'age') - _check_option('age', age, _AGES.split()) + _validate_type(age, str, "age") + _check_option("age", age, _AGES.split()) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - unit = dict(wk='Weeks', mo='Months', yr='Years')[age[-2:]] - first = age[:-2].split('.')[0] - dash = '-5' if '.5' in age else '-0' - subject = f'ANTS{first}{dash}{unit}3T' + unit = dict(wk="Weeks", mo="Months", yr="Years")[age[-2:]] + first = age[:-2].split(".")[0] + dash = "-5" if ".5" in age else "-0" + subject = f"ANTS{first}{dash}{unit}3T" # Actually get and create the files subj_dir = subjects_dir / subject os.makedirs(subj_dir, exist_ok=True) # .zip -> hash mapping - orig_hashes = dict(line.strip().split()[::-1] - for line in _ORIGINAL_HASHES.strip().splitlines()) + orig_hashes = dict( + line.strip().split()[::-1] for line in _ORIGINAL_HASHES.strip().splitlines() + ) _manifest_check_download( - manifest_path=op.join(_MANIFEST_PATH, f'{subject}.txt'), + manifest_path=op.join(_MANIFEST_PATH, f"{subject}.txt"), destination=subj_dir, url=_ORIGINAL_URL.format(subject=subject), - hash_=orig_hashes[f'{subject}.zip'], + hash_=orig_hashes[f"{subject}.zip"], ) return subject diff --git a/mne/datasets/_phantom/base.py b/mne/datasets/_phantom/base.py index 8785e3018ec..3d8af0e68ac 100644 --- a/mne/datasets/_phantom/base.py +++ b/mne/datasets/_phantom/base.py @@ -43,19 +43,21 @@ def fetch_phantom(kind, subjects_dir=None, *, verbose=None): .. versionadded:: 0.24 """ phantoms = dict( - otaniemi=dict(url='https://osf.io/j5czy/download?version=1', - hash='42d17db5b1db3e30327ffb4cf2649de8'), + otaniemi=dict( + url="https://osf.io/j5czy/download?version=1", + hash="42d17db5b1db3e30327ffb4cf2649de8", + ), ) - _validate_type(kind, str, 'kind') - _check_option('kind', kind, list(phantoms)) + _validate_type(kind, str, "kind") + _check_option("kind", kind, list(phantoms)) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - subject = f'phantom_{kind}' + subject = f"phantom_{kind}" subject_dir = subjects_dir / subject os.makedirs(subject_dir, exist_ok=True) _manifest_check_download( - manifest_path=op.join(PHANTOM_MANIFEST_PATH, f'{subject}.txt'), + manifest_path=op.join(PHANTOM_MANIFEST_PATH, f"{subject}.txt"), destination=subjects_dir, - url=phantoms[kind]['url'], - hash_=phantoms[kind]['hash'], + url=phantoms[kind]["url"], + hash_=phantoms[kind]["hash"], ) return subject_dir diff --git a/mne/datasets/brainstorm/__init__.py b/mne/datasets/brainstorm/__init__.py index 8dcf9b79811..e97790f52c6 100644 --- a/mne/datasets/brainstorm/__init__.py +++ b/mne/datasets/brainstorm/__init__.py @@ -1,4 +1,3 @@ """Brainstorm datasets.""" -from . import (bst_raw, bst_resting, bst_auditory, bst_phantom_ctf, - bst_phantom_elekta) +from . import bst_raw, bst_resting, bst_auditory, bst_phantom_ctf, bst_phantom_elekta diff --git a/mne/datasets/brainstorm/bst_auditory.py b/mne/datasets/brainstorm/bst_auditory.py index 41c2f078671..a45dc72b5cf 100644 --- a/mne/datasets/brainstorm/bst_auditory.py +++ b/mne/datasets/brainstorm/bst_auditory.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetAuditory @@ -22,26 +26,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_auditory', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_auditory", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_auditory) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_auditory) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_auditory') + return _get_version("bst_auditory") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_phantom_ctf.py b/mne/datasets/brainstorm/bst_phantom_ctf.py index 87300a82971..147626d33b6 100644 --- a/mne/datasets/brainstorm/bst_phantom_ctf.py +++ b/mne/datasets/brainstorm/bst_phantom_ctf.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomCtf @@ -11,26 +15,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_phantom_ctf', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_phantom_ctf", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_phantom_ctf) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_phantom_ctf) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_phantom_ctf') + return _get_version("bst_phantom_ctf") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_phantom_elekta.py b/mne/datasets/brainstorm/bst_phantom_elekta.py index abfa5a68aca..8e5b5a8a69c 100644 --- a/mne/datasets/brainstorm/bst_phantom_elekta.py +++ b/mne/datasets/brainstorm/bst_phantom_elekta.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomElekta @@ -11,27 +15,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_phantom_elekta', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_phantom_elekta", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_phantom_elekta) ' - 'dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_phantom_elekta) " "dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_phantom_elekta') + return _get_version("bst_phantom_elekta") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_raw.py b/mne/datasets/brainstorm/bst_raw.py index 0616ca176d5..f8d92e0b26c 100644 --- a/mne/datasets/brainstorm/bst_raw.py +++ b/mne/datasets/brainstorm/bst_raw.py @@ -4,11 +4,16 @@ from functools import partial from ...utils import verbose, get_config -from ..utils import (has_dataset, _get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + has_dataset, + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) -has_brainstorm_data = partial(has_dataset, name='bst_raw') +has_brainstorm_data = partial(has_dataset, name="bst_raw") _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetMedianNerveCtf @@ -26,26 +31,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_raw', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_raw", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_raw) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_raw) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_raw') + return _get_version("bst_raw") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): # noqa: D103 @@ -55,8 +74,7 @@ def description(): # noqa: D103 def _skip_bstraw_data(): - skip_testing = (get_config('MNE_SKIP_TESTING_DATASET_TESTS', 'false') == - 'true') + skip_testing = get_config("MNE_SKIP_TESTING_DATASET_TESTS", "false") == "true" skip = skip_testing or not has_brainstorm_data() return skip @@ -64,5 +82,7 @@ def _skip_bstraw_data(): def requires_bstraw_data(func): """Skip testing data test.""" import pytest - return pytest.mark.skipif(_skip_bstraw_data(), - reason='Requires brainstorm dataset')(func) + + return pytest.mark.skipif( + _skip_bstraw_data(), reason="Requires brainstorm dataset" + )(func) diff --git a/mne/datasets/brainstorm/bst_resting.py b/mne/datasets/brainstorm/bst_resting.py index e0eb226e863..9e2f8f7e73b 100644 --- a/mne/datasets/brainstorm/bst_resting.py +++ b/mne/datasets/brainstorm/bst_resting.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetResting @@ -14,26 +18,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_resting', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_resting", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_resting) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_resting) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_resting') + return _get_version("bst_resting") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ec45dbbf91b..7869f97a78e 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing='0.146', misc='0.26') +RELEASES = dict(testing="0.146", misc="0.26") TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -109,240 +109,245 @@ # of the downloaded dataset (ex: "MNE_DATASETS_EEGBCI_PATH"). # Testing and misc are at the top as they're updated most often -MNE_DATASETS['testing'] = dict( - archive_name=f'{TESTING_VERSIONED}.tar.gz', - hash='md5:a2e86fe404f4321408b22f38711d11b7', - url=('https://codeload.github.com/mne-tools/mne-testing-data/' - f'tar.gz/{RELEASES["testing"]}'), +MNE_DATASETS["testing"] = dict( + archive_name=f"{TESTING_VERSIONED}.tar.gz", + hash="md5:a2e86fe404f4321408b22f38711d11b7", + url=( + "https://codeload.github.com/mne-tools/mne-testing-data/" + f'tar.gz/{RELEASES["testing"]}' + ), # In case we ever have to resort to osf.io again... # archive_name='mne-testing-data.tar.gz', # hash='md5:c805a5fed8ca46f723e7eec828d90824', # url='https://osf.io/dqfgy/download?version=1', # 0.136 - folder_name='MNE-testing-data', - config_key='MNE_DATASETS_TESTING_PATH', + folder_name="MNE-testing-data", + config_key="MNE_DATASETS_TESTING_PATH", ) -MNE_DATASETS['misc'] = dict( - archive_name=f'{MISC_VERSIONED}.tar.gz', # 'mne-misc-data', - hash='md5:868b484fadd73b1d1a3535b7194a0d03', - url=('https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/' - f'{RELEASES["misc"]}'), - folder_name='MNE-misc-data', - config_key='MNE_DATASETS_MISC_PATH' +MNE_DATASETS["misc"] = dict( + archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', + hash="md5:868b484fadd73b1d1a3535b7194a0d03", + url=( + "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" + f'{RELEASES["misc"]}' + ), + folder_name="MNE-misc-data", + config_key="MNE_DATASETS_MISC_PATH", ) -MNE_DATASETS['fnirs_motor'] = dict( - archive_name='MNE-fNIRS-motor-data.tgz', - hash='md5:c4935d19ddab35422a69f3326a01fef8', - url='https://osf.io/dj3eh/download?version=1', - folder_name='MNE-fNIRS-motor-data', - config_key='MNE_DATASETS_FNIRS_MOTOR_PATH', +MNE_DATASETS["fnirs_motor"] = dict( + archive_name="MNE-fNIRS-motor-data.tgz", + hash="md5:c4935d19ddab35422a69f3326a01fef8", + url="https://osf.io/dj3eh/download?version=1", + folder_name="MNE-fNIRS-motor-data", + config_key="MNE_DATASETS_FNIRS_MOTOR_PATH", ) -MNE_DATASETS['ucl_opm_auditory'] = dict( - archive_name='auditory_OPM_stationary.zip', - hash='md5:9ed0d8d554894542b56f8e7c4c0041fe', - url='https://osf.io/download/mwrt3/?version=1', - folder_name='auditory_OPM_stationary', - config_key='MNE_DATASETS_UCL_OPM_AUDITORY_PATH', +MNE_DATASETS["ucl_opm_auditory"] = dict( + archive_name="auditory_OPM_stationary.zip", + hash="md5:9ed0d8d554894542b56f8e7c4c0041fe", + url="https://osf.io/download/mwrt3/?version=1", + folder_name="auditory_OPM_stationary", + config_key="MNE_DATASETS_UCL_OPM_AUDITORY_PATH", ) -MNE_DATASETS['kiloword'] = dict( - archive_name='MNE-kiloword-data.tar.gz', - hash='md5:3a124170795abbd2e48aae8727e719a8', - url='https://osf.io/qkvf9/download?version=1', - folder_name='MNE-kiloword-data', - config_key='MNE_DATASETS_KILOWORD_PATH', +MNE_DATASETS["kiloword"] = dict( + archive_name="MNE-kiloword-data.tar.gz", + hash="md5:3a124170795abbd2e48aae8727e719a8", + url="https://osf.io/qkvf9/download?version=1", + folder_name="MNE-kiloword-data", + config_key="MNE_DATASETS_KILOWORD_PATH", ) -MNE_DATASETS['multimodal'] = dict( - archive_name='MNE-multimodal-data.tar.gz', - hash='md5:26ec847ae9ab80f58f204d09e2c08367', - url='https://ndownloader.figshare.com/files/5999598', - folder_name='MNE-multimodal-data', - config_key='MNE_DATASETS_MULTIMODAL_PATH', +MNE_DATASETS["multimodal"] = dict( + archive_name="MNE-multimodal-data.tar.gz", + hash="md5:26ec847ae9ab80f58f204d09e2c08367", + url="https://ndownloader.figshare.com/files/5999598", + folder_name="MNE-multimodal-data", + config_key="MNE_DATASETS_MULTIMODAL_PATH", ) -MNE_DATASETS['opm'] = dict( - archive_name='MNE-OPM-data.tar.gz', - hash='md5:370ad1dcfd5c47e029e692c85358a374', - url='https://osf.io/p6ae7/download?version=2', - folder_name='MNE-OPM-data', - config_key='MNE_DATASETS_OPM_PATH', +MNE_DATASETS["opm"] = dict( + archive_name="MNE-OPM-data.tar.gz", + hash="md5:370ad1dcfd5c47e029e692c85358a374", + url="https://osf.io/p6ae7/download?version=2", + folder_name="MNE-OPM-data", + config_key="MNE_DATASETS_OPM_PATH", ) -MNE_DATASETS['phantom_4dbti'] = dict( - archive_name='MNE-phantom-4DBTi.zip', - hash='md5:938a601440f3ffa780d20a17bae039ff', - url='https://osf.io/v2brw/download?version=2', - folder_name='MNE-phantom-4DBTi', - config_key='MNE_DATASETS_PHANTOM_4DBTI_PATH', +MNE_DATASETS["phantom_4dbti"] = dict( + archive_name="MNE-phantom-4DBTi.zip", + hash="md5:938a601440f3ffa780d20a17bae039ff", + url="https://osf.io/v2brw/download?version=2", + folder_name="MNE-phantom-4DBTi", + config_key="MNE_DATASETS_PHANTOM_4DBTI_PATH", ) -MNE_DATASETS['sample'] = dict( - archive_name='MNE-sample-data-processed.tar.gz', - hash='md5:e8f30c4516abdc12a0c08e6bae57409c', - url='https://osf.io/86qa2/download?version=6', - folder_name='MNE-sample-data', - config_key='MNE_DATASETS_SAMPLE_PATH', +MNE_DATASETS["sample"] = dict( + archive_name="MNE-sample-data-processed.tar.gz", + hash="md5:e8f30c4516abdc12a0c08e6bae57409c", + url="https://osf.io/86qa2/download?version=6", + folder_name="MNE-sample-data", + config_key="MNE_DATASETS_SAMPLE_PATH", ) -MNE_DATASETS['somato'] = dict( - archive_name='MNE-somato-data.tar.gz', - hash='md5:32fd2f6c8c7eb0784a1de6435273c48b', - url='https://osf.io/tp4sg/download?version=7', - folder_name='MNE-somato-data', - config_key='MNE_DATASETS_SOMATO_PATH' +MNE_DATASETS["somato"] = dict( + archive_name="MNE-somato-data.tar.gz", + hash="md5:32fd2f6c8c7eb0784a1de6435273c48b", + url="https://osf.io/tp4sg/download?version=7", + folder_name="MNE-somato-data", + config_key="MNE_DATASETS_SOMATO_PATH", ) -MNE_DATASETS['spm'] = dict( - archive_name='MNE-spm-face.tar.gz', - hash='md5:9f43f67150e3b694b523a21eb929ea75', - url='https://osf.io/je4s8/download?version=2', - folder_name='MNE-spm-face', - config_key='MNE_DATASETS_SPM_FACE_PATH', +MNE_DATASETS["spm"] = dict( + archive_name="MNE-spm-face.tar.gz", + hash="md5:9f43f67150e3b694b523a21eb929ea75", + url="https://osf.io/je4s8/download?version=2", + folder_name="MNE-spm-face", + config_key="MNE_DATASETS_SPM_FACE_PATH", ) # Visual 92 categories has the dataset split into 2 files. # We define a dictionary holding the items with the same # value across both files: folder name and configuration key. -MNE_DATASETS['visual_92_categories'] = dict( - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories"] = dict( + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['visual_92_categories_1'] = dict( - archive_name='MNE-visual_92_categories-data-part1.tar.gz', - hash='md5:74f50bbeb65740903eadc229c9fa759f', - url='https://osf.io/8ejrs/download?version=1', - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories_1"] = dict( + archive_name="MNE-visual_92_categories-data-part1.tar.gz", + hash="md5:74f50bbeb65740903eadc229c9fa759f", + url="https://osf.io/8ejrs/download?version=1", + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['visual_92_categories_2'] = dict( - archive_name='MNE-visual_92_categories-data-part2.tar.gz', - hash='md5:203410a98afc9df9ae8ba9f933370e20', - url='https://osf.io/t4yjp/download?version=1', - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories_2"] = dict( + archive_name="MNE-visual_92_categories-data-part2.tar.gz", + hash="md5:203410a98afc9df9ae8ba9f933370e20", + url="https://osf.io/t4yjp/download?version=1", + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['mtrf'] = dict( - archive_name='mTRF_1.5.zip', - hash='md5:273a390ebbc48da2c3184b01a82e4636', - url='https://osf.io/h85s2/download?version=1', - folder_name='mTRF_1.5', - config_key='MNE_DATASETS_MTRF_PATH' +MNE_DATASETS["mtrf"] = dict( + archive_name="mTRF_1.5.zip", + hash="md5:273a390ebbc48da2c3184b01a82e4636", + url="https://osf.io/h85s2/download?version=1", + folder_name="mTRF_1.5", + config_key="MNE_DATASETS_MTRF_PATH", ) -MNE_DATASETS['refmeg_noise'] = dict( - archive_name='sample_reference_MEG_noise-raw.zip', - hash='md5:779fecd890d98b73a4832e717d7c7c45', - url='https://osf.io/drt6v/download?version=1', - folder_name='MNE-refmeg-noise-data', - config_key='MNE_DATASETS_REFMEG_NOISE_PATH' +MNE_DATASETS["refmeg_noise"] = dict( + archive_name="sample_reference_MEG_noise-raw.zip", + hash="md5:779fecd890d98b73a4832e717d7c7c45", + url="https://osf.io/drt6v/download?version=1", + folder_name="MNE-refmeg-noise-data", + config_key="MNE_DATASETS_REFMEG_NOISE_PATH", ) -MNE_DATASETS['ssvep'] = dict( - archive_name='ssvep_example_data.zip', - hash='md5:af866bbc0f921114ac9d683494fe87d6', - url='https://osf.io/z8h6k/download?version=5', - folder_name='ssvep-example-data', - config_key='MNE_DATASETS_SSVEP_PATH' +MNE_DATASETS["ssvep"] = dict( + archive_name="ssvep_example_data.zip", + hash="md5:af866bbc0f921114ac9d683494fe87d6", + url="https://osf.io/z8h6k/download?version=5", + folder_name="ssvep-example-data", + config_key="MNE_DATASETS_SSVEP_PATH", ) -MNE_DATASETS['erp_core'] = dict( - archive_name='MNE-ERP-CORE-data.tar.gz', - hash='md5:5866c0d6213bd7ac97f254c776f6c4b1', - url='https://osf.io/rzgba/download?version=1', - folder_name='MNE-ERP-CORE-data', - config_key='MNE_DATASETS_ERP_CORE_PATH', +MNE_DATASETS["erp_core"] = dict( + archive_name="MNE-ERP-CORE-data.tar.gz", + hash="md5:5866c0d6213bd7ac97f254c776f6c4b1", + url="https://osf.io/rzgba/download?version=1", + folder_name="MNE-ERP-CORE-data", + config_key="MNE_DATASETS_ERP_CORE_PATH", ) -MNE_DATASETS['epilepsy_ecog'] = dict( - archive_name='MNE-epilepsy-ecog-data.tar.gz', - hash='md5:ffb139174afa0f71ec98adbbb1729dea', - url='https://osf.io/z4epq/download?version=1', - folder_name='MNE-epilepsy-ecog-data', - config_key='MNE_DATASETS_EPILEPSY_ECOG_PATH', +MNE_DATASETS["epilepsy_ecog"] = dict( + archive_name="MNE-epilepsy-ecog-data.tar.gz", + hash="md5:ffb139174afa0f71ec98adbbb1729dea", + url="https://osf.io/z4epq/download?version=1", + folder_name="MNE-epilepsy-ecog-data", + config_key="MNE_DATASETS_EPILEPSY_ECOG_PATH", ) # Fieldtrip CMC dataset -MNE_DATASETS['fieldtrip_cmc'] = dict( - archive_name='SubjectCMC.zip', - hash='md5:6f9fd6520f9a66e20994423808d2528c', - url='https://osf.io/j9b6s/download?version=1', - folder_name='MNE-fieldtrip_cmc-data', - config_key='MNE_DATASETS_FIELDTRIP_CMC_PATH' +MNE_DATASETS["fieldtrip_cmc"] = dict( + archive_name="SubjectCMC.zip", + hash="md5:6f9fd6520f9a66e20994423808d2528c", + url="https://osf.io/j9b6s/download?version=1", + folder_name="MNE-fieldtrip_cmc-data", + config_key="MNE_DATASETS_FIELDTRIP_CMC_PATH", ) # brainstorm datasets: -MNE_DATASETS['bst_auditory'] = dict( - archive_name='bst_auditory.tar.gz', - hash='md5:fa371a889a5688258896bfa29dd1700b', - url='https://osf.io/5t9n8/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_auditory"] = dict( + archive_name="bst_auditory.tar.gz", + hash="md5:fa371a889a5688258896bfa29dd1700b", + url="https://osf.io/5t9n8/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_phantom_ctf'] = dict( - archive_name='bst_phantom_ctf.tar.gz', - hash='md5:80819cb7f5b92d1a5289db3fb6acb33c', - url='https://osf.io/sxr8y/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_phantom_ctf"] = dict( + archive_name="bst_phantom_ctf.tar.gz", + hash="md5:80819cb7f5b92d1a5289db3fb6acb33c", + url="https://osf.io/sxr8y/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_phantom_elekta'] = dict( - archive_name='bst_phantom_elekta.tar.gz', - hash='md5:1badccbe17998d18cc373526e86a7aaf', - url='https://osf.io/dpcku/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_phantom_elekta"] = dict( + archive_name="bst_phantom_elekta.tar.gz", + hash="md5:1badccbe17998d18cc373526e86a7aaf", + url="https://osf.io/dpcku/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_raw'] = dict( - archive_name='bst_raw.tar.gz', - hash='md5:fa2efaaec3f3d462b319bc24898f440c', - url='https://osf.io/9675n/download?version=2', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_raw"] = dict( + archive_name="bst_raw.tar.gz", + hash="md5:fa2efaaec3f3d462b319bc24898f440c", + url="https://osf.io/9675n/download?version=2", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_resting'] = dict( - archive_name='bst_resting.tar.gz', - hash='md5:70fc7bf9c3b97c4f2eab6260ee4a0430', - url='https://osf.io/m7bd3/download?version=3', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_resting"] = dict( + archive_name="bst_resting.tar.gz", + hash="md5:70fc7bf9c3b97c4f2eab6260ee4a0430", + url="https://osf.io/m7bd3/download?version=3", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) # HF-SEF -MNE_DATASETS['hf_sef_raw'] = dict( - archive_name='hf_sef_raw.tar.gz', - hash='md5:33934351e558542bafa9b262ac071168', - url='https://zenodo.org/record/889296/files/hf_sef_raw.tar.gz', - folder_name='hf_sef', - config_key='MNE_DATASETS_HF_SEF_PATH', +MNE_DATASETS["hf_sef_raw"] = dict( + archive_name="hf_sef_raw.tar.gz", + hash="md5:33934351e558542bafa9b262ac071168", + url="https://zenodo.org/record/889296/files/hf_sef_raw.tar.gz", + folder_name="hf_sef", + config_key="MNE_DATASETS_HF_SEF_PATH", ) -MNE_DATASETS['hf_sef_evoked'] = dict( - archive_name='hf_sef_evoked.tar.gz', - hash='md5:13d34cb5db584e00868677d8fb0aab2b', +MNE_DATASETS["hf_sef_evoked"] = dict( + archive_name="hf_sef_evoked.tar.gz", + hash="md5:13d34cb5db584e00868677d8fb0aab2b", # Zenodo can be slow, so we use the OSF mirror # url=('https://zenodo.org/record/3523071/files/' # 'hf_sef_evoked.tar.gz'), - url='https://osf.io/25f8d/download?version=2', - folder_name='hf_sef', - config_key='MNE_DATASETS_HF_SEF_PATH', + url="https://osf.io/25f8d/download?version=2", + folder_name="hf_sef", + config_key="MNE_DATASETS_HF_SEF_PATH", ) # "fake" dataset (for testing) -MNE_DATASETS['fake'] = dict( - archive_name='foo.tgz', - hash='md5:3194e9f7b46039bb050a74f3e1ae9908', - url=('https://github.com/mne-tools/mne-testing-data/raw/master/' - 'datasets/foo.tgz'), - folder_name='foo', - config_key='MNE_DATASETS_FAKE_PATH' +MNE_DATASETS["fake"] = dict( + archive_name="foo.tgz", + hash="md5:3194e9f7b46039bb050a74f3e1ae9908", + url=( + "https://github.com/mne-tools/mne-testing-data/raw/master/" "datasets/foo.tgz" + ), + folder_name="foo", + config_key="MNE_DATASETS_FAKE_PATH", ) # eyelink dataset -MNE_DATASETS['eyelink'] = dict( - archive_name='eyelink_example_data.zip', - hash='md5:081950c05f35267458d9c751e178f161', - url=('https://osf.io/r5ndq/download?version=1'), - folder_name='eyelink-example-data', - config_key='MNE_DATASETS_EYELINK_PATH' +MNE_DATASETS["eyelink"] = dict( + archive_name="eyelink_example_data.zip", + hash="md5:081950c05f35267458d9c751e178f161", + url=("https://osf.io/r5ndq/download?version=1"), + folder_name="eyelink-example-data", + config_key="MNE_DATASETS_EYELINK_PATH", ) diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index fd2b0a71e24..4d5b3f9b7d6 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -10,8 +10,7 @@ import time from ...utils import _url_to_local_path, verbose, logger -from ..utils import (_do_path_update, _get_path, _log_time_size, - _downloader_params) +from ..utils import _do_path_update, _get_path, _log_time_size, _downloader_params # TODO: remove try/except when our min version is py 3.9 try: @@ -20,12 +19,11 @@ from importlib_resources import files -EEGMI_URL = 'https://physionet.org/files/eegmmidb/1.0.0/' +EEGMI_URL = "https://physionet.org/files/eegmmidb/1.0.0/" @verbose -def data_path(url, path=None, force_update=False, update_path=None, *, - verbose=None): +def data_path(url, path=None, force_update=False, update_path=None, *, verbose=None): """Get path to local copy of EEGMMI dataset URL. This is a low-level function useful for getting a local copy of a @@ -73,10 +71,10 @@ def data_path(url, path=None, force_update=False, update_path=None, *, """ # noqa: E501 import pooch - key = 'MNE_DATASETS_EEGBCI_PATH' - name = 'EEGBCI' + key = "MNE_DATASETS_EEGBCI_PATH" + name = "EEGBCI" path = _get_path(path, key, name) - fname = 'MNE-eegbci-data' + fname = "MNE-eegbci-data" destination = _url_to_local_path(url, op.join(path, fname)) destinations = [destination] @@ -101,8 +99,15 @@ def data_path(url, path=None, force_update=False, update_path=None, *, @verbose -def load_data(subject, runs, path=None, force_update=False, update_path=None, - base_url=EEGMI_URL, verbose=None): # noqa: D301 +def load_data( + subject, + runs, + path=None, + force_update=False, + update_path=None, + base_url=EEGMI_URL, + verbose=None, +): # noqa: D301 """Get paths to local copies of EEGBCI dataset files. This will fetch data for the EEGBCI dataset :footcite:`SchalkEtAl2004`, which is also @@ -165,43 +170,46 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() - if not hasattr(runs, '__iter__'): + if not hasattr(runs, "__iter__"): runs = [runs] # get local storage path - config_key = 'MNE_DATASETS_EEGBCI_PATH' - folder = 'MNE-eegbci-data' - name = 'EEGBCI' + config_key = "MNE_DATASETS_EEGBCI_PATH" + folder = "MNE-eegbci-data" + name = "EEGBCI" path = _get_path(path, config_key, name) # extract path parts - pattern = r'(?:https?://.*)(files)/(eegmmidb)/(\d+\.\d+\.\d+)/?' + pattern = r"(?:https?://.*)(files)/(eegmmidb)/(\d+\.\d+\.\d+)/?" match = re.compile(pattern).match(base_url) if match is None: - raise ValueError('base_url does not match the expected EEGMI folder ' - 'structure. Please notify MNE-Python developers.') + raise ValueError( + "base_url does not match the expected EEGMI folder " + "structure. Please notify MNE-Python developers." + ) base_path = op.join(path, folder, *match.groups()) # create the download manager fetcher = pooch.create( path=base_path, base_url=base_url, - version=None, # Data versioning is decoupled from MNE-Python version. + version=None, # Data versioning is decoupled from MNE-Python version. registry=None, # Registry is loaded from file, below. - retry_if_failed=2 # 2 retries = 3 total attempts + retry_if_failed=2, # 2 retries = 3 total attempts ) # load the checksum registry - registry = files('mne').joinpath('data', 'eegbci_checksums.txt') + registry = files("mne").joinpath("data", "eegbci_checksums.txt") fetcher.load_registry(registry) # fetch the file(s) data_paths = [] sz = 0 for run in runs: - file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf' + file_part = f"S{subject:03d}/S{subject:03d}R{run:02d}.edf" destination = Path(base_path, file_part) data_paths.append(destination) if destination.exists(): @@ -210,7 +218,7 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, else: continue if sz == 0: # log once - logger.info('Downloading EEGBCI data') + logger.info("Downloading EEGBCI data") fetcher.fetch(file_part) # update path in config if desired sz += destination.stat().st_size @@ -230,11 +238,11 @@ def standardize(raw): """ rename = dict() for name in raw.ch_names: - std_name = name.strip('.') + std_name = name.strip(".") std_name = std_name.upper() - if std_name.endswith('Z'): - std_name = std_name[:-1] + 'z' - if std_name.startswith('FP'): - std_name = 'Fp' + std_name[2:] + if std_name.endswith("Z"): + std_name = std_name[:-1] + "z" + if std_name.startswith("FP"): + std_name = "Fp" + std_name[2:] rename[name] = std_name raw.rename_channels(rename) diff --git a/mne/datasets/eegbci/tests/test_eegbci.py b/mne/datasets/eegbci/tests/test_eegbci.py index e60988ff36c..c59c6802ede 100644 --- a/mne/datasets/eegbci/tests/test_eegbci.py +++ b/mne/datasets/eegbci/tests/test_eegbci.py @@ -8,7 +8,6 @@ def test_eegbci_download(tmp_path, fake_retrieve): """Test Sleep Physionet URL handling.""" for subj in range(4): - fnames = eegbci.load_data( - subj + 1, runs=[3], path=tmp_path, update_path=False) + fnames = eegbci.load_data(subj + 1, runs=[3], path=tmp_path, update_path=False) assert len(fnames) == 1, subj assert fake_retrieve.call_count == 4 diff --git a/mne/datasets/epilepsy_ecog/_data.py b/mne/datasets/epilepsy_ecog/_data.py index 33535c1aff0..b6cc93b92bd 100644 --- a/mne/datasets/epilepsy_ecog/_data.py +++ b/mne/datasets/epilepsy_ecog/_data.py @@ -3,25 +3,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='epilepsy_ecog', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="epilepsy_ecog", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='epilepsy_ecog', conf='MNE_DATASETS_EPILEPSY_ECOG_PATH') + name="epilepsy_ecog", conf="MNE_DATASETS_EPILEPSY_ECOG_PATH" +) def get_version(): # noqa: D103 - return _get_version('epilepsy_ecog') + return _get_version("epilepsy_ecog") -get_version.__doc__ = _version_doc.format(name='epilepsy_ecog') +get_version.__doc__ = _version_doc.format(name="epilepsy_ecog") diff --git a/mne/datasets/erp_core/erp_core.py b/mne/datasets/erp_core/erp_core.py index 76bd62ca209..8f3aa1e2663 100644 --- a/mne/datasets/erp_core/erp_core.py +++ b/mne/datasets/erp_core/erp_core.py @@ -1,23 +1,28 @@ from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='erp_core', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="erp_core", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='erp_core', - conf='MNE_DATASETS_ERP_CORE_PATH') +data_path.__doc__ = _data_path_doc.format( + name="erp_core", conf="MNE_DATASETS_ERP_CORE_PATH" +) def get_version(): # noqa: D103 - return _get_version('erp_core') + return _get_version("erp_core") -get_version.__doc__ = _version_doc.format(name='erp_core') +get_version.__doc__ = _version_doc.format(name="erp_core") diff --git a/mne/datasets/eyelink/eyelink.py b/mne/datasets/eyelink/eyelink.py index a08e338ab33..f0a349c3c16 100644 --- a/mne/datasets/eyelink/eyelink.py +++ b/mne/datasets/eyelink/eyelink.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='eyelink', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="eyelink", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='eyelink', - conf='MNE_DATASETS_EYELINK_PATH') +data_path.__doc__ = _data_path_doc.format( + name="eyelink", conf="MNE_DATASETS_EYELINK_PATH" +) def get_version(): # noqa: D103 - return _get_version('eyelink') + return _get_version("eyelink") -get_version.__doc__ = _version_doc.format(name='eyelink') +get_version.__doc__ = _version_doc.format(name="eyelink") diff --git a/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py b/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py index d7abe1c68f0..cdce53d57a8 100644 --- a/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py +++ b/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py @@ -3,25 +3,30 @@ # # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fieldtrip_cmc', processor='nested_unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fieldtrip_cmc", + processor="nested_unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='fieldtrip_cmc', conf='MNE_DATASETS_FIELDTRIP_CMC_PATH') + name="fieldtrip_cmc", conf="MNE_DATASETS_FIELDTRIP_CMC_PATH" +) def get_version(): # noqa: D103 - return _get_version('fieldtrip_cmc') + return _get_version("fieldtrip_cmc") -get_version.__doc__ = _version_doc.format(name='fieldtrip_cmc') +get_version.__doc__ = _version_doc.format(name="fieldtrip_cmc") diff --git a/mne/datasets/fnirs_motor/fnirs_motor.py b/mne/datasets/fnirs_motor/fnirs_motor.py index ce0294f9f4e..2c49a32c891 100644 --- a/mne/datasets/fnirs_motor/fnirs_motor.py +++ b/mne/datasets/fnirs_motor/fnirs_motor.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fnirs_motor', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fnirs_motor", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='fnirs_motor', - conf='MNE_DATASETS_FNIRS_MOTOR_PATH') +data_path.__doc__ = _data_path_doc.format( + name="fnirs_motor", conf="MNE_DATASETS_FNIRS_MOTOR_PATH" +) def get_version(): # noqa: D103 - return _get_version('fnirs_motor') + return _get_version("fnirs_motor") -get_version.__doc__ = _version_doc.format(name='fnirs_motor') +get_version.__doc__ = _version_doc.format(name="fnirs_motor") diff --git a/mne/datasets/hf_sef/hf_sef.py b/mne/datasets/hf_sef/hf_sef.py index 401c3636017..66c25ad12be 100644 --- a/mne/datasets/hf_sef/hf_sef.py +++ b/mne/datasets/hf_sef/hf_sef.py @@ -11,8 +11,9 @@ @verbose -def data_path(dataset='evoked', path=None, force_update=False, - update_path=True, *, verbose=None): +def data_path( + dataset="evoked", path=None, force_update=False, update_path=True, *, verbose=None +): """Get path to local copy of the high frequency SEF dataset. Gets a local copy of the high frequency SEF MEG dataset @@ -46,33 +47,38 @@ def data_path(dataset='evoked', path=None, force_update=False, ---------- .. footbibliography:: """ - _check_option('dataset', dataset, ('evoked', 'raw')) - if dataset == 'raw': - data_dict = MNE_DATASETS['hf_sef_raw'] - data_dict['dataset_name'] = 'hf_sef_raw' + _check_option("dataset", dataset, ("evoked", "raw")) + if dataset == "raw": + data_dict = MNE_DATASETS["hf_sef_raw"] + data_dict["dataset_name"] = "hf_sef_raw" else: - data_dict = MNE_DATASETS['hf_sef_evoked'] - data_dict['dataset_name'] = 'hf_sef_evoked' - config_key = data_dict['config_key'] - folder_name = data_dict['folder_name'] + data_dict = MNE_DATASETS["hf_sef_evoked"] + data_dict["dataset_name"] = "hf_sef_evoked" + config_key = data_dict["config_key"] + folder_name = data_dict["folder_name"] # get download path for specific dataset path = _get_path(path=path, key=config_key, name=folder_name) final_path = op.join(path, folder_name) - megdir = op.join(final_path, 'MEG', 'subject_a') - has_raw = (dataset == 'raw' and op.isdir(megdir) and - any('raw' in filename for filename in os.listdir(megdir))) - has_evoked = (dataset == 'evoked' and - op.isdir(op.join(final_path, 'subjects'))) + megdir = op.join(final_path, "MEG", "subject_a") + has_raw = ( + dataset == "raw" + and op.isdir(megdir) + and any("raw" in filename for filename in os.listdir(megdir)) + ) + has_evoked = dataset == "evoked" and op.isdir(op.join(final_path, "subjects")) # data not there, or force_update requested: if has_raw or has_evoked and not force_update: - _do_path_update(path, update_path, config_key, - folder_name) + _do_path_update(path, update_path, config_key, folder_name) return final_path # instantiate processor that unzips file - data_path = _download_mne_dataset(name=data_dict['dataset_name'], - processor='untar', path=path, - force_update=force_update, - update_path=update_path, download=True) + data_path = _download_mne_dataset( + name=data_dict["dataset_name"], + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=True, + ) return data_path diff --git a/mne/datasets/kiloword/kiloword.py b/mne/datasets/kiloword/kiloword.py index c011365bad3..c6f437ab36e 100644 --- a/mne/datasets/kiloword/kiloword.py +++ b/mne/datasets/kiloword/kiloword.py @@ -1,12 +1,13 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_get_version, _version_doc, _download_mne_dataset) +from ..utils import _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): """Get path to local copy of the kiloword dataset. This is the dataset from :footcite:`DufauEtAl2015`. @@ -44,14 +45,18 @@ def data_path(path=None, force_update=False, update_path=True, .. footbibliography:: """ return _download_mne_dataset( - name='kiloword', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="kiloword", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) def get_version(): """Get dataset version.""" - return _get_version('kiloword') + return _get_version("kiloword") -get_version.__doc__ = _version_doc.format(name='kiloword') +get_version.__doc__ = _version_doc.format(name="kiloword") diff --git a/mne/datasets/limo/limo.py b/mne/datasets/limo/limo.py index e0f1d0f9fa9..47055a0bd91 100644 --- a/mne/datasets/limo/limo.py +++ b/mne/datasets/limo/limo.py @@ -12,17 +12,17 @@ from ...epochs import EpochsArray from ...io.meas_info import create_info from ...utils import _check_pandas_installed, verbose, logger -from ..utils import (_get_path, _do_path_update, _log_time_size, - _downloader_params) +from ..utils import _get_path, _do_path_update, _log_time_size, _downloader_params # root url for LIMO files -root_url = 'https://files.de-1.osf.io/v1/resources/52rea/providers/osfstorage/' +root_url = "https://files.de-1.osf.io/v1/resources/52rea/providers/osfstorage/" @verbose -def data_path(subject, path=None, force_update=False, update_path=None, *, - verbose=None): +def data_path( + subject, path=None, force_update=False, update_path=None, *, verbose=None +): """Get path to local copy of LIMO dataset URL. This is a low-level function useful for getting a local copy of the @@ -69,110 +69,183 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() downloader = pooch.HTTPDownloader(**_downloader_params()) # local storage patch - config_key = 'MNE_DATASETS_LIMO_PATH' - name = 'LIMO' - subj = f'S{subject}' + config_key = "MNE_DATASETS_LIMO_PATH" + name = "LIMO" + subj = f"S{subject}" path = _get_path(path, config_key, name) - base_path = op.join(path, 'MNE-limo-data') + base_path = op.join(path, "MNE-limo-data") subject_path = op.join(base_path, subj) # the remote URLs are in the form of UUIDs: urls = dict( - S18={'Yr.mat': '5cf839833a4d9500178a6ff8', - 'LIMO.mat': '5cf83907e650a2001ad592e4'}, - S17={'Yr.mat': '5cf838e83a4d9500168aeb76', - 'LIMO.mat': '5cf83867a542b80019c87602'}, - S16={'Yr.mat': '5cf83857e650a20019d5778f', - 'LIMO.mat': '5cf837dc3a4d9500188a64fe'}, - S15={'Yr.mat': '5cf837cce650a2001ad591e8', - 'LIMO.mat': '5cf83758a542b8001ac7d11d'}, - S14={'Yr.mat': '5cf837493a4d9500198a938f', - 'LIMO.mat': '5cf836e4a542b8001bc7cc53'}, - S13={'Yr.mat': '5cf836d23a4d9500178a6df7', - 'LIMO.mat': '5cf836543a4d9500168ae7cb'}, - S12={'Yr.mat': '5cf83643d4c7d700193e5954', - 'LIMO.mat': '5cf835193a4d9500178a6c92'}, - S11={'Yr.mat': '5cf8356ea542b8001cc81517', - 'LIMO.mat': '5cf834f7d4c7d700163daab8'}, - S10={'Yr.mat': '5cf833b0e650a20019d57454', - 'LIMO.mat': '5cf83204e650a20018d59eb2'}, - S9={'Yr.mat': '5cf83201a542b8001cc811cf', - 'LIMO.mat': '5cf8316c3a4d9500168ae13b'}, - S8={'Yr.mat': '5cf8326ce650a20017d60373', - 'LIMO.mat': '5cf8316d3a4d9500198a8dc5'}, - S7={'Yr.mat': '5cf834a03a4d9500168ae59b', - 'LIMO.mat': '5cf83069e650a20017d600d7'}, - S6={'Yr.mat': '5cf830e6a542b80019c86a70', - 'LIMO.mat': '5cf83057a542b80019c869ca'}, - S5={'Yr.mat': '5cf8115be650a20018d58041', - 'LIMO.mat': '5cf80c0bd4c7d700193e213c'}, - S4={'Yr.mat': '5cf810c9a542b80019c8450a', - 'LIMO.mat': '5cf80bf83a4d9500198a6eb4'}, - S3={'Yr.mat': '5cf80c55d4c7d700163d8f52', - 'LIMO.mat': '5cf80bdea542b80019c83cab'}, - S2={'Yr.mat': '5cde827123fec40019e01300', - 'LIMO.mat': '5cde82682a50c4001677c259'}, - S1={'Yr.mat': '5d6d3071536cf5001a8b0c78', - 'LIMO.mat': '5d6d305f6f41fc001a3151d8'}, + S18={ + "Yr.mat": "5cf839833a4d9500178a6ff8", + "LIMO.mat": "5cf83907e650a2001ad592e4", + }, + S17={ + "Yr.mat": "5cf838e83a4d9500168aeb76", + "LIMO.mat": "5cf83867a542b80019c87602", + }, + S16={ + "Yr.mat": "5cf83857e650a20019d5778f", + "LIMO.mat": "5cf837dc3a4d9500188a64fe", + }, + S15={ + "Yr.mat": "5cf837cce650a2001ad591e8", + "LIMO.mat": "5cf83758a542b8001ac7d11d", + }, + S14={ + "Yr.mat": "5cf837493a4d9500198a938f", + "LIMO.mat": "5cf836e4a542b8001bc7cc53", + }, + S13={ + "Yr.mat": "5cf836d23a4d9500178a6df7", + "LIMO.mat": "5cf836543a4d9500168ae7cb", + }, + S12={ + "Yr.mat": "5cf83643d4c7d700193e5954", + "LIMO.mat": "5cf835193a4d9500178a6c92", + }, + S11={ + "Yr.mat": "5cf8356ea542b8001cc81517", + "LIMO.mat": "5cf834f7d4c7d700163daab8", + }, + S10={ + "Yr.mat": "5cf833b0e650a20019d57454", + "LIMO.mat": "5cf83204e650a20018d59eb2", + }, + S9={ + "Yr.mat": "5cf83201a542b8001cc811cf", + "LIMO.mat": "5cf8316c3a4d9500168ae13b", + }, + S8={ + "Yr.mat": "5cf8326ce650a20017d60373", + "LIMO.mat": "5cf8316d3a4d9500198a8dc5", + }, + S7={ + "Yr.mat": "5cf834a03a4d9500168ae59b", + "LIMO.mat": "5cf83069e650a20017d600d7", + }, + S6={ + "Yr.mat": "5cf830e6a542b80019c86a70", + "LIMO.mat": "5cf83057a542b80019c869ca", + }, + S5={ + "Yr.mat": "5cf8115be650a20018d58041", + "LIMO.mat": "5cf80c0bd4c7d700193e213c", + }, + S4={ + "Yr.mat": "5cf810c9a542b80019c8450a", + "LIMO.mat": "5cf80bf83a4d9500198a6eb4", + }, + S3={ + "Yr.mat": "5cf80c55d4c7d700163d8f52", + "LIMO.mat": "5cf80bdea542b80019c83cab", + }, + S2={ + "Yr.mat": "5cde827123fec40019e01300", + "LIMO.mat": "5cde82682a50c4001677c259", + }, + S1={ + "Yr.mat": "5d6d3071536cf5001a8b0c78", + "LIMO.mat": "5d6d305f6f41fc001a3151d8", + }, ) # these can't be in the registry file (mne/data/dataset_checksums.txt) # because of filename duplication hashes = dict( - S18={'Yr.mat': 'md5:87f883d442737971a80fc0a35d057e51', - 'LIMO.mat': 'md5:8b4879646f65d7876fa4adf2e40162c5'}, - S17={'Yr.mat': 'md5:7b667ec9eefd7a9996f61ae270e295ee', - 'LIMO.mat': 'md5:22eaca4e6fad54431fd61b307fc426b8'}, - S16={'Yr.mat': 'md5:c877afdb4897426421577e863a45921a', - 'LIMO.mat': 'md5:86672d7afbea1e8c39305bc3f852c8c2'}, - S15={'Yr.mat': 'md5:eea9e0140af598fefc08c886a6f05de5', - 'LIMO.mat': 'md5:aed5cb71ddbfd27c6a3ac7d3e613d07f'}, - S14={'Yr.mat': 'md5:8bd842cfd8588bd5d32e72fdbe70b66e', - 'LIMO.mat': 'md5:1e07d1f36f2eefad435a77530daf2680'}, - S13={'Yr.mat': 'md5:d7925d2af7288b8a5186dfb5dbb63d34', - 'LIMO.mat': 'md5:ba891015d2f9e447955fffa9833404ca'}, - S12={'Yr.mat': 'md5:0e1d05beaa4bf2726e0d0671b78fe41e', - 'LIMO.mat': 'md5:423fd479d71097995b6614ecb11df9ad'}, - S11={'Yr.mat': 'md5:1b0016fb9832e43b71f79c1992fcbbb1', - 'LIMO.mat': 'md5:1a281348c2a41ee899f42731d30cda70'}, - S10={'Yr.mat': 'md5:13c66f60e241b9a9cc576eaf1b55a417', - 'LIMO.mat': 'md5:3c4b41e221eb352a21bbef1a7e006f06'}, - S9={'Yr.mat': 'md5:3ae1d9c3a1d9325deea2f2dddd1ab507', - 'LIMO.mat': 'md5:5e204e2a4bcfe4f535b4b1af469b37f7'}, - S8={'Yr.mat': 'md5:7e9adbca4e03d8d7ce8ea07ccecdc8fd', - 'LIMO.mat': 'md5:88313c21d34428863590e586b2bc3408'}, - S7={'Yr.mat': 'md5:6b5290a6725ecebf1022d5d2789b186d', - 'LIMO.mat': 'md5:8c769219ebc14ce3f595063e84bfc0a9'}, - S6={'Yr.mat': 'md5:420c858a8340bf7c28910b7b0425dc5d', - 'LIMO.mat': 'md5:9cf4e1a405366d6bd0cc6d996e32fd63'}, - S5={'Yr.mat': 'md5:946436cfb474c8debae56ffb1685ecf3', - 'LIMO.mat': 'md5:241fac95d3a79d2cea081391fb7078bd'}, - S4={'Yr.mat': 'md5:c8216af78ac87b739e86e57b345cafdd', - 'LIMO.mat': 'md5:8e10ef36c2e075edc2f787581ba33459'}, - S3={'Yr.mat': 'md5:ff02e885b65b7b807146f259a30b1b5e', - 'LIMO.mat': 'md5:59b5fb3a9749003133608b5871309e2c'}, - S2={'Yr.mat': 'md5:a4329022e57fd07ceceb7d1735fd2718', - 'LIMO.mat': 'md5:98b284b567f2dd395c936366e404f2c6'}, - S1={'Yr.mat': 'md5:076c0ae78fb71d43409c1877707df30e', - 'LIMO.mat': 'md5:136c8cf89f8f111a11f531bd9fa6ae69'}, + S18={ + "Yr.mat": "md5:87f883d442737971a80fc0a35d057e51", + "LIMO.mat": "md5:8b4879646f65d7876fa4adf2e40162c5", + }, + S17={ + "Yr.mat": "md5:7b667ec9eefd7a9996f61ae270e295ee", + "LIMO.mat": "md5:22eaca4e6fad54431fd61b307fc426b8", + }, + S16={ + "Yr.mat": "md5:c877afdb4897426421577e863a45921a", + "LIMO.mat": "md5:86672d7afbea1e8c39305bc3f852c8c2", + }, + S15={ + "Yr.mat": "md5:eea9e0140af598fefc08c886a6f05de5", + "LIMO.mat": "md5:aed5cb71ddbfd27c6a3ac7d3e613d07f", + }, + S14={ + "Yr.mat": "md5:8bd842cfd8588bd5d32e72fdbe70b66e", + "LIMO.mat": "md5:1e07d1f36f2eefad435a77530daf2680", + }, + S13={ + "Yr.mat": "md5:d7925d2af7288b8a5186dfb5dbb63d34", + "LIMO.mat": "md5:ba891015d2f9e447955fffa9833404ca", + }, + S12={ + "Yr.mat": "md5:0e1d05beaa4bf2726e0d0671b78fe41e", + "LIMO.mat": "md5:423fd479d71097995b6614ecb11df9ad", + }, + S11={ + "Yr.mat": "md5:1b0016fb9832e43b71f79c1992fcbbb1", + "LIMO.mat": "md5:1a281348c2a41ee899f42731d30cda70", + }, + S10={ + "Yr.mat": "md5:13c66f60e241b9a9cc576eaf1b55a417", + "LIMO.mat": "md5:3c4b41e221eb352a21bbef1a7e006f06", + }, + S9={ + "Yr.mat": "md5:3ae1d9c3a1d9325deea2f2dddd1ab507", + "LIMO.mat": "md5:5e204e2a4bcfe4f535b4b1af469b37f7", + }, + S8={ + "Yr.mat": "md5:7e9adbca4e03d8d7ce8ea07ccecdc8fd", + "LIMO.mat": "md5:88313c21d34428863590e586b2bc3408", + }, + S7={ + "Yr.mat": "md5:6b5290a6725ecebf1022d5d2789b186d", + "LIMO.mat": "md5:8c769219ebc14ce3f595063e84bfc0a9", + }, + S6={ + "Yr.mat": "md5:420c858a8340bf7c28910b7b0425dc5d", + "LIMO.mat": "md5:9cf4e1a405366d6bd0cc6d996e32fd63", + }, + S5={ + "Yr.mat": "md5:946436cfb474c8debae56ffb1685ecf3", + "LIMO.mat": "md5:241fac95d3a79d2cea081391fb7078bd", + }, + S4={ + "Yr.mat": "md5:c8216af78ac87b739e86e57b345cafdd", + "LIMO.mat": "md5:8e10ef36c2e075edc2f787581ba33459", + }, + S3={ + "Yr.mat": "md5:ff02e885b65b7b807146f259a30b1b5e", + "LIMO.mat": "md5:59b5fb3a9749003133608b5871309e2c", + }, + S2={ + "Yr.mat": "md5:a4329022e57fd07ceceb7d1735fd2718", + "LIMO.mat": "md5:98b284b567f2dd395c936366e404f2c6", + }, + S1={ + "Yr.mat": "md5:076c0ae78fb71d43409c1877707df30e", + "LIMO.mat": "md5:136c8cf89f8f111a11f531bd9fa6ae69", + }, ) # create the download manager fetcher = pooch.create( path=subject_path, - base_url='', - version=None, # Data versioning is decoupled from MNE-Python version. + base_url="", + version=None, # Data versioning is decoupled from MNE-Python version. registry=hashes[subj], - urls={key: f'{root_url}{uuid}' for key, uuid in urls[subj].items()}, - retry_if_failed=2 # 2 retries = 3 total attempts + urls={key: f"{root_url}{uuid}" for key, uuid in urls[subj].items()}, + retry_if_failed=2, # 2 retries = 3 total attempts ) # use our logger level for pooch's logger too pooch.get_logger().setLevel(logger.getEffectiveLevel()) # fetch the data sz = 0 - for fname in ('LIMO.mat', 'Yr.mat'): + for fname in ("LIMO.mat", "Yr.mat"): destination = Path(subject_path, fname) if destination.exists(): if force_update: @@ -180,7 +253,7 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, else: continue if sz == 0: # log once - logger.info('Downloading LIMO data') + logger.info("Downloading LIMO data") # fetch the remote file (if local file missing or has hash mismatch) fetcher.fetch(fname=fname, downloader=downloader) sz += destination.stat().st_size @@ -192,8 +265,7 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, @verbose -def load_data(subject, path=None, force_update=False, update_path=None, - verbose=None): +def load_data(subject, path=None, force_update=False, update_path=None, verbose=None): """Fetch subjects epochs data for the LIMO data set. Parameters @@ -222,45 +294,45 @@ def load_data(subject, path=None, force_update=False, update_path=None, # subject in question if isinstance(subject, int) and 1 <= subject <= 18: - subj = 'S%i' % subject + subj = "S%i" % subject else: - raise ValueError('subject must be an int in the range from 1 to 18') + raise ValueError("subject must be an int in the range from 1 to 18") # set limo path, download and decompress files if not found limo_path = data_path(subject, path, force_update, update_path) # -- 1) import .mat files # epochs info - fname_info = op.join(limo_path, subj, 'LIMO.mat') + fname_info = op.join(limo_path, subj, "LIMO.mat") data_info = loadmat(fname_info) # number of epochs per condition - design = data_info['LIMO']['design'][0][0]['X'][0][0] - data_info = data_info['LIMO']['data'][0][0][0][0] + design = data_info["LIMO"]["design"][0][0]["X"][0][0] + data_info = data_info["LIMO"]["data"][0][0][0][0] # epochs data - fname_eeg = op.join(limo_path, subj, 'Yr.mat') + fname_eeg = op.join(limo_path, subj, "Yr.mat") data = loadmat(fname_eeg) # -- 2) get epochs information from structure # sampling rate - sfreq = data_info['sampling_rate'][0][0] + sfreq = data_info["sampling_rate"][0][0] # tmin and tmax - tmin = data_info['start'][0][0] + tmin = data_info["start"][0][0] # create events matrix sample = np.arange(len(design)) prev_id = np.zeros(len(design)) ev_id = design[:, 1] events = np.array([sample, prev_id, ev_id]).astype(int).T # event ids, such that Face B == 1 - event_id = {'Face/A': 0, 'Face/B': 1} + event_id = {"Face/A": 0, "Face/B": 1} # -- 3) extract channel labels from LIMO structure # get individual labels - labels = data_info['chanlocs']['labels'] + labels = data_info["chanlocs"]["labels"] labels = [label for label, *_ in labels[0]] # get montage - montage = make_standard_montage('biosemi128') + montage = make_standard_montage("biosemi128") # add external electrodes (e.g., eogs) - ch_names = montage.ch_names + ['EXG1', 'EXG2', 'EXG3', 'EXG4'] + ch_names = montage.ch_names + ["EXG1", "EXG2", "EXG3", "EXG4"] # match individual labels to labels in montage found_inds = [ind for ind, name in enumerate(ch_names) if name in labels] missing_chans = [name for name in ch_names if name not in labels] @@ -270,7 +342,7 @@ def load_data(subject, path=None, force_update=False, update_path=None, # data is stored as channels x time points x epochs # data['Yr'].shape # <-- see here # transpose to epochs x channels time points - data = np.transpose(data['Yr'], (2, 0, 1)) + data = np.transpose(data["Yr"], (2, 0, 1)) # initialize data in expected order temp_data = np.empty((data.shape[0], len(ch_names), data.shape[2])) # copy over the non-missing data @@ -287,15 +359,16 @@ def load_data(subject, path=None, force_update=False, update_path=None, info = create_info(ch_names, sfreq, types).set_montage(montage) # get faces and noise variables from design matrix event_list = list(events[:, 2]) - faces = ['B' if event else 'A' for event in event_list] + faces = ["B" if event else "A" for event in event_list] noise = list(design[:, 2]) # create epochs metadata - metadata = {'face': faces, 'phase-coherence': noise} + metadata = {"face": faces, "phase-coherence": noise} metadata = pd.DataFrame(metadata) # -- 6) Create custom epochs array - epochs = EpochsArray(data, info, events, tmin, event_id, metadata=metadata, - verbose=False) - epochs.info['bads'] = missing_chans # missing channels are marked as bad. + epochs = EpochsArray( + data, info, events, tmin, event_id, metadata=metadata, verbose=False + ) + epochs.info["bads"] = missing_chans # missing channels are marked as bad. return epochs diff --git a/mne/datasets/misc/_misc.py b/mne/datasets/misc/_misc.py index 85f65332ad1..443aa24787b 100644 --- a/mne/datasets/misc/_misc.py +++ b/mne/datasets/misc/_misc.py @@ -8,19 +8,25 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='misc', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="misc", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) def _pytest_mark(): import pytest + return pytest.mark.skipif( - not has_dataset(name='misc'), reason='Requires misc dataset') + not has_dataset(name="misc"), reason="Requires misc dataset" + ) -data_path.__doc__ = _data_path_doc.format(name='misc', - conf='MNE_DATASETS_MISC_PATH') +data_path.__doc__ = _data_path_doc.format(name="misc", conf="MNE_DATASETS_MISC_PATH") diff --git a/mne/datasets/mtrf/mtrf.py b/mne/datasets/mtrf/mtrf.py index bfc5cd0ba58..1ce4f741a4f 100644 --- a/mne/datasets/mtrf/mtrf.py +++ b/mne/datasets/mtrf/mtrf.py @@ -3,24 +3,27 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset -data_name = 'mtrf' +data_name = "mtrf" @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name=data_name, processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name=data_name, + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name=data_name, - conf='MNE_DATASETS_MTRF_PATH') +data_path.__doc__ = _data_path_doc.format(name=data_name, conf="MNE_DATASETS_MTRF_PATH") def get_version(): # noqa: D103 diff --git a/mne/datasets/multimodal/multimodal.py b/mne/datasets/multimodal/multimodal.py index 4ef0fd38efb..84fbf662e5f 100644 --- a/mne/datasets/multimodal/multimodal.py +++ b/mne/datasets/multimodal/multimodal.py @@ -4,25 +4,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='multimodal', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="multimodal", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='multimodal', - conf='MNE_DATASETS_MULTIMODAL_PATH') +data_path.__doc__ = _data_path_doc.format( + name="multimodal", conf="MNE_DATASETS_MULTIMODAL_PATH" +) def get_version(): # noqa: D103 - return _get_version('multimodal') + return _get_version("multimodal") -get_version.__doc__ = _version_doc.format(name='multimodal') +get_version.__doc__ = _version_doc.format(name="multimodal") diff --git a/mne/datasets/opm/opm.py b/mne/datasets/opm/opm.py index 014e91f2029..b2b24f2e3f8 100644 --- a/mne/datasets/opm/opm.py +++ b/mne/datasets/opm/opm.py @@ -4,25 +4,28 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='opm', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="opm", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='opm', - conf='MNE_DATASETS_OPML_PATH') +data_path.__doc__ = _data_path_doc.format(name="opm", conf="MNE_DATASETS_OPML_PATH") def get_version(): # noqa: D103 - return _get_version('opm') + return _get_version("opm") -get_version.__doc__ = _version_doc.format(name='opm') +get_version.__doc__ = _version_doc.format(name="opm") diff --git a/mne/datasets/phantom_4dbti/phantom_4dbti.py b/mne/datasets/phantom_4dbti/phantom_4dbti.py index 2154dee99ce..59c42416d5a 100644 --- a/mne/datasets/phantom_4dbti/phantom_4dbti.py +++ b/mne/datasets/phantom_4dbti/phantom_4dbti.py @@ -3,25 +3,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='phantom_4dbti', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="phantom_4dbti", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='phantom_4dbti', conf='MNE_DATASETS_PHANTOM_4DBTI_PATH') + name="phantom_4dbti", conf="MNE_DATASETS_PHANTOM_4DBTI_PATH" +) def get_version(): # noqa: D103 - return _get_version('phantom_4dbti') + return _get_version("phantom_4dbti") -get_version.__doc__ = _version_doc.format(name='phantom_4dbti') +get_version.__doc__ = _version_doc.format(name="phantom_4dbti") diff --git a/mne/datasets/refmeg_noise/refmeg_noise.py b/mne/datasets/refmeg_noise/refmeg_noise.py index 2027a31bacc..e77f3eefaf0 100644 --- a/mne/datasets/refmeg_noise/refmeg_noise.py +++ b/mne/datasets/refmeg_noise/refmeg_noise.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='refmeg_noise', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="refmeg_noise", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='refmeg_noise', conf='MNE_DATASETS_REFMEG_NOISE_PATH') + name="refmeg_noise", conf="MNE_DATASETS_REFMEG_NOISE_PATH" +) def get_version(): # noqa: D103 - return _get_version('refmeg_noise') + return _get_version("refmeg_noise") -get_version.__doc__ = _version_doc.format(name='refmeg_noise') +get_version.__doc__ = _version_doc.format(name="refmeg_noise") diff --git a/mne/datasets/sample/sample.py b/mne/datasets/sample/sample.py index 4876b7bc7f7..f5ca6de24c4 100644 --- a/mne/datasets/sample/sample.py +++ b/mne/datasets/sample/sample.py @@ -4,25 +4,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='sample', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="sample", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='sample', - conf='MNE_DATASETS_SAMPLE_PATH') +data_path.__doc__ = _data_path_doc.format( + name="sample", conf="MNE_DATASETS_SAMPLE_PATH" +) def get_version(): # noqa: D103 - return _get_version('sample') + return _get_version("sample") -get_version.__doc__ = _version_doc.format(name='sample') +get_version.__doc__ = _version_doc.format(name="sample") diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index 50f992e7803..bca3284d73b 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -8,27 +8,30 @@ import numpy as np -from ...utils import (verbose, _TempDir, _check_pandas_installed, - _on_missing) +from ...utils import verbose, _TempDir, _check_pandas_installed, _on_missing from ..utils import _get_path, _downloader_params -AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), 'age_records.csv') -TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__), - 'temazepam_records.csv') +AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), "age_records.csv") +TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__), "temazepam_records.csv") -TEMAZEPAM_RECORDS_URL = 'https://physionet.org/physiobank/database/sleep-edfx/ST-subjects.xls' # noqa: E501 -TEMAZEPAM_RECORDS_URL_SHA1 = 'f52fffe5c18826a2bd4c5d5cb375bb4a9008c885' +TEMAZEPAM_RECORDS_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/ST-subjects.xls" # noqa: E501 +) +TEMAZEPAM_RECORDS_URL_SHA1 = "f52fffe5c18826a2bd4c5d5cb375bb4a9008c885" -AGE_RECORDS_URL = 'https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls' # noqa: E501 -AGE_RECORDS_URL_SHA1 = '0ba6650892c5d33a8e2b3f62ce1cc9f30438c54f' +AGE_RECORDS_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls" # noqa: E501 +) +AGE_RECORDS_URL_SHA1 = "0ba6650892c5d33a8e2b3f62ce1cc9f30438c54f" -sha1sums_fname = op.join(op.dirname(__file__), 'SHA1SUMS') +sha1sums_fname = op.join(op.dirname(__file__), "SHA1SUMS") def _fetch_one(fname, hashsum, path, force_update, base_url): import pooch + # Fetch the file - url = base_url + '/' + fname + url = base_url + "/" + fname destination = op.join(path, fname) if op.isfile(destination) and not force_update: return destination, False @@ -42,7 +45,7 @@ def _fetch_one(fname, hashsum, path, force_update, base_url): known_hash=f"sha1:{hashsum}", path=path, downloader=downloader, - fname=fname + fname=fname, ) return destination, True @@ -75,10 +78,10 @@ def _data_path(path=None, verbose=None): ---------- .. footbibliography:: """ # noqa: E501 - key = 'PHYSIONET_SLEEP_PATH' - name = 'PHYSIONET_SLEEP' + key = "PHYSIONET_SLEEP_PATH" + name = "PHYSIONET_SLEEP" path = _get_path(path, key, name) - return op.join(path, 'physionet-sleep-data') + return op.join(path, "physionet-sleep-data") def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): @@ -89,7 +92,7 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): tmp = _TempDir() # Download subjects info. - subjects_fname = op.join(tmp, 'ST-subjects.xls') + subjects_fname = op.join(tmp, "ST-subjects.xls") downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=TEMAZEPAM_RECORDS_URL, @@ -100,44 +103,60 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): ) # Load and Massage the checksums. - sha1_df = pd.read_csv(sha1sums_fname, sep=' ', header=None, - names=['sha', 'fname'], engine='python') - select_age_records = (sha1_df.fname.str.startswith('ST') & - sha1_df.fname.str.endswith('edf')) + sha1_df = pd.read_csv( + sha1sums_fname, sep=" ", header=None, names=["sha", "fname"], engine="python" + ) + select_age_records = sha1_df.fname.str.startswith( + "ST" + ) & sha1_df.fname.str.endswith("edf") sha1_df = sha1_df[select_age_records] - sha1_df['id'] = [name[:6] for name in sha1_df.fname] + sha1_df["id"] = [name[:6] for name in sha1_df.fname] # Load and massage the data. data = pd.read_excel(subjects_fname, header=[0, 1]) - data = data.set_index(('Subject - age - sex', 'Nr')) - data.index.name = 'subject' + data = data.set_index(("Subject - age - sex", "Nr")) + data.index.name = "subject" data.columns.names = [None, None] - data = (data.set_index([('Subject - age - sex', 'Age'), - ('Subject - age - sex', 'M1/F2')], append=True) - .stack(level=0).reset_index()) - - data = data.rename(columns={('Subject - age - sex', 'Age'): 'age', - ('Subject - age - sex', 'M1/F2'): 'sex', - 'level_3': 'drug'}) - data['id'] = ['ST7{:02d}{:1d}'.format(s, n) - for s, n in zip(data.subject, data['night nr'])] + data = ( + data.set_index( + [("Subject - age - sex", "Age"), ("Subject - age - sex", "M1/F2")], + append=True, + ) + .stack(level=0) + .reset_index() + ) - data = pd.merge(sha1_df, data, how='outer', on='id') - data['record type'] = (data.fname.str.split('-', expand=True)[1] - .str.split('.', expand=True)[0] - .astype('category')) + data = data.rename( + columns={ + ("Subject - age - sex", "Age"): "age", + ("Subject - age - sex", "M1/F2"): "sex", + "level_3": "drug", + } + ) + data["id"] = [ + "ST7{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data["night nr"]) + ] + + data = pd.merge(sha1_df, data, how="outer", on="id") + data["record type"] = ( + data.fname.str.split("-", expand=True)[1] + .str.split(".", expand=True)[0] + .astype("category") + ) - data = data.set_index(['id', 'subject', 'age', 'sex', 'drug', - 'lights off', 'night nr', 'record type']).unstack() - data.columns = [l1 + '_' + l2 for l1, l2 in data.columns] - data = data.reset_index().drop(columns=['id']) + data = data.set_index( + ["id", "subject", "age", "sex", "drug", "lights off", "night nr", "record type"] + ).unstack() + data.columns = [l1 + "_" + l2 for l1, l2 in data.columns] + data = data.reset_index().drop(columns=["id"]) - data['sex'] = (data.sex.astype('category') - .cat.rename_categories({1: 'male', 2: 'female'})) + data["sex"] = data.sex.astype("category").cat.rename_categories( + {1: "male", 2: "female"} + ) - data['drug'] = data['drug'].str.split(expand=True)[0] - data['subject_orig'] = data['subject'] - data['subject'] = data.index // 2 # to make sure index is from 0 to 21 + data["drug"] = data["drug"].str.split(expand=True)[0] + data["subject_orig"] = data["subject"] + data["subject"] = data.index // 2 # to make sure index is from 0 to 21 # Save the data. data.to_csv(fname, index=False) @@ -146,11 +165,12 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): """Help function to download Physionet's age dataset records.""" import pooch + pd = _check_pandas_installed() tmp = _TempDir() # Download subjects info. - subjects_fname = op.join(tmp, 'SC-subjects.xls') + subjects_fname = op.join(tmp, "SC-subjects.xls") downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=AGE_RECORDS_URL, @@ -161,38 +181,46 @@ def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): ) # Load and Massage the checksums. - sha1_df = pd.read_csv(sha1sums_fname, sep=' ', header=None, - names=['sha', 'fname'], engine='python') - select_age_records = (sha1_df.fname.str.startswith('SC') & - sha1_df.fname.str.endswith('edf')) + sha1_df = pd.read_csv( + sha1sums_fname, sep=" ", header=None, names=["sha", "fname"], engine="python" + ) + select_age_records = sha1_df.fname.str.startswith( + "SC" + ) & sha1_df.fname.str.endswith("edf") sha1_df = sha1_df[select_age_records] - sha1_df['id'] = [name[:6] for name in sha1_df.fname] + sha1_df["id"] = [name[:6] for name in sha1_df.fname] # Load and massage the data. data = pd.read_excel(subjects_fname) - data = data.rename(index=str, columns={'sex (F=1)': 'sex', - 'LightsOff': 'lights off'}) - data['sex'] = (data.sex.astype('category') - .cat.rename_categories({1: 'female', 2: 'male'})) + data = data.rename( + index=str, columns={"sex (F=1)": "sex", "LightsOff": "lights off"} + ) + data["sex"] = data.sex.astype("category").cat.rename_categories( + {1: "female", 2: "male"} + ) - data['id'] = ['SC4{:02d}{:1d}'.format(s, n) - for s, n in zip(data.subject, data.night)] + data["id"] = [ + "SC4{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data.night) + ] - data = data.set_index('id').join(sha1_df.set_index('id')).dropna() + data = data.set_index("id").join(sha1_df.set_index("id")).dropna() - data['record type'] = (data.fname.str.split('-', expand=True)[1] - .str.split('.', expand=True)[0] - .astype('category')) + data["record type"] = ( + data.fname.str.split("-", expand=True)[1] + .str.split(".", expand=True)[0] + .astype("category") + ) - data = data.reset_index().drop(columns=['id']) - data = data[['subject', 'night', 'record type', 'age', 'sex', 'lights off', - 'sha', 'fname']] + data = data.reset_index().drop(columns=["id"]) + data = data[ + ["subject", "night", "record type", "age", "sex", "lights off", "sha", "fname"] + ] # Save the data. data.to_csv(fname, index=False) -def _check_subjects(subjects, n_subjects, missing=None, on_missing='raise'): +def _check_subjects(subjects, n_subjects, missing=None, on_missing="raise"): """Check whether subjects are available. Parameters @@ -214,8 +242,10 @@ def _check_subjects(subjects, n_subjects, missing=None, on_missing='raise'): valid_subjects = np.setdiff1d(valid_subjects, missing) unknown_subjects = np.setdiff1d(subjects, valid_subjects) if unknown_subjects.size > 0: - subjects_list = ', '.join([str(s) for s in unknown_subjects]) - msg = (f'This dataset contains subjects 0 to {n_subjects - 1} with ' - f'missing subjects {missing}. Unknown subjects: ' - f'{subjects_list}.') + subjects_list = ", ".join([str(s) for s in unknown_subjects]) + msg = ( + f"This dataset contains subjects 0 to {n_subjects - 1} with " + f"missing subjects {missing}. Unknown subjects: " + f"{subjects_list}." + ) _on_missing(on_missing, msg) diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index 106d39d4e32..0a7fb174d1c 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -15,12 +15,22 @@ data_path = _data_path # expose _data_path(..) as data_path(..) -BASE_URL = 'https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/' # noqa: E501 +BASE_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/" # noqa: E501 +) @verbose -def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, - base_url=BASE_URL, on_missing='raise', *, verbose=None): # noqa: D301, E501 +def fetch_data( + subjects, + recording=(1, 2), + path=None, + force_update=False, + base_url=BASE_URL, + on_missing="raise", + *, + verbose=None +): # noqa: D301, E501 """Get paths to local copies of PhysioNet Polysomnography dataset files. This will fetch data from the publicly available subjects from PhysioNet's @@ -84,46 +94,53 @@ def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, .. footbibliography:: """ # noqa: E501 t0 = time.time() - records = np.loadtxt(AGE_SLEEP_RECORDS, - skiprows=1, - delimiter=',', - usecols=(0, 1, 2, 6, 7), - dtype={'names': ('subject', 'record', 'type', 'sha', - 'fname'), - 'formats': (' 0: - os.makedirs(op.join(destination, 'foo')) - assert op.isdir(op.join(destination, 'foo')) + os.makedirs(op.join(destination, "foo")) + assert op.isdir(op.join(destination, "foo")) for fname in _zip_fnames: assert not op.isfile(op.join(destination, fname)) for fname in _zip_fnames[:n_have]: - with open(op.join(destination, fname), 'w'): + with open(op.join(destination, fname), "w"): pass with catch_logging() as log: with use_log_level(True): # we mock the pooch.retrieve so these are not used - url = hash_ = '' + url = hash_ = "" _manifest_check_download(manifest_path, destination, url, hash_) log = log.getvalue() n_missing = 3 - n_have - assert ('%d file%s missing from' % (n_missing, _pl(n_missing))) in log - for want in ('Extracting missing', 'Successfully '): + assert ("%d file%s missing from" % (n_missing, _pl(n_missing))) in log + for want in ("Extracting missing", "Successfully "): if n_missing > 0: assert want in log else: @@ -236,10 +264,9 @@ def test_manifest_check_download(tmp_path, n_have, monkeypatch): assert op.isfile(op.join(destination, fname)) -def _fake_mcd(manifest_path, destination, url, hash_, name=None, - fake_files=False): +def _fake_mcd(manifest_path, destination, url, hash_, name=None, fake_files=False): if name is None: - name = url.split('/')[-1].split('.')[0] + name = url.split("/")[-1].split(".")[0] assert name in url assert name in str(destination) assert name in manifest_path @@ -252,16 +279,16 @@ def _fake_mcd(manifest_path, destination, url, hash_, name=None, continue fname = op.join(destination, path) os.makedirs(op.dirname(fname), exist_ok=True) - with open(fname, 'wb'): + with open(fname, "wb"): pass def test_infant(tmp_path, monkeypatch): """Test fetch_infant_template.""" - monkeypatch.setattr(infant_base, '_manifest_check_download', _fake_mcd) - fetch_infant_template('12mo', subjects_dir=tmp_path) - with pytest.raises(ValueError, match='Invalid value for'): - fetch_infant_template('0mo', subjects_dir=tmp_path) + monkeypatch.setattr(infant_base, "_manifest_check_download", _fake_mcd) + fetch_infant_template("12mo", subjects_dir=tmp_path) + with pytest.raises(ValueError, match="Invalid value for"): + fetch_infant_template("0mo", subjects_dir=tmp_path) def test_phantom(tmp_path, monkeypatch): @@ -270,21 +297,25 @@ def test_phantom(tmp_path, monkeypatch): # an actual download here. But it doesn't seem worth it given that # CircleCI will at least test the VectorView one, and this file should # not change often. - monkeypatch.setattr(phantom_base, '_manifest_check_download', - partial(_fake_mcd, name='phantom_otaniemi', - fake_files=True)) - fetch_phantom('otaniemi', subjects_dir=tmp_path) - assert op.isfile(tmp_path / 'phantom_otaniemi' / 'mri' / 'T1.mgz') + monkeypatch.setattr( + phantom_base, + "_manifest_check_download", + partial(_fake_mcd, name="phantom_otaniemi", fake_files=True), + ) + fetch_phantom("otaniemi", subjects_dir=tmp_path) + assert op.isfile(tmp_path / "phantom_otaniemi" / "mri" / "T1.mgz") def test_fetch_uncompressed_file(tmp_path): """Test downloading an uncompressed file with our fetch function.""" dataset_dict = dict( - dataset_name='license', - url=('https://raw.githubusercontent.com/mne-tools/mne-python/main/' - 'LICENSE.txt'), - archive_name='LICENSE.foo', - folder_name=op.join(tmp_path, 'foo'), - hash=None) + dataset_name="license", + url=( + "https://raw.githubusercontent.com/mne-tools/mne-python/main/" "LICENSE.txt" + ), + archive_name="LICENSE.foo", + folder_name=op.join(tmp_path, "foo"), + hash=None, + ) fetch_dataset(dataset_dict, path=None, force_update=True) - assert (tmp_path / 'foo' / 'LICENSE.foo').is_file() + assert (tmp_path / "foo" / "LICENSE.foo").is_file() diff --git a/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py b/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py index e43443d1480..09853e640de 100644 --- a/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py +++ b/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py @@ -2,26 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset -_NAME = 'ucl_opm_auditory' -_PROCESSOR = 'unzip' +_NAME = "ucl_opm_auditory" +_PROCESSOR = "unzip" @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name=_NAME, processor=_PROCESSOR, path=path, - force_update=force_update, update_path=update_path, - download=download) + name=_NAME, + processor=_PROCESSOR, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( name=_NAME, - conf=f'MNE_DATASETS_{_NAME.upper()}_PATH', + conf=f"MNE_DATASETS_{_NAME.upper()}_PATH", ) diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 1fba832abb0..32ff152cd5e 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -24,8 +24,16 @@ from .config import _hcp_mmp_license_text, MNE_DATASETS from ..label import read_labels_from_annot, Label, write_labels_to_annot -from ..utils import (get_config, set_config, logger, _validate_type, - verbose, get_subjects_dir, _pl, _safe_input) +from ..utils import ( + get_config, + set_config, + logger, + _validate_type, + verbose, + get_subjects_dir, + _pl, + _safe_input, +) from ..utils.docs import docdict, _docformat @@ -58,10 +66,10 @@ path : instance of Path Path to {name} dataset directory. """ -_data_path_doc_accept = _data_path_doc.split('%(verbose)s') -_data_path_doc_accept[-1] = '%(verbose)s' + _data_path_doc_accept[-1] -_data_path_doc_accept.insert(1, ' %(accept)s') -_data_path_doc_accept = ''.join(_data_path_doc_accept) +_data_path_doc_accept = _data_path_doc.split("%(verbose)s") +_data_path_doc_accept[-1] = "%(verbose)s" + _data_path_doc_accept[-1] +_data_path_doc_accept.insert(1, " %(accept)s") +_data_path_doc_accept = "".join(_data_path_doc_accept) _data_path_doc = _docformat(_data_path_doc, docdict) _data_path_doc_accept = _docformat(_data_path_doc_accept, docdict) @@ -77,69 +85,73 @@ def _dataset_version(path, name): """Get the version of the dataset.""" - ver_fname = op.join(path, 'version.txt') + ver_fname = op.join(path, "version.txt") if op.exists(ver_fname): - with open(ver_fname, 'r') as fid: + with open(ver_fname, "r") as fid: version = fid.readline().strip() # version is on first line else: - logger.debug(f'Version file missing: {ver_fname}') + logger.debug(f"Version file missing: {ver_fname}") # Sample dataset versioning was introduced after 0.3 # SPM dataset was introduced with 0.7 - versions = dict(sample='0.7', spm='0.3') - version = versions.get(name, '0.0') + versions = dict(sample="0.7", spm="0.3") + version = versions.get(name, "0.0") return version def _get_path(path, key, name): """Get a dataset path.""" # 1. Input - _validate_type(path, ('path-like', None), path) + _validate_type(path, ("path-like", None), path) if path is not None: return path # 2. get_config(key) — unless key is None or "" (special get_config values) # 3. get_config('MNE_DATA') - path = get_config(key or 'MNE_DATA', get_config('MNE_DATA')) + path = get_config(key or "MNE_DATA", get_config("MNE_DATA")) if path is not None: path = Path(path).expanduser() if not path.exists(): - msg = (f"Download location {path} as specified by MNE_DATA does " - f"not exist. Either create this directory manually and try " - f"again, or set MNE_DATA to an existing directory.") + msg = ( + f"Download location {path} as specified by MNE_DATA does " + f"not exist. Either create this directory manually and try " + f"again, or set MNE_DATA to an existing directory." + ) raise FileNotFoundError(msg) return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info('Using default location ~/mne_data for %s...' % name) - path = op.join(os.getenv('_MNE_FAKE_HOME_DIR', - op.expanduser("~")), 'mne_data') + logger.info("Using default location ~/mne_data for %s..." % name) + path = op.join(os.getenv("_MNE_FAKE_HOME_DIR", op.expanduser("~")), "mne_data") if not op.exists(path): - logger.info('Creating ~/mne_data') + logger.info("Creating ~/mne_data") try: os.mkdir(path) except OSError: - raise OSError("User does not have write permissions " - "at '%s', try giving the path as an " - "argument to data_path() where user has " - "write permissions, for ex:data_path" - "('/home/xyz/me2/')" % (path)) + raise OSError( + "User does not have write permissions " + "at '%s', try giving the path as an " + "argument to data_path() where user has " + "write permissions, for ex:data_path" + "('/home/xyz/me2/')" % (path) + ) return Path(path) def _do_path_update(path, update_path, key, name): """Update path.""" path = op.abspath(path) - identical = get_config(key, '', use_env=False) == path + identical = get_config(key, "", use_env=False) == path if not identical: if update_path is None: update_path = True - if '--update-dataset-path' in sys.argv: - answer = 'y' + if "--update-dataset-path" in sys.argv: + answer = "y" else: - msg = ('Do you want to set the path:\n %s\nas the default ' - '%s dataset path in the mne-python config [y]/n? ' - % (path, name)) - answer = _safe_input(msg, alt='pass update_path=True') - if answer.lower() == 'n': + msg = ( + "Do you want to set the path:\n %s\nas the default " + "%s dataset path in the mne-python config [y]/n? " % (path, name) + ) + answer = _safe_input(msg, alt="pass update_path=True") + if answer.lower() == "n": update_path = False if update_path: @@ -149,14 +161,15 @@ def _do_path_update(path, update_path, key, name): # This is meant to be semi-public: let packages like mne-bids use it to make # sure they don't accidentally set download=True in their tests, too -_MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS = ('mne',) +_MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS = ("mne",) def _check_in_testing_and_raise(name, download): """Check if we're in an MNE test and raise an error if download!=False.""" root_dirs = [ importlib.import_module(ns) - for ns in _MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS] + for ns in _MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS + ] root_dirs = [str(Path(ns.__file__).parent) for ns in root_dirs] check = False func = None @@ -164,7 +177,7 @@ def _check_in_testing_and_raise(name, download): try: # First, traverse out of the data_path() call while frame: - if frame.f_code.co_name in ('data_path', 'load_data'): + if frame.f_code.co_name in ("data_path", "load_data"): func = frame.f_code.co_name frame = frame.f_back.f_back # out of verbose decorator break @@ -177,10 +190,12 @@ def _check_in_testing_and_raise(name, download): # in mne namespace, and # (can't use is_relative_to here until 3.9) if any(str(fname).startswith(rd) for rd in root_dirs) and ( - # in tests/*.py - fname.parent.stem == 'tests' or - # or in a conftest.py - fname.stem == 'conftest.py'): + # in tests/*.py + fname.parent.stem == "tests" + or + # or in a conftest.py + fname.stem == "conftest.py" + ): check = True break frame = frame.f_back @@ -188,12 +203,14 @@ def _check_in_testing_and_raise(name, download): del frame if check and download is not False: raise RuntimeError( - f'Do not download dataset {repr(name)} in tests, pass ' - f'{func}(download=False) to prevent accidental downloads') + f"Do not download dataset {repr(name)} in tests, pass " + f"{func}(download=False) to prevent accidental downloads" + ) -def _download_mne_dataset(name, processor, path, force_update, - update_path, download, accept=False): +def _download_mne_dataset( + name, processor, path, force_update, update_path, download, accept=False +): """Aux function for downloading internal MNE datasets.""" import pooch from mne.datasets._fetch import fetch_dataset @@ -202,33 +219,38 @@ def _download_mne_dataset(name, processor, path, force_update, # import pooch library for handling the dataset downloading dataset_params = MNE_DATASETS[name] - dataset_params['dataset_name'] = name - config_key = MNE_DATASETS[name]['config_key'] - folder_name = MNE_DATASETS[name]['folder_name'] + dataset_params["dataset_name"] = name + config_key = MNE_DATASETS[name]["config_key"] + folder_name = MNE_DATASETS[name]["folder_name"] # get download path for specific dataset path = _get_path(path=path, key=config_key, name=name) # instantiate processor that unzips file - if processor == 'nested_untar': + if processor == "nested_untar": processor_ = pooch.Untar(extract_dir=op.join(path, folder_name)) - elif processor == 'nested_unzip': + elif processor == "nested_unzip": processor_ = pooch.Unzip(extract_dir=op.join(path, folder_name)) else: processor_ = processor # handle case of multiple sub-datasets with different urls - if name == 'visual_92_categories': + if name == "visual_92_categories": dataset_params = [] - for name in ['visual_92_categories_1', 'visual_92_categories_2']: + for name in ["visual_92_categories_1", "visual_92_categories_2"]: this_dataset = MNE_DATASETS[name] - this_dataset['dataset_name'] = name + this_dataset["dataset_name"] = name dataset_params.append(this_dataset) - return fetch_dataset(dataset_params=dataset_params, processor=processor_, - path=path, force_update=force_update, - update_path=update_path, download=download, - accept=accept) + return fetch_dataset( + dataset_params=dataset_params, + processor=processor_, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) def _get_version(name): @@ -238,14 +260,13 @@ def _get_version(name): if not has_dataset(name): return None dataset_params = MNE_DATASETS[name] - dataset_params['dataset_name'] = name - config_key = MNE_DATASETS[name]['config_key'] + dataset_params["dataset_name"] = name + config_key = MNE_DATASETS[name]["config_key"] # get download path for specific dataset path = _get_path(path=None, key=config_key, name=name) - return fetch_dataset(dataset_params, path=path, - return_version=True)[1] + return fetch_dataset(dataset_params, path=path, return_version=True)[1] def has_dataset(name): @@ -268,24 +289,23 @@ def has_dataset(name): from mne.datasets._fetch import fetch_dataset if isinstance(name, dict): - dataset_name = name['dataset_name'] + dataset_name = name["dataset_name"] dataset_params = name else: - dataset_name = 'spm' if name == 'spm_face' else name + dataset_name = "spm" if name == "spm_face" else name dataset_params = MNE_DATASETS[dataset_name] - dataset_params['dataset_name'] = dataset_name + dataset_params["dataset_name"] = dataset_name - config_key = dataset_params['config_key'] + config_key = dataset_params["config_key"] # get download path for specific dataset path = _get_path(path=None, key=config_key, name=dataset_name) - dp = fetch_dataset(dataset_params, path=path, download=False, - check_version=False) - if dataset_name.startswith('bst_'): + dp = fetch_dataset(dataset_params, path=path, download=False, check_version=False) + if dataset_name.startswith("bst_"): check = dataset_name else: - check = MNE_DATASETS[dataset_name]['folder_name'] + check = MNE_DATASETS[dataset_name]["folder_name"] return str(dp).endswith(check) @@ -302,51 +322,57 @@ def _download_all_example_data(verbose=True): # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build paths = dict() - for kind in ('sample testing misc spm_face somato hf_sef multimodal ' - 'fnirs_motor opm mtrf fieldtrip_cmc kiloword phantom_4dbti ' - 'refmeg_noise ssvep epilepsy_ecog ucl_opm_auditory eyelink ' - 'erp_core brainstorm.bst_raw brainstorm.bst_auditory ' - 'brainstorm.bst_resting brainstorm.bst_phantom_ctf ' - 'brainstorm.bst_phantom_elekta' - ).split(): - mod = importlib.import_module(f'mne.datasets.{kind}') - data_path_func = getattr(mod, 'data_path') + for kind in ( + "sample testing misc spm_face somato hf_sef multimodal " + "fnirs_motor opm mtrf fieldtrip_cmc kiloword phantom_4dbti " + "refmeg_noise ssvep epilepsy_ecog ucl_opm_auditory eyelink " + "erp_core brainstorm.bst_raw brainstorm.bst_auditory " + "brainstorm.bst_resting brainstorm.bst_phantom_ctf " + "brainstorm.bst_phantom_elekta" + ).split(): + mod = importlib.import_module(f"mne.datasets.{kind}") + data_path_func = getattr(mod, "data_path") kwargs = dict() - if 'accept' in inspect.getfullargspec(data_path_func).args: - kwargs['accept'] = True + if "accept" in inspect.getfullargspec(data_path_func).args: + kwargs["accept"] = True paths[kind] = data_path_func(**kwargs) - logger.info(f'[done {kind}]') + logger.info(f"[done {kind}]") # Now for the exceptions: from . import ( - eegbci, sleep_physionet, limo, fetch_fsaverage, fetch_infant_template, - fetch_hcp_mmp_parcellation, fetch_phantom) + eegbci, + sleep_physionet, + limo, + fetch_fsaverage, + fetch_infant_template, + fetch_hcp_mmp_parcellation, + fetch_phantom, + ) + eegbci.load_data(1, [6, 10, 14], update_path=True) for subj in range(4): eegbci.load_data(subj + 1, runs=[3], update_path=True) - logger.info('[done eegbci]') + logger.info("[done eegbci]") sleep_physionet.age.fetch_data(subjects=[0, 1], recording=[1]) - logger.info('[done sleep_physionet]') + logger.info("[done sleep_physionet]") # If the user has SUBJECTS_DIR, respect it, if not, set it to the EEG one # (probably on CircleCI, or otherwise advanced user) fetch_fsaverage(None) - logger.info('[done fsaverage]') + logger.info("[done fsaverage]") - fetch_infant_template('6mo') - logger.info('[done infant_template]') + fetch_infant_template("6mo") + logger.info("[done infant_template]") - fetch_hcp_mmp_parcellation( - subjects_dir=paths['sample'] / 'subjects', accept=True) - logger.info('[done hcp_mmp_parcellation]') + fetch_hcp_mmp_parcellation(subjects_dir=paths["sample"] / "subjects", accept=True) + logger.info("[done hcp_mmp_parcellation]") - fetch_phantom( - 'otaniemi', subjects_dir=paths['brainstorm.bst_phantom_elekta']) - logger.info('[done phantom]') + fetch_phantom("otaniemi", subjects_dir=paths["brainstorm.bst_phantom_elekta"]) + logger.info("[done phantom]") limo.load_data(subject=1, update_path=True) - logger.info('[done limo]') + logger.info("[done limo]") @verbose @@ -372,13 +398,13 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) destination = subjects_dir / "fsaverage" / "label" - urls = dict(lh='https://osf.io/p92yb/download', - rh='https://osf.io/4kxny/download') - hashes = dict(lh='9e4d8d6b90242b7e4b0145353436ef77', - rh='dd6464db8e7762d969fc1d8087cd211b') + urls = dict(lh="https://osf.io/p92yb/download", rh="https://osf.io/4kxny/download") + hashes = dict( + lh="9e4d8d6b90242b7e4b0145353436ef77", rh="dd6464db8e7762d969fc1d8087cd211b" + ) downloader = pooch.HTTPDownloader(**_downloader_params()) - for hemi in ('lh', 'rh'): - fname = f'{hemi}.aparc_sub.annot' + for hemi in ("lh", "rh"): + fname = f"{hemi}.aparc_sub.annot" fpath = destination / fname if not fpath.is_file(): pooch.retrieve( @@ -391,8 +417,9 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): @verbose -def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, - accept=False, verbose=None): +def fetch_hcp_mmp_parcellation( + subjects_dir=None, combine=True, *, accept=False, verbose=None +): """Fetch the HCP-MMP parcellation. This will download and install the HCP-MMP parcellation @@ -425,20 +452,22 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) destination = subjects_dir / "fsaverage" / "label" fnames = [destination / f"{hemi}.HCPMMP1.annot" for hemi in ("lh", "rh")] - urls = dict(lh='https://ndownloader.figshare.com/files/5528816', - rh='https://ndownloader.figshare.com/files/5528819') - hashes = dict(lh='46a102b59b2fb1bb4bd62d51bf02e975', - rh='75e96b331940227bbcb07c1c791c2463') + urls = dict( + lh="https://ndownloader.figshare.com/files/5528816", + rh="https://ndownloader.figshare.com/files/5528819", + ) + hashes = dict( + lh="46a102b59b2fb1bb4bd62d51bf02e975", rh="75e96b331940227bbcb07c1c791c2463" + ) if not all(fname.exists() for fname in fnames): - if accept or '--accept-hcpmmp-license' in sys.argv: - answer = 'y' + if accept or "--accept-hcpmmp-license" in sys.argv: + answer = "y" else: - answer = _safe_input('%s\nAgree (y/[n])? ' % _hcp_mmp_license_text) - if answer.lower() != 'y': - raise RuntimeError('You must agree to the license to use this ' - 'dataset') + answer = _safe_input("%s\nAgree (y/[n])? " % _hcp_mmp_license_text) + if answer.lower() != "y": + raise RuntimeError("You must agree to the license to use this " "dataset") downloader = pooch.HTTPDownloader(**_downloader_params()) - for hemi, fpath in zip(('lh', 'rh'), fnames): + for hemi, fpath in zip(("lh", "rh"), fnames): if not op.isfile(fpath): fname = fpath.name pooch.retrieve( @@ -450,82 +479,255 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, ) if combine: - fnames = [op.join(destination, '%s.HCPMMP1_combined.annot' % hemi) - for hemi in ('lh', 'rh')] + fnames = [ + op.join(destination, "%s.HCPMMP1_combined.annot" % hemi) + for hemi in ("lh", "rh") + ] if all(op.isfile(fname) for fname in fnames): return # otherwise, let's make them - logger.info('Creating combined labels') - groups = OrderedDict([ - ('Primary Visual Cortex (V1)', - ('V1',)), - ('Early Visual Cortex', - ('V2', 'V3', 'V4')), - ('Dorsal Stream Visual Cortex', - ('V3A', 'V3B', 'V6', 'V6A', 'V7', 'IPS1')), - ('Ventral Stream Visual Cortex', - ('V8', 'VVC', 'PIT', 'FFC', 'VMV1', 'VMV2', 'VMV3')), - ('MT+ Complex and Neighboring Visual Areas', - ('V3CD', 'LO1', 'LO2', 'LO3', 'V4t', 'FST', 'MT', 'MST', 'PH')), - ('Somatosensory and Motor Cortex', - ('4', '3a', '3b', '1', '2')), - ('Paracentral Lobular and Mid Cingulate Cortex', - ('24dd', '24dv', '6mp', '6ma', 'SCEF', '5m', '5L', '5mv',)), - ('Premotor Cortex', - ('55b', '6d', '6a', 'FEF', '6v', '6r', 'PEF')), - ('Posterior Opercular Cortex', - ('43', 'FOP1', 'OP4', 'OP1', 'OP2-3', 'PFcm')), - ('Early Auditory Cortex', - ('A1', 'LBelt', 'MBelt', 'PBelt', 'RI')), - ('Auditory Association Cortex', - ('A4', 'A5', 'STSdp', 'STSda', 'STSvp', 'STSva', 'STGa', 'TA2',)), - ('Insular and Frontal Opercular Cortex', - ('52', 'PI', 'Ig', 'PoI1', 'PoI2', 'FOP2', 'FOP3', - 'MI', 'AVI', 'AAIC', 'Pir', 'FOP4', 'FOP5')), - ('Medial Temporal Cortex', - ('H', 'PreS', 'EC', 'PeEc', 'PHA1', 'PHA2', 'PHA3',)), - ('Lateral Temporal Cortex', - ('PHT', 'TE1p', 'TE1m', 'TE1a', 'TE2p', 'TE2a', - 'TGv', 'TGd', 'TF',)), - ('Temporo-Parieto-Occipital Junction', - ('TPOJ1', 'TPOJ2', 'TPOJ3', 'STV', 'PSL',)), - ('Superior Parietal Cortex', - ('LIPv', 'LIPd', 'VIP', 'AIP', 'MIP', - '7PC', '7AL', '7Am', '7PL', '7Pm',)), - ('Inferior Parietal Cortex', - ('PGp', 'PGs', 'PGi', 'PFm', 'PF', 'PFt', 'PFop', - 'IP0', 'IP1', 'IP2',)), - ('Posterior Cingulate Cortex', - ('DVT', 'ProS', 'POS1', 'POS2', 'RSC', 'v23ab', 'd23ab', - '31pv', '31pd', '31a', '23d', '23c', 'PCV', '7m',)), - ('Anterior Cingulate and Medial Prefrontal Cortex', - ('33pr', 'p24pr', 'a24pr', 'p24', 'a24', 'p32pr', 'a32pr', 'd32', - 'p32', 's32', '8BM', '9m', '10v', '10r', '25',)), - ('Orbital and Polar Frontal Cortex', - ('47s', '47m', 'a47r', '11l', '13l', - 'a10p', 'p10p', '10pp', '10d', 'OFC', 'pOFC',)), - ('Inferior Frontal Cortex', - ('44', '45', 'IFJp', 'IFJa', 'IFSp', 'IFSa', '47l', 'p47r',)), - ('DorsoLateral Prefrontal Cortex', - ('8C', '8Av', 'i6-8', 's6-8', 'SFL', '8BL', '9p', '9a', '8Ad', - 'p9-46v', 'a9-46v', '46', '9-46d',)), - ('???', - ('???',))]) + logger.info("Creating combined labels") + groups = OrderedDict( + [ + ("Primary Visual Cortex (V1)", ("V1",)), + ("Early Visual Cortex", ("V2", "V3", "V4")), + ( + "Dorsal Stream Visual Cortex", + ("V3A", "V3B", "V6", "V6A", "V7", "IPS1"), + ), + ( + "Ventral Stream Visual Cortex", + ("V8", "VVC", "PIT", "FFC", "VMV1", "VMV2", "VMV3"), + ), + ( + "MT+ Complex and Neighboring Visual Areas", + ("V3CD", "LO1", "LO2", "LO3", "V4t", "FST", "MT", "MST", "PH"), + ), + ("Somatosensory and Motor Cortex", ("4", "3a", "3b", "1", "2")), + ( + "Paracentral Lobular and Mid Cingulate Cortex", + ( + "24dd", + "24dv", + "6mp", + "6ma", + "SCEF", + "5m", + "5L", + "5mv", + ), + ), + ("Premotor Cortex", ("55b", "6d", "6a", "FEF", "6v", "6r", "PEF")), + ( + "Posterior Opercular Cortex", + ("43", "FOP1", "OP4", "OP1", "OP2-3", "PFcm"), + ), + ("Early Auditory Cortex", ("A1", "LBelt", "MBelt", "PBelt", "RI")), + ( + "Auditory Association Cortex", + ( + "A4", + "A5", + "STSdp", + "STSda", + "STSvp", + "STSva", + "STGa", + "TA2", + ), + ), + ( + "Insular and Frontal Opercular Cortex", + ( + "52", + "PI", + "Ig", + "PoI1", + "PoI2", + "FOP2", + "FOP3", + "MI", + "AVI", + "AAIC", + "Pir", + "FOP4", + "FOP5", + ), + ), + ( + "Medial Temporal Cortex", + ( + "H", + "PreS", + "EC", + "PeEc", + "PHA1", + "PHA2", + "PHA3", + ), + ), + ( + "Lateral Temporal Cortex", + ( + "PHT", + "TE1p", + "TE1m", + "TE1a", + "TE2p", + "TE2a", + "TGv", + "TGd", + "TF", + ), + ), + ( + "Temporo-Parieto-Occipital Junction", + ( + "TPOJ1", + "TPOJ2", + "TPOJ3", + "STV", + "PSL", + ), + ), + ( + "Superior Parietal Cortex", + ( + "LIPv", + "LIPd", + "VIP", + "AIP", + "MIP", + "7PC", + "7AL", + "7Am", + "7PL", + "7Pm", + ), + ), + ( + "Inferior Parietal Cortex", + ( + "PGp", + "PGs", + "PGi", + "PFm", + "PF", + "PFt", + "PFop", + "IP0", + "IP1", + "IP2", + ), + ), + ( + "Posterior Cingulate Cortex", + ( + "DVT", + "ProS", + "POS1", + "POS2", + "RSC", + "v23ab", + "d23ab", + "31pv", + "31pd", + "31a", + "23d", + "23c", + "PCV", + "7m", + ), + ), + ( + "Anterior Cingulate and Medial Prefrontal Cortex", + ( + "33pr", + "p24pr", + "a24pr", + "p24", + "a24", + "p32pr", + "a32pr", + "d32", + "p32", + "s32", + "8BM", + "9m", + "10v", + "10r", + "25", + ), + ), + ( + "Orbital and Polar Frontal Cortex", + ( + "47s", + "47m", + "a47r", + "11l", + "13l", + "a10p", + "p10p", + "10pp", + "10d", + "OFC", + "pOFC", + ), + ), + ( + "Inferior Frontal Cortex", + ( + "44", + "45", + "IFJp", + "IFJa", + "IFSp", + "IFSa", + "47l", + "p47r", + ), + ), + ( + "DorsoLateral Prefrontal Cortex", + ( + "8C", + "8Av", + "i6-8", + "s6-8", + "SFL", + "8BL", + "9p", + "9a", + "8Ad", + "p9-46v", + "a9-46v", + "46", + "9-46d", + ), + ), + ("???", ("???",)), + ] + ) assert len(groups) == 23 labels_out = list() - for hemi in ('lh', 'rh'): - labels = read_labels_from_annot('fsaverage', 'HCPMMP1', hemi=hemi, - subjects_dir=subjects_dir, - sort=False) + for hemi in ("lh", "rh"): + labels = read_labels_from_annot( + "fsaverage", "HCPMMP1", hemi=hemi, subjects_dir=subjects_dir, sort=False + ) label_names = [ - '???' if label.name.startswith('???') else - label.name.split('_')[1] for label in labels] + "???" if label.name.startswith("???") else label.name.split("_")[1] + for label in labels + ] used = np.zeros(len(labels), bool) for key, want in groups.items(): - assert '\t' not in key - these_labels = [li for li, label_name in enumerate(label_names) - if label_name in want] + assert "\t" not in key + these_labels = [ + li + for li, label_name in enumerate(label_names) + if label_name in want + ] assert not used[these_labels].any() assert len(these_labels) == len(want) used[these_labels] = True @@ -535,38 +737,47 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, w = np.array([len(label.vertices) for label in these_labels]) w = w / float(w.sum()) color = np.dot(w, [label.color for label in these_labels]) - these_labels = sum(these_labels, - Label([], subject='fsaverage', hemi=hemi)) + these_labels = sum( + these_labels, Label([], subject="fsaverage", hemi=hemi) + ) these_labels.name = key these_labels.color = color labels_out.append(these_labels) assert used.all() assert len(labels_out) == 46 - for hemi, side in (('lh', 'left'), ('rh', 'right')): - table_name = './%s.fsaverage164.label.gii' % (side,) - write_labels_to_annot(labels_out, 'fsaverage', 'HCPMMP1_combined', - hemi=hemi, subjects_dir=subjects_dir, - sort=False, table_name=table_name) + for hemi, side in (("lh", "left"), ("rh", "right")): + table_name = "./%s.fsaverage164.label.gii" % (side,) + write_labels_to_annot( + labels_out, + "fsaverage", + "HCPMMP1_combined", + hemi=hemi, + subjects_dir=subjects_dir, + sort=False, + table_name=table_name, + ) def _manifest_check_download(manifest_path, destination, url, hash_): import pooch - with open(manifest_path, 'r') as fid: + with open(manifest_path, "r") as fid: names = [name.strip() for name in fid.readlines()] manifest_path = op.basename(manifest_path) need = list() for name in names: if not op.isfile(op.join(destination, name)): need.append(name) - logger.info('%d file%s missing from %s in %s' - % (len(need), _pl(need), manifest_path, destination)) + logger.info( + "%d file%s missing from %s in %s" + % (len(need), _pl(need), manifest_path, destination) + ) if len(need) > 0: downloader = pooch.HTTPDownloader(**_downloader_params()) with tempfile.TemporaryDirectory() as path: - logger.info('Downloading missing files remotely') + logger.info("Downloading missing files remotely") - fname_path = op.join(path, 'temp.zip') + fname_path = op.join(path, "temp.zip") pooch.retrieve( url=url, known_hash=f"md5:{hash_}", @@ -575,36 +786,36 @@ def _manifest_check_download(manifest_path, destination, url, hash_): fname=op.basename(fname_path), ) - logger.info('Extracting missing file%s' % (_pl(need),)) - with zipfile.ZipFile(fname_path, 'r') as ff: - members = set(f for f in ff.namelist() if not f.endswith('/')) + logger.info("Extracting missing file%s" % (_pl(need),)) + with zipfile.ZipFile(fname_path, "r") as ff: + members = set(f for f in ff.namelist() if not f.endswith("/")) missing = sorted(members.symmetric_difference(set(names))) if len(missing): - raise RuntimeError('Zip file did not have correct names:' - '\n%s' % ('\n'.join(missing))) + raise RuntimeError( + "Zip file did not have correct names:" + "\n%s" % ("\n".join(missing)) + ) for name in need: ff.extract(name, path=destination) - logger.info('Successfully extracted %d file%s' - % (len(need), _pl(need))) + logger.info("Successfully extracted %d file%s" % (len(need), _pl(need))) def _log_time_size(t0, sz): t = time.time() - t0 - fmt = '%Ss' + fmt = "%Ss" if t > 60: - fmt = f'%Mm{fmt}' + fmt = f"%Mm{fmt}" if t > 3600: - fmt = f'%Hh{fmt}' + fmt = f"%Hh{fmt}" sz = sz / 1048576 # 1024 ** 2 t = time.strftime(fmt, time.gmtime(t)) - logger.info(f'Download complete in {t} ({sz:.1f} MB)') + logger.info(f"Download complete in {t} ({sz:.1f} MB)") def _downloader_params(*, auth=None, token=None): params = dict() - params['progressbar'] = ( - logger.level <= logging.INFO and - get_config('MNE_TQDM', 'tqdm.auto') != 'off' + params["progressbar"] = ( + logger.level <= logging.INFO and get_config("MNE_TQDM", "tqdm.auto") != "off" ) if auth is not None: params["auth"] = auth diff --git a/mne/datasets/visual_92_categories/visual_92_categories.py b/mne/datasets/visual_92_categories/visual_92_categories.py index df687aafb6c..d5fb1c1c8bb 100644 --- a/mne/datasets/visual_92_categories/visual_92_categories.py +++ b/mne/datasets/visual_92_categories/visual_92_categories.py @@ -1,13 +1,13 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_download_mne_dataset, _data_path_doc, _get_version, - _version_doc) +from ..utils import _download_mne_dataset, _data_path_doc, _get_version, _version_doc @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): """ Get path to local copy of visual_92_categories dataset. @@ -43,18 +43,23 @@ def data_path(path=None, force_update=False, update_path=True, human object recognition in space and time. doi: 10.1038/NN.3635 """ return _download_mne_dataset( - name='visual_92_categories', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="visual_92_categories", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='visual_92_categories', conf='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH') + name="visual_92_categories", conf="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH" +) def get_version(): """Get dataset version.""" - return _get_version('visual_92_categories') + return _get_version("visual_92_categories") -get_version.__doc__ = _version_doc.format(name='visual_92_categories') +get_version.__doc__ = _version_doc.format(name="visual_92_categories") diff --git a/mne/decoding/__init__.py b/mne/decoding/__init__.py index 2b0136256b6..099e3c0dd30 100644 --- a/mne/decoding/__init__.py +++ b/mne/decoding/__init__.py @@ -1,8 +1,13 @@ """Decoding and encoding, including machine learning and receptive fields.""" -from .transformer import (PSDEstimator, Vectorizer, - UnsupervisedSpatialFilter, TemporalFilter, - Scaler, FilterEstimator) +from .transformer import ( + PSDEstimator, + Vectorizer, + UnsupervisedSpatialFilter, + TemporalFilter, + Scaler, + FilterEstimator, +) from .mixin import TransformerMixin from .base import BaseEstimator, LinearModel, get_coef, cross_val_multiscore from .csp import CSP, SPoC diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 9d8070b8179..348ee2ee0f7 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -54,7 +54,8 @@ class LinearModel(BaseEstimator): def __init__(self, model=None): # noqa: D102 if model is None: from sklearn.linear_model import LogisticRegression - model = LogisticRegression(solver='liblinear') + + model = LogisticRegression(solver="liblinear") self.model = model self._estimator_type = getattr(model, "_estimator_type", None) @@ -81,18 +82,22 @@ def fit(self, X, y, **fit_params): """ X, y = np.asarray(X), np.asarray(y) if X.ndim != 2: - raise ValueError('LinearModel only accepts 2-dimensional X, got ' - '%s instead.' % (X.shape,)) + raise ValueError( + "LinearModel only accepts 2-dimensional X, got " + "%s instead." % (X.shape,) + ) if y.ndim > 2: - raise ValueError('LinearModel only accepts up to 2-dimensional y, ' - 'got %s instead.' % (y.shape,)) + raise ValueError( + "LinearModel only accepts up to 2-dimensional y, " + "got %s instead." % (y.shape,) + ) # fit the Model self.model.fit(X, y, **fit_params) # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y - inv_Y = 1. + inv_Y = 1.0 X = X - X.mean(0, keepdims=True) if y.ndim == 2 and y.shape[1] != 1: y = y - y.mean(0, keepdims=True) @@ -103,14 +108,14 @@ def fit(self, X, y, **fit_params): @property def filters_(self): - if hasattr(self.model, 'coef_'): + if hasattr(self.model, "coef_"): # Standard Linear Model filters = self.model.coef_ - elif hasattr(self.model.best_estimator_, 'coef_'): + elif hasattr(self.model.best_estimator_, "coef_"): # Linear Model with GridSearchCV filters = self.model.best_estimator_.coef_ else: - raise ValueError('model does not have a `coef_` attribute.') + raise ValueError("model does not have a `coef_` attribute.") if filters.ndim == 2 and filters.shape[0] == 1: filters = filters[0] return filters @@ -213,60 +218,62 @@ def score(self, X, y): def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" # Detect whether classification or regression - if estimator in ['classifier', 'regressor']: - est_is_classifier = estimator == 'classifier' + if estimator in ["classifier", "regressor"]: + est_is_classifier = estimator == "classifier" else: est_is_classifier = is_classifier(estimator) # Setup CV from sklearn import model_selection as models - from sklearn.model_selection import (check_cv, StratifiedKFold, KFold) + from sklearn.model_selection import check_cv, StratifiedKFold, KFold + if isinstance(cv, (int, np.int64)): XFold = StratifiedKFold if est_is_classifier else KFold cv = XFold(n_splits=cv) elif isinstance(cv, str): if not hasattr(models, cv): - raise ValueError('Unknown cross-validation') + raise ValueError("Unknown cross-validation") cv = getattr(models, cv) cv = cv() cv = check_cv(cv=cv, y=y, classifier=est_is_classifier) # Extract train and test set to retrieve them at predict time - cv_splits = [(train, test) for train, test in - cv.split(X=np.zeros_like(y), y=y)] + cv_splits = [(train, test) for train, test in cv.split(X=np.zeros_like(y), y=y)] if not np.all([len(train) for train, _ in cv_splits]): - raise ValueError('Some folds do not have any train epochs.') + raise ValueError("Some folds do not have any train epochs.") return cv, cv_splits def _check_estimator(estimator, get_params=True): """Check whether an object has the methods required by sklearn.""" - valid_methods = ('predict', 'transform', 'predict_proba', - 'decision_function') - if ( - (not hasattr(estimator, 'fit')) or - (not any(hasattr(estimator, method) for method in valid_methods)) + valid_methods = ("predict", "transform", "predict_proba", "decision_function") + if (not hasattr(estimator, "fit")) or ( + not any(hasattr(estimator, method) for method in valid_methods) ): - raise ValueError('estimator must be a scikit-learn transformer or ' - 'an estimator with the fit and a predict-like (e.g. ' - 'predict_proba) or a transform method.') + raise ValueError( + "estimator must be a scikit-learn transformer or " + "an estimator with the fit and a predict-like (e.g. " + "predict_proba) or a transform method." + ) - if get_params and not hasattr(estimator, 'get_params'): - raise ValueError('estimator must be a scikit-learn transformer or an ' - 'estimator with the get_params method that allows ' - 'cloning.') + if get_params and not hasattr(estimator, "get_params"): + raise ValueError( + "estimator must be a scikit-learn transformer or an " + "estimator with the get_params method that allows " + "cloning." + ) def _get_inverse_funcs(estimator, terminal=True): """Retrieve the inverse functions of an pipeline or an estimator.""" inverse_func = [False] - if hasattr(estimator, 'steps'): + if hasattr(estimator, "steps"): # if pipeline, retrieve all steps by nesting inverse_func = list() for _, est in estimator.steps: inverse_func.extend(_get_inverse_funcs(est, terminal=False)) - elif hasattr(estimator, 'inverse_transform'): + elif hasattr(estimator, "inverse_transform"): # if not pipeline attempt to retrieve inverse function inverse_func = [estimator.inverse_transform] @@ -284,7 +291,7 @@ def _get_inverse_funcs(estimator, terminal=True): return inverse_func -def get_coef(estimator, attr='filters_', inverse_transform=False): +def get_coef(estimator, attr="filters_", inverse_transform=False): """Retrieve the coefficients of an estimator ending with a Linear Model. This is typically useful to retrieve "spatial filters" or "spatial @@ -312,13 +319,13 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): """ # Get the coefficients of the last estimator in case of nested pipeline est = estimator - while hasattr(est, 'steps'): + while hasattr(est, "steps"): est = est.steps[-1][1] squeeze_first_dim = False # If SlidingEstimator, loop across estimators - if hasattr(est, 'estimators_'): + if hasattr(est, "estimators_"): coef = list() for this_est in est.estimators_: coef.append(get_coef(this_est, attr, inverse_transform)) @@ -326,8 +333,9 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): coef = coef[np.newaxis] # fake a sample dimension squeeze_first_dim = True elif not hasattr(est, attr): - raise ValueError('This estimator does not have a %s attribute:\n%s' - % (attr, est)) + raise ValueError( + "This estimator does not have a %s attribute:\n%s" % (attr, est) + ) else: coef = getattr(est, attr) @@ -337,9 +345,10 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): # inverse pattern e.g. to get back physical units if inverse_transform: - if not hasattr(estimator, 'steps') and not hasattr(est, 'estimators_'): - raise ValueError('inverse_transform can only be applied onto ' - 'pipeline estimators.') + if not hasattr(estimator, "steps") and not hasattr(est, "estimators_"): + raise ValueError( + "inverse_transform can only be applied onto " "pipeline estimators." + ) # The inverse_transform parameter will call this method on any # estimator contained in the pipeline, in reverse order. for inverse_func in _get_inverse_funcs(estimator)[::-1]: @@ -352,9 +361,18 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): @verbose -def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, - cv=None, n_jobs=None, verbose=None, fit_params=None, - pre_dispatch='2*n_jobs'): +def cross_val_multiscore( + estimator, + X, + y=None, + groups=None, + scoring=None, + cv=None, + n_jobs=None, + verbose=None, + fit_params=None, + pre_dispatch="2*n_jobs", +): """Evaluate a score by cross-validation. Parameters @@ -420,6 +438,7 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, from sklearn.base import clone from sklearn.utils import indexable from sklearn.model_selection._split import check_cv + check_scoring = _get_check_scoring() X, y, groups = indexable(X, y, groups) @@ -430,15 +449,23 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. # Note: this parallelization is implemented using MNE Parallel - parallel, p_func, n_jobs = parallel_func(_fit_and_score, n_jobs, - pre_dispatch=pre_dispatch) - position = hasattr(estimator, 'position') + parallel, p_func, n_jobs = parallel_func( + _fit_and_score, n_jobs, pre_dispatch=pre_dispatch + ) + position = hasattr(estimator, "position") scores = parallel( p_func( - estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train, - test=test, fit_params=fit_params, verbose=verbose, + estimator=clone(estimator), + X=X, + y=y, + scorer=scorer, + train=train, + test=test, + fit_params=fit_params, + verbose=verbose, parameters=dict(position=ii % n_jobs) if position else None, - ) for ii, (train, test) in enumerate(cv_iter) + ) + for ii, (train, test) in enumerate(cv_iter) ) return np.array(scores)[:, 0, ...] # flatten over joblib output. @@ -446,11 +473,24 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, # This verbose is necessary to properly set the verbosity level # during parallelization @verbose -def _fit_and_score(estimator, X, y, scorer, train, test, - parameters, fit_params, return_train_score=False, - return_parameters=False, return_n_test_samples=False, - return_times=False, error_score='raise', *, verbose=None, - position=0): +def _fit_and_score( + estimator, + X, + y, + scorer, + train, + test, + parameters, + fit_params, + return_train_score=False, + return_parameters=False, + return_n_test_samples=False, + return_times=False, + error_score="raise", + *, + verbose=None, + position=0 +): """Fit estimator and compute scores for a given dataset split.""" # This code is adapted from sklearn from ..fixes import _check_fit_params @@ -479,19 +519,23 @@ def _fit_and_score(estimator, X, y, scorer, train, test, # Note fit time as time until error fit_duration = dt.datetime.now() - start_time score_duration = dt.timedelta(0) - if error_score == 'raise': + if error_score == "raise": raise elif isinstance(error_score, numbers.Number): test_score = error_score if return_train_score: train_score = error_score - warn("Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r" % (error_score, e)) + warn( + "Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e) + ) else: - raise ValueError("error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)") + raise ValueError( + "error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) else: fit_duration = dt.datetime.now() - start_time @@ -505,10 +549,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, if return_n_test_samples: ret.append(_num_samples(X_test)) if return_times: - ret.extend([ - fit_duration.total_seconds(), - score_duration.total_seconds() - ]) + ret.extend([fit_duration.total_seconds(), score_duration.total_seconds()]) if return_parameters: ret.append(parameters) return ret @@ -524,7 +565,7 @@ def _score(estimator, X_test, y_test, scorer): score = scorer(estimator, X_test) else: score = scorer(estimator, X_test, y_test) - if hasattr(score, 'item'): + if hasattr(score, "item"): try: # e.g. unwrap memmapped scalars score = score.item() diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 6e3ed67c163..f4c74ad6d91 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -13,8 +13,7 @@ from .base import BaseEstimator from .mixin import TransformerMixin from ..cov import _regularized_covariance -from ..defaults import (_BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, - _INTERPOLATION_DEFAULT) +from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..fixes import pinv from ..utils import fill_doc, _check_option, _validate_type, copy_doc @@ -97,13 +96,21 @@ class CSP(TransformerMixin, BaseEstimator): .. footbibliography:: """ - def __init__(self, n_components=4, reg=None, log=None, cov_est='concat', - transform_into='average_power', norm_trace=False, - cov_method_params=None, rank=None, - component_order='mutual_info'): + def __init__( + self, + n_components=4, + reg=None, + log=None, + cov_est="concat", + transform_into="average_power", + norm_trace=False, + cov_method_params=None, + rank=None, + component_order="mutual_info", + ): # Init default CSP if not isinstance(n_components, int): - raise ValueError('n_components must be an integer.') + raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg @@ -114,37 +121,39 @@ def __init__(self, n_components=4, reg=None, log=None, cov_est='concat', self.cov_est = cov_est # Init default transform_into - self.transform_into = _check_option('transform_into', transform_into, - ['average_power', 'csp_space']) + self.transform_into = _check_option( + "transform_into", transform_into, ["average_power", "csp_space"] + ) # Init default log - if transform_into == 'average_power': + if transform_into == "average_power": if log is not None and not isinstance(log, bool): - raise ValueError('log must be a boolean if transform_into == ' - '"average_power".') + raise ValueError( + "log must be a boolean if transform_into == " '"average_power".' + ) else: if log is not None: - raise ValueError('log must be a None if transform_into == ' - '"csp_space".') + raise ValueError( + "log must be a None if transform_into == " '"csp_space".' + ) self.log = log - _validate_type(norm_trace, bool, 'norm_trace') + _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option('component_order', - component_order, - ('mutual_info', 'alternate')) + self.component_order = _check_option( + "component_order", component_order, ("mutual_info", "alternate") + ) def _check_Xy(self, X, y=None): """Check input data.""" if not isinstance(X, np.ndarray): - raise ValueError("X should be of type ndarray (got %s)." - % type(X)) + raise ValueError("X should be of type ndarray (got %s)." % type(X)) if y is not None: if len(X) != len(y) or len(y) < 1: - raise ValueError('X and y must have the same length.') + raise ValueError("X and y must have the same length.") if X.ndim < 3: - raise ValueError('X must have at least 3 dimensions.') + raise ValueError("X must have at least 3 dimensions.") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -167,28 +176,30 @@ def fit(self, X, y): n_classes = len(self._classes) if n_classes < 2: raise ValueError("n_classes must be >= 2.") - if n_classes > 2 and self.component_order == 'alternate': - raise ValueError("component_order='alternate' requires two " - "classes, but data contains {} classes; use " - "component_order='mutual_info' " - "instead.".format(n_classes)) + if n_classes > 2 and self.component_order == "alternate": + raise ValueError( + "component_order='alternate' requires two " + "classes, but data contains {} classes; use " + "component_order='mutual_info' " + "instead.".format(n_classes) + ) covs, sample_weights = self._compute_covariance_matrices(X, y) - eigen_vectors, eigen_values = self._decompose_covs(covs, - sample_weights) - ix = self._order_components(covs, sample_weights, eigen_vectors, - eigen_values, self.component_order) + eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights) + ix = self._order_components( + covs, sample_weights, eigen_vectors, eigen_values, self.component_order + ) eigen_vectors = eigen_vectors[:, ix] self.filters_ = eigen_vectors.T self.patterns_ = pinv(eigen_vectors) - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean power) - X = (X ** 2).mean(axis=2) + X = (X**2).mean(axis=2) # To standardize features self.mean_ = X.mean(axis=0) @@ -215,15 +226,16 @@ def transform(self, X): if not isinstance(X, np.ndarray): raise ValueError("X should be of type ndarray (got %s)." % type(X)) if self.filters_ is None: - raise RuntimeError('No filters available. Please first fit CSP ' - 'decomposition.') + raise RuntimeError( + "No filters available. Please first fit CSP " "decomposition." + ) - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean band power) - if self.transform_into == 'average_power': - X = (X ** 2).mean(axis=2) + if self.transform_into == "average_power": + X = (X**2).mean(axis=2) log = True if self.log is None else self.log if log: X = np.log(X) @@ -238,14 +250,37 @@ def fit_transform(self, X, y, **fit_params): # noqa: D102 @fill_doc def plot_patterns( - self, info, components=None, *, average=None, ch_type=None, - scalings=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap='RdBu_r', vlim=(None, None), cnorm=None, - colorbar=True, cbar_fmt='%3.1f', units=None, axes=None, - name_format='CSP%01d', nrows=1, ncols='auto', show=True): + self, + info, + components=None, + *, + average=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="CSP%01d", + nrows=1, + ncols="auto", + show=True + ): """Plot topographic patterns of components. The patterns explain how the measured data was generated from the @@ -304,38 +339,81 @@ def plot_patterns( from .. import EvokedArray if units is None: - units = 'AU' + units = "AU" if components is None: components = np.arange(self.n_components) # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): - info['sfreq'] = 1. + info["sfreq"] = 1.0 # create an evoked patterns = EvokedArray(self.patterns_.T, info, tmin=0) # the call plot_topomap fig = patterns.plot_topomap( - times=components, average=average, ch_type=ch_type, - scalings=scalings, sensors=sensors, show_names=show_names, - mask=mask, mask_params=mask_params, contours=contours, - outlines=outlines, sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, border=border, res=res, size=size, - cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, - cbar_fmt=cbar_fmt, units=units, axes=axes, time_format=name_format, - nrows=nrows, ncols=ncols, show=show) + times=components, + average=average, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) return fig @fill_doc def plot_filters( - self, info, components=None, *, average=None, ch_type=None, - scalings=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap='RdBu_r', vlim=(None, None), cnorm=None, - colorbar=True, cbar_fmt='%3.1f', units=None, axes=None, - name_format='CSP%01d', nrows=1, ncols='auto', show=True): + self, + info, + components=None, + *, + average=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="CSP%01d", + nrows=1, + ncols="auto", + show=True + ): """Plot topographic filters of components. The filters are used to extract discriminant neural sources from @@ -394,26 +472,46 @@ def plot_filters( from .. import EvokedArray if units is None: - units = 'AU' + units = "AU" if components is None: components = np.arange(self.n_components) # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): - info['sfreq'] = 1. + info["sfreq"] = 1.0 # create an evoked filters = EvokedArray(self.filters_.T, info, tmin=0) # the call plot_topomap fig = filters.plot_topomap( - times=components, average=average, ch_type=ch_type, - scalings=scalings, sensors=sensors, show_names=show_names, - mask=mask, mask_params=mask_params, contours=contours, - outlines=outlines, sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, border=border, res=res, size=size, - cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, - cbar_fmt=cbar_fmt, units=units, axes=axes, time_format=name_format, - nrows=nrows, ncols=ncols, show=show) + times=components, + average=average, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) return fig def _compute_covariance_matrices(self, X, y): @@ -444,18 +542,23 @@ def _concat_cov(self, x_class): x_class = np.transpose(x_class, [1, 0, 2]) x_class = x_class.reshape(n_channels, -1) cov = _regularized_covariance( - x_class, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank) + x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank + ) weight = x_class.shape[0] return cov, weight def _epoch_cov(self, x_class): """Mean of per-epoch covariances.""" - cov = sum(_regularized_covariance( - this_X, reg=self.reg, - method_params=self.cov_method_params, - rank=self.rank) for this_X in x_class) + cov = sum( + _regularized_covariance( + this_X, + reg=self.reg, + method_params=self.cov_method_params, + rank=self.rank, + ) + for this_X in x_class + ) cov /= len(x_class) weight = len(x_class) @@ -463,6 +566,7 @@ def _epoch_cov(self, x_class): def _decompose_covs(self, covs, sample_weights): from scipy import linalg + n_classes = len(covs) if n_classes == 2: eigen_values, eigen_vectors = linalg.eigh(covs[0], covs.sum(0)) @@ -470,8 +574,9 @@ def _decompose_covs(self, covs, sample_weights): # The multiclass case is adapted from # http://github.com/alexandrebarachant/pyRiemann eigen_vectors, D = _ajd_pham(covs) - eigen_vectors = self._normalize_eigenvectors(eigen_vectors.T, covs, - sample_weights) + eigen_vectors = self._normalize_eigenvectors( + eigen_vectors.T, covs, sample_weights + ) eigen_values = None return eigen_vectors, eigen_values @@ -481,12 +586,11 @@ def _compute_mutual_info(self, covs, sample_weights, eigen_vectors): mutual_info = [] for jj in range(eigen_vectors.shape[1]): aa, bb = 0, 0 - for (cov, prob) in zip(covs, class_probas): - tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), - eigen_vectors[:, jj]) + for cov, prob in zip(covs, class_probas): + tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), eigen_vectors[:, jj]) aa += prob * np.log(np.sqrt(tmp)) - bb += prob * (tmp ** 2 - 1) - mi = - (aa + (3.0 / 16) * (bb ** 2)) + bb += prob * (tmp**2 - 1) + mi = -(aa + (3.0 / 16) * (bb**2)) mutual_info.append(mi) return mutual_info @@ -496,25 +600,24 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): mean_cov = np.average(covs, axis=0, weights=sample_weights) for ii in range(eigen_vectors.shape[1]): - tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), - eigen_vectors[:, ii]) + tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), eigen_vectors[:, ii]) eigen_vectors[:, ii] /= np.sqrt(tmp) return eigen_vectors - def _order_components(self, covs, sample_weights, eigen_vectors, - eigen_values, component_order): + def _order_components( + self, covs, sample_weights, eigen_vectors, eigen_values, component_order + ): n_classes = len(self._classes) - if component_order == 'mutual_info' and n_classes > 2: - mutual_info = self._compute_mutual_info(covs, sample_weights, - eigen_vectors) + if component_order == "mutual_info" and n_classes > 2: + mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] - elif component_order == 'mutual_info' and n_classes == 2: + elif component_order == "mutual_info" and n_classes == 2: ix = np.argsort(np.abs(eigen_values - 0.5))[::-1] - elif component_order == 'alternate' and n_classes == 2: + elif component_order == "alternate" and n_classes == 2: i = np.argsort(eigen_values) ix = np.empty_like(i) - ix[1::2] = i[:len(i) // 2] - ix[0::2] = i[len(i) // 2:][::-1] + ix[1::2] = i[: len(i) // 2] + ix[0::2] = i[len(i) // 2 :][::-1] return ix @@ -583,16 +686,16 @@ def _ajd_pham(X, eps=1e-6, max_iter=15): decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0 - tmp = 1 + 1.j * 0.5 * np.imag(h12 * h21) - tmp = np.real(tmp + np.sqrt(tmp ** 2 - h12 * h21)) + tmp = 1 + 1.0j * 0.5 * np.imag(h12 * h21) + tmp = np.real(tmp + np.sqrt(tmp**2 - h12 * h21)) tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :]) tmp = np.c_[A[:, Ii], A[:, Ij]] - tmp = np.reshape(tmp, (n_times * n_epochs, 2), order='F') + tmp = np.reshape(tmp, (n_times * n_epochs, 2), order="F") tmp = np.dot(tmp, tau.T) - tmp = np.reshape(tmp, (n_times, n_epochs * 2), order='F') + tmp = np.reshape(tmp, (n_times, n_epochs * 2), order="F") A[:, Ii] = tmp[:, :n_epochs] A[:, Ij] = tmp[:, n_epochs:] V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :]) @@ -663,19 +766,31 @@ class SPoC(CSP): .. footbibliography:: """ - def __init__(self, n_components=4, reg=None, log=None, - transform_into='average_power', cov_method_params=None, - rank=None): + def __init__( + self, + n_components=4, + reg=None, + log=None, + transform_into="average_power", + cov_method_params=None, + rank=None, + ): """Init of SPoC.""" - super(SPoC, self).__init__(n_components=n_components, reg=reg, log=log, - cov_est="epoch", norm_trace=False, - transform_into=transform_into, rank=rank, - cov_method_params=cov_method_params) + super(SPoC, self).__init__( + n_components=n_components, + reg=reg, + log=log, + cov_est="epoch", + norm_trace=False, + transform_into=transform_into, + rank=rank, + cov_method_params=cov_method_params, + ) # Covariance estimation have to be done on the single epoch level, # unlike CSP where covariance estimation can also be achieved through # concatenation of all epochs from the same class. - delattr(self, 'cov_est') - delattr(self, 'norm_trace') + delattr(self, "cov_est") + delattr(self, "norm_trace") def fit(self, X, y): """Estimate the SPoC decomposition on epochs. @@ -693,6 +808,7 @@ def fit(self, X, y): Returns the modified instance. """ from scipy import linalg + self._check_Xy(X, y) if len(np.unique(y)) < 2: @@ -711,8 +827,11 @@ def fit(self, X, y): covs = np.empty((n_epochs, n_channels, n_channels)) for ii, epoch in enumerate(X): covs[ii] = _regularized_covariance( - epoch, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank) + epoch, + reg=self.reg, + method_params=self.cov_method_params, + rank=self.rank, + ) C = covs.mean(0) Cz = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) @@ -731,11 +850,11 @@ def fit(self, X, y): self.patterns_ = linalg.pinv(evecs).T # n_channels x n_channels self.filters_ = evecs # n_channels x n_channels - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean band power) - X = (X ** 2).mean(axis=-1) + X = (X**2).mean(axis=-1) # To standardize features self.mean_ = X.mean(axis=0) diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index 3f125bfb74a..f0dabe4d681 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -41,11 +41,13 @@ class EMS(TransformerMixin, EstimatorMixin): """ def __repr__(self): # noqa: D105 - if hasattr(self, 'filters_'): - return '' % ( - len(self.filters_), len(self.classes_)) + if hasattr(self, "filters_"): + return "" % ( + len(self.filters_), + len(self.classes_), + ) else: - return '' + return "" def fit(self, X, y): """Fit the spatial filters. @@ -67,7 +69,7 @@ def fit(self, X, y): """ classes = np.unique(y) if len(classes) != 2: - raise ValueError('EMS only works for binary classification.') + raise ValueError("EMS only works for binary classification.") self.classes_ = classes filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] @@ -92,8 +94,9 @@ def transform(self, X): @verbose -def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, - verbose=None): +def compute_ems( + epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None +): """Compute event-matched spatial filter on epochs. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -141,16 +144,18 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, ---------- .. footbibliography:: """ - logger.info('...computing surrogate time series. This can take some time') + logger.info("...computing surrogate time series. This can take some time") # Default to leave-one-out cv - cv = 'LeaveOneOut' if cv is None else cv + cv = "LeaveOneOut" if cv is None else cv picks = _picks_to_idx(epochs.info, picks) if not len(set(Counter(epochs.events[:, 2]).values())) == 1: - raise ValueError('The same number of epochs is required by ' - 'this function. Please consider ' - '`epochs.equalize_event_counts`') + raise ValueError( + "The same number of epochs is required by " + "this function. Please consider " + "`epochs.equalize_event_counts`" + ) if conditions is None: conditions = epochs.event_id.keys() @@ -161,9 +166,10 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, epochs.drop_bad() if len(conditions) != 2: - raise ValueError('Currently this function expects exactly 2 ' - 'conditions but you gave me %i' % - len(conditions)) + raise ValueError( + "Currently this function expects exactly 2 " + "conditions but you gave me %i" % len(conditions) + ) ev = epochs.events[:, 2] # Special care to avoid path dependent mappings and orders @@ -175,10 +181,10 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, # Scale (z-score) the data by channel type # XXX the z-scoring is applied outside the CV, which is not standard. - for ch_type in ['mag', 'grad', 'eeg']: + for ch_type in ["mag", "grad", "eeg"]: if ch_type in epochs: # FIXME should be applied to all sort of data channels - if ch_type == 'eeg': + if ch_type == "eeg": this_picks = pick_types(info, meg=False, eeg=True) else: this_picks = pick_types(info, meg=ch_type, eeg=False) @@ -187,15 +193,16 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, # Setup cross-validation. Need to use _set_cv to deal with sklearn # deprecation of cv objects. y = epochs.events[:, 2] - _, cv_splits = _set_cv(cv, 'classifier', X=y, y=y) + _, cv_splits = _set_cv(cv, "classifier", X=y, y=y) parallel, p_func, n_jobs = parallel_func(_run_ems, n_jobs=n_jobs) # FIXME this parallelization should be removed. # 1) it's numpy computation so it's already efficient, # 2) it duplicates the data in RAM, # 3) the computation is already super fast. - out = parallel(p_func(_ems_diff, data, cond_idx, train, test) - for train, test in cv_splits) + out = parallel( + p_func(_ems_diff, data, cond_idx, train, test) for train, test in cv_splits + ) surrogate_trials, spatial_filter = zip(*out) surrogate_trials = np.array(surrogate_trials) @@ -212,6 +219,6 @@ def _ems_diff(data0, data1): def _run_ems(objective_function, data, cond_idx, train, test): """Run EMS.""" d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx)) - d /= np.sqrt(np.sum(d ** 2, axis=0))[None, :] + d /= np.sqrt(np.sum(d**2, axis=0))[None, :] # compute surrogates return np.sum(data[test[0]] * d, axis=0), d diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py index c000ae4b74d..d009e0a23ba 100644 --- a/mne/decoding/mixin.py +++ b/mne/decoding/mixin.py @@ -61,23 +61,26 @@ def set_params(self, **params): return self valid_params = self.get_params(deep=True) for key, value in params.items(): - split = key.split('__', 1) + split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." % (name, self) + ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." + % (key, self.__class__.__name__) + ) setattr(self, key, value) return self diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index cf6e6dd35bc..6fa38a4f72f 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -103,14 +103,25 @@ class ReceptiveField(BaseEstimator): """ # noqa E501 @verbose - def __init__(self, tmin, tmax, sfreq, feature_names=None, estimator=None, - fit_intercept=None, scoring='r2', patterns=False, - n_jobs=None, edge_correction=True, verbose=None): + def __init__( + self, + tmin, + tmax, + sfreq, + feature_names=None, + estimator=None, + fit_intercept=None, + scoring="r2", + patterns=False, + n_jobs=None, + edge_correction=True, + verbose=None, + ): self.feature_names = feature_names self.sfreq = float(sfreq) self.tmin = tmin self.tmax = tmax - self.estimator = 0. if estimator is None else estimator + self.estimator = 0.0 if estimator is None else estimator self.fit_intercept = fit_intercept self.scoring = scoring self.patterns = patterns @@ -123,7 +134,7 @@ def __repr__(self): # noqa: D105 if not isinstance(estimator, str): estimator = type(self.estimator) s += "estimator : %s, " % (estimator,) - if hasattr(self, 'coef_'): + if hasattr(self, "coef_"): if self.feature_names is not None: feats = self.feature_names if len(feats) == 1: @@ -133,7 +144,7 @@ def __repr__(self): # noqa: D105 s += "fit: True" else: s += "fit: False" - if hasattr(self, 'scores_'): + if hasattr(self, "scores_"): s += "scored (%s)" % self.scoring return "" % s @@ -141,12 +152,13 @@ def _delay_and_reshape(self, X, y=None): """Delay and reshape the variables.""" if not isinstance(self.estimator_, TimeDelayingRidge): # X is now shape (n_times, n_epochs, n_feats, n_delays) - X = _delay_time_series(X, self.tmin, self.tmax, self.sfreq, - fill_mean=self.fit_intercept) + X = _delay_time_series( + X, self.tmin, self.tmax, self.sfreq, fill_mean=self.fit_intercept + ) X = _reshape_for_est(X) # Concat times + epochs if y is not None: - y = y.reshape(-1, y.shape[-1], order='F') + y = y.reshape(-1, y.shape[-1], order="F") return X, y def fit(self, X, y): @@ -165,15 +177,20 @@ def fit(self, X, y): The instance so you can chain operations. """ from scipy import linalg + if self.scoring not in _SCORERS.keys(): - raise ValueError('scoring must be one of %s, got' - '%s ' % (sorted(_SCORERS.keys()), self.scoring)) + raise ValueError( + "scoring must be one of %s, got" + "%s " % (sorted(_SCORERS.keys()), self.scoring) + ) from sklearn.base import clone + X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: - raise ValueError('tmin (%s) must be at most tmax (%s)' - % (self.tmin, self.tmax)) + raise ValueError( + "tmin (%s) must be at most tmax (%s)" % (self.tmin, self.tmax) + ) # Initialize delays self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq) @@ -184,23 +201,33 @@ def fit(self, X, y): if self.fit_intercept is None: self.fit_intercept = True estimator = TimeDelayingRidge( - self.tmin, self.tmax, self.sfreq, alpha=self.estimator, - fit_intercept=self.fit_intercept, n_jobs=self.n_jobs, - edge_correction=self.edge_correction) + self.tmin, + self.tmax, + self.sfreq, + alpha=self.estimator, + fit_intercept=self.fit_intercept, + n_jobs=self.n_jobs, + edge_correction=self.edge_correction, + ) elif is_regressor(self.estimator): estimator = clone(self.estimator) - if self.fit_intercept is not None and \ - estimator.fit_intercept != self.fit_intercept: + if ( + self.fit_intercept is not None + and estimator.fit_intercept != self.fit_intercept + ): raise ValueError( - 'Estimator fit_intercept (%s) != initialization ' - 'fit_intercept (%s), initialize ReceptiveField with the ' - 'same fit_intercept value or use fit_intercept=None' - % (estimator.fit_intercept, self.fit_intercept)) + "Estimator fit_intercept (%s) != initialization " + "fit_intercept (%s), initialize ReceptiveField with the " + "same fit_intercept value or use fit_intercept=None" + % (estimator.fit_intercept, self.fit_intercept) + ) self.fit_intercept = estimator.fit_intercept else: - raise ValueError('`estimator` must be a float or an instance' - ' of `BaseEstimator`,' - ' got type %s.' % type(self.estimator)) + raise ValueError( + "`estimator` must be a float or an instance" + " of `BaseEstimator`," + " got type %s." % type(self.estimator) + ) self.estimator_ = estimator del estimator _check_estimator(self.estimator_) @@ -211,16 +238,17 @@ def fit(self, X, y): n_delays = len(self.delays_) # Update feature names if we have none - if ((self.feature_names is not None) and - (len(self.feature_names) != n_feats)): - raise ValueError('n_features in X does not match feature names ' - '(%s != %s)' % (n_feats, len(self.feature_names))) + if (self.feature_names is not None) and (len(self.feature_names) != n_feats): + raise ValueError( + "n_features in X does not match feature names " + "(%s != %s)" % (n_feats, len(self.feature_names)) + ) # Create input features X, y = self._delay_and_reshape(X, y) self.estimator_.fit(X, y) - coef = get_coef(self.estimator_, 'coef_') # (n_targets, n_features) + coef = get_coef(self.estimator_, "coef_") # (n_targets, n_features) shape = [n_feats, n_delays] if self._y_dim > 1: shape.insert(0, -1) @@ -230,7 +258,7 @@ def fit(self, X, y): if self.patterns: if isinstance(self.estimator_, TimeDelayingRidge): cov_ = self.estimator_.cov_ / float(n_times * n_epochs - 1) - y = y.reshape(-1, y.shape[-1], order='F') + y = y.reshape(-1, y.shape[-1], order="F") else: X = X - X.mean(0, keepdims=True) cov_ = np.cov(X.T) @@ -241,7 +269,7 @@ def fit(self, X, y): y = y - y.mean(0, keepdims=True) inv_Y = linalg.pinv(np.cov(y.T)) else: - inv_Y = 1. / float(n_times * n_epochs - 1) + inv_Y = 1.0 / float(n_times * n_epochs - 1) del y # Inverse coef according to Haufe's method @@ -267,8 +295,8 @@ def predict(self, X): unaffected by edge artifacts during the time delaying step) can be obtained using ``y_pred[rf.valid_samples_]``. """ - if not hasattr(self, 'delays_'): - raise ValueError('Estimator has not been fit yet.') + if not hasattr(self, "delays_"): + raise ValueError("Estimator has not been fit yet.") X, _, X_dim = self._check_dimensions(X, None, predict=True)[:3] del _ # convert to sklearn and back @@ -277,14 +305,14 @@ def predict(self, X): pred_shape = pred_shape + (self.coef_.shape[0],) X, _ = self._delay_and_reshape(X) y_pred = self.estimator_.predict(X) - y_pred = y_pred.reshape(pred_shape, order='F') + y_pred = y_pred.reshape(pred_shape, order="F") shape = list(y_pred.shape) if X_dim <= 2: shape.pop(1) # epochs extra = 0 else: extra = 1 - shape = shape[:self._y_dim + extra] + shape = shape[: self._y_dim + extra] y_pred.shape = shape return y_pred @@ -319,10 +347,10 @@ def score(self, X, y): y = y[self.valid_samples_] # Re-vectorize and call scorer - y = y.reshape([-1, n_outputs], order='F') - y_pred = y_pred.reshape([-1, n_outputs], order='F') + y = y.reshape([-1, n_outputs], order="F") + y_pred = y_pred.reshape([-1, n_outputs], order="F") assert y.shape == y_pred.shape - scores = scorer_(y, y_pred, multioutput='raw_values') + scores = scorer_(y, y_pred, multioutput="raw_values") return scores def _check_dimensions(self, X, y, predict=False): @@ -337,28 +365,39 @@ def _check_dimensions(self, X, y, predict=False): elif y_dim == 2: y = y[:, np.newaxis, :] # epochs else: - raise ValueError('y must be shape (n_times[, n_epochs]' - '[,n_outputs], got %s' % (y.shape,)) + raise ValueError( + "y must be shape (n_times[, n_epochs]" + "[,n_outputs], got %s" % (y.shape,) + ) elif X.ndim == 3: if y is not None: if y.ndim == 2: y = y[:, :, np.newaxis] # Add an outputs dim elif y.ndim != 3: - raise ValueError('If X has 3 dimensions, ' - 'y must have 2 or 3 dimensions') + raise ValueError( + "If X has 3 dimensions, " "y must have 2 or 3 dimensions" + ) else: - raise ValueError('X must be shape (n_times[, n_epochs],' - ' n_features), got %s' % (X.shape,)) + raise ValueError( + "X must be shape (n_times[, n_epochs]," + " n_features), got %s" % (X.shape,) + ) if y is not None: if X.shape[0] != y.shape[0]: - raise ValueError('X and y do not have the same n_times\n' - '%s != %s' % (X.shape[0], y.shape[0])) + raise ValueError( + "X and y do not have the same n_times\n" + "%s != %s" % (X.shape[0], y.shape[0]) + ) if X.shape[1] != y.shape[1]: - raise ValueError('X and y do not have the same n_epochs\n' - '%s != %s' % (X.shape[1], y.shape[1])) + raise ValueError( + "X and y do not have the same n_epochs\n" + "%s != %s" % (X.shape[1], y.shape[1]) + ) if predict and y.shape[-1] != len(self.estimator_.coef_): - raise ValueError('Number of outputs does not match' - ' estimator coefficients dimensions') + raise ValueError( + "Number of outputs does not match" + " estimator coefficients dimensions" + ) return X, y, X_dim, y_dim @@ -423,15 +462,14 @@ def _delay_time_series(X, tmin, tmax, sfreq, fill_mean=False): use_X = X out[:] = use_X if fill_mean: - out[:] += (mean_value - use_X.mean(axis=0)) + out[:] += mean_value - use_X.mean(axis=0) return delayed def _times_to_delays(tmin, tmax, sfreq): """Convert a tmin/tmax in seconds to delays.""" # Convert seconds to samples - delays = np.arange(int(np.round(tmin * sfreq)), - int(np.round(tmax * sfreq) + 1)) + delays = np.arange(int(np.round(tmin * sfreq)), int(np.round(tmax * sfreq) + 1)) return delays @@ -446,37 +484,39 @@ def _delays_to_slice(delays): def _check_delayer_params(tmin, tmax, sfreq): """Check delayer input parameters. For future custom delay support.""" - _validate_type(sfreq, 'numeric', '`sfreq`') + _validate_type(sfreq, "numeric", "`sfreq`") for tlim in (tmin, tmax): - _validate_type(tlim, 'numeric', 'tmin/tmax') + _validate_type(tlim, "numeric", "tmin/tmax") if not tmin <= tmax: - raise ValueError('tmin must be <= tmax') + raise ValueError("tmin must be <= tmax") def _reshape_for_est(X_del): """Convert X_del to a sklearn-compatible shape.""" n_times, n_epochs, n_feats, n_delays = X_del.shape X_del = X_del.reshape(n_times, n_epochs, -1) # concatenate feats - X_del = X_del.reshape(n_times * n_epochs, -1, order='F') + X_del = X_del.reshape(n_times * n_epochs, -1, order="F") return X_del # Create a correlation scikit-learn-style scorer def _corr_score(y_true, y, multioutput=None): from scipy.stats import pearsonr - assert multioutput == 'raw_values' + + assert multioutput == "raw_values" for this_y in (y_true, y): if this_y.ndim != 2: - raise ValueError('inputs must be shape (samples, outputs), got %s' - % (this_y.shape,)) - return np.array([pearsonr(y_true[:, ii], y[:, ii])[0] - for ii in range(y.shape[-1])]) + raise ValueError( + "inputs must be shape (samples, outputs), got %s" % (this_y.shape,) + ) + return np.array([pearsonr(y_true[:, ii], y[:, ii])[0] for ii in range(y.shape[-1])]) def _r2_score(y_true, y, multioutput=None): from sklearn.metrics import r2_score + return r2_score(y_true, y, multioutput=multioutput) -_SCORERS = {'r2': _r2_score, 'corrcoef': _corr_score} +_SCORERS = {"r2": _r2_score, "corrcoef": _corr_score} diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 81c83b256a4..f2671b7ea11 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -8,8 +8,7 @@ from .base import BaseEstimator, _check_estimator from ..fixes import _get_check_scoring from ..parallel import parallel_func -from ..utils import (array_split_idx, ProgressBar, - verbose, fill_doc, _parse_verbose) +from ..utils import array_split_idx, ProgressBar, verbose, fill_doc, _parse_verbose @fill_doc @@ -35,8 +34,9 @@ class SlidingEstimator(BaseEstimator, TransformerMixin): """ @verbose - def __init__(self, base_estimator, scoring=None, n_jobs=None, *, - position=0, verbose=None): # noqa: D102 + def __init__( + self, base_estimator, scoring=None, n_jobs=None, *, position=0, verbose=None + ): # noqa: D102 _check_estimator(base_estimator) self._estimator_type = getattr(base_estimator, "_estimator_type", None) self.base_estimator = base_estimator @@ -46,11 +46,11 @@ def __init__(self, base_estimator, scoring=None, n_jobs=None, *, self.verbose = verbose def __repr__(self): # noqa: D105 - repr_str = '<' + super(SlidingEstimator, self).__repr__() - if hasattr(self, 'estimators_'): + repr_str = "<" + super(SlidingEstimator, self).__repr__() + if hasattr(self, "estimators_"): repr_str = repr_str[:-1] - repr_str += ', fitted with %i estimators' % len(self.estimators_) - return repr_str + '>' + repr_str += ", fitted with %i estimators" % len(self.estimators_) + return repr_str + ">" def fit(self, X, y, **fit_params): """Fit a series of independent estimators to the dataset. @@ -74,16 +74,16 @@ def fit(self, X, y, **fit_params): """ self._check_Xy(X, y) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) self.estimators_ = list() self.fit_params = fit_params # For fitting, the parallelization is across estimators. - context = _create_progressbar_context(self, X, 'Fitting') + context = _create_progressbar_context(self, X, "Fitting") with context as pb: estimators = parallel( - p_func(self.base_estimator, split, y, - pb.subset(pb_idx), **fit_params) + p_func(self.base_estimator, split, y, pb.subset(pb_idx), **fit_params) for pb_idx, split in array_split_idx(X, n_jobs, axis=-1) ) @@ -126,17 +126,17 @@ def _transform(self, X, method): self._check_Xy(X) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): - raise ValueError('The number of estimators does not match ' - 'X.shape[-1]') + raise ValueError("The number of estimators does not match " "X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) X_splits = np.array_split(X, n_jobs, axis=-1) idx, est_splits = zip(*array_split_idx(self.estimators_, n_jobs)) - context = _create_progressbar_context(self, X, 'Transforming') + context = _create_progressbar_context(self, X, "Transforming") with context as pb: y_pred = parallel( p_func(est, x, method, pb.subset(pb_idx)) @@ -166,7 +166,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, 'transform') + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -188,7 +188,7 @@ def predict(self, X): y_pred : array, shape (n_samples, n_estimators) | (n_samples, n_tasks, n_targets) Predicted values for each estimator/data slice. """ # noqa: E501 - return self._transform(X, 'predict') + return self._transform(X, "predict") def predict_proba(self, X): """Predict each data slice with a series of independent estimators. @@ -210,7 +210,7 @@ def predict_proba(self, X): y_pred : array, shape (n_samples, n_tasks, n_classes) Predicted probabilities for each estimator/data slice/task. """ # noqa: E501 - return self._transform(X, 'predict_proba') + return self._transform(X, "predict_proba") def decision_function(self, X): """Estimate distances of each data slice to the hyperplanes. @@ -233,15 +233,15 @@ def decision_function(self, X): ----- This requires base_estimator to have a ``decision_function`` method. """ # noqa: E501 - return self._transform(X, 'decision_function') + return self._transform(X, "decision_function") def _check_Xy(self, X, y=None): """Aux. function to check input data.""" if y is not None: if len(X) != len(y) or len(y) < 1: - raise ValueError('X and y must have the same length.') + raise ValueError("X and y must have the same length.") if X.ndim < 3: - raise ValueError('X must have at least 3 dimensions.') + raise ValueError("X must have at least 3 dimensions.") def score(self, X, y): """Score each estimator on each task. @@ -270,8 +270,7 @@ def score(self, X, y): self._check_Xy(X) if X.shape[-1] != len(self.estimators_): - raise ValueError('The number of estimators does not match ' - 'X.shape[-1]') + raise ValueError("The number of estimators does not match " "X.shape[-1]") scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) @@ -279,21 +278,25 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) - score = parallel(p_func(est, scoring, x, y) - for (est, x) in zip(est_splits, X_splits)) + score = parallel( + p_func(est, scoring, x, y) for (est, x) in zip(est_splits, X_splits) + ) score = np.concatenate(score, axis=0) return score @property def classes_(self): - if not hasattr(self.estimators_[0], 'classes_'): - raise AttributeError('classes_ attribute available only if ' - 'base_estimator has it, and estimator %s does' - ' not' % (self.estimators_[0],)) + if not hasattr(self.estimators_[0], "classes_"): + raise AttributeError( + "classes_ attribute available only if " + "base_estimator has it, and estimator %s does" + " not" % (self.estimators_[0],) + ) return self.estimators_[0].classes_ @@ -322,6 +325,7 @@ def _sl_fit(estimator, X, y, pb, **fit_params): The fitted estimators. """ from sklearn.base import clone + estimators_ = list() for ii in range(X.shape[-1]): est = clone(estimator) @@ -410,10 +414,10 @@ def _check_method(estimator, method): If method == 'transform' and estimator does not have 'transform', use 'predict' instead. """ - if method == 'transform' and not hasattr(estimator, 'transform'): - method = 'predict' + if method == "transform" and not hasattr(estimator, "transform"): + method = "predict" if not hasattr(estimator, method): - ValueError('base_estimator does not have `%s` method.' % method) + ValueError("base_estimator does not have `%s` method." % method) return method @@ -435,9 +439,9 @@ class GeneralizingEstimator(SlidingEstimator): def __repr__(self): # noqa: D105 repr_str = super(GeneralizingEstimator, self).__repr__() - if hasattr(self, 'estimators_'): + if hasattr(self, "estimators_"): repr_str = repr_str[:-1] - repr_str += ', fitted with %i estimators>' % len(self.estimators_) + repr_str += ", fitted with %i estimators>" % len(self.estimators_) return repr_str def _transform(self, X, method): @@ -446,14 +450,16 @@ def _transform(self, X, method): method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) - context = _create_progressbar_context(self, X, 'Transforming') + context = _create_progressbar_context(self, X, "Transforming") with context as pb: y_pred = parallel( p_func(self.estimators_, x_split, method, pb.subset(pb_idx)) for pb_idx, x_split in array_split_idx( - X, n_jobs, axis=-1, n_per_split=len(self.estimators_)) + X, n_jobs, axis=-1, n_per_split=len(self.estimators_) + ) ) y_pred = np.concatenate(y_pred, axis=2) @@ -475,7 +481,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ - return self._transform(X, 'transform') + return self._transform(X, "transform") def predict(self, X): """Predict each data slice with all possible estimators. @@ -493,7 +499,7 @@ def predict(self, X): y_pred : array, shape (n_samples, n_estimators, n_slices) | (n_samples, n_estimators, n_slices, n_targets) The predicted values for each estimator. """ # noqa: E501 - return self._transform(X, 'predict') + return self._transform(X, "predict") def predict_proba(self, X): """Estimate probabilistic estimates of each data slice with all possible estimators. @@ -515,7 +521,7 @@ def predict_proba(self, X): ----- This requires ``base_estimator`` to have a ``predict_proba`` method. """ # noqa: E501 - return self._transform(X, 'predict_proba') + return self._transform(X, "predict_proba") def decision_function(self, X): """Estimate distances of each data slice to all hyperplanes. @@ -539,7 +545,7 @@ def decision_function(self, X): This requires ``base_estimator`` to have a ``decision_function`` method. """ # noqa: E501 - return self._transform(X, 'decision_function') + return self._transform(X, "decision_function") def score(self, X, y): """Score each of the estimators on the tested dimensions. @@ -565,16 +571,18 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) - context = _create_progressbar_context(self, X, 'Scoring') + context = _create_progressbar_context(self, X, "Scoring") with context as pb: score = parallel( p_func(self.estimators_, scoring, x, y, pb.subset(pb_idx)) for pb_idx, x in array_split_idx( - X, n_jobs, axis=-1, n_per_split=len(self.estimators_)) + X, n_jobs, axis=-1, n_per_split=len(self.estimators_) + ) ) score = np.concatenate(score, axis=1) @@ -628,8 +636,7 @@ def _gl_init_pred(y_pred, X, n_train): """Aux. function to GeneralizingEstimator to initialize y_pred.""" n_sample, n_iter = X.shape[0], X.shape[-1] if y_pred.ndim == 3: - y_pred = np.zeros((n_sample, n_train, n_iter, y_pred.shape[-1]), - y_pred.dtype) + y_pred = np.zeros((n_sample, n_train, n_iter, y_pred.shape[-1]), y_pred.dtype) else: y_pred = np.zeros((n_sample, n_train, n_iter), y_pred.dtype) return y_pred @@ -679,30 +686,34 @@ def _gl_score(estimators, scoring, X, y, pb): def _fix_auc(scoring, y): from sklearn.preprocessing import LabelEncoder + # This fixes sklearn's inability to compute roc_auc when y not in [0, 1] # scikit-learn/scikit-learn#6874 if scoring is not None: - score_func = getattr(scoring, '_score_func', None) - kwargs = getattr(scoring, '_kwargs', {}) - if (getattr(score_func, '__name__', '') == 'roc_auc_score' and - kwargs.get('multi_class', 'raise') == 'raise'): + score_func = getattr(scoring, "_score_func", None) + kwargs = getattr(scoring, "_kwargs", {}) + if ( + getattr(score_func, "__name__", "") == "roc_auc_score" + and kwargs.get("multi_class", "raise") == "raise" + ): if np.ndim(y) != 1 or len(set(y)) != 2: - raise ValueError('roc_auc scoring can only be computed for ' - 'two-class problems.') + raise ValueError( + "roc_auc scoring can only be computed for " "two-class problems." + ) y = LabelEncoder().fit_transform(y) return y def _create_progressbar_context(inst, X, message): """Create a progress bar taking into account ``inst.verbose``.""" - multiply = (len(inst.estimators_) - if isinstance(inst, GeneralizingEstimator) else 1) + multiply = len(inst.estimators_) if isinstance(inst, GeneralizingEstimator) else 1 n_steps = X.shape[-1] * max(1, multiply) - mesg = f'{message} {inst.__class__.__name__}' + mesg = f"{message} {inst.__class__.__name__}" - which_tqdm = 'off' if not _check_verbose(inst.verbose) else None - context = ProgressBar(n_steps, mesg=mesg, position=inst.position, - which_tqdm=which_tqdm) + which_tqdm = "off" if not _check_verbose(inst.verbose) else None + context = ProgressBar( + n_steps, mesg=mesg, position=inst.position, which_tqdm=which_tqdm + ) return context diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 4739264f544..8b747e4e350 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -13,8 +13,13 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - fill_doc, logger, _check_option, _time_mask, _validate_type, - _verbose_safe_false) + fill_doc, + logger, + _check_option, + _time_mask, + _validate_type, + _verbose_safe_false, +) @fill_doc @@ -84,52 +89,66 @@ class SSD(BaseEstimator, TransformerMixin): .. footbibliography:: """ - def __init__(self, info, filt_params_signal, filt_params_noise, - reg=None, n_components=None, picks=None, - sort_by_spectral_ratio=True, return_filtered=False, - n_fft=None, cov_method_params=None, rank=None): + def __init__( + self, + info, + filt_params_signal, + filt_params_noise, + reg=None, + n_components=None, + picks=None, + sort_by_spectral_ratio=True, + return_filtered=False, + n_fft=None, + cov_method_params=None, + rank=None, + ): """Initialize instance.""" dicts = {"signal": filt_params_signal, "noise": filt_params_noise} - for param, dd in [('l', 0), ('h', 0), ('l', 1), ('h', 1)]: - key = ('signal', 'noise')[dd] - if param + '_freq' not in dicts[key]: + for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: + key = ("signal", "noise")[dd] + if param + "_freq" not in dicts[key]: raise ValueError( - '%s must be defined in filter parameters for %s' - % (param + '_freq', key)) - val = dicts[key][param + '_freq'] + "%s must be defined in filter parameters for %s" + % (param + "_freq", key) + ) + val = dicts[key][param + "_freq"] if not isinstance(val, (int, float)): - _validate_type(val, ('numeric',), f'{key} {param}_freq') + _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands - if (filt_params_noise['l_freq'] > filt_params_signal['l_freq'] or - filt_params_signal['h_freq'] > filt_params_noise['h_freq']): - raise ValueError('Wrongly specified frequency bands!\n' - 'The signal band-pass must be within the noise ' - 'band-pass!') - self.picks_ = _picks_to_idx(info, picks, none='data', exclude='bads') + if ( + filt_params_noise["l_freq"] > filt_params_signal["l_freq"] + or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + ): + raise ValueError( + "Wrongly specified frequency bands!\n" + "The signal band-pass must be within the noise " + "band-pass!" + ) + self.picks_ = _picks_to_idx(info, picks, none="data", exclude="bads") del picks ch_types = _get_channel_types(info, picks=self.picks_, unique=True) if len(ch_types) > 1: - raise ValueError('At this point SSD only supports fitting ' - 'single channel types. Your info has %i types' % - (len(ch_types))) + raise ValueError( + "At this point SSD only supports fitting " + "single channel types. Your info has %i types" % (len(ch_types)) + ) self.info = info - self.freqs_signal = (filt_params_signal['l_freq'], - filt_params_signal['h_freq']) - self.freqs_noise = (filt_params_noise['l_freq'], - filt_params_noise['h_freq']) + self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) + self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) self.filt_params_signal = filt_params_signal self.filt_params_noise = filt_params_noise # check if boolean if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError('sort_by_spectral_ratio must be boolean') + raise ValueError("sort_by_spectral_ratio must be boolean") self.sort_by_spectral_ratio = sort_by_spectral_ratio if n_fft is None: - self.n_fft = int(self.info['sfreq']) + self.n_fft = int(self.info["sfreq"]) else: self.n_fft = int(n_fft) # check if boolean if not isinstance(return_filtered, (bool)): - raise ValueError('return_filtered must be boolean') + raise ValueError("return_filtered must be boolean") self.return_filtered = return_filtered self.reg = reg self.n_components = n_components @@ -138,13 +157,14 @@ def __init__(self, info, filt_params_signal, filt_params_noise, def _check_X(self, X): """Check input data.""" - _validate_type(X, np.ndarray, 'X') - _check_option('X.ndim', X.ndim, (2, 3)) + _validate_type(X, np.ndarray, "X") + _check_option("X.ndim", X.ndim, (2, 3)) n_chan = X.shape[-2] - if n_chan != self.info['nchan']: - raise ValueError('Info must match the input data.' - 'Found %i channels but expected %i.' % - (n_chan, self.info['nchan'])) + if n_chan != self.info["nchan"]: + raise ValueError( + "Info must match the input data." + "Found %i channels but expected %i." % (n_chan, self.info["nchan"]) + ) def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -164,13 +184,12 @@ def fit(self, X, y=None): Returns the modified instance. """ from scipy import linalg + self._check_X(X) X_aux = X[..., self.picks_, :] - X_signal = filter_data( - X_aux, self.info['sfreq'], **self.filt_params_signal) - X_noise = filter_data( - X_aux, self.info['sfreq'], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -178,15 +197,24 @@ def fit(self, X, y=None): # prevent rank change when computing cov with rank='full' cov_signal = _regularized_covariance( - X_signal, reg=self.reg, method_params=self.cov_method_params, - rank='full', info=self.info) + X_signal, + reg=self.reg, + method_params=self.cov_method_params, + rank="full", + info=self.info, + ) cov_noise = _regularized_covariance( - X_noise, reg=self.reg, method_params=self.cov_method_params, - rank='full', info=self.info) + X_noise, + reg=self.reg, + method_params=self.cov_method_params, + rank="full", + info=self.info, + ) # project cov to rank subspace - cov_signal, cov_noise, rank_proj = (_dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank)) + cov_signal, cov_noise, rank_proj = _dimensionality_reduction( + cov_signal, cov_noise, self.info, self.rank + ) eigvals_, eigvects_ = linalg.eigh(cov_signal, cov_noise) # sort in descending order @@ -204,7 +232,7 @@ def fit(self, X, y=None): if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) self.sorter_spec = sorter_spec - logger.info('Done.') + logger.info("Done.") return self def transform(self, X): @@ -224,16 +252,15 @@ def transform(self, X): """ self._check_X(X) if self.filters_ is None: - raise RuntimeError('No filters available. Please first call fit') + raise RuntimeError("No filters available. Please first call fit") if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info['sfreq'], - **self.filt_params_signal) + X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][:self.n_components] + X_ssd = X_ssd[self.sorter_spec][: self.n_components] else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, :self.n_components, :] + X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] return X_ssd def get_spectral_ratio(self, ssd_sources): @@ -259,7 +286,8 @@ def get_spectral_ratio(self, ssd_sources): .. footbibliography:: """ psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info['sfreq'], n_fft=self.n_fft) + ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft + ) sig_idx = _time_mask(freqs, *self.freqs_signal) noise_idx = _time_mask(freqs, *self.freqs_noise) if psd.ndim == 3: @@ -275,7 +303,7 @@ def get_spectral_ratio(self, ssd_sources): def inverse_transform(self): """Not implemented yet.""" - raise NotImplementedError('inverse_transform is not yet available.') + raise NotImplementedError("inverse_transform is not yet available.") def apply(self, X): """Remove selected components from the signal. @@ -301,7 +329,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][:self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T X = pick_patterns @ X_ssd return X @@ -309,17 +337,40 @@ def apply(self, X): def _dimensionality_reduction(cov_signal, cov_noise, info, rank): """Perform dimensionality reduction on the covariance matrices.""" from scipy import linalg + n_channels = cov_signal.shape[0] # find ranks of covariance matrices - rank_signal = list(compute_rank( - Covariance(cov_signal, info.ch_names, list(), list(), 0, - verbose=_verbose_safe_false()), - rank, _handle_default('scalings_cov_rank', None), info).values())[0] - rank_noise = list(compute_rank( - Covariance(cov_noise, info.ch_names, list(), list(), 0, - verbose=_verbose_safe_false()), - rank, _handle_default('scalings_cov_rank', None), info).values())[0] + rank_signal = list( + compute_rank( + Covariance( + cov_signal, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + rank_noise = list( + compute_rank( + Covariance( + cov_noise, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] rank = np.min([rank_signal, rank_noise]) # should be identical if rank < n_channels: @@ -330,13 +381,18 @@ def _dimensionality_reduction(cov_signal, cov_noise, info, rank): eigvects = eigvects[:, ix] # compute rank subspace projection matrix rank_proj = np.matmul( - eigvects[:, :rank], np.eye(rank) * (eigvals[:rank]**-0.5)) + eigvects[:, :rank], np.eye(rank) * (eigvals[:rank] ** -0.5) + ) logger.info( - 'Projecting covariance of %i channels to %i rank subspace' - % (n_channels, rank,)) + "Projecting covariance of %i channels to %i rank subspace" + % ( + n_channels, + rank, + ) + ) else: rank_proj = np.eye(n_channels) - logger.info('Preserving covariance rank (%i)' % (rank,)) + logger.info("Preserving covariance rank (%i)" % (rank,)) # project covariance matrices to rank subspace cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj)) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 43f8b08d097..c7773a217d4 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -4,18 +4,27 @@ # License: BSD-3-Clause import numpy as np -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_equal, assert_allclose, assert_array_less) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_equal, + assert_allclose, + assert_array_less, +) import pytest from mne import create_info, EpochsArray from mne.fixes import is_regressor, is_classifier from mne.utils import requires_sklearn -from mne.decoding.base import (_get_inverse_funcs, LinearModel, get_coef, - cross_val_multiscore, BaseEstimator) +from mne.decoding.base import ( + _get_inverse_funcs, + LinearModel, + get_coef, + cross_val_multiscore, + BaseEstimator, +) from mne.decoding.search_light import SlidingEstimator -from mne.decoding import (Scaler, TransformerMixin, Vectorizer, - GeneralizingEstimator) +from mne.decoding import Scaler, TransformerMixin, Vectorizer, GeneralizingEstimator def _make_data(n_samples=1000, n_features=5, n_targets=3): @@ -43,7 +52,7 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): # Define Y latent factors np.random.seed(0) cov_Y = np.eye(n_targets) * 10 + np.random.rand(n_targets, n_targets) - cov_Y = (cov_Y + cov_Y.T) / 2. + cov_Y = (cov_Y + cov_Y.T) / 2.0 mean_Y = np.random.rand(n_targets) Y = np.random.multivariate_normal(mean_Y, cov_Y, size=n_samples) @@ -68,19 +77,21 @@ def test_get_coef(): from sklearn.model_selection import GridSearchCV lm_classification = LinearModel() - assert (is_classifier(lm_classification)) + assert is_classifier(lm_classification) lm_regression = LinearModel(Ridge()) - assert (is_regressor(lm_regression)) + assert is_regressor(lm_regression) - parameters = {'kernel': ['linear'], 'C': [1, 10]} + parameters = {"kernel": ["linear"], "C": [1, 10]} lm_gs_classification = LinearModel( - GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)) - assert (is_classifier(lm_gs_classification)) + GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None) + ) + assert is_classifier(lm_gs_classification) lm_gs_regression = LinearModel( - GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)) - assert (is_regressor(lm_gs_regression)) + GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None) + ) + assert is_regressor(lm_gs_regression) # Define a classifier, an invertible transformer and an non-invertible one. @@ -113,14 +124,15 @@ def inverse_transform(self, X): for expected_n, est in good_estimators: est.fit(X, y) - assert (expected_n == len(_get_inverse_funcs(est))) + assert expected_n == len(_get_inverse_funcs(est)) bad_estimators = [ Clf(), # no preprocessing Inv(), # final estimator isn't classifier make_pipeline(NoInv(), Clf()), # first step isn't invertible - make_pipeline(Inv(), make_pipeline( - Inv(), NoInv()), Clf()), # nested step isn't invertible + make_pipeline( + Inv(), make_pipeline(Inv(), NoInv()), Clf() + ), # nested step isn't invertible ] for est in bad_estimators: est.fit(X, y) @@ -129,11 +141,12 @@ def inverse_transform(self, X): # II. Test get coef for classification/regression estimators and pipelines rng = np.random.RandomState(0) - for clf in (lm_regression, - lm_gs_classification, - make_pipeline(StandardScaler(), lm_classification), - make_pipeline(StandardScaler(), lm_gs_regression)): - + for clf in ( + lm_regression, + lm_gs_classification, + make_pipeline(StandardScaler(), lm_classification), + make_pipeline(StandardScaler(), lm_gs_regression), + ): # generate some categorical/continuous data # according to the type of estimator. if is_classifier(clf): @@ -147,16 +160,16 @@ def inverse_transform(self, X): clf.fit(X, y) # Retrieve final linear model - filters = get_coef(clf, 'filters_', False) - if hasattr(clf, 'steps'): - if hasattr(clf.steps[-1][-1].model, 'best_estimator_'): + filters = get_coef(clf, "filters_", False) + if hasattr(clf, "steps"): + if hasattr(clf.steps[-1][-1].model, "best_estimator_"): # Linear Model with GridSearchCV coefs = clf.steps[-1][-1].model.best_estimator_.coef_ else: # Standard Linear Model coefs = clf.steps[-1][-1].model.coef_ else: - if hasattr(clf.model, 'best_estimator_'): + if hasattr(clf.model, "best_estimator_"): # Linear Model with GridSearchCV coefs = clf.model.best_estimator_.coef_ else: @@ -165,20 +178,19 @@ def inverse_transform(self, X): if coefs.ndim == 2 and coefs.shape[0] == 1: coefs = coefs[0] assert_array_equal(filters, coefs) - patterns = get_coef(clf, 'patterns_', False) - assert (filters[0] != patterns[0]) + patterns = get_coef(clf, "patterns_", False) + assert filters[0] != patterns[0] n_chans = X.shape[1] assert_array_equal(filters.shape, patterns.shape, [n_chans, n_chans]) # Inverse transform linear model - filters_inv = get_coef(clf, 'filters_', True) - assert (filters[0] != filters_inv[0]) - patterns_inv = get_coef(clf, 'patterns_', True) - assert (patterns[0] != patterns_inv[0]) + filters_inv = get_coef(clf, "filters_", True) + assert filters[0] != filters_inv[0] + patterns_inv = get_coef(clf, "patterns_", True) + assert patterns[0] != patterns_inv[0] class _Noop(BaseEstimator, TransformerMixin): - def fit(self, X, y=None): return self @@ -189,15 +201,19 @@ def transform(self, X): @requires_sklearn -@pytest.mark.parametrize('inverse', (True, False)) -@pytest.mark.parametrize('Scale, kwargs', [ - (Scaler, dict(info=None, scalings='mean')), - (_Noop, dict()), -]) +@pytest.mark.parametrize("inverse", (True, False)) +@pytest.mark.parametrize( + "Scale, kwargs", + [ + (Scaler, dict(info=None, scalings="mean")), + (_Noop, dict()), + ], +) def test_get_coef_inverse_transform(inverse, Scale, kwargs): """Test get_coef with and without inverse_transform.""" from sklearn.linear_model import Ridge from sklearn.pipeline import make_pipeline + lm_regression = LinearModel(Ridge()) X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1) # Check with search_light and combination of preprocessing ending with sl: @@ -208,29 +224,29 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs): X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples clf = make_pipeline(Scale(**kwargs), slider) clf.fit(X, y) - patterns = get_coef(clf, 'patterns_', inverse) - filters = get_coef(clf, 'filters_', inverse) + patterns = get_coef(clf, "patterns_", inverse) + filters = get_coef(clf, "filters_", inverse) assert_array_equal(filters.shape, patterns.shape, X.shape[1:]) # the two time samples get inverted patterns assert_equal(patterns[0, 0], -patterns[0, 1]) for t in [0, 1]: filters_t = get_coef( - clf.named_steps['slidingestimator'].estimators_[t], - 'filters_', False) + clf.named_steps["slidingestimator"].estimators_[t], "filters_", False + ) if Scale is _Noop: assert_array_equal(filters_t, filters[:, t]) @requires_sklearn -@pytest.mark.parametrize('n_features', [1, 5]) -@pytest.mark.parametrize('n_targets', [1, 3]) +@pytest.mark.parametrize("n_features", [1, 5]) +@pytest.mark.parametrize("n_targets", [1, 3]) def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor from sklearn.linear_model import LinearRegression, Ridge from sklearn.pipeline import make_pipeline - X, Y, A = _make_data( - n_samples=30000, n_features=n_features, n_targets=n_targets) + + X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) lm = LinearModel(LinearRegression()).fit(X, Y) assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: @@ -245,22 +261,22 @@ def test_get_coef_multiclass(n_features, n_targets): clf.fit(X, Y) if n_features > 1 and n_targets > 1: assert_allclose(A, lm.patterns_.T, atol=2e-2) - coef = get_coef(clf, 'patterns_', inverse_transform=True) + coef = get_coef(clf, "patterns_", inverse_transform=True) assert_allclose(lm.patterns_, coef, atol=1e-5) # With epochs, scaler, and vectorizer (typical use case) X_epo = X.reshape(X.shape + (1,)) - info = create_info(n_features, 1000., 'eeg') + info = create_info(n_features, 1000.0, "eeg") lm = LinearModel(Ridge(alpha=1)) clf = make_pipeline( - Scaler(info, scalings=dict(eeg=1.)), # XXX adding this step breaks + Scaler(info, scalings=dict(eeg=1.0)), # XXX adding this step breaks Vectorizer(), lm, ) clf.fit(X_epo, Y) if n_features > 1 and n_targets > 1: assert_allclose(A, lm.patterns_.T, atol=2e-2) - coef = get_coef(clf, 'patterns_', inverse_transform=True) + coef = get_coef(clf, "patterns_", inverse_transform=True) lm_patterns_ = lm.patterns_[..., np.newaxis] assert_allclose(lm_patterns_, coef, atol=1e-5) @@ -269,31 +285,36 @@ def test_get_coef_multiclass(n_features, n_targets): @requires_sklearn -@pytest.mark.parametrize('n_classes, n_channels, n_times', [ - (4, 10, 2), - (4, 3, 2), - (3, 2, 1), - (3, 1, 2), -]) +@pytest.mark.parametrize( + "n_classes, n_channels, n_times", + [ + (4, 10, 2), + (4, 3, 2), + (3, 2, 1), + (3, 1, 2), + ], +) def test_get_coef_multiclass_full(n_classes, n_channels, n_times): """Test a full example with pattern extraction.""" from sklearn.pipeline import make_pipeline from sklearn.linear_model import LogisticRegression from sklearn.model_selection import StratifiedKFold + data = np.zeros((10 * n_classes, n_channels, n_times)) # Make only the first channel informative for ii in range(n_classes): - data[ii * 10:(ii + 1) * 10, 0] = ii + data[ii * 10 : (ii + 1) * 10, 0] = ii events = np.zeros((len(data), 3), int) events[:, 0] = np.arange(len(events)) events[:, 2] = data[:, 0, 0] - info = create_info(n_channels, 1000., 'eeg') + info = create_info(n_channels, 1000.0, "eeg") epochs = EpochsArray(data, info, events, tmin=0) clf = make_pipeline( - Scaler(epochs.info), Vectorizer(), - LinearModel(LogisticRegression(random_state=0, multi_class='ovr')), + Scaler(epochs.info), + Vectorizer(), + LinearModel(LogisticRegression(random_state=0, multi_class="ovr")), ) - scorer = 'roc_auc_ovr_weighted' + scorer = "roc_auc_ovr_weighted" time_gen = GeneralizingEstimator(clf, scorer, verbose=True) X = epochs.get_data() y = epochs.events[:, 2] @@ -306,9 +327,9 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times): assert scores.shape == want assert_array_less(0.8, scores) clf.fit(X, y) - patterns = get_coef(clf, 'patterns_', inverse_transform=True) + patterns = get_coef(clf, "patterns_", inverse_transform=True) assert patterns.shape == (n_classes, n_channels, n_times) - assert_allclose(patterns[:, 1:], 0., atol=1e-7) # no other channels useful + assert_allclose(patterns[:, 1:], 0.0, atol=1e-7) # no other channels useful @requires_sklearn @@ -316,6 +337,7 @@ def test_linearmodel(): """Test LinearModel class for computing filters and patterns.""" # check categorical target fit in standard linear model from sklearn.linear_model import LinearRegression + rng = np.random.RandomState(0) clf = LinearModel() n, n_features = 20, 3 @@ -331,9 +353,11 @@ def test_linearmodel(): # check categorical target fit in standard linear model with GridSearchCV from sklearn import svm from sklearn.model_selection import GridSearchCV - parameters = {'kernel': ['linear'], 'C': [1, 10]} + + parameters = {"kernel": ["linear"], "C": [1, 10]} clf = LinearModel( - GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)) + GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None) + ) clf.fit(X, y) assert_equal(clf.filters_.shape, (n_features,)) assert_equal(clf.patterns_.shape, (n_features,)) @@ -345,10 +369,11 @@ def test_linearmodel(): n_targets = 1 Y = rng.rand(n, n_targets) clf = LinearModel( - GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)) + GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None) + ) clf.fit(X, y) - assert_equal(clf.filters_.shape, (n_features, )) - assert_equal(clf.patterns_.shape, (n_features, )) + assert_equal(clf.filters_.shape, (n_features,)) + assert_equal(clf.patterns_.shape, (n_features,)) with pytest.raises(ValueError): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) @@ -371,20 +396,21 @@ def test_cross_val_multiscore(): from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score from sklearn.linear_model import LogisticRegression, LinearRegression - logreg = LogisticRegression(solver='liblinear', random_state=0) + logreg = LogisticRegression(solver="liblinear", random_state=0) # compare to cross-val-score X = np.random.rand(20, 3) y = np.arange(20) % 2 cv = KFold(2, random_state=0, shuffle=True) clf = logreg - assert_array_equal(cross_val_score(clf, X, y, cv=cv), - cross_val_multiscore(clf, X, y, cv=cv)) + assert_array_equal( + cross_val_score(clf, X, y, cv=cv), cross_val_multiscore(clf, X, y, cv=cv) + ) # Test with search light X = np.random.rand(20, 4, 3) y = np.arange(20) % 2 - clf = SlidingEstimator(logreg, scoring='accuracy') + clf = SlidingEstimator(logreg, scoring="accuracy") scores_acc = cross_val_multiscore(clf, X, y, cv=cv) assert_array_equal(np.shape(scores_acc), [2, 3]) @@ -399,9 +425,8 @@ def test_cross_val_multiscore(): # raise an error if scoring is defined at cross-val-score level and # search light, because search light does not return a 1-dimensional # prediction. - pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, - scoring='roc_auc') - clf = SlidingEstimator(logreg, scoring='roc_auc') + pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, scoring="roc_auc") + clf = SlidingEstimator(logreg, scoring="roc_auc") scores_auc = cross_val_multiscore(clf, X, y, cv=cv, n_jobs=None) scores_auc_manual = list() for train, test in cv.split(X, y): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 6945a812cf7..a505a6c7bdc 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -9,8 +9,7 @@ import numpy as np import pytest -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) +from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal from mne import io, Epochs, read_events, pick_types from mne.decoding.csp import CSP, _ajd_pham, SPoC @@ -46,23 +45,39 @@ def simulate_data(target, n_trials=100, n_channels=10, random_state=42): return X, mixing_mat -def deterministic_toy_data(classes=('class_a', 'class_b')): +def deterministic_toy_data(classes=("class_a", "class_b")): """Generate a small deterministic toy data set. Four independent sources are modulated by the target class and mixed into signal space. """ - sources_a = np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], - dtype=float) * 2 - 1 - - sources_b = np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], - dtype=float) * 2 - 1 + sources_a = ( + np.array( + [ + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=float, + ) + * 2 + - 1 + ) + + sources_b = ( + np.array( + [ + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=float, + ) + * 2 + - 1 + ) sources_a[0, :] *= 1 sources_a[1, :] *= 2 @@ -70,10 +85,14 @@ def deterministic_toy_data(classes=('class_a', 'class_b')): sources_b[2, :] *= 3 sources_b[3, :] *= 4 - mixing = np.array([[1.0, 0.8, 0.6, 0.4], - [0.8, 1.0, 0.8, 0.6], - [0.6, 0.8, 1.0, 0.8], - [0.4, 0.6, 0.8, 1.0]]) + mixing = np.array( + [ + [1.0, 0.8, 0.6, 0.4], + [0.8, 1.0, 0.8, 0.6], + [0.6, 0.8, 1.0, 0.8], + [0.4, 0.6, 0.8, 1.0], + ] + ) x_class_a = mixing @ sources_a x_class_b = mixing @ sources_b @@ -89,28 +108,38 @@ def test_csp(): """Test Common Spatial Patterns algorithm on epochs.""" raw = io.read_raw_fif(raw_fname, preload=False) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[2:12:3] # subselect channels -> disable proj! raw.add_proj([], remove_existing=True) - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True, proj=False) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components='foo', norm_trace=False) - for reg in ['foo', -0.1, 1.1]: + pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) - for reg in ['oas', 'ledoit_wolf', 0, 0.5, 1.]: + for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ['foo', None]: + for cov_est in ["foo", None]: pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) - with pytest.raises(TypeError, match='instance of bool'): - CSP(norm_trace='foo') - for cov_est in ['concat', 'epoch']: + with pytest.raises(TypeError, match="instance of bool"): + CSP(norm_trace="foo") + for cov_est in ["concat", "epoch"]: CSP(cov_est=cov_est, norm_trace=False) n_components = 3 @@ -125,33 +154,40 @@ def test_csp(): # Transform X = csp.fit_transform(epochs_data, y) sources = csp.transform(epochs_data) - assert (sources.shape[1] == n_components) - assert (csp.filters_.shape == (n_channels, n_channels)) - assert (csp.patterns_.shape == (n_channels, n_channels)) + assert sources.shape[1] == n_components + assert csp.filters_.shape == (n_channels, n_channels) + assert csp.patterns_.shape == (n_channels, n_channels) assert_array_almost_equal(sources, X) # Test data exception - pytest.raises(ValueError, csp.fit, epochs_data, - np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) pytest.raises(ValueError, csp.fit, epochs, y) pytest.raises(ValueError, csp.transform, epochs) # Test plots - epochs.pick_types(meg='mag') - cmap = ('RdBu', True) + epochs.pick_types(meg="mag") + cmap = ("RdBu", True) components = np.arange(n_components) for plot in (csp.plot_patterns, csp.plot_filters): plot(epochs.info, components=components, res=12, show=False, cmap=cmap) # Test with more than 2 classes - epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, picks=picks, - event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), - baseline=(None, 0), proj=False, preload=True) + epochs = Epochs( + raw, + events, + tmin=tmin, + tmax=tmax, + picks=picks, + event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), + baseline=(None, 0), + proj=False, + preload=True, + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_channels = epochs_data.shape[1] - for cov_est in ['concat', 'epoch']: + for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) assert_equal(len(csp._classes), 4) @@ -160,31 +196,31 @@ def test_csp(): # Test average power transform n_components = 2 - assert (csp.transform_into == 'average_power') + assert csp.transform_into == "average_power" feature_shape = [len(epochs_data), n_components] X_trans = dict() for log in (None, True, False): csp = CSP(n_components=n_components, log=log, norm_trace=False) - assert (csp.log is log) + assert csp.log is log Xt = csp.fit_transform(epochs_data, epochs.events[:, 2]) assert_array_equal(Xt.shape, feature_shape) X_trans[str(log)] = Xt # log=None => log=True - assert_array_almost_equal(X_trans['None'], X_trans['True']) + assert_array_almost_equal(X_trans["None"], X_trans["True"]) # Different normalization return different transform - assert (np.sum((X_trans['True'] - X_trans['False']) ** 2) > 1.) + assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into='average_power', log='foo') + pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") # Test csp space transform - csp = CSP(transform_into='csp_space', norm_trace=False) - assert (csp.transform_into == 'csp_space') - for log in ('foo', True, False): - pytest.raises(ValueError, CSP, transform_into='csp_space', log=log, - norm_trace=False) + csp = CSP(transform_into="csp_space", norm_trace=False) + assert csp.transform_into == "csp_space" + for log in ("foo", True, False): + pytest.raises( + ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False + ) n_components = 2 - csp = CSP(n_components=n_components, transform_into='csp_space', - norm_trace=False) + csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]] assert_array_equal(Xt.shape, feature_shape) @@ -193,7 +229,7 @@ def test_csp(): y = np.array([100] * 50 + [1] * 50) X, A = simulate_data(y) - for cov_est in ['concat', 'epoch']: + for cov_est in ["concat", "epoch"]: # fit csp csp = CSP(n_components=1, cov_est=cov_est, norm_trace=False) csp.fit(X, y) @@ -214,36 +250,35 @@ def test_regularized_csp(): """Test Common Spatial Patterns algorithm using regularized covariance.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_components = 3 - reg_cov = [None, 0.05, 'ledoit_wolf', 'oas'] + reg_cov = [None, 0.05, "ledoit_wolf", "oas"] for reg in reg_cov: - csp = CSP(n_components=n_components, reg=reg, norm_trace=False, - rank=None) + csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=None) csp.fit(epochs_data, epochs.events[:, -1]) y = epochs.events[:, -1] X = csp.fit_transform(epochs_data, y) - assert (csp.filters_.shape == (n_channels, n_channels)) - assert (csp.patterns_.shape == (n_channels, n_channels)) - assert_array_almost_equal(csp.fit(epochs_data, y). - transform(epochs_data), X) + assert csp.filters_.shape == (n_channels, n_channels) + assert csp.patterns_.shape == (n_channels, n_channels) + assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) # test init exception - pytest.raises(ValueError, csp.fit, epochs_data, - np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) pytest.raises(ValueError, csp.fit, epochs, y) pytest.raises(ValueError, csp.transform, epochs) csp.n_components = n_components sources = csp.transform(epochs_data) - assert (sources.shape[1] == n_components) + assert sources.shape[1] == n_components @requires_sklearn @@ -251,11 +286,12 @@ def test_csp_pipeline(): """Test if CSP works in a pipeline.""" from sklearn.svm import SVC from sklearn.pipeline import Pipeline + csp = CSP(reg=1, norm_trace=False) svc = SVC() pipe = Pipeline([("CSP", csp), ("SVC", svc)]) pipe.set_params(CSP__reg=0.2) - assert (pipe.get_params()["CSP__reg"] == 0.2) + assert pipe.get_params()["CSP__reg"] == 0.2 def test_ajd(): @@ -267,15 +303,17 @@ def test_ajd(): seed = np.random.RandomState(0) diags = 2.0 + 0.1 * seed.randn(n_times, n_channels) A = 2 * seed.rand(n_channels, n_channels) - 1 - A /= np.atleast_2d(np.sqrt(np.sum(A ** 2, 1))).T + A /= np.atleast_2d(np.sqrt(np.sum(A**2, 1))).T covmats = np.empty((n_times, n_channels, n_channels)) for i in range(n_times): covmats[i] = np.dot(np.dot(A, np.diag(diags[i])), A.T) V, D = _ajd_pham(covmats) # Results obtained with original matlab implementation - V_matlab = [[-3.507280775058041, -5.498189967306344, 7.720624541198574], - [0.694689013234610, 0.775690358505945, -1.162043086446043], - [-0.592603135588066, -0.598996925696260, 1.009550086271192]] + V_matlab = [ + [-3.507280775058041, -5.498189967306344, 7.720624541198574], + [0.694689013234610, 0.775690358505945, -1.162043086446043], + [-0.592603135588066, -0.598996925696260, 1.009550086271192], + ] assert_array_almost_equal(V, V_matlab) @@ -288,7 +326,7 @@ def test_spoc(): spoc.fit(X, y) Xt = spoc.transform(X) assert_array_equal(Xt.shape, [10, 4]) - spoc = SPoC(n_components=4, transform_into='csp_space') + spoc = SPoC(n_components=4, transform_into="csp_space") spoc.fit(X, y) Xt = spoc.transform(X) assert_array_equal(Xt.shape, [10, 4, 20]) @@ -299,7 +337,7 @@ def test_spoc(): pytest.raises(ValueError, spoc.fit, X, y * 0) # Check that doesn't take CSP-spcific input - pytest.raises(TypeError, SPoC, cov_est='epoch') + pytest.raises(TypeError, SPoC, cov_est="epoch") # Check mixing matrix on simulated data rs = np.random.RandomState(42) @@ -322,33 +360,32 @@ def test_spoc(): def test_csp_twoclass_symmetry(): """Test that CSP is symmetric when swapping classes.""" - x, y = deterministic_toy_data(['class_a', 'class_b']) - csp = CSP(norm_trace=False, transform_into='average_power', log=True) + x, y = deterministic_toy_data(["class_a", "class_b"]) + csp = CSP(norm_trace=False, transform_into="average_power", log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ab = log_power[0] - log_power[1] - x, y = deterministic_toy_data(['class_b', 'class_a']) - csp = CSP(norm_trace=False, transform_into='average_power', log=True) + x, y = deterministic_toy_data(["class_b", "class_a"]) + csp = CSP(norm_trace=False, transform_into="average_power", log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ba = log_power[0] - log_power[1] - assert_array_almost_equal(log_power_ratio_ab, - log_power_ratio_ba) + assert_array_almost_equal(log_power_ratio_ab, log_power_ratio_ba) def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" - x, y = deterministic_toy_data(['class_a', 'class_b']) + x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order='invalid') + pytest.raises(ValueError, CSP, component_order="invalid") # component_order='alternate' only works with two classes - csp = CSP(component_order='alternate') + csp = CSP(component_order="alternate") with pytest.raises(ValueError): - csp.fit(np.zeros((3, 0, 0)), ['a', 'b', 'c']) + csp.fit(np.zeros((3, 0, 0)), ["a", "b", "c"]) - p_alt = CSP(component_order='alternate').fit(x, y).patterns_ - p_mut = CSP(component_order='mutual_info').fit(x, y).patterns_ + p_alt = CSP(component_order="alternate").fit(x, y).patterns_ + p_mut = CSP(component_order="mutual_info").fit(x, y).patterns_ # This permutation of p_alt and p_mut is explained by the particular # eigenvalues of the toy data: [0.06, 0.1, 0.5, 0.8]. diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index b24ebdd75aa..aaeea7c28f7 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -23,44 +23,55 @@ def test_ems(): """Test event-matched spatial filters.""" from sklearn.model_selection import StratifiedKFold + raw = io.read_raw_fif(raw_fname, preload=False) # create unequal number of events events = read_events(event_name) events[-2, 2] = 3 - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) - pytest.raises(ValueError, compute_ems, epochs, ['aud_l', 'vis_l']) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) + pytest.raises(ValueError, compute_ems, epochs, ["aud_l", "vis_l"]) epochs.equalize_event_counts(epochs.event_id) - pytest.raises(KeyError, compute_ems, epochs, ['blah', 'hahah']) + pytest.raises(KeyError, compute_ems, epochs, ["blah", "hahah"]) surrogates, filters, conditions = compute_ems(epochs) assert_equal(list(set(conditions)), [1, 3]) events = read_events(event_name) event_id2 = dict(aud_l=1, aud_r=2, vis_l=3) - epochs = Epochs(raw, events, event_id2, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, + events, + event_id2, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + ) epochs.equalize_event_counts(epochs.event_id) - n_expected = sum([len(epochs[k]) for k in ['aud_l', 'vis_l']]) + n_expected = sum([len(epochs[k]) for k in ["aud_l", "vis_l"]]) pytest.raises(ValueError, compute_ems, epochs) - surrogates, filters, conditions = compute_ems(epochs, ['aud_r', 'vis_l']) + surrogates, filters, conditions = compute_ems(epochs, ["aud_r", "vis_l"]) assert_equal(n_expected, len(surrogates)) assert_equal(n_expected, len(conditions)) assert_equal(list(set(conditions)), [2, 3]) # test compute_ems cv - epochs = epochs['aud_r', 'vis_l'] + epochs = epochs["aud_r", "vis_l"] epochs.equalize_event_counts(epochs.event_id) cv = StratifiedKFold(n_splits=3) compute_ems(epochs, cv=cv) compute_ems(epochs, cv=2) - pytest.raises(ValueError, compute_ems, epochs, cv='foo') + pytest.raises(ValueError, compute_ems, epochs, cv="foo") pytest.raises(ValueError, compute_ems, epochs, cv=len(epochs) + 1) raw.close() @@ -70,13 +81,13 @@ def test_ems(): X = X / np.std(X) # X scaled outside cv in compute_ems Xt, coefs = list(), list() ems = EMS() - assert_equal(ems.__repr__(), '') + assert_equal(ems.__repr__(), "") # manual leave-one-out to avoid sklearn version problem for test in range(len(y)): train = np.setdiff1d(range(len(y)), np.atleast_1d(test)) ems.fit(X[train], y[train]) coefs.append(ems.filters_) Xt.append(ems.transform(X[[test]])) - assert_equal(ems.__repr__(), '') + assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index c5d62fb4c63..9a993b43669 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -12,10 +12,13 @@ from mne.utils import requires_sklearn from mne.decoding import ReceptiveField, TimeDelayingRidge -from mne.decoding.receptive_field import (_delay_time_series, _SCORERS, - _times_to_delays, _delays_to_slice) -from mne.decoding.time_delaying_ridge import (_compute_reg_neighbors, - _compute_corrs) +from mne.decoding.receptive_field import ( + _delay_time_series, + _SCORERS, + _times_to_delays, + _delays_to_slice, +) +from mne.decoding.time_delaying_ridge import _compute_reg_neighbors, _compute_corrs data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" @@ -26,30 +29,53 @@ event_id = dict(aud_l=1, vis_l=3) # Loading raw data -n_jobs_test = (1, 'cuda') +n_jobs_test = (1, "cuda") def test_compute_reg_neighbors(): """Test fast calculation of laplacian regularizer.""" for reg_type in ( - ('ridge', 'ridge'), - ('ridge', 'laplacian'), - ('laplacian', 'ridge'), - ('laplacian', 'laplacian')): + ("ridge", "ridge"), + ("ridge", "laplacian"), + ("laplacian", "ridge"), + ("laplacian", "laplacian"), + ): for n_ch_x, n_delays in ( - (1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), - (2, 2), (2, 3), (3, 2), (3, 3), - (2, 4), (4, 2), (3, 4), (4, 3), (4, 4), - (5, 4), (4, 5), (5, 5), - (20, 9), (9, 20)): + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (1, 4), + (4, 1), + (2, 2), + (2, 3), + (3, 2), + (3, 3), + (2, 4), + (4, 2), + (3, 4), + (4, 3), + (4, 4), + (5, 4), + (4, 5), + (5, 5), + (20, 9), + (9, 20), + ): for normed in (True, False): reg_direct = _compute_reg_neighbors( - n_ch_x, n_delays, reg_type, 'direct', normed=normed) + n_ch_x, n_delays, reg_type, "direct", normed=normed + ) reg_csgraph = _compute_reg_neighbors( - n_ch_x, n_delays, reg_type, 'csgraph', normed=normed) + n_ch_x, n_delays, reg_type, "csgraph", normed=normed + ) assert_allclose( - reg_direct, reg_csgraph, atol=1e-7, - err_msg='%s: %s' % (reg_type, (n_ch_x, n_delays))) + reg_direct, + reg_csgraph, + atol=1e-7, + err_msg="%s: %s" % (reg_type, (n_ch_x, n_delays)), + ) @requires_sklearn @@ -57,19 +83,20 @@ def test_rank_deficiency(): """Test signals that are rank deficient.""" # See GH#4253 from sklearn.linear_model import Ridge + N = 256 - fs = 1. + fs = 1.0 tmin, tmax = -50, 100 reg = 0.1 rng = np.random.RandomState(0) eeg = rng.randn(N, 1) eeg *= 100 eeg = rfft(eeg, axis=0) - eeg[N // 4:] = 0 # rank-deficient lowpass + eeg[N // 4 :] = 0 # rank-deficient lowpass eeg = irfft(eeg, axis=0) win = np.hanning(N // 8) win /= win.mean() - y = np.apply_along_axis(np.convolve, 0, eeg, win, mode='same') + y = np.apply_along_axis(np.convolve, 0, eeg, win, mode="same") y += rng.randn(*y.shape) * 100 for est in (Ridge(reg), reg): @@ -101,14 +128,15 @@ def test_time_delay(): ((-2, 0), 1), ((-2, -1), 1), ((-2, -1), 1), - ((0, .2), 10), - ((-.1, .1), 10)] + ((0, 0.2), 10), + ((-0.1, 0.1), 10), + ] for (tmin, tmax), isfreq in test_tlims: # sfreq must be int/float - with pytest.raises(TypeError, match='`sfreq` must be an instance of'): + with pytest.raises(TypeError, match="`sfreq` must be an instance of"): _delay_time_series(X, tmin, tmax, sfreq=[1]) # Delays must be int/float - with pytest.raises(TypeError, match='.*complex.*'): + with pytest.raises(TypeError, match=".*complex.*"): _delay_time_series(X, np.complex128(tmin), tmax, 1) # Make sure swapaxes works start, stop = int(round(tmin * isfreq)), int(round(tmax * isfreq)) + 1 @@ -128,34 +156,36 @@ def test_time_delay(): del_zero = int(round(-tmin * isfreq)) for ii in range(-2, 3): idx = del_zero + ii - err_msg = '[%s,%s] (%s): %s %s' % (tmin, tmax, isfreq, ii, idx) + err_msg = "[%s,%s] (%s): %s %s" % (tmin, tmax, isfreq, ii, idx) if 0 <= idx < X_delayed.shape[-1]: if ii == 0: - assert_array_equal(X_delayed[:, :, idx], X, - err_msg=err_msg) + assert_array_equal(X_delayed[:, :, idx], X, err_msg=err_msg) elif ii < 0: # negative delay - assert_array_equal(X_delayed[:ii, :, idx], X[-ii:, :], - err_msg=err_msg) - assert_array_equal(X_delayed[ii:, :, idx], 0.) + assert_array_equal( + X_delayed[:ii, :, idx], X[-ii:, :], err_msg=err_msg + ) + assert_array_equal(X_delayed[ii:, :, idx], 0.0) else: - assert_array_equal(X_delayed[ii:, :, idx], X[:-ii, :], - err_msg=err_msg) - assert_array_equal(X_delayed[:ii, :, idx], 0.) + assert_array_equal( + X_delayed[ii:, :, idx], X[:-ii, :], err_msg=err_msg + ) + assert_array_equal(X_delayed[:ii, :, idx], 0.0) @pytest.mark.slowtest # slow on Azure -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_basic(n_jobs): """Test model prep and fitting.""" from sklearn.linear_model import Ridge + # Make sure estimator pulling works mod = Ridge() rng = np.random.RandomState(1337) # Test the receptive field model # Define parameters for the model and simulate inputs + weights - tmin, tmax = -10., 0 + tmin, tmax = -10.0, 0 n_feats = 3 rng = np.random.RandomState(0) X = rng.randn(10000, n_feats) @@ -163,82 +193,83 @@ def test_receptive_field_basic(n_jobs): # Delay inputs and cut off first 4 values since they'll be cut in the fit X_del = np.concatenate( - _delay_time_series(X, tmin, tmax, 1.).transpose(2, 0, 1), axis=1) + _delay_time_series(X, tmin, tmax, 1.0).transpose(2, 0, 1), axis=1 + ) y = np.dot(X_del, w) # Fit the model and test values - feature_names = ['feature_%i' % ii for ii in [0, 1, 2]] - rf = ReceptiveField(tmin, tmax, 1, feature_names, estimator=mod, - patterns=True) + feature_names = ["feature_%i" % ii for ii in [0, 1, 2]] + rf = ReceptiveField(tmin, tmax, 1, feature_names, estimator=mod, patterns=True) rf.fit(X, y) assert_array_equal(rf.delays_, np.arange(tmin, tmax + 1)) y_pred = rf.predict(X) assert_allclose(y[rf.valid_samples_], y_pred[rf.valid_samples_], atol=1e-2) scores = rf.score(X, y) - assert scores > .99 + assert scores > 0.99 assert_allclose(rf.coef_.T.ravel(), w, atol=1e-3) # Make sure different input shapes work - rf.fit(X[:, np.newaxis:], y[:, np.newaxis]) + rf.fit(X[:, np.newaxis :], y[:, np.newaxis]) rf.fit(X, y[:, np.newaxis]) - with pytest.raises(ValueError, match='If X has 3 .* y must have 2 or 3'): + with pytest.raises(ValueError, match="If X has 3 .* y must have 2 or 3"): rf.fit(X[..., np.newaxis], y) - with pytest.raises(ValueError, match='X must be shape'): + with pytest.raises(ValueError, match="X must be shape"): rf.fit(X[:, 0], y) - with pytest.raises(ValueError, match='X and y do not have the same n_epo'): - rf.fit(X[:, np.newaxis], np.tile(y[:, np.newaxis, np.newaxis], - [1, 2, 1])) - with pytest.raises(ValueError, match='X and y do not have the same n_tim'): + with pytest.raises(ValueError, match="X and y do not have the same n_epo"): + rf.fit(X[:, np.newaxis], np.tile(y[:, np.newaxis, np.newaxis], [1, 2, 1])) + with pytest.raises(ValueError, match="X and y do not have the same n_tim"): rf.fit(X, y[:-2]) - with pytest.raises(ValueError, match='n_features in X does not match'): + with pytest.raises(ValueError, match="n_features in X does not match"): rf.fit(X[:, :1], y) # auto-naming features - feature_names = ['feature_%s' % ii for ii in [0, 1, 2]] - rf = ReceptiveField(tmin, tmax, 1, estimator=mod, - feature_names=feature_names) + feature_names = ["feature_%s" % ii for ii in [0, 1, 2]] + rf = ReceptiveField(tmin, tmax, 1, estimator=mod, feature_names=feature_names) assert_equal(rf.feature_names, feature_names) rf = ReceptiveField(tmin, tmax, 1, estimator=mod) rf.fit(X, y) assert_equal(rf.feature_names, None) # Float becomes ridge - rf = ReceptiveField(tmin, tmax, 1, ['one', 'two', 'three'], estimator=0) + rf = ReceptiveField(tmin, tmax, 1, ["one", "two", "three"], estimator=0) str(rf) # repr works before fit rf.fit(X, y) assert isinstance(rf.estimator_, TimeDelayingRidge) str(rf) # repr works after fit - rf = ReceptiveField(tmin, tmax, 1, ['one'], estimator=0) + rf = ReceptiveField(tmin, tmax, 1, ["one"], estimator=0) rf.fit(X[:, [0]], y) str(rf) # repr with one feature # Should only accept estimators or floats - with pytest.raises(ValueError, match='`estimator` must be a float or'): - ReceptiveField(tmin, tmax, 1, estimator='foo').fit(X, y) - with pytest.raises(ValueError, match='`estimator` must be a float or'): + with pytest.raises(ValueError, match="`estimator` must be a float or"): + ReceptiveField(tmin, tmax, 1, estimator="foo").fit(X, y) + with pytest.raises(ValueError, match="`estimator` must be a float or"): ReceptiveField(tmin, tmax, 1, estimator=np.array([1, 2, 3])).fit(X, y) - with pytest.raises(ValueError, match='tmin .* must be at most tmax'): + with pytest.raises(ValueError, match="tmin .* must be at most tmax"): ReceptiveField(5, 4, 1).fit(X, y) # scorers for key, val in _SCORERS.items(): - rf = ReceptiveField(tmin, tmax, 1, ['one'], - estimator=0, scoring=key, patterns=True) + rf = ReceptiveField( + tmin, tmax, 1, ["one"], estimator=0, scoring=key, patterns=True + ) rf.fit(X[:, [0]], y) y_pred = rf.predict(X[:, [0]]).T.ravel()[:, np.newaxis] - assert_allclose(val(y[:, np.newaxis], y_pred, - multioutput='raw_values'), - rf.score(X[:, [0]], y), rtol=1e-2) - with pytest.raises(ValueError, match='inputs must be shape'): - _SCORERS['corrcoef'](y.ravel(), y_pred, multioutput='raw_values') + assert_allclose( + val(y[:, np.newaxis], y_pred, multioutput="raw_values"), + rf.score(X[:, [0]], y), + rtol=1e-2, + ) + with pytest.raises(ValueError, match="inputs must be shape"): + _SCORERS["corrcoef"](y.ravel(), y_pred, multioutput="raw_values") # Need correct scorers - with pytest.raises(ValueError, match='scoring must be one of'): - ReceptiveField(tmin, tmax, 1., scoring='foo').fit(X, y) + with pytest.raises(ValueError, match="scoring must be one of"): + ReceptiveField(tmin, tmax, 1.0, scoring="foo").fit(X, y) -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) def test_time_delaying_fast_calc(n_jobs): """Test time delaying and fast calculations.""" X = np.array([[1, 2, 3], [5, 7, 11]]).T # all negative smin, smax = 1, 2 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) # (n_times, n_features, n_delays) -> (n_times, n_features * n_delays) X_del.shape = (X.shape[0], -1) expected = np.array([[0, 1, 2], [0, 0, 1], [0, 5, 7], [0, 0, 5]]).T @@ -250,30 +281,32 @@ def test_time_delaying_fast_calc(n_jobs): assert_allclose(x_xt, expected) # all positive smin, smax = -2, -1 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) expected = np.array([[3, 0, 0], [2, 3, 0], [11, 0, 0], [7, 11, 0]]).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = [[9, 6, 33, 21], [6, 13, 22, 47], - [33, 22, 121, 77], [21, 47, 77, 170]] + expected = [[9, 6, 33, 21], [6, 13, 22, 47], [33, 22, 121, 77], [21, 47, 77, 170]] assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) # both sides smin, smax = -1, 1 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[2, 3, 0], [1, 2, 3], [0, 1, 2], - [7, 11, 0], [5, 7, 11], [0, 5, 7]]).T + expected = np.array( + [[2, 3, 0], [1, 2, 3], [0, 1, 2], [7, 11, 0], [5, 7, 11], [0, 5, 7]] + ).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = [[13, 8, 3, 47, 31, 15], - [8, 14, 8, 29, 52, 31], - [3, 8, 5, 11, 29, 19], - [47, 29, 11, 170, 112, 55], - [31, 52, 29, 112, 195, 112], - [15, 31, 19, 55, 112, 74]] + expected = [ + [13, 8, 3, 47, 31, 15], + [8, 14, 8, 29, 52, 31], + [3, 8, 5, 11, 29, 19], + [47, 29, 11, 170, 112, 55], + [31, 52, 29, 112, 195, 112], + [15, 31, 19, 55, 112, 74], + ] assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) @@ -281,10 +314,9 @@ def test_time_delaying_fast_calc(n_jobs): # slightly harder to get the non-Toeplitz correction correct X = np.array([[1, 2, 3, 5]]).T smin, smax = 0, 3 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[1, 2, 3, 5], [0, 1, 2, 3], - [0, 0, 1, 2], [0, 0, 0, 1]]).T + expected = np.array([[1, 2, 3, 5], [0, 1, 2, 3], [0, 0, 1, 2], [0, 0, 0, 1]]).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) expected = [[39, 23, 13, 5], [23, 14, 8, 3], [13, 8, 5, 2], [5, 3, 2, 1]] @@ -295,18 +327,23 @@ def test_time_delaying_fast_calc(n_jobs): # even worse X = np.array([[1, 2, 3], [5, 7, 11]]).T smin, smax = 0, 2 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[1, 2, 3], [0, 1, 2], [0, 0, 1], - [5, 7, 11], [0, 5, 7], [0, 0, 5]]).T + expected = np.array( + [[1, 2, 3], [0, 1, 2], [0, 0, 1], [5, 7, 11], [0, 5, 7], [0, 0, 5]] + ).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = np.array([[14, 8, 3, 52, 31, 15], - [8, 5, 2, 29, 19, 10], - [3, 2, 1, 11, 7, 5], - [52, 29, 11, 195, 112, 55], - [31, 19, 7, 112, 74, 35], - [15, 10, 5, 55, 35, 25]]) + expected = np.array( + [ + [14, 8, 3, 52, 31, 15], + [8, 5, 2, 29, 19, 10], + [3, 2, 1, 11, 7, 5], + [52, 29, 11, 195, 112, 55], + [31, 19, 7, 112, 74, 35], + [15, 10, 5, 55, 35, 25], + ] + ) assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) @@ -323,10 +360,10 @@ def test_time_delaying_fast_calc(n_jobs): for ii in range(X.shape[1]): kernel = rng.randn(smax - smin + 1) kernel -= np.mean(kernel) - y[:, ii % y.shape[-1]] = np.convolve(X[:, ii], kernel, 'same') + y[:, ii % y.shape[-1]] = np.convolve(X[:, ii], kernel, "same") x_xt, x_yt, n_ch_x, _, _ = _compute_corrs(X, y, smin, smax + 1) - X_del = _delay_time_series(X, smin, smax, 1., fill_mean=False) - x_yt_true = einsum('tfd,to->ofd', X_del, y) + X_del = _delay_time_series(X, smin, smax, 1.0, fill_mean=False) + x_yt_true = einsum("tfd,to->ofd", X_del, y) x_yt_true = np.reshape(x_yt_true, (x_yt_true.shape[0], -1)).T assert_allclose(x_yt, x_yt_true, atol=1e-7, err_msg=(smin, smax)) X_del.shape = (X.shape[0], -1) @@ -334,11 +371,12 @@ def test_time_delaying_fast_calc(n_jobs): assert_allclose(x_xt, x_xt_true, atol=1e-7, err_msg=(smin, smax)) -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_1d(n_jobs): """Test that the fast solving works like Ridge.""" from sklearn.linear_model import Ridge + rng = np.random.RandomState(0) x = rng.randn(500, 1) for delay in range(-2, 3): @@ -356,22 +394,26 @@ def test_receptive_field_1d(n_jobs): y.shape = (y.shape[0],) + (1,) * (ndim - 1) for slim in slims: smin, smax = slim - lap = TimeDelayingRidge(smin, smax, 1., 0.1, 'laplacian', - fit_intercept=False, n_jobs=n_jobs) - for estimator in (Ridge(alpha=0.), Ridge(alpha=0.1), 0., 0.1, - lap): + lap = TimeDelayingRidge( + smin, + smax, + 1.0, + 0.1, + "laplacian", + fit_intercept=False, + n_jobs=n_jobs, + ) + for estimator in (Ridge(alpha=0.0), Ridge(alpha=0.1), 0.0, 0.1, lap): for offset in (-100, 0, 100): - model = ReceptiveField(smin, smax, 1., - estimator=estimator, - n_jobs=n_jobs) + model = ReceptiveField( + smin, smax, 1.0, estimator=estimator, n_jobs=n_jobs + ) use_x = x + offset model.fit(use_x, y) if estimator is lap: continue # these checks are too stringent - assert_allclose(model.estimator_.intercept_, -offset, - atol=1e-1) - assert_array_equal(model.delays_, - np.arange(smin, smax + 1)) + assert_allclose(model.estimator_.intercept_, -offset, atol=1e-1) + assert_array_equal(model.delays_, np.arange(smin, smax + 1)) expected = (model.delays_ == delay).astype(float) expected = expected[np.newaxis] # features if y.ndim == 2: @@ -383,16 +425,19 @@ def test_receptive_field_1d(n_jobs): assert stop - start >= 495 assert_allclose( model.predict(use_x)[model.valid_samples_], - y[model.valid_samples_], atol=1e-2) + y[model.valid_samples_], + atol=1e-2, + ) score = np.mean(model.score(use_x, y)) assert score > 0.9999 -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_nd(n_jobs): """Test multidimensional support.""" from sklearn.linear_model import Ridge + # multidimensional rng = np.random.RandomState(3) x = rng.randn(1000, 3) @@ -407,55 +452,57 @@ def test_receptive_field_nd(n_jobs): x -= np.mean(x, axis=0) x_off = x + 1e3 expected = [ - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 4, 0], - [0, 0, 2, 0, 0, 0]], - [[0, 0, 0, -3, 0, 0], - [0, -1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 4, 0], [0, 0, 2, 0, 0, 0]], + [[0, 0, 0, -3, 0, 0], [0, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], ] - tdr_l = TimeDelayingRidge(smin, smax, 1., 0.1, 'laplacian', n_jobs=n_jobs) - tdr_nc = TimeDelayingRidge(smin, smax, 1., 0.1, n_jobs=n_jobs, - edge_correction=False) - for estimator, atol in zip((Ridge(alpha=0.), 0., 0.01, tdr_l, tdr_nc), - (1e-3, 1e-3, 1e-3, 5e-3, 5e-2)): - model = ReceptiveField(smin, smax, 1., - estimator=estimator) + tdr_l = TimeDelayingRidge(smin, smax, 1.0, 0.1, "laplacian", n_jobs=n_jobs) + tdr_nc = TimeDelayingRidge( + smin, smax, 1.0, 0.1, n_jobs=n_jobs, edge_correction=False + ) + for estimator, atol in zip( + (Ridge(alpha=0.0), 0.0, 0.01, tdr_l, tdr_nc), (1e-3, 1e-3, 1e-3, 5e-3, 5e-2) + ): + model = ReceptiveField(smin, smax, 1.0, estimator=estimator) model.fit(x, y) - assert_array_equal(model.delays_, - np.arange(smin, smax + 1)) + assert_array_equal(model.delays_, np.arange(smin, smax + 1)) assert_allclose(model.coef_, expected, atol=atol) - tdr = TimeDelayingRidge(smin, smax, 1., 0.01, reg_type='foo', - n_jobs=n_jobs) - model = ReceptiveField(smin, smax, 1., estimator=tdr) - with pytest.raises(ValueError, match='reg_type entries must be one of'): + tdr = TimeDelayingRidge(smin, smax, 1.0, 0.01, reg_type="foo", n_jobs=n_jobs) + model = ReceptiveField(smin, smax, 1.0, estimator=tdr) + with pytest.raises(ValueError, match="reg_type entries must be one of"): model.fit(x, y) - tdr = TimeDelayingRidge(smin, smax, 1., 0.01, reg_type=['laplacian'], - n_jobs=n_jobs) - model = ReceptiveField(smin, smax, 1., estimator=tdr) - with pytest.raises(ValueError, match='reg_type must have two elements'): + tdr = TimeDelayingRidge( + smin, smax, 1.0, 0.01, reg_type=["laplacian"], n_jobs=n_jobs + ) + model = ReceptiveField(smin, smax, 1.0, estimator=tdr) + with pytest.raises(ValueError, match="reg_type must have two elements"): model.fit(x, y) model = ReceptiveField(smin, smax, 1, estimator=tdr, fit_intercept=False) - with pytest.raises(ValueError, match='fit_intercept'): + with pytest.raises(ValueError, match="fit_intercept"): model.fit(x, y) # Now check the intercept_ - tdr = TimeDelayingRidge(smin, smax, 1., 0., n_jobs=n_jobs) - tdr_no = TimeDelayingRidge(smin, smax, 1., 0., fit_intercept=False, - n_jobs=n_jobs) - for estimator in (Ridge(alpha=0.), tdr, - Ridge(alpha=0., fit_intercept=False), tdr_no): + tdr = TimeDelayingRidge(smin, smax, 1.0, 0.0, n_jobs=n_jobs) + tdr_no = TimeDelayingRidge(smin, smax, 1.0, 0.0, fit_intercept=False, n_jobs=n_jobs) + for estimator in ( + Ridge(alpha=0.0), + tdr, + Ridge(alpha=0.0, fit_intercept=False), + tdr_no, + ): # first with no intercept in the data - model = ReceptiveField(smin, smax, 1., estimator=estimator) + model = ReceptiveField(smin, smax, 1.0, estimator=estimator) model.fit(x, y) - assert_allclose(model.estimator_.intercept_, 0., atol=1e-7, - err_msg=repr(estimator)) - assert_allclose(model.coef_, expected, atol=1e-3, - err_msg=repr(estimator)) + assert_allclose( + model.estimator_.intercept_, 0.0, atol=1e-7, err_msg=repr(estimator) + ) + assert_allclose(model.coef_, expected, atol=1e-3, err_msg=repr(estimator)) y_pred = model.predict(x) - assert_allclose(y_pred[model.valid_samples_], - y[model.valid_samples_], - atol=1e-2, err_msg=repr(estimator)) + assert_allclose( + y_pred[model.valid_samples_], + y[model.valid_samples_], + atol=1e-2, + err_msg=repr(estimator), + ) score = np.mean(model.score(x, y)) assert score > 0.9999 @@ -466,12 +513,14 @@ def test_receptive_field_nd(n_jobs): itol = 0.5 ctol = 5e-4 else: - val = itol = 0. - ctol = 2. - assert_allclose(model.estimator_.intercept_, val, atol=itol, - err_msg=repr(estimator)) - assert_allclose(model.coef_, expected, atol=ctol, rtol=ctol, - err_msg=repr(estimator)) + val = itol = 0.0 + ctol = 2.0 + assert_allclose( + model.estimator_.intercept_, val, atol=itol, err_msg=repr(estimator) + ) + assert_allclose( + model.coef_, expected, atol=ctol, rtol=ctol, err_msg=repr(estimator) + ) if estimator.fit_intercept: ptol = 1e-2 stol = 0.999999 @@ -479,13 +528,14 @@ def test_receptive_field_nd(n_jobs): ptol = 10 stol = 0.6 y_pred = model.predict(x_off)[model.valid_samples_] - assert_allclose(y_pred, y[model.valid_samples_], - atol=ptol, err_msg=repr(estimator)) + assert_allclose( + y_pred, y[model.valid_samples_], atol=ptol, err_msg=repr(estimator) + ) score = np.mean(model.score(x_off, y)) assert score > stol, estimator - model = ReceptiveField(smin, smax, 1., fit_intercept=False) + model = ReceptiveField(smin, smax, 1.0, fit_intercept=False) model.fit(x_off, y) - assert_allclose(model.estimator_.intercept_, 0., atol=1e-7) + assert_allclose(model.estimator_.intercept_, 0.0, atol=1e-7) score = np.mean(model.score(x_off, y)) assert score > 0.6 @@ -496,7 +546,8 @@ def _make_data(n_feats, n_targets, n_samples, tmin, tmax): w = rng.randn(int((tmax - tmin) + 1) * n_feats, n_targets) # Delay inputs X_del = np.concatenate( - _delay_time_series(X, tmin, tmax, 1.).transpose(2, 0, 1), axis=1) + _delay_time_series(X, tmin, tmax, 1.0).transpose(2, 0, 1), axis=1 + ) y = np.dot(X_del, w) return X, y @@ -506,25 +557,25 @@ def test_inverse_coef(): """Test inverse coefficients computation.""" from sklearn.linear_model import Ridge - tmin, tmax = 0., 10. + tmin, tmax = 0.0, 10.0 n_feats, n_targets, n_samples = 3, 2, 1000 n_delays = int((tmax - tmin) + 1) # Check coefficient dims, for all estimator types X, y = _make_data(n_feats, n_targets, n_samples, tmin, tmax) - tdr = TimeDelayingRidge(tmin, tmax, 1., 0.1, 'laplacian') - for estimator in (0., 0.01, Ridge(alpha=0.), tdr): - rf = ReceptiveField(tmin, tmax, 1., estimator=estimator, - patterns=True) + tdr = TimeDelayingRidge(tmin, tmax, 1.0, 0.1, "laplacian") + for estimator in (0.0, 0.01, Ridge(alpha=0.0), tdr): + rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator, patterns=True) rf.fit(X, y) - inv_rf = ReceptiveField(tmin, tmax, 1., estimator=estimator, - patterns=True) + inv_rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator, patterns=True) inv_rf.fit(y, X) - assert_array_equal(rf.coef_.shape, rf.patterns_.shape, - (n_targets, n_feats, n_delays)) - assert_array_equal(inv_rf.coef_.shape, inv_rf.patterns_.shape, - (n_feats, n_targets, n_delays)) + assert_array_equal( + rf.coef_.shape, rf.patterns_.shape, (n_targets, n_feats, n_delays) + ) + assert_array_equal( + inv_rf.coef_.shape, inv_rf.patterns_.shape, (n_feats, n_targets, n_delays) + ) # we should have np.dot(patterns.T,coef) ~ np.eye(n) c0 = rf.coef_.reshape(n_targets, n_feats * n_delays) @@ -536,10 +587,12 @@ def test_inverse_coef(): def test_linalg_warning(): """Test that warnings are issued when no regularization is applied.""" from sklearn.linear_model import Ridge + n_feats, n_targets, n_samples = 5, 60, 50 X, y = _make_data(n_feats, n_targets, n_samples, tmin, tmax) - for estimator in (0., Ridge(alpha=0.)): - rf = ReceptiveField(tmin, tmax, 1., estimator=estimator) - with pytest.warns((RuntimeWarning, UserWarning), - match='[Singular|scipy.linalg.solve]'): + for estimator in (0.0, Ridge(alpha=0.0)): + rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator) + with pytest.warns( + (RuntimeWarning, UserWarning), match="[Singular|scipy.linalg.solve]" + ): rf.fit(y, X) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 1bc4f1e1e9a..a531d7b668e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -31,25 +31,25 @@ def test_search_light(): from sklearn.linear_model import Ridge, LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.metrics import roc_auc_score, make_scorer + with _record_warnings(): # NumPy module import from sklearn.ensemble import BaggingClassifier from sklearn.base import is_classifier - logreg = LogisticRegression(solver='liblinear', multi_class='ovr', - random_state=0) + logreg = LogisticRegression(solver="liblinear", multi_class="ovr", random_state=0) X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, 'foo') + pytest.raises(ValueError, SlidingEstimator, "foo") sl = SlidingEstimator(Ridge()) - assert (not is_classifier(sl)) - sl = SlidingEstimator(LogisticRegression(solver='liblinear')) - assert (is_classifier(sl)) + assert not is_classifier(sl) + sl = SlidingEstimator(LogisticRegression(solver="liblinear")) + assert is_classifier(sl) # fit - assert_equal(sl.__repr__()[:18], '') + assert_equal(sl.__repr__()[-28:], ", fitted with 10 estimators>") pytest.raises(ValueError, sl.fit, X[1:], y) pytest.raises(ValueError, sl.fit, X[:, :, 0], y) sl.fit(X, y, sample_weight=np.ones_like(y)) @@ -57,38 +57,37 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_pred = sl.predict(X) - assert (y_pred.dtype == int) + assert y_pred.dtype == int assert_array_equal(y_pred.shape, [n_epochs, n_time]) y_proba = sl.predict_proba(X) - assert (y_proba.dtype == float) + assert y_proba.dtype == float assert_array_equal(y_proba.shape, [n_epochs, n_time, 2]) # score score = sl.score(X, y) assert_array_equal(score.shape, [n_time]) - assert (np.sum(np.abs(score)) != 0) - assert (score.dtype == float) + assert np.sum(np.abs(score)) != 0 + assert score.dtype == float sl = SlidingEstimator(logreg) assert_equal(sl.scoring, None) # Scoring method - for scoring in ['foo', 999]: + for scoring in ["foo", 999]: sl = SlidingEstimator(logreg, scoring=scoring) sl.fit(X, y) pytest.raises((ValueError, TypeError), sl.score, X, y) # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874 # -- 3 class problem - sl = SlidingEstimator(logreg, scoring='roc_auc') + sl = SlidingEstimator(logreg, scoring="roc_auc") y = np.arange(len(X)) % 3 sl.fit(X, y) - with pytest.raises(ValueError, match='for two-class'): + with pytest.raises(ValueError, match="for two-class"): sl.score(X, y) # But check that valid ones should work with new enough sklearn - if 'multi_class' in signature(roc_auc_score).parameters: - scoring = make_scorer( - roc_auc_score, needs_proba=True, multi_class='ovo') + if "multi_class" in signature(roc_auc_score).parameters: + scoring = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovo") sl = SlidingEstimator(logreg, scoring=scoring) sl.fit(X, y) sl.score(X, y) # smoke test @@ -97,8 +96,10 @@ def test_search_light(): y = np.arange(len(X)) % 2 + 1 sl.fit(X, y) score = sl.score(X, y) - assert_array_equal(score, [roc_auc_score(y - 1, _y_pred - 1) - for _y_pred in sl.decision_function(X).T]) + assert_array_equal( + score, + [roc_auc_score(y - 1, _y_pred - 1) for _y_pred in sl.decision_function(X).T], + ) y = np.arange(len(X)) % 2 # Cannot pass a metric as a scoring parameter @@ -107,22 +108,23 @@ def test_search_light(): pytest.raises(ValueError, sl1.score, X, y) # Now use string as scoring - sl1 = SlidingEstimator(logreg, scoring='roc_auc') + sl1 = SlidingEstimator(logreg, scoring="roc_auc") sl1.fit(X, y) rng = np.random.RandomState(0) X = rng.randn(*X.shape) # randomize X to avoid AUCs in [0, 1] score_sl = sl1.score(X, y) assert_array_equal(score_sl.shape, [n_time]) - assert (score_sl.dtype == float) + assert score_sl.dtype == float # Check that scoring was applied adequately scoring = make_scorer(roc_auc_score, needs_threshold=True) - score_manual = [scoring(est, x, y) for est, x in zip( - sl1.estimators_, X.transpose(2, 0, 1))] + score_manual = [ + scoring(est, x, y) for est, x in zip(sl1.estimators_, X.transpose(2, 0, 1)) + ] assert_array_equal(score_manual, score_sl) # n_jobs - sl = SlidingEstimator(logreg, n_jobs=None, scoring='roc_auc') + sl = SlidingEstimator(logreg, n_jobs=None, scoring="roc_auc") score_1job = sl.fit(X, y).score(X, y) sl.n_jobs = 2 score_njobs = sl.fit(X, y).score(X, y) @@ -139,10 +141,9 @@ def transform(self, X): return super(_LogRegTransformer, self).predict_proba(X)[..., 1] logreg_transformer = _LogRegTransformer( - random_state=0, multi_class='ovr', solver='liblinear' + random_state=0, multi_class="ovr", solver="liblinear" ) - pipe = make_pipeline(SlidingEstimator(logreg_transformer), - logreg) + pipe = make_pipeline(SlidingEstimator(logreg_transformer), logreg) pipe.fit(X, y) pipe.predict(X) @@ -151,8 +152,7 @@ def transform(self, X): y = np.arange(10) % 2 y_preds = list() for n_jobs in [1, 2]: - pipe = SlidingEstimator( - make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) + pipe = SlidingEstimator(make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) y_preds.append(pipe.fit(X, y).predict(X)) features_shape = pipe.estimators_[0].steps[0][1].features_shape_ assert_array_equal(features_shape, [3, 4]) @@ -164,7 +164,7 @@ def transform(self, X): pipe = SlidingEstimator(BaggingClassifier(None, 2), n_jobs=n_jobs) pipe.fit(X, y) pipe.score(X, y) - assert (isinstance(pipe.estimators_[0], BaggingClassifier)) + assert isinstance(pipe.estimators_[0], BaggingClassifier) @requires_sklearn @@ -174,24 +174,23 @@ def test_generalization_light(): from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score - logreg = LogisticRegression(solver='liblinear', multi_class='ovr', - random_state=0) + logreg = LogisticRegression(solver="liblinear", multi_class="ovr", random_state=0) X, y = make_data() n_epochs, _, n_time = X.shape # fit gl = GeneralizingEstimator(logreg) - assert_equal(repr(gl)[:23], '') + assert_equal(gl.__repr__()[-28:], ", fitted with 10 estimators>") # transforms y_pred = gl.predict(X) assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time]) - assert (y_pred.dtype == int) + assert y_pred.dtype == int y_proba = gl.predict_proba(X) - assert (y_proba.dtype == float) + assert y_proba.dtype == float assert_array_equal(y_proba.shape, [n_epochs, n_time, n_time, 2]) # transform to different datasize @@ -201,23 +200,23 @@ def test_generalization_light(): # score score = gl.score(X[:, :, :3], y) assert_array_equal(score.shape, [n_time, 3]) - assert (np.sum(np.abs(score)) != 0) - assert (score.dtype == float) + assert np.sum(np.abs(score)) != 0 + assert score.dtype == float - gl = GeneralizingEstimator(logreg, scoring='roc_auc') + gl = GeneralizingEstimator(logreg, scoring="roc_auc") gl.fit(X, y) score = gl.score(X, y) auc = roc_auc_score(y, gl.estimators_[0].predict_proba(X[..., 0])[..., 1]) assert_equal(score[0, 0], auc) - for scoring in ['foo', 999]: + for scoring in ["foo", 999]: gl = GeneralizingEstimator(logreg, scoring=scoring) gl.fit(X, y) pytest.raises((ValueError, TypeError), gl.score, X, y) # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874 # -- 3 class problem - gl = GeneralizingEstimator(logreg, scoring='roc_auc') + gl = GeneralizingEstimator(logreg, scoring="roc_auc") y = np.arange(len(X)) % 3 gl.fit(X, y) pytest.raises(ValueError, gl.score, X, y) @@ -225,8 +224,10 @@ def test_generalization_light(): y = np.arange(len(X)) % 2 + 1 gl.fit(X, y) score = gl.score(X, y) - manual_score = [[roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds] - for _y_preds in gl.decision_function(X).transpose(1, 2, 0)] + manual_score = [ + [roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds] + for _y_preds in gl.decision_function(X).transpose(1, 2, 0) + ] assert_array_equal(score, manual_score) # n_jobs @@ -246,8 +247,7 @@ def test_generalization_light(): y = np.arange(10) % 2 y_preds = list() for n_jobs in [1, 2]: - pipe = GeneralizingEstimator( - make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) + pipe = GeneralizingEstimator(make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) y_preds.append(pipe.fit(X, y).predict(X)) features_shape = pipe.estimators_[0].steps[0][1].features_shape_ assert_array_equal(features_shape, [3, 4]) @@ -255,8 +255,9 @@ def test_generalization_light(): @requires_sklearn -@pytest.mark.parametrize('n_jobs, verbose', - [(1, False), (2, False), (1, True), (2, 'info')]) +@pytest.mark.parametrize( + "n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")] +) def test_verbose_arg(capsys, n_jobs, verbose): """Test controlling output with the ``verbose`` argument.""" from sklearn.svm import SVC @@ -267,15 +268,14 @@ def test_verbose_arg(capsys, n_jobs, verbose): # shows progress bar and prints other messages to the console with use_log_level(True): for estimator_object in [SlidingEstimator, GeneralizingEstimator]: - estimator = estimator_object( - clf, n_jobs=n_jobs, verbose=verbose) + estimator = estimator_object(clf, n_jobs=n_jobs, verbose=verbose) estimator = estimator.fit(X, y) estimator.score(X, y) estimator.predict(X) stdout, stderr = capsys.readouterr() if isinstance(verbose, bool) and not verbose: - assert all(channel == '' for channel in (stdout, stderr)) + assert all(channel == "" for channel in (stdout, stderr)) else: assert any(len(channel) > 0 for channel in (stdout, stderr)) @@ -287,6 +287,7 @@ def test_cross_val_predict(): from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.base import BaseEstimator, clone from sklearn.model_selection import cross_val_predict + rng = np.random.RandomState(42) X = rng.randn(10, 1, 3) y = rng.randint(0, 2, 10) @@ -309,7 +310,7 @@ def predict_proba(self, X): with pytest.raises(AttributeError, match="classes_ attribute"): estimator = SlidingEstimator(Classifier()) - cross_val_predict(estimator, X, y, method='predict_proba', cv=2) + cross_val_predict(estimator, X, y, method="predict_proba", cv=2) estimator = SlidingEstimator(LinearDiscriminantAnalysis()) - cross_val_predict(estimator, X, y, method='predict_proba', cv=2) + cross_val_predict(estimator, X, y, method="predict_proba", cv=2) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 8ba7657b660..4f674242fd8 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from numpy.testing import (assert_array_almost_equal, assert_array_equal) +from numpy.testing import assert_array_almost_equal, assert_array_equal from mne import io from mne.time_frequency import psd_array_welch from mne.decoding.ssd import SSD @@ -18,9 +18,16 @@ freqs_noise = 8, 13 -def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, - n_samples=500, samples_per_second=250, - n_components=5, SNR=0.05, random_state=42): +def simulate_data( + freqs_sig=[9, 12], + n_trials=100, + n_channels=20, + n_samples=500, + samples_per_second=250, + n_components=5, + SNR=0.05, + random_state=42, +): """Simulate data according to an instantaneous mixin model. Data are simulated in the statistical source space, where n=n_components @@ -28,9 +35,13 @@ def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, """ rng = np.random.RandomState(random_state) - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1, - fir_design='firwin') + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + fir_design="firwin", + ) # generate an orthogonal mixin matrix mixing_mat = np.linalg.svd(rng.randn(n_channels, n_channels))[0] @@ -44,8 +55,8 @@ def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, X_s = np.dot(mixing_mat[:, :n_components], S_s.T).T X_n = np.dot(mixing_mat[:, n_components:], S_n.T).T # add noise - X_s = X_s / np.linalg.norm(X_s, 'fro') - X_n = X_n / np.linalg.norm(X_n, 'fro') + X_s = X_s / np.linalg.norm(X_s, "fro") + X_n = X_n / np.linalg.norm(X_n, "fro") X = SNR * X_s + (1 - SNR) * X_n X = X.T S = S.T @@ -58,75 +69,98 @@ def test_ssd(): X, A, S = simulate_data() sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") n_components_true = 5 # Init - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) # freq no int - freq = 'foo' - filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - with pytest.raises(TypeError, match='must be an instance '): + freq = "foo" + filt_params_signal = dict( + l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1 + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + with pytest.raises(TypeError, match="must be an instance "): ssd = SSD(info, filt_params_signal, filt_params_noise) # Wrongly specified noise band freq = 2 - filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - with pytest.raises(ValueError, match='Wrongly specified '): + filt_params_signal = dict( + l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1 + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + with pytest.raises(ValueError, match="Wrongly specified "): ssd = SSD(info, filt_params_signal, filt_params_noise) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise - with pytest.raises(ValueError, match='must be defined'): + with pytest.raises(ValueError, match="must be defined"): ssd = SSD(info, filt_params_signal, filt_params_noise) # Data type - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) pytest.raises(TypeError, ssd.fit, raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match='return_filtered'): - ssd = SSD(info, filt_params_signal, filt_params_noise, - return_filtered=0) + with pytest.raises(ValueError, match="return_filtered"): + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match='sort_by_spectral_ratio'): - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=0) + with pytest.raises(ValueError, match="sort_by_spectral_ratio"): + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) # More than 1 channel type - ch_types = np.reshape([['mag'] * 10, ['eeg'] * 10], n_channels) + ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) info_2 = create_info(ch_names=n_channels, sfreq=sf, ch_types=ch_types) - with pytest.raises(ValueError, match='At this point SSD'): + with pytest.raises(ValueError, match="At this point SSD"): ssd = SSD(info_2, filt_params_signal, filt_params_noise) # Number of channels - info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types='eeg') + info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) pytest.raises(ValueError, ssd.fit, X) # Fit n_components = 10 - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components) + ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) # Call transform before fit pytest.raises(AttributeError, ssd.transform, X) @@ -134,28 +168,43 @@ def test_ssd(): # Check outputs ssd.fit(X) - assert (ssd.filters_.shape == (n_channels, n_channels)) - assert (ssd.patterns_.shape == (n_channels, n_channels)) + assert ssd.filters_.shape == (n_channels, n_channels) + assert ssd.patterns_.shape == (n_channels, n_channels) # Transform X_ssd = ssd.fit_transform(X) - assert (X_ssd.shape[0] == n_components) + assert X_ssd.shape[0] == n_components # back and forward - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(X) X_denoised = ssd.apply(X) assert_array_almost_equal(X_denoised, X) # denoised by low-rank-factorization - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components, sort_by_spectral_ratio=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=n_components, + sort_by_spectral_ratio=True, + ) ssd.fit(X) X_denoised = ssd.apply(X) - assert (np.linalg.matrix_rank(X_denoised) == n_components) + assert np.linalg.matrix_rank(X_denoised) == n_components # Power ratio ordering - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(X) spec_ratio, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) # since we now that the number of true components is 5, the relative @@ -165,12 +214,25 @@ def test_ssd(): # Check detected peaks # fit ssd n_components = n_components_true - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components, sort_by_spectral_ratio=False) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=n_components, + sort_by_spectral_ratio=False, + ) ssd.fit(X) out = ssd.transform(X) @@ -197,7 +259,7 @@ def test_ssd_epoched_data(): X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") n_components_true = 5 # Build epochs as sliding windows over the continuous raw file @@ -206,10 +268,18 @@ def test_ssd_epoched_data(): X_e = np.reshape(X, (100, 20, 500)) # Fit - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) # ssd on epochs ssd_e = SSD(info, filt_params_signal, filt_params_noise) @@ -221,34 +291,44 @@ def test_ssd_epoched_data(): # Check if the 5 first 5 components are the same for both _, sorter_spec_e = ssd_e.get_spectral_ratio(ssd_e.transform(X_e)) _, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) - assert_array_equal(sorter_spec_e[:n_components_true], - sorter_spec[:n_components_true]) + assert_array_equal( + sorter_spec_e[:n_components_true], sorter_spec[:n_components_true] + ) @requires_sklearn def test_ssd_pipeline(): """Test if SSD works in a pipeline.""" from sklearn.pipeline import Pipeline + sf = 250 X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) X_e = np.reshape(X, (100, 20, 500)) # define bynary random output y = np.random.randint(2, size=100) - info = create_info(ch_names=20, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + info = create_info(ch_names=20, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) csp = CSP() - pipe = Pipeline([('SSD', ssd), ('CSP', csp)]) + pipe = Pipeline([("SSD", ssd), ("CSP", csp)]) pipe.set_params(SSD__n_components=5) pipe.set_params(CSP__n_components=2) out = pipe.fit_transform(X_e, y) - assert (out.shape == (100, 2)) - assert (pipe.get_params()['SSD__n_components'] == 5) + assert out.shape == (100, 2) + assert pipe.get_params()["SSD__n_components"] == 5 def test_sorting(): @@ -260,30 +340,53 @@ def test_sorting(): Xtr, Xte = X[:80], X[80:] sf = 250 n_channels = Xtr.shape[1] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) # check sort_by_spectral_ratio set to False - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(Xtr) _, sorter_tr = ssd.get_spectral_ratio(ssd.transform(Xtr)) _, sorter_te = ssd.get_spectral_ratio(ssd.transform(Xte)) assert any(sorter_tr != sorter_te) # check sort_by_spectral_ratio set to True - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=True, + ) ssd.fit(Xtr) # check sorters sorter_in = ssd.sorter_spec - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(Xtr) _, sorter_out = ssd.get_spectral_ratio(ssd.transform(Xtr)) @@ -297,44 +400,70 @@ def test_return_filtered(): X, _, _ = simulate_data(SNR=0.9, freqs_sig=[4, 13]) sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) # return filtered to true - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=False, return_filtered=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + sort_by_spectral_ratio=False, + return_filtered=True, + ) ssd.fit(X) out = ssd.transform(X) psd_out, freqs = psd_array_welch(out[0], sfreq=250, n_fft=250) freqs_up = int(freqs[psd_out > 0.5][0]), int(freqs[psd_out > 0.5][-1]) - assert (freqs_up == freqs_sig) + assert freqs_up == freqs_sig # return filtered to false - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=False, return_filtered=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + sort_by_spectral_ratio=False, + return_filtered=False, + ) ssd.fit(X) out = ssd.transform(X) psd_out, freqs = psd_array_welch(out[0], sfreq=250, n_fft=250) freqs_up = int(freqs[psd_out > 0.5][0]), int(freqs[psd_out > 0.5][-1]) - assert (freqs_up != freqs_sig) + assert freqs_up != freqs_sig def test_non_full_rank_data(): """Test that the method works with non-full rank data.""" n_channels = 10 X, _, _ = simulate_data(SNR=0.9, freqs_sig=[4, 13], n_channels=n_channels) - info = create_info(ch_names=n_channels, sfreq=250, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + info = create_info(ch_names=n_channels, sfreq=250, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) # Make data non-full rank rank = 5 diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 5fea1402e68..8d92c8fd72e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -15,17 +15,18 @@ def test_timefrequency(): """Test TimeFrequency.""" from sklearn.base import clone + # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) - for output in ['avg_power', 'foo', None]: + for output in ["avg_power", "foo", None]: pytest.raises(ValueError, TimeFrequency, freqs, output=output) tf = clone(tf) # Clone estimator freqs_array = np.array(np.asarray(freqs)) - tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.) + tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.0) clone(tf) # Fit diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 3c53d7e2ca1..1884f926862 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -8,12 +8,22 @@ import numpy as np import pytest -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_allclose, + assert_equal, +) from mne import io, read_events, Epochs, pick_types -from mne.decoding import (Scaler, FilterEstimator, PSDEstimator, Vectorizer, - UnsupervisedSpatialFilter, TemporalFilter) +from mne.decoding import ( + Scaler, + FilterEstimator, + PSDEstimator, + Vectorizer, + UnsupervisedSpatialFilter, + TemporalFilter, +) from mne.defaults import DEFAULTS from mne.utils import requires_sklearn, check_version, use_log_level @@ -25,29 +35,34 @@ event_name = data_dir / "test-eve.fif" -@pytest.mark.parametrize('info, method', [ - (True, None), - (True, dict(mag=5, grad=10, eeg=20)), - (False, 'mean'), - (False, 'median'), -]) +@pytest.mark.parametrize( + "info, method", + [ + (True, None), + (True, dict(mag=5, grad=10, eeg=20)), + (False, "mean"), + (False, "median"), + ], +) def test_scaler(info, method): """Test methods of Scaler.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() y = epochs.events[:, -1] epochs_data_t = epochs_data.transpose([1, 0, 2]) - if method in ('mean', 'median'): - if not check_version('sklearn'): - with pytest.raises(ImportError, match='No module'): + if method in ("mean", "median"): + if not check_version("sklearn"): + with pytest.raises(ImportError, match="No module"): Scaler(info, method) return @@ -57,22 +72,28 @@ def test_scaler(info, method): X = scaler.fit_transform(epochs_data, y) assert_equal(X.shape, epochs_data.shape) if method is None or isinstance(method, dict): - sd = DEFAULTS['scalings'] if method is None else method + sd = DEFAULTS["scalings"] if method is None else method stds = np.zeros(len(picks)) - for key in ('mag', 'grad'): - stds[pick_types(epochs.info, meg=key)] = 1. / sd[key] - stds[pick_types(epochs.info, meg=False, eeg=True)] = 1. / sd['eeg'] + for key in ("mag", "grad"): + stds[pick_types(epochs.info, meg=key)] = 1.0 / sd[key] + stds[pick_types(epochs.info, meg=False, eeg=True)] = 1.0 / sd["eeg"] means = np.zeros(len(epochs.ch_names)) - elif method == 'mean': + elif method == "mean": stds = np.array([np.std(ch_data) for ch_data in epochs_data_t]) means = np.array([np.mean(ch_data) for ch_data in epochs_data_t]) else: # median - percs = np.array([np.percentile(ch_data, [25, 50, 75]) - for ch_data in epochs_data_t]) + percs = np.array( + [np.percentile(ch_data, [25, 50, 75]) for ch_data in epochs_data_t] + ) stds = percs[:, 2] - percs[:, 0] means = percs[:, 1] - assert_allclose(X * stds[:, np.newaxis] + means[:, np.newaxis], - epochs_data, rtol=1e-12, atol=1e-20, err_msg=method) + assert_allclose( + X * stds[:, np.newaxis] + means[:, np.newaxis], + epochs_data, + rtol=1e-12, + atol=1e-20, + err_msg=method, + ) X2 = scaler.fit(epochs_data, y).transform(epochs_data) assert_array_equal(X, X2) @@ -85,8 +106,15 @@ def test_scaler(info, method): pytest.raises(ValueError, Scaler, None, None) pytest.raises(TypeError, scaler.fit, epochs, y) pytest.raises(TypeError, scaler.transform, epochs) - epochs_bad = Epochs(raw, events, event_id, 0, 0.01, baseline=None, - picks=np.arange(len(raw.ch_names))) # non-data chs + epochs_bad = Epochs( + raw, + events, + event_id, + 0, + 0.01, + baseline=None, + picks=np.arange(len(raw.ch_names)), + ) # non-data chs scaler = Scaler(epochs_bad.info, None) pytest.raises(ValueError, scaler.fit, epochs_bad.get_data(), y) @@ -95,34 +123,46 @@ def test_filterestimator(): """Test methods of FilterEstimator.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() # Add tests for different combinations of l_freq and h_freq filt = FilterEstimator(epochs.info, l_freq=40, h_freq=80) y = epochs.events[:, -1] X = filt.fit_transform(epochs_data, y) - assert (X.shape == epochs_data.shape) + assert X.shape == epochs_data.shape assert_array_equal(filt.fit(epochs_data, y).transform(epochs_data), X) - filt = FilterEstimator(epochs.info, l_freq=None, h_freq=40, - filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto') + filt = FilterEstimator( + epochs.info, + l_freq=None, + h_freq=40, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + ) y = epochs.events[:, -1] X = filt.fit_transform(epochs_data, y) filt = FilterEstimator(epochs.info, l_freq=1, h_freq=1) y = epochs.events[:, -1] - with pytest.warns(RuntimeWarning, match='longer than the signal'): + with pytest.warns(RuntimeWarning, match="longer than the signal"): pytest.raises(ValueError, filt.fit_transform, epochs_data, y) - filt = FilterEstimator(epochs.info, l_freq=40, h_freq=None, - filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto') + filt = FilterEstimator( + epochs.info, + l_freq=40, + h_freq=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + ) X = filt.fit_transform(epochs_data, y) # Test init exception @@ -134,17 +174,19 @@ def test_psdestimator(): """Test methods of PSDEstimator.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] X = psd.fit_transform(epochs_data, y) - assert (X.shape[0] == epochs_data.shape[0]) + assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception @@ -166,15 +208,13 @@ def test_vectorizer(): assert_array_equal(vect.inverse_transform(result[1:]), data[1:]) # check with different shape - assert_equal(vect.fit_transform(np.random.rand(150, 18, 6, 3)).shape, - (150, 324)) + assert_equal(vect.fit_transform(np.random.rand(150, 18, 6, 3)).shape, (150, 324)) assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly vect.fit(np.random.rand(105, 12, 3)) pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, - np.random.rand(102, 12, 12)) + pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) @requires_sklearn @@ -182,13 +222,24 @@ def test_unsupervised_spatial_filter(): """Test unsupervised spatial filter.""" from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge + raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, baseline=None, verbose=False) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + baseline=None, + verbose=False, + ) # Test estimator pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) @@ -218,34 +269,39 @@ def test_temporal_filter(): X = np.random.rand(5, 5, 1200) # Test init test - values = (('10hz', None, 100., 'auto'), (5., '10hz', 100., 'auto'), - (10., 20., 5., 'auto'), (None, None, 100., '5hz')) + values = ( + ("10hz", None, 100.0, "auto"), + (5.0, "10hz", 100.0, "auto"), + (10.0, 20.0, 5.0, "auto"), + (None, None, 100.0, "5hz"), + ) for low, high, sf, ltrans in values: - filt = TemporalFilter(low, high, sf, ltrans, fir_design='firwin') + filt = TemporalFilter(low, high, sf, ltrans, fir_design="firwin") pytest.raises(ValueError, filt.fit_transform, X) # Add tests for different combinations of l_freq and h_freq - for low, high in ((5., 15.), (None, 15.), (5., None)): - filt = TemporalFilter(low, high, sfreq=100., fir_design='firwin') + for low, high in ((5.0, 15.0), (None, 15.0), (5.0, None)): + filt = TemporalFilter(low, high, sfreq=100.0, fir_design="firwin") Xt = filt.fit_transform(X) assert_array_equal(filt.fit_transform(X), Xt) - assert (X.shape == Xt.shape) + assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match='Data to be filtered must be'): + with pytest.raises(ValueError, match="Data to be filtered must be"): filt.transform([1, 2]) # Test with 2 dimensional data array X = np.random.rand(101, 500) - filt = TemporalFilter(l_freq=25., h_freq=50., sfreq=1000., - filter_length=150, fir_design='firwin2') - with use_log_level('error'): # warning about transition bandwidth + filt = TemporalFilter( + l_freq=25.0, h_freq=50.0, sfreq=1000.0, filter_length=150, fir_design="firwin2" + ) + with use_log_level("error"): # warning about transition bandwidth assert_equal(filt.fit_transform(X).shape, X.shape) def test_bad_triage(): """Test for gh-10924.""" - filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.) + filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" filt.fit_transform(np.zeros((1, 1, 481))) diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index 2d3d13f1300..2299aa5d861 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -13,8 +13,9 @@ from ..utils import warn, ProgressBar, logger -def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, - edge_correction=True): +def _compute_corrs( + X, y, smin, smax, n_jobs=None, fit_intercept=False, edge_correction=True +): """Compute auto- and cross-correlations.""" if fit_intercept: # We could do this in the Fourier domain, too, but it should @@ -27,7 +28,7 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, X = X - X_offset y = y - y_offset else: - X_offset = y_offset = 0. + X_offset = y_offset = 0.0 if X.ndim == 2: assert y.ndim == 2 X = X[:, np.newaxis, :] @@ -41,7 +42,8 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, n_fft = next_fast_len(2 * X.shape[0] - 1) _, cuda_dict = _setup_cuda_fft_multiply_repeated( - n_jobs, [1.], n_fft, 'correlation calculations') + n_jobs, [1.0], n_fft, "correlation calculations" + ) del n_jobs # only used to set as CUDA # create our Toeplitz indexer @@ -49,26 +51,27 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, for ii in range(len_trf): ij[ii, ii:] = np.arange(len_trf - ii) x = np.arange(n_fft - 1, n_fft - len_trf + ii, -1) - ij[ii + 1:, ii] = x + ij[ii + 1 :, ii] = x x_xt = np.zeros([n_ch_x * len_trf] * 2) - x_y = np.zeros((len_trf, n_ch_x, n_ch_y), order='F') + x_y = np.zeros((len_trf, n_ch_x, n_ch_y), order="F") n = n_epochs * (n_ch_x * (n_ch_x + 1) // 2 + n_ch_x) - logger.info('Fitting %d epochs, %d channels' % (n_epochs, n_ch_x)) - pb = ProgressBar(n, mesg='Sample') + logger.info("Fitting %d epochs, %d channels" % (n_epochs, n_ch_x)) + pb = ProgressBar(n, mesg="Sample") count = 0 pb.update(count) for ei in range(n_epochs): this_X = X[:, ei, :] # XXX maybe this is what we should parallelize over CPUs at some point - X_fft = cuda_dict['rfft'](this_X, n=n_fft, axis=0) + X_fft = cuda_dict["rfft"](this_X, n=n_fft, axis=0) X_fft_conj = X_fft.conj() - y_fft = cuda_dict['rfft'](y[:, ei, :], n=n_fft, axis=0) + y_fft = cuda_dict["rfft"](y[:, ei, :], n=n_fft, axis=0) for ch0 in range(n_ch_x): for oi, ch1 in enumerate(range(ch0, n_ch_x)): - this_result = cuda_dict['irfft']( - X_fft[:, ch0] * X_fft_conj[:, ch1], n=n_fft, axis=0) + this_result = cuda_dict["irfft"]( + X_fft[:, ch0] * X_fft_conj[:, ch1], n=n_fft, axis=0 + ) # Our autocorrelation structure is a Toeplitz matrix, but # it's faster to create the Toeplitz ourselves than use # linalg.toeplitz. @@ -85,40 +88,43 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, _edge_correct(this_result, this_X, smax, smin, ch0, ch1) # Store the results in our output matrix - x_xt[ch0 * len_trf:(ch0 + 1) * len_trf, - ch1 * len_trf:(ch1 + 1) * len_trf] += this_result + x_xt[ + ch0 * len_trf : (ch0 + 1) * len_trf, + ch1 * len_trf : (ch1 + 1) * len_trf, + ] += this_result if ch0 != ch1: - x_xt[ch1 * len_trf:(ch1 + 1) * len_trf, - ch0 * len_trf:(ch0 + 1) * len_trf] += this_result.T + x_xt[ + ch1 * len_trf : (ch1 + 1) * len_trf, + ch0 * len_trf : (ch0 + 1) * len_trf, + ] += this_result.T count += 1 pb.update(count) # compute the crosscorrelations - cc_temp = cuda_dict['irfft']( - y_fft * X_fft_conj[:, slice(ch0, ch0 + 1)], n=n_fft, axis=0) + cc_temp = cuda_dict["irfft"]( + y_fft * X_fft_conj[:, slice(ch0, ch0 + 1)], n=n_fft, axis=0 + ) if smin < 0 and smax >= 0: x_y[:-smin, ch0] += cc_temp[smin:] - x_y[len_trf - smax:, ch0] += cc_temp[:smax] + x_y[len_trf - smax :, ch0] += cc_temp[:smax] else: x_y[:, ch0] += cc_temp[smin:smax] count += 1 pb.update(count) - x_y = np.reshape(x_y, (n_ch_x * len_trf, n_ch_y), order='F') + x_y = np.reshape(x_y, (n_ch_x * len_trf, n_ch_y), order="F") return x_xt, x_y, n_ch_x, X_offset, y_offset @jit() def _edge_correct(this_result, this_X, smax, smin, ch0, ch1): if smax > 0: - tail = _toeplitz_dot(this_X[-1:-smax:-1, ch0], - this_X[-1:-smax:-1, ch1]) + tail = _toeplitz_dot(this_X[-1:-smax:-1, ch0], this_X[-1:-smax:-1, ch1]) if smin > 0: - tail = tail[smin - 1:, smin - 1:] - this_result[max(-smin + 1, 0):, max(-smin + 1, 0):] -= tail + tail = tail[smin - 1 :, smin - 1 :] + this_result[max(-smin + 1, 0) :, max(-smin + 1, 0) :] -= tail if smin < 0: - head = _toeplitz_dot(this_X[:-smin, ch0], - this_X[:-smin, ch1])[::-1, ::-1] + head = _toeplitz_dot(this_X[:-smin, ch0], this_X[:-smin, ch1])[::-1, ::-1] if smax < 0: head = head[:smax, :smax] this_result[:-smin, :-smin] -= head @@ -136,28 +142,28 @@ def _toeplitz_dot(a, b): assert a.shape == b.shape and a.ndim == 1 out = np.outer(a, b) for ii in range(1, len(a)): - out[ii, ii:] += out[ii - 1, ii - 1:-1] - out[ii + 1:, ii] += out[ii:-1, ii - 1] + out[ii, ii:] += out[ii - 1, ii - 1 : -1] + out[ii + 1 :, ii] += out[ii:-1, ii - 1] return out -def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', - normed=False): +def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method="direct", normed=False): """Compute regularization parameter from neighbors.""" from scipy import linalg from scipy.sparse.csgraph import laplacian - known_types = ('ridge', 'laplacian') + + known_types = ("ridge", "laplacian") if isinstance(reg_type, str): reg_type = (reg_type,) * 2 if len(reg_type) != 2: - raise ValueError('reg_type must have two elements, got %s' - % (len(reg_type),)) + raise ValueError("reg_type must have two elements, got %s" % (len(reg_type),)) for r in reg_type: if r not in known_types: - raise ValueError('reg_type entries must be one of %s, got %s' - % (known_types, r)) - reg_time = (reg_type[0] == 'laplacian' and n_delays > 1) - reg_chs = (reg_type[1] == 'laplacian' and n_ch_x > 1) + raise ValueError( + "reg_type entries must be one of %s, got %s" % (known_types, r) + ) + reg_time = reg_type[0] == "laplacian" and n_delays > 1 + reg_chs = reg_type[1] == "laplacian" and n_ch_x > 1 if not reg_time and not reg_chs: return np.eye(n_ch_x * n_delays) # regularize time @@ -166,7 +172,7 @@ def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', stride = n_delays + 1 reg.flat[1::stride] += -1 reg.flat[n_delays::stride] += -1 - reg.flat[n_delays + 1:-n_delays - 1:stride] += 1 + reg.flat[n_delays + 1 : -n_delays - 1 : stride] += 1 args = [reg] * n_ch_x reg = linalg.block_diag(*args) else: @@ -178,12 +184,12 @@ def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', row_offset = block * n_ch_x stride = n_delays * n_ch_x + 1 reg.flat[n_delays:-row_offset:stride] += -1 - reg.flat[n_delays + row_offset::stride] += 1 + reg.flat[n_delays + row_offset :: stride] += 1 reg.flat[row_offset:-n_delays:stride] += -1 - reg.flat[:-(n_delays + row_offset):stride] += 1 + reg.flat[: -(n_delays + row_offset) : stride] += 1 assert np.array_equal(reg[::-1, ::-1], reg) - if method == 'direct': + if method == "direct": if normed: norm = np.sqrt(np.diag(reg)) reg /= norm @@ -201,6 +207,7 @@ def _fit_corrs(x_xt, x_y, n_ch_x, reg_type, alpha, n_ch_in): """Fit the model using correlation matrices.""" # do the regularized solving from scipy import linalg + n_ch_out = x_y.shape[1] assert x_y.shape[0] % n_ch_x == 0 n_delays = x_y.shape[0] // n_ch_x @@ -211,11 +218,13 @@ def _fit_corrs(x_xt, x_y, n_ch_x, reg_type, alpha, n_ch_in): # Note: we must use overwrite_a=False in order to be able to # use the fall-back solution below in case a LinAlgError # is raised - w = linalg.solve(mat, x_y, overwrite_a=False, assume_a='pos') + w = linalg.solve(mat, x_y, overwrite_a=False, assume_a="pos") except np.linalg.LinAlgError: - warn('Singular matrix in solving dual problem. Using ' - 'least-squares solution instead.') - w = linalg.lstsq(mat, x_y, lapack_driver='gelsy')[0] + warn( + "Singular matrix in solving dual problem. Using " + "least-squares solution instead." + ) + w = linalg.lstsq(mat, x_y, lapack_driver="gelsy")[0] w = w.T.reshape([n_ch_out, n_ch_in, n_delays]) return w @@ -270,11 +279,19 @@ class TimeDelayingRidge(BaseEstimator): _estimator_type = "regressor" - def __init__(self, tmin, tmax, sfreq, alpha=0., reg_type='ridge', - fit_intercept=True, n_jobs=None, edge_correction=True): + def __init__( + self, + tmin, + tmax, + sfreq, + alpha=0.0, + reg_type="ridge", + fit_intercept=True, + n_jobs=None, + edge_correction=True, + ): if tmin > tmax: - raise ValueError('tmin must be <= tmax, got %s and %s' - % (tmin, tmax)) + raise ValueError("tmin must be <= tmax, got %s and %s" % (tmin, tmax)) self.tmin = float(tmin) self.tmax = float(tmax) self.sfreq = float(sfreq) @@ -317,15 +334,22 @@ def fit(self, X, y): # might want to allow people to do them separately (e.g., to test # different regularization parameters). self.cov_, x_y_, n_ch_x, X_offset, y_offset = _compute_corrs( - X, y, self._smin, self._smax, self.n_jobs, self.fit_intercept, - self.edge_correction) - self.coef_ = _fit_corrs(self.cov_, x_y_, n_ch_x, - self.reg_type, self.alpha, n_ch_x) + X, + y, + self._smin, + self._smax, + self.n_jobs, + self.fit_intercept, + self.edge_correction, + ) + self.coef_ = _fit_corrs( + self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha, n_ch_x + ) # This is the sklearn formula from LinearModel (will be 0. for no fit) if self.fit_intercept: self.intercept_ = y_offset - np.dot(X_offset, self.coef_.sum(-1).T) else: - self.intercept_ = 0. + self.intercept_ = 0.0 return self def predict(self, X): @@ -355,8 +379,8 @@ def predict(self, X): for oi in range(self.coef_.shape[0]): for fi in range(self.coef_.shape[1]): temp = fftconvolve(X[:, ei, fi], self.coef_[oi, fi]) - temp = temp[max(-smin, 0):][:len(out) - offset] - out[offset:len(temp) + offset, ei, oi] += temp + temp = temp[max(-smin, 0) :][: len(out) - offset] + out[offset : len(temp) + offset, ei, oi] += temp out += self.intercept_ if singleton: out = out[:, 0, :] diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index 330cc1ed5c8..d6ed4f6dd56 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -60,13 +60,22 @@ class TimeFrequency(TransformerMixin, BaseEstimator): """ @verbose - def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0, - time_bandwidth=None, use_fft=True, decim=1, output='complex', - n_jobs=1, verbose=None): # noqa: D102 + def __init__( + self, + freqs, + sfreq=1.0, + method="morlet", + n_cycles=7.0, + time_bandwidth=None, + use_fft=True, + decim=1, + output="complex", + n_jobs=1, + verbose=None, + ): # noqa: D102 """Init TimeFrequency transformer.""" # Check non-average output - output = _check_option('output', output, - ['complex', 'power', 'phase']) + output = _check_option("output", output, ["complex", "power", "phase"]) self.freqs = freqs self.sfreq = sfreq @@ -137,10 +146,20 @@ def transform(self, X): X = X[:, np.newaxis, :] # Compute time-frequency - Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method, - self.n_cycles, True, self.time_bandwidth, - self.use_fft, self.decim, self.output, self.n_jobs, - self.verbose) + Xt = _compute_tfr( + X, + self.freqs, + self.sfreq, + self.method, + self.n_cycles, + True, + self.time_bandwidth, + self.use_fft, + self.decim, + self.output, + self.n_jobs, + self.verbose, + ) # Back to original shape if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index b6faf66cf97..2d4316e768a 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -13,12 +13,11 @@ from ..filter import filter_data from ..time_frequency import psd_array_multitaper from ..utils import fill_doc, _check_option, _validate_type, verbose -from ..io.pick import (pick_info, _pick_data_channels, _picks_by_type, - _picks_to_idx) +from ..io.pick import pick_info, _pick_data_channels, _picks_by_type, _picks_to_idx from ..cov import _check_scalings_user -class _ConstantScaler(): +class _ConstantScaler: """Scale channel types using constant values.""" def __init__(self, info, scalings, do_scaling=True): @@ -28,15 +27,17 @@ def __init__(self, info, scalings, do_scaling=True): def fit(self, X, y=None): scalings = _check_scalings_user(self._scalings) - picks_by_type = _picks_by_type(pick_info( - self._info, _pick_data_channels(self._info, exclude=()))) + picks_by_type = _picks_by_type( + pick_info(self._info, _pick_data_channels(self._info, exclude=())) + ) std = np.ones(sum(len(p[1]) for p in picks_by_type)) if X.shape[1] != len(std): - raise ValueError('info had %d data channels but X has %d channels' - % (len(std), len(X))) + raise ValueError( + "info had %d data channels but X has %d channels" % (len(std), len(X)) + ) if self._do_scaling: # this is silly, but necessary for completeness for kind, picks in picks_by_type: - std[picks] = 1. / scalings[kind] + std[picks] = 1.0 / scalings[kind] self.std_ = std self.mean_ = np.zeros_like(std) return self @@ -101,31 +102,38 @@ class Scaler(TransformerMixin, BaseEstimator): if ``scalings`` is a dict or None). """ - def __init__(self, info=None, scalings=None, with_mean=True, - with_std=True): # noqa: D102 + def __init__( + self, info=None, scalings=None, with_mean=True, with_std=True + ): # noqa: D102 self.info = info self.with_mean = with_mean self.with_std = with_std self.scalings = scalings if not (scalings is None or isinstance(scalings, (dict, str))): - raise ValueError('scalings type should be dict, str, or None, ' - 'got %s' % type(scalings)) + raise ValueError( + "scalings type should be dict, str, or None, " "got %s" % type(scalings) + ) if isinstance(scalings, str): - _check_option('scalings', scalings, ['mean', 'median']) + _check_option("scalings", scalings, ["mean", "median"]) if scalings is None or isinstance(scalings, dict): if info is None: - raise ValueError('Need to specify "info" if scalings is' - '%s' % type(scalings)) + raise ValueError( + 'Need to specify "info" if scalings is' "%s" % type(scalings) + ) self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == 'mean': + elif scalings == "mean": from sklearn.preprocessing import StandardScaler + self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std) + with_mean=self.with_mean, with_std=self.with_std + ) else: # scalings == 'median': from sklearn.preprocessing import RobustScaler + self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std) + with_centering=self.with_mean, with_scaling=self.with_std + ) def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -142,7 +150,7 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, 'epochs_data') + _validate_type(epochs_data, np.ndarray, "epochs_data") if epochs_data.ndim == 2: epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape @@ -167,14 +175,13 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, 'epochs_data') + _validate_type(epochs_data, np.ndarray, "epochs_data") if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: - assert len(self.info['ch_names']) == epochs_data.shape[1] + assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, - epochs_data) + return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -221,8 +228,7 @@ def inverse_transform(self, epochs_data): memory usage may be large with big data. """ assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.inverse_transform, True, - epochs_data) + return _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) class Vectorizer(TransformerMixin): @@ -282,8 +288,7 @@ def transform(self, X): """ X = np.asarray(X) if X.shape[1:] != self.features_shape_: - raise ValueError("Shape of X used in fit and transform must be " - "same") + raise ValueError("Shape of X used in fit and transform must be " "same") return X.reshape(len(X), -1) def fit_transform(self, X, y=None): @@ -322,8 +327,9 @@ def inverse_transform(self, X): """ X = np.asarray(X) if X.ndim not in (2, 3): - raise ValueError("X should be of 2 or 3 dimensions but has shape " - "%s" % (X.shape,)) + raise ValueError( + "X should be of 2 or 3 dimensions but has shape " "%s" % (X.shape,) + ) return X.reshape(X.shape[:-1] + self.features_shape_) @@ -361,9 +367,19 @@ class PSDEstimator(TransformerMixin): """ @verbose - def __init__(self, sfreq=2 * np.pi, fmin=0, fmax=np.inf, bandwidth=None, - adaptive=False, low_bias=True, n_jobs=None, - normalization='length', *, verbose=None): # noqa: D102 + def __init__( + self, + sfreq=2 * np.pi, + fmin=0, + fmax=np.inf, + bandwidth=None, + adaptive=False, + low_bias=True, + n_jobs=None, + normalization="length", + *, + verbose=None + ): # noqa: D102 self.sfreq = sfreq self.fmin = fmin self.fmax = fmax @@ -389,8 +405,9 @@ def fit(self, epochs_data, y): The modified instance. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) return self @@ -408,13 +425,20 @@ def transform(self, epochs_data): The computed PSD. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) psd, _ = psd_array_multitaper( - epochs_data, sfreq=self.sfreq, fmin=self.fmin, fmax=self.fmax, - bandwidth=self.bandwidth, adaptive=self.adaptive, - low_bias=self.low_bias, normalization=self.normalization, - n_jobs=self.n_jobs) + epochs_data, + sfreq=self.sfreq, + fmin=self.fmin, + fmax=self.fmax, + bandwidth=self.bandwidth, + adaptive=self.adaptive, + low_bias=self.low_bias, + normalization=self.normalization, + n_jobs=self.n_jobs, + ) return psd @@ -469,10 +493,22 @@ class FilterEstimator(TransformerMixin): caution. """ - def __init__(self, info, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - n_jobs=None, method='fir', iir_params=None, - fir_design='firwin', *, verbose=None): # noqa: D102 + def __init__( + self, + info, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + fir_design="firwin", + *, + verbose=None + ): # noqa: D102 self.info = info self.l_freq = l_freq self.h_freq = h_freq @@ -501,37 +537,39 @@ def fit(self, epochs_data, y): The modified instance. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) if self.picks is None: - self.picks = pick_types(self.info, meg=True, eeg=True, - ref_meg=False, exclude=[]) + self.picks = pick_types( + self.info, meg=True, eeg=True, ref_meg=False, exclude=[] + ) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info['sfreq'] / 2.): + if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): self.h_freq = None if self.l_freq is not None and not isinstance(self.l_freq, float): self.l_freq = float(self.l_freq) if self.h_freq is not None and not isinstance(self.h_freq, float): self.h_freq = float(self.h_freq) - if self.info['lowpass'] is None or (self.h_freq is not None and - (self.l_freq is None or - self.l_freq < self.h_freq) and - self.h_freq < - self.info['lowpass']): + if self.info["lowpass"] is None or ( + self.h_freq is not None + and (self.l_freq is None or self.l_freq < self.h_freq) + and self.h_freq < self.info["lowpass"] + ): with self.info._unlock(): - self.info['lowpass'] = self.h_freq + self.info["lowpass"] = self.h_freq - if self.info['highpass'] is None or (self.l_freq is not None and - (self.h_freq is None or - self.l_freq < self.h_freq) and - self.l_freq > - self.info['highpass']): + if self.info["highpass"] is None or ( + self.l_freq is not None + and (self.h_freq is None or self.l_freq < self.h_freq) + and self.l_freq > self.info["highpass"] + ): with self.info._unlock(): - self.info['highpass'] = self.l_freq + self.info["highpass"] = self.l_freq return self @@ -549,15 +587,26 @@ def transform(self, epochs_data): The data after filtering. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, self.info['sfreq'], self.l_freq, self.h_freq, - self.picks, self.filter_length, self.l_trans_bandwidth, - self.h_trans_bandwidth, method=self.method, - iir_params=self.iir_params, n_jobs=self.n_jobs, copy=False, - fir_design=self.fir_design, verbose=False) + epochs_data, + self.info["sfreq"], + self.l_freq, + self.h_freq, + self.picks, + self.filter_length, + self.l_trans_bandwidth, + self.h_trans_bandwidth, + method=self.method, + iir_params=self.iir_params, + n_jobs=self.n_jobs, + copy=False, + fir_design=self.fir_design, + verbose=False, + ) class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): @@ -574,14 +623,17 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): def __init__(self, estimator, average=False): # noqa: D102 # XXX: Use _check_estimator #3381 - for attr in ('fit', 'transform', 'fit_transform'): + for attr in ("fit", "transform", "fit_transform"): if not hasattr(estimator, attr): - raise ValueError('estimator must be a scikit-learn ' - 'transformer, missing %s method' % attr) + raise ValueError( + "estimator must be a scikit-learn " + "transformer, missing %s method" % attr + ) if not isinstance(average, bool): - raise ValueError("average parameter must be of bool type, got " - "%s instead" % type(bool)) + raise ValueError( + "average parameter must be of bool type, got " "%s instead" % type(bool) + ) self.estimator = estimator self.average = average @@ -606,8 +658,7 @@ def fit(self, X, y=None): else: n_epochs, n_channels, n_times = X.shape # trial as time samples - X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * - n_times)).T + X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T self.estimator.fit(X) return self @@ -641,7 +692,7 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ - return self._apply_method(X, 'transform') + return self._apply_method(X, "transform") def inverse_transform(self, X): """Inverse transform the data to its original space. @@ -656,7 +707,7 @@ def inverse_transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ - return self._apply_method(X, 'inverse_transform') + return self._apply_method(X, "inverse_transform") def _apply_method(self, X, method): """Vectorize time samples as trials, apply method and reshape back. @@ -768,11 +819,22 @@ class TemporalFilter(TransformerMixin): """ @verbose - def __init__(self, l_freq=None, h_freq=None, sfreq=1.0, - filter_length='auto', l_trans_bandwidth='auto', - h_trans_bandwidth='auto', n_jobs=None, method='fir', - iir_params=None, fir_window='hamming', fir_design='firwin', - *, verbose=None): # noqa: D102 + def __init__( + self, + l_freq=None, + h_freq=None, + sfreq=1.0, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + fir_window="hamming", + fir_design="firwin", + *, + verbose=None + ): # noqa: D102 self.l_freq = l_freq self.h_freq = h_freq self.sfreq = sfreq @@ -785,9 +847,10 @@ def __init__(self, l_freq=None, h_freq=None, sfreq=1.0, self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == 'cuda': - raise ValueError('n_jobs must be int or "cuda", got %s instead.' - % type(self.n_jobs)) + if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": + raise ValueError( + 'n_jobs must be int or "cuda", got %s instead.' % type(self.n_jobs) + ) def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). @@ -824,16 +887,26 @@ def transform(self, X): X = np.atleast_2d(X) if X.ndim > 3: - raise ValueError("Array must be of at max 3 dimensions instead " - "got %s dimensional matrix" % (X.ndim)) + raise ValueError( + "Array must be of at max 3 dimensions instead " + "got %s dimensional matrix" % (X.ndim) + ) shape = X.shape X = X.reshape(-1, shape[-1]) - X = filter_data(X, self.sfreq, self.l_freq, self.h_freq, - filter_length=self.filter_length, - l_trans_bandwidth=self.l_trans_bandwidth, - h_trans_bandwidth=self.h_trans_bandwidth, - n_jobs=self.n_jobs, method=self.method, - iir_params=self.iir_params, copy=False, - fir_window=self.fir_window, fir_design=self.fir_design) + X = filter_data( + X, + self.sfreq, + self.l_freq, + self.h_freq, + filter_length=self.filter_length, + l_trans_bandwidth=self.l_trans_bandwidth, + h_trans_bandwidth=self.h_trans_bandwidth, + n_jobs=self.n_jobs, + method=self.method, + iir_params=self.iir_params, + copy=False, + fir_window=self.fir_window, + fir_design=self.fir_design, + ) return X.reshape(shape) diff --git a/mne/defaults.py b/mne/defaults.py index 16b3b843406..498312caa15 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -7,81 +7,232 @@ from copy import deepcopy DEFAULTS = dict( - color=dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m', emg='k', - ref_meg='steelblue', misc='k', stim='k', resp='k', chpi='k', - exci='k', ias='k', syst='k', seeg='saddlebrown', dbs='seagreen', - dipole='k', gof='k', bio='k', ecog='k', hbo='#AA3377', hbr='b', - fnirs_cw_amplitude='k', fnirs_fd_ac_amplitude='k', - fnirs_fd_phase='k', fnirs_od='k', csd='k', whitened='k', - gsr='#666633', temperature='#663333', - eyegaze='k', pupil='k'), - si_units=dict(mag='T', grad='T/m', eeg='V', eog='V', ecg='V', emg='V', - misc='AU', seeg='V', dbs='V', dipole='Am', gof='GOF', - bio='V', ecog='V', hbo='M', hbr='M', ref_meg='T', - fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', - fnirs_fd_phase='rad', fnirs_od='V', csd='V/m²', - whitened='Z', gsr='S', temperature='C', - eyegaze='AU', pupil='AU'), - units=dict(mag='fT', grad='fT/cm', eeg='µV', eog='µV', ecg='µV', emg='µV', - misc='AU', seeg='mV', dbs='µV', dipole='nAm', gof='GOF', - bio='µV', ecog='µV', hbo='µM', hbr='µM', ref_meg='fT', - fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', - fnirs_fd_phase='rad', fnirs_od='V', csd='mV/m²', - whitened='Z', gsr='S', temperature='C', - eyegaze='AU', pupil='AU'), + color=dict( + mag="darkblue", + grad="b", + eeg="k", + eog="k", + ecg="m", + emg="k", + ref_meg="steelblue", + misc="k", + stim="k", + resp="k", + chpi="k", + exci="k", + ias="k", + syst="k", + seeg="saddlebrown", + dbs="seagreen", + dipole="k", + gof="k", + bio="k", + ecog="k", + hbo="#AA3377", + hbr="b", + fnirs_cw_amplitude="k", + fnirs_fd_ac_amplitude="k", + fnirs_fd_phase="k", + fnirs_od="k", + csd="k", + whitened="k", + gsr="#666633", + temperature="#663333", + eyegaze="k", + pupil="k", + ), + si_units=dict( + mag="T", + grad="T/m", + eeg="V", + eog="V", + ecg="V", + emg="V", + misc="AU", + seeg="V", + dbs="V", + dipole="Am", + gof="GOF", + bio="V", + ecog="V", + hbo="M", + hbr="M", + ref_meg="T", + fnirs_cw_amplitude="V", + fnirs_fd_ac_amplitude="V", + fnirs_fd_phase="rad", + fnirs_od="V", + csd="V/m²", + whitened="Z", + gsr="S", + temperature="C", + eyegaze="AU", + pupil="AU", + ), + units=dict( + mag="fT", + grad="fT/cm", + eeg="µV", + eog="µV", + ecg="µV", + emg="µV", + misc="AU", + seeg="mV", + dbs="µV", + dipole="nAm", + gof="GOF", + bio="µV", + ecog="µV", + hbo="µM", + hbr="µM", + ref_meg="fT", + fnirs_cw_amplitude="V", + fnirs_fd_ac_amplitude="V", + fnirs_fd_phase="rad", + fnirs_od="V", + csd="mV/m²", + whitened="Z", + gsr="S", + temperature="C", + eyegaze="AU", + pupil="AU", + ), # scalings for the units - scalings=dict(mag=1e15, grad=1e13, eeg=1e6, eog=1e6, emg=1e6, ecg=1e6, - misc=1.0, seeg=1e3, dbs=1e6, ecog=1e6, dipole=1e9, gof=1.0, - bio=1e6, hbo=1e6, hbr=1e6, ref_meg=1e15, - fnirs_cw_amplitude=1.0, fnirs_fd_ac_amplitude=1.0, - fnirs_fd_phase=1., fnirs_od=1.0, csd=1e3, whitened=1., - gsr=1., temperature=1., eyegaze=1., pupil=1.), + scalings=dict( + mag=1e15, + grad=1e13, + eeg=1e6, + eog=1e6, + emg=1e6, + ecg=1e6, + misc=1.0, + seeg=1e3, + dbs=1e6, + ecog=1e6, + dipole=1e9, + gof=1.0, + bio=1e6, + hbo=1e6, + hbr=1e6, + ref_meg=1e15, + fnirs_cw_amplitude=1.0, + fnirs_fd_ac_amplitude=1.0, + fnirs_fd_phase=1.0, + fnirs_od=1.0, + csd=1e3, + whitened=1.0, + gsr=1.0, + temperature=1.0, + eyegaze=1.0, + pupil=1.0, + ), # rough guess for a good plot - scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, - ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc='auto', - stim=1, resp=1, chpi=1e-4, exci=1, ias=1, syst=1, - seeg=1e-4, dbs=1e-4, bio=1e-6, ecog=1e-4, hbo=10e-6, - hbr=10e-6, whitened=10., fnirs_cw_amplitude=2e-2, - fnirs_fd_ac_amplitude=2e-2, fnirs_fd_phase=2e-1, - fnirs_od=2e-2, csd=200e-4, - dipole=1e-7, gof=1e2, - gsr=1., temperature=0.1, - eyegaze=3e-1, pupil=1e3), - scalings_cov_rank=dict(mag=1e12, grad=1e11, eeg=1e5, # ~100x scalings - seeg=1e1, dbs=1e4, ecog=1e4, hbo=1e4, hbr=1e4), - ylim=dict(mag=(-600., 600.), grad=(-200., 200.), eeg=(-200., 200.), - misc=(-5., 5.), seeg=(-20., 20.), dbs=(-200., 200.), - dipole=(-100., 100.), gof=(0., 1.), bio=(-500., 500.), - ecog=(-200., 200.), hbo=(0, 20), hbr=(0, 20), csd=(-50., 50.), - eyegaze=(0., 5000.), pupil=(0., 5000.)), - titles=dict(mag='Magnetometers', grad='Gradiometers', eeg='EEG', eog='EOG', - ecg='ECG', emg='EMG', misc='misc', seeg='sEEG', dbs='DBS', - bio='BIO', dipole='Dipole', ecog='ECoG', hbo='Oxyhemoglobin', - ref_meg='Reference Magnetometers', - fnirs_cw_amplitude='fNIRS (CW amplitude)', - fnirs_fd_ac_amplitude='fNIRS (FD AC amplitude)', - fnirs_fd_phase='fNIRS (FD phase)', - fnirs_od='fNIRS (OD)', hbr='Deoxyhemoglobin', - gof='Goodness of fit', csd='Current source density', - stim='Stimulus', gsr='Galvanic skin response', - temperature='Temperature', - eyegaze='Eye-tracking (Gaze position)', - pupil='Eye-tracking (Pupil size)', - ), - mask_params=dict(marker='o', - markerfacecolor='w', - markeredgecolor='k', - linewidth=0, - markeredgewidth=1, - markersize=4), + scalings_plot_raw=dict( + mag=1e-12, + grad=4e-11, + eeg=20e-6, + eog=150e-6, + ecg=5e-4, + emg=1e-3, + ref_meg=1e-12, + misc="auto", + stim=1, + resp=1, + chpi=1e-4, + exci=1, + ias=1, + syst=1, + seeg=1e-4, + dbs=1e-4, + bio=1e-6, + ecog=1e-4, + hbo=10e-6, + hbr=10e-6, + whitened=10.0, + fnirs_cw_amplitude=2e-2, + fnirs_fd_ac_amplitude=2e-2, + fnirs_fd_phase=2e-1, + fnirs_od=2e-2, + csd=200e-4, + dipole=1e-7, + gof=1e2, + gsr=1.0, + temperature=0.1, + eyegaze=3e-1, + pupil=1e3, + ), + scalings_cov_rank=dict( + mag=1e12, + grad=1e11, + eeg=1e5, # ~100x scalings + seeg=1e1, + dbs=1e4, + ecog=1e4, + hbo=1e4, + hbr=1e4, + ), + ylim=dict( + mag=(-600.0, 600.0), + grad=(-200.0, 200.0), + eeg=(-200.0, 200.0), + misc=(-5.0, 5.0), + seeg=(-20.0, 20.0), + dbs=(-200.0, 200.0), + dipole=(-100.0, 100.0), + gof=(0.0, 1.0), + bio=(-500.0, 500.0), + ecog=(-200.0, 200.0), + hbo=(0, 20), + hbr=(0, 20), + csd=(-50.0, 50.0), + eyegaze=(0.0, 5000.0), + pupil=(0.0, 5000.0), + ), + titles=dict( + mag="Magnetometers", + grad="Gradiometers", + eeg="EEG", + eog="EOG", + ecg="ECG", + emg="EMG", + misc="misc", + seeg="sEEG", + dbs="DBS", + bio="BIO", + dipole="Dipole", + ecog="ECoG", + hbo="Oxyhemoglobin", + ref_meg="Reference Magnetometers", + fnirs_cw_amplitude="fNIRS (CW amplitude)", + fnirs_fd_ac_amplitude="fNIRS (FD AC amplitude)", + fnirs_fd_phase="fNIRS (FD phase)", + fnirs_od="fNIRS (OD)", + hbr="Deoxyhemoglobin", + gof="Goodness of fit", + csd="Current source density", + stim="Stimulus", + gsr="Galvanic skin response", + temperature="Temperature", + eyegaze="Eye-tracking (Gaze position)", + pupil="Eye-tracking (Pupil size)", + ), + mask_params=dict( + marker="o", + markerfacecolor="w", + markeredgecolor="k", + linewidth=0, + markeredgewidth=1, + markersize=4, + ), coreg=dict( mri_fid_opacity=1.0, dig_fid_opacity=1.0, - mri_fid_scale=5e-3, dig_fid_scale=8e-3, extra_scale=4e-3, - eeg_scale=4e-3, eegp_scale=20e-3, eegp_height=0.1, + eeg_scale=4e-3, + eegp_scale=20e-3, + eegp_height=0.1, ecog_scale=5e-3, seeg_scale=5e-3, dbs_scale=5e-3, @@ -89,49 +240,74 @@ source_scale=5e-3, detector_scale=5e-3, hpi_scale=4e-3, - head_color=(0.988, 0.89, 0.74), - hpi_color=(1., 0., 1.), - extra_color=(1., 1., 1.), - meg_color=(0., 0.25, 0.5), ref_meg_color=(0.5, 0.5, 0.5), + hpi_color=(1.0, 0.0, 1.0), + extra_color=(1.0, 1.0, 1.0), + meg_color=(0.0, 0.25, 0.5), + ref_meg_color=(0.5, 0.5, 0.5), helmet_color=(0.0, 0.0, 0.6), - eeg_color=(1., 0.596, 0.588), eegp_color=(0.839, 0.15, 0.16), - ecog_color=(1., 1., 1.), + eeg_color=(1.0, 0.596, 0.588), + eegp_color=(0.839, 0.15, 0.16), + ecog_color=(1.0, 1.0, 1.0), dbs_color=(0.82, 0.455, 0.659), - seeg_color=(1., 1., .3), - fnirs_color=(1., .647, 0.), - source_color=(1., .05, 0.), - detector_color=(.3, .15, .15), - lpa_color=(1., 0., 0.), - nasion_color=(0., 1., 0.), - rpa_color=(0., 0., 1.), + seeg_color=(1.0, 1.0, 0.3), + fnirs_color=(1.0, 0.647, 0.0), + source_color=(1.0, 0.05, 0.0), + detector_color=(0.3, 0.15, 0.15), + lpa_color=(1.0, 0.0, 0.0), + nasion_color=(0.0, 1.0, 0.0), + rpa_color=(0.0, 0.0, 1.0), ), noise_std=dict(grad=5e-13, mag=20e-15, eeg=0.2e-6), eloreta_options=dict(eps=1e-6, max_iter=20, force_equal=False), - depth_mne=dict(exp=0.8, limit=10., limit_depth_chs=True, - combine_xyz='spectral', allow_fixed_depth=False), - depth_sparse=dict(exp=0.8, limit=None, limit_depth_chs='whiten', - combine_xyz='fro', allow_fixed_depth=True), - interpolation_method=dict(eeg='spline', meg='MNE', fnirs='nearest'), + depth_mne=dict( + exp=0.8, + limit=10.0, + limit_depth_chs=True, + combine_xyz="spectral", + allow_fixed_depth=False, + ), + depth_sparse=dict( + exp=0.8, + limit=None, + limit_depth_chs="whiten", + combine_xyz="fro", + allow_fixed_depth=True, + ), + interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"), volume_options=dict( - alpha=None, resolution=1., surface_alpha=None, blending='mip', - silhouette_alpha=None, silhouette_linewidth=2.), - prefixes={'k': 1e-3, 'h': 1e-2, '': 1e0, 'd': 1e1, 'c': 1e2, 'm': 1e3, - 'µ': 1e6, 'u': 1e6, 'n': 1e9, 'p': 1e12, 'f': 1e15}, - transform_zooms=dict( - translation=None, rigid=None, affine=None, sdr=None), + alpha=None, + resolution=1.0, + surface_alpha=None, + blending="mip", + silhouette_alpha=None, + silhouette_linewidth=2.0, + ), + prefixes={ + "k": 1e-3, + "h": 1e-2, + "": 1e0, + "d": 1e1, + "c": 1e2, + "m": 1e3, + "µ": 1e6, + "u": 1e6, + "n": 1e9, + "p": 1e12, + "f": 1e15, + }, + transform_zooms=dict(translation=None, rigid=None, affine=None, sdr=None), transform_niter=dict( translation=(10000, 1000, 100), rigid=(10000, 1000, 100), affine=(10000, 1000, 100), - sdr=(10, 10, 5)), + sdr=(10, 10, 5), + ), volume_label_indices=( # Left and middle 4, # Left-Lateral-Ventricle 5, # Left-Inf-Lat-Vent - 8, # Left-Cerebellum-Cortex - 10, # Left-Thalamus-Proper 11, # Left-Caudate 12, # Left-Putamen @@ -141,44 +317,32 @@ 16, # Brain-Stem 17, # Left-Hippocampus 18, # Left-Amygdala - 26, # Left-Accumbens-area - 28, # Left-VentralDC - # Right 43, # Right-Lateral-Ventricle 44, # Right-Inf-Lat-Vent - 47, # Right-Cerebellum-Cortex - 49, # Right-Thalamus-Proper 50, # Right-Caudate 51, # Right-Putamen 52, # Right-Pallidum 53, # Right-Hippocampus 54, # Right-Amygdala - 58, # Right-Accumbens-area - 60, # Right-VentralDC ), report_stc_plot_kwargs=dict( - views=('lateral', 'medial'), - hemi='split', - backend='pyvistaqt', + views=("lateral", "medial"), + hemi="split", + backend="pyvistaqt", time_viewer=False, show_traces=False, size=(450, 450), - background='white', + background="white", time_label=None, - add_data_kwargs={ - 'colorbar_kwargs': { - 'label_font_size': 12, - 'n_labels': 5 - } - } - ) + add_data_kwargs={"colorbar_kwargs": {"label_font_size": 12, "n_labels": 5}}, + ), ) @@ -201,6 +365,6 @@ def _handle_default(k, v=None): HEAD_SIZE_DEFAULT = 0.095 # in [m] -_BORDER_DEFAULT = 'mean' -_INTERPOLATION_DEFAULT = 'cubic' -_EXTRAPOLATE_DEFAULT = 'auto' +_BORDER_DEFAULT = "mean" +_INTERPOLATION_DEFAULT = "cubic" +_EXTRAPOLATE_DEFAULT = "auto" diff --git a/mne/dipole.py b/mne/dipole.py index 65fe90a39a3..6083b9bfbdd 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -23,21 +23,36 @@ from .transforms import _print_coord_trans, _coord_frame_name, apply_trans from .viz.evoked import _plot_evoked from ._freesurfer import head_to_mni, head_to_mri -from .forward._make_forward import (_get_trans, _setup_bem, - _prep_meg_channels, _prep_eeg_channels) -from .forward._compute_forward import (_compute_forwards_meeg, - _prep_field_computation) - -from .surface import (transform_surface_to, _compute_nearest, - _points_outside_surface) +from .forward._make_forward import ( + _get_trans, + _setup_bem, + _prep_meg_channels, + _prep_eeg_channels, +) +from .forward._compute_forward import _compute_forwards_meeg, _prep_field_computation + +from .surface import transform_surface_to, _compute_nearest, _points_outside_surface from .bem import _bem_find_surface, _bem_surf_name from .source_space import _make_volume_source_space, SourceSpaces from .parallel import parallel_func -from .utils import (logger, verbose, _time_mask, warn, _check_fname, - check_fname, _pl, fill_doc, _check_option, - _svd_lwork, _repeated_svd, _get_blas_funcs, _validate_type, - copy_function_doc_to_method_doc, TimeMixin, - _verbose_safe_false) +from .utils import ( + logger, + verbose, + _time_mask, + warn, + _check_fname, + check_fname, + _pl, + fill_doc, + _check_option, + _svd_lwork, + _repeated_svd, + _get_blas_funcs, + _validate_type, + copy_function_doc_to_method_doc, + TimeMixin, + _verbose_safe_false, +) from .viz import plot_dipole_locations @@ -101,9 +116,20 @@ class Dipole(TimeMixin): """ @verbose - def __init__(self, times, pos, amplitude, ori, gof, - name=None, conf=None, khi2=None, nfree=None, - *, verbose=None): # noqa: D102 + def __init__( + self, + times, + pos, + amplitude, + ori, + gof, + name=None, + conf=None, + khi2=None, + nfree=None, + *, + verbose=None, + ): # noqa: D102 self._set_times(np.array(times)) self.pos = np.array(pos) self.amplitude = np.array(amplitude) @@ -168,11 +194,12 @@ def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None): """ sfreq = None if len(self.times) > 1: - sfreq = 1. / np.median(np.diff(self.times)) - mask = _time_mask(self.times, tmin, tmax, sfreq=sfreq, - include_tmax=include_tmax) + sfreq = 1.0 / np.median(np.diff(self.times)) + mask = _time_mask( + self.times, tmin, tmax, sfreq=sfreq, include_tmax=include_tmax + ) self._set_times(self.times[mask]) - for attr in ('pos', 'gof', 'amplitude', 'ori', 'khi2', 'nfree'): + for attr in ("pos", "gof", "amplitude", "ori", "khi2", "nfree"): if getattr(self, attr) is not None: setattr(self, attr, getattr(self, attr)[mask]) for key in self.conf.keys(): @@ -191,21 +218,53 @@ def copy(self): @verbose @copy_function_doc_to_method_doc(plot_dipole_locations) - def plot_locations(self, trans, subject, subjects_dir=None, - mode='orthoview', coord_frame='mri', idx='gof', - show_all=True, ax=None, block=False, show=True, - scale=None, color=None, *, highlight_color='r', - fig=None, title=None, head_source='seghead', - surf='pial', width=None, verbose=None): + def plot_locations( + self, + trans, + subject, + subjects_dir=None, + mode="orthoview", + coord_frame="mri", + idx="gof", + show_all=True, + ax=None, + block=False, + show=True, + scale=None, + color=None, + *, + highlight_color="r", + fig=None, + title=None, + head_source="seghead", + surf="pial", + width=None, + verbose=None, + ): return plot_dipole_locations( - self, trans, subject, subjects_dir, mode, coord_frame, idx, - show_all, ax, block, show, scale=scale, color=color, - highlight_color=highlight_color, fig=fig, title=title, - head_source=head_source, surf=surf, width=width) + self, + trans, + subject, + subjects_dir, + mode, + coord_frame, + idx, + show_all, + ax, + block, + show, + scale=scale, + color=color, + highlight_color=highlight_color, + fig=fig, + title=title, + head_source=head_source, + surf=surf, + width=width, + ) @verbose - def to_mni(self, subject, trans, subjects_dir=None, - verbose=None): + def to_mni(self, subject, trans, subjects_dir=None, verbose=None): """Convert dipole location from head to MNI coordinates. Parameters @@ -221,12 +280,12 @@ def to_mni(self, subject, trans, subjects_dir=None, The MNI coordinates (in mm) of pos. """ mri_head_t, trans = _get_trans(trans) - return head_to_mni(self.pos, subject, mri_head_t, - subjects_dir=subjects_dir, verbose=verbose) + return head_to_mni( + self.pos, subject, mri_head_t, subjects_dir=subjects_dir, verbose=verbose + ) @verbose - def to_mri(self, subject, trans, subjects_dir=None, - verbose=None): + def to_mri(self, subject, trans, subjects_dir=None, verbose=None): """Convert dipole location from head to MRI surface RAS coordinates. Parameters @@ -242,13 +301,24 @@ def to_mri(self, subject, trans, subjects_dir=None, The Freesurfer surface RAS coordinates (in mm) of pos. """ mri_head_t, trans = _get_trans(trans) - return head_to_mri(self.pos, subject, mri_head_t, - subjects_dir=subjects_dir, verbose=verbose, - kind='mri') + return head_to_mri( + self.pos, + subject, + mri_head_t, + subjects_dir=subjects_dir, + verbose=verbose, + kind="mri", + ) @verbose - def to_volume_labels(self, trans, subject='fsaverage', aseg='aparc+aseg', - subjects_dir=None, verbose=None): + def to_volume_labels( + self, + trans, + subject="fsaverage", + aseg="aparc+aseg", + subjects_dir=None, + verbose=None, + ): """Find an ROI in atlas for the dipole positions. Parameters @@ -279,16 +349,15 @@ def to_volume_labels(self, trans, subject='fsaverage', aseg='aparc+aseg', lut = {v: k for k, v in lut_inv.items()} # transform to voxel space from head space - pos = self.to_mri(subject, trans, subjects_dir=subjects_dir, - verbose=verbose) + pos = self.to_mri(subject, trans, subjects_dir=subjects_dir, verbose=verbose) pos = apply_trans(mri_vox_t, pos) pos = np.rint(pos).astype(int) # Get voxel value and label from LUT - labels = [lut.get(aseg_data[tuple(coord)], 'Unknown') for coord in pos] + labels = [lut.get(aseg_data[tuple(coord)], "Unknown") for coord in pos] return labels - def plot_amplitudes(self, color='k', show=True): + def plot_amplitudes(self, color="k", show=True): """Plot the dipole amplitudes as a function of time. Parameters @@ -304,6 +373,7 @@ def plot_amplitudes(self, color='k', show=True): The figure object containing the plot. """ from .viz import plot_dipole_amplitudes + return plot_dipole_amplitudes([self], [color], show) def __getitem__(self, item): @@ -334,9 +404,16 @@ def __getitem__(self, item): selected_khi2 = self.khi2[item] if self.khi2 is not None else None selected_nfree = self.nfree[item] if self.nfree is not None else None return Dipole( - selected_times, selected_pos, selected_amplitude, selected_ori, - selected_gof, selected_name, selected_conf, selected_khi2, - selected_nfree) + selected_times, + selected_pos, + selected_amplitude, + selected_ori, + selected_gof, + selected_name, + selected_conf, + selected_khi2, + selected_nfree, + ) def __len__(self): """Return the number of dipoles. @@ -358,7 +435,7 @@ def __len__(self): def _read_dipole_fixed(fname): """Read a fixed dipole FIF file.""" - logger.info('Reading %s ...' % fname) + logger.info("Reading %s ..." % fname) info, nave, aspect_kind, comment, times, data, _ = _read_evoked(fname) return DipoleFixed(info, data, times, nave, aspect_kind, comment=comment) @@ -403,12 +480,13 @@ class DipoleFixed(TimeMixin): """ @verbose - def __init__(self, info, data, times, nave, aspect_kind, - comment='', *, verbose=None): # noqa: D102 + def __init__( + self, info, data, times, nave, aspect_kind, comment="", *, verbose=None + ): # noqa: D102 self.info = info self.nave = nave self._aspect_kind = aspect_kind - self.kind = _aspect_rev.get(aspect_kind, 'unknown') + self.kind = _aspect_rev.get(aspect_kind, "unknown") self.comment = comment self._set_times(np.array(times)) self.data = data @@ -438,7 +516,7 @@ def copy(self): @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @verbose def save(self, fname, verbose=None): @@ -452,12 +530,20 @@ def save(self, fname, verbose=None): dipole information in FIF format. %(verbose)s """ - check_fname(fname, 'DipoleFixed', ('-dip.fif', '-dip.fif.gz', - '_dip.fif', '_dip.fif.gz',), - ('.fif', '.fif.gz')) + check_fname( + fname, + "DipoleFixed", + ( + "-dip.fif", + "-dip.fif.gz", + "_dip.fif", + "_dip.fif.gz", + ), + (".fif", ".fif.gz"), + ) _write_evokeds(fname, self, check=False) - def plot(self, show=True, time_unit='s'): + def plot(self, show=True, time_unit="s"): """Plot dipole data. Parameters @@ -474,12 +560,27 @@ def plot(self, show=True, time_unit='s'): fig : instance of matplotlib.figure.Figure The figure containing the time courses. """ - return _plot_evoked(self, picks=None, exclude=(), unit=True, show=show, - ylim=None, xlim='tight', proj=False, hline=None, - units=None, scalings=None, titles=None, axes=None, - gfp=False, window_title=None, spatial_colors=False, - plot_type="butterfly", selectable=False, - time_unit=time_unit) + return _plot_evoked( + self, + picks=None, + exclude=(), + unit=True, + show=show, + ylim=None, + xlim="tight", + proj=False, + hline=None, + units=None, + scalings=None, + titles=None, + axes=None, + gfp=False, + window_title=None, + spatial_colors=False, + plot_type="butterfly", + selectable=False, + time_unit=time_unit, + ) # ############################################################################# @@ -509,7 +610,7 @@ def read_dipole(fname, verbose=None): .. versionchanged:: 0.20 Support for reading bdip (Xfit binary) format. """ - fname = _check_fname(fname, overwrite='read', must_exist=True) + fname = _check_fname(fname, overwrite="read", must_exist=True) if fname.suffix == ".fif" or fname.name.endswith(".fif.gz"): return _read_dipole_fixed(fname) elif fname.suffix == ".bdip": @@ -526,69 +627,96 @@ def _read_dipole_text(fname): # There is a bug in older np.loadtxt regarding skipping fields, # so just read the data ourselves (need to get name and header anyway) data = list() - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: - if not (line.startswith('%') or line.startswith('#')): + if not (line.startswith("%") or line.startswith("#")): need_header = False data.append(line.strip().split()) else: if need_header: def_line = line - if line.startswith('##') or line.startswith('%%'): + if line.startswith("##") or line.startswith("%%"): m = re.search('Name "(.*) dipoles"', line) if m: name = m.group(1) del line data = np.atleast_2d(np.array(data, float)) if def_line is None: - raise OSError('Dipole text file is missing field definition ' - 'comment, cannot parse %s' % (fname,)) + raise OSError( + "Dipole text file is missing field definition " + "comment, cannot parse %s" % (fname,) + ) # actually parse the fields - def_line = def_line.lstrip('%').lstrip('#').strip() + def_line = def_line.lstrip("%").lstrip("#").strip() # MNE writes it out differently than Elekta, let's standardize them... - fields = re.sub(r'([X|Y|Z] )\(mm\)', # "X (mm)", etc. - lambda match: match.group(1).strip() + '/mm', def_line) - fields = re.sub(r'\((.*?)\)', # "Q(nAm)", etc. - lambda match: '/' + match.group(1), fields) - fields = re.sub('(begin|end) ', # "begin" and "end" with no units - lambda match: match.group(1) + '/ms', fields) + fields = re.sub( + r"([X|Y|Z] )\(mm\)", # "X (mm)", etc. + lambda match: match.group(1).strip() + "/mm", + def_line, + ) + fields = re.sub( + r"\((.*?)\)", lambda match: "/" + match.group(1), fields # "Q(nAm)", etc. + ) + fields = re.sub( + "(begin|end) ", # "begin" and "end" with no units + lambda match: match.group(1) + "/ms", + fields, + ) fields = fields.lower().split() - required_fields = ('begin/ms', - 'x/mm', 'y/mm', 'z/mm', - 'q/nam', 'qx/nam', 'qy/nam', 'qz/nam', - 'g/%') - optional_fields = ('khi^2', 'free', # standard ones - # now the confidence fields (up to 5!) - 'vol/mm^3', 'depth/mm', 'long/mm', 'trans/mm', - 'qlong/nam', 'qtrans/nam') + required_fields = ( + "begin/ms", + "x/mm", + "y/mm", + "z/mm", + "q/nam", + "qx/nam", + "qy/nam", + "qz/nam", + "g/%", + ) + optional_fields = ( + "khi^2", + "free", # standard ones + # now the confidence fields (up to 5!) + "vol/mm^3", + "depth/mm", + "long/mm", + "trans/mm", + "qlong/nam", + "qtrans/nam", + ) conf_scales = [1e-9, 1e-3, 1e-3, 1e-3, 1e-9, 1e-9] missing_fields = sorted(set(required_fields) - set(fields)) if len(missing_fields) > 0: - raise RuntimeError('Could not find necessary fields in header: %s' - % (missing_fields,)) + raise RuntimeError( + "Could not find necessary fields in header: %s" % (missing_fields,) + ) handled_fields = set(required_fields) | set(optional_fields) assert len(handled_fields) == len(required_fields) + len(optional_fields) - ignored_fields = sorted(set(fields) - - set(handled_fields) - - {'end/ms'}) + ignored_fields = sorted(set(fields) - set(handled_fields) - {"end/ms"}) if len(ignored_fields) > 0: - warn('Ignoring extra fields in dipole file: %s' % (ignored_fields,)) + warn("Ignoring extra fields in dipole file: %s" % (ignored_fields,)) if len(fields) != data.shape[1]: - raise OSError('More data fields (%s) found than data columns (%s): %s' - % (len(fields), data.shape[1], fields)) + raise OSError( + "More data fields (%s) found than data columns (%s): %s" + % (len(fields), data.shape[1], fields) + ) logger.info("%d dipole(s) found" % len(data)) - if 'end/ms' in fields: - if np.diff(data[:, [fields.index('begin/ms'), - fields.index('end/ms')]], 1, -1).any(): - warn('begin and end fields differed, but only begin will be used ' - 'to store time values') + if "end/ms" in fields: + if np.diff( + data[:, [fields.index("begin/ms"), fields.index("end/ms")]], 1, -1 + ).any(): + warn( + "begin and end fields differed, but only begin will be used " + "to store time values" + ) # Find the correct column in our data array, then scale to proper units idx = [fields.index(field) for field in required_fields] assert len(idx) >= 9 - times = data[:, idx[0]] / 1000. + times = data[:, idx[0]] / 1000.0 pos = 1e-3 * data[:, idx[1:4]] # put data in meters amplitude = data[:, idx[4]] norm = amplitude.copy() @@ -605,36 +733,39 @@ def _read_dipole_text(fname): conf = dict() for field, scale in zip(optional_fields[2:], conf_scales): # confidence if field in fields: - conf[field.split('/')[0]] = scale * data[:, fields.index(field)] + conf[field.split("/")[0]] = scale * data[:, fields.index(field)] return Dipole(times, pos, amplitude, ori, gof, name, conf, khi2, nfree) def _write_dipole_text(fname, dip): - fmt = ' %7.1f %7.1f %8.2f %8.2f %8.2f %8.3f %8.3f %8.3f %8.3f %6.2f' - header = ('# begin end X (mm) Y (mm) Z (mm)' - ' Q(nAm) Qx(nAm) Qy(nAm) Qz(nAm) g/%') - t = dip.times[:, np.newaxis] * 1000. + fmt = " %7.1f %7.1f %8.2f %8.2f %8.2f %8.3f %8.3f %8.3f %8.3f %6.2f" + header = ( + "# begin end X (mm) Y (mm) Z (mm)" + " Q(nAm) Qx(nAm) Qy(nAm) Qz(nAm) g/%" + ) + t = dip.times[:, np.newaxis] * 1000.0 gof = dip.gof[:, np.newaxis] amp = 1e9 * dip.amplitude[:, np.newaxis] out = (t, t, dip.pos / 1e-3, amp, dip.ori * amp, gof) # optional fields - fmts = dict(khi2=(' khi^2', ' %8.1f', 1.), - nfree=(' free', ' %5d', 1), - vol=(' vol/mm^3', ' %9.3f', 1e9), - depth=(' depth/mm', ' %9.3f', 1e3), - long=(' long/mm', ' %8.3f', 1e3), - trans=(' trans/mm', ' %9.3f', 1e3), - qlong=(' Qlong/nAm', ' %10.3f', 1e9), - qtrans=(' Qtrans/nAm', ' %11.3f', 1e9), - ) - for key in ('khi2', 'nfree'): + fmts = dict( + khi2=(" khi^2", " %8.1f", 1.0), + nfree=(" free", " %5d", 1), + vol=(" vol/mm^3", " %9.3f", 1e9), + depth=(" depth/mm", " %9.3f", 1e3), + long=(" long/mm", " %8.3f", 1e3), + trans=(" trans/mm", " %9.3f", 1e3), + qlong=(" Qlong/nAm", " %10.3f", 1e9), + qtrans=(" Qtrans/nAm", " %11.3f", 1e9), + ) + for key in ("khi2", "nfree"): data = getattr(dip, key) if data is not None: header += fmts[key][0] fmt += fmts[key][1] out += (data[:, np.newaxis] * fmts[key][2],) - for key in ('vol', 'depth', 'long', 'trans', 'qlong', 'qtrans'): + for key in ("vol", "depth", "long", "trans", "qlong", "qtrans"): data = dip.conf.get(key) if data is not None: header += fmts[key][0] @@ -643,22 +774,23 @@ def _write_dipole_text(fname, dip): out = np.concatenate(out, axis=-1) # NB CoordinateSystem is hard-coded as Head here - with open(fname, 'wb') as fid: - fid.write('# CoordinateSystem "Head"\n'.encode('utf-8')) - fid.write((header + '\n').encode('utf-8')) + with open(fname, "wb") as fid: + fid.write('# CoordinateSystem "Head"\n'.encode("utf-8")) + fid.write((header + "\n").encode("utf-8")) np.savetxt(fid, out, fmt=fmt) if dip.name is not None: - fid.write(('## Name "%s dipoles" Style "Dipoles"' - % dip.name).encode('utf-8')) + fid.write( + ('## Name "%s dipoles" Style "Dipoles"' % dip.name).encode("utf-8") + ) -_BDIP_ERROR_KEYS = ('depth', 'long', 'trans', 'qlong', 'qtrans') +_BDIP_ERROR_KEYS = ("depth", "long", "trans", "qlong", "qtrans") def _read_dipole_bdip(fname): name = None nfree = None - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: # Which dipole in a multi-dipole set times = list() pos = list() @@ -669,75 +801,77 @@ def _read_dipole_bdip(fname): khi2 = list() has_errors = None while True: - num = np.frombuffer(fid.read(4), '>i4') + num = np.frombuffer(fid.read(4), ">i4") if len(num) == 0: break - times.append(np.frombuffer(fid.read(4), '>f4')[0]) + times.append(np.frombuffer(fid.read(4), ">f4")[0]) fid.read(4) # end fid.read(12) # r0 - pos.append(np.frombuffer(fid.read(12), '>f4')) - Q = np.frombuffer(fid.read(12), '>f4') + pos.append(np.frombuffer(fid.read(12), ">f4")) + Q = np.frombuffer(fid.read(12), ">f4") amplitude.append(np.linalg.norm(Q)) ori.append(Q / amplitude[-1]) - gof.append(100 * np.frombuffer(fid.read(4), '>f4')[0]) - this_has_errors = bool(np.frombuffer(fid.read(4), '>i4')[0]) + gof.append(100 * np.frombuffer(fid.read(4), ">f4")[0]) + this_has_errors = bool(np.frombuffer(fid.read(4), ">i4")[0]) if has_errors is None: has_errors = this_has_errors for key in _BDIP_ERROR_KEYS: conf[key] = list() assert has_errors == this_has_errors fid.read(4) # Noise level used for error computations - limits = np.frombuffer(fid.read(20), '>f4') # error limits + limits = np.frombuffer(fid.read(20), ">f4") # error limits for key, lim in zip(_BDIP_ERROR_KEYS, limits): conf[key].append(lim) fid.read(100) # (5, 5) fully describes the conf. ellipsoid - conf['vol'].append(np.frombuffer(fid.read(4), '>f4')[0]) - khi2.append(np.frombuffer(fid.read(4), '>f4')[0]) + conf["vol"].append(np.frombuffer(fid.read(4), ">f4")[0]) + khi2.append(np.frombuffer(fid.read(4), ">f4")[0]) fid.read(4) # prob fid.read(4) # total noise estimate return Dipole(times, pos, amplitude, ori, gof, name, conf, khi2, nfree) def _write_dipole_bdip(fname, dip): - with open(fname, 'wb+') as fid: + with open(fname, "wb+") as fid: for ti, t in enumerate(dip.times): - fid.write(np.zeros(1, '>i4').tobytes()) # int dipole - fid.write(np.array([t, 0]).astype('>f4').tobytes()) - fid.write(np.zeros(3, '>f4').tobytes()) # r0 - fid.write(dip.pos[ti].astype('>f4').tobytes()) # pos + fid.write(np.zeros(1, ">i4").tobytes()) # int dipole + fid.write(np.array([t, 0]).astype(">f4").tobytes()) + fid.write(np.zeros(3, ">f4").tobytes()) # r0 + fid.write(dip.pos[ti].astype(">f4").tobytes()) # pos Q = dip.amplitude[ti] * dip.ori[ti] - fid.write(Q.astype('>f4').tobytes()) - fid.write(np.array(dip.gof[ti] / 100., '>f4').tobytes()) + fid.write(Q.astype(">f4").tobytes()) + fid.write(np.array(dip.gof[ti] / 100.0, ">f4").tobytes()) has_errors = int(bool(len(dip.conf))) - fid.write(np.array(has_errors, '>i4').tobytes()) # has_errors - fid.write(np.zeros(1, '>f4').tobytes()) # noise level + fid.write(np.array(has_errors, ">i4").tobytes()) # has_errors + fid.write(np.zeros(1, ">f4").tobytes()) # noise level for key in _BDIP_ERROR_KEYS: - val = dip.conf[key][ti] if key in dip.conf else 0. + val = dip.conf[key][ti] if key in dip.conf else 0.0 assert val.shape == () - fid.write(np.array(val, '>f4').tobytes()) - fid.write(np.zeros(25, '>f4').tobytes()) - conf = dip.conf['vol'][ti] if 'vol' in dip.conf else 0. - fid.write(np.array(conf, '>f4').tobytes()) + fid.write(np.array(val, ">f4").tobytes()) + fid.write(np.zeros(25, ">f4").tobytes()) + conf = dip.conf["vol"][ti] if "vol" in dip.conf else 0.0 + fid.write(np.array(conf, ">f4").tobytes()) khi2 = dip.khi2[ti] if dip.khi2 is not None else 0 - fid.write(np.array(khi2, '>f4').tobytes()) - fid.write(np.zeros(1, '>f4').tobytes()) # prob - fid.write(np.zeros(1, '>f4').tobytes()) # total noise est + fid.write(np.array(khi2, ">f4").tobytes()) + fid.write(np.zeros(1, ">f4").tobytes()) # prob + fid.write(np.zeros(1, ">f4").tobytes()) # total noise est # ############################################################################# # Fitting + def _dipole_forwards(*, sensors, fwd_data, whitener, rr, n_jobs=None): """Compute the forward solution and do other nice stuff.""" B = _compute_forwards_meeg( - rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs, silent=True) + rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs, silent=True + ) B = np.concatenate(list(B.values()), axis=1) assert np.isfinite(B).all() B_orig = B.copy() # Apply projection and whiten (cov has projections already) _, _, dgemm = _get_ddot_dgemv_dgemm() - B = dgemm(1., B, whitener.T) + B = dgemm(1.0, B, whitener.T) # column normalization doesn't affect our fitting, so skip for now # S = np.sum(B * B, axis=1) # across channels @@ -751,21 +885,30 @@ def _dipole_forwards(*, sensors, fwd_data, whitener, rr, n_jobs=None): @verbose def _make_guesses(surf, grid, exclude, mindist, n_jobs=None, verbose=None): """Make a guess space inside a sphere or BEM surface.""" - if 'rr' in surf: - logger.info('Guess surface (%s) is in %s coordinates' - % (_bem_surf_name[surf['id']], - _coord_frame_name(surf['coord_frame']))) + if "rr" in surf: + logger.info( + "Guess surface (%s) is in %s coordinates" + % (_bem_surf_name[surf["id"]], _coord_frame_name(surf["coord_frame"])) + ) else: - logger.info('Making a spherical guess space with radius %7.1f mm...' - % (1000 * surf['R'])) - logger.info('Filtering (grid = %6.f mm)...' % (1000 * grid)) - src = _make_volume_source_space(surf, grid, exclude, 1000 * mindist, - do_neighbors=False, n_jobs=n_jobs)[0] - assert 'vertno' in src + logger.info( + "Making a spherical guess space with radius %7.1f mm..." + % (1000 * surf["R"]) + ) + logger.info("Filtering (grid = %6.f mm)..." % (1000 * grid)) + src = _make_volume_source_space( + surf, grid, exclude, 1000 * mindist, do_neighbors=False, n_jobs=n_jobs + )[0] + assert "vertno" in src # simplify the result to make things easier later - src = dict(rr=src['rr'][src['vertno']], nn=src['nn'][src['vertno']], - nuse=src['nuse'], coord_frame=src['coord_frame'], - vertno=np.arange(src['nuse']), type='discrete') + src = dict( + rr=src["rr"][src["vertno"]], + nn=src["nn"][src["vertno"]], + nuse=src["nuse"], + coord_frame=src["coord_frame"], + vertno=np.arange(src["nuse"]), + type="discrete", + ) return SourceSpaces([src]) @@ -774,26 +917,26 @@ def _fit_eval(rd, B, B2, *, sensors, fwd_data, whitener, lwork, fwd_svd): if fwd_svd is None: assert sensors is not None fwd = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis, :])[0] + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis, :] + )[0] uu, sing, vv = _repeated_svd(fwd, lwork, overwrite_a=True) else: uu, sing, vv = fwd_svd gof = _dipole_gof(uu, sing, vv, B, B2)[0] # mne-c uses fitness=B2-Bm2, but ours (1-gof) is just a normalized version - return 1. - gof + return 1.0 - gof @functools.lru_cache(None) def _get_ddot_dgemv_dgemm(): - return _get_blas_funcs(np.float64, ('dot', 'gemv', 'gemm')) + return _get_blas_funcs(np.float64, ("dot", "gemv", "gemm")) def _dipole_gof(uu, sing, vv, B, B2): """Calculate the goodness of fit from the forward SVD.""" ddot, dgemv, _ = _get_ddot_dgemv_dgemm() - ncomp = 3 if sing[2] / (sing[0] if sing[0] > 0 else 1.) > 0.2 else 2 - one = dgemv(1., vv[:ncomp], B) # np.dot(vv[:ncomp], B) + ncomp = 3 if sing[2] / (sing[0] if sing[0] > 0 else 1.0) > 0.2 else 2 + one = dgemv(1.0, vv[:ncomp], B) # np.dot(vv[:ncomp], B) Bm2 = ddot(one, one) # np.sum(one * one) gof = Bm2 / B2 return gof, one @@ -802,20 +945,21 @@ def _dipole_gof(uu, sing, vv, B, B2): def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None): """Fit the dipole moment once the location is known.""" from scipy import linalg - if 'fwd' in fwd_data: + + if "fwd" in fwd_data: # should be a single precomputed "guess" (i.e., fixed position) assert rd is None - fwd = fwd_data['fwd'] + fwd = fwd_data["fwd"] assert fwd.shape[0] == 3 - fwd_orig = fwd_data['fwd_orig'] + fwd_orig = fwd_data["fwd_orig"] assert fwd_orig.shape[0] == 3 - scales = fwd_data['scales'] + scales = fwd_data["scales"] assert scales.shape == (3,) - fwd_svd = fwd_data['fwd_svd'][0] + fwd_svd = fwd_data["fwd_svd"][0] else: fwd, fwd_orig, scales = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis, :]) + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis, :] + ) fwd_svd = None if ori is None: if fwd_svd is None: @@ -838,19 +982,44 @@ def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None): return Q, gof, B_residual_noproj, ncomp -def _fit_dipoles(fun, min_dist_to_inner_skull, data, times, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, ori, n_jobs, - rank, rhoend): +def _fit_dipoles( + fun, + min_dist_to_inner_skull, + data, + times, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + ori, + n_jobs, + rank, + rhoend, +): """Fit a single dipole to the given whitened, projected data.""" from scipy.optimize import fmin_cobyla + parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) # parallel over time points res = parallel( p_fun( - min_dist_to_inner_skull, B, t, guess_rrs, guess_data, - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - fmin_cobyla=fmin_cobyla, ori=ori, rank=rank, rhoend=rhoend) - for B, t in zip(data.T, times)) + min_dist_to_inner_skull, + B, + t, + guess_rrs, + guess_data, + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + fmin_cobyla=fmin_cobyla, + ori=ori, + rank=rank, + rhoend=rhoend, + ) + for B, t in zip(data.T, times) + ) pos = np.array([r[0] for r in res]) amp = np.array([r[1] for r in res]) ori = np.array([r[2] for r in res]) @@ -858,7 +1027,7 @@ def _fit_dipoles(fun, min_dist_to_inner_skull, data, times, guess_rrs, conf = None if res[0][4] is not None: conf = np.array([r[4] for r in res]) - keys = ['vol', 'depth', 'long', 'trans', 'qlong', 'qtrans'] + keys = ["vol", "depth", "long", "trans", "qlong", "qtrans"] conf = {key: conf[:, ki] for ki, key in enumerate(keys)} khi2 = np.array([r[5] for r in res]) nfree = np.array([r[6] for r in res]) @@ -971,11 +1140,12 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): # And then the confidence interval is the diagonal of C, scaled by 1.96 # (for 95% confidence). from scipy import linalg + direction = np.empty((3, 3)) # The coordinate system has the x axis aligned with the dipole orientation, direction[0] = ori # the z axis through the origin of the sphere model - rvec = rd - fwd_data['inner_skull']['r0'] + rvec = rd - fwd_data["inner_skull"]["r0"] direction[2] = rvec - ori * np.dot(ori, rvec) # orthogonalize direction[2] /= np.linalg.norm(direction[2]) # and the y axis perpendical with these forming a right-handed system. @@ -989,15 +1159,19 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): for delta in deltas: this_r = rd[np.newaxis] + delta * direction[ii] fwds.append( - np.dot(Q, _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, - whitener=whitener, rr=this_r)[0])) + np.dot( + Q, + _dipole_forwards( + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=this_r + )[0], + ) + ) J[:, ii] = np.diff(fwds, axis=0)[0] / np.diff(deltas)[0] # Get current (Q) deltas in the dipole directions deltas = np.array([-0.01, 0.01]) * np.linalg.norm(Q) this_fwd = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis])[0] + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis] + )[0] for ii in range(3): fwds = [] for delta in deltas: @@ -1018,8 +1192,12 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): # The confidence volume of the dipole location is obtained from by # taking the eigenvalues of the upper left submatrix and computing # v = 4π/3 √(c^3 λ1 λ2 λ3) with c = 7.81, or: - vol_conf = 4 * np.pi / 3. * np.sqrt( - 476.379541 * np.prod(linalg.eigh(C[:3, :3], eigvals_only=True))) + vol_conf = ( + 4 + * np.pi + / 3.0 + * np.sqrt(476.379541 * np.prod(linalg.eigh(C[:3, :3], eigvals_only=True))) + ) conf = np.concatenate([conf, [vol_conf]]) # Now we reorder and subselect the proper columns: # vol, depth, long, trans, Qlong, Qtrans (discard Qdepth, assumed zero) @@ -1029,10 +1207,9 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): def _surface_constraint(rd, surf, min_dist_to_inner_skull): """Surface fitting constraint.""" - dist = _compute_nearest(surf['rr'], rd[np.newaxis, :], - return_dists=True)[1][0] + dist = _compute_nearest(surf["rr"], rd[np.newaxis, :], return_dists=True)[1][0] if _points_outside_surface(rd[np.newaxis, :], surf, 1)[0]: - dist *= -1. + dist *= -1.0 # Once we know the dipole is below the inner skull, # let's check if its distance to the inner skull is at least # min_dist_to_inner_skull. This can be enforced by adding a @@ -1046,45 +1223,82 @@ def _sphere_constraint(rd, r0, R_adj): return R_adj - np.sqrt(np.sum((rd - r0) ** 2)) -def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, fmin_cobyla, - ori, rank, rhoend): +def _fit_dipole( + min_dist_to_inner_skull, + B_orig, + t, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + fmin_cobyla, + ori, + rank, + rhoend, +): """Fit a single bit of data.""" B = np.dot(whitener, B_orig) # make constraint function to keep the solver within the inner skull - if 'rr' in fwd_data['inner_skull']: # bem - surf = fwd_data['inner_skull'] - constraint = partial(_surface_constraint, surf=surf, - min_dist_to_inner_skull=min_dist_to_inner_skull) + if "rr" in fwd_data["inner_skull"]: # bem + surf = fwd_data["inner_skull"] + constraint = partial( + _surface_constraint, + surf=surf, + min_dist_to_inner_skull=min_dist_to_inner_skull, + ) else: # sphere surf = None constraint = partial( - _sphere_constraint, r0=fwd_data['inner_skull']['r0'], - R_adj=fwd_data['inner_skull']['R'] - min_dist_to_inner_skull) + _sphere_constraint, + r0=fwd_data["inner_skull"]["r0"], + R_adj=fwd_data["inner_skull"]["R"] - min_dist_to_inner_skull, + ) # Find a good starting point (find_best_guess in C) B2 = np.dot(B, B) if B2 == 0: - warn('Zero field found for time %s' % t) + warn("Zero field found for time %s" % t) return np.zeros(3), 0, np.zeros(3), 0, B - idx = np.argmin([ - _fit_eval(guess_rrs[[fi], :], B, B2, fwd_svd=fwd_svd, - fwd_data=None, sensors=None, whitener=None, lwork=None) - for fi, fwd_svd in enumerate(guess_data['fwd_svd'])]) + idx = np.argmin( + [ + _fit_eval( + guess_rrs[[fi], :], + B, + B2, + fwd_svd=fwd_svd, + fwd_data=None, + sensors=None, + whitener=None, + lwork=None, + ) + for fi, fwd_svd in enumerate(guess_data["fwd_svd"]) + ] + ) x0 = guess_rrs[idx] lwork = _svd_lwork((3, B.shape[0])) - fun = partial(_fit_eval, B=B, B2=B2, fwd_data=fwd_data, whitener=whitener, - lwork=lwork, sensors=sensors, fwd_svd=None) + fun = partial( + _fit_eval, + B=B, + B2=B2, + fwd_data=fwd_data, + whitener=whitener, + lwork=lwork, + sensors=sensors, + fwd_svd=None, + ) # Tested minimizers: # Simplex, BFGS, CG, COBYLA, L-BFGS-B, Powell, SLSQP, TNC # Several were similar, but COBYLA won for having a handy constraint # function we can use to ensure we stay inside the inner skull / # smallest sphere - rd_final = fmin_cobyla(fun, x0, (constraint,), consargs=(), - rhobeg=5e-2, rhoend=rhoend, disp=False) + rd_final = fmin_cobyla( + fun, x0, (constraint,), consargs=(), rhobeg=5e-2, rhoend=rhoend, disp=False + ) # simplex = _make_tetra_simplex() + x0 # _simplex_minimize(simplex, 1e-4, 2e-4, fun) @@ -1092,45 +1306,71 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs, # Compute the dipole moment at the final point Q, gof, residual_noproj, n_comp = _fit_Q( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, B=B, B2=B2, - B_orig=B_orig, rd=rd_final, ori=ori) + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + B=B, + B2=B2, + B_orig=B_orig, + rd=rd_final, + ori=ori, + ) khi2 = (1 - gof) * B2 nfree = rank - n_comp amp = np.sqrt(np.dot(Q, Q)) - norm = 1. if amp == 0. else amp + norm = 1.0 if amp == 0.0 else amp ori = Q / norm conf = _fit_confidence( - sensors=sensors, rd=rd_final, Q=Q, ori=ori, whitener=whitener, - fwd_data=fwd_data) + sensors=sensors, rd=rd_final, Q=Q, ori=ori, whitener=whitener, fwd_data=fwd_data + ) - msg = '---- Fitted : %7.1f ms' % (1000. * t) + msg = "---- Fitted : %7.1f ms" % (1000.0 * t) if surf is not None: dist_to_inner_skull = _compute_nearest( - surf['rr'], rd_final[np.newaxis, :], return_dists=True)[1][0] - msg += (", distance to inner skull : %2.4f mm" - % (dist_to_inner_skull * 1000.)) + surf["rr"], rd_final[np.newaxis, :], return_dists=True + )[1][0] + msg += ", distance to inner skull : %2.4f mm" % (dist_to_inner_skull * 1000.0) logger.info(msg) return rd_final, amp, ori, gof, conf, khi2, nfree, residual_noproj -def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, - fmin_cobyla, ori, rank, rhoend): +def _fit_dipole_fixed( + min_dist_to_inner_skull, + B_orig, + t, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + fmin_cobyla, + ori, + rank, + rhoend, +): """Fit a data using a fixed position.""" B = np.dot(whitener, B_orig) B2 = np.dot(B, B) if B2 == 0: - warn('Zero field found for time %s' % t) + warn("Zero field found for time %s" % t) return np.zeros(3), 0, np.zeros(3), 0, np.zeros(6) # Compute the dipole moment Q, gof, residual_noproj = _fit_Q( - fwd_data=guess_data, whitener=whitener, B=B, B2=B2, B_orig=B_orig, - sensors=sensors, rd=None, ori=ori)[:3] + fwd_data=guess_data, + whitener=whitener, + B=B, + B2=B2, + B_orig=B_orig, + sensors=sensors, + rd=None, + ori=ori, + )[:3] if ori is None: amp = np.sqrt(np.dot(Q, Q)) - norm = 1. if amp == 0. else amp + norm = 1.0 if amp == 0.0 else amp ori = Q / norm else: amp = np.dot(Q, ori) @@ -1143,9 +1383,20 @@ def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs, @verbose -def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, - pos=None, ori=None, rank=None, accuracy='normal', tol=5e-5, - verbose=None): +def fit_dipole( + evoked, + cov, + bem, + trans=None, + min_dist=5.0, + n_jobs=None, + pos=None, + ori=None, + rank=None, + accuracy="normal", + tol=5e-5, + verbose=None, +): """Fit a dipole. Parameters @@ -1219,76 +1470,84 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, .. versionadded:: 0.9.0 """ from scipy import linalg + # This could eventually be adapted to work with other inputs, these # are what is needed: evoked = evoked.copy() - _validate_type(accuracy, str, 'accuracy') - _check_option('accuracy', accuracy, ('accurate', 'normal')) + _validate_type(accuracy, str, "accuracy") + _check_option("accuracy", accuracy, ("accurate", "normal")) # Determine if a list of projectors has an average EEG ref if _needs_eeg_average_ref_proj(evoked.info): - raise ValueError('EEG average reference is mandatory for dipole ' - 'fitting.') + raise ValueError("EEG average reference is mandatory for dipole " "fitting.") if min_dist < 0: - raise ValueError('min_dist should be positive. Got %s' % min_dist) + raise ValueError("min_dist should be positive. Got %s" % min_dist) if ori is not None and pos is None: - raise ValueError('pos must be provided if ori is not None') + raise ValueError("pos must be provided if ori is not None") data = evoked.data if not np.isfinite(data).all(): - raise ValueError('Evoked data must be finite') + raise ValueError("Evoked data must be finite") info = evoked.info times = evoked.times.copy() comment = evoked.comment # Convert the min_dist to meters - min_dist_to_inner_skull = min_dist / 1000. + min_dist_to_inner_skull = min_dist / 1000.0 del min_dist # Figure out our inputs - neeg = len(pick_types(info, meg=False, eeg=True, ref_meg=False, - exclude=[])) + neeg = len(pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[])) if isinstance(bem, str): bem_extra = bem else: bem_extra = repr(bem) - logger.info('BEM : %s' % bem_extra) + logger.info("BEM : %s" % bem_extra) mri_head_t, trans = _get_trans(trans) - logger.info('MRI transform : %s' % trans) + logger.info("MRI transform : %s" % trans) safe_false = _verbose_safe_false() bem = _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=safe_false) - if not bem['is_sphere']: + if not bem["is_sphere"]: # Find the best-fitting sphere - inner_skull = _bem_find_surface(bem, 'inner_skull') + inner_skull = _bem_find_surface(bem, "inner_skull") inner_skull = inner_skull.copy() - R, r0 = _fit_sphere(inner_skull['rr'], disp=False) + R, r0 = _fit_sphere(inner_skull["rr"], disp=False) # r0 back to head frame for logging - r0 = apply_trans(mri_head_t['trans'], r0[np.newaxis, :])[0] - inner_skull['r0'] = r0 - logger.info('Head origin : ' - '%6.1f %6.1f %6.1f mm rad = %6.1f mm.' - % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], 1000 * R)) + r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0] + inner_skull["r0"] = r0 + logger.info( + "Head origin : " + "%6.1f %6.1f %6.1f mm rad = %6.1f mm." + % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], 1000 * R) + ) del R, r0 else: - r0 = bem['r0'] - if len(bem.get('layers', [])) > 0: - R = bem['layers'][0]['rad'] - kind = 'rad' + r0 = bem["r0"] + if len(bem.get("layers", [])) > 0: + R = bem["layers"][0]["rad"] + kind = "rad" else: # MEG-only # Use the minimum distance to the MEG sensors as the radius then - R = np.dot(np.linalg.inv(info['dev_head_t']['trans']), - np.hstack([r0, [1.]]))[:3] # r0 -> device - R = R - [info['chs'][pick]['loc'][:3] - for pick in pick_types(info, meg=True, exclude=[])] + R = np.dot( + np.linalg.inv(info["dev_head_t"]["trans"]), np.hstack([r0, [1.0]]) + )[ + :3 + ] # r0 -> device + R = R - [ + info["chs"][pick]["loc"][:3] + for pick in pick_types(info, meg=True, exclude=[]) + ] if len(R) == 0: - raise RuntimeError('No MEG channels found, but MEG-only ' - 'sphere model used') + raise RuntimeError( + "No MEG channels found, but MEG-only " "sphere model used" + ) R = np.min(np.sqrt(np.sum(R * R, axis=1))) # use dist to sensors - kind = 'max_rad' - logger.info('Sphere model : origin at (% 7.2f % 7.2f % 7.2f) mm, ' - '%s = %6.1f mm' - % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], kind, R)) + kind = "max_rad" + logger.info( + "Sphere model : origin at (% 7.2f % 7.2f % 7.2f) mm, " + "%s = %6.1f mm" % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], kind, R) + ) inner_skull = dict(R=R, r0=r0) # NB sphere model defined in head frame del R, r0 @@ -1297,23 +1556,22 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, fixed_position = True pos = np.array(pos, float) if pos.shape != (3,): - raise ValueError('pos must be None or a 3-element array-like,' - ' got %s' % (pos,)) - logger.info('Fixed position : %6.1f %6.1f %6.1f mm' - % tuple(1000 * pos)) + raise ValueError( + "pos must be None or a 3-element array-like," " got %s" % (pos,) + ) + logger.info("Fixed position : %6.1f %6.1f %6.1f mm" % tuple(1000 * pos)) if ori is not None: ori = np.array(ori, float) if ori.shape != (3,): - raise ValueError('oris must be None or a 3-element array-like,' - ' got %s' % (ori,)) + raise ValueError( + "oris must be None or a 3-element array-like," " got %s" % (ori,) + ) norm = np.sqrt(np.sum(ori * ori)) if not np.isclose(norm, 1): - raise ValueError('ori must be a unit vector, got length %s' - % (norm,)) - logger.info('Fixed orientation : %6.4f %6.4f %6.4f mm' - % tuple(ori)) + raise ValueError("ori must be a unit vector, got length %s" % (norm,)) + logger.info("Fixed orientation : %6.4f %6.4f %6.4f mm" % tuple(ori)) else: - logger.info('Free orientation : ') + logger.info("Free orientation : ") fit_n_jobs = 1 # only use 1 job to do the guess fitting else: fixed_position = False @@ -1323,39 +1581,37 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, guess_mindist = max(0.005, min_dist_to_inner_skull) guess_exclude = 0.02 - logger.info('Guess grid : %6.1f mm' % (1000 * guess_grid,)) + logger.info("Guess grid : %6.1f mm" % (1000 * guess_grid,)) if guess_mindist > 0.0: - logger.info('Guess mindist : %6.1f mm' - % (1000 * guess_mindist,)) + logger.info("Guess mindist : %6.1f mm" % (1000 * guess_mindist,)) if guess_exclude > 0: - logger.info('Guess exclude : %6.1f mm' - % (1000 * guess_exclude,)) - logger.info(f'Using {accuracy} MEG coil definitions.') + logger.info("Guess exclude : %6.1f mm" % (1000 * guess_exclude,)) + logger.info(f"Using {accuracy} MEG coil definitions.") fit_n_jobs = n_jobs cov = _ensure_cov(cov) - logger.info('') + logger.info("") _print_coord_trans(mri_head_t) - _print_coord_trans(info['dev_head_t']) - logger.info('%d bad channels total' % len(info['bads'])) + _print_coord_trans(info["dev_head_t"]) + logger.info("%d bad channels total" % len(info["bads"])) # Forward model setup (setup_forward_model from setup.c) ch_types = evoked.get_channel_types() sensors = dict() - if 'grad' in ch_types or 'mag' in ch_types: - sensors['meg'] = _prep_meg_channels( - info, exclude='bads', accuracy=accuracy, verbose=verbose) - if 'eeg' in ch_types: - sensors['eeg'] = _prep_eeg_channels( - info, exclude='bads', verbose=verbose) + if "grad" in ch_types or "mag" in ch_types: + sensors["meg"] = _prep_meg_channels( + info, exclude="bads", accuracy=accuracy, verbose=verbose + ) + if "eeg" in ch_types: + sensors["eeg"] = _prep_eeg_channels(info, exclude="bads", verbose=verbose) # Ensure that MEG and/or EEG channels are present if len(sensors) == 0: - raise RuntimeError('No MEG or EEG channels found.') + raise RuntimeError("No MEG or EEG channels found.") # Whitener for the data - logger.info('Decomposing the sensor noise covariance matrix...') + logger.info("Decomposing the sensor noise covariance matrix...") picks = pick_types(info, meg=True, eeg=True, ref_meg=False) # In case we want to more closely match MNE-C for debugging: @@ -1369,63 +1625,85 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, # whitener[nzero, nzero] = 1.0 / np.sqrt(cov['eig'][nzero]) # whitener = np.dot(whitener, cov['eigvec']) - whitener, _, rank = compute_whitener(cov, info, picks=picks, - rank=rank, return_rank=True) + whitener, _, rank = compute_whitener( + cov, info, picks=picks, rank=rank, return_rank=True + ) # Proceed to computing the fits (make_guess_data) if fixed_position: guess_src = dict(nuse=1, rr=pos[np.newaxis], inuse=np.array([True])) - logger.info('Compute forward for dipole location...') + logger.info("Compute forward for dipole location...") else: - logger.info('\n---- Computing the forward solution for the guesses...') - guess_src = _make_guesses(inner_skull, guess_grid, guess_exclude, - guess_mindist, n_jobs=n_jobs)[0] + logger.info("\n---- Computing the forward solution for the guesses...") + guess_src = _make_guesses( + inner_skull, guess_grid, guess_exclude, guess_mindist, n_jobs=n_jobs + )[0] # grid coordinates go from mri to head frame - transform_surface_to(guess_src, 'head', mri_head_t) - logger.info('Go through all guess source locations...') + transform_surface_to(guess_src, "head", mri_head_t) + logger.info("Go through all guess source locations...") # inner_skull goes from mri to head frame - if 'rr' in inner_skull: - transform_surface_to(inner_skull, 'head', mri_head_t) + if "rr" in inner_skull: + transform_surface_to(inner_skull, "head", mri_head_t) if fixed_position: - if 'rr' in inner_skull: - check = _surface_constraint(pos, inner_skull, - min_dist_to_inner_skull) + if "rr" in inner_skull: + check = _surface_constraint(pos, inner_skull, min_dist_to_inner_skull) else: check = _sphere_constraint( - pos, inner_skull['r0'], - R_adj=inner_skull['R'] - min_dist_to_inner_skull) + pos, inner_skull["r0"], R_adj=inner_skull["R"] - min_dist_to_inner_skull + ) if check <= 0: - raise ValueError('fixed position is %0.1fmm outside the inner ' - 'skull boundary' % (-1000 * check,)) + raise ValueError( + "fixed position is %0.1fmm outside the inner " + "skull boundary" % (-1000 * check,) + ) # C code computes guesses w/sphere model for speed, don't bother here fwd_data = _prep_field_computation( - guess_src['rr'], sensors=sensors, bem=bem, n_jobs=n_jobs, - verbose=safe_false) - fwd_data['inner_skull'] = inner_skull + guess_src["rr"], sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=safe_false + ) + fwd_data["inner_skull"] = inner_skull guess_fwd, guess_fwd_orig, guess_fwd_scales = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=guess_src['rr'], n_jobs=fit_n_jobs) + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + rr=guess_src["rr"], + n_jobs=fit_n_jobs, + ) # decompose ahead of time - guess_fwd_svd = [linalg.svd(fwd, full_matrices=False) - for fwd in np.array_split(guess_fwd, - len(guess_src['rr']))] - guess_data = dict(fwd=guess_fwd, fwd_svd=guess_fwd_svd, - fwd_orig=guess_fwd_orig, scales=guess_fwd_scales) + guess_fwd_svd = [ + linalg.svd(fwd, full_matrices=False) + for fwd in np.array_split(guess_fwd, len(guess_src["rr"])) + ] + guess_data = dict( + fwd=guess_fwd, + fwd_svd=guess_fwd_svd, + fwd_orig=guess_fwd_orig, + scales=guess_fwd_scales, + ) del guess_fwd, guess_fwd_svd, guess_fwd_orig, guess_fwd_scales # destroyed - logger.info('[done %d source%s]' % (guess_src['nuse'], - _pl(guess_src['nuse']))) + logger.info("[done %d source%s]" % (guess_src["nuse"], _pl(guess_src["nuse"]))) # Do actual fits data = data[picks] - ch_names = [info['ch_names'][p] for p in picks] - proj_op = make_projector(info['projs'], ch_names, info['bads'])[0] + ch_names = [info["ch_names"][p] for p in picks] + proj_op = make_projector(info["projs"], ch_names, info["bads"])[0] fun = _fit_dipole_fixed if fixed_position else _fit_dipole out = _fit_dipoles( - fun, min_dist_to_inner_skull, data, times, guess_src['rr'], - guess_data, sensors=sensors, fwd_data=fwd_data, whitener=whitener, - ori=ori, n_jobs=n_jobs, rank=rank, rhoend=tol) + fun, + min_dist_to_inner_skull, + data, + times, + guess_src["rr"], + guess_data, + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + ori=ori, + n_jobs=n_jobs, + rank=rank, + rhoend=tol, + ) assert len(out) == 8 if fixed_position and ori is not None: # DipoleFixed @@ -1433,38 +1711,66 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, out_info = deepcopy(info) loc = np.concatenate([pos, ori, np.zeros(6)]) out_info._unlocked = True - out_info['chs'] = [ - dict(ch_name='dip 01', loc=loc, kind=FIFF.FIFFV_DIPOLE_WAVE, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN, unit=FIFF.FIFF_UNIT_AM, - coil_type=FIFF.FIFFV_COIL_DIPOLE, - unit_mul=0, range=1, cal=1., scanno=1, logno=1), - dict(ch_name='goodness', loc=np.full(12, np.nan), - kind=FIFF.FIFFV_GOODNESS_FIT, unit=FIFF.FIFF_UNIT_AM, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - coil_type=FIFF.FIFFV_COIL_NONE, - unit_mul=0, range=1., cal=1., scanno=2, logno=100)] - for key in ['hpi_meas', 'hpi_results', 'projs']: + out_info["chs"] = [ + dict( + ch_name="dip 01", + loc=loc, + kind=FIFF.FIFFV_DIPOLE_WAVE, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + unit=FIFF.FIFF_UNIT_AM, + coil_type=FIFF.FIFFV_COIL_DIPOLE, + unit_mul=0, + range=1, + cal=1.0, + scanno=1, + logno=1, + ), + dict( + ch_name="goodness", + loc=np.full(12, np.nan), + kind=FIFF.FIFFV_GOODNESS_FIT, + unit=FIFF.FIFF_UNIT_AM, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + coil_type=FIFF.FIFFV_COIL_NONE, + unit_mul=0, + range=1.0, + cal=1.0, + scanno=2, + logno=100, + ), + ] + for key in ["hpi_meas", "hpi_results", "projs"]: out_info[key] = list() - for key in ['acq_pars', 'acq_stim', 'description', 'dig', - 'experimenter', 'hpi_subsystem', 'proj_id', 'proj_name', - 'subject_info']: + for key in [ + "acq_pars", + "acq_stim", + "description", + "dig", + "experimenter", + "hpi_subsystem", + "proj_id", + "proj_name", + "subject_info", + ]: out_info[key] = None out_info._unlocked = False - out_info['bads'] = [] + out_info["bads"] = [] out_info._update_redundant() out_info._check_consistency() - dipoles = DipoleFixed(out_info, data, times, evoked.nave, - evoked._aspect_kind, comment=comment) + dipoles = DipoleFixed( + out_info, data, times, evoked.nave, evoked._aspect_kind, comment=comment + ) else: - dipoles = Dipole(times, out[0], out[1], out[2], out[3], comment, - out[4], out[5], out[6]) + dipoles = Dipole( + times, out[0], out[1], out[2], out[3], comment, out[4], out[5], out[6] + ) residual = evoked.copy().apply_proj() # set the projs active residual.data[picks] = np.dot(proj_op, out[-1]) - logger.info('%d time points fitted' % len(dipoles.times)) + logger.info("%d time points fitted" % len(dipoles.times)) return dipoles, residual -def get_phantom_dipoles(kind='vectorview'): +def get_phantom_dipoles(kind="vectorview"): """Get standard phantom dipole locations and orientations. Parameters @@ -1493,8 +1799,8 @@ def get_phantom_dipoles(kind='vectorview'): The Elekta phantoms have a radius of 79.5mm, and HPI coil locations in the XY-plane at the axis extrema (e.g., (79.5, 0), (0, -79.5), ...). """ - _check_option('kind', kind, ['vectorview', 'otaniemi']) - if kind == 'vectorview': + _check_option("kind", kind, ["vectorview", "otaniemi"]) + if kind == "vectorview": # these values were pulled from a scanned image provided by # Elekta folks a = np.array([59.7, 48.6, 35.8, 24.8, 37.2, 27.5, 15.8, 7.9]) @@ -1505,7 +1811,7 @@ def get_phantom_dipoles(kind='vectorview'): d = [44.4, 34.0, 21.6, 12.7, 62.4, 51.5, 39.1, 27.9] z = np.concatenate((c, c, d, d)) signs = ([1, -1] * 4 + [-1, 1] * 4) * 2 - elif kind == 'otaniemi': + elif kind == "otaniemi": # these values were pulled from an Neuromag manual # (NM20456A, 13.7.1999, p.65) a = np.array([56.3, 47.6, 39.0, 30.3]) @@ -1515,7 +1821,7 @@ def get_phantom_dipoles(kind='vectorview'): y = np.concatenate((c, c, -a, -b, c, c, b, a)) z = np.concatenate((b, a, b, a, b, a, a, b)) signs = [-1] * 8 + [1] * 16 + [-1] * 8 - pos = np.vstack((x, y, z)).T / 1000. + pos = np.vstack((x, y, z)).T / 1000.0 # Locs are always in XZ or YZ, and so are the oris. The oris are # also in the same plane and tangential, so it's easy to determine # the orientation. @@ -1525,8 +1831,7 @@ def get_phantom_dipoles(kind='vectorview'): idx = np.where(this_pos == 0)[0] # assert len(idx) == 1 idx = np.setdiff1d(np.arange(3), idx[0]) - this_ori[idx] = (this_pos[idx][::-1] / - np.linalg.norm(this_pos[idx])) * [1, -1] + this_ori[idx] = (this_pos[idx][::-1] / np.linalg.norm(this_pos[idx])) * [1, -1] this_ori *= signs[pi] # Now we have this quality, which we could uncomment to # double-check: @@ -1548,6 +1853,11 @@ def _concatenate_dipoles(dipoles): ori.append(dipole.ori) gof.append(dipole.gof) - return Dipole(np.concatenate(times), np.concatenate(pos), - np.concatenate(amplitude), np.concatenate(ori), - np.concatenate(gof), name=None) + return Dipole( + np.concatenate(times), + np.concatenate(pos), + np.concatenate(amplitude), + np.concatenate(ori), + np.concatenate(gof), + name=None, + ) diff --git a/mne/epochs.py b/mne/epochs.py index 8a9e83d22d9..050d0cfaec4 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -19,66 +19,105 @@ import numpy as np from .io.utils import _construct_bids_filename -from .io.write import (start_and_end_file, start_block, end_block, - write_int, write_float, write_float_matrix, - write_double_matrix, write_complex_float_matrix, - write_complex_double_matrix, write_id, write_string, - _get_split_size, _NEXT_FILE_BUFFER, INT32_MAX) -from .io.meas_info import (read_meas_info, write_meas_info, - _ensure_infos_match, ContainsMixin) +from .io.write import ( + start_and_end_file, + start_block, + end_block, + write_int, + write_float, + write_float_matrix, + write_double_matrix, + write_complex_float_matrix, + write_complex_double_matrix, + write_id, + write_string, + _get_split_size, + _NEXT_FILE_BUFFER, + INT32_MAX, +) +from .io.meas_info import ( + read_meas_info, + write_meas_info, + _ensure_infos_match, + ContainsMixin, +) from .io.open import fiff_open, _get_next_fname from .io.tree import dir_tree_find from .io.tag import read_tag, read_tag_info from .io.constants import FIFF from .io.fiff.raw import _get_fname_rep -from .io.pick import (channel_indices_by_type, channel_type, - pick_channels, pick_info, _pick_data_channels, - _DATA_CH_TYPES_SPLIT, _picks_to_idx) +from .io.pick import ( + channel_indices_by_type, + channel_type, + pick_channels, + pick_info, + _pick_data_channels, + _DATA_CH_TYPES_SPLIT, + _picks_to_idx, +) from .io.proj import setup_proj, ProjMixin from .io.base import BaseRaw, TimeMixin, _get_ch_factors from .bem import _check_origin from .evoked import EvokedArray from .baseline import rescale, _log_rescale, _check_baseline -from .channels.channels import (UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin) +from .channels.channels import UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin from .filter import detrend, FilterMixin, _check_fun from .parallel import parallel_func -from .event import (_read_events_fif, make_fixed_length_events, - match_event_names) +from .event import _read_events_fif, make_fixed_length_events, match_event_names from .fixes import rng_uniform -from .time_frequency.spectrum import (EpochsSpectrum, SpectrumMixin, - _validate_method) -from .viz import (plot_epochs, plot_epochs_image, - plot_topo_image_epochs, plot_drop_log) -from .utils import (_check_fname, check_fname, logger, verbose, repr_html, - check_random_state, warn, _pl, - sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc, - _check_pandas_installed, - _check_preload, GetEpochsMixin, - _prepare_read_metadata, _prepare_write_metadata, - _check_event_id, _gen_events, _check_option, - _check_combine, _build_data_frame, - _check_pandas_index_arguments, _convert_times, - _scale_dataframe_data, _check_time_format, object_size, - _on_missing, _validate_type, _ensure_events, - _path_like) +from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method +from .viz import plot_epochs, plot_epochs_image, plot_topo_image_epochs, plot_drop_log +from .utils import ( + _check_fname, + check_fname, + logger, + verbose, + repr_html, + check_random_state, + warn, + _pl, + sizeof_fmt, + SizeMixin, + copy_function_doc_to_method_doc, + _check_pandas_installed, + _check_preload, + GetEpochsMixin, + _prepare_read_metadata, + _prepare_write_metadata, + _check_event_id, + _gen_events, + _check_option, + _check_combine, + _build_data_frame, + _check_pandas_index_arguments, + _convert_times, + _scale_dataframe_data, + _check_time_format, + object_size, + _on_missing, + _validate_type, + _ensure_events, + _path_like, +) from .utils.docs import fill_doc -from .annotations import (_write_annotations, _read_annotations_fif, - EpochAnnotationsMixin) +from .annotations import ( + _write_annotations, + _read_annotations_fif, + EpochAnnotationsMixin, +) def _pack_reject_params(epochs): reject_params = dict() - for key in ('reject', 'flat', 'reject_tmin', 'reject_tmax'): + for key in ("reject", "flat", "reject_tmin", "reject_tmax"): val = getattr(epochs, key, None) if val is not None: reject_params[key] = val return reject_params -def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, - overwrite): +def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite): """Split epochs. Anything new added to this function also needs to be added to @@ -87,22 +126,22 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, # insert index in filename base, ext = op.splitext(fname) if part_idx > 0: - if split_naming == 'neuromag': - fname = '%s-%d%s' % (base, part_idx, ext) + if split_naming == "neuromag": + fname = "%s-%d%s" % (base, part_idx, ext) else: - assert split_naming == 'bids' - fname = _construct_bids_filename(base, ext, part_idx, - validate=False) + assert split_naming == "bids" + fname = _construct_bids_filename(base, ext, part_idx, validate=False) _check_fname(fname, overwrite=overwrite) next_fname = None if part_idx < n_parts - 1: - if split_naming == 'neuromag': - next_fname = '%s-%d%s' % (base, part_idx + 1, ext) + if split_naming == "neuromag": + next_fname = "%s-%d%s" % (base, part_idx + 1, ext) else: - assert split_naming == 'bids' - next_fname = _construct_bids_filename(base, ext, part_idx + 1, - validate=False) + assert split_naming == "bids" + next_fname = _construct_bids_filename( + base, ext, part_idx + 1, validate=False + ) next_idx = part_idx + 1 else: next_idx = None @@ -113,12 +152,12 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): info = epochs.info - meas_id = info['meas_id'] + meas_id = info["meas_id"] start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) # Write measurement info write_meas_info(fid, info) @@ -130,21 +169,21 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): # write events out after getting data to ensure bad events are dropped data = epochs.get_data() - _check_option('fmt', fmt, ['single', 'double']) + _check_option("fmt", fmt, ["single", "double"]) if np.iscomplexobj(data): - if fmt == 'single': + if fmt == "single": write_function = write_complex_float_matrix - elif fmt == 'double': + elif fmt == "double": write_function = write_complex_double_matrix else: - if fmt == 'single': + if fmt == "single": write_function = write_float_matrix - elif fmt == 'double': + elif fmt == "double": write_function = write_double_matrix # Epoch annotations are written if there are any - annotations = getattr(epochs, 'annotations', []) + annotations = getattr(epochs, "annotations", []) if annotations is not None and len(annotations): _write_annotations(fid, annotations) @@ -162,7 +201,7 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): end_block(fid, FIFF.FIFFB_MNE_METADATA) # First and last sample - first = int(round(epochs.tmin * info['sfreq'])) # round just to be safe + first = int(round(epochs.tmin * info["sfreq"])) # round just to be safe last = first + len(epochs.times) - 1 write_int(fid, FIFF.FIFF_FIRST_SAMPLE, first) write_int(fid, FIFF.FIFF_LAST_SAMPLE, last) @@ -177,10 +216,9 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, bmax) # The epochs itself - decal = np.empty(info['nchan']) - for k in range(info['nchan']): - decal[k] = 1.0 / (info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0)) + decal = np.empty(info["nchan"]) + for k in range(info["nchan"]): + decal[k] = 1.0 / (info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0)) data *= decal[np.newaxis, :, np.newaxis] @@ -189,16 +227,13 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): # undo modifications to data data /= decal[np.newaxis, :, np.newaxis] - write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, - json.dumps(epochs.drop_log)) + write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(epochs.drop_log)) reject_params = _pack_reject_params(epochs) if reject_params: - write_string(fid, FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT, - json.dumps(reject_params)) + write_string(fid, FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT, json.dumps(reject_params)) - write_int(fid, FIFF.FIFF_MNE_EPOCHS_SELECTION, - epochs.selection) + write_int(fid, FIFF.FIFF_MNE_EPOCHS_SELECTION, epochs.selection) # And now write the next file info in case epochs are split on disk if next_fname is not None and n_parts > 1: @@ -216,7 +251,7 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): def _event_id_string(event_id): - return ';'.join([k + ':' + str(v) for k, v in event_id.items()]) + return ";".join([k + ":" + str(v) for k, v in event_id.items()]) def _merge_events(events, event_id, selection): @@ -226,7 +261,6 @@ def _merge_events(events, event_id, selection): event_idxs_to_delete = list() unique_events, counts = np.unique(events[:, 0], return_counts=True) for ev in unique_events[counts > 1]: - # indices at which the non-unique events happened idxs = (events[:, 0] == ev).nonzero()[0] @@ -242,18 +276,18 @@ def _merge_events(events, event_id, selection): # Else, make a new event_id for the merged event else: - # Find all event_id keys involved in duplicated events. These # keys will be merged to become a new entry in "event_id" event_id_keys = list(event_id.keys()) event_id_vals = list(event_id.values()) - new_key_comps = [event_id_keys[event_id_vals.index(value)] - for value in ev_vals] + new_key_comps = [ + event_id_keys[event_id_vals.index(value)] for value in ev_vals + ] # Check if we already have an entry for merged keys of duplicate # events ... if yes, reuse it for key in event_id: - if set(key.split('/')) == set(new_key_comps): + if set(key.split("/")) == set(new_key_comps): new_event_val = event_id[key] break @@ -261,9 +295,10 @@ def _merge_events(events, event_id, selection): # the event_id dict else: ev_vals = np.unique( - np.concatenate((list(event_id.values()), - events[:, 1:].flatten()), - axis=0)) + np.concatenate( + (list(event_id.values()), events[:, 1:].flatten()), axis=0 + ) + ) if ev_vals[0] > 1: new_event_val = 1 else: @@ -272,7 +307,7 @@ def _merge_events(events, event_id, selection): idx = -1 if len(idx) == 0 else idx[0] new_event_val = ev_vals[idx] + 1 - new_event_id_key = '/'.join(sorted(new_key_comps)) + new_event_id_key = "/".join(sorted(new_key_comps)) event_id[new_event_id_key] = int(new_event_val) # Replace duplicate event times with merged event and remember which @@ -288,8 +323,7 @@ def _merge_events(events, event_id, selection): return new_events, event_id, new_selection -def _handle_event_repeated(events, event_id, event_repeated, selection, - drop_log): +def _handle_event_repeated(events, event_id, event_repeated, selection, drop_log): """Handle repeated events. Note that drop_log will be modified inplace @@ -304,29 +338,34 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, return events, event_id, selection, drop_log # Else, we have duplicates. Triage ... - _check_option('event_repeated', event_repeated, ['error', 'drop', 'merge']) + _check_option("event_repeated", event_repeated, ["error", "drop", "merge"]) drop_log = list(drop_log) - if event_repeated == 'error': - raise RuntimeError('Event time samples were not unique. Consider ' - 'setting the `event_repeated` parameter."') + if event_repeated == "error": + raise RuntimeError( + "Event time samples were not unique. Consider " + 'setting the `event_repeated` parameter."' + ) - elif event_repeated == 'drop': - logger.info('Multiple event values for single event times found. ' - 'Keeping the first occurrence and dropping all others.') + elif event_repeated == "drop": + logger.info( + "Multiple event values for single event times found. " + "Keeping the first occurrence and dropping all others." + ) new_events = events[u_ev_idxs] new_selection = selection[u_ev_idxs] drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx] = drop_log[idx] + ('DROP DUPLICATE',) + drop_log[idx] = drop_log[idx] + ("DROP DUPLICATE",) selection = new_selection - elif event_repeated == 'merge': - logger.info('Multiple event values for single event times found. ' - 'Creating new event value to reflect simultaneous events.') - new_events, event_id, new_selection = \ - _merge_events(events, event_id, selection) + elif event_repeated == "merge": + logger.info( + "Multiple event values for single event times found. " + "Creating new event value to reflect simultaneous events." + ) + new_events, event_id, new_selection = _merge_events(events, event_id, selection) drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx] = drop_log[idx] + ('MERGE DUPLICATE',) + drop_log[idx] = drop_log[idx] + ("MERGE DUPLICATE",) selection = new_selection drop_log = tuple(drop_log) @@ -338,10 +377,19 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, @fill_doc -class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin, FilterMixin, - TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin, - SpectrumMixin): +class BaseEpochs( + ProjMixin, + ContainsMixin, + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + FilterMixin, + TimeMixin, + SizeMixin, + GetEpochsMixin, + EpochAnnotationsMixin, + SpectrumMixin, +): """Abstract base class for `~mne.Epochs`-type classes. .. note:: @@ -399,22 +447,44 @@ class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, """ @verbose - def __init__(self, info, data, events, event_id=None, - tmin=-0.2, tmax=0.5, - baseline=(None, 0), raw=None, picks=None, reject=None, - flat=None, decim=1, reject_tmin=None, reject_tmax=None, - detrend=None, proj=True, on_missing='raise', - preload_at_end=False, selection=None, drop_log=None, - filename=None, metadata=None, event_repeated='error', - *, raw_sfreq=None, - annotations=None, verbose=None): # noqa: D102 + def __init__( + self, + info, + data, + events, + event_id=None, + tmin=-0.2, + tmax=0.5, + baseline=(None, 0), + raw=None, + picks=None, + reject=None, + flat=None, + decim=1, + reject_tmin=None, + reject_tmax=None, + detrend=None, + proj=True, + on_missing="raise", + preload_at_end=False, + selection=None, + drop_log=None, + filename=None, + metadata=None, + event_repeated="error", + *, + raw_sfreq=None, + annotations=None, + verbose=None, + ): # noqa: D102 if events is not None: # RtEpochs can have events=None events = _ensure_events(events) events_max = events.max() if events_max > INT32_MAX: raise ValueError( - f'events array values must not exceed {INT32_MAX}, ' - f'got {events_max}') + f"events array values must not exceed {INT32_MAX}, " + f"got {events_max}" + ) event_id = _check_event_id(event_id, events) self.event_id = event_id del event_id @@ -422,8 +492,10 @@ def __init__(self, info, data, events, event_id=None, if events is not None: # RtEpochs can have events=None for key, val in self.event_id.items(): if val not in events[:, 2]: - msg = ('No matching events found for %s ' - '(event id %i)' % (key, val)) + msg = "No matching events found for %s " "(event id %i)" % ( + key, + val, + ) _on_missing(on_missing, msg) # ensure metadata matches original events size @@ -442,23 +514,33 @@ def __init__(self, info, data, events, event_id=None, else: selection = np.array(selection, int) if selection.shape != (len(selected),): - raise ValueError('selection must be shape %s got shape %s' - % (selected.shape, selection.shape)) + raise ValueError( + "selection must be shape %s got shape %s" + % (selected.shape, selection.shape) + ) self.selection = selection if drop_log is None: self.drop_log = tuple( - () if k in self.selection else ('IGNORED',) - for k in range(max(len(self.events), - max(self.selection) + 1))) + () if k in self.selection else ("IGNORED",) + for k in range(max(len(self.events), max(self.selection) + 1)) + ) else: self.drop_log = drop_log self.events = self.events[selected] - self.events, self.event_id, self.selection, self.drop_log = \ - _handle_event_repeated( - self.events, self.event_id, event_repeated, - self.selection, self.drop_log) + ( + self.events, + self.event_id, + self.selection, + self.drop_log, + ) = _handle_event_repeated( + self.events, + self.event_id, + event_repeated, + self.selection, + self.drop_log, + ) # then subselect sub = np.where(np.in1d(selection, self.selection))[0] @@ -477,13 +559,16 @@ def __init__(self, info, data, events, event_id=None, n_events = len(self.events) if n_events > 1: if np.diff(self.events.astype(np.int64)[:, 0]).min() <= 0: - warn('The events passed to the Epochs constructor are not ' - 'chronologically ordered.', RuntimeWarning) + warn( + "The events passed to the Epochs constructor are not " + "chronologically ordered.", + RuntimeWarning, + ) if n_events > 0: - logger.info('%d matching events found' % n_events) + logger.info("%d matching events found" % n_events) else: - raise ValueError('No desired events found.') + raise ValueError("No desired events found.") else: self.drop_log = tuple() self.selection = np.array([], int) @@ -491,13 +576,14 @@ def __init__(self, info, data, events, event_id=None, # do not set self.events here, let subclass do it if (detrend not in [None, 0, 1]) or isinstance(detrend, bool): - raise ValueError('detrend must be None, 0, or 1') + raise ValueError("detrend must be None, 0, or 1") self.detrend = detrend self._raw = raw info._check_consistency() - self.picks = _picks_to_idx(info, picks, none='all', exclude=(), - allow_empty=False) + self.picks = _picks_to_idx( + info, picks, none="all", exclude=(), allow_empty=False + ) self.info = pick_info(info, self.picks) del info self._current = 0 @@ -508,48 +594,54 @@ def __init__(self, info, data, events, event_id=None, self._do_baseline = True else: assert decim == 1 - if data.ndim != 3 or data.shape[2] != \ - round((tmax - tmin) * self.info['sfreq']) + 1: - raise RuntimeError('bad data shape') + if ( + data.ndim != 3 + or data.shape[2] != round((tmax - tmin) * self.info["sfreq"]) + 1 + ): + raise RuntimeError("bad data shape") if data.shape[0] != len(self.events): raise ValueError( - 'The number of epochs and the number of events must match') + "The number of epochs and the number of events must match" + ) self.preload = True self._data = data self._do_baseline = False self._offset = None if tmin > tmax: - raise ValueError('tmin has to be less than or equal to tmax') + raise ValueError("tmin has to be less than or equal to tmax") # Handle times - sfreq = float(self.info['sfreq']) + sfreq = float(self.info["sfreq"]) start_idx = int(round(tmin * sfreq)) - self._raw_times = np.arange(start_idx, - int(round(tmax * sfreq)) + 1) / sfreq + self._raw_times = np.arange(start_idx, int(round(tmax * sfreq)) + 1) / sfreq self._set_times(self._raw_times) # check reject_tmin and reject_tmax if reject_tmin is not None: - if (np.isclose(reject_tmin, tmin)): + if np.isclose(reject_tmin, tmin): # adjust for potential small deviations due to sampling freq reject_tmin = self.tmin elif reject_tmin < tmin: - raise ValueError(f'reject_tmin needs to be None or >= tmin ' - f'(got {reject_tmin})') + raise ValueError( + f"reject_tmin needs to be None or >= tmin " f"(got {reject_tmin})" + ) if reject_tmax is not None: - if (np.isclose(reject_tmax, tmax)): + if np.isclose(reject_tmax, tmax): # adjust for potential small deviations due to sampling freq reject_tmax = self.tmax elif reject_tmax > tmax: - raise ValueError(f'reject_tmax needs to be None or <= tmax ' - f'(got {reject_tmax})') + raise ValueError( + f"reject_tmax needs to be None or <= tmax " f"(got {reject_tmax})" + ) if (reject_tmin is not None) and (reject_tmax is not None): if reject_tmin >= reject_tmax: - raise ValueError(f'reject_tmin ({reject_tmin}) needs to be ' - f' < reject_tmax ({reject_tmax})') + raise ValueError( + f"reject_tmin ({reject_tmin}) needs to be " + f" < reject_tmax ({reject_tmax})" + ) self.reject_tmin = reject_tmin self.reject_tmax = reject_tmax @@ -559,11 +651,14 @@ def __init__(self, info, data, events, event_id=None, self.decimate(decim) # baseline correction: replace `None` tuple elements with actual times - self.baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + self.baseline = _check_baseline( + baseline, times=self.times, sfreq=self.info["sfreq"] + ) if self.baseline is not None and self.baseline != baseline: - logger.info(f'Setting baseline interval to ' - f'[{self.baseline[0]}, {self.baseline[1]}] s') + logger.info( + f"Setting baseline interval to " + f"[{self.baseline[0]}, {self.baseline[1]}] s" + ) logger.info(_log_rescale(self.baseline)) @@ -573,18 +668,16 @@ def __init__(self, info, data, events, event_id=None, self._reject_setup(reject, flat) # do the rest - valid_proj = [True, 'delayed', False] + valid_proj = [True, "delayed", False] if proj not in valid_proj: - raise ValueError('"proj" must be one of %s, not %s' - % (valid_proj, proj)) - if proj == 'delayed': + raise ValueError('"proj" must be one of %s, not %s' % (valid_proj, proj)) + if proj == "delayed": self._do_delayed_proj = True - logger.info('Entering delayed SSP mode.') + logger.info("Entering delayed SSP mode.") else: self._do_delayed_proj = False activate = False if self._do_delayed_proj else proj - self._projector, self.info = setup_proj(self.info, False, - activate=activate) + self._projector, self.info = setup_proj(self.info, False, activate=activate) if preload_at_end: assert self._data is None assert self.preload is False @@ -598,20 +691,19 @@ def __init__(self, info, data, events, event_id=None, self._data[ii] = np.dot(self._projector, epoch) self._filename = str(filename) if filename is not None else filename if raw_sfreq is None: - raw_sfreq = self.info['sfreq'] + raw_sfreq = self.info["sfreq"] self._raw_sfreq = raw_sfreq self._check_consistency() self.set_annotations(annotations) def _check_consistency(self): """Check invariants of epochs object.""" - if hasattr(self, 'events'): + if hasattr(self, "events"): assert len(self.selection) == len(self.events) assert len(self.drop_log) >= len(self.events) - assert len(self.selection) == sum( - (len(dl) == 0 for dl in self.drop_log)) - assert hasattr(self, '_times_readonly') - assert not self.times.flags['WRITEABLE'] + assert len(self.selection) == sum((len(dl) == 0 for dl in self.drop_log)) + assert hasattr(self, "_times_readonly") + assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) assert all(isinstance(log, tuple) for log in self.drop_log) assert all(isinstance(s, str) for log in self.drop_log for s in log) @@ -678,14 +770,15 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): .. versionadded:: 0.10.0 """ - baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + baseline = _check_baseline(baseline, times=self.times, sfreq=self.info["sfreq"]) if self.preload: if self.baseline is not None and baseline is None: - raise RuntimeError('You cannot remove baseline correction ' - 'from preloaded data once it has been ' - 'applied.') + raise RuntimeError( + "You cannot remove baseline correction " + "from preloaded data once it has been " + "applied." + ) self._do_baseline = True picks = self._detrend_picks rescale(self._data, self.times, baseline, copy=False, picks=picks) @@ -704,39 +797,45 @@ def _reject_setup(self, reject, flat): idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() flat = deepcopy(flat) if flat is not None else dict() - for rej, kind in zip((reject, flat), ('reject', 'flat')): + for rej, kind in zip((reject, flat), ("reject", "flat")): if not isinstance(rej, dict): - raise TypeError('reject and flat must be dict or None, not %s' - % type(rej)) + raise TypeError( + "reject and flat must be dict or None, not %s" % type(rej) + ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError('Unknown channel types found in %s: %s' - % (kind, bads)) + raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) for key in idx.keys(): # don't throw an error if rejection/flat would do nothing - if len(idx[key]) == 0 and (np.isfinite(reject.get(key, np.inf)) or - flat.get(key, -1) >= 0): + if len(idx[key]) == 0 and ( + np.isfinite(reject.get(key, np.inf)) or flat.get(key, -1) >= 0 + ): # This is where we could eventually add e.g. # self.allow_missing_reject_keys check to allow users to # provide keys that don't exist in data - raise ValueError("No %s channel found. Cannot reject based on " - "%s." % (key.upper(), key.upper())) + raise ValueError( + "No %s channel found. Cannot reject based on " + "%s." % (key.upper(), key.upper()) + ) # check for invalid values - for rej, kind in zip((reject, flat), ('Rejection', 'Flat')): + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): if val is None or val < 0: - raise ValueError('%s value must be a number >= 0, not "%s"' - % (kind, val)) + raise ValueError( + '%s value must be a number >= 0, not "%s"' % (kind, val) + ) # now check to see if our rejection and flat are getting more # restrictive old_reject = self.reject if self.reject is not None else dict() old_flat = self.flat if self.flat is not None else dict() - bad_msg = ('{kind}["{key}"] == {new} {op} {old} (old value), new ' - '{kind} values must be at least as stringent as ' - 'previous ones') + bad_msg = ( + '{kind}["{key}"] == {new} {op} {old} (old value), new ' + "{kind} values must be at least as stringent as " + "previous ones" + ) # copy thresholds for channel types that were used previously, but not # passed this time @@ -746,8 +845,14 @@ def _reject_setup(self, reject, flat): for key in reject: if key in old_reject and reject[key] > old_reject[key]: raise ValueError( - bad_msg.format(kind='reject', key=key, new=reject[key], - old=old_reject[key], op='>')) + bad_msg.format( + kind="reject", + key=key, + new=reject[key], + old=old_reject[key], + op=">", + ) + ) # same for flat thresholds for key in set(old_flat) - set(flat): @@ -755,8 +860,10 @@ def _reject_setup(self, reject, flat): for key in flat: if key in old_flat and flat[key] < old_flat[key]: raise ValueError( - bad_msg.format(kind='flat', key=key, new=flat[key], - old=old_flat[key], op='<')) + bad_msg.format( + kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" + ) + ) # after validation, set parameters self._bad_dropped = False @@ -785,20 +892,26 @@ def _is_good_epoch(self, data, verbose=None): if isinstance(data, str): return False, (data,) if data is None: - return False, ('NO_DATA',) + return False, ("NO_DATA",) n_times = len(self.times) if data.shape[1] < n_times: # epoch is too short ie at the end of the data - return False, ('TOO_SHORT',) + return False, ("TOO_SHORT",) if self.reject is None and self.flat is None: return True, None else: if self._reject_time is not None: data = data[:, self._reject_time] - return _is_good(data, self.ch_names, self._channel_type_idx, - self.reject, self.flat, full_report=True, - ignore_chs=self.info['bads']) + return _is_good( + data, + self.ch_names, + self._channel_type_idx, + self.reject, + self.flat, + full_report=True, + ignore_chs=self.info["bads"], + ) @verbose def _detrend_offset_decim(self, epoch, picks, verbose=None): @@ -819,8 +932,13 @@ def _detrend_offset_decim(self, epoch, picks, verbose=None): # Baseline correct if self._do_baseline: rescale( - epoch, self._raw_times, self.baseline, picks=picks, copy=False, - verbose=False) + epoch, + self._raw_times, + self.baseline, + picks=picks, + copy=False, + verbose=False, + ) # Decimate if necessary (i.e., epoch not preloaded) epoch = epoch[:, self._decim_slice] @@ -883,14 +1001,13 @@ def subtract_evoked(self, evoked=None): .. [1] David et al. "Mechanisms of evoked and induced responses in MEG/EEG", NeuroImage, vol. 31, no. 4, pp. 1580-1591, July 2006. """ - logger.info('Subtracting Evoked from Epochs') + logger.info("Subtracting Evoked from Epochs") if evoked is None: picks = _pick_data_channels(self.info, exclude=[]) evoked = self.average(picks) # find the indices of the channels to use - picks = pick_channels( - evoked.ch_names, include=self.ch_names, ordered=False) + picks = pick_channels(evoked.ch_names, include=self.ch_names, ordered=False) # make sure the omitted channels are not data channels if len(picks) < len(self.ch_names): @@ -898,24 +1015,32 @@ def subtract_evoked(self, evoked=None): diff_ch = list(set(self.ch_names).difference(sel_ch)) diff_idx = [self.ch_names.index(ch) for ch in diff_ch] diff_types = [channel_type(self.info, idx) for idx in diff_idx] - bad_idx = [diff_types.index(t) for t in diff_types if t in - _DATA_CH_TYPES_SPLIT] + bad_idx = [ + diff_types.index(t) for t in diff_types if t in _DATA_CH_TYPES_SPLIT + ] if len(bad_idx) > 0: - bad_str = ', '.join([diff_ch[ii] for ii in bad_idx]) - raise ValueError('The following data channels are missing ' - 'in the evoked response: %s' % bad_str) - logger.info(' The following channels are not included in the ' - 'subtraction: %s' % ', '.join(diff_ch)) + bad_str = ", ".join([diff_ch[ii] for ii in bad_idx]) + raise ValueError( + "The following data channels are missing " + "in the evoked response: %s" % bad_str + ) + logger.info( + " The following channels are not included in the " + "subtraction: %s" % ", ".join(diff_ch) + ) # make sure the times match - if (len(self.times) != len(evoked.times) or - np.max(np.abs(self.times - evoked.times)) >= 1e-7): - raise ValueError('Epochs and Evoked object do not contain ' - 'the same time points.') + if ( + len(self.times) != len(evoked.times) + or np.max(np.abs(self.times - evoked.times)) >= 1e-7 + ): + raise ValueError( + "Epochs and Evoked object do not contain " "the same time points." + ) # handle SSPs if not self.proj and evoked.proj: - warn('Evoked has SSP applied while Epochs has not.') + warn("Evoked has SSP applied while Epochs has not.") if self.proj and not evoked.proj: evoked = evoked.copy().apply_proj() @@ -927,10 +1052,11 @@ def subtract_evoked(self, evoked=None): self._data[:, ep_picks, :] -= evoked.data[picks][None, :, :] else: if self._offset is None: - self._offset = np.zeros((len(self.ch_names), len(self.times)), - dtype=np.float64) + self._offset = np.zeros( + (len(self.ch_names), len(self.times)), dtype=np.float64 + ) self._offset[ep_picks] -= evoked.data[picks] - logger.info('[done]') + logger.info("[done]") return self @@ -978,8 +1104,7 @@ def average(self, picks=None, method="mean", by_event_type=False): if by_event_type: evokeds = list() for event_type in self.event_id.keys(): - ev = self[event_type]._compute_aggregate(picks=picks, - mode=method) + ev = self[event_type]._compute_aggregate(picks=picks, mode=method) ev.comment = event_type evokeds.append(ev) else: @@ -999,39 +1124,43 @@ def standard_error(self, picks=None, by_event_type=False): ------- %(std_err_by_event_type_returns)s """ - return self.average(picks=picks, method="std", - by_event_type=by_event_type) + return self.average(picks=picks, method="std", by_event_type=by_event_type) - def _compute_aggregate(self, picks, mode='mean'): + def _compute_aggregate(self, picks, mode="mean"): """Compute the mean, median, or std over epochs and return Evoked.""" # if instance contains ICA channels they won't be included unless picks # is specified if picks is None: - check_ICA = [x.startswith('ICA') for x in self.ch_names] + check_ICA = [x.startswith("ICA") for x in self.ch_names] if np.all(check_ICA): - raise TypeError('picks must be specified (i.e. not None) for ' - 'ICA channel data') + raise TypeError( + "picks must be specified (i.e. not None) for " "ICA channel data" + ) elif np.any(check_ICA): - warn('ICA channels will not be included unless explicitly ' - 'selected in picks') + warn( + "ICA channels will not be included unless explicitly " + "selected in picks" + ) n_channels = len(self.ch_names) n_times = len(self.times) if self.preload: n_events = len(self.events) - fun = _check_combine(mode, valid=('mean', 'median', 'std')) + fun = _check_combine(mode, valid=("mean", "median", "std")) data = fun(self._data) assert len(self.events) == len(self._data) if data.shape != self._data.shape[1:]: raise RuntimeError( - 'You passed a function that resulted n data of shape {}, ' - 'but it should be {}.'.format( - data.shape, self._data.shape[1:])) + "You passed a function that resulted n data of shape {}, " + "but it should be {}.".format(data.shape, self._data.shape[1:]) + ) else: if mode not in {"mean", "std"}: - raise ValueError("If data are not preloaded, can only compute " - "mean or standard deviation.") + raise ValueError( + "If data are not preloaded, can only compute " + "mean or standard deviation." + ) data = np.zeros((n_channels, n_times)) n_events = 0 for e in self: @@ -1049,26 +1178,27 @@ def _compute_aggregate(self, picks, mode='mean'): # two (slower) in case there are large numbers if mode == "std": data_mean = data.copy() - data.fill(0.) + data.fill(0.0) for e in self: data += (e - data_mean) ** 2 data = np.sqrt(data / n_events) if mode == "std": - kind = 'standard_error' + kind = "standard_error" data /= np.sqrt(n_events) else: kind = "average" - return self._evoked_from_epoch_data(data, self.info, picks, n_events, - kind, self._name) + return self._evoked_from_epoch_data( + data, self.info, picks, n_events, kind, self._name + ) @property def _name(self): """Give a nice string representation based on event ids.""" return self._get_name() - def _get_name(self, count='frac', ms='×', sep='+'): + def _get_name(self, count="frac", ms="×", sep="+"): """Generate human-readable name for epochs and evokeds from event_id. Parameters @@ -1084,7 +1214,7 @@ def _get_name(self, count='frac', ms='×', sep='+'): How to separate the different events names. Ignored if only one event type is present. """ - _check_option('count', value=count, allowed_values=['frac', 'total']) + _check_option("count", value=count, allowed_values=["frac", "total"]) if len(self.event_id) == 1: comment = next(iter(self.event_id.keys())) @@ -1094,28 +1224,34 @@ def _get_name(self, count='frac', ms='×', sep='+'): # Take care of padding if ms is None: - ms = ' ' + ms = " " else: - ms = f' {ms} ' + ms = f" {ms} " for event_name, event_code in self.event_id.items(): - if count == 'frac': + if count == "frac": frac = float(counter[event_code]) / len(self.events) - comment = f'{frac:.2f}{ms}{event_name}' + comment = f"{frac:.2f}{ms}{event_name}" else: # 'total' - comment = f'{counter[event_code]}{ms}{event_name}' + comment = f"{counter[event_code]}{ms}{event_name}" comments.append(comment) - comment = f' {sep} '.join(comments) + comment = f" {sep} ".join(comments) return comment - def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, - comment): + def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment): """Create an evoked object from epoch data.""" info = deepcopy(info) # don't apply baseline correction; we'll set evoked.baseline manually - evoked = EvokedArray(data, info, tmin=self.times[0], comment=comment, - nave=n_events, kind=kind, baseline=None) + evoked = EvokedArray( + data, + info, + tmin=self.times[0], + comment=comment, + nave=n_events, + kind=kind, + baseline=None, + ) evoked.baseline = self.baseline # the above constructor doesn't recreate the times object precisely @@ -1123,58 +1259,116 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, evoked._set_times(self.times.copy()) # pick channels - picks = _picks_to_idx(self.info, picks, 'data_or_ica', ()) + picks = _picks_to_idx(self.info, picks, "data_or_ica", ()) ch_names = [evoked.ch_names[p] for p in picks] evoked.pick_channels(ch_names) - if len(evoked.info['ch_names']) == 0: - raise ValueError('No data channel found when averaging.') + if len(evoked.info["ch_names"]) == 0: + raise ValueError("No data channel found when averaging.") if evoked.nave < 1: - warn('evoked object is empty (based on less than 1 epoch)') + warn("evoked object is empty (based on less than 1 epoch)") return evoked @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @copy_function_doc_to_method_doc(plot_epochs) - def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20, - title=None, events=None, event_color=None, - order=None, show=True, block=False, decim='auto', noise_cov=None, - butterfly=False, show_scrollbars=True, show_scalebars=True, - epoch_colors=None, event_id=None, group_by='type', - precompute=None, use_opengl=None, *, theme=None, - overview_mode=None): - return plot_epochs(self, picks=picks, scalings=scalings, - n_epochs=n_epochs, n_channels=n_channels, - title=title, events=events, event_color=event_color, - order=order, show=show, block=block, decim=decim, - noise_cov=noise_cov, butterfly=butterfly, - show_scrollbars=show_scrollbars, - show_scalebars=show_scalebars, - epoch_colors=epoch_colors, event_id=event_id, - group_by=group_by, precompute=precompute, - use_opengl=use_opengl, theme=theme, - overview_mode=overview_mode) + def plot( + self, + picks=None, + scalings=None, + n_epochs=20, + n_channels=20, + title=None, + events=None, + event_color=None, + order=None, + show=True, + block=False, + decim="auto", + noise_cov=None, + butterfly=False, + show_scrollbars=True, + show_scalebars=True, + epoch_colors=None, + event_id=None, + group_by="type", + precompute=None, + use_opengl=None, + *, + theme=None, + overview_mode=None, + ): + return plot_epochs( + self, + picks=picks, + scalings=scalings, + n_epochs=n_epochs, + n_channels=n_channels, + title=title, + events=events, + event_color=event_color, + order=order, + show=show, + block=block, + decim=decim, + noise_cov=noise_cov, + butterfly=butterfly, + show_scrollbars=show_scrollbars, + show_scalebars=show_scalebars, + epoch_colors=epoch_colors, + event_id=event_id, + group_by=group_by, + precompute=precompute, + use_opengl=use_opengl, + theme=theme, + overview_mode=overview_mode, + ) @copy_function_doc_to_method_doc(plot_topo_image_epochs) - def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None, - colorbar=None, order=None, cmap='RdBu_r', - layout_scale=.95, title=None, scalings=None, - border='none', fig_facecolor='k', fig_background=None, - font_color='w', show=True): + def plot_topo_image( + self, + layout=None, + sigma=0.0, + vmin=None, + vmax=None, + colorbar=None, + order=None, + cmap="RdBu_r", + layout_scale=0.95, + title=None, + scalings=None, + border="none", + fig_facecolor="k", + fig_background=None, + font_color="w", + show=True, + ): return plot_topo_image_epochs( - self, layout=layout, sigma=sigma, vmin=vmin, vmax=vmax, - colorbar=colorbar, order=order, cmap=cmap, - layout_scale=layout_scale, title=title, scalings=scalings, - border=border, fig_facecolor=fig_facecolor, - fig_background=fig_background, font_color=font_color, show=show) + self, + layout=layout, + sigma=sigma, + vmin=vmin, + vmax=vmax, + colorbar=colorbar, + order=order, + cmap=cmap, + layout_scale=layout_scale, + title=title, + scalings=scalings, + border=border, + fig_facecolor=fig_facecolor, + fig_background=fig_background, + font_color=font_color, + show=show, + ) @verbose - def drop_bad(self, reject='existing', flat='existing', verbose=None): + def drop_bad(self, reject="existing", flat="existing", verbose=None): """Drop bad epochs without retaining the epochs data. Should be used before slicing operations. @@ -1206,20 +1400,19 @@ def drop_bad(self, reject='existing', flat='existing', verbose=None): subsequently be applied, `epochs.copy ` should be used. """ - if reject == 'existing': - if flat == 'existing' and self._bad_dropped: + if reject == "existing": + if flat == "existing" and self._bad_dropped: return reject = self.reject - if flat == 'existing': + if flat == "existing": flat = self.flat - if any(isinstance(rej, str) and rej != 'existing' for - rej in (reject, flat)): + if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') self._reject_setup(reject, flat) self._get_data(out=False, verbose=verbose) return self - def drop_log_stats(self, ignore=('IGNORED',)): + def drop_log_stats(self, ignore=("IGNORED",)): """Compute the channel stats based on a drop_log from Epochs. Parameters @@ -1239,33 +1432,81 @@ def drop_log_stats(self, ignore=('IGNORED',)): return _drop_log_stats(self.drop_log, ignore) @copy_function_doc_to_method_doc(plot_drop_log) - def plot_drop_log(self, threshold=0, n_max_plot=20, subject=None, - color=(0.9, 0.9, 0.9), width=0.8, ignore=('IGNORED',), - show=True): + def plot_drop_log( + self, + threshold=0, + n_max_plot=20, + subject=None, + color=(0.9, 0.9, 0.9), + width=0.8, + ignore=("IGNORED",), + show=True, + ): if not self._bad_dropped: - raise ValueError("You cannot use plot_drop_log since bad " - "epochs have not yet been dropped. " - "Use epochs.drop_bad().") - return plot_drop_log(self.drop_log, threshold, n_max_plot, subject, - color=color, width=width, ignore=ignore, - show=show) + raise ValueError( + "You cannot use plot_drop_log since bad " + "epochs have not yet been dropped. " + "Use epochs.drop_bad()." + ) + return plot_drop_log( + self.drop_log, + threshold, + n_max_plot, + subject, + color=color, + width=width, + ignore=ignore, + show=show, + ) @copy_function_doc_to_method_doc(plot_epochs_image) - def plot_image(self, picks=None, sigma=0., vmin=None, vmax=None, - colorbar=True, order=None, show=True, units=None, - scalings=None, cmap=None, fig=None, axes=None, - overlay_times=None, combine=None, group_by=None, - evoked=True, ts_args=None, title=None, clear=False): - return plot_epochs_image(self, picks=picks, sigma=sigma, vmin=vmin, - vmax=vmax, colorbar=colorbar, order=order, - show=show, units=units, scalings=scalings, - cmap=cmap, fig=fig, axes=axes, - overlay_times=overlay_times, combine=combine, - group_by=group_by, evoked=evoked, - ts_args=ts_args, title=title, clear=clear) + def plot_image( + self, + picks=None, + sigma=0.0, + vmin=None, + vmax=None, + colorbar=True, + order=None, + show=True, + units=None, + scalings=None, + cmap=None, + fig=None, + axes=None, + overlay_times=None, + combine=None, + group_by=None, + evoked=True, + ts_args=None, + title=None, + clear=False, + ): + return plot_epochs_image( + self, + picks=picks, + sigma=sigma, + vmin=vmin, + vmax=vmax, + colorbar=colorbar, + order=order, + show=show, + units=units, + scalings=scalings, + cmap=cmap, + fig=fig, + axes=axes, + overlay_times=overlay_times, + combine=combine, + group_by=group_by, + evoked=evoked, + ts_args=ts_args, + title=title, + clear=clear, + ) @verbose - def drop(self, indices, reason='USER', verbose=None): + def drop(self, indices, reason="USER", verbose=None): """Drop epochs based on indices or boolean mask. .. note:: The indices refer to the current set of undropped epochs @@ -1309,8 +1550,10 @@ def drop(self, indices, reason='USER', verbose=None): keep = np.setdiff1d(np.arange(len(self.events)), try_idx) self._getitem(keep, reason, copy=False, drop_event_id=False) count = len(try_idx) - logger.info('Dropped %d epoch%s: %s' % - (count, _pl(count), ', '.join(map(str, np.sort(try_idx))))) + logger.info( + "Dropped %d epoch%s: %s" + % (count, _pl(count), ", ".join(map(str, np.sort(try_idx)))) + ) return self @@ -1330,8 +1573,17 @@ def _project_epoch(self, epoch): return epoch @verbose - def _get_data(self, out=True, picks=None, item=None, *, units=None, - tmin=None, tmax=None, verbose=None): + def _get_data( + self, + out=True, + picks=None, + item=None, + *, + units=None, + tmin=None, + tmax=None, + verbose=None, + ): """Load all data, dropping bad epochs along the way. Parameters @@ -1354,13 +1606,19 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, if not out: # make sure first and last epoch not out of bounds of raw in_bounds = self.preload or ( - self._get_epoch_from_raw(idx=0) is not None and - self._get_epoch_from_raw(idx=-1) is not None) + self._get_epoch_from_raw(idx=0) is not None + and self._get_epoch_from_raw(idx=-1) is not None + ) # might be BaseEpochs or Epochs, only the latter has the attribute - reject_by_annotation = getattr(self, 'reject_by_annotation', False) - if (self.reject is None and self.flat is None and in_bounds and - self._reject_time is None and not reject_by_annotation): - logger.debug('_get_data is a noop, returning') + reject_by_annotation = getattr(self, "reject_by_annotation", False) + if ( + self.reject is None + and self.flat is None + and in_bounds + and self._reject_time is None + and not reject_by_annotation + ): + logger.debug("_get_data is a noop, returning") self._bad_dropped = True return None start, stop = self._handle_tmin_tmax(tmin, tmax) @@ -1369,8 +1627,9 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, item = slice(None) elif not self._bad_dropped: raise ValueError( - 'item must be None in epochs.get_data() unless bads have been ' - 'dropped. Consider using epochs.drop_bad().') + "item must be None in epochs.get_data() unless bads have been " + "dropped. Consider using epochs.drop_bad()." + ) select = self._item_to_select(item) # indices or slice use_idx = np.arange(len(self.events))[select] n_events = len(use_idx) @@ -1380,15 +1639,17 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, data = self._data else: # we start out with an empty array, allocate only if necessary - data = np.empty((0, len(self.info['ch_names']), len(self.times))) - msg = (f'for {n_events} events and {len(self._raw_times)} ' - 'original time points') + data = np.empty((0, len(self.info["ch_names"]), len(self.times))) + msg = ( + f"for {n_events} events and {len(self._raw_times)} " + "original time points" + ) if self._decim > 1: - msg += ' (prior to decimation)' + msg += " (prior to decimation)" if getattr(self._raw, "preload", False): - logger.info(f'Using data from preloaded Raw {msg} ...') + logger.info(f"Using data from preloaded Raw {msg} ...") else: - logger.info(f'Loading data {msg} ...') + logger.info(f"Loading data {msg} ...") orig_picks = picks if orig_picks is None: @@ -1418,15 +1679,16 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, for ii, idx in enumerate(use_idx): # faster to pre-allocate memory here epoch_noproj = self._get_epoch_from_raw(idx) - epoch_noproj = self._detrend_offset_decim( - epoch_noproj, detrend_picks) + epoch_noproj = self._detrend_offset_decim(epoch_noproj, detrend_picks) if self._do_delayed_proj: epoch_out = epoch_noproj else: epoch_out = self._project_epoch(epoch_noproj) if ii == 0: - data = np.empty((n_events, len(self.ch_names), - len(self.times)), dtype=epoch_out.dtype) + data = np.empty( + (n_events, len(self.ch_names), len(self.times)), + dtype=epoch_out.dtype, + ) data[ii] = epoch_out else: # bads need to be dropped, this might occur after a preload @@ -1448,12 +1710,12 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, else: # from disk epoch_noproj = self._get_epoch_from_raw(idx) epoch_noproj = self._detrend_offset_decim( - epoch_noproj, detrend_picks) + epoch_noproj, detrend_picks + ) epoch = self._project_epoch(epoch_noproj) epoch_out = epoch_noproj if self._do_delayed_proj else epoch - is_good, bad_tuple = self._is_good_epoch( - epoch, verbose=verbose) + is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) assert all(isinstance(x, str) for x in bad_tuple) @@ -1465,9 +1727,11 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, if out or self.preload: # faster to pre-allocate, then trim as necessary if n_out == 0 and not self.preload: - data = np.empty((n_events, epoch_out.shape[0], - epoch_out.shape[1]), - dtype=epoch_out.dtype, order='C') + data = np.empty( + (n_events, epoch_out.shape[0], epoch_out.shape[1]), + dtype=epoch_out.dtype, + order="C", + ) data[n_out] = epoch_out n_out += 1 self.drop_log = tuple(drop_log) @@ -1478,7 +1742,7 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, # adjust the data size if there is a reason to (output or update) if out or self.preload: - if data.flags['OWNDATA'] and data.flags['C_CONTIGUOUS']: + if data.flags["OWNDATA"] and data.flags["C_CONTIGUOUS"]: data.resize((n_out,) + data.shape[1:], refcheck=False) else: data = data[:n_out] @@ -1486,8 +1750,9 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, self._data = data # Now update our properties (excepd data, which is already fixed) - self._getitem(good_idx, None, copy=False, drop_event_id=False, - select_data=False) + self._getitem( + good_idx, None, copy=False, drop_event_id=False, select_data=False + ) if out: if orig_picks is not None: @@ -1504,13 +1769,13 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, def _detrend_picks(self): if self._do_baseline: return _pick_data_channels( - self.info, with_ref_meg=True, with_aux=True, exclude=()) + self.info, with_ref_meg=True, with_aux=True, exclude=() + ) else: return [] @fill_doc - def get_data(self, picks=None, item=None, units=None, tmin=None, - tmax=None): + def get_data(self, picks=None, item=None, units=None, tmin=None, tmax=None): """Get all epochs as a 3D array. Parameters @@ -1541,12 +1806,19 @@ def get_data(self, picks=None, item=None, units=None, tmin=None, data : array of shape (n_epochs, n_channels, n_times) A view on epochs data. """ - return self._get_data(picks=picks, item=item, units=units, tmin=tmin, - tmax=tmax) + return self._get_data(picks=picks, item=item, units=units, tmin=tmin, tmax=tmax) @verbose - def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, - channel_wise=True, verbose=None, **kwargs): + def apply_function( + self, + fun, + picks=None, + dtype=None, + n_jobs=None, + channel_wise=True, + verbose=None, + **kwargs, + ): """Apply a function to a subset of channels. %(applyfun_summary_epochs)s @@ -1567,11 +1839,11 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self : instance of Epochs The epochs object with transformed data. """ - _check_preload(self, 'epochs.apply_function') + _check_preload(self, "epochs.apply_function") picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) if not callable(fun): - raise ValueError('fun needs to be a function') + raise ValueError("fun needs to be a function") data_in = self._data if dtype is not None and dtype != self._data.dtype: @@ -1584,11 +1856,13 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, # modify data inplace to save memory for idx in picks: self._data[:, idx, :] = np.apply_along_axis( - _fun, -1, data_in[:, idx, :]) + _fun, -1, data_in[:, idx, :] + ) else: # use parallel function - data_picks_new = parallel(p_fun( - fun, data_in[:, p, :], **kwargs) for p in picks) + data_picks_new = parallel( + p_fun(fun, data_in[:, p, :], **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[:, p, :] = data_picks_new[pp] else: @@ -1603,60 +1877,66 @@ def filename(self): def __repr__(self): """Build string representation.""" - s = ' %s events ' % len(self.events) - s += '(all good)' if self._bad_dropped else '(good & bad)' - s += ', %g – %g s' % (self.tmin, self.tmax) - s += ', baseline ' + s = " %s events " % len(self.events) + s += "(all good)" if self._bad_dropped else "(good & bad)" + s += ", %g – %g s" % (self.tmin, self.tmax) + s += ", baseline " if self.baseline is None: - s += 'off' + s += "off" else: - s += f'{self.baseline[0]:g} – {self.baseline[1]:g} s' + s += f"{self.baseline[0]:g} – {self.baseline[1]:g} s" if self.baseline != _check_baseline( - self.baseline, times=self.times, sfreq=self.info['sfreq'], - on_baseline_outside_data='adjust'): - s += ' (baseline period was cropped after baseline correction)' - - s += ', ~%s' % (sizeof_fmt(self._size),) - s += ', data%s loaded' % ('' if self.preload else ' not') - s += ', with metadata' if self.metadata is not None else '' + self.baseline, + times=self.times, + sfreq=self.info["sfreq"], + on_baseline_outside_data="adjust", + ): + s += " (baseline period was cropped after baseline correction)" + + s += ", ~%s" % (sizeof_fmt(self._size),) + s += ", data%s loaded" % ("" if self.preload else " not") + s += ", with metadata" if self.metadata is not None else "" max_events = 10 - counts = ['%r: %i' % (k, sum(self.events[:, 2] == v)) - for k, v in list(self.event_id.items())[:max_events]] + counts = [ + "%r: %i" % (k, sum(self.events[:, 2] == v)) + for k, v in list(self.event_id.items())[:max_events] + ] if len(self.event_id) > 0: - s += ',' + '\n '.join([''] + counts) + s += "," + "\n ".join([""] + counts) if len(self.event_id) > max_events: not_shown_events = len(self.event_id) - max_events s += f"\n and {not_shown_events} more events ..." class_name = self.__class__.__name__ - class_name = 'Epochs' if class_name == 'BaseEpochs' else class_name - return '<%s | %s>' % (class_name, s) + class_name = "Epochs" if class_name == "BaseEpochs" else class_name + return "<%s | %s>" % (class_name, s) @repr_html def _repr_html_(self): from .html_templates import repr_templates_env + if self.baseline is None: - baseline = 'off' + baseline = "off" else: - baseline = tuple([f'{b:.3f}' for b in self.baseline]) - baseline = f'{baseline[0]} – {baseline[1]} s' + baseline = tuple([f"{b:.3f}" for b in self.baseline]) + baseline = f"{baseline[0]} – {baseline[1]} s" if isinstance(self.event_id, dict): event_strings = [] for k, v in sorted(self.event_id.items()): n_events = sum(self.events[:, 2] == v) - event_strings.append(f'{k}: {n_events}') + event_strings.append(f"{k}: {n_events}") elif isinstance(self.event_id, list): event_strings = [] for k in self.event_id: n_events = sum(self.events[:, 2] == k) - event_strings.append(f'{k}: {n_events}') + event_strings.append(f"{k}: {n_events}") elif isinstance(self.event_id, int): n_events = len(self.events[:, 2]) - event_strings = [f'{self.event_id}: {n_events}'] + event_strings = [f"{self.event_id}: {n_events}"] else: event_strings = None - t = repr_templates_env.get_template('epochs.html.jinja') + t = repr_templates_env.get_template("epochs.html.jinja") t = t.render(epochs=self, baseline=baseline, events=event_strings) return t @@ -1683,20 +1963,22 @@ def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None): %(notes_tmax_included_by_default)s """ # XXX this could be made to work on non-preloaded data... - _check_preload(self, 'Modifying data of epochs') + _check_preload(self, "Modifying data of epochs") super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) # Adjust rejection period if self.reject_tmin is not None and self.reject_tmin < self.tmin: logger.info( - f'reject_tmin is not in epochs time interval. ' - f'Setting reject_tmin to epochs.tmin ({self.tmin} s)') + f"reject_tmin is not in epochs time interval. " + f"Setting reject_tmin to epochs.tmin ({self.tmin} s)" + ) self.reject_tmin = self.tmin if self.reject_tmax is not None and self.reject_tmax > self.tmax: logger.info( - f'reject_tmax is not in epochs time interval. ' - f'Setting reject_tmax to epochs.tmax ({self.tmax} s)') + f"reject_tmax is not in epochs time interval. " + f"Setting reject_tmax to epochs.tmax ({self.tmax} s)" + ) self.reject_tmax = self.tmax return self @@ -1717,7 +1999,7 @@ def __deepcopy__(self, memodict): for k, v in self.__dict__.items(): # drop_log is immutable and _raw is private (and problematic to # deepcopy) - if k in ('drop_log', '_raw', '_times_readonly'): + if k in ("drop_log", "_raw", "_times_readonly"): memodict[id(v)] = v else: v = deepcopy(v, memodict) @@ -1725,8 +2007,15 @@ def __deepcopy__(self, memodict): return result @verbose - def save(self, fname, split_size='2GB', fmt='single', overwrite=False, - split_naming='neuromag', verbose=None): + def save( + self, + fname, + split_size="2GB", + fmt="single", + overwrite=False, + split_naming="neuromag", + verbose=None, + ): """Save epochs in a fif file. Parameters @@ -1765,15 +2054,16 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, ----- Bad epochs will be dropped before saving the epochs to disk. """ - check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz', - '_epo.fif', '_epo.fif.gz')) + check_fname( + fname, "epochs", ("-epo.fif", "-epo.fif.gz", "_epo.fif", "_epo.fif.gz") + ) # check for file existence and expand `~` if present fname = str(_check_fname(fname=fname, overwrite=overwrite)) split_size_bytes = _get_split_size(split_size) - _check_option('fmt', fmt, ['single', 'double']) + _check_option("fmt", fmt, ["single", "double"]) # to know the length accurately. The get_data() call would drop # bad epochs anyway @@ -1781,12 +2071,12 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, # total_size tracks sizes that get split # over_size tracks overhead (tags, things that get written to each) if len(self) == 0: - warn('Saving epochs with no data') + warn("Saving epochs with no data") total_size = 0 else: d = self[0].get_data() # this should be guaranteed by subclasses - assert d.dtype in ('>f8', 'c16', 'f8", "c16", "= 1, n_parts if n_parts > 1: - logger.info(f'Splitting into {n_parts} parts') + logger.info(f"Splitting into {n_parts} parts") if n_parts > 100: # This must be an error raise ValueError( - f'Split size {split_size} would result in writing ' - f'{n_parts} files') + f"Split size {split_size} would result in writing " + f"{n_parts} files" + ) if len(self.drop_log) > 100000: - warn(f'epochs.drop_log contains {len(self.drop_log)} entries ' - f'which will incur up to a {sizeof_fmt(drop_size)} writing ' - f'overhead (per split file), consider using ' - f'epochs.reset_drop_log_selection() prior to writing') + warn( + f"epochs.drop_log contains {len(self.drop_log)} entries " + f"which will incur up to a {sizeof_fmt(drop_size)} writing " + f"overhead (per split file), consider using " + f"epochs.reset_drop_log_selection() prior to writing" + ) epoch_idxs = np.array_split(np.arange(n_epochs), n_parts) @@ -1857,11 +2150,12 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, this_epochs = self[epoch_idx] if n_parts > 1 else self # avoid missing event_ids in splits this_epochs.event_id = self.event_id - _save_split(this_epochs, fname, part_idx, n_parts, fmt, - split_naming, overwrite) + _save_split( + this_epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite + ) @verbose - def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): + def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): """Export Epochs to external formats. %(export_fmt_support_epochs)s @@ -1885,9 +2179,10 @@ def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): %(export_eeglab_note)s """ from .export import export_epochs + export_epochs(fname, self, fmt, overwrite=overwrite, verbose=verbose) - def equalize_event_counts(self, event_ids=None, method='mintime'): + def equalize_event_counts(self, event_ids=None, method="mintime"): """Equalize the number of trials in each condition. It tries to make the remaining epochs occurring as close as possible in @@ -1960,16 +2255,23 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): event names were specified explicitly. """ from collections.abc import Iterable - _validate_type(event_ids, types=(Iterable, None), - item_name='event_ids', type_name='list-like or None') + + _validate_type( + event_ids, + types=(Iterable, None), + item_name="event_ids", + type_name="list-like or None", + ) if isinstance(event_ids, str): - raise TypeError(f'event_ids must be list-like or None, but ' - f'received a string: {event_ids}') + raise TypeError( + f"event_ids must be list-like or None, but " + f"received a string: {event_ids}" + ) if event_ids is None: event_ids = list(self.event_id) elif not event_ids: - raise ValueError('event_ids must have at least one element') + raise ValueError("event_ids must have at least one element") if not self._bad_dropped: self.drop_bad() @@ -1982,8 +2284,7 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): tagging = False if "/" in "".join(ids): # make string inputs a list of length 1 - event_ids = [[x] if isinstance(x, str) else x - for x in event_ids] + event_ids = [[x] if isinstance(x, str) else x for x in event_ids] for ids_ in event_ids: # check if tagging is attempted if any([id_ not in ids for id_ in ids_]): tagging = True @@ -1991,19 +2292,24 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): # 2a. for tags, find all the event_ids matched by the tags # 2b. for non-tag ids, just pass them directly # 3. do this for every input - event_ids = [[k for k in ids - if all((tag in k.split("/") - for tag in id_))] # ids matching all tags - if all(id__ not in ids for id__ in id_) - else id_ # straight pass for non-tag inputs - for id_ in event_ids] + event_ids = [ + [ + k for k in ids if all((tag in k.split("/") for tag in id_)) + ] # ids matching all tags + if all(id__ not in ids for id__ in id_) + else id_ # straight pass for non-tag inputs + for id_ in event_ids + ] for ii, id_ in enumerate(event_ids): if len(id_) == 0: - raise KeyError(f"{orig_ids[ii]} not found in the epoch " - "object's event_id.") + raise KeyError( + f"{orig_ids[ii]} not found in the epoch " "object's event_id." + ) elif len({sub_id in ids for sub_id in id_}) != 1: - err = ("Don't mix hierarchical and regular event_ids" - " like in \'%s\'." % ", ".join(id_)) + err = ( + "Don't mix hierarchical and regular event_ids" + " like in '%s'." % ", ".join(id_) + ) raise ValueError(err) # raise for non-orthogonal tags @@ -2011,9 +2317,11 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): events_ = [set(self[x].events[:, 0]) for x in event_ids] doubles = events_[0].intersection(events_[1]) if len(doubles): - raise ValueError("The two sets of epochs are " - "overlapping. Provide an " - "orthogonal selection.") + raise ValueError( + "The two sets of epochs are " + "overlapping. Provide an " + "orthogonal selection." + ) for eq in event_ids: eq_inds.append(self._keys_to_idx(eq)) @@ -2022,14 +2330,25 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): indices = _get_drop_indices(event_times, method) # need to re-index indices indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)]) - self.drop(indices, reason='EQUALIZED_COUNT') + self.drop(indices, reason="EQUALIZED_COUNT") # actually remove the indices return self, indices @verbose - def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, - tmax=None, picks=None, proj=False, *, n_jobs=1, - verbose=None, **method_kw): + def compute_psd( + self, + method="multitaper", + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + n_jobs=1, + verbose=None, + **method_kw, + ): """Perform spectral analysis on sensor data. Parameters @@ -2061,17 +2380,47 @@ def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, self._set_legacy_nfft_default(tmin, tmax, method, method_kw) return EpochsSpectrum( - self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - picks=picks, proj=proj, n_jobs=n_jobs, verbose=verbose, - **method_kw) + self, + method=method, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, - proj=False, *, method='auto', average=False, dB=True, - estimate='auto', xscale='linear', area_mode='std', - area_alpha=0.33, color='black', line_alpha=None, - spatial_colors=True, sphere=None, exclude='bads', ax=None, - show=True, n_jobs=1, verbose=None, **method_kw): + def plot_psd( + self, + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + method="auto", + average=False, + dB=True, + estimate="auto", + xscale="linear", + area_mode="std", + area_alpha=0.33, + color="black", + line_alpha=None, + spatial_colors=True, + sphere=None, + exclude="bads", + ax=None, + show=True, + n_jobs=1, + verbose=None, + **method_kw, + ): """%(plot_psd_doc)s. Parameters @@ -2115,17 +2464,44 @@ def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, %(notes_plot_psd_meth)s """ return super().plot_psd( - fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, - reject_by_annotation=False, method=method, average=average, dB=dB, - estimate=estimate, xscale=xscale, area_mode=area_mode, - area_alpha=area_alpha, color=color, line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, - ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + method=method, + average=average, + dB=dB, + estimate=estimate, + xscale=xscale, + area_mode=area_mode, + area_alpha=area_alpha, + color=color, + line_alpha=line_alpha, + spatial_colors=spatial_colors, + sphere=sphere, + exclude=exclude, + ax=ax, + show=show, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def to_data_frame(self, picks=None, index=None, - scalings=None, copy=True, long_format=False, - time_format=None, *, verbose=None): + def to_data_frame( + self, + picks=None, + index=None, + scalings=None, + copy=True, + long_format=False, + time_format=None, + *, + verbose=None, + ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, @@ -2155,12 +2531,12 @@ def to_data_frame(self, picks=None, index=None, # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking - valid_index_args = ['time', 'epoch', 'condition'] - valid_time_formats = ['ms', 'timedelta'] + valid_index_args = ["time", "epoch", "condition"] + valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) data = self.get_data()[:, picks, :] times = self.times n_epochs, n_picks, n_times = data.shape @@ -2172,18 +2548,25 @@ def to_data_frame(self, picks=None, index=None, mindex = list() times = np.tile(times, n_epochs) times = _convert_times(self, times, time_format) - mindex.append(('time', times)) + mindex.append(("time", times)) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(('condition', np.repeat(conditions, n_times))) - mindex.append(('epoch', np.repeat(self.selection, n_times))) + mindex.append(("condition", np.repeat(conditions, n_times))) + mindex.append(("epoch", np.repeat(self.selection, n_times))) assert all(len(mdx) == len(mindex[0]) for mdx in mindex) # build DataFrame - df = _build_data_frame(self, data, picks, long_format, mindex, index, - default_index=['condition', 'epoch', 'time']) + df = _build_data_frame( + self, + data, + picks, + long_format, + mindex, + index, + default_index=["condition", "epoch", "time"], + ) return df - def as_type(self, ch_type='grad', mode='fast'): + def as_type(self, ch_type="grad", mode="fast"): """Compute virtual epochs using interpolated fields. .. Warning:: Using virtual epochs to compute inverse can yield @@ -2213,10 +2596,11 @@ def as_type(self, ch_type='grad', mode='fast'): .. versionadded:: 0.20.0 """ from .forward import _as_meg_type_inst + return _as_meg_type_inst(self, ch_type=ch_type, mode=mode) -def _drop_log_stats(drop_log, ignore=('IGNORED',)): +def _drop_log_stats(drop_log, ignore=("IGNORED",)): """Compute drop log stats. Parameters @@ -2231,17 +2615,28 @@ def _drop_log_stats(drop_log, ignore=('IGNORED',)): perc : float Total percentage of epochs dropped. """ - if not isinstance(drop_log, tuple) or \ - not all(isinstance(d, tuple) for d in drop_log) or \ - not all(isinstance(s, str) for d in drop_log for s in d): - raise TypeError('drop_log must be a tuple of tuple of str') - perc = 100 * np.mean([len(d) > 0 for d in drop_log - if not any(r in ignore for r in d)]) + if ( + not isinstance(drop_log, tuple) + or not all(isinstance(d, tuple) for d in drop_log) + or not all(isinstance(s, str) for d in drop_log for s in d) + ): + raise TypeError("drop_log must be a tuple of tuple of str") + perc = 100 * np.mean( + [len(d) > 0 for d in drop_log if not any(r in ignore for r in d)] + ) return perc -def make_metadata(events, event_id, tmin, tmax, sfreq, - row_events=None, keep_first=None, keep_last=None): +def make_metadata( + events, + event_id, + tmin, + tmax, + sfreq, + row_events=None, + keep_first=None, + keep_last=None, +): """Generate metadata from events for use with `mne.Epochs`. This function mimics the epoching process (it constructs time windows @@ -2369,16 +2764,13 @@ def make_metadata(events, event_id, tmin, tmax, sfreq, """ pd = _check_pandas_installed() - _validate_type(event_id, types=(dict,), item_name='event_id') - _validate_type(row_events, types=(None, str, list, tuple), - item_name='row_events') - _validate_type(keep_first, types=(None, str, list, tuple), - item_name='keep_first') - _validate_type(keep_last, types=(None, str, list, tuple), - item_name='keep_last') + _validate_type(event_id, types=(dict,), item_name="event_id") + _validate_type(row_events, types=(None, str, list, tuple), item_name="row_events") + _validate_type(keep_first, types=(None, str, list, tuple), item_name="keep_first") + _validate_type(keep_last, types=(None, str, list, tuple), item_name="keep_last") if not event_id: - raise ValueError('event_id dictionary must contain at least one entry') + raise ValueError("event_id dictionary must contain at least one entry") def _ensure_list(x): if x is None: @@ -2394,26 +2786,29 @@ def _ensure_list(x): keep_first_and_last = set(keep_first) & set(keep_last) if keep_first_and_last: - raise ValueError(f'The event names in keep_first and keep_last must ' - f'be mutually exclusive. Specified in both: ' - f'{", ".join(sorted(keep_first_and_last))}') + raise ValueError( + f"The event names in keep_first and keep_last must " + f"be mutually exclusive. Specified in both: " + f'{", ".join(sorted(keep_first_and_last))}' + ) del keep_first_and_last - for param_name, values in dict(keep_first=keep_first, - keep_last=keep_last).items(): + for param_name, values in dict(keep_first=keep_first, keep_last=keep_last).items(): for first_last_event_name in values: try: match_event_names(event_id, [first_last_event_name]) except KeyError: raise ValueError( f'Event "{first_last_event_name}", specified in ' - f'{param_name}, cannot be found in event_id dictionary') + f"{param_name}, cannot be found in event_id dictionary" + ) event_name_diff = sorted(set(row_events) - set(event_id.keys())) if event_name_diff: raise ValueError( - f'Present in row_events, but missing from event_id: ' - f'{", ".join(event_name_diff)}') + f"Present in row_events, but missing from event_id: " + f'{", ".join(event_name_diff)}' + ) del event_name_diff # First and last sample of each epoch, relative to the time-locked event @@ -2425,12 +2820,12 @@ def _ensure_list(x): # We create the DataFrame before subsetting the events so we end up with # indices corresponding to the original event indices. Not used for now, # but might come in handy sometime later - events_df = pd.DataFrame(events, columns=('sample', 'prev_id', 'id')) + events_df = pd.DataFrame(events, columns=("sample", "prev_id", "id")) id_to_name_map = {v: k for k, v in event_id.items()} # Only keep events that are of interest events = events[np.in1d(events[:, 2], list(event_id.values()))] - events_df = events_df.loc[events_df['id'].isin(event_id.values()), :] + events_df = events_df.loc[events_df["id"].isin(event_id.values()), :] # Prepare & condition the metadata DataFrame @@ -2438,26 +2833,27 @@ def _ensure_list(x): # event_id.keys() and keep_first / keep_last simultaneously keep_first_cols = [col for col in keep_first if col not in event_id] keep_last_cols = [col for col in keep_last if col not in event_id] - first_cols = [f'first_{col}' for col in keep_first_cols] - last_cols = [f'last_{col}' for col in keep_last_cols] - - columns = ['event_name', - *event_id.keys(), - *keep_first_cols, - *keep_last_cols, - *first_cols, - *last_cols] + first_cols = [f"first_{col}" for col in keep_first_cols] + last_cols = [f"last_{col}" for col in keep_last_cols] + + columns = [ + "event_name", + *event_id.keys(), + *keep_first_cols, + *keep_last_cols, + *first_cols, + *last_cols, + ] data = np.empty((len(events_df), len(columns))) metadata = pd.DataFrame(data=data, columns=columns, index=events_df.index) # Event names - metadata.iloc[:, 0] = '' + metadata.iloc[:, 0] = "" # Event times start_idx = 1 - stop_idx = (start_idx + len(event_id.keys()) + - len(keep_first_cols + keep_last_cols)) + stop_idx = start_idx + len(event_id.keys()) + len(keep_first_cols + keep_last_cols) metadata.iloc[:, start_idx:stop_idx] = np.nan # keep_first and keep_last names @@ -2467,22 +2863,23 @@ def _ensure_list(x): # We're all set, let's iterate over all eventns and fill in in the # respective cells in the metadata. We will subset this to include only # `row_events` later - for row_event in events_df.itertuples(name='RowEvent'): + for row_event in events_df.itertuples(name="RowEvent"): row_idx = row_event.Index - metadata.loc[row_idx, 'event_name'] = \ - id_to_name_map[row_event.id] + metadata.loc[row_idx, "event_name"] = id_to_name_map[row_event.id] # Determine which events fall into the current epoch window_start_sample = row_event.sample + start_sample window_stop_sample = row_event.sample + stop_sample events_in_window = events_df.loc[ - (events_df['sample'] >= window_start_sample) & - (events_df['sample'] <= window_stop_sample), :] + (events_df["sample"] >= window_start_sample) + & (events_df["sample"] <= window_stop_sample), + :, + ] assert not events_in_window.empty # Store the metadata - for event in events_in_window.itertuples(name='Event'): + for event in events_in_window.itertuples(name="Event"): event_sample = event.sample - row_event.sample event_time = event_sample / sfreq event_time = 0 if np.isclose(event_time, 0) else event_time @@ -2499,31 +2896,29 @@ def _ensure_list(x): # Handle keep_first and keep_last event aggregation for event_group_name in keep_first + keep_last: - if event_name not in match_event_names( - event_id, [event_group_name] - ): + if event_name not in match_event_names(event_id, [event_group_name]): continue if event_group_name in keep_first: - first_last_col = f'first_{event_group_name}' + first_last_col = f"first_{event_group_name}" else: - first_last_col = f'last_{event_group_name}' + first_last_col = f"last_{event_group_name}" old_time = metadata.loc[row_idx, event_group_name] if not np.isnan(old_time): - if ((event_group_name in keep_first and - old_time <= event_time) or - (event_group_name in keep_last and - old_time >= event_time)): + if (event_group_name in keep_first and old_time <= event_time) or ( + event_group_name in keep_last and old_time >= event_time + ): continue if event_group_name not in event_id: # This is an HED. Strip redundant information from the # event name - name = (event_name - .replace(event_group_name, '') - .replace('//', '/') - .strip('/')) + name = ( + event_name.replace(event_group_name, "") + .replace("//", "/") + .strip("/") + ) metadata.loc[row_idx, first_last_col] = name del name @@ -2531,12 +2926,11 @@ def _ensure_list(x): # Only keep rows of interest if row_events: - event_id_timelocked = {name: val for name, val in event_id.items() - if name in row_events} - events = events[np.in1d(events[:, 2], - list(event_id_timelocked.values()))] - metadata = metadata.loc[ - metadata['event_name'].isin(event_id_timelocked)] + event_id_timelocked = { + name: val for name, val in event_id.items() if name in row_events + } + events = events[np.in1d(events[:, 2], list(event_id_timelocked.values()))] + metadata = metadata.loc[metadata["event_name"].isin(event_id_timelocked)] assert len(events) == len(metadata) event_id = event_id_timelocked @@ -2648,15 +3042,34 @@ class Epochs(BaseEpochs): """ @verbose - def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5, - baseline=(None, 0), picks=None, preload=False, reject=None, - flat=None, proj=True, decim=1, reject_tmin=None, - reject_tmax=None, detrend=None, on_missing='raise', - reject_by_annotation=True, metadata=None, - event_repeated='error', verbose=None): # noqa: D102 + def __init__( + self, + raw, + events, + event_id=None, + tmin=-0.2, + tmax=0.5, + baseline=(None, 0), + picks=None, + preload=False, + reject=None, + flat=None, + proj=True, + decim=1, + reject_tmin=None, + reject_tmax=None, + detrend=None, + on_missing="raise", + reject_by_annotation=True, + metadata=None, + event_repeated="error", + verbose=None, + ): # noqa: D102 if not isinstance(raw, BaseRaw): - raise ValueError('The first argument to `Epochs` must be an ' - 'instance of mne.io.BaseRaw') + raise ValueError( + "The first argument to `Epochs` must be an " + "instance of mne.io.BaseRaw" + ) info = deepcopy(raw.info) # proj is on when applied in Raw @@ -2665,17 +3078,34 @@ def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5, self.reject_by_annotation = reject_by_annotation # keep track of original sfreq (needed for annotations) - raw_sfreq = raw.info['sfreq'] + raw_sfreq = raw.info["sfreq"] # call BaseEpochs constructor super(Epochs, self).__init__( - info, None, events, event_id, tmin, tmax, - metadata=metadata, baseline=baseline, raw=raw, picks=picks, - reject=reject, flat=flat, decim=decim, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, detrend=detrend, - proj=proj, on_missing=on_missing, preload_at_end=preload, - event_repeated=event_repeated, verbose=verbose, - raw_sfreq=raw_sfreq, annotations=raw.annotations) + info, + None, + events, + event_id, + tmin, + tmax, + metadata=metadata, + baseline=baseline, + raw=raw, + picks=picks, + reject=reject, + flat=flat, + decim=decim, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + detrend=detrend, + proj=proj, + on_missing=on_missing, + preload_at_end=preload, + event_repeated=event_repeated, + verbose=verbose, + raw_sfreq=raw_sfreq, + annotations=raw.annotations, + ) @verbose def _get_epoch_from_raw(self, idx, verbose=None): @@ -2690,10 +3120,12 @@ def _get_epoch_from_raw(self, idx, verbose=None): """ if self._raw is None: # This should never happen, as raw=None only if preload=True - raise ValueError('An error has occurred, no valid raw file found. ' - 'Please report this to the mne-python ' - 'developers.') - sfreq = self._raw.info['sfreq'] + raise ValueError( + "An error has occurred, no valid raw file found. " + "Please report this to the mne-python " + "developers." + ) + sfreq = self._raw.info["sfreq"] event_samp = self.events[idx, 0] # Read a data segment from "start" to "stop" in samples first_samp = self._raw.first_samp @@ -2715,10 +3147,15 @@ def _get_epoch_from_raw(self, idx, verbose=None): diff = int(round((self._raw_times[-1] - reject_tmax) * sfreq)) reject_stop = stop - diff - logger.debug(' Getting epoch for %d-%d' % (start, stop)) - data = self._raw._check_bad_segment(start, stop, self.picks, - reject_start, reject_stop, - self.reject_by_annotation) + logger.debug(" Getting epoch for %d-%d" % (start, stop)) + data = self._raw._check_bad_segment( + start, + stop, + self.picks, + reject_start, + reject_stop, + self.reject_by_annotation, + ) return data @@ -2800,38 +3237,72 @@ class EpochsArray(BaseEpochs): """ @verbose - def __init__(self, data, info, events=None, tmin=0, event_id=None, - reject=None, flat=None, reject_tmin=None, - reject_tmax=None, baseline=None, proj=True, - on_missing='raise', metadata=None, selection=None, - *, drop_log=None, raw_sfreq=None, verbose=None): # noqa: D102 + def __init__( + self, + data, + info, + events=None, + tmin=0, + event_id=None, + reject=None, + flat=None, + reject_tmin=None, + reject_tmax=None, + baseline=None, + proj=True, + on_missing="raise", + metadata=None, + selection=None, + *, + drop_log=None, + raw_sfreq=None, + verbose=None, + ): # noqa: D102 dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 3: - raise ValueError('Data must be a 3D array of shape (n_epochs, ' - 'n_channels, n_samples)') + raise ValueError( + "Data must be a 3D array of shape (n_epochs, " "n_channels, n_samples)" + ) - if len(info['ch_names']) != data.shape[1]: - raise ValueError('Info and data must have same number of ' - 'channels.') + if len(info["ch_names"]) != data.shape[1]: + raise ValueError("Info and data must have same number of " "channels.") if events is None: n_epochs = len(data) events = _gen_events(n_epochs) info = info.copy() # do not modify original info - tmax = (data.shape[2] - 1) / info['sfreq'] + tmin + tmax = (data.shape[2] - 1) / info["sfreq"] + tmin super(EpochsArray, self).__init__( - info, data, events, event_id, tmin, tmax, baseline, - reject=reject, flat=flat, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, decim=1, metadata=metadata, - selection=selection, proj=proj, on_missing=on_missing, - drop_log=drop_log, raw_sfreq=raw_sfreq, verbose=verbose) + info, + data, + events, + event_id, + tmin, + tmax, + baseline, + reject=reject, + flat=flat, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + decim=1, + metadata=metadata, + selection=selection, + proj=proj, + on_missing=on_missing, + drop_log=drop_log, + raw_sfreq=raw_sfreq, + verbose=verbose, + ) if self.baseline is not None: self._do_baseline = True - if len(events) != np.in1d(self.events[:, 2], - list(self.event_id.values())).sum(): - raise ValueError('The events must only contain event numbers from ' - 'event_id') + if ( + len(events) + != np.in1d(self.events[:, 2], list(self.event_id.values())).sum() + ): + raise ValueError( + "The events must only contain event numbers from " "event_id" + ) detrend_picks = self._detrend_picks for e in self._data: # This is safe without assignment b/c there is no decim @@ -2875,19 +3346,20 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): new_event_id = {str(new_event_id): new_event_id} else: if not isinstance(new_event_id, dict): - raise ValueError('new_event_id must be a dict or int') + raise ValueError("new_event_id must be a dict or int") if not len(list(new_event_id.keys())) == 1: - raise ValueError('new_event_id dict must have one entry') + raise ValueError("new_event_id dict must have one entry") new_event_num = list(new_event_id.values())[0] new_event_num = operator.index(new_event_num) if new_event_num in epochs.event_id.values(): - raise ValueError('new_event_id value must not already exist') + raise ValueError("new_event_id value must not already exist") # could use .pop() here, but if a latter one doesn't exist, we're # in trouble, so run them all here and pop() later old_event_nums = np.array([epochs.event_id[key] for key in old_event_ids]) # find the ones to replace - inds = np.any(epochs.events[:, 2][:, np.newaxis] == - old_event_nums[np.newaxis, :], axis=1) + inds = np.any( + epochs.events[:, 2][:, np.newaxis] == old_event_nums[np.newaxis, :], axis=1 + ) # replace the event numbers in the events list epochs.events[inds, 2] = new_event_num # delete old entries @@ -2898,7 +3370,7 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): return epochs -def equalize_epoch_counts(epochs_list, method='mintime'): +def equalize_epoch_counts(epochs_list, method="mintime"): """Equalize the number of trials in multiple Epoch instances. Parameters @@ -2927,7 +3399,7 @@ def equalize_epoch_counts(epochs_list, method='mintime'): >>> equalize_epoch_counts([epochs1, epochs2]) # doctest: +SKIP """ if not all(isinstance(e, BaseEpochs) for e in epochs_list): - raise ValueError('All inputs must be Epochs instances') + raise ValueError("All inputs must be Epochs instances") # make sure bad epochs are dropped for e in epochs_list: @@ -2936,21 +3408,21 @@ def equalize_epoch_counts(epochs_list, method='mintime'): event_times = [e.events[:, 0] for e in epochs_list] indices = _get_drop_indices(event_times, method) for e, inds in zip(epochs_list, indices): - e.drop(inds, reason='EQUALIZED_COUNT') + e.drop(inds, reason="EQUALIZED_COUNT") def _get_drop_indices(event_times, method): """Get indices to drop from multiple event timing lists.""" small_idx = np.argmin([e.shape[0] for e in event_times]) small_e_times = event_times[small_idx] - _check_option('method', method, ['mintime', 'truncate']) + _check_option("method", method, ["mintime", "truncate"]) indices = list() for e in event_times: - if method == 'mintime': + if method == "mintime": mask = _minimize_time_diff(small_e_times, e) else: mask = np.ones(e.shape[0], dtype=bool) - mask[small_e_times.shape[0]:] = False + mask[small_e_times.shape[0] :] = False indices.append(np.where(np.logical_not(mask))[0]) return indices @@ -2959,6 +3431,7 @@ def _get_drop_indices(event_times, method): def _minimize_time_diff(t_shorter, t_longer): """Find a boolean mask to minimize timing differences.""" from scipy.interpolate import interp1d + keep = np.ones((len(t_longer)), dtype=bool) # special case: length zero or one if len(t_shorter) < 2: # interp1d won't work @@ -2971,8 +3444,7 @@ def _minimize_time_diff(t_shorter, t_longer): x1 = np.arange(len(t_shorter)) # The first set of keep masks to test kwargs = dict(copy=False, bounds_error=False, assume_sorted=True) - shorter_interp = interp1d(x1, t_shorter, fill_value=t_shorter[-1], - **kwargs) + shorter_interp = interp1d(x1, t_shorter, fill_value=t_shorter[-1], **kwargs) for ii in range(len(t_longer) - len(t_shorter)): scores.fill(np.inf) # set up the keep masks to test, eliminating any rows that are already @@ -2982,9 +3454,9 @@ def _minimize_time_diff(t_shorter, t_longer): # Check every possible removal to see if it minimizes x2 = np.arange(len(t_longer) - ii - 1) t_keeps = np.array([t_longer[km] for km in keep_mask]) - longer_interp = interp1d(x2, t_keeps, axis=1, - fill_value=t_keeps[:, -1], - **kwargs) + longer_interp = interp1d( + x2, t_keeps, axis=1, fill_value=t_keeps[:, -1], **kwargs + ) d1 = longer_interp(x1) - t_shorter d2 = shorter_interp(x2) - t_keeps scores[keep] = np.abs(d1, d1).sum(axis=1) + np.abs(d2, d2).sum(axis=1) @@ -2993,8 +3465,16 @@ def _minimize_time_diff(t_shorter, t_longer): @verbose -def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, - ignore_chs=[], verbose=None): +def _is_good( + e, + ch_names, + channel_type_idx, + reject, + flat, + full_report=False, + ignore_chs=[], + verbose=None, +): """Test if data segment e is good according to reject and flat. If full_report=True, it will give True/False as well as a list of all @@ -3003,9 +3483,8 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, bad_tuple = tuple() has_printed = False checkable = np.ones(len(ch_names), dtype=bool) - checkable[np.array([c in ignore_chs - for c in ch_names], dtype=bool)] = False - for refl, f, t in zip([reject, flat], [np.greater, np.less], ['', 'flat']): + checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False + for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, thresh in refl.items(): idx = channel_type_idx[key] @@ -3014,14 +3493,17 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, e_idx = e[idx] deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) checkable_idx = checkable[idx] - idx_deltas = np.where(np.logical_and(f(deltas, thresh), - checkable_idx))[0] + idx_deltas = np.where( + np.logical_and(f(deltas, thresh), checkable_idx) + )[0] if len(idx_deltas) > 0: bad_names = [ch_names[idx[i]] for i in idx_deltas] - if (not has_printed): - logger.info(' Rejecting %s epoch based on %s : ' - '%s' % (t, name, bad_names)) + if not has_printed: + logger.info( + " Rejecting %s epoch based on %s : " + "%s" % (t, name, bad_names) + ) has_printed = True if not full_report: return False @@ -3051,7 +3533,7 @@ def _read_one_epoch_file(f, tree, preload): metadata = None metadata_tree = dir_tree_find(tree, FIFF.FIFFB_MNE_METADATA) if len(metadata_tree) > 0: - for dd in metadata_tree[0]['directory']: + for dd in metadata_tree[0]["directory"]: kind = dd.kind pos = dd.pos if kind == FIFF.FIFF_DESCRIPTION: @@ -3063,7 +3545,7 @@ def _read_one_epoch_file(f, tree, preload): processed = dir_tree_find(meas, FIFF.FIFFB_PROCESSED_DATA) del meas if len(processed) == 0: - raise ValueError('Could not find processed data') + raise ValueError("Could not find processed data") epochs_node = dir_tree_find(tree, FIFF.FIFFB_MNE_EPOCHS) if len(epochs_node) == 0: @@ -3073,7 +3555,7 @@ def _read_one_epoch_file(f, tree, preload): if len(epochs_node) == 0: epochs_node = dir_tree_find(tree, 122) # 122 used before v0.11 if len(epochs_node) == 0: - raise ValueError('Could not find epochs data') + raise ValueError("Could not find epochs data") my_epochs = epochs_node[0] @@ -3086,9 +3568,9 @@ def _read_one_epoch_file(f, tree, preload): drop_log = None raw_sfreq = None reject_params = {} - for k in range(my_epochs['nent']): - kind = my_epochs['directory'][k].kind - pos = my_epochs['directory'][k].pos + for k in range(my_epochs["nent"]): + kind = my_epochs["directory"][k].kind + pos = my_epochs["directory"][k].pos if kind == FIFF.FIFF_FIRST_SAMPLE: tag = read_tag(fid, pos) first = int(tag.data.item()) @@ -3128,44 +3610,52 @@ def _read_one_epoch_file(f, tree, preload): baseline = (bmin, bmax) n_samp = last - first + 1 - logger.info(' Found the data of interest:') - logger.info(' t = %10.2f ... %10.2f ms' - % (1000 * first / info['sfreq'], - 1000 * last / info['sfreq'])) - if info['comps'] is not None: - logger.info(' %d CTF compensation matrices available' - % len(info['comps'])) + logger.info(" Found the data of interest:") + logger.info( + " t = %10.2f ... %10.2f ms" + % (1000 * first / info["sfreq"], 1000 * last / info["sfreq"]) + ) + if info["comps"] is not None: + logger.info( + " %d CTF compensation matrices available" % len(info["comps"]) + ) # Inspect the data if data_tag is None: - raise ValueError('Epochs data not found') - epoch_shape = (len(info['ch_names']), n_samp) + raise ValueError("Epochs data not found") + epoch_shape = (len(info["ch_names"]), n_samp) size_expected = len(events) * np.prod(epoch_shape) # on read double-precision is always used if data_tag.type == FIFF.FIFFT_FLOAT: datatype = np.float64 - fmt = '>f4' + fmt = ">f4" elif data_tag.type == FIFF.FIFFT_DOUBLE: datatype = np.float64 - fmt = '>f8' + fmt = ">f8" elif data_tag.type == FIFF.FIFFT_COMPLEX_FLOAT: datatype = np.complex128 - fmt = '>c8' + fmt = ">c8" elif data_tag.type == FIFF.FIFFT_COMPLEX_DOUBLE: datatype = np.complex128 - fmt = '>c16' + fmt = ">c16" fmt_itemsize = np.dtype(fmt).itemsize assert fmt_itemsize in (4, 8, 16) size_actual = data_tag.size // fmt_itemsize - 16 // fmt_itemsize if not size_actual == size_expected: - raise ValueError('Incorrect number of samples (%d instead of %d)' - % (size_actual, size_expected)) + raise ValueError( + "Incorrect number of samples (%d instead of %d)" + % (size_actual, size_expected) + ) # Calibration factors - cals = np.array([[info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0)] - for k in range(info['nchan'])], np.float64) + cals = np.array( + [ + [info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0)] + for k in range(info["nchan"]) + ], + np.float64, + ) # Read the data if preload: @@ -3173,10 +3663,13 @@ def _read_one_epoch_file(f, tree, preload): data *= cals # Put it all together - tmin = first / info['sfreq'] - tmax = last / info['sfreq'] - event_id = ({str(e): e for e in np.unique(events[:, 2])} - if mappings is None else mappings) + tmin = first / info["sfreq"] + tmax = last / info["sfreq"] + event_id = ( + {str(e): e for e in np.unique(events[:, 2])} + if mappings is None + else mappings + ) # In case epochs didn't have a FIFF.FIFF_MNE_EPOCHS_SELECTION tag # (version < 0.8): if selection is None: @@ -3184,9 +3677,25 @@ def _read_one_epoch_file(f, tree, preload): if drop_log is None: drop_log = ((),) * len(events) - return (info, data, data_tag, events, event_id, metadata, tmin, tmax, - baseline, selection, drop_log, epoch_shape, cals, reject_params, - fmt, annotations, raw_sfreq) + return ( + info, + data, + data_tag, + events, + event_id, + metadata, + tmin, + tmax, + baseline, + selection, + drop_log, + epoch_shape, + cals, + reject_params, + fmt, + annotations, + raw_sfreq, + ) @verbose @@ -3213,8 +3722,9 @@ def read_epochs(fname, proj=True, preload=True, verbose=None): class _RawContainer: """Helper for a raw data container.""" - def __init__(self, fid, data_tag, event_samps, epoch_shape, - cals, fmt): # noqa: D102 + def __init__( + self, fid, data_tag, event_samps, epoch_shape, cals, fmt + ): # noqa: D102 self.fid = fid self.data_tag = data_tag self.event_samps = event_samps @@ -3248,36 +3758,51 @@ class EpochsFIF(BaseEpochs): """ @verbose - def __init__(self, fname, proj=True, preload=True, - verbose=None): # noqa: D102 + def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102 if _path_like(fname): check_fname( - fname=fname, filetype='epochs', - endings=('-epo.fif', '-epo.fif.gz', '_epo.fif', '_epo.fif.gz') - ) - fname = str( - _check_fname(fname=fname, must_exist=True, overwrite="read") + fname=fname, + filetype="epochs", + endings=("-epo.fif", "-epo.fif.gz", "_epo.fif", "_epo.fif.gz"), ) + fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read")) elif not preload: - raise ValueError('preload must be used with file-like objects') + raise ValueError("preload must be used with file-like objects") fnames = [fname] ep_list = list() raw = list() for fname in fnames: fname_rep = _get_fname_rep(fname) - logger.info('Reading %s ...' % fname_rep) + logger.info("Reading %s ..." % fname_rep) fid, tree, _ = fiff_open(fname, preload=preload) next_fname = _get_next_fname(fid, fname, tree) - (info, data, data_tag, events, event_id, metadata, tmin, tmax, - baseline, selection, drop_log, epoch_shape, cals, - reject_params, fmt, annotations, raw_sfreq) = \ - _read_one_epoch_file(fid, tree, preload) + ( + info, + data, + data_tag, + events, + event_id, + metadata, + tmin, + tmax, + baseline, + selection, + drop_log, + epoch_shape, + cals, + reject_params, + fmt, + annotations, + raw_sfreq, + ) = _read_one_epoch_file(fid, tree, preload) if (events[:, 0] < 0).any(): events = events.copy() - warn('Incorrect events detected on disk, setting event ' - 'numbers to consecutive increasing integers') + warn( + "Incorrect events detected on disk, setting event " + "numbers to consecutive increasing integers" + ) events[:, 0] = np.arange(1, len(events) + 1) # here we ignore missing events, since users should already be # aware of missing events if they have saved data that way @@ -3285,35 +3810,63 @@ def __init__(self, fname, proj=True, preload=True, # correction (data is being baseline-corrected when written to # disk) epoch = BaseEpochs( - info, data, events, event_id, tmin, tmax, + info, + data, + events, + event_id, + tmin, + tmax, baseline=None, - metadata=metadata, on_missing='ignore', - selection=selection, drop_log=drop_log, - proj=False, verbose=False, raw_sfreq=raw_sfreq) + metadata=metadata, + on_missing="ignore", + selection=selection, + drop_log=drop_log, + proj=False, + verbose=False, + raw_sfreq=raw_sfreq, + ) epoch.baseline = baseline epoch._do_baseline = False # might be superfluous but won't hurt ep_list.append(epoch) if not preload: # store everything we need to index back to the original data - raw.append(_RawContainer(fiff_open(fname)[0], data_tag, - events[:, 0].copy(), epoch_shape, - cals, fmt)) + raw.append( + _RawContainer( + fiff_open(fname)[0], + data_tag, + events[:, 0].copy(), + epoch_shape, + cals, + fmt, + ) + ) if next_fname is not None: fnames.append(next_fname) unsafe_annot_add = raw_sfreq is None - (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) = _concatenate_epochs( + ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) = _concatenate_epochs( ep_list, with_data=preload, add_offset=False, - on_mismatch='raise', + on_mismatch="raise", ) # we need this uniqueness for non-preloaded data to work properly if len(np.unique(events[:, 0])) != len(events): - raise RuntimeError('Event time samples were not unique') + raise RuntimeError("Event time samples were not unique") # correct the drop log assert len(drop_log) % len(fnames) == 0 @@ -3323,7 +3876,7 @@ def __init__(self, fname, proj=True, preload=True, for i1, i2 in zip(offsets[:-1], offsets[1:]): other_log = drop_log[i1:i2] for k, (a, b) in enumerate(zip(drop_log, other_log)): - if a == ('IGNORED',) and b != ('IGNORED',): + if a == ("IGNORED",) and b != ("IGNORED",): drop_log[k] = b drop_log = tuple(drop_log[:step]) @@ -3331,12 +3884,26 @@ def __init__(self, fname, proj=True, preload=True, # again, ensure we're retaining the baseline period originally loaded # from disk without trying to re-apply baseline correction super(EpochsFIF, self).__init__( - info, data, events, event_id, tmin, tmax, - baseline=None, raw=raw, - proj=proj, preload_at_end=False, on_missing='ignore', - selection=selection, drop_log=drop_log, filename=fname_rep, - metadata=metadata, verbose=verbose, raw_sfreq=raw_sfreq, - annotations=annotations, **reject_params) + info, + data, + events, + event_id, + tmin, + tmax, + baseline=None, + raw=raw, + proj=proj, + preload_at_end=False, + on_missing="ignore", + selection=selection, + drop_log=drop_log, + filename=fname_rep, + metadata=metadata, + verbose=verbose, + raw_sfreq=raw_sfreq, + annotations=annotations, + **reject_params, + ) self.baseline = baseline self._do_baseline = False # use the private property instead of drop_bad so that epochs @@ -3361,8 +3928,10 @@ def _get_epoch_from_raw(self, idx, verbose=None): break else: # read the correct subset of the data - raise RuntimeError('Correct epoch could not be found, please ' - 'contact mne-python developers') + raise RuntimeError( + "Correct epoch could not be found, please " + "contact mne-python developers" + ) # the following is equivalent to this, but faster: # # >>> data = read_tag(raw.fid, raw.data_tag.pos).data.astype(float) @@ -3372,10 +3941,10 @@ def _get_epoch_from_raw(self, idx, verbose=None): # Eventually this could be refactored in io/tag.py if other functions # could make use of it raw.fid.seek(raw.data_tag.pos + offset, 0) - if fmt == '>c8': - read_fmt = '>f4' - elif fmt == '>c16': - read_fmt = '>f8' + if fmt == ">c8": + read_fmt = ">f4" + elif fmt == ">c16": + read_fmt = ">f8" else: read_fmt = fmt data = np.frombuffer(raw.fid.read(size), read_fmt) @@ -3406,9 +3975,11 @@ def bootstrap(epochs, random_state=None): The bootstrap samples """ if not epochs.preload: - raise RuntimeError('Modifying data of epochs is only supported ' - 'when preloading is used. Use preload=True ' - 'in the constructor.') + raise RuntimeError( + "Modifying data of epochs is only supported " + "when preloading is used. Use preload=True " + "in the constructor." + ) rng = check_random_state(random_state) epochs_bootstrap = epochs.copy() @@ -3430,27 +4001,35 @@ def _check_merge_epochs(epochs_list): raise NotImplementedError("Epochs with unequal values for baseline") -def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, - on_mismatch='raise'): +def _concatenate_epochs( + epochs_list, *, with_data=True, add_offset=True, on_mismatch="raise" +): """Auxiliary function for concatenating epochs.""" if not isinstance(epochs_list, (list, tuple)): - raise TypeError('epochs_list must be a list or tuple, got %s' - % (type(epochs_list),)) + raise TypeError( + "epochs_list must be a list or tuple, got %s" % (type(epochs_list),) + ) # to make warning messages only occur once during concatenation warned = False for ei, epochs in enumerate(epochs_list): if not isinstance(epochs, BaseEpochs): - raise TypeError('epochs_list[%d] must be an instance of Epochs, ' - 'got %s' % (ei, type(epochs))) + raise TypeError( + "epochs_list[%d] must be an instance of Epochs, " + "got %s" % (ei, type(epochs)) + ) - if (getattr(epochs, 'annotations', None) is not None and - len(epochs.annotations) > 0 and - not warned): + if ( + getattr(epochs, "annotations", None) is not None + and len(epochs.annotations) > 0 + and not warned + ): warned = True - warn('Concatenation of Annotations within Epochs is not supported ' - 'yet. All annotations will be dropped.') + warn( + "Concatenation of Annotations within Epochs is not supported " + "yet. All annotations will be dropped." + ) # create a copy, so that the Annotations are not modified in place # from the original object @@ -3470,40 +4049,42 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, event_id = deepcopy(out.event_id) selection = out.selection # offset is the last epoch + tmax + 10 second - shift = int((10 + tmax) * out.info['sfreq']) + shift = int((10 + tmax) * out.info["sfreq"]) events_offset = int(np.max(events[0][:, 0])) + shift events_overflow = False warned = False for ii, epochs in enumerate(epochs_list[1:], 1): - _ensure_infos_match(epochs.info, info, f'epochs[{ii}]', - on_mismatch=on_mismatch) + _ensure_infos_match(epochs.info, info, f"epochs[{ii}]", on_mismatch=on_mismatch) if not np.allclose(epochs.times, epochs_list[0].times): - raise ValueError('Epochs must have same times') + raise ValueError("Epochs must have same times") if epochs.baseline != baseline: - raise ValueError('Baseline must be same for all epochs') + raise ValueError("Baseline must be same for all epochs") if epochs._raw_sfreq != raw_sfreq and not warned: warned = True - warn('The original raw sampling rate of the Epochs does not ' - 'match for all Epochs. Please proceed cautiously.') + warn( + "The original raw sampling rate of the Epochs does not " + "match for all Epochs. Please proceed cautiously." + ) # compare event_id common_keys = list(set(event_id).intersection(set(epochs.event_id))) for key in common_keys: if not event_id[key] == epochs.event_id[key]: - msg = ('event_id values must be the same for identical keys ' - 'for all concatenated epochs. Key "{}" maps to {} in ' - 'some epochs and to {} in others.') - raise ValueError(msg.format(key, event_id[key], - epochs.event_id[key])) + msg = ( + "event_id values must be the same for identical keys " + 'for all concatenated epochs. Key "{}" maps to {} in ' + "some epochs and to {} in others." + ) + raise ValueError(msg.format(key, event_id[key], epochs.event_id[key])) if with_data: epochs.drop_bad() offsets.append(len(epochs)) evs = epochs.events.copy() if len(epochs.events) == 0: - warn('One of the Epochs objects to concatenate was empty.') + warn("One of the Epochs objects to concatenate was empty.") elif add_offset: # We need to cast to a native Python int here to detect an # overflow of a numpy int32 (which is the default on windows) @@ -3511,9 +4092,11 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, evs[:, 0] += events_offset events_offset += max_timestamp + shift if events_offset > INT32_MAX: - warn(f'Event number greater than {INT32_MAX} created, ' - 'events[:, 0] will be assigned consecutive increasing ' - 'integer values') + warn( + f"Event number greater than {INT32_MAX} created, " + "events[:, 0] will be assigned consecutive increasing " + "integer values" + ) events_overflow = True add_offset = False # we no longer need to add offset events.append(evs) @@ -3531,9 +4114,10 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, if n_have == 0: metadata = None elif n_have != len(metadata): - raise ValueError('%d of %d epochs instances have metadata, either ' - 'all or none must have metadata' - % (n_have, len(metadata))) + raise ValueError( + "%d of %d epochs instances have metadata, either " + "all or none must have metadata" % (n_have, len(metadata)) + ) else: pd = _check_pandas_installed(strict=False) if pd is not False: @@ -3549,15 +4133,28 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, if data is None: data = np.empty( (offsets[-1], len(out.ch_names), len(out.times)), - dtype=this_data.dtype) + dtype=this_data.dtype, + ) data[start:stop] = this_data - return (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) + return ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) @verbose -def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', - verbose=None): +def concatenate_epochs( + epochs_list, add_offset=True, *, on_mismatch="raise", verbose=None +): """Concatenate a list of `~mne.Epochs` into one `~mne.Epochs` object. .. note:: Unlike `~mne.concatenate_raws`, this function does **not** @@ -3586,8 +4183,19 @@ def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', ----- .. versionadded:: 0.9.0 """ - (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) = _concatenate_epochs( + ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) = _concatenate_epochs( epochs_list, with_data=True, add_offset=add_offset, @@ -3595,19 +4203,39 @@ def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', ) selection = np.where([len(d) == 0 for d in drop_log])[0] out = EpochsArray( - data=data, info=info, events=events, event_id=event_id, - tmin=tmin, baseline=baseline, selection=selection, drop_log=drop_log, - proj=False, on_missing='ignore', metadata=metadata, - raw_sfreq=raw_sfreq) + data=data, + info=info, + events=events, + event_id=event_id, + tmin=tmin, + baseline=baseline, + selection=selection, + drop_log=drop_log, + proj=False, + on_missing="ignore", + metadata=metadata, + raw_sfreq=raw_sfreq, + ) out.drop_bad() return out @verbose -def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, - origin='auto', weight_all=True, int_order=8, ext_order=3, - destination=None, ignore_ref=False, return_mapping=False, - mag_scale=100., verbose=None): +def average_movements( + epochs, + head_pos=None, + orig_sfreq=None, + picks=None, + origin="auto", + weight_all=True, + int_order=8, + ext_order=3, + destination=None, + ignore_ref=False, + return_mapping=False, + mag_scale=100.0, + verbose=None, +): """Average data using Maxwell filtering, transforming using head positions. Parameters @@ -3668,37 +4296,48 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, of children in MEG: Quantification, effects on source estimation, and compensation. NeuroImage 40:541–550, 2008. """ # noqa: E501 - from .preprocessing.maxwell import (_trans_sss_basis, _reset_meg_bads, - _check_usable, _col_norm_pinv, - _get_n_moments, _get_mf_picks_fix_mags, - _prep_mf_coils, _check_destination, - _remove_meg_projs_comps, - _get_coil_scale, _get_sensor_operator) + from .preprocessing.maxwell import ( + _trans_sss_basis, + _reset_meg_bads, + _check_usable, + _col_norm_pinv, + _get_n_moments, + _get_mf_picks_fix_mags, + _prep_mf_coils, + _check_destination, + _remove_meg_projs_comps, + _get_coil_scale, + _get_sensor_operator, + ) + if head_pos is None: - raise TypeError('head_pos must be provided and cannot be None') + raise TypeError("head_pos must be provided and cannot be None") from .chpi import head_pos_to_trans_rot_t + if not isinstance(epochs, BaseEpochs): - raise TypeError('epochs must be an instance of Epochs, not %s' - % (type(epochs),)) - orig_sfreq = epochs.info['sfreq'] if orig_sfreq is None else orig_sfreq + raise TypeError( + "epochs must be an instance of Epochs, not %s" % (type(epochs),) + ) + orig_sfreq = epochs.info["sfreq"] if orig_sfreq is None else orig_sfreq orig_sfreq = float(orig_sfreq) if isinstance(head_pos, np.ndarray): head_pos = head_pos_to_trans_rot_t(head_pos) trn, rot, t = head_pos del head_pos _check_usable(epochs, ignore_ref) - origin = _check_origin(origin, epochs.info, 'head') + origin = _check_origin(origin, epochs.info, "head") recon_trans = _check_destination(destination, epochs.info, True) - logger.info('Aligning and averaging up to %s epochs' - % (len(epochs.events))) + logger.info("Aligning and averaging up to %s epochs" % (len(epochs.events))) if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])): - raise RuntimeError('Epochs must have monotonically increasing events') + raise RuntimeError("Epochs must have monotonically increasing events") info_to = epochs.info.copy() - meg_picks, mag_picks, grad_picks, good_mask, _ = \ - _get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref) + meg_picks, mag_picks, grad_picks, good_mask, _ = _get_mf_picks_fix_mags( + info_to, int_order, ext_order, ignore_ref + ) coil_scale, mag_scale = _get_coil_scale( - meg_picks, mag_picks, grad_picks, mag_scale, info_to) + meg_picks, mag_picks, grad_picks, mag_scale, info_to + ) mult = _get_sensor_operator(epochs, meg_picks) n_channels, n_times = len(epochs.ch_names), len(epochs.times) other_picks = np.setdiff1d(np.arange(n_channels), meg_picks) @@ -3711,37 +4350,36 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, # remove MEG bads in "to" info _reset_meg_bads(info_to) # set up variables - w_sum = 0. + w_sum = 0.0 n_in, n_out = _get_n_moments([int_order, ext_order]) - S_decomp = 0. # this will end up being a weighted average + S_decomp = 0.0 # this will end up being a weighted average last_trans = None decomp_coil_scale = coil_scale[good_mask] - exp = dict(int_order=int_order, ext_order=ext_order, head_frame=True, - origin=origin) + exp = dict(int_order=int_order, ext_order=ext_order, head_frame=True, origin=origin) n_in = _get_n_moments(int_order) for ei, epoch in enumerate(epochs): event_time = epochs.events[epochs._current - 1, 0] / orig_sfreq use_idx = np.where(t <= event_time)[0] if len(use_idx) == 0: - trans = info_to['dev_head_t']['trans'] + trans = info_to["dev_head_t"]["trans"] else: use_idx = use_idx[-1] - trans = np.vstack([np.hstack([rot[use_idx], trn[[use_idx]].T]), - [[0., 0., 0., 1.]]]) - loc_str = ', '.join('%0.1f' % tr for tr in (trans[:3, 3] * 1000)) + trans = np.vstack( + [np.hstack([rot[use_idx], trn[[use_idx]].T]), [[0.0, 0.0, 0.0, 1.0]]] + ) + loc_str = ", ".join("%0.1f" % tr for tr in (trans[:3, 3] * 1000)) if last_trans is None or not np.allclose(last_trans, trans): - logger.info(' Processing epoch %s (device location: %s mm)' - % (ei + 1, loc_str)) + logger.info( + " Processing epoch %s (device location: %s mm)" % (ei + 1, loc_str) + ) reuse = False last_trans = trans else: - logger.info(' Processing epoch %s (device location: same)' - % (ei + 1,)) + logger.info(" Processing epoch %s (device location: same)" % (ei + 1,)) reuse = True epoch = epoch.copy() # because we operate inplace if not reuse: - S = _trans_sss_basis(exp, all_coils, trans, - coil_scale=decomp_coil_scale) + S = _trans_sss_basis(exp, all_coils, trans, coil_scale=decomp_coil_scale) # Get the weight from the un-regularized version (eq. 44) weight = np.linalg.norm(S[:, :n_in]) # XXX Eventually we could do cross-talk and fine-cal here @@ -3762,12 +4400,12 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, S_decomp /= w_sum # Get recon matrix # (We would need to include external here for regularization to work) - exp['ext_order'] = 0 + exp["ext_order"] = 0 S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans) if mult is not None: S_decomp = mult @ S_decomp S_recon = mult @ S_recon - exp['ext_order'] = ext_order + exp["ext_order"] = ext_order # We could determine regularization on basis of destination basis # matrix, restricted to good channels, as regularizing individual # matrices within the loop above does not seem to work. But in @@ -3781,19 +4419,26 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, mapping = np.dot(S_recon, pS_ave) # Apply mapping data[meg_picks] = np.dot(mapping, data[meg_picks[good_mask]]) - info_to['dev_head_t'] = recon_trans # set the reconstruction transform - evoked = epochs._evoked_from_epoch_data(data, info_to, picks, - n_events=count, kind='average', - comment=epochs._name) + info_to["dev_head_t"] = recon_trans # set the reconstruction transform + evoked = epochs._evoked_from_epoch_data( + data, info_to, picks, n_events=count, kind="average", comment=epochs._name + ) _remove_meg_projs_comps(evoked, ignore_ref) - logger.info('Created Evoked dataset from %s epochs' % (count,)) + logger.info("Created Evoked dataset from %s epochs" % (count,)) return (evoked, mapping) if return_mapping else evoked @verbose -def make_fixed_length_epochs(raw, duration=1., preload=False, - reject_by_annotation=True, proj=True, overlap=0., - id=1, verbose=None): +def make_fixed_length_epochs( + raw, + duration=1.0, + preload=False, + reject_by_annotation=True, + proj=True, + overlap=0.0, + id=1, + verbose=None, +): """Divide continuous raw data into equal-sized consecutive epochs. Parameters @@ -3829,10 +4474,17 @@ def make_fixed_length_epochs(raw, duration=1., preload=False, ----- .. versionadded:: 0.20 """ - events = make_fixed_length_events(raw, id=id, duration=duration, - overlap=overlap) - delta = 1. / raw.info['sfreq'] - return Epochs(raw, events, event_id=[id], tmin=0, tmax=duration - delta, - baseline=None, preload=preload, - reject_by_annotation=reject_by_annotation, proj=proj, - verbose=verbose) + events = make_fixed_length_events(raw, id=id, duration=duration, overlap=overlap) + delta = 1.0 / raw.info["sfreq"] + return Epochs( + raw, + events, + event_id=[id], + tmin=0, + tmax=duration - delta, + baseline=None, + preload=preload, + reject_by_annotation=reject_by_annotation, + proj=proj, + verbose=verbose, + ) diff --git a/mne/event.py b/mne/event.py index 68f943c3b49..63cb994db8a 100644 --- a/mne/event.py +++ b/mne/event.py @@ -12,9 +12,19 @@ import numpy as np -from .utils import (check_fname, logger, verbose, _get_stim_channel, warn, - _validate_type, _check_option, fill_doc, _check_fname, - _on_missing, _check_on_missing) +from .utils import ( + check_fname, + logger, + verbose, + _get_stim_channel, + warn, + _validate_type, + _check_option, + fill_doc, + _check_fname, + _on_missing, + _check_on_missing, +) from .io.constants import FIFF from .io.tree import dir_tree_find from .io.tag import read_tag @@ -75,8 +85,9 @@ def pick_events(events, include=None, exclude=None, step=False): return events -def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, - new_id=None, fill_na=None): +def define_target_events( + events, reference_id, target_id, sfreq, tmin, tmax, new_id=None, fill_na=None +): """Define new events by co-occurrence of existing events. This function can be used to evaluate events depending on the @@ -125,8 +136,11 @@ def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, if event[2] == reference_id: lower = event[0] + imin upper = event[0] + imax - res = events[(events[:, 0] > lower) & - (events[:, 0] < upper) & (events[:, 2] == target_id)] + res = events[ + (events[:, 0] > lower) + & (events[:, 0] < upper) + & (events[:, 2] == target_id) + ] if res.any(): lag += [event[0] - res[0][0]] event[2] = new_id @@ -138,8 +152,8 @@ def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, new_events = np.array(new_events) - with np.errstate(invalid='ignore'): # casting nans - lag = np.abs(lag, dtype='f8') + with np.errstate(invalid="ignore"): # casting nans + lag = np.abs(lag, dtype="f8") if lag.any(): lag *= tsample else: @@ -155,12 +169,12 @@ def _read_events_fif(fid, tree): if len(events) == 0: fid.close() - raise ValueError('Could not find event data') + raise ValueError("Could not find event data") events = events[0] event_list = None event_id = None - for d in events['directory']: + for d in events["directory"]: kind = d.kind pos = d.pos if kind == FIFF.FIFF_MNE_EVENT_LIST: @@ -169,21 +183,20 @@ def _read_events_fif(fid, tree): event_list.shape = (-1, 3) break if event_list is None: - raise ValueError('Could not find any events') - for d in events['directory']: + raise ValueError("Could not find any events") + for d in events["directory"]: kind = d.kind pos = d.pos if kind == FIFF.FIFF_DESCRIPTION: tag = read_tag(fid, pos) event_id = tag.data - m_ = [[s[::-1] for s in m[::-1].split(':', 1)] - for m in event_id.split(';')] + m_ = [[s[::-1] for s in m[::-1].split(":", 1)] for m in event_id.split(";")] event_id = {k: int(v) for v, k in m_} break elif kind == FIFF.FIFF_MNE_EVENT_COMMENTS: tag = read_tag(fid, pos) event_id = tag.data - event_id = event_id.tobytes().decode('latin-1').split('\x00')[:-1] + event_id = event_id.tobytes().decode("latin-1").split("\x00")[:-1] assert len(event_id) == len(event_list) event_id = {k: v[2] for k, v in zip(event_id, event_list)} break @@ -191,8 +204,15 @@ def _read_events_fif(fid, tree): @verbose -def read_events(filename, include=None, exclude=None, mask=None, - mask_type='and', return_event_id=False, verbose=None): +def read_events( + filename, + include=None, + exclude=None, + mask=None, + mask_type="and", + return_event_id=False, + verbose=None, +): """Read :term:`events` from fif or text file. See :ref:`tut-events-vs-annotations` and :ref:`tut-event-arrays` @@ -247,11 +267,22 @@ def read_events(filename, include=None, exclude=None, mask=None, For more information on ``mask`` and ``mask_type``, see :func:`mne.find_events`. """ - check_fname(filename, 'events', ('.eve', '-eve.fif', '-eve.fif.gz', - '-eve.lst', '-eve.txt', '_eve.fif', - '_eve.fif.gz', '_eve.lst', '_eve.txt', - '-annot.fif', # MNE-C annot - )) + check_fname( + filename, + "events", + ( + ".eve", + "-eve.fif", + "-eve.fif.gz", + "-eve.lst", + "-eve.txt", + "_eve.fif", + "_eve.fif.gz", + "_eve.lst", + "_eve.txt", + "-annot.fif", # MNE-C annot + ), + ) filename = Path(filename) if filename.suffix in (".fif", ".gz"): fid, tree, _ = fiff_open(filename) @@ -264,7 +295,7 @@ def read_events(filename, include=None, exclude=None, mask=None, # eve/lst files had a second float column that will raise errors lines = np.loadtxt(filename, dtype=np.float64).astype(int) if len(lines) == 0: - raise ValueError('No text lines found') + raise ValueError("No text lines found") if lines.ndim == 1: # Special case for only one event lines = lines[np.newaxis, :] @@ -274,13 +305,12 @@ def read_events(filename, include=None, exclude=None, mask=None, elif len(lines[0]) == 3: goods = [0, 1, 2] else: - raise ValueError('Unknown number of columns in event text file') + raise ValueError("Unknown number of columns in event text file") event_list = lines[:, goods] - if (mask is not None and event_list.shape[0] > 0 and - event_list[0, 2] == 0): + if mask is not None and event_list.shape[0] > 0 and event_list[0, 2] == 0: event_list = event_list[1:] - warn('first row of event file discarded (zero-valued)') + warn("first row of event file discarded (zero-valued)") event_id = None event_list = pick_events(event_list, include, exclude) @@ -289,12 +319,13 @@ def read_events(filename, include=None, exclude=None, mask=None, event_list = _mask_trigs(event_list, mask, mask_type) masked_len = event_list.shape[0] if masked_len < unmasked_len: - warn('{} of {} events masked'.format(unmasked_len - masked_len, - unmasked_len)) + warn( + "{} of {} events masked".format(unmasked_len - masked_len, unmasked_len) + ) out = event_list if return_event_id: if event_id is None: - raise RuntimeError('No event_id found in the file') + raise RuntimeError("No event_id found in the file") out = (out, event_id) return out @@ -321,26 +352,38 @@ def write_events(filename, events, *, overwrite=False, verbose=None): read_events """ filename = _check_fname(filename, overwrite=overwrite) - check_fname(filename, 'events', ('.eve', '-eve.fif', '-eve.fif.gz', - '-eve.lst', '-eve.txt', '_eve.fif', - '_eve.fif.gz', '_eve.lst', '_eve.txt')) - if filename.suffix in ('.fif', '.gz'): + check_fname( + filename, + "events", + ( + ".eve", + "-eve.fif", + "-eve.fif.gz", + "-eve.lst", + "-eve.txt", + "_eve.fif", + "_eve.fif.gz", + "_eve.lst", + "_eve.txt", + ), + ) + if filename.suffix in (".fif", ".gz"): # Start writing... with start_and_end_file(filename) as fid: start_block(fid, FIFF.FIFFB_MNE_EVENTS) write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, events.T) end_block(fid, FIFF.FIFFB_MNE_EVENTS) else: - with open(filename, 'w') as f: + with open(filename, "w") as f: for e in events: - f.write('%6d %6d %3d\n' % tuple(e)) + f.write("%6d %6d %3d\n" % tuple(e)) def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): changed = np.diff(data, axis=1) != 0 idx = np.where(np.all(changed, axis=0))[0] if len(idx) == 0: - return np.empty((0, 3), dtype='int32') + return np.empty((0, 3), dtype="int32") pre_step = data[0, idx] idx += 1 @@ -361,7 +404,7 @@ def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): if merge != 0: diff = np.diff(steps[:, 0]) - idx = (diff <= abs(merge)) + idx = diff <= abs(merge) if np.any(idx): where = np.where(idx)[0] keep = np.logical_not(idx) @@ -374,15 +417,14 @@ def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): steps[where, 2] = steps[where + 1, 2] keep = np.insert(keep, 0, True) - is_step = (steps[:, 1] != steps[:, 2]) + is_step = steps[:, 1] != steps[:, 2] keep = np.logical_and(keep, is_step) steps = steps[keep] return steps -def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, - stim_channel=None): +def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, stim_channel=None): """Find all steps in data from a stim channel. Parameters @@ -422,24 +464,33 @@ def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, # pull stim channel from config if necessary stim_channel = _get_stim_channel(stim_channel, raw.info) - picks = pick_channels( - raw.info['ch_names'], include=stim_channel, ordered=False) + picks = pick_channels(raw.info["ch_names"], include=stim_channel, ordered=False) if len(picks) == 0: - raise ValueError('No stim channel found to extract event triggers.') + raise ValueError("No stim channel found to extract event triggers.") data, _ = raw[picks, :] if np.any(data < 0): - warn('Trigger channel contains negative values, using absolute value.') + warn("Trigger channel contains negative values, using absolute value.") data = np.abs(data) # make sure trig channel is positive data = data.astype(np.int64) - return _find_stim_steps(data, raw.first_samp, pad_start=pad_start, - pad_stop=pad_stop, merge=merge) + return _find_stim_steps( + data, raw.first_samp, pad_start=pad_start, pad_stop=pad_stop, merge=merge + ) @verbose -def _find_events(data, first_samp, verbose=None, output='onset', - consecutive='increasing', min_samples=0, mask=None, - uint_cast=False, mask_type='and', initial_event=False): +def _find_events( + data, + first_samp, + verbose=None, + output="onset", + consecutive="increasing", + min_samples=0, + mask=None, + uint_cast=False, + mask_type="and", + initial_event=False, +): """Help find events.""" assert data.shape[0] == 1 # data should be only a row vector @@ -454,42 +505,46 @@ def _find_events(data, first_samp, verbose=None, output='onset', if uint_cast: data = data.astype(np.uint16).astype(np.int64) if data.min() < 0: - warn('Trigger channel contains negative values, using absolute ' - 'value. If data were acquired on a Neuromag system with ' - 'STI016 active, consider using uint_cast=True to work around ' - 'an acquisition bug') + warn( + "Trigger channel contains negative values, using absolute " + "value. If data were acquired on a Neuromag system with " + "STI016 active, consider using uint_cast=True to work around " + "an acquisition bug" + ) data = np.abs(data) # make sure trig channel is positive events = _find_stim_steps(data, first_samp, pad_stop=0, merge=merge) initial_value = data[0, 0] if initial_value != 0: if initial_event: - events = np.insert( - events, 0, [first_samp, 0, initial_value], axis=0) + events = np.insert(events, 0, [first_samp, 0, initial_value], axis=0) else: - logger.info('Trigger channel has a non-zero initial value of {} ' - '(consider using initial_event=True to detect this ' - 'event)'.format(initial_value)) + logger.info( + "Trigger channel has a non-zero initial value of {} " + "(consider using initial_event=True to detect this " + "event)".format(initial_value) + ) events = _mask_trigs(events, mask, mask_type) # Determine event onsets and offsets - if consecutive == 'increasing': - onsets = (events[:, 2] > events[:, 1]) - offsets = np.logical_and(np.logical_or(onsets, (events[:, 2] == 0)), - (events[:, 1] > 0)) + if consecutive == "increasing": + onsets = events[:, 2] > events[:, 1] + offsets = np.logical_and( + np.logical_or(onsets, (events[:, 2] == 0)), (events[:, 1] > 0) + ) elif consecutive: - onsets = (events[:, 2] > 0) - offsets = (events[:, 1] > 0) + onsets = events[:, 2] > 0 + offsets = events[:, 1] > 0 else: - onsets = (events[:, 1] == 0) - offsets = (events[:, 2] == 0) + onsets = events[:, 1] == 0 + offsets = events[:, 2] == 0 onset_idx = np.where(onsets)[0] offset_idx = np.where(offsets)[0] if len(onset_idx) == 0 or len(offset_idx) == 0: - return np.empty((0, 3), dtype='int32') + return np.empty((0, 3), dtype="int32") # delete orphaned onsets/offsets if onset_idx[0] > offset_idx[0]: @@ -500,12 +555,12 @@ def _find_events(data, first_samp, verbose=None, output='onset', logger.info("Removing orphaned onset at the end of the file.") onset_idx = np.delete(onset_idx, -1) - if output == 'onset': + if output == "onset": events = events[onset_idx] - elif output == 'step': + elif output == "step": idx = np.union1d(onset_idx, offset_idx) events = events[idx] - elif output == 'offset': + elif output == "offset": event_id = events[onset_idx, 2] events = events[offset_idx] events[:, 1] = events[:, 2] @@ -523,20 +578,32 @@ def _find_events(data, first_samp, verbose=None, output='onset', def _find_unique_events(events): """Uniquify events (ie remove duplicated rows.""" e = np.ascontiguousarray(events).view( - np.dtype((np.void, events.dtype.itemsize * events.shape[1]))) + np.dtype((np.void, events.dtype.itemsize * events.shape[1])) + ) _, idx = np.unique(e, return_index=True) n_dupes = len(events) - len(idx) if n_dupes > 0: - warn("Some events are duplicated in your different stim channels." - " %d events were ignored during deduplication." % n_dupes) + warn( + "Some events are duplicated in your different stim channels." + " %d events were ignored during deduplication." % n_dupes + ) return events[idx] @verbose -def find_events(raw, stim_channel=None, output='onset', - consecutive='increasing', min_duration=0, - shortest_event=2, mask=None, uint_cast=False, - mask_type='and', initial_event=False, verbose=None): +def find_events( + raw, + stim_channel=None, + output="onset", + consecutive="increasing", + min_duration=0, + shortest_event=2, + mask=None, + uint_cast=False, + mask_type="and", + initial_event=False, + verbose=None, +): """Find :term:`events` from raw file. See :ref:`tut-events-vs-annotations` and :ref:`tut-event-arrays` @@ -683,42 +750,53 @@ def find_events(raw, stim_channel=None, output='onset', ---------------- 2 '0000010' """ - min_samples = min_duration * raw.info['sfreq'] + min_samples = min_duration * raw.info["sfreq"] # pull stim channel from config if necessary try: stim_channel = _get_stim_channel(stim_channel, raw.info) except ValueError: if len(raw.annotations) > 0: - raise ValueError("No stim channels found, but the raw object has " - "annotations. Consider using " - "mne.events_from_annotations to convert these to " - "events.") + raise ValueError( + "No stim channels found, but the raw object has " + "annotations. Consider using " + "mne.events_from_annotations to convert these to " + "events." + ) else: raise - picks = pick_channels(raw.info['ch_names'], include=stim_channel) + picks = pick_channels(raw.info["ch_names"], include=stim_channel) if len(picks) == 0: - raise ValueError('No stim channel found to extract event triggers.') + raise ValueError("No stim channel found to extract event triggers.") data, _ = raw[picks, :] events_list = [] for d in data: - events = _find_events(d[np.newaxis, :], raw.first_samp, - verbose=verbose, output=output, - consecutive=consecutive, min_samples=min_samples, - mask=mask, uint_cast=uint_cast, - mask_type=mask_type, initial_event=initial_event) + events = _find_events( + d[np.newaxis, :], + raw.first_samp, + verbose=verbose, + output=output, + consecutive=consecutive, + min_samples=min_samples, + mask=mask, + uint_cast=uint_cast, + mask_type=mask_type, + initial_event=initial_event, + ) # add safety check for spurious events (for ex. from neuromag syst.) by # checking the number of low sample events n_short_events = np.sum(np.diff(events[:, 0]) < shortest_event) if n_short_events > 0: - raise ValueError("You have %i events shorter than the " - "shortest_event. These are very unusual and you " - "may want to set min_duration to a larger value " - "e.g. x / raw.info['sfreq']. Where x = 1 sample " - "shorter than the shortest event " - "length." % (n_short_events)) + raise ValueError( + "You have %i events shorter than the " + "shortest_event. These are very unusual and you " + "may want to set min_duration to a larger value " + "e.g. x / raw.info['sfreq']. Where x = 1 sample " + "shorter than the shortest event " + "length." % (n_short_events) + ) events_list.append(events) @@ -730,7 +808,7 @@ def find_events(raw, stim_channel=None, output='onset', def _mask_trigs(events, mask, mask_type): """Mask digital trigger values.""" - _check_option('mask_type', mask_type, ['not_and', 'and']) + _check_option("mask_type", mask_type, ["not_and", "and"]) if mask is not None: _validate_type(mask, "int", "mask", "int or None") n_events = len(events) @@ -738,11 +816,13 @@ def _mask_trigs(events, mask, mask_type): return events.copy() if mask is not None: - if mask_type == 'not_and': + if mask_type == "not_and": mask = np.bitwise_not(mask) - elif mask_type != 'and': - raise ValueError("'mask_type' should be either 'and'" - " or 'not_and', instead of '%s'" % mask_type) + elif mask_type != "and": + raise ValueError( + "'mask_type' should be either 'and'" + " or 'not_and', instead of '%s'" % mask_type + ) events[:, 1:] = np.bitwise_and(events[:, 1:], mask) events = events[events[:, 1] != events[:, 2]] @@ -841,8 +921,9 @@ def shift_time_events(events, ids, tshift, sfreq): @fill_doc -def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., - first_samp=True, overlap=0.): +def make_fixed_length_events( + raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0 +): """Make a set of :term:`events` separated by a fixed duration. Parameters @@ -875,14 +956,16 @@ def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., %(events)s """ from .io.base import BaseRaw + _validate_type(raw, BaseRaw, "raw") _validate_type(id, int, "id") _validate_type(duration, "numeric", "duration") _validate_type(overlap, "numeric", "overlap") duration, overlap = float(duration), float(overlap) if not 0 <= overlap < duration: - raise ValueError('overlap must be >=0 but < duration (%s), got %s' - % (duration, overlap)) + raise ValueError( + "overlap must be >=0 but < duration (%s), got %s" % (duration, overlap) + ) start = raw.time_as_index(start, use_rounding=True)[0] if stop is not None: @@ -895,16 +978,17 @@ def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., else: stop = min([stop, len(raw.times)]) # Make sure we don't go out the end of the file: - stop -= int(np.round(raw.info['sfreq'] * duration)) + stop -= int(np.round(raw.info["sfreq"] * duration)) # This should be inclusive due to how we generally use start and stop... - ts = np.arange(start, stop + 1, - raw.info['sfreq'] * (duration - overlap)).astype(int) + ts = np.arange(start, stop + 1, raw.info["sfreq"] * (duration - overlap)).astype( + int + ) n_events = len(ts) if n_events == 0: - raise ValueError('No events produced, check the values of start, ' - 'stop, and duration') - events = np.c_[ts, np.zeros(n_events, dtype=int), - id * np.ones(n_events, dtype=int)] + raise ValueError( + "No events produced, check the values of start, " "stop, and duration" + ) + events = np.c_[ts, np.zeros(n_events, dtype=int), id * np.ones(n_events, dtype=int)] return events @@ -935,10 +1019,10 @@ def concatenate_events(events, first_samps, last_samps): mne.concatenate_raws """ _validate_type(events, list, "events") - if not (len(events) == len(last_samps) and - len(events) == len(first_samps)): - raise ValueError('events, first_samps, and last_samps must all have ' - 'the same lengths') + if not (len(events) == len(last_samps) and len(events) == len(first_samps)): + raise ValueError( + "events, first_samps, and last_samps must all have " "the same lengths" + ) first_samps = np.array(first_samps) last_samps = np.array(last_samps) n_samps = np.cumsum(last_samps - first_samps + 1) @@ -994,85 +1078,125 @@ class AcqParserFIF: """ # DACQ variables always start with one of these - _acq_var_magic = ['ERF', 'DEF', 'ACQ', 'TCP'] + _acq_var_magic = ["ERF", "DEF", "ACQ", "TCP"] # averager related DACQ variable names (without preceding 'ERF') # old versions (DACQ < 3.4) - _dacq_vars_compat = ('megMax', 'megMin', 'megNoise', 'megSlope', - 'megSpike', 'eegMax', 'eegMin', 'eegNoise', - 'eegSlope', 'eegSpike', 'eogMax', 'ecgMax', 'ncateg', - 'nevent', 'stimSource', 'triggerMap', 'update', - 'artefIgnore', 'averUpdate') - - _event_vars_compat = ('Comment', 'Delay') - - _cat_vars = ('Comment', 'Display', 'Start', 'State', 'End', 'Event', - 'Nave', 'ReqEvent', 'ReqWhen', 'ReqWithin', 'SubAve') + _dacq_vars_compat = ( + "megMax", + "megMin", + "megNoise", + "megSlope", + "megSpike", + "eegMax", + "eegMin", + "eegNoise", + "eegSlope", + "eegSpike", + "eogMax", + "ecgMax", + "ncateg", + "nevent", + "stimSource", + "triggerMap", + "update", + "artefIgnore", + "averUpdate", + ) + + _event_vars_compat = ("Comment", "Delay") + + _cat_vars = ( + "Comment", + "Display", + "Start", + "State", + "End", + "Event", + "Nave", + "ReqEvent", + "ReqWhen", + "ReqWithin", + "SubAve", + ) # new versions only (DACQ >= 3.4) - _dacq_vars = _dacq_vars_compat + ('magMax', 'magMin', 'magNoise', - 'magSlope', 'magSpike', 'version') - - _event_vars = _event_vars_compat + ('Name', 'Channel', 'NewBits', - 'OldBits', 'NewMask', 'OldMask') + _dacq_vars = _dacq_vars_compat + ( + "magMax", + "magMin", + "magNoise", + "magSlope", + "magSpike", + "version", + ) + + _event_vars = _event_vars_compat + ( + "Name", + "Channel", + "NewBits", + "OldBits", + "NewMask", + "OldMask", + ) def __init__(self, info): # noqa: D102 - acq_pars = info['acq_pars'] + acq_pars = info["acq_pars"] if not acq_pars: - raise ValueError('No acquisition parameters') + raise ValueError("No acquisition parameters") self.acq_dict = dict(self._acqpars_gen(acq_pars)) - if 'ERFversion' in self.acq_dict: + if "ERFversion" in self.acq_dict: self.compat = False # DACQ ver >= 3.4 - elif 'ERFncateg' in self.acq_dict: # probably DACQ < 3.4 + elif "ERFncateg" in self.acq_dict: # probably DACQ < 3.4 self.compat = True else: - raise ValueError('Cannot parse acquisition parameters') + raise ValueError("Cannot parse acquisition parameters") dacq_vars = self._dacq_vars_compat if self.compat else self._dacq_vars # set instance variables for var in dacq_vars: - val = self.acq_dict['ERF' + var] - if var[:3] in ['mag', 'meg', 'eeg', 'eog', 'ecg']: + val = self.acq_dict["ERF" + var] + if var[:3] in ["mag", "meg", "eeg", "eog", "ecg"]: val = float(val) - elif var in ['ncateg', 'nevent']: + elif var in ["ncateg", "nevent"]: val = int(val) setattr(self, var.lower(), val) - self.stimsource = ( - 'Internal' if self.stimsource == '1' else 'External') + self.stimsource = "Internal" if self.stimsource == "1" else "External" # collect all events and categories self._events = self._events_from_acq_pars() self._categories = self._categories_from_acq_pars() # mark events that are used by a category for cat in self._categories.values(): - if cat['event']: - self._events[cat['event']]['in_use'] = True - if cat['reqevent']: - self._events[cat['reqevent']]['in_use'] = True + if cat["event"]: + self._events[cat["event"]]["in_use"] = True + if cat["reqevent"]: + self._events[cat["reqevent"]]["in_use"] = True # make mne rejection dicts based on the averager parameters - self.reject = {'grad': self.megmax, 'eeg': self.eegmax, - 'eog': self.eogmax, 'ecg': self.ecgmax} + self.reject = { + "grad": self.megmax, + "eeg": self.eegmax, + "eog": self.eogmax, + "ecg": self.ecgmax, + } if not self.compat: - self.reject['mag'] = self.magmax - self.reject = {k: float(v) for k, v in self.reject.items() - if float(v) > 0} - self.flat = {'grad': self.megmin, 'eeg': self.eegmin} + self.reject["mag"] = self.magmax + self.reject = {k: float(v) for k, v in self.reject.items() if float(v) > 0} + self.flat = {"grad": self.megmin, "eeg": self.eegmin} if not self.compat: - self.flat['mag'] = self.magmin - self.flat = {k: float(v) for k, v in self.flat.items() - if float(v) > 0} + self.flat["mag"] = self.magmin + self.flat = {k: float(v) for k, v in self.flat.items() if float(v) > 0} def __repr__(self): # noqa: D105 - s = ' bits for old DACQ versions - _compat_event_lookup = {1: 1, 2: 2, 3: 4, 4: 8, 5: 16, 6: 32, 7: 3, - 8: 5, 9: 6, 10: 7, 11: 9, 12: 10, 13: 11, - 14: 12, 15: 13, 16: 14, 17: 15} + _compat_event_lookup = { + 1: 1, + 2: 2, + 3: 4, + 4: 8, + 5: 16, + 6: 32, + 7: 3, + 8: 5, + 9: 6, + 10: 7, + 11: 9, + 12: 10, + 13: 11, + 14: 12, + 15: 13, + 16: 14, + 17: 15, + } events = dict() for evnum in range(1, self.nevent + 1): evnum_s = str(evnum).zfill(2) # '01', '02' etc. evdi = dict() - event_vars = (self._event_vars_compat if self.compat - else self._event_vars) + event_vars = self._event_vars_compat if self.compat else self._event_vars for var in event_vars: # name of DACQ variable, e.g. 'ERFeventNewBits01' - acq_key = 'ERFevent' + var + evnum_s + acq_key = "ERFevent" + var + evnum_s # corresponding dict key, e.g. 'newbits' dict_key = var.lower() val = self.acq_dict[acq_key] # type convert numeric values - if dict_key in ['newbits', 'oldbits', 'newmask', 'oldmask']: + if dict_key in ["newbits", "oldbits", "newmask", "oldmask"]: val = int(val) - elif dict_key in ['delay']: + elif dict_key in ["delay"]: val = float(val) evdi[dict_key] = val - evdi['in_use'] = False # __init__() will set this - evdi['index'] = evnum + evdi["in_use"] = False # __init__() will set this + evdi["index"] = evnum if self.compat: - evdi['name'] = str(evnum) - evdi['oldmask'] = 63 - evdi['newmask'] = 63 - evdi['oldbits'] = 0 - evdi['newbits'] = _compat_event_lookup[evnum] + evdi["name"] = str(evnum) + evdi["oldmask"] = 63 + evdi["newmask"] = 63 + evdi["oldbits"] = 0 + evdi["newbits"] = _compat_event_lookup[evnum] events[evnum] = evdi return events def _acqpars_gen(self, acq_pars): """Yield key/value pairs from ``info['acq_pars'])``.""" - key, val = '', '' + key, val = "", "" for line in acq_pars.split(): if any([line.startswith(x) for x in self._acq_var_magic]): key = line - val = '' + val = "" else: if not key: - raise ValueError('Cannot parse acquisition parameters') + raise ValueError("Cannot parse acquisition parameters") # DACQ splits items with spaces into multiple lines - val += ' ' + line if val else line + val += " " + line if val else line yield key, val def _categories_from_acq_pars(self): @@ -1210,20 +1349,20 @@ def _categories_from_acq_pars(self): catdi = dict() # read all category variables for var in self._cat_vars: - acq_key = 'ERFcat' + var + catnum + acq_key = "ERFcat" + var + catnum class_key = var.lower() val = self.acq_dict[acq_key] catdi[class_key] = val # some type conversions - catdi['display'] = (catdi['display'] == '1') - catdi['state'] = (catdi['state'] == '1') - for key in ['start', 'end', 'reqwithin']: + catdi["display"] = catdi["display"] == "1" + catdi["state"] = catdi["state"] == "1" + for key in ["start", "end", "reqwithin"]: catdi[key] = float(catdi[key]) - for key in ['nave', 'event', 'reqevent', 'reqwhen', 'subave']: + for key in ["nave", "event", "reqevent", "reqwhen", "subave"]: catdi[key] = int(catdi[key]) # some convenient extra (non-DACQ) vars - catdi['index'] = int(catnum) # index of category in DACQ list - cats[catdi['comment']] = catdi + catdi["index"] = int(catnum) # index of category in DACQ list + cats[catdi["comment"]] = catdi return cats def _events_mne_to_dacq(self, mne_events): @@ -1239,13 +1378,13 @@ def _events_mne_to_dacq(self, mne_events): events_ = mne_events.copy() events_[:, 1:3] = 0 for n, ev in self._events.items(): - if ev['in_use']: + if ev["in_use"]: pre_ok = ( - np.bitwise_and(ev['oldmask'], - mne_events[:, 1]) == ev['oldbits']) + np.bitwise_and(ev["oldmask"], mne_events[:, 1]) == ev["oldbits"] + ) post_ok = ( - np.bitwise_and(ev['newmask'], - mne_events[:, 2]) == ev['newbits']) + np.bitwise_and(ev["newmask"], mne_events[:, 2]) == ev["newbits"] + ) ok_ind = np.where(pre_ok & post_ok) events_[ok_ind, 2] |= 1 << (n - 1) return events_ @@ -1257,8 +1396,8 @@ def _mne_events_to_category_t0(self, cat, mne_events, sfreq): Then the zero times for the epochs are obtained by considering the reference and conditional (required) events and the delay to stimulus. """ - cat_ev = cat['event'] - cat_reqev = cat['reqevent'] + cat_ev = cat["event"] + cat_reqev = cat["reqevent"] # first convert mne events to dacq event list events = self._events_mne_to_dacq(mne_events) # next, take req. events and delays into account @@ -1268,25 +1407,25 @@ def _mne_events_to_category_t0(self, cat, mne_events, sfreq): refEvents_t = times[refEvents_inds] if cat_reqev: # indices of times where req. event occurs - reqEvents_inds = np.where(events[:, 2] & ( - 1 << cat_reqev - 1))[0] + reqEvents_inds = np.where(events[:, 2] & (1 << cat_reqev - 1))[0] reqEvents_t = times[reqEvents_inds] # relative (to refevent) time window where req. event # must occur (e.g. [0 .2]) - twin = [0, (-1)**(cat['reqwhen']) * cat['reqwithin']] + twin = [0, (-1) ** (cat["reqwhen"]) * cat["reqwithin"]] win = np.round(np.array(sorted(twin)) * sfreq) # to samples refEvents_wins = refEvents_t[:, None] + win req_acc = np.zeros(refEvents_inds.shape, dtype=bool) for t in reqEvents_t: # mark time windows where req. condition is satisfied reqEvent_in_win = np.logical_and( - t >= refEvents_wins[:, 0], t <= refEvents_wins[:, 1]) + t >= refEvents_wins[:, 0], t <= refEvents_wins[:, 1] + ) req_acc |= reqEvent_in_win # drop ref. events where req. event condition is not satisfied refEvents_inds = refEvents_inds[np.where(req_acc)] refEvents_t = times[refEvents_inds] # adjust for trigger-stimulus delay by delaying the ref. event - refEvents_t += int(np.round(self._events[cat_ev]['delay'] * sfreq)) + refEvents_t += int(np.round(self._events[cat_ev]["delay"] * sfreq)) return refEvents_t @property @@ -1295,8 +1434,7 @@ def categories(self): Only returns categories marked active in DACQ. """ - cats = sorted(self._categories_in_use.values(), - key=lambda cat: cat['index']) + cats = sorted(self._categories_in_use.values(), key=lambda cat: cat["index"]) return cats @property @@ -1305,19 +1443,27 @@ def events(self): Only returns events that are in use (referred to by a category). """ - evs = sorted(self._events_in_use.values(), key=lambda ev: ev['index']) + evs = sorted(self._events_in_use.values(), key=lambda ev: ev["index"]) return evs @property def _categories_in_use(self): - return {k: v for k, v in self._categories.items() if v['state']} + return {k: v for k, v in self._categories.items() if v["state"]} @property def _events_in_use(self): - return {k: v for k, v in self._events.items() if v['in_use']} - - def get_condition(self, raw, condition=None, stim_channel=None, mask=None, - uint_cast=None, mask_type='and', delayed_lookup=True): + return {k: v for k, v in self._events.items() if v["in_use"]} + + def get_condition( + self, + raw, + condition=None, + stim_channel=None, + mask=None, + uint_cast=None, + mask_type="and", + delayed_lookup=True, + ): """Get averaging parameters for a condition (averaging category). Output is designed to be used with the Epochs class to extract the @@ -1389,35 +1535,45 @@ def get_condition(self, raw, condition=None, stim_channel=None, mask=None, for cat in condition: if isinstance(cat, str): cat = self[cat] - mne_events = find_events(raw, stim_channel=stim_channel, mask=mask, - mask_type=mask_type, output='step', - uint_cast=uint_cast, consecutive=True, - verbose=False, shortest_event=1) + mne_events = find_events( + raw, + stim_channel=stim_channel, + mask=mask, + mask_type=mask_type, + output="step", + uint_cast=uint_cast, + consecutive=True, + verbose=False, + shortest_event=1, + ) if delayed_lookup: ind = np.where(np.diff(mne_events[:, 0]) == 1)[0] if 1 in np.diff(ind): - raise ValueError('There are several subsequent ' - 'transitions on the trigger channel. ' - 'This will not work well with ' - 'delayed_lookup=True. You may want to ' - 'check your trigger data and ' - 'set delayed_lookup=False.') + raise ValueError( + "There are several subsequent " + "transitions on the trigger channel. " + "This will not work well with " + "delayed_lookup=True. You may want to " + "check your trigger data and " + "set delayed_lookup=False." + ) mne_events[ind, 2] = mne_events[ind + 1, 2] mne_events = np.delete(mne_events, ind + 1, axis=0) - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] cat_t0_ = self._mne_events_to_category_t0(cat, mne_events, sfreq) # make it compatible with the usual events array - cat_t0 = np.c_[cat_t0_, np.zeros(cat_t0_.shape), - cat['index'] * np.ones(cat_t0_.shape) - ].astype(np.uint32) - cat_id = {cat['comment']: cat['index']} - tmin, tmax = cat['start'], cat['end'] - conds_data.append(dict(events=cat_t0, event_id=cat_id, - tmin=tmin, tmax=tmax)) + cat_t0 = np.c_[ + cat_t0_, np.zeros(cat_t0_.shape), cat["index"] * np.ones(cat_t0_.shape) + ].astype(np.uint32) + cat_id = {cat["comment"]: cat["index"]} + tmin, tmax = cat["start"], cat["end"] + conds_data.append( + dict(events=cat_t0, event_id=cat_id, tmin=tmin, tmax=tmax) + ) return conds_data[0] if len(conds_data) == 1 else conds_data -def match_event_names(event_names, keys, *, on_missing='raise'): +def match_event_names(event_names, keys, *, on_missing="raise"): """Search a collection of event names for matching (sub-)groups of events. This function is particularly helpful when using grouped event names @@ -1468,10 +1624,7 @@ def match_event_names(event_names, keys, *, on_missing='raise'): event_names = list(event_names) # ensure we have a list of `keys` - if ( - isinstance(keys, (Sequence, np.ndarray)) and - not isinstance(keys, str) - ): + if isinstance(keys, (Sequence, np.ndarray)) and not isinstance(keys, str): keys = list(keys) else: keys = [keys] @@ -1481,19 +1634,20 @@ def match_event_names(event_names, keys, *, on_missing='raise'): # form the hierarchical event name mapping for key in keys: if not isinstance(key, str): - raise ValueError(f'keys must be strings, got {type(key)} ({key})') + raise ValueError(f"keys must be strings, got {type(key)} ({key})") matches.extend( - name for name in event_names - if set(key.split('/')).issubset(name.split('/')) + name + for name in event_names + if set(key.split("/")).issubset(name.split("/")) ) if not matches: _on_missing( on_missing=on_missing, msg=f'Event name "{key}" could not be found. The following events ' - f'are present in the data: {", ".join(event_names)}', - error_klass=KeyError + f'are present in the data: {", ".join(event_names)}', + error_klass=KeyError, ) matches = sorted(set(matches)) # deduplicate if necessary diff --git a/mne/evoked.py b/mne/evoked.py index 29c6e5ca1cb..e41c8c10cbb 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -12,20 +12,39 @@ import numpy as np from .baseline import rescale, _log_rescale, _check_baseline -from .channels.channels import (UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin) +from .channels.channels import UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin from .channels.layout import _merge_ch_data, _pair_grad_sensors -from .defaults import (_INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, - _BORDER_DEFAULT) +from .defaults import _INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, _BORDER_DEFAULT from .filter import detrend, FilterMixin, _check_fun -from .utils import (check_fname, logger, verbose, warn, sizeof_fmt, repr_html, - SizeMixin, copy_function_doc_to_method_doc, _validate_type, - fill_doc, _check_option, _build_data_frame, - _check_pandas_installed, _check_pandas_index_arguments, - _convert_times, _scale_dataframe_data, _check_time_format, - _check_preload, _check_fname, TimeMixin) -from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field, - plot_evoked_image, plot_evoked_topo) +from .utils import ( + check_fname, + logger, + verbose, + warn, + sizeof_fmt, + repr_html, + SizeMixin, + copy_function_doc_to_method_doc, + _validate_type, + fill_doc, + _check_option, + _build_data_frame, + _check_pandas_installed, + _check_pandas_index_arguments, + _convert_times, + _scale_dataframe_data, + _check_time_format, + _check_preload, + _check_fname, + TimeMixin, +) +from .viz import ( + plot_evoked, + plot_evoked_topomap, + plot_evoked_field, + plot_evoked_image, + plot_evoked_topo, +) from .viz.evoked import plot_evoked_white, plot_evoked_joint from .viz.topomap import _topomap_animation @@ -34,37 +53,58 @@ from .io.tag import read_tag from .io.tree import dir_tree_find from .io.pick import pick_types, _picks_to_idx, _FNIRS_CH_TYPES_SPLIT -from .io.meas_info import (ContainsMixin, read_meas_info, write_meas_info, - _read_extended_ch_info, _rename_list, - _ensure_infos_match) +from .io.meas_info import ( + ContainsMixin, + read_meas_info, + write_meas_info, + _read_extended_ch_info, + _rename_list, + _ensure_infos_match, +) from .io.proj import ProjMixin -from .io.write import (start_and_end_file, start_block, end_block, - write_int, write_string, write_float_matrix, - write_id, write_float, write_complex_float_matrix) +from .io.write import ( + start_and_end_file, + start_block, + end_block, + write_int, + write_string, + write_float_matrix, + write_id, + write_float, + write_complex_float_matrix, +) from .io.base import _check_maxshield, _get_ch_factors from .parallel import parallel_func from .time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method _aspect_dict = { - 'average': FIFF.FIFFV_ASPECT_AVERAGE, - 'standard_error': FIFF.FIFFV_ASPECT_STD_ERR, - 'single_epoch': FIFF.FIFFV_ASPECT_SINGLE, - 'partial_average': FIFF.FIFFV_ASPECT_SUBAVERAGE, - 'alternating_subaverage': FIFF.FIFFV_ASPECT_ALTAVERAGE, - 'sample_cut_out_by_graph': FIFF.FIFFV_ASPECT_SAMPLE, - 'power_density_spectrum': FIFF.FIFFV_ASPECT_POWER_DENSITY, - 'dipole_amplitude_cuvre': FIFF.FIFFV_ASPECT_DIPOLE_WAVE, - 'squid_modulation_lower_bound': FIFF.FIFFV_ASPECT_IFII_LOW, - 'squid_modulation_upper_bound': FIFF.FIFFV_ASPECT_IFII_HIGH, - 'squid_gate_setting': FIFF.FIFFV_ASPECT_GATE, + "average": FIFF.FIFFV_ASPECT_AVERAGE, + "standard_error": FIFF.FIFFV_ASPECT_STD_ERR, + "single_epoch": FIFF.FIFFV_ASPECT_SINGLE, + "partial_average": FIFF.FIFFV_ASPECT_SUBAVERAGE, + "alternating_subaverage": FIFF.FIFFV_ASPECT_ALTAVERAGE, + "sample_cut_out_by_graph": FIFF.FIFFV_ASPECT_SAMPLE, + "power_density_spectrum": FIFF.FIFFV_ASPECT_POWER_DENSITY, + "dipole_amplitude_cuvre": FIFF.FIFFV_ASPECT_DIPOLE_WAVE, + "squid_modulation_lower_bound": FIFF.FIFFV_ASPECT_IFII_LOW, + "squid_modulation_upper_bound": FIFF.FIFFV_ASPECT_IFII_HIGH, + "squid_gate_setting": FIFF.FIFFV_ASPECT_GATE, } _aspect_rev = {val: key for key, val in _aspect_dict.items()} @fill_doc -class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, FilterMixin, TimeMixin, SizeMixin, - SpectrumMixin): +class Evoked( + ProjMixin, + ContainsMixin, + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + FilterMixin, + TimeMixin, + SizeMixin, + SpectrumMixin, +): """Evoked data. Parameters @@ -123,17 +163,28 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, """ @verbose - def __init__(self, fname, condition=None, proj=True, - kind='average', allow_maxshield=False, *, - verbose=None): # noqa: D102 + def __init__( + self, + fname, + condition=None, + proj=True, + kind="average", + allow_maxshield=False, + *, + verbose=None, + ): # noqa: D102 _validate_type(proj, bool, "'proj'") # Read the requested data - fname = str( - _check_fname(fname=fname, must_exist=True, overwrite="read") - ) - self.info, self.nave, self._aspect_kind, self.comment, times, \ - self.data, self.baseline = _read_evoked(fname, condition, kind, - allow_maxshield) + fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read")) + ( + self.info, + self.nave, + self._aspect_kind, + self.comment, + times, + self.data, + self.baseline, + ) = _read_evoked(fname, condition, kind, allow_maxshield) self._set_times(times) self._raw_times = self.times.copy() self._decim = 1 @@ -152,7 +203,7 @@ def kind(self): @kind.setter def kind(self, kind): - _check_option('kind', kind, list(_aspect_dict.keys())) + _check_option("kind", kind, list(_aspect_dict.keys())) self._aspect_kind = _aspect_dict[kind] @property @@ -200,8 +251,9 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None): return data @verbose - def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, - verbose=None, **kwargs): + def apply_function( + self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs + ): """Apply a function to a subset of channels. %(applyfun_summary_evoked)s @@ -221,18 +273,18 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self : instance of Evoked The evoked object with transformed data. """ - _check_preload(self, 'evoked.apply_function') + _check_preload(self, "evoked.apply_function") picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) if not callable(fun): - raise ValueError('fun needs to be a function') + raise ValueError("fun needs to be a function") data_in = self._data if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) # check the dimension of the incoming evoked data - _check_option('evoked.ndim', self._data.ndim, [2]) + _check_option("evoked.ndim", self._data.ndim, [2]) parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) if n_jobs == 1: @@ -241,8 +293,9 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) else: # use parallel function - data_picks_new = parallel(p_fun( - fun, data_in[p, :], **kwargs) for p in picks) + data_picks_new = parallel( + p_fun(fun, data_in[p, :], **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[p, :] = data_picks_new[pp] @@ -270,11 +323,12 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): .. versionadded:: 0.13.0 """ - baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + baseline = _check_baseline(baseline, times=self.times, sfreq=self.info["sfreq"]) if self.baseline is not None and baseline is None: - raise ValueError('The data has already been baseline-corrected. ' - 'Cannot remove existing baseline correction.') + raise ValueError( + "The data has already been baseline-corrected. " + "Cannot remove existing baseline correction." + ) elif baseline is None: # Do not rescale logger.info(_log_rescale(None)) @@ -309,7 +363,7 @@ def save(self, fname, *, overwrite=False, verbose=None): write_evokeds(fname, self, overwrite=overwrite) @verbose - def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): + def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): """Export Evoked to external formats. %(export_fmt_support_evoked)s @@ -330,6 +384,7 @@ def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): %(export_warning_note_evoked)s """ from .export import export_evokeds + export_evokeds(fname, self, fmt, overwrite=overwrite, verbose=verbose) def __repr__(self): # noqa: D105 @@ -341,15 +396,18 @@ def __repr__(self): # noqa: D105 comment = self.comment s = "'%s' (%s, N=%s)" % (comment, self.kind, self.nave) s += ", %0.5g – %0.5g s" % (self.times[0], self.times[-1]) - s += ', baseline ' + s += ", baseline " if self.baseline is None: - s += 'off' + s += "off" else: - s += f'{self.baseline[0]:g} – {self.baseline[1]:g} s' + s += f"{self.baseline[0]:g} – {self.baseline[1]:g} s" if self.baseline != _check_baseline( - self.baseline, times=self.times, sfreq=self.info['sfreq'], - on_baseline_outside_data='adjust'): - s += ' (baseline period was cropped after baseline correction)' + self.baseline, + times=self.times, + sfreq=self.info["sfreq"], + on_baseline_outside_data="adjust", + ): + s += " (baseline period was cropped after baseline correction)" s += ", %s ch" % self.data.shape[0] s += ", ~%s" % (sizeof_fmt(self._size),) return "" % s @@ -357,122 +415,328 @@ def __repr__(self): # noqa: D105 @repr_html def _repr_html_(self): from .html_templates import repr_templates_env + if self.baseline is None: - baseline = 'off' + baseline = "off" else: - baseline = tuple([f'{b:.3f}' for b in self.baseline]) - baseline = f'{baseline[0]} – {baseline[1]} s' + baseline = tuple([f"{b:.3f}" for b in self.baseline]) + baseline = f"{baseline[0]} – {baseline[1]} s" - t = repr_templates_env.get_template('evoked.html.jinja') + t = repr_templates_env.get_template("evoked.html.jinja") t = t.render(evoked=self, baseline=baseline) return t @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @copy_function_doc_to_method_doc(plot_evoked) - def plot(self, picks=None, exclude='bads', unit=True, show=True, ylim=None, - xlim='tight', proj=False, hline=None, units=None, scalings=None, - titles=None, axes=None, gfp=False, window_title=None, - spatial_colors='auto', zorder='unsorted', selectable=True, - noise_cov=None, time_unit='s', sphere=None, *, highlight=None, - verbose=None): + def plot( + self, + picks=None, + exclude="bads", + unit=True, + show=True, + ylim=None, + xlim="tight", + proj=False, + hline=None, + units=None, + scalings=None, + titles=None, + axes=None, + gfp=False, + window_title=None, + spatial_colors="auto", + zorder="unsorted", + selectable=True, + noise_cov=None, + time_unit="s", + sphere=None, + *, + highlight=None, + verbose=None, + ): return plot_evoked( - self, picks=picks, exclude=exclude, unit=unit, show=show, - ylim=ylim, proj=proj, xlim=xlim, hline=hline, units=units, - scalings=scalings, titles=titles, axes=axes, gfp=gfp, - window_title=window_title, spatial_colors=spatial_colors, - zorder=zorder, selectable=selectable, noise_cov=noise_cov, - time_unit=time_unit, sphere=sphere, highlight=highlight, - verbose=verbose) + self, + picks=picks, + exclude=exclude, + unit=unit, + show=show, + ylim=ylim, + proj=proj, + xlim=xlim, + hline=hline, + units=units, + scalings=scalings, + titles=titles, + axes=axes, + gfp=gfp, + window_title=window_title, + spatial_colors=spatial_colors, + zorder=zorder, + selectable=selectable, + noise_cov=noise_cov, + time_unit=time_unit, + sphere=sphere, + highlight=highlight, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_image) - def plot_image(self, picks=None, exclude='bads', unit=True, show=True, - clim=None, xlim='tight', proj=False, units=None, - scalings=None, titles=None, axes=None, cmap='RdBu_r', - colorbar=True, mask=None, mask_style=None, - mask_cmap='Greys', mask_alpha=.25, time_unit='s', - show_names=None, group_by=None, sphere=None): + def plot_image( + self, + picks=None, + exclude="bads", + unit=True, + show=True, + clim=None, + xlim="tight", + proj=False, + units=None, + scalings=None, + titles=None, + axes=None, + cmap="RdBu_r", + colorbar=True, + mask=None, + mask_style=None, + mask_cmap="Greys", + mask_alpha=0.25, + time_unit="s", + show_names=None, + group_by=None, + sphere=None, + ): return plot_evoked_image( - self, picks=picks, exclude=exclude, unit=unit, show=show, - clim=clim, xlim=xlim, proj=proj, units=units, scalings=scalings, - titles=titles, axes=axes, cmap=cmap, colorbar=colorbar, mask=mask, - mask_style=mask_style, mask_cmap=mask_cmap, mask_alpha=mask_alpha, - time_unit=time_unit, show_names=show_names, group_by=group_by, - sphere=sphere) + self, + picks=picks, + exclude=exclude, + unit=unit, + show=show, + clim=clim, + xlim=xlim, + proj=proj, + units=units, + scalings=scalings, + titles=titles, + axes=axes, + cmap=cmap, + colorbar=colorbar, + mask=mask, + mask_style=mask_style, + mask_cmap=mask_cmap, + mask_alpha=mask_alpha, + time_unit=time_unit, + show_names=show_names, + group_by=group_by, + sphere=sphere, + ) @copy_function_doc_to_method_doc(plot_evoked_topo) - def plot_topo(self, layout=None, layout_scale=0.945, color=None, - border='none', ylim=None, scalings=None, title=None, - proj=False, vline=[0.0], fig_background=None, - merge_grads=False, legend=True, axes=None, - background_color='w', noise_cov=None, exclude='bads', - show=True): + def plot_topo( + self, + layout=None, + layout_scale=0.945, + color=None, + border="none", + ylim=None, + scalings=None, + title=None, + proj=False, + vline=[0.0], + fig_background=None, + merge_grads=False, + legend=True, + axes=None, + background_color="w", + noise_cov=None, + exclude="bads", + show=True, + ): """ Notes ----- .. versionadded:: 0.10.0 """ return plot_evoked_topo( - self, layout=layout, layout_scale=layout_scale, - color=color, border=border, ylim=ylim, scalings=scalings, - title=title, proj=proj, vline=vline, fig_background=fig_background, - merge_grads=merge_grads, legend=legend, axes=axes, - background_color=background_color, noise_cov=noise_cov, - exclude=exclude, show=show) + self, + layout=layout, + layout_scale=layout_scale, + color=color, + border=border, + ylim=ylim, + scalings=scalings, + title=title, + proj=proj, + vline=vline, + fig_background=fig_background, + merge_grads=merge_grads, + legend=legend, + axes=axes, + background_color=background_color, + noise_cov=noise_cov, + exclude=exclude, + show=show, + ) @copy_function_doc_to_method_doc(plot_evoked_topomap) def plot_topomap( - self, times="auto", *, average=None, ch_type=None, scalings=None, - proj=False, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, - cbar_fmt='%3.1f', units=None, axes=None, time_unit='s', - time_format=None, nrows=1, ncols='auto', show=True): + self, + times="auto", + *, + average=None, + ch_type=None, + scalings=None, + proj=False, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + time_unit="s", + time_format=None, + nrows=1, + ncols="auto", + show=True, + ): return plot_evoked_topomap( - self, times=times, ch_type=ch_type, vlim=vlim, cmap=cmap, - cnorm=cnorm, sensors=sensors, colorbar=colorbar, scalings=scalings, - units=units, res=res, size=size, cbar_fmt=cbar_fmt, - time_unit=time_unit, time_format=time_format, proj=proj, show=show, - show_names=show_names, mask=mask, mask_params=mask_params, - outlines=outlines, contours=contours, image_interp=image_interp, - average=average, axes=axes, extrapolate=extrapolate, sphere=sphere, - border=border, nrows=nrows, ncols=ncols) + self, + times=times, + ch_type=ch_type, + vlim=vlim, + cmap=cmap, + cnorm=cnorm, + sensors=sensors, + colorbar=colorbar, + scalings=scalings, + units=units, + res=res, + size=size, + cbar_fmt=cbar_fmt, + time_unit=time_unit, + time_format=time_format, + proj=proj, + show=show, + show_names=show_names, + mask=mask, + mask_params=mask_params, + outlines=outlines, + contours=contours, + image_interp=image_interp, + average=average, + axes=axes, + extrapolate=extrapolate, + sphere=sphere, + border=border, + nrows=nrows, + ncols=ncols, + ) @copy_function_doc_to_method_doc(plot_evoked_field) - def plot_field(self, surf_maps, time=None, time_label='t = %0.0f ms', - n_jobs=None, fig=None, vmax=None, n_contours=21, - *, interaction='http://github.com/terrain', verbose=None): - return plot_evoked_field(self, surf_maps, time=time, - time_label=time_label, n_jobs=n_jobs, - fig=fig, vmax=vmax, n_contours=n_contours, - interaction=interaction, verbose=verbose) + def plot_field( + self, + surf_maps, + time=None, + time_label="t = %0.0f ms", + n_jobs=None, + fig=None, + vmax=None, + n_contours=21, + *, + interaction="http://github.com/terrain", + verbose=None, + ): + return plot_evoked_field( + self, + surf_maps, + time=time, + time_label=time_label, + n_jobs=n_jobs, + fig=fig, + vmax=vmax, + n_contours=n_contours, + interaction=interaction, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_white) - def plot_white(self, noise_cov, show=True, rank=None, time_unit='s', - sphere=None, axes=None, verbose=None): + def plot_white( + self, + noise_cov, + show=True, + rank=None, + time_unit="s", + sphere=None, + axes=None, + verbose=None, + ): return plot_evoked_white( - self, noise_cov=noise_cov, rank=rank, show=show, - time_unit=time_unit, sphere=sphere, axes=axes, verbose=verbose) + self, + noise_cov=noise_cov, + rank=rank, + show=show, + time_unit=time_unit, + sphere=sphere, + axes=axes, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_joint) - def plot_joint(self, times="peaks", title='', picks=None, - exclude='bads', show=True, ts_args=None, - topomap_args=None): - return plot_evoked_joint(self, times=times, title=title, picks=picks, - exclude=exclude, show=show, ts_args=ts_args, - topomap_args=topomap_args) + def plot_joint( + self, + times="peaks", + title="", + picks=None, + exclude="bads", + show=True, + ts_args=None, + topomap_args=None, + ): + return plot_evoked_joint( + self, + times=times, + title=title, + picks=picks, + exclude=exclude, + show=show, + ts_args=ts_args, + topomap_args=topomap_args, + ) @fill_doc - def animate_topomap(self, ch_type=None, times=None, frame_rate=None, - butterfly=False, blit=True, show=True, time_unit='s', - sphere=None, *, image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, vmin=None, vmax=None, - verbose=None): + def animate_topomap( + self, + ch_type=None, + times=None, + frame_rate=None, + butterfly=False, + blit=True, + show=True, + time_unit="s", + sphere=None, + *, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + vmin=None, + vmax=None, + verbose=None, + ): """Make animation of evoked data as topomap timeseries. The animation can be paused/resumed with left mouse button. @@ -530,12 +794,23 @@ def animate_topomap(self, ch_type=None, times=None, frame_rate=None, .. versionadded:: 0.12.0 """ return _topomap_animation( - self, ch_type=ch_type, times=times, frame_rate=frame_rate, - butterfly=butterfly, blit=blit, show=show, time_unit=time_unit, - sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, vmin=vmin, vmax=vmax, verbose=verbose) + self, + ch_type=ch_type, + times=times, + frame_rate=frame_rate, + butterfly=butterfly, + blit=blit, + show=show, + time_unit=time_unit, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + vmin=vmin, + vmax=vmax, + verbose=verbose, + ) - def as_type(self, ch_type='grad', mode='fast'): + def as_type(self, ch_type="grad", mode="fast"): """Compute virtual evoked using interpolated fields. .. Warning:: Using virtual evoked to compute inverse can yield @@ -565,6 +840,7 @@ def as_type(self, ch_type='grad', mode='fast'): .. versionadded:: 0.9.0 """ from .forward import _as_meg_type_inst + return _as_meg_type_inst(self, ch_type=ch_type, mode=mode) @fill_doc @@ -612,14 +888,21 @@ def __neg__(self): out = self.copy() out.data *= -1 - if out.comment is not None and ' + ' in out.comment: - out.comment = f'({out.comment})' # multiple conditions in evoked + if out.comment is not None and " + " in out.comment: + out.comment = f"({out.comment})" # multiple conditions in evoked out.comment = f'- {out.comment or "unknown"}' return out - def get_peak(self, ch_type=None, tmin=None, tmax=None, - mode='abs', time_as_index=False, merge_grads=False, - return_amplitude=False): + def get_peak( + self, + ch_type=None, + tmin=None, + tmax=None, + mode="abs", + time_as_index=False, + merge_grads=False, + return_amplitude=False, + ): """Get location and latency of peak amplitude. Parameters @@ -660,11 +943,19 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, .. versionadded:: 0.16 """ # noqa: E501 - supported = ('mag', 'grad', 'eeg', 'seeg', 'dbs', 'ecog', 'misc', - 'None') + _FNIRS_CH_TYPES_SPLIT + supported = ( + "mag", + "grad", + "eeg", + "seeg", + "dbs", + "ecog", + "misc", + "None", + ) + _FNIRS_CH_TYPES_SPLIT types_used = self.get_channel_types(unique=True, only_data_chs=True) - _check_option('ch_type', str(ch_type), supported) + _check_option("ch_type", str(ch_type), supported) if ch_type is not None and ch_type not in types_used: raise ValueError( @@ -674,29 +965,31 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, elif len(types_used) > 1 and ch_type is None: raise RuntimeError( 'Multiple data channel types found. Please pass the "ch_type" ' - 'parameter.' + "parameter." ) if merge_grads: - if ch_type != 'grad': + if ch_type != "grad": raise ValueError('Channel type must be "grad" for merge_grads') - elif mode == 'neg': - raise ValueError('Negative mode (mode=neg) does not make ' - 'sense with merge_grads=True') + elif mode == "neg": + raise ValueError( + "Negative mode (mode=neg) does not make " + "sense with merge_grads=True" + ) meg = eeg = misc = seeg = dbs = ecog = fnirs = False picks = None - if ch_type in ('mag', 'grad'): + if ch_type in ("mag", "grad"): meg = ch_type - elif ch_type == 'eeg': + elif ch_type == "eeg": eeg = True - elif ch_type == 'misc': + elif ch_type == "misc": misc = True - elif ch_type == 'seeg': + elif ch_type == "seeg": seeg = True - elif ch_type == 'dbs': + elif ch_type == "dbs": dbs = True - elif ch_type == 'ecog': + elif ch_type == "ecog": ecog = True elif ch_type in _FNIRS_CH_TYPES_SPLIT: fnirs = ch_type @@ -705,9 +998,17 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, if merge_grads: picks = _pair_grad_sensors(self.info, topomap_coords=False) else: - picks = pick_types(self.info, meg=meg, eeg=eeg, misc=misc, - seeg=seeg, ecog=ecog, ref_meg=False, - fnirs=fnirs, dbs=dbs) + picks = pick_types( + self.info, + meg=meg, + eeg=eeg, + misc=misc, + seeg=seeg, + ecog=ecog, + ref_meg=False, + fnirs=fnirs, + dbs=dbs, + ) data = self.data ch_names = self.ch_names @@ -717,13 +1018,11 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, if merge_grads: data, _ = _merge_ch_data(data, ch_type, []) - ch_names = [ch_name[:-1] + 'X' for ch_name in ch_names[::2]] + ch_names = [ch_name[:-1] + "X" for ch_name in ch_names[::2]] - ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin, - tmax, mode) + ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin, tmax, mode) - out = (ch_names[ch_idx], time_idx if time_as_index else - self.times[time_idx]) + out = (ch_names[ch_idx], time_idx if time_as_index else self.times[time_idx]) if return_amplitude: out += (max_amp,) @@ -731,9 +1030,20 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, return out @verbose - def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, - tmax=None, picks=None, proj=False, *, n_jobs=1, - verbose=None, **method_kw): + def compute_psd( + self, + method="multitaper", + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + n_jobs=1, + verbose=None, + **method_kw, + ): """Perform spectral analysis on sensor data. Parameters @@ -765,17 +1075,48 @@ def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, self._set_legacy_nfft_default(tmin, tmax, method, method_kw) return Spectrum( - self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - picks=picks, proj=proj, reject_by_annotation=False, n_jobs=n_jobs, - verbose=verbose, **method_kw) + self, + method=method, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, - proj=False, *, method='auto', average=False, dB=True, - estimate='auto', xscale='linear', area_mode='std', - area_alpha=0.33, color='black', line_alpha=None, - spatial_colors=True, sphere=None, exclude='bads', ax=None, - show=True, n_jobs=1, verbose=None, **method_kw): + def plot_psd( + self, + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + method="auto", + average=False, + dB=True, + estimate="auto", + xscale="linear", + area_mode="std", + area_alpha=0.33, + color="black", + line_alpha=None, + spatial_colors=True, + sphere=None, + exclude="bads", + ax=None, + show=True, + n_jobs=1, + verbose=None, + **method_kw, + ): """%(plot_psd_doc)s. Parameters @@ -819,17 +1160,44 @@ def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, %(notes_plot_psd_meth)s """ return super().plot_psd( - fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, - reject_by_annotation=False, method=method, average=average, dB=dB, - estimate=estimate, xscale=xscale, area_mode=area_mode, - area_alpha=area_alpha, color=color, line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, - ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + method=method, + average=average, + dB=dB, + estimate=estimate, + xscale=xscale, + area_mode=area_mode, + area_alpha=area_alpha, + color=color, + line_alpha=line_alpha, + spatial_colors=spatial_colors, + sphere=sphere, + exclude=exclude, + ax=ax, + show=show, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def to_data_frame(self, picks=None, index=None, - scalings=None, copy=True, long_format=False, - time_format=None, *, verbose=None): + def to_data_frame( + self, + picks=None, + index=None, + scalings=None, + copy=True, + long_format=False, + time_format=None, + *, + verbose=None, + ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, @@ -856,12 +1224,12 @@ def to_data_frame(self, picks=None, index=None, # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking - valid_index_args = ['time'] - valid_time_formats = ['ms', 'timedelta'] + valid_index_args = ["time"] + valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) data = self.data[picks, :] times = self.times data = data.T @@ -871,10 +1239,11 @@ def to_data_frame(self, picks=None, index=None, # prepare extra columns / multiindex mindex = list() times = _convert_times(self, times, time_format) - mindex.append(('time', times)) + mindex.append(("time", times)) # build DataFrame - df = _build_data_frame(self, data, picks, long_format, mindex, index, - default_index=['time']) + df = _build_data_frame( + self, data, picks, long_format, mindex, index, default_index=["time"] + ) return df @@ -919,26 +1288,40 @@ class EvokedArray(Evoked): """ @verbose - def __init__(self, data, info, tmin=0., comment='', nave=1, kind='average', - baseline=None, *, verbose=None): # noqa: D102 + def __init__( + self, + data, + info, + tmin=0.0, + comment="", + nave=1, + kind="average", + baseline=None, + *, + verbose=None, + ): # noqa: D102 dtype = np.complex128 if np.iscomplexobj(data) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 2: - raise ValueError('Data must be a 2D array of shape (n_channels, ' - 'n_samples), got shape %s' % (data.shape,)) + raise ValueError( + "Data must be a 2D array of shape (n_channels, " + "n_samples), got shape %s" % (data.shape,) + ) - if len(info['ch_names']) != np.shape(data)[0]: - raise ValueError('Info (%s) and data (%s) must have same number ' - 'of channels.' % (len(info['ch_names']), - np.shape(data)[0])) + if len(info["ch_names"]) != np.shape(data)[0]: + raise ValueError( + "Info (%s) and data (%s) must have same number " + "of channels." % (len(info["ch_names"]), np.shape(data)[0]) + ) self.data = data - self.first = int(round(tmin * info['sfreq'])) + self.first = int(round(tmin * info["sfreq"])) self.last = self.first + np.shape(data)[-1] - 1 - self._set_times(np.arange(self.first, self.last + 1, - dtype=np.float64) / info['sfreq']) + self._set_times( + np.arange(self.first, self.last + 1, dtype=np.float64) / info["sfreq"] + ) self._raw_times = self.times.copy() self._decim = 1 self.info = info.copy() # do not modify original info @@ -950,8 +1333,10 @@ def __init__(self, data, info, tmin=0., comment='', nave=1, kind='average', self._projector = None _validate_type(self.kind, "str", "kind") if self.kind not in _aspect_dict: - raise ValueError('unknown kind "%s", should be "average" or ' - '"standard_error"' % (self.kind,)) + raise ValueError( + 'unknown kind "%s", should be "average" or ' + '"standard_error"' % (self.kind,) + ) self._aspect_kind = _aspect_dict[self.kind] self.baseline = baseline @@ -964,16 +1349,16 @@ def _get_entries(fid, evoked_node, allow_maxshield=False): comments = list() aspect_kinds = list() for ev in evoked_node: - for k in range(ev['nent']): - my_kind = ev['directory'][k].kind - pos = ev['directory'][k].pos + for k in range(ev["nent"]): + my_kind = ev["directory"][k].kind + pos = ev["directory"][k].pos if my_kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comments.append(tag.data) my_aspect = _get_aspect(ev, allow_maxshield)[0] - for k in range(my_aspect['nent']): - my_kind = my_aspect['directory'][k].kind - pos = my_aspect['directory'][k].pos + for k in range(my_aspect["nent"]): + my_kind = my_aspect["directory"][k].kind + pos = my_aspect["directory"][k].pos if my_kind == FIFF.FIFF_ASPECT_KIND: tag = read_tag(fid, pos) aspect_kinds.append(int(tag.data.item())) @@ -981,11 +1366,10 @@ def _get_entries(fid, evoked_node, allow_maxshield=False): aspect_kinds = np.atleast_1d(aspect_kinds) if len(comments) != len(aspect_kinds) or len(comments) == 0: fid.close() - raise ValueError('Dataset names in FIF file ' - 'could not be found.') + raise ValueError("Dataset names in FIF file " "could not be found.") t = [_aspect_rev[a] for a in aspect_kinds] - t = ['"' + c + '" (' + tt + ')' for tt, c in zip(t, comments)] - t = '\n'.join(t) + t = ['"' + c + '" (' + tt + ")" for tt, c in zip(t, comments)] + t = "\n".join(t) return comments, aspect_kinds, t @@ -998,7 +1382,7 @@ def _get_aspect(evoked, allow_maxshield): aspect = dir_tree_find(evoked, FIFF.FIFFB_IAS_ASPECT) is_maxshield = True if len(aspect) > 1: - logger.info('Multiple data aspects found. Taking first one.') + logger.info("Multiple data aspects found. Taking first one.") return aspect[0], is_maxshield @@ -1018,16 +1402,17 @@ def _check_evokeds_ch_names_times(all_evoked): if ev.ch_names != ch_names: if set(ev.ch_names) != set(ch_names): raise ValueError( - "%s and %s do not contain the same channels." % (evoked, - ev)) + "%s and %s do not contain the same channels." % (evoked, ev) + ) else: warn("Order of channels differs, reordering channels ...") ev = ev.copy() ev.reorder_channels(ch_names) all_evoked[ii + 1] = ev if not np.max(np.abs(ev.times - evoked.times)) < 1e-7: - raise ValueError("%s and %s do not contain the same time instants" - % (evoked, ev)) + raise ValueError( + "%s and %s do not contain the same time instants" % (evoked, ev) + ) return all_evoked @@ -1066,8 +1451,8 @@ def combine_evoked(all_evoked, weights): """ naves = np.array([evk.nave for evk in all_evoked], float) if isinstance(weights, str): - _check_option('weights', weights, ['nave', 'equal']) - if weights == 'nave': + _check_option("weights", weights, ["nave", "equal"]) + if weights == "nave": weights = naves / naves.sum() else: weights = np.ones_like(naves) / len(naves) @@ -1075,7 +1460,7 @@ def combine_evoked(all_evoked, weights): weights = np.array(weights, float) if weights.ndim != 1 or weights.size != len(all_evoked): - raise ValueError('weights must be the same size as all_evoked') + raise ValueError("weights must be the same size as all_evoked") # cf. https://en.wikipedia.org/wiki/Weighted_arithmetic_mean, section on # "weighted sample variance". The variance of a weighted sample mean is: @@ -1087,7 +1472,7 @@ def combine_evoked(all_evoked, weights): # σ² = w₁² / nave₁ + w₂² / nave₂ + ... + wₙ² / naveₙ # # And our resulting nave is the reciprocal of this: - new_nave = 1. / np.sum(weights ** 2 / naves) + new_nave = 1.0 / np.sum(weights**2 / naves) # This general formula is equivalent to formulae in Matti's manual # (pp 128-129), where: # new_nave = sum(naves) when weights='nave' and @@ -1097,37 +1482,44 @@ def combine_evoked(all_evoked, weights): evoked = all_evoked[0].copy() # use union of bad channels - bads = list(set(b for e in all_evoked for b in e.info['bads'])) - evoked.info['bads'] = bads + bads = list(set(b for e in all_evoked for b in e.info["bads"])) + evoked.info["bads"] = bads evoked.data = sum(w * e.data for w, e in zip(weights, all_evoked)) evoked.nave = new_nave - comment = '' + comment = "" for idx, (w, e) in enumerate(zip(weights, all_evoked)): # pick sign - sign = '' if w >= 0 else '-' + sign = "" if w >= 0 else "-" # format weight - weight = '' if np.isclose(abs(w), 1.) else f'{abs(w):0.3f}' + weight = "" if np.isclose(abs(w), 1.0) else f"{abs(w):0.3f}" # format multiplier - multiplier = ' × ' if weight else '' + multiplier = " × " if weight else "" # format comment - if e.comment is not None and ' + ' in e.comment: # multiple conditions - this_comment = f'({e.comment})' + if e.comment is not None and " + " in e.comment: # multiple conditions + this_comment = f"({e.comment})" else: this_comment = f'{e.comment or "unknown"}' # assemble everything if idx == 0: - comment += f'{sign}{weight}{multiplier}{this_comment}' + comment += f"{sign}{weight}{multiplier}{this_comment}" else: comment += f' {sign or "+"} {weight}{multiplier}{this_comment}' # special-case: combine_evoked([e1, -e2], [1, -1]) - evoked.comment = comment.replace(' - - ', ' + ') + evoked.comment = comment.replace(" - - ", " + ") return evoked @verbose -def read_evokeds(fname, condition=None, baseline=None, kind='average', - proj=True, allow_maxshield=False, verbose=None): +def read_evokeds( + fname, + condition=None, + baseline=None, + kind="average", + proj=True, + allow_maxshield=False, + verbose=None, +): """Read evoked dataset(s). Parameters @@ -1182,9 +1574,8 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', reading. """ fname = str(_check_fname(fname, overwrite="read", must_exist=True)) - check_fname(fname, 'evoked', ('-ave.fif', '-ave.fif.gz', - '_ave.fif', '_ave.fif.gz')) - logger.info('Reading %s ...' % fname) + check_fname(fname, "evoked", ("-ave.fif", "-ave.fif.gz", "_ave.fif", "_ave.fif.gz")) + logger.info("Reading %s ..." % fname) return_list = True if condition is None: evoked_node = _get_evoked_node(fname) @@ -1195,16 +1586,23 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', out = [] for c in condition: - evoked = Evoked(fname, c, kind=kind, proj=proj, - allow_maxshield=allow_maxshield, - verbose=verbose) + evoked = Evoked( + fname, + c, + kind=kind, + proj=proj, + allow_maxshield=allow_maxshield, + verbose=verbose, + ) if baseline is None and evoked.baseline is None: logger.info(_log_rescale(None)) elif baseline is None and evoked.baseline is not None: # Don't touch an existing baseline bmin, bmax = evoked.baseline - logger.info(f'Loaded Evoked data is baseline-corrected ' - f'(baseline: [{bmin:g}, {bmax:g}] s)') + logger.info( + f"Loaded Evoked data is baseline-corrected " + f"(baseline: [{bmin:g}, {bmax:g}] s)" + ) else: evoked.apply_baseline(baseline) out.append(evoked) @@ -1212,10 +1610,10 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', return out if return_list else out[0] -def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): +def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): """Read evoked data from a FIF file.""" if fname is None: - raise ValueError('No evoked filename specified') + raise ValueError("No evoked filename specified") f, tree, _ = fiff_open(fname) with f as fid: @@ -1225,47 +1623,47 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): # Locate the data of interest processed = dir_tree_find(meas, FIFF.FIFFB_PROCESSED_DATA) if len(processed) == 0: - raise ValueError('Could not find processed data') + raise ValueError("Could not find processed data") evoked_node = dir_tree_find(meas, FIFF.FIFFB_EVOKED) if len(evoked_node) == 0: - raise ValueError('Could not find evoked data') + raise ValueError("Could not find evoked data") # find string-based entry if isinstance(condition, str): if kind not in _aspect_dict.keys(): - raise ValueError('kind must be "average" or ' - '"standard_error"') + raise ValueError('kind must be "average" or ' '"standard_error"') - comments, aspect_kinds, t = _get_entries(fid, evoked_node, - allow_maxshield) - goods = (np.in1d(comments, [condition]) & - np.in1d(aspect_kinds, [_aspect_dict[kind]])) + comments, aspect_kinds, t = _get_entries(fid, evoked_node, allow_maxshield) + goods = np.in1d(comments, [condition]) & np.in1d( + aspect_kinds, [_aspect_dict[kind]] + ) found_cond = np.where(goods)[0] if len(found_cond) != 1: - raise ValueError('condition "%s" (%s) not found, out of ' - 'found datasets:\n%s' - % (condition, kind, t)) + raise ValueError( + 'condition "%s" (%s) not found, out of ' + "found datasets:\n%s" % (condition, kind, t) + ) condition = found_cond[0] elif condition is None: if len(evoked_node) > 1: - _, _, conditions = _get_entries(fid, evoked_node, - allow_maxshield) - raise TypeError("Evoked file has more than one " - "condition, the condition parameters " - "must be specified from:\n%s" % conditions) + _, _, conditions = _get_entries(fid, evoked_node, allow_maxshield) + raise TypeError( + "Evoked file has more than one " + "condition, the condition parameters " + "must be specified from:\n%s" % conditions + ) else: condition = 0 if condition >= len(evoked_node) or condition < 0: - raise ValueError('Data set selector out of range') + raise ValueError("Data set selector out of range") my_evoked = evoked_node[condition] # Identify the aspects with info._unlock(): - my_aspect, info['maxshield'] = _get_aspect(my_evoked, - allow_maxshield) + my_aspect, info["maxshield"] = _get_aspect(my_evoked, allow_maxshield) # Now find the data in the evoked block nchan = 0 @@ -1273,9 +1671,9 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): chs = [] baseline = bmin = bmax = None comment = last = first = first_time = nsamp = None - for k in range(my_evoked['nent']): - my_kind = my_evoked['directory'][k].kind - pos = my_evoked['directory'][k].pos + for k in range(my_evoked["nent"]): + my_kind = my_evoked["directory"][k].kind + pos = my_evoked["directory"][k].pos if my_kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comment = tag.data @@ -1308,7 +1706,7 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): bmax = float(tag.data.item()) if comment is None: - comment = 'No comment' + comment = "No comment" if bmin is not None or bmax is not None: # None's should've been replaced with floats @@ -1318,27 +1716,31 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): # Local channel information? if nchan > 0: if chs is None: - raise ValueError('Local channel information was not found ' - 'when it was expected.') + raise ValueError( + "Local channel information was not found " "when it was expected." + ) if len(chs) != nchan: - raise ValueError('Number of channels and number of ' - 'channel definitions are different') + raise ValueError( + "Number of channels and number of " + "channel definitions are different" + ) ch_names_mapping = _read_extended_ch_info(chs, my_evoked, fid) - info['chs'] = chs - info['bads'][:] = _rename_list(info['bads'], ch_names_mapping) - logger.info(' Found channel information in evoked data. ' - 'nchan = %d' % nchan) + info["chs"] = chs + info["bads"][:] = _rename_list(info["bads"], ch_names_mapping) + logger.info( + " Found channel information in evoked data. " "nchan = %d" % nchan + ) if sfreq > 0: - info['sfreq'] = sfreq + info["sfreq"] = sfreq # Read the data in the aspect block nave = 1 epoch = [] - for k in range(my_aspect['nent']): - kind = my_aspect['directory'][k].kind - pos = my_aspect['directory'][k].pos + for k in range(my_aspect["nent"]): + kind = my_aspect["directory"][k].kind + pos = my_aspect["directory"][k].pos if kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comment = tag.data @@ -1353,16 +1755,17 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): epoch.append(tag) nepoch = len(epoch) - if nepoch != 1 and nepoch != info['nchan']: - raise ValueError('Number of epoch tags is unreasonable ' - '(nepoch = %d nchan = %d)' - % (nepoch, info['nchan'])) + if nepoch != 1 and nepoch != info["nchan"]: + raise ValueError( + "Number of epoch tags is unreasonable " + "(nepoch = %d nchan = %d)" % (nepoch, info["nchan"]) + ) if nepoch == 1: # Only one epoch data = epoch[0].data # May need a transpose if the number of channels is one - if data.shape[1] == 1 and info['nchan'] == 1: + if data.shape[1] == 1 and info["nchan"] == 1: data = data.T else: # Put the old style epochs together @@ -1373,37 +1776,43 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): data = data.astype(np.complex128) if first_time is not None and nsamp is not None: - times = first_time + np.arange(nsamp) / info['sfreq'] + times = first_time + np.arange(nsamp) / info["sfreq"] elif first is not None: nsamp = last - first + 1 - times = np.arange(first, last + 1) / info['sfreq'] + times = np.arange(first, last + 1) / info["sfreq"] else: - raise RuntimeError('Could not read time parameters') + raise RuntimeError("Could not read time parameters") del first, last if nsamp is not None and data.shape[1] != nsamp: - raise ValueError('Incorrect number of samples (%d instead of ' - ' %d)' % (data.shape[1], nsamp)) - logger.info(' Found the data of interest:') - logger.info(' t = %10.2f ... %10.2f ms (%s)' - % (1000 * times[0], 1000 * times[-1], comment)) - if info['comps'] is not None: - logger.info(' %d CTF compensation matrices available' - % len(info['comps'])) - logger.info(' nave = %d - aspect type = %d' - % (nave, aspect_kind)) + raise ValueError( + "Incorrect number of samples (%d instead of " + " %d)" % (data.shape[1], nsamp) + ) + logger.info(" Found the data of interest:") + logger.info( + " t = %10.2f ... %10.2f ms (%s)" + % (1000 * times[0], 1000 * times[-1], comment) + ) + if info["comps"] is not None: + logger.info( + " %d CTF compensation matrices available" % len(info["comps"]) + ) + logger.info(" nave = %d - aspect type = %d" % (nave, aspect_kind)) # Calibrate - cals = np.array([info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0) - for k in range(info['nchan'])]) + cals = np.array( + [ + info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0) + for k in range(info["nchan"]) + ] + ) data *= cals[:, np.newaxis] return info, nave, aspect_kind, comment, times, data, baseline @verbose -def write_evokeds(fname, evoked, *, on_mismatch='raise', overwrite=False, - verbose=None): +def write_evokeds(fname, evoked, *, on_mismatch="raise", overwrite=False, verbose=None): """Write an evoked dataset to a file. Parameters @@ -1436,15 +1845,15 @@ def write_evokeds(fname, evoked, *, on_mismatch='raise', overwrite=False, _write_evokeds(fname, evoked, on_mismatch=on_mismatch, overwrite=overwrite) -def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', - overwrite=False): +def _write_evokeds(fname, evoked, check=True, *, on_mismatch="raise", overwrite=False): """Write evoked data.""" from .dipole import DipoleFixed # avoid circular import fname = _check_fname(fname=fname, overwrite=overwrite) if check: - check_fname(fname, 'evoked', ('-ave.fif', '-ave.fif.gz', - '_ave.fif', '_ave.fif.gz')) + check_fname( + fname, "evoked", ("-ave.fif", "-ave.fif.gz", "_ave.fif", "_ave.fif.gz") + ) if not isinstance(evoked, (list, tuple)): evoked = [evoked] @@ -1452,11 +1861,10 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', warned = False # Create the file and save the essentials with start_and_end_file(fname) as fid: - start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) - if evoked[0].info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, evoked[0].info['meas_id']) + if evoked[0].info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, evoked[0].info["meas_id"]) # Write measurement info write_meas_info(fid, evoked[0].info) @@ -1465,9 +1873,12 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', start_block(fid, FIFF.FIFFB_PROCESSED_DATA) for ei, e in enumerate(evoked): if ei: - _ensure_infos_match(info1=evoked[0].info, info2=e.info, - name=f'evoked[{ei}]', - on_mismatch=on_mismatch) + _ensure_infos_match( + info1=evoked[0].info, + info2=e.info, + name=f"evoked[{ei}]", + on_mismatch=on_mismatch, + ) start_block(fid, FIFF.FIFFB_EVOKED) # Comment is optional @@ -1487,7 +1898,7 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, bmax) # The evoked data itself - if e.info.get('maxshield'): + if e.info.get("maxshield"): aspect = FIFF.FIFFB_IAS_ASPECT else: aspect = FIFF.FIFFB_ASPECT @@ -1497,17 +1908,20 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', # convert nave to integer to comply with FIFF spec nave_int = int(round(e.nave)) if nave_int != e.nave and not warned: - warn('converting "nave" to integer before saving evoked; this ' - 'can have a minor effect on the scale of source ' - 'estimates that are computed using "nave".') + warn( + 'converting "nave" to integer before saving evoked; this ' + "can have a minor effect on the scale of source " + 'estimates that are computed using "nave".' + ) warned = True write_int(fid, FIFF.FIFF_NAVE, nave_int) del nave_int - decal = np.zeros((e.info['nchan'], 1)) - for k in range(e.info['nchan']): - decal[k] = 1.0 / (e.info['chs'][k]['cal'] * - e.info['chs'][k].get('scale', 1.0)) + decal = np.zeros((e.info["nchan"], 1)) + for k in range(e.info["nchan"]): + decal[k] = 1.0 / ( + e.info["chs"][k]["cal"] * e.info["chs"][k].get("scale", 1.0) + ) if np.iscomplexobj(e.data): write_function = write_complex_float_matrix @@ -1522,7 +1936,7 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', end_block(fid, FIFF.FIFFB_MEAS) -def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): +def _get_peak(data, times, tmin=None, tmax=None, mode="abs"): """Get feature-index and time of maximum signal from 2D array. Note. This is a 'getter', not a 'finder'. For non-evoked type @@ -1553,7 +1967,7 @@ def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): max_amp : float Amplitude of the maximum response. """ - _check_option('mode', mode, ['abs', 'neg', 'pos']) + _check_option("mode", mode, ["abs", "neg", "pos"]) if tmin is None: tmin = times[0] @@ -1562,36 +1976,37 @@ def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): if tmin < times.min() or tmax > times.max(): if tmin < times.min(): - param_name = 'tmin' + param_name = "tmin" param_val = tmin else: - param_name = 'tmax' + param_name = "tmax" param_val = tmax raise ValueError( - f'{param_name} ({param_val}) is out of bounds. It must be ' - f'between {times.min()} and {times.max()}' + f"{param_name} ({param_val}) is out of bounds. It must be " + f"between {times.min()} and {times.max()}" ) elif tmin > tmax: - raise ValueError(f'tmin ({tmin}) must be <= tmax ({tmax})') + raise ValueError(f"tmin ({tmin}) must be <= tmax ({tmax})") time_win = (times >= tmin) & (times <= tmax) mask = np.ones_like(data).astype(bool) mask[:, time_win] = False maxfun = np.argmax - if mode == 'pos': + if mode == "pos": if not np.any(data[~mask] > 0): - raise ValueError('No positive values encountered. Cannot ' - 'operate in pos mode.') - elif mode == 'neg': + raise ValueError( + "No positive values encountered. Cannot " "operate in pos mode." + ) + elif mode == "neg": if not np.any(data[~mask] < 0): - raise ValueError('No negative values encountered. Cannot ' - 'operate in neg mode.') + raise ValueError( + "No negative values encountered. Cannot " "operate in neg mode." + ) maxfun = np.argmin - masked_index = np.ma.array(np.abs(data) if mode == 'abs' else data, - mask=mask) + masked_index = np.ma.array(np.abs(data) if mode == "abs" else data, mask=mask) max_loc, max_time = np.unravel_index(maxfun(masked_index), data.shape) diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index 91e0c08b94d..ff61ee939fb 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -5,6 +5,7 @@ import os from ..utils import _check_pybv_installed + _check_pybv_installed() from pybv._export import _export_mne_raw # noqa: E402 diff --git a/mne/export/_edf.py b/mne/export/_edf.py index 8a6b1370470..3666aae30fe 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -6,6 +6,7 @@ import numpy as np from ..utils import _check_edflib_installed, warn + _check_edflib_installed() from EDFlib.edfwriter import EDFwriter # noqa: E402 @@ -14,7 +15,7 @@ def _try_to_set_value(header, key, value, channel_index=None): """Set key/value pairs in EDF header.""" # all EDFLib set functions are set # for example "setPatientName()" - func_name = f'set{key}' + func_name = f"set{key}" func = getattr(header, func_name) # some setter functions are indexed by channels @@ -25,9 +26,9 @@ def _try_to_set_value(header, key, value, channel_index=None): # a nonzero return value indicates an error if return_val != 0: - raise RuntimeError(f"Setting {key} with {value} " - f"returned an error value " - f"{return_val}.") + raise RuntimeError( + f"Setting {key} with {value} " f"returned an error value " f"{return_val}." + ) @contextmanager @@ -49,11 +50,12 @@ def _export_raw(fname, raw, physical_range, add_ch_type): technician information, allow writing those here. """ # scale to save data in EDF - phys_dims = 'uV' + phys_dims = "uV" # get EEG-related data in uV - units = dict(eeg='uV', ecog='uV', seeg='uV', eog='uV', ecg='uV', emg='uV', - bio='uV', dbs='uV') + units = dict( + eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" + ) digital_min = -32767 digital_max = 32767 @@ -65,8 +67,8 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # remove extra STI channels orig_ch_types = raw.get_channel_types() drop_chs = [] - if 'stim' in orig_ch_types: - stim_index = np.argwhere(np.array(orig_ch_types) == 'stim') + if "stim" in orig_ch_types: + stim_index = np.argwhere(np.array(orig_ch_types) == "stim") stim_index = np.atleast_1d(stim_index.squeeze()).tolist() drop_chs.extend([raw.ch_names[idx] for idx in stim_index]) @@ -77,17 +79,19 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # Note: we can write these other channels, such as 'misc' # but these are simply a "catch all" for unknown or undesired # channels. - voltage_types = list(units) + ['stim', 'misc'] + voltage_types = list(units) + ["stim", "misc"] non_voltage_ch = [ch not in voltage_types for ch in orig_ch_types] if any(non_voltage_ch): - warn(f"Non-voltage channels detected: {non_voltage_ch}. MNE-Python's " - 'EDF exporter only supports voltage-based channels, because the ' - 'EDF format cannot accommodate much of the accompanying data ' - 'necessary for channel types like MEG and fNIRS (channel ' - 'orientations, coordinate frame transforms, etc). You can ' - 'override this restriction by setting those channel types to ' - '"misc" but no guarantees are made of the fidelity of that ' - 'approach.') + warn( + f"Non-voltage channels detected: {non_voltage_ch}. MNE-Python's " + "EDF exporter only supports voltage-based channels, because the " + "EDF format cannot accommodate much of the accompanying data " + "necessary for channel types like MEG and fNIRS (channel " + "orientations, coordinate frame transforms, etc). You can " + "override this restriction by setting those channel types to " + '"misc" but no guarantees are made of the fidelity of that ' + "approach." + ) ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] ch_types = np.array(raw.get_channel_types(picks=ch_names)) @@ -97,28 +101,29 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # Sampling frequency in EDF only supports integers, so to allow for # float sampling rates from Raw, we adjust the output sampling rate # for all channels and the data record duration. - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] if float(sfreq).is_integer(): out_sfreq = int(sfreq) data_record_duration = None else: out_sfreq = np.floor(sfreq).astype(int) - data_record_duration = int(np.around( - out_sfreq / sfreq, decimals=6) * 1e6) + data_record_duration = int(np.around(out_sfreq / sfreq, decimals=6) * 1e6) - warn(f'Data has a non-integer sampling rate of {sfreq}; writing to ' - 'EDF format may cause a small change to sample times.') + warn( + f"Data has a non-integer sampling rate of {sfreq}; writing to " + "EDF format may cause a small change to sample times." + ) # get any filter information applied to the data - lowpass = raw.info['lowpass'] - highpass = raw.info['highpass'] - linefreq = raw.info['line_freq'] + lowpass = raw.info["lowpass"] + highpass = raw.info["highpass"] + linefreq = raw.info["line_freq"] filter_str_info = f"HP:{highpass}Hz LP:{lowpass}Hz N:{linefreq}Hz" # get the entire dataset in uV data = raw.get_data(units=units, picks=ch_names) - if physical_range == 'auto': + if physical_range == "auto": # get max and min for each channel type data ch_types_phys_max = dict() ch_types_phys_min = dict() @@ -156,54 +161,60 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # set channel data for idx, ch in enumerate(ch_names): ch_type = ch_types[idx] - signal_label = f'{ch_type.upper()} {ch}' if add_ch_type else ch + signal_label = f"{ch_type.upper()} {ch}" if add_ch_type else ch if len(signal_label) > 16: - raise RuntimeError(f'Signal label for {ch} ({ch_type}) is ' - f'longer than 16 characters, which is not ' - f'supported in EDF. Please shorten the ' - f'channel name before exporting to EDF.') - - if physical_range == 'auto': + raise RuntimeError( + f"Signal label for {ch} ({ch_type}) is " + f"longer than 16 characters, which is not " + f"supported in EDF. Please shorten the " + f"channel name before exporting to EDF." + ) + + if physical_range == "auto": # take the channel type minimum and maximum pmin = ch_types_phys_min[ch_type] pmax = ch_types_phys_max[ch_type] - for key, val in [('PhysicalMaximum', pmax), - ('PhysicalMinimum', pmin), - ('DigitalMaximum', digital_max), - ('DigitalMinimum', digital_min), - ('PhysicalDimension', phys_dims), - ('SampleFrequency', out_sfreq), - ('SignalLabel', signal_label), - ('PreFilter', filter_str_info)]: + for key, val in [ + ("PhysicalMaximum", pmax), + ("PhysicalMinimum", pmin), + ("DigitalMaximum", digital_max), + ("DigitalMinimum", digital_min), + ("PhysicalDimension", phys_dims), + ("SampleFrequency", out_sfreq), + ("SignalLabel", signal_label), + ("PreFilter", filter_str_info), + ]: _try_to_set_value(hdl, key, val, channel_index=idx) # set patient info - subj_info = raw.info.get('subject_info') + subj_info = raw.info.get("subject_info") if subj_info is not None: - birthday = subj_info.get('birthday') + birthday = subj_info.get("birthday") # get the full name of subject if available - first_name = subj_info.get('first_name') - last_name = subj_info.get('last_name') - first_name = first_name or '' - last_name = last_name or '' - joiner = '' + first_name = subj_info.get("first_name") + last_name = subj_info.get("last_name") + first_name = first_name or "" + last_name = last_name or "" + joiner = "" if len(first_name) and len(last_name): - joiner = ' ' + joiner = " " name = joiner.join([first_name, last_name]) - hand = subj_info.get('hand') - sex = subj_info.get('sex') + hand = subj_info.get("hand") + sex = subj_info.get("sex") if birthday is not None: - if hdl.setPatientBirthDate(birthday[0], birthday[1], - birthday[2]) != 0: + if hdl.setPatientBirthDate(birthday[0], birthday[1], birthday[2]) != 0: raise RuntimeError( f"Setting patient birth date to {birthday} " - f"returned an error") - for key, val in [('PatientName', name), - ('PatientGender', sex), - ('AdditionalPatientInfo', f'hand={hand}')]: + f"returned an error" + ) + for key, val in [ + ("PatientName", name), + ("PatientGender", sex), + ("AdditionalPatientInfo", f"hand={hand}"), + ]: # EDFwriter compares integer encodings of sex and will # raise a TypeError if value is None as returned by # subj_info.get(key) if key is missing. @@ -211,25 +222,33 @@ def _export_raw(fname, raw, physical_range, add_ch_type): _try_to_set_value(hdl, key, val) # set measurement date - meas_date = raw.info['meas_date'] + meas_date = raw.info["meas_date"] if meas_date: subsecond = int(meas_date.microsecond / 100) - if hdl.setStartDateTime(year=meas_date.year, month=meas_date.month, - day=meas_date.day, hour=meas_date.hour, - minute=meas_date.minute, - second=meas_date.second, - subsecond=subsecond) != 0: - raise RuntimeError(f"Setting start date time {meas_date} " - f"returned an error") - - device_info = raw.info.get('device_info') + if ( + hdl.setStartDateTime( + year=meas_date.year, + month=meas_date.month, + day=meas_date.day, + hour=meas_date.hour, + minute=meas_date.minute, + second=meas_date.second, + subsecond=subsecond, + ) + != 0 + ): + raise RuntimeError( + f"Setting start date time {meas_date} " f"returned an error" + ) + + device_info = raw.info.get("device_info") if device_info is not None: - device_type = device_info.get('type') - _try_to_set_value(hdl, 'Equipment', device_type) + device_type = device_info.get("type") + _try_to_set_value(hdl, "Equipment", device_type) # set data record duration if data_record_duration is not None: - _try_to_set_value(hdl, 'DataRecordDuration', data_record_duration) + _try_to_set_value(hdl, "DataRecordDuration", data_record_duration) # compute number of data records to loop over n_blocks = np.ceil(n_times / out_sfreq).astype(int) @@ -260,29 +279,36 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ch_data = data[jdx, start_samp:end_samp] # assign channel data to the buffer and write to EDF - buf[:len(ch_data)] = ch_data + buf[: len(ch_data)] = ch_data err = hdl.writeSamples(buf) if err != 0: raise RuntimeError( f"writeSamples() for channel{ch_names[jdx]} " - f"returned error: {err}") + f"returned error: {err}" + ) # there was an incomplete datarecord if len(ch_data) != len(buf): - warn(f'EDF format requires equal-length data blocks, ' - f'so {(len(buf) - len(ch_data)) / sfreq} seconds of ' - 'zeros were appended to all channels when writing the ' - 'final block.') + warn( + f"EDF format requires equal-length data blocks, " + f"so {(len(buf) - len(ch_data)) / sfreq} seconds of " + "zeros were appended to all channels when writing the " + "final block." + ) # write annotations if annots is not None: - for desc, onset, duration in zip(raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration): + for desc, onset, duration in zip( + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + ): # annotations are written in terms of 100 microseconds onset = onset * 10000 duration = duration * 10000 if hdl.writeAnnotation(onset, duration, desc) != 0: - raise RuntimeError(f'writeAnnotation() returned an error ' - f'trying to write {desc} at {onset} ' - f'for {duration} seconds.') + raise RuntimeError( + f"writeAnnotation() returned an error " + f"trying to write {desc} at {onset} " + f"for {duration} seconds." + ) diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 00d566c13fe..3fd1cc55902 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -5,6 +5,7 @@ import numpy as np from ..utils import _check_eeglabio_installed + _check_eeglabio_installed() import eeglabio.raw # noqa: E402 import eeglabio.epochs # noqa: E402 @@ -15,20 +16,27 @@ def _export_raw(fname, raw): raw.load_data() # remove extra epoc and STI channels - drop_chs = ['epoc'] + drop_chs = ["epoc"] # filenames attribute of RawArray is filled with None - if raw.filenames[0] and not (raw.filenames[0].endswith('.fif')): - drop_chs.append('STI 014') + if raw.filenames[0] and not (raw.filenames[0].endswith(".fif")): + drop_chs.append("STI 014") ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] - cart_coords = _get_als_coords_from_chs(raw.info['chs'], drop_chs) + cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration] + annotations = [ + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + ] eeglabio.raw.export_set( - fname, data=raw.get_data(picks=ch_names), sfreq=raw.info['sfreq'], - ch_names=ch_names, ch_locs=cart_coords, annotations=annotations) + fname, + data=raw.get_data(picks=ch_names), + sfreq=raw.info["sfreq"], + ch_names=ch_names, + ch_locs=cart_coords, + annotations=annotations, + ) def _export_epochs(fname, epochs): @@ -37,21 +45,31 @@ def _export_epochs(fname, epochs): epochs.load_data() # remove extra epoc and STI channels - drop_chs = ['epoc', 'STI 014'] + drop_chs = ["epoc", "STI 014"] ch_names = [ch for ch in epochs.ch_names if ch not in drop_chs] - cart_coords = _get_als_coords_from_chs(epochs.info['chs'], drop_chs) + cart_coords = _get_als_coords_from_chs(epochs.info["chs"], drop_chs) if epochs.annotations: - annot = [epochs.annotations.description, epochs.annotations.onset, - epochs.annotations.duration] + annot = [ + epochs.annotations.description, + epochs.annotations.onset, + epochs.annotations.duration, + ] else: annot = None eeglabio.epochs.export_set( - fname, data=epochs.get_data(picks=ch_names), - sfreq=epochs.info['sfreq'], events=epochs.events, - tmin=epochs.tmin, tmax=epochs.tmax, ch_names=ch_names, - event_id=epochs.event_id, ch_locs=cart_coords, annotations=annot) + fname, + data=epochs.get_data(picks=ch_names), + sfreq=epochs.info["sfreq"], + events=epochs.events, + tmin=epochs.tmin, + tmax=epochs.tmax, + ch_names=ch_names, + event_id=epochs.event_id, + ch_locs=cart_coords, + annotations=annot, + ) def _get_als_coords_from_chs(chs, drop_chs=None): @@ -63,8 +81,7 @@ def _get_als_coords_from_chs(chs, drop_chs=None): """ if drop_chs is None: drop_chs = [] - cart_coords = np.array([d['loc'][:3] for d in chs - if d['ch_name'] not in drop_chs]) + cart_coords = np.array([d["loc"][:3] for d in chs if d["ch_name"] not in drop_chs]) if cart_coords.any(): # has coordinates # (-y x z) to (x y z) cart_coords[:, 0] = -cart_coords[:, 0] # -y to y diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 65418d35d6c..2fc1e66ef9e 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -15,8 +15,7 @@ @verbose -def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, - verbose=None): +def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose=None): """Export evoked dataset to MFF. %(export_warning)s @@ -49,18 +48,22 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, (e.g. 'HydroCel GSN 256 1.0'). This field is automatically populated when using MFF read functions. """ - mffpy = _import_mffpy('Export evokeds to MFF.') + mffpy = _import_mffpy("Export evokeds to MFF.") import pytz + info = evoked[0].info - if np.round(info['sfreq']) != info['sfreq']: - raise ValueError('Sampling frequency must be a whole number. ' - f'sfreq: {info["sfreq"]}') - sampling_rate = int(info['sfreq']) + if np.round(info["sfreq"]) != info["sfreq"]: + raise ValueError( + "Sampling frequency must be a whole number. " f'sfreq: {info["sfreq"]}' + ) + sampling_rate = int(info["sfreq"]) # check for unapplied projectors - if any(not proj['active'] for proj in evoked[0].info['projs']): - warn('Evoked instance has unapplied projectors. Consider applying ' - 'them before exporting with evoked.apply_proj().') + if any(not proj["active"] for proj in evoked[0].info["projs"]): + warn( + "Evoked instance has unapplied projectors. Consider applying " + "them before exporting with evoked.apply_proj()." + ) # Initialize writer # Future changes: conditions based on version or mffpy requirement if @@ -70,11 +73,11 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, os.remove(fname) if op.isfile(fname) else shutil.rmtree(fname) writer = mffpy.Writer(fname) current_time = pytz.utc.localize(datetime.datetime.utcnow()) - writer.addxml('fileInfo', recordTime=current_time) + writer.addxml("fileInfo", recordTime=current_time) try: - device = info['device_info']['type'] + device = info["device_info"]["type"] except (TypeError, KeyError): - raise ValueError('No device type. Cannot determine sensor layout.') + raise ValueError("No device type. Cannot determine sensor layout.") writer.add_coordinates_and_sensor_layout(device) # Add EEG data @@ -88,11 +91,11 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, # Add categories categories_content = _categories_content_from_evokeds(evoked) - writer.addxml('categories', categories=categories_content) + writer.addxml("categories", categories=categories_content) # Add history if history: - writer.addxml('historyEntries', entries=history) + writer.addxml("historyEntries", entries=history) writer.write() @@ -103,14 +106,20 @@ def _categories_content_from_evokeds(evoked): begin_time = 0 for ave in evoked: # Times are converted to microseconds - sfreq = ave.info['sfreq'] + sfreq = ave.info["sfreq"] duration = np.round(len(ave.times) / sfreq * 1e6).astype(int) end_time = begin_time + duration event_time = begin_time - np.round(ave.tmin * 1e6).astype(int) eeg_bads = _get_bad_eeg_channels(ave.info) content[ave.comment] = [ - _build_segment_content(begin_time, end_time, event_time, eeg_bads, - name='Average', nsegs=ave.nave) + _build_segment_content( + begin_time, + end_time, + event_time, + eeg_bads, + name="Average", + nsegs=ave.nave, + ) ] begin_time += duration return content @@ -122,49 +131,47 @@ def _get_bad_eeg_channels(info): Given a list of only the EEG channels in file, return the indices of this list (starting at 1) that correspond to bad channels. """ - if len(info['bads']) == 0: + if len(info["bads"]) == 0: return [] eeg_channels = pick_types(info, eeg=True, exclude=[]) - bad_channels = pick_channels(info['ch_names'], info['bads']) + bad_channels = pick_channels(info["ch_names"], info["bads"]) bads_elementwise = np.isin(eeg_channels, bad_channels) return list(np.flatnonzero(bads_elementwise) + 1) -def _build_segment_content(begin_time, end_time, event_time, eeg_bads, - status='unedited', name=None, pns_bads=None, - nsegs=None): +def _build_segment_content( + begin_time, + end_time, + event_time, + eeg_bads, + status="unedited", + name=None, + pns_bads=None, + nsegs=None, +): """Build content for a single segment in categories.xml. Segments are sorted into categories in categories.xml. In a segmented MFF each category can contain multiple segments, but in an averaged MFF each category only contains one segment (the average). """ - channel_status = [{ - 'signalBin': 1, - 'exclusion': 'badChannels', - 'channels': eeg_bads - }] + channel_status = [ + {"signalBin": 1, "exclusion": "badChannels", "channels": eeg_bads} + ] if pns_bads: - channel_status.append({ - 'signalBin': 2, - 'exclusion': 'badChannels', - 'channels': pns_bads - }) + channel_status.append( + {"signalBin": 2, "exclusion": "badChannels", "channels": pns_bads} + ) content = { - 'status': status, - 'beginTime': begin_time, - 'endTime': end_time, - 'evtBegin': event_time, - 'evtEnd': event_time, - 'channelStatus': channel_status, + "status": status, + "beginTime": begin_time, + "endTime": end_time, + "evtBegin": event_time, + "evtEnd": event_time, + "channelStatus": channel_status, } if name: - content['name'] = name + content["name"] = name if nsegs: - content['keys'] = { - '#seg': { - 'type': 'long', - 'data': nsegs - } - } + content["keys"] = {"#seg": {"type": "long", "data": nsegs}} return content diff --git a/mne/export/_export.py b/mne/export/_export.py index c26927d1755..5afa420540c 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -9,8 +9,16 @@ @verbose -def export_raw(fname, raw, fmt='auto', physical_range='auto', - add_ch_type=False, *, overwrite=False, verbose=None): +def export_raw( + fname, + raw, + fmt="auto", + physical_range="auto", + add_ch_type=False, + *, + overwrite=False, + verbose=None, +): """Export Raw to external formats. %(export_fmt_support_raw)s @@ -40,30 +48,39 @@ def export_raw(fname, raw, fmt='auto', physical_range='auto', """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { # format : (extensions,) - 'eeglab': ('set',), - 'edf': ('edf',), - 'brainvision': ('eeg', 'vmrk', 'vhdr',) + "eeglab": ("set",), + "edf": ("edf",), + "brainvision": ( + "eeg", + "vmrk", + "vhdr", + ), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) # check for unapplied projectors - if any(not proj['active'] for proj in raw.info['projs']): - warn('Raw instance has unapplied projectors. Consider applying ' - 'them before exporting with raw.apply_proj().') + if any(not proj["active"] for proj in raw.info["projs"]): + warn( + "Raw instance has unapplied projectors. Consider applying " + "them before exporting with raw.apply_proj()." + ) - if fmt == 'eeglab': + if fmt == "eeglab": from ._eeglab import _export_raw + _export_raw(fname, raw) - elif fmt == 'edf': + elif fmt == "edf": from ._edf import _export_raw + _export_raw(fname, raw, physical_range, add_ch_type) - elif fmt == 'brainvision': + elif fmt == "brainvision": from ._brainvision import _export_raw + _export_raw(fname, raw, overwrite) @verbose -def export_epochs(fname, epochs, fmt='auto', *, overwrite=False, verbose=None): +def export_epochs(fname, epochs, fmt="auto", *, overwrite=False, verbose=None): """Export Epochs to external formats. %(export_fmt_support_epochs)s @@ -90,23 +107,25 @@ def export_epochs(fname, epochs, fmt='auto', *, overwrite=False, verbose=None): """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { - 'eeglab': ('set',), + "eeglab": ("set",), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) # check for unapplied projectors - if any(not proj['active'] for proj in epochs.info['projs']): - warn('Epochs instance has unapplied projectors. Consider applying ' - 'them before exporting with epochs.apply_proj().') + if any(not proj["active"] for proj in epochs.info["projs"]): + warn( + "Epochs instance has unapplied projectors. Consider applying " + "them before exporting with epochs.apply_proj()." + ) - if fmt == 'eeglab': + if fmt == "eeglab": from ._eeglab import _export_epochs + _export_epochs(fname, epochs) @verbose -def export_evokeds(fname, evoked, fmt='auto', *, overwrite=False, - verbose=None): +def export_evokeds(fname, evoked, fmt="auto", *, overwrite=False, verbose=None): """Export evoked dataset to external formats. This function is a wrapper for format-specific export functions. The export @@ -143,16 +162,16 @@ def export_evokeds(fname, evoked, fmt='auto', *, overwrite=False, """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { - 'mff': ('mff',), + "mff": ("mff",), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) if not isinstance(evoked, list): evoked = [evoked] - logger.info(f'Exporting evoked dataset to {fname}...') + logger.info(f"Exporting evoked dataset to {fname}...") - if fmt == 'mff': + if fmt == "mff": export_evokeds_mff(fname, evoked, overwrite=overwrite) @@ -174,26 +193,30 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): Dictionary containing supported formats (as keys) and each format's corresponding file extensions in a tuple (e.g., {'eeglab': ('set',)}) """ - _validate_type(fmt, str, 'fmt') + _validate_type(fmt, str, "fmt") fmt = fmt.lower() if fmt == "auto": fmt = op.splitext(fname)[1] if fmt: fmt = fmt[1:].lower() # find fmt in supported formats dict's tuples - fmt = next((k for k, v in supported_formats.items() if fmt in v), - fmt) # default to original fmt for raising error later + fmt = next( + (k for k, v in supported_formats.items() if fmt in v), fmt + ) # default to original fmt for raising error later else: - raise ValueError(f"Couldn't infer format from filename {fname}" - " (no extension found)") + raise ValueError( + f"Couldn't infer format from filename {fname}" " (no extension found)" + ) if fmt not in supported_formats: supported = [] for format, extensions in supported_formats.items(): - ext_str = ', '.join(f'*.{ext}' for ext in extensions) - supported.append(f'{format} ({ext_str})') - - supported_str = ', '.join(supported) - raise ValueError(f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}.") + ext_str = ", ".join(f"*.{ext}" for ext in extensions) + supported.append(f"{format} ({ext_str})") + + supported_str = ", ".join(supported) + raise ValueError( + f"Format '{fmt}' is not supported. " + f"Supported formats are {supported_str}." + ) return fmt diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 27e29ab343f..4aeada34543 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -10,39 +10,49 @@ import pytest import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_array_equal) +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal -from mne import (read_epochs_eeglab, Epochs, read_evokeds, read_evokeds_mff, - Annotations) +from mne import read_epochs_eeglab, Epochs, read_evokeds, read_evokeds_mff, Annotations from mne.datasets import testing, misc from mne.export import export_evokeds, export_evokeds_mff from mne.fixes import _compare_version -from mne.io import (RawArray, read_raw_fif, read_raw_eeglab, read_raw_edf, - read_raw_brainvision) +from mne.io import ( + RawArray, + read_raw_fif, + read_raw_eeglab, + read_raw_edf, + read_raw_brainvision, +) from mne.io.meas_info import create_info -from mne.utils import (_check_eeglabio_installed, requires_version, - object_diff, _check_edflib_installed, _resource_path, - _check_pybv_installed, _record_warnings) +from mne.utils import ( + _check_eeglabio_installed, + requires_version, + object_diff, + _check_edflib_installed, + _resource_path, + _check_pybv_installed, + _record_warnings, +) from mne.tests.test_epochs import _get_data -fname_evoked = _resource_path('mne.io.tests.data', 'test-ave.fif') -fname_raw = _resource_path('mne.io.tests.data', 'test_raw.fif') +fname_evoked = _resource_path("mne.io.tests.data", "test-ave.fif") +fname_raw = _resource_path("mne.io.tests.data", "test_raw.fif") data_path = testing.data_path(download=False) egi_evoked_fname = data_path / "EGI" / "test_egi_evoked.mff" misc_path = misc.data_path(download=False) -@pytest.mark.skipif(not _check_pybv_installed(strict=False), - reason='pybv not installed') +@pytest.mark.skipif( + not _check_pybv_installed(strict=False), reason="pybv not installed" +) @pytest.mark.parametrize( - ['meas_date', 'orig_time', 'ext'], [ - [None, None, '.vhdr'], - [datetime(2022, 12, 3, 19, 1, 10, 720100, tzinfo=timezone.utc), - None, - '.eeg'], - ]) + ["meas_date", "orig_time", "ext"], + [ + [None, None, ".vhdr"], + [datetime(2022, 12, 3, 19, 1, 10, 720100, tzinfo=timezone.utc), None, ".eeg"], + ], +) def test_export_raw_pybv(tmp_path, meas_date, orig_time, ext): """Test saving a Raw instance to BrainVision format via pybv.""" raw = read_raw_fif(fname_raw, preload=True) @@ -66,39 +76,39 @@ def test_export_raw_pybv(tmp_path, meas_date, orig_time, ext): ) raw.set_annotations(annots) - temp_fname = tmp_path / ('test' + ext) + temp_fname = tmp_path / ("test" + ext) with pytest.warns(RuntimeWarning, match="'short' format. Converting"): raw.export(temp_fname) - raw_read = read_raw_brainvision(str(temp_fname).replace('.eeg', '.vhdr')) + raw_read = read_raw_brainvision(str(temp_fname).replace(".eeg", ".vhdr")) assert raw.ch_names == raw_read.ch_names assert_allclose(raw.times, raw_read.times) assert_allclose(raw.get_data(), raw_read.get_data()) -@requires_version('pymatreader') -@pytest.mark.skipif(not _check_eeglabio_installed(strict=False), - reason='eeglabio not installed') +@requires_version("pymatreader") +@pytest.mark.skipif( + not _check_eeglabio_installed(strict=False), reason="eeglabio not installed" +) def test_export_raw_eeglab(tmp_path): """Test saving a Raw instance to EEGLAB's set format.""" raw = read_raw_fif(fname_raw, preload=True) raw.apply_proj() temp_fname = tmp_path / "test.set" raw.export(temp_fname) - raw.drop_channels([ch for ch in ['epoc'] - if ch in raw.ch_names]) + raw.drop_channels([ch for ch in ["epoc"] if ch in raw.ch_names]) - with pytest.warns(RuntimeWarning, match='is above the 99th percentile'): - raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units='m') + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") assert raw.ch_names == raw_read.ch_names - cart_coords = np.array([d['loc'][:3] for d in raw.info['chs']]) # just xyz - cart_coords_read = np.array([d['loc'][:3] for d in raw_read.info['chs']]) + cart_coords = np.array([d["loc"][:3] for d in raw.info["chs"]]) # just xyz + cart_coords_read = np.array([d["loc"][:3] for d in raw_read.info["chs"]]) assert_allclose(cart_coords, cart_coords_read) assert_allclose(raw.times, raw_read.times) assert_allclose(raw.get_data(), raw_read.get_data()) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): raw.export(temp_fname, overwrite=False) raw.export(temp_fname, overwrite=True) @@ -107,29 +117,41 @@ def test_export_raw_eeglab(tmp_path): # test warning with unapplied projectors raw = read_raw_fif(fname_raw, preload=True) - with pytest.warns(RuntimeWarning, - match='Raw instance has unapplied projectors.'): + with pytest.warns(RuntimeWarning, match="Raw instance has unapplied projectors."): raw.export(temp_fname, overwrite=True) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" rng = np.random.RandomState(123456) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'ecog', 'seeg', 'eog', 'ecg', - 'emg', 'dbs', 'bio'] + format = "edf" + ch_types = [ + "eeg", + "eeg", + "stim", + "ecog", + "ecog", + "seeg", + "eog", + "ecg", + "emg", + "dbs", + "bio", + ] info = create_info(len(ch_types), sfreq=1000, ch_types=ch_types) data = rng.random(size=(len(ch_types), 1000)) * 1e-5 # include subject info and measurement date - info['subject_info'] = dict(first_name='mne', last_name='python', - birthday=(1992, 1, 20), sex=1, hand=3) + info["subject_info"] = dict( + first_name="mne", last_name="python", birthday=(1992, 1, 20), sex=1, hand=3 + ) raw = RawArray(data, info) # export once - temp_fname = tmp_path / f'test.{format}' + temp_fname = tmp_path / f"test.{format}" raw.export(temp_fname, add_ch_type=True) raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) @@ -139,15 +161,15 @@ def test_double_export_edf(tmp_path): raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) # stim channel should be dropped - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() @@ -155,28 +177,41 @@ def test_double_export_edf(tmp_path): assert_array_equal(orig_ch_types, read_ch_types) # check handling of missing subject metadata - del info['subject_info']['sex'] + del info["subject_info"]["sex"] raw_2 = RawArray(data, info) raw_2.export(temp_fname, add_ch_type=True, overwrite=True) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_export_edf_annotations(tmp_path): """Test that exporting EDF preserves annotations.""" rng = np.random.RandomState(123456) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'ecog', 'seeg', - 'eog', 'ecg', 'emg', 'dbs', 'bio'] + format = "edf" + ch_types = [ + "eeg", + "eeg", + "stim", + "ecog", + "ecog", + "seeg", + "eog", + "ecg", + "emg", + "dbs", + "bio", + ] ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, - ch_types=ch_types) - data = rng.random(size=(len(ch_names), 2000)) * 1.e-5 + info = create_info(ch_names, sfreq=1000, ch_types=ch_types) + data = rng.random(size=(len(ch_names), 2000)) * 1.0e-5 raw = RawArray(data, info) annotations = Annotations( - onset=[0.01, 0.05, 0.90, 1.05], duration=[0, 1, 0, 0], - description=['test1', 'test2', 'test3', 'test4']) + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ) raw.set_annotations(annotations) # export @@ -187,33 +222,37 @@ def test_export_edf_annotations(tmp_path): raw_read = read_raw_edf(temp_fname, preload=True) assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, - raw_read.annotations.description) + assert_array_equal(raw.annotations.description, raw_read.annotations.description) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_rawarray_edf(tmp_path): """Test saving a Raw array with integer sfreq to EDF.""" rng = np.random.RandomState(12345) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'seeg', 'eog', 'ecg', 'emg', - 'dbs', 'bio'] + format = "edf" + ch_types = ["eeg", "eeg", "stim", "ecog", "seeg", "eog", "ecg", "emg", "dbs", "bio"] ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, - ch_types=ch_types) + info = create_info(ch_names, sfreq=1000, ch_types=ch_types) data = rng.random(size=(len(ch_names), 1000)) * 1e-5 # include subject info and measurement date - subject_info = dict(first_name='mne', last_name='python', - birthday=(1992, 1, 20), sex=1, hand=3) - info['subject_info'] = subject_info + subject_info = dict( + first_name="mne", last_name="python", birthday=(1992, 1, 20), sex=1, hand=3 + ) + info["subject_info"] = subject_info raw = RawArray(data, info) time_now = datetime.now() - meas_date = datetime(year=time_now.year, month=time_now.month, - day=time_now.day, hour=time_now.hour, - minute=time_now.minute, second=time_now.second, - tzinfo=timezone.utc) + meas_date = datetime( + year=time_now.year, + month=time_now.month, + day=time_now.day, + hour=time_now.hour, + minute=time_now.minute, + second=time_now.second, + tzinfo=timezone.utc, + ) raw.set_meas_date(meas_date) temp_fname = tmp_path / f"test.{format}" @@ -221,82 +260,84 @@ def test_rawarray_edf(tmp_path): raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) # stim channel should be dropped - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() read_ch_types = raw_read.get_channel_types() assert_array_equal(orig_ch_types, read_ch_types) - assert raw.info['meas_date'] == raw_read.info['meas_date'] + assert raw.info["meas_date"] == raw_read.info["meas_date"] # channel name can't be longer than 16 characters with the type added raw_bad = raw.copy() - raw_bad.rename_channels({'1': 'abcdefghijklmnopqrstuvwxyz'}) - with pytest.raises(RuntimeError, match='Signal label'), \ - pytest.warns(RuntimeWarning, match='Data has a non-integer'): + raw_bad.rename_channels({"1": "abcdefghijklmnopqrstuvwxyz"}) + with pytest.raises(RuntimeError, match="Signal label"), pytest.warns( + RuntimeWarning, match="Data has a non-integer" + ): raw_bad.export(temp_fname, overwrite=True) # include bad birthday that is non-EDF compliant bad_info = info.copy() - bad_info['subject_info']['birthday'] = (1700, 1, 20) + bad_info["subject_info"]["birthday"] = (1700, 1, 20) raw = RawArray(data, bad_info) - with pytest.raises(RuntimeError, match='Setting patient birth date'): + with pytest.raises(RuntimeError, match="Setting patient birth date"): raw.export(temp_fname, overwrite=True) # include bad measurement date that is non-EDF compliant raw = RawArray(data, info) meas_date = datetime(year=1984, month=1, day=1, tzinfo=timezone.utc) raw.set_meas_date(meas_date) - with pytest.raises(RuntimeError, match='Setting start date time'): + with pytest.raises(RuntimeError, match="Setting start date time"): raw.export(temp_fname, overwrite=True) # test that warning is raised if there are non-voltage based channels raw = RawArray(data, info) - raw.set_channel_types({'9': 'hbr'}, on_unit_change='ignore') - with pytest.warns(RuntimeWarning, match='Non-voltage channels'): + raw.set_channel_types({"9": "hbr"}, on_unit_change="ignore") + with pytest.warns(RuntimeWarning, match="Non-voltage channels"): raw.export(temp_fname, overwrite=True) # data should match up to the non-accepted channel raw_read = read_raw_edf(temp_fname, preload=True) orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data()[:-1, :], raw_read.get_data()[:, :orig_raw_len], - decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data()[:-1, :], raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # the data should still match though raw_read = read_raw_edf(temp_fname, preload=True) - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) @pytest.mark.parametrize( - ['dataset', 'format'], [ - ['test', 'edf'], - pytest.param('misc', 'edf', marks=[pytest.mark.slowtest, - misc._pytest_mark()]), - ]) + ["dataset", "format"], + [ + ["test", "edf"], + pytest.param("misc", "edf", marks=[pytest.mark.slowtest, misc._pytest_mark()]), + ], +) def test_export_raw_edf(tmp_path, dataset, format): """Test saving a Raw instance to EDF format.""" - if dataset == 'test': + if dataset == "test": raw = read_raw_fif(fname_raw) - elif dataset == 'misc': + elif dataset == "misc": fname = misc_path / "ecog" / "sample_ecog_ieeg.fif" raw = read_raw_fif(fname) @@ -309,31 +350,27 @@ def test_export_raw_edf(tmp_path, dataset, format): # test runtime errors with pytest.warns() as record: raw.export(temp_fname, physical_range=(-1e6, 0)) - if dataset == 'test': - assert any( - "Data has a non-integer" in str(rec.message) for rec in record - ) + if dataset == "test": + assert any("Data has a non-integer" in str(rec.message) for rec in record) assert any("The maximum" in str(rec.message) for rec in record) remove(temp_fname) with pytest.warns() as record: raw.export(temp_fname, physical_range=(0, 1e6)) - if dataset == 'test': - assert any( - "Data has a non-integer" in str(rec.message) for rec in record - ) + if dataset == "test": + assert any("Data has a non-integer" in str(rec.message) for rec in record) assert any("The minimum" in str(rec.message) for rec in record) remove(temp_fname) - if dataset == 'test': - with pytest.warns(RuntimeWarning, match='Data has a non-integer'): + if dataset == "test": + with pytest.warns(RuntimeWarning, match="Data has a non-integer"): raw.export(temp_fname) - elif dataset == 'misc': - with pytest.warns(RuntimeWarning, match='EDF format requires'): + elif dataset == "misc": + with pytest.warns(RuntimeWarning, match="EDF format requires"): raw.export(temp_fname) - if 'epoc' in raw.ch_names: - raw.drop_channels(['epoc']) + if "epoc" in raw.ch_names: + raw.drop_channels(["epoc"]) raw_read = read_raw_edf(temp_fname, preload=True) assert orig_ch_names == raw_read.ch_names @@ -346,7 +383,8 @@ def test_export_raw_edf(tmp_path, dataset, format): # will result in a resolution of 0.09 uV. This resolution # though is acceptable for most EEG manufacturers. assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) # Due to the data record duration limitations of EDF files, one # cannot store arbitrary float sampling rate exactly. Usually this @@ -354,46 +392,43 @@ def test_export_raw_edf(tmp_path, dataset, format): # decimal points. This for practical purposes does not matter # but will result in an error when say the number of time points # is very very large. - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) -@pytest.mark.xfail(reason='eeglabio (usage?) bugs that should be fixed') -@requires_version('pymatreader') -@pytest.mark.skipif(not _check_eeglabio_installed(strict=False), - reason='eeglabio not installed') -@pytest.mark.parametrize('preload', (True, False)) +@pytest.mark.xfail(reason="eeglabio (usage?) bugs that should be fixed") +@requires_version("pymatreader") +@pytest.mark.skipif( + not _check_eeglabio_installed(strict=False), reason="eeglabio not installed" +) +@pytest.mark.parametrize("preload", (True, False)) def test_export_epochs_eeglab(tmp_path, preload): """Test saving an Epochs instance to EEGLAB's set format.""" import eeglabio + raw, events = _get_data()[:2] raw.load_data() epochs = Epochs(raw, events, preload=preload) temp_fname = tmp_path / "test.set" # TODO: eeglabio 0.2 warns about invalid events - if _compare_version(eeglabio.__version__, '==', '0.0.2-1'): + if _compare_version(eeglabio.__version__, "==", "0.0.2-1"): ctx = _record_warnings else: ctx = nullcontext with ctx(): epochs.export(temp_fname) - epochs.drop_channels([ch for ch in ['epoc', 'STI 014'] - if ch in epochs.ch_names]) + epochs.drop_channels([ch for ch in ["epoc", "STI 014"] if ch in epochs.ch_names]) epochs_read = read_epochs_eeglab(temp_fname) assert epochs.ch_names == epochs_read.ch_names - cart_coords = np.array([d['loc'][:3] - for d in epochs.info['chs']]) # just xyz - cart_coords_read = np.array([d['loc'][:3] - for d in epochs_read.info['chs']]) + cart_coords = np.array([d["loc"][:3] for d in epochs.info["chs"]]) # just xyz + cart_coords_read = np.array([d["loc"][:3] for d in epochs_read.info["chs"]]) assert_allclose(cart_coords, cart_coords_read) - assert_array_equal(epochs.events[:, 0], - epochs_read.events[:, 0]) # latency + assert_array_equal(epochs.events[:, 0], epochs_read.events[:, 0]) # latency assert epochs.event_id.keys() == epochs_read.event_id.keys() # just keys assert_allclose(epochs.times, epochs_read.times) assert_allclose(epochs.get_data(), epochs_read.get_data()) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): epochs.export(temp_fname, overwrite=False) with ctx(): epochs.export(temp_fname, overwrite=True) @@ -404,45 +439,46 @@ def test_export_epochs_eeglab(tmp_path, preload): # test warning with unapplied projectors epochs = Epochs(raw, events, preload=preload, proj=False) - with pytest.warns(RuntimeWarning, - match='Epochs instance has unapplied projectors.'): + with pytest.warns( + RuntimeWarning, match="Epochs instance has unapplied projectors." + ): epochs.export(Path(temp_fname), overwrite=True) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @testing.requires_testing_data -@pytest.mark.parametrize('fmt', ('auto', 'mff')) -@pytest.mark.parametrize('do_history', (True, False)) +@pytest.mark.parametrize("fmt", ("auto", "mff")) +@pytest.mark.parametrize("do_history", (True, False)) def test_export_evokeds_to_mff(tmp_path, fmt, do_history): """Test exporting evoked dataset to MFF.""" evoked = read_evokeds_mff(egi_evoked_fname) export_fname = tmp_path / "evoked.mff" history = [ { - 'name': 'Test Segmentation', - 'method': 'Segmentation', - 'settings': ['Setting 1', 'Setting 2'], - 'results': ['Result 1', 'Result 2'] + "name": "Test Segmentation", + "method": "Segmentation", + "settings": ["Setting 1", "Setting 2"], + "results": ["Result 1", "Result 2"], }, { - 'name': 'Test Averaging', - 'method': 'Averaging', - 'settings': ['Setting 1', 'Setting 2'], - 'results': ['Result 1', 'Result 2'] - } + "name": "Test Averaging", + "method": "Averaging", + "settings": ["Setting 1", "Setting 2"], + "results": ["Result 1", "Result 2"], + }, ] if do_history: export_evokeds_mff(export_fname, evoked, history=history) else: export_evokeds(export_fname, evoked, fmt=fmt) # Drop non-EEG channels - evoked = [ave.drop_channels(['ECG', 'EMG']) for ave in evoked] + evoked = [ave.drop_channels(["ECG", "EMG"]) for ave in evoked] evoked_exported = read_evokeds_mff(export_fname) assert len(evoked) == len(evoked_exported) for ave, ave_exported in zip(evoked, evoked_exported): # Compare infos - assert object_diff(ave_exported.info, ave.info) == '' + assert object_diff(ave_exported.info, ave.info) == "" # Compare data assert_allclose(ave_exported.data, ave.data) # Compare properties @@ -452,16 +488,14 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history): assert_allclose(ave_exported.times, ave.times) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): if do_history: - export_evokeds_mff(export_fname, evoked, history=history, - overwrite=False) + export_evokeds_mff(export_fname, evoked, history=history, overwrite=False) else: export_evokeds(export_fname, evoked, overwrite=False) if do_history: - export_evokeds_mff(export_fname, evoked, history=history, - overwrite=True) + export_evokeds_mff(export_fname, evoked, history=history, overwrite=True) else: export_evokeds(export_fname, evoked, overwrite=True) @@ -469,35 +503,33 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history): evoked[0].export(export_fname, overwrite=True) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @testing.requires_testing_data def test_export_to_mff_no_device(): """Test no device type throws ValueError.""" - evoked = read_evokeds_mff(egi_evoked_fname, condition='Category 1') - evoked.info['device_info'] = None - with pytest.raises(ValueError, match='No device type.'): - export_evokeds('output.mff', evoked) + evoked = read_evokeds_mff(egi_evoked_fname, condition="Category 1") + evoked.info["device_info"] = None + with pytest.raises(ValueError, match="No device type."): + export_evokeds("output.mff", evoked) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") def test_export_to_mff_incompatible_sfreq(): """Test non-whole number sampling frequency throws ValueError.""" evoked = read_evokeds(fname_evoked) with pytest.raises(ValueError, match=f'sfreq: {evoked[0].info["sfreq"]}'): - export_evokeds('output.mff', evoked) + export_evokeds("output.mff", evoked) -@pytest.mark.parametrize('fmt,ext', [ - ('EEGLAB', 'set'), - ('EDF', 'edf'), - ('BrainVision', 'vhdr'), - ('auto', 'vhdr') -]) +@pytest.mark.parametrize( + "fmt,ext", + [("EEGLAB", "set"), ("EDF", "edf"), ("BrainVision", "vhdr"), ("auto", "vhdr")], +) def test_export_evokeds_unsupported_format(fmt, ext): """Test exporting evoked dataset to non-supported formats.""" evoked = read_evokeds(fname_evoked) errstr = fmt.lower() if fmt != "auto" else "vhdr" with pytest.raises(ValueError, match=f"Format '{errstr}' is not .*"): - export_evokeds(f'output.{ext}', evoked, fmt=fmt) + export_evokeds(f"output.{ext}", evoked, fmt=fmt) diff --git a/mne/filter.py b/mne/filter.py index 2fc0d10b2c4..5277a1fd502 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -8,11 +8,25 @@ from .annotations import _annotations_starts_stops from .io.pick import _picks_to_idx -from .cuda import (_setup_cuda_fft_multiply_repeated, _fft_multiply_repeated, - _setup_cuda_fft_resample, _fft_resample, _smart_pad) +from .cuda import ( + _setup_cuda_fft_multiply_repeated, + _fft_multiply_repeated, + _setup_cuda_fft_resample, + _fft_resample, + _smart_pad, +) from .parallel import parallel_func -from .utils import (logger, verbose, sum_squared, warn, _pl, - _check_preload, _validate_type, _check_option, _ensure_int) +from .utils import ( + logger, + verbose, + sum_squared, + warn, + _pl, + _check_preload, + _validate_type, + _check_option, + _ensure_int, +) from ._ola import _COLA # These values from Ifeachor and Jervis. @@ -66,20 +80,178 @@ def next_fast_len(target): Copied from SciPy with minor modifications. """ from bisect import bisect_left - hams = (8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, - 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, - 135, 144, 150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, - 256, 270, 288, 300, 320, 324, 360, 375, 384, 400, 405, 432, 450, - 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675, 720, 729, - 750, 768, 800, 810, 864, 900, 960, 972, 1000, 1024, 1080, 1125, - 1152, 1200, 1215, 1250, 1280, 1296, 1350, 1440, 1458, 1500, 1536, - 1600, 1620, 1728, 1800, 1875, 1920, 1944, 2000, 2025, 2048, 2160, - 2187, 2250, 2304, 2400, 2430, 2500, 2560, 2592, 2700, 2880, 2916, - 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600, 3645, 3750, 3840, - 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800, 4860, 5000, - 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250, 6400, - 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100, - 8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000) + + hams = ( + 8, + 9, + 10, + 12, + 15, + 16, + 18, + 20, + 24, + 25, + 27, + 30, + 32, + 36, + 40, + 45, + 48, + 50, + 54, + 60, + 64, + 72, + 75, + 80, + 81, + 90, + 96, + 100, + 108, + 120, + 125, + 128, + 135, + 144, + 150, + 160, + 162, + 180, + 192, + 200, + 216, + 225, + 240, + 243, + 250, + 256, + 270, + 288, + 300, + 320, + 324, + 360, + 375, + 384, + 400, + 405, + 432, + 450, + 480, + 486, + 500, + 512, + 540, + 576, + 600, + 625, + 640, + 648, + 675, + 720, + 729, + 750, + 768, + 800, + 810, + 864, + 900, + 960, + 972, + 1000, + 1024, + 1080, + 1125, + 1152, + 1200, + 1215, + 1250, + 1280, + 1296, + 1350, + 1440, + 1458, + 1500, + 1536, + 1600, + 1620, + 1728, + 1800, + 1875, + 1920, + 1944, + 2000, + 2025, + 2048, + 2160, + 2187, + 2250, + 2304, + 2400, + 2430, + 2500, + 2560, + 2592, + 2700, + 2880, + 2916, + 3000, + 3072, + 3125, + 3200, + 3240, + 3375, + 3456, + 3600, + 3645, + 3750, + 3840, + 3888, + 4000, + 4050, + 4096, + 4320, + 4374, + 4500, + 4608, + 4800, + 4860, + 5000, + 5120, + 5184, + 5400, + 5625, + 5760, + 5832, + 6000, + 6075, + 6144, + 6250, + 6400, + 6480, + 6561, + 6750, + 6912, + 7200, + 7290, + 7500, + 7680, + 7776, + 8000, + 8100, + 8192, + 8640, + 8748, + 9000, + 9216, + 9375, + 9600, + 9720, + 10000, + ) if target <= 6: return target @@ -92,7 +264,7 @@ def next_fast_len(target): if target <= hams[-1]: return hams[bisect_left(hams, target)] - match = float('inf') # Anything found will be smaller + match = float("inf") # Anything found will be smaller p5 = 1 while p5 < target: p35 = p5 @@ -121,8 +293,16 @@ def next_fast_len(target): return match -def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, - n_jobs=None, copy=True, pad='reflect_limited'): +def _overlap_add_filter( + x, + h, + n_fft=None, + phase="zero", + picks=None, + n_jobs=None, + copy=True, + pad="reflect_limited", +): """Filter the signal x using h with overlap-add FFTs. Parameters @@ -162,12 +342,12 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, # response _check_zero_phase_length(len(h), phase) if len(h) == 1: - return x * h ** 2 if phase == 'zero-double' else x * h + return x * h**2 if phase == "zero-double" else x * h n_edge = max(min(len(h), x.shape[1]) - 1, 0) - logger.debug('Smart-padding with: %s samples on each edge' % n_edge) + logger.debug("Smart-padding with: %s samples on each edge" % n_edge) n_x = x.shape[1] + 2 * n_edge - if phase == 'zero-double': + if phase == "zero-double": h = np.convolve(h, h[::-1]) # Determine FFT length to use @@ -176,10 +356,14 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, max_fft = n_x if max_fft >= min_fft: # cost function based on number of multiplications - N = 2 ** np.arange(np.ceil(np.log2(min_fft)), - np.ceil(np.log2(max_fft)) + 1, dtype=int) - cost = (np.ceil(n_x / (N - len(h) + 1).astype(np.float64)) * - N * (np.log2(N) + 1)) + N = 2 ** np.arange( + np.ceil(np.log2(min_fft)), np.ceil(np.log2(max_fft)) + 1, dtype=int + ) + cost = ( + np.ceil(n_x / (N - len(h) + 1).astype(np.float64)) + * N + * (np.log2(N) + 1) + ) # add a heuristic term to prevent too-long FFT's which are slow # (not predicted by mult. cost alone, 4e-5 exp. determined) @@ -189,10 +373,12 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, else: # Use only a single block n_fft = next_fast_len(min_fft) - logger.debug('FFT block length: %s' % n_fft) + logger.debug("FFT block length: %s" % n_fft) if n_fft < min_fft: - raise ValueError('n_fft is too short, has to be at least ' - '2 * len(h) - 1 (%s), got %s' % (min_fft, n_fft)) + raise ValueError( + "n_fft is too short, has to be at least " + "2 * len(h) - 1 (%s), got %s" % (min_fft, n_fft) + ) # Figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft) @@ -202,11 +388,13 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs) if n_jobs == 1: for p in picks: - x[p] = _1d_overlap_filter(x[p], len(h), n_edge, phase, - cuda_dict, pad, n_fft) + x[p] = _1d_overlap_filter( + x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft + ) else: - data_new = parallel(p_fun(x[p], len(h), n_edge, phase, - cuda_dict, pad, n_fft) for p in picks) + data_new = parallel( + p_fun(x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft) for p in picks + ) for pp, p in enumerate(picks): x[p] = data_new[pp] @@ -223,7 +411,7 @@ def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft): n_seg = n_fft - n_h + 1 n_segments = int(np.ceil(n_x / float(n_seg))) - shift = ((n_h - 1) // 2 if phase.startswith('zero') else 0) + n_edge + shift = ((n_h - 1) // 2 if phase.startswith("zero") else 0) + n_edge # Now the actual filtering step is identical for zero-phase (filtfilt-like) # or single-pass @@ -242,13 +430,14 @@ def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft): x_filtered[start_filt:stop_filt] += prod[start_prod:stop_prod] # Remove mirrored edges that we added and cast (n_edge can be zero) - x_filtered = x_filtered[:n_x - 2 * n_edge].astype(x.dtype) + x_filtered = x_filtered[: n_x - 2 * n_edge].astype(x.dtype) return x_filtered def _filter_attenuation(h, freq, gain): """Compute minimum attenuation at stop frequency.""" from scipy.signal import freqz + _, filt_resp = freqz(h.ravel(), worN=np.pi * freq) filt_resp = np.abs(filt_resp) # use amplitude response filt_resp[np.where(gain == 1)] = 0 @@ -269,12 +458,13 @@ def _prep_for_filtering(x, copy, picks=None): x.shape = (np.prod(x.shape[:-1]), x.shape[-1]) if len(orig_shape) == 3: n_epochs, n_channels, n_times = orig_shape - offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), - len(picks)) + offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), len(picks)) picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: - raise ValueError('picks argument is not supported for data with more' - ' than three dimensions') + raise ValueError( + "picks argument is not supported for data with more" + " than three dimensions" + ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above return x, orig_shape, picks @@ -283,6 +473,7 @@ def _prep_for_filtering(x, copy, picks=None): def _firwin_design(N, freq, gain, window, sfreq): """Construct a FIR filter using firwin.""" from scipy.signal import firwin + assert freq[0] == 0 assert len(freq) > 1 assert len(freq) == len(gain) @@ -297,30 +488,37 @@ def _firwin_design(N, freq, gain, window, sfreq): assert this_gain in (0, 1) if this_gain != prev_gain: # Get the correct N to satistify the requested transition bandwidth - transition = (prev_freq - this_freq) / 2. + transition = (prev_freq - this_freq) / 2.0 this_N = int(round(_length_factors[window] / transition)) - this_N += (1 - this_N % 2) # make it odd + this_N += 1 - this_N % 2 # make it odd if this_N > N: - raise ValueError('The requested filter length %s is too short ' - 'for the requested %0.2f Hz transition band, ' - 'which requires %s samples' - % (N, transition * sfreq / 2., this_N)) + raise ValueError( + "The requested filter length %s is too short " + "for the requested %0.2f Hz transition band, " + "which requires %s samples" % (N, transition * sfreq / 2.0, this_N) + ) # Construct a lowpass - this_h = firwin(this_N, (prev_freq + this_freq) / 2., - window=window, pass_zero=True, fs=freq[-1] * 2) + this_h = firwin( + this_N, + (prev_freq + this_freq) / 2.0, + window=window, + pass_zero=True, + fs=freq[-1] * 2, + ) assert this_h.shape == (this_N,) offset = (N - this_N) // 2 if this_gain == 0: - h[offset:N - offset] -= this_h + h[offset : N - offset] -= this_h else: - h[offset:N - offset] += this_h + h[offset : N - offset] += this_h prev_gain = this_gain prev_freq = this_freq return h -def _construct_fir_filter(sfreq, freq, gain, filter_length, phase, fir_window, - fir_design): +def _construct_fir_filter( + sfreq, freq, gain, filter_length, phase, fir_window, fir_design +): """Filter signal using gain control points in the frequency domain. The filter impulse response is constructed from a Hann window (window @@ -358,50 +556,53 @@ def _construct_fir_filter(sfreq, freq, gain, filter_length, phase, fir_window, Filter coefficients. """ assert freq[0] == 0 - if fir_design == 'firwin2': + if fir_design == "firwin2": from scipy.signal import firwin2 as fir_design else: - assert fir_design == 'firwin' + assert fir_design == "firwin" fir_design = partial(_firwin_design, sfreq=sfreq) from scipy.signal import minimum_phase # issue a warning if attenuation is less than this - min_att_db = 12 if phase == 'minimum' else 20 + min_att_db = 12 if phase == "minimum" else 20 # normalize frequencies - freq = np.array(freq) / (sfreq / 2.) + freq = np.array(freq) / (sfreq / 2.0) if freq[0] != 0 or freq[-1] != 1: - raise ValueError('freq must start at 0 and end an Nyquist (%s), got %s' - % (sfreq / 2., freq)) + raise ValueError( + "freq must start at 0 and end an Nyquist (%s), got %s" % (sfreq / 2.0, freq) + ) gain = np.array(gain) # Use overlap-add filter with a fixed length N = _check_zero_phase_length(filter_length, phase, gain[-1]) # construct symmetric (linear phase) filter - if phase == 'minimum': + if phase == "minimum": h = fir_design(N * 2 - 1, freq, gain, window=fir_window) h = minimum_phase(h) else: h = fir_design(N, freq, gain, window=fir_window) assert h.size == N att_db, att_freq = _filter_attenuation(h, freq, gain) - if phase == 'zero-double': + if phase == "zero-double": att_db += 6 if att_db < min_att_db: - att_freq *= sfreq / 2. - warn('Attenuation at stop frequency %0.2f Hz is only %0.2f dB. ' - 'Increase filter_length for higher attenuation.' - % (att_freq, att_db)) + att_freq *= sfreq / 2.0 + warn( + "Attenuation at stop frequency %0.2f Hz is only %0.2f dB. " + "Increase filter_length for higher attenuation." % (att_freq, att_db) + ) return h def _check_zero_phase_length(N, phase, gain_nyq=0): N = int(N) if N % 2 == 0: - if phase == 'zero': - raise RuntimeError('filter_length must be odd if phase="zero", ' - 'got %s' % N) - elif phase == 'zero-double' and gain_nyq == 1: + if phase == "zero": + raise RuntimeError( + 'filter_length must be odd if phase="zero", ' "got %s" % N + ) + elif phase == "zero-double" and gain_nyq == 1: N += 1 return N @@ -410,39 +611,43 @@ def _check_coefficients(system): """Check for filter stability.""" if isinstance(system, tuple): from scipy.signal import tf2zpk + z, p, k = tf2zpk(*system) else: # sos from scipy.signal import sos2zpk + z, p, k = sos2zpk(system) if np.any(np.abs(p) > 1.0): - raise RuntimeError('Filter poles outside unit circle, filter will be ' - 'unstable. Consider using different filter ' - 'coefficients.') + raise RuntimeError( + "Filter poles outside unit circle, filter will be " + "unstable. Consider using different filter " + "coefficients." + ) -def _iir_filter(x, iir_params, picks, n_jobs, copy, phase='zero'): +def _iir_filter(x, iir_params, picks, n_jobs, copy, phase="zero"): """Call filtfilt or lfilter.""" # set up array for filtering, reshape to 2D, operate on last axis from scipy.signal import filtfilt, sosfiltfilt, lfilter, sosfilt + x, orig_shape, picks = _prep_for_filtering(x, copy, picks) - if phase in ('zero', 'zero-double'): - padlen = min(iir_params['padlen'], x.shape[-1] - 1) - if 'sos' in iir_params: - fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen, - axis=-1) - _check_coefficients(iir_params['sos']) + if phase in ("zero", "zero-double"): + padlen = min(iir_params["padlen"], x.shape[-1] - 1) + if "sos" in iir_params: + fun = partial(sosfiltfilt, sos=iir_params["sos"], padlen=padlen, axis=-1) + _check_coefficients(iir_params["sos"]) else: - fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'], - padlen=padlen, axis=-1) - _check_coefficients((iir_params['b'], iir_params['a'])) + fun = partial( + filtfilt, b=iir_params["b"], a=iir_params["a"], padlen=padlen, axis=-1 + ) + _check_coefficients((iir_params["b"], iir_params["a"])) else: - if 'sos' in iir_params: - fun = partial(sosfilt, sos=iir_params['sos'], axis=-1) - _check_coefficients(iir_params['sos']) + if "sos" in iir_params: + fun = partial(sosfilt, sos=iir_params["sos"], axis=-1) + _check_coefficients(iir_params["sos"]) else: - fun = partial(lfilter, b=iir_params['b'], a=iir_params['a'], - axis=-1) - _check_coefficients((iir_params['b'], iir_params['a'])) + fun = partial(lfilter, b=iir_params["b"], a=iir_params["a"], axis=-1) + _check_coefficients((iir_params["b"], iir_params["a"])) parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) if n_jobs == 1: for p in picks: @@ -472,14 +677,15 @@ def estimate_ringing_samples(system, max_try=100000): The approximate ringing. """ from scipy import signal + if isinstance(system, tuple): # TF - kind = 'ba' + kind = "ba" b, a = system - zi = [0.] * (len(a) - 1) + zi = [0.0] * (len(a) - 1) else: - kind = 'sos' + kind = "sos" sos = system - zi = [[0.] * 2] * len(sos) + zi = [[0.0] * 2] * len(sos) n_per_chunk = 1000 n_chunks_max = int(np.ceil(max_try / float(n_per_chunk))) x = np.zeros(n_per_chunk) @@ -487,7 +693,7 @@ def estimate_ringing_samples(system, max_try=100000): last_good = n_per_chunk thresh_val = 0 for ii in range(n_chunks_max): - if kind == 'ba': + if kind == "ba": h, zi = signal.lfilter(b, a, x, zi=zi) else: h, zi = signal.sosfilt(sos, x, zi=zi) @@ -501,24 +707,32 @@ def estimate_ringing_samples(system, max_try=100000): idx = (ii - 1) * n_per_chunk + last_good break else: - warn('Could not properly estimate ringing for the filter') + warn("Could not properly estimate ringing for the filter") idx = n_per_chunk * n_chunks_max return idx _ftype_dict = { - 'butter': 'Butterworth', - 'cheby1': 'Chebyshev I', - 'cheby2': 'Chebyshev II', - 'ellip': 'Cauer/elliptic', - 'bessel': 'Bessel/Thomson', + "butter": "Butterworth", + "cheby1": "Chebyshev I", + "cheby2": "Chebyshev II", + "ellip": "Cauer/elliptic", + "bessel": "Bessel/Thomson", } @verbose -def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, - btype=None, return_copy=True, *, phase='zero', - verbose=None): +def construct_iir_filter( + iir_params, + f_pass=None, + f_stop=None, + sfreq=None, + btype=None, + return_copy=True, + *, + phase="zero", + verbose=None, +): """Use IIR parameters to get filtering coefficients. This function works like a wrapper for iirdesign and iirfilter in @@ -636,136 +850,190 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, :ref:`disc-filtering` and :ref:`tut-filter-resample`. """ # noqa: E501 from scipy.signal import iirfilter, iirdesign, freqz, sosfreqz - known_filters = ('bessel', 'butter', 'butterworth', 'cauer', 'cheby1', - 'cheby2', 'chebyshev1', 'chebyshev2', 'chebyshevi', - 'chebyshevii', 'ellip', 'elliptic') + + known_filters = ( + "bessel", + "butter", + "butterworth", + "cauer", + "cheby1", + "cheby2", + "chebyshev1", + "chebyshev2", + "chebyshevi", + "chebyshevii", + "ellip", + "elliptic", + ) if not isinstance(iir_params, dict): - raise TypeError('iir_params must be a dict, got %s' % type(iir_params)) + raise TypeError("iir_params must be a dict, got %s" % type(iir_params)) # if the filter has been designed, we're good to go Wp = None - if 'sos' in iir_params: - system = iir_params['sos'] - output = 'sos' - elif 'a' in iir_params and 'b' in iir_params: - system = (iir_params['b'], iir_params['a']) - output = 'ba' + if "sos" in iir_params: + system = iir_params["sos"] + output = "sos" + elif "a" in iir_params and "b" in iir_params: + system = (iir_params["b"], iir_params["a"]) + output = "ba" else: - output = iir_params.get('output', 'sos') - _check_option('output', output, ('ba', 'sos')) + output = iir_params.get("output", "sos") + _check_option("output", output, ("ba", "sos")) # ensure we have a valid ftype - if 'ftype' not in iir_params: - raise RuntimeError('ftype must be an entry in iir_params if ''b'' ' - 'and ''a'' are not specified') - ftype = iir_params['ftype'] + if "ftype" not in iir_params: + raise RuntimeError( + "ftype must be an entry in iir_params if " + "b" + " " + "and " + "a" + " are not specified" + ) + ftype = iir_params["ftype"] if ftype not in known_filters: - raise RuntimeError('ftype must be in filter_dict from ' - 'scipy.signal (e.g., butter, cheby1, etc.) not ' - '%s' % ftype) + raise RuntimeError( + "ftype must be in filter_dict from " + "scipy.signal (e.g., butter, cheby1, etc.) not " + "%s" % ftype + ) # use order-based design f_pass = np.atleast_1d(f_pass) if f_pass.ndim > 1: - raise ValueError('frequencies must be 1D, got %dD' % f_pass.ndim) - edge_freqs = ', '.join('%0.2f' % (f,) for f in f_pass) + raise ValueError("frequencies must be 1D, got %dD" % f_pass.ndim) + edge_freqs = ", ".join("%0.2f" % (f,) for f in f_pass) Wp = f_pass / (float(sfreq) / 2) # IT will de designed ftype_nice = _ftype_dict.get(ftype, ftype) - _validate_type(phase, str, 'phase') - _check_option('phase', phase, ('zero', 'zero-double', 'forward')) - if phase in ('zero-double', 'zero'): - ptype = 'zero-phase (two-pass forward and reverse) non-causal' + _validate_type(phase, str, "phase") + _check_option("phase", phase, ("zero", "zero-double", "forward")) + if phase in ("zero-double", "zero"): + ptype = "zero-phase (two-pass forward and reverse) non-causal" else: - ptype = 'non-linear phase (one-pass forward) causal' - logger.info('') - logger.info('IIR filter parameters') - logger.info('---------------------') - logger.info(f'{ftype_nice} {btype} {ptype} filter:') + ptype = "non-linear phase (one-pass forward) causal" + logger.info("") + logger.info("IIR filter parameters") + logger.info("---------------------") + logger.info(f"{ftype_nice} {btype} {ptype} filter:") # SciPy designs forward for -3dB, so forward-backward is -6dB - if 'order' in iir_params: - singleton = btype in ('low', 'lowpass', 'high', 'highpass') + if "order" in iir_params: + singleton = btype in ("low", "lowpass", "high", "highpass") use_Wp = Wp.item() if singleton else Wp - kwargs = dict(N=iir_params['order'], Wn=use_Wp, btype=btype, - ftype=ftype, output=output) - for key in ('rp', 'rs'): + kwargs = dict( + N=iir_params["order"], + Wn=use_Wp, + btype=btype, + ftype=ftype, + output=output, + ) + for key in ("rp", "rs"): if key in iir_params: kwargs[key] = iir_params[key] system = iirfilter(**kwargs) - if phase in ('zero', 'zero-double'): - ptype, pmul = '(effective, after forward-backward)', 2 + if phase in ("zero", "zero-double"): + ptype, pmul = "(effective, after forward-backward)", 2 else: - ptype, pmul = '(forward)', 1 - logger.info('- Filter order %d %s' - % (pmul * iir_params['order'] * len(Wp), ptype)) + ptype, pmul = "(forward)", 1 + logger.info( + "- Filter order %d %s" % (pmul * iir_params["order"] * len(Wp), ptype) + ) else: # use gpass / gstop design Ws = np.asanyarray(f_stop) / (float(sfreq) / 2) - if 'gpass' not in iir_params or 'gstop' not in iir_params: - raise ValueError('iir_params must have at least ''gstop'' and' - ' ''gpass'' (or ''N'') entries') - system = iirdesign(Wp, Ws, iir_params['gpass'], - iir_params['gstop'], ftype=ftype, output=output) + if "gpass" not in iir_params or "gstop" not in iir_params: + raise ValueError( + "iir_params must have at least " + "gstop" + " and" + " " + "gpass" + " (or " + "N" + ") entries" + ) + system = iirdesign( + Wp, + Ws, + iir_params["gpass"], + iir_params["gstop"], + ftype=ftype, + output=output, + ) if system is None: - raise RuntimeError('coefficients could not be created from iir_params') + raise RuntimeError("coefficients could not be created from iir_params") # do some sanity checks _check_coefficients(system) # get the gains at the cutoff frequencies if Wp is not None: - if output == 'sos': + if output == "sos": cutoffs = sosfreqz(system, worN=Wp * np.pi)[1] else: cutoffs = freqz(system[0], system[1], worN=Wp * np.pi)[1] cutoffs = 20 * np.log10(np.abs(cutoffs)) # 2 * 20 here because we do forward-backward filtering - if phase in ('zero', 'zero-double'): + if phase in ("zero", "zero-double"): cutoffs *= 2 - cutoffs = ', '.join(['%0.2f' % (c,) for c in cutoffs]) - logger.info('- Cutoff%s at %s Hz: %s dB' - % (_pl(f_pass), edge_freqs, cutoffs)) + cutoffs = ", ".join(["%0.2f" % (c,) for c in cutoffs]) + logger.info("- Cutoff%s at %s Hz: %s dB" % (_pl(f_pass), edge_freqs, cutoffs)) # now deal with padding - if 'padlen' not in iir_params: + if "padlen" not in iir_params: padlen = estimate_ringing_samples(system) else: - padlen = iir_params['padlen'] + padlen = iir_params["padlen"] if return_copy: iir_params = deepcopy(iir_params) iir_params.update(dict(padlen=padlen)) - if output == 'sos': + if output == "sos": iir_params.update(sos=system) else: iir_params.update(b=system[0], a=system[1]) - logger.info('') + logger.info("") return iir_params def _check_method(method, iir_params, extra_types=()): """Parse method arguments.""" - allowed_types = ['iir', 'fir', 'fft'] + list(extra_types) - _validate_type(method, 'str', 'method') - _check_option('method', method, allowed_types) - if method == 'fft': - method = 'fir' # use the better name - if method == 'iir': + allowed_types = ["iir", "fir", "fft"] + list(extra_types) + _validate_type(method, "str", "method") + _check_option("method", method, allowed_types) + if method == "fft": + method = "fir" # use the better name + if method == "iir": if iir_params is None: iir_params = dict() - if len(iir_params) == 0 or (len(iir_params) == 1 and - 'output' in iir_params): - iir_params = dict(order=4, ftype='butter', - output=iir_params.get('output', 'sos')) + if len(iir_params) == 0 or (len(iir_params) == 1 and "output" in iir_params): + iir_params = dict( + order=4, ftype="butter", output=iir_params.get("output", "sos") + ) elif iir_params is not None: raise ValueError('iir_params must be None if method != "iir"') return iir_params, method @verbose -def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - n_jobs=None, method='fir', iir_params=None, copy=True, - phase='zero', fir_window='hamming', fir_design='firwin', - pad='reflect_limited', *, verbose=None): +def filter_data( + data, + sfreq, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + copy=True, + phase="zero", + fir_window="hamming", + fir_design="firwin", + pad="reflect_limited", + *, + verbose=None, +): """Filter a subset of channels. Parameters @@ -834,21 +1102,42 @@ def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto', data = _check_filterable(data) iir_params, method = _check_method(method, iir_params) filt = create_filter( - data, sfreq, l_freq, h_freq, filter_length, l_trans_bandwidth, - h_trans_bandwidth, method, iir_params, phase, fir_window, fir_design) - if method in ('fir', 'fft'): - data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, - copy, pad) + data, + sfreq, + l_freq, + h_freq, + filter_length, + l_trans_bandwidth, + h_trans_bandwidth, + method, + iir_params, + phase, + fir_window, + fir_design, + ) + if method in ("fir", "fft"): + data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, copy, pad) else: data = _iir_filter(data, filt, picks, n_jobs, copy, phase) return data @verbose -def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - method='fir', iir_params=None, phase='zero', - fir_window='hamming', fir_design='firwin', verbose=None): +def create_filter( + data, + sfreq, + l_freq, + h_freq, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + method="fir", + iir_params=None, + phase="zero", + fir_window="hamming", + fir_design="firwin", + verbose=None, +): r"""Create a FIR or IIR filter. ``l_freq`` and ``h_freq`` are the frequencies below which and above @@ -967,61 +1256,127 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', """ sfreq = float(sfreq) if sfreq < 0: - raise ValueError('sfreq must be positive') + raise ValueError("sfreq must be positive") # If no data specified, sanity checking will be skipped if data is None: - logger.info('No data specified. Sanity checks related to the length of' - ' the signal relative to the filter order will be' - ' skipped.') + logger.info( + "No data specified. Sanity checks related to the length of" + " the signal relative to the filter order will be" + " skipped." + ) if h_freq is not None: h_freq = np.array(h_freq, float).ravel() - if (h_freq > (sfreq / 2.)).any(): - raise ValueError('h_freq (%s) must be less than the Nyquist ' - 'frequency %s' % (h_freq, sfreq / 2.)) + if (h_freq > (sfreq / 2.0)).any(): + raise ValueError( + "h_freq (%s) must be less than the Nyquist " + "frequency %s" % (h_freq, sfreq / 2.0) + ) if l_freq is not None: l_freq = np.array(l_freq, float).ravel() if (l_freq == 0).all(): l_freq = None iir_params, method = _check_method(method, iir_params) if l_freq is None and h_freq is None: - data, sfreq, _, _, _, _, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, None, None, None, None, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': + ( + data, + sfreq, + _, + _, + _, + _, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + None, + None, + None, + None, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": out = dict() if iir_params is None else deepcopy(iir_params) - out.update(b=np.array([1.]), a=np.array([1.])) + out.update(b=np.array([1.0]), a=np.array([1.0])) else: - freq = [0, sfreq / 2.] - gain = [1., 1.] + freq = [0, sfreq / 2.0] + gain = [1.0, 1.0] if l_freq is None and h_freq is not None: h_freq = h_freq.item() - logger.info('Setting up low-pass filter at %0.2g Hz' % (h_freq,)) - data, sfreq, _, f_p, _, f_s, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, None, h_freq, None, h_trans_bandwidth, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, f_p, f_s, sfreq, 'lowpass', - phase=phase) + logger.info("Setting up low-pass filter at %0.2g Hz" % (h_freq,)) + ( + data, + sfreq, + _, + f_p, + _, + f_s, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + None, + h_freq, + None, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, f_p, f_s, sfreq, "lowpass", phase=phase + ) else: # 'fir' freq = [0, f_p, f_s] gain = [1, 1, 0] - if f_s != sfreq / 2.: - freq += [sfreq / 2.] + if f_s != sfreq / 2.0: + freq += [sfreq / 2.0] gain += [0] elif l_freq is not None and h_freq is None: l_freq = l_freq.item() - logger.info('Setting up high-pass filter at %0.2g Hz' % (l_freq,)) - data, sfreq, pass_, _, stop, _, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, l_freq, None, l_trans_bandwidth, None, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, pass_, stop, sfreq, - 'highpass', phase=phase) + logger.info("Setting up high-pass filter at %0.2g Hz" % (l_freq,)) + ( + data, + sfreq, + pass_, + _, + stop, + _, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + l_freq, + None, + l_trans_bandwidth, + None, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, pass_, stop, sfreq, "highpass", phase=phase + ) else: # 'fir' - freq = [stop, pass_, sfreq / 2.] + freq = [stop, pass_, sfreq / 2.0] gain = [0, 1, 1] if stop != 0: freq = [0] + freq @@ -1029,22 +1384,47 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', elif l_freq is not None and h_freq is not None: if (l_freq < h_freq).any(): l_freq, h_freq = l_freq.item(), h_freq.item() - logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz' - % (l_freq, h_freq)) - data, sfreq, f_p1, f_p2, f_s1, f_s2, filter_length, phase, \ - fir_window, fir_design = _triage_filter_params( - data, sfreq, l_freq, h_freq, l_trans_bandwidth, - h_trans_bandwidth, filter_length, method, phase, - fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, [f_p1, f_p2], - [f_s1, f_s2], sfreq, 'bandpass', - phase=phase) + logger.info( + "Setting up band-pass filter from %0.2g - %0.2g Hz" % (l_freq, h_freq) + ) + ( + data, + sfreq, + f_p1, + f_p2, + f_s1, + f_s2, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + l_freq, + h_freq, + l_trans_bandwidth, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, + [f_p1, f_p2], + [f_s1, f_s2], + sfreq, + "bandpass", + phase=phase, + ) else: # 'fir' freq = [f_s1, f_p1, f_p2, f_s2] gain = [0, 1, 1, 0] - if f_s2 != sfreq / 2.: - freq += [sfreq / 2.] + if f_s2 != sfreq / 2.0: + freq += [sfreq / 2.0] gain += [0] if f_s1 != 0: freq = [0] + freq @@ -1053,54 +1433,100 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', # This could possibly be removed after 0.14 release, but might # as well leave it in to sanity check notch_filter if len(l_freq) != len(h_freq): - raise ValueError('l_freq and h_freq must be the same length') - msg = 'Setting up band-stop filter' + raise ValueError("l_freq and h_freq must be the same length") + msg = "Setting up band-stop filter" if len(l_freq) == 1: l_freq, h_freq = l_freq.item(), h_freq.item() - msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq) + msg += " from %0.2g - %0.2g Hz" % (h_freq, l_freq) logger.info(msg) # Note: order of outputs is intentionally switched here! - data, sfreq, f_s1, f_s2, f_p1, f_p2, filter_length, phase, \ - fir_window, fir_design = _triage_filter_params( - data, sfreq, h_freq, l_freq, h_trans_bandwidth, - l_trans_bandwidth, filter_length, method, phase, - fir_window, fir_design, bands='arr', reverse=True) - if method == 'iir': + ( + data, + sfreq, + f_s1, + f_s2, + f_p1, + f_p2, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + h_freq, + l_freq, + h_trans_bandwidth, + l_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + bands="arr", + reverse=True, + ) + if method == "iir": if len(f_p1) != 1: - raise ValueError('Multiple stop-bands can only be used ' - 'with FIR filtering') - out = construct_iir_filter(iir_params, [f_p1[0], f_p2[0]], - [f_s1[0], f_s2[0]], sfreq, - 'bandstop', phase=phase) + raise ValueError( + "Multiple stop-bands can only be used " "with FIR filtering" + ) + out = construct_iir_filter( + iir_params, + [f_p1[0], f_p2[0]], + [f_s1[0], f_s2[0]], + sfreq, + "bandstop", + phase=phase, + ) else: # 'fir' freq = np.r_[f_p1, f_s1, f_s2, f_p2] - gain = np.r_[np.ones_like(f_p1), np.zeros_like(f_s1), - np.zeros_like(f_s2), np.ones_like(f_p2)] + gain = np.r_[ + np.ones_like(f_p1), + np.zeros_like(f_s1), + np.zeros_like(f_s2), + np.ones_like(f_p2), + ] order = np.argsort(freq) freq = freq[order] gain = gain[order] if freq[0] != 0: - freq = np.r_[[0.], freq] - gain = np.r_[[1.], gain] - if freq[-1] != sfreq / 2.: - freq = np.r_[freq, [sfreq / 2.]] - gain = np.r_[gain, [1.]] + freq = np.r_[[0.0], freq] + gain = np.r_[[1.0], gain] + if freq[-1] != sfreq / 2.0: + freq = np.r_[freq, [sfreq / 2.0]] + gain = np.r_[gain, [1.0]] if np.any(np.abs(np.diff(gain, 2)) > 1): - raise ValueError('Stop bands are not sufficiently ' - 'separated.') - if method == 'fir': - out = _construct_fir_filter(sfreq, freq, gain, filter_length, phase, - fir_window, fir_design) + raise ValueError("Stop bands are not sufficiently " "separated.") + if method == "fir": + out = _construct_fir_filter( + sfreq, freq, gain, filter_length, phase, fir_window, fir_design + ) return out @verbose -def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None, - trans_bandwidth=1, method='fir', iir_params=None, - mt_bandwidth=None, p_value=0.05, picks=None, n_jobs=None, - copy=True, phase='zero', fir_window='hamming', - fir_design='firwin', pad='reflect_limited', *, - verbose=None): +def notch_filter( + x, + Fs, + freqs, + filter_length="auto", + notch_widths=None, + trans_bandwidth=1, + method="fir", + iir_params=None, + mt_bandwidth=None, + p_value=0.05, + picks=None, + n_jobs=None, + copy=True, + phase="zero", + fir_window="hamming", + fir_design="firwin", + pad="reflect_limited", + *, + verbose=None, +): r"""Notch filter for the signal x. Applies a zero-phase notch filter to the signal x, operating on the last @@ -1184,42 +1610,65 @@ def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None, & Hemant Bokil, Oxford University Press, New York, 2008. Please cite this in publications if method 'spectrum_fit' is used. """ - x = _check_filterable(x, 'notch filtered', 'notch_filter') - iir_params, method = _check_method(method, iir_params, ['spectrum_fit']) + x = _check_filterable(x, "notch filtered", "notch_filter") + iir_params, method = _check_method(method, iir_params, ["spectrum_fit"]) if freqs is not None: freqs = np.atleast_1d(freqs) - elif method != 'spectrum_fit': - raise ValueError('freqs=None can only be used with method ' - 'spectrum_fit') + elif method != "spectrum_fit": + raise ValueError("freqs=None can only be used with method " "spectrum_fit") # Only have to deal with notch_widths for non-autodetect if freqs is not None: if notch_widths is None: notch_widths = freqs / 200.0 elif np.any(notch_widths < 0): - raise ValueError('notch_widths must be >= 0') + raise ValueError("notch_widths must be >= 0") else: notch_widths = np.atleast_1d(notch_widths) if len(notch_widths) == 1: notch_widths = notch_widths[0] * np.ones_like(freqs) elif len(notch_widths) != len(freqs): - raise ValueError('notch_widths must be None, scalar, or the ' - 'same length as freqs') + raise ValueError( + "notch_widths must be None, scalar, or the " "same length as freqs" + ) - if method in ('fir', 'iir'): + if method in ("fir", "iir"): # Speed this up by computing the fourier coefficients once tb_2 = trans_bandwidth / 2.0 - lows = [freq - nw / 2.0 - tb_2 - for freq, nw in zip(freqs, notch_widths)] - highs = [freq + nw / 2.0 + tb_2 - for freq, nw in zip(freqs, notch_widths)] - xf = filter_data(x, Fs, highs, lows, picks, filter_length, tb_2, tb_2, - n_jobs, method, iir_params, copy, phase, fir_window, - fir_design, pad=pad) - elif method == 'spectrum_fit': - xf = _mt_spectrum_proc(x, Fs, freqs, notch_widths, mt_bandwidth, - p_value, picks, n_jobs, copy, filter_length) + lows = [freq - nw / 2.0 - tb_2 for freq, nw in zip(freqs, notch_widths)] + highs = [freq + nw / 2.0 + tb_2 for freq, nw in zip(freqs, notch_widths)] + xf = filter_data( + x, + Fs, + highs, + lows, + picks, + filter_length, + tb_2, + tb_2, + n_jobs, + method, + iir_params, + copy, + phase, + fir_window, + fir_design, + pad=pad, + ) + elif method == "spectrum_fit": + xf = _mt_spectrum_proc( + x, + Fs, + freqs, + notch_widths, + mt_bandwidth, + p_value, + picks, + n_jobs, + copy, + filter_length, + ) return xf @@ -1230,26 +1679,37 @@ def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value): # figure out what tapers to use window_fun, _, _ = _compute_mt_params( - n_times, sfreq, mt_bandwidth, False, False, verbose=False) + n_times, sfreq, mt_bandwidth, False, False, verbose=False + ) # F-stat of 1-p point threshold = stats.f.ppf(1 - p_value / n_times, 2, 2 * len(window_fun) - 2) return window_fun, threshold -def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth, - p_value, picks, n_jobs, copy, filter_length): +def _mt_spectrum_proc( + x, + sfreq, + line_freqs, + notch_widths, + mt_bandwidth, + p_value, + picks, + n_jobs, + copy, + filter_length, +): """Call _mt_spectrum_remove.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) - if isinstance(filter_length, str) and filter_length == 'auto': - filter_length = '10s' + if isinstance(filter_length, str) and filter_length == "auto": + filter_length = "10s" if filter_length is None: filter_length = x.shape[-1] - filter_length = min(_to_samples(filter_length, sfreq, '', ''), x.shape[-1]) + filter_length = min(_to_samples(filter_length, sfreq, "", ""), x.shape[-1]) get_wt = partial( - _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, - p_value=p_value) + _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, p_value=p_value + ) window_fun, threshold = get_wt(filter_length) parallel, p_fun, n_jobs = parallel_func(_mt_spectrum_remove_win, n_jobs) if n_jobs == 1: @@ -1257,34 +1717,41 @@ def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth, for ii, x_ in enumerate(x): if ii in picks: x[ii], f = _mt_spectrum_remove_win( - x_, sfreq, line_freqs, notch_widths, window_fun, threshold, - get_wt) + x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt + ) freq_list.append(f) else: - data_new = parallel(p_fun(x_, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_wt) - for xi, x_ in enumerate(x) - if xi in picks) + data_new = parallel( + p_fun(x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt) + for xi, x_ in enumerate(x) + if xi in picks + ) freq_list = [d[1] for d in data_new] data_new = np.array([d[0] for d in data_new]) x[picks, :] = data_new # report found frequencies, but do some sanitizing first by binning into # 1 Hz bins - counts = Counter(sum((np.unique(np.round(ff)).tolist() - for f in freq_list for ff in f), list())) - kind = 'Detected' if line_freqs is None else 'Removed' - found_freqs = '\n'.join(f' {freq:6.2f} : ' - f'{counts[freq]:4d} window{_pl(counts[freq])}' - for freq in sorted(counts)) or ' None' - logger.info(f'{kind} notch frequencies (Hz):\n{found_freqs}') + counts = Counter( + sum((np.unique(np.round(ff)).tolist() for f in freq_list for ff in f), list()) + ) + kind = "Detected" if line_freqs is None else "Removed" + found_freqs = ( + "\n".join( + f" {freq:6.2f} : " f"{counts[freq]:4d} window{_pl(counts[freq])}" + for freq in sorted(counts) + ) + or " None" + ) + logger.info(f"{kind} notch frequencies (Hz):\n{found_freqs}") x.shape = orig_shape return x -def _mt_spectrum_remove_win(x, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_thresh): +def _mt_spectrum_remove_win( + x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh +): n_times = x.shape[-1] n_samples = window_fun.shape[1] n_overlap = (n_samples + 1) // 2 @@ -1295,31 +1762,32 @@ def _mt_spectrum_remove_win(x, sfreq, line_freqs, notch_widths, # Define how to process a chunk of data def process(x_): out = _mt_spectrum_remove( - x_, sfreq, line_freqs, notch_widths, window_fun, threshold, - get_thresh) + x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh + ) rm_freqs.append(out[1]) return (out[0],) # must return a tuple # Define how to store a chunk of fully processed data (it's trivial) def store(x_): stop = idx[0] + x_.shape[-1] - x_out[..., idx[0]:stop] += x_ + x_out[..., idx[0] : stop] += x_ idx[0] = stop - _COLA(process, store, n_times, n_samples, n_overlap, sfreq, - verbose=False).feed(x) + _COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x) assert idx[0] == n_times return x_out, rm_freqs -def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_thresh): +def _mt_spectrum_remove( + x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh +): """Use MT-spectrum to remove line frequencies. Based on Chronux. If line_freqs is specified, all freqs within notch_width of each line_freq is set to zero. """ from .time_frequency.multitaper import _mt_spectra + assert x.ndim == 1 if x.shape[-1] != window_fun.shape[-1]: window_fun, threshold = get_thresh(x.shape[-1]) @@ -1342,8 +1810,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, x_p, freqs = _mt_spectra(x[np.newaxis, :], window_fun, sfreq) # sum of the product of x_p and H0 across tapers (1, n_freqs) - x_p_H0 = np.sum(x_p[:, tapers_odd, :] * - H0[np.newaxis, :, np.newaxis], axis=1) + x_p_H0 = np.sum(x_p[:, tapers_odd, :] * H0[np.newaxis, :, np.newaxis], axis=1) # resulting calculated amplitudes for all freqs A = x_p_H0 / H0_sq @@ -1357,8 +1824,9 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, # numerator for F-statistic num = (n_tapers - 1) * (A * A.conj()).real * H0_sq # denominator for F-statistic - den = (np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + - np.sum(np.abs(x_p[:, tapers_even, :]) ** 2, 1)) + den = np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + np.sum( + np.abs(x_p[:, tapers_even, :]) ** 2, 1 + ) den[den == 0] = np.inf f_stat = num / den @@ -1367,10 +1835,11 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, rm_freqs = freqs[indices] else: # specify frequencies - indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) - for lf in line_freqs]) - indices_2 = [np.logical_and(freqs > lf - nw / 2., freqs < lf + nw / 2.) - for lf, nw in zip(line_freqs, notch_widths)] + indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) for lf in line_freqs]) + indices_2 = [ + np.logical_and(freqs > lf - nw / 2.0, freqs < lf + nw / 2.0) + for lf, nw in zip(line_freqs, notch_widths) + ] indices_2 = np.where(np.any(np.array(indices_2), axis=0))[0] indices = np.unique(np.r_[indices_1, indices_2]) rm_freqs = freqs[indices] @@ -1390,7 +1859,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, return x - datafit, rm_freqs -def _check_filterable(x, kind='filtered', alternative='filter'): +def _check_filterable(x, kind="filtered", alternative="filter"): # Let's be fairly strict about this -- users can easily coerce to ndarray # at their end, and we already should do it internally any time we are # using these low-level functions. At the same time, let's @@ -1399,6 +1868,7 @@ def _check_filterable(x, kind='filtered', alternative='filter'): from .io.base import BaseRaw from .epochs import BaseEpochs from .evoked import Evoked + if isinstance(x, (BaseRaw, BaseEpochs, Evoked)): try: name = x.__class__.__name__ @@ -1406,15 +1876,21 @@ def _check_filterable(x, kind='filtered', alternative='filter'): pass else: raise TypeError( - 'This low-level function only operates on np.ndarray ' - f'instances. To get a {kind} {name} instance, use a method ' - f'like `inst_new = inst.copy().{alternative}(...)` ' - 'instead.') - _validate_type(x, (np.ndarray, list, tuple), f'Data to be {kind}') + "This low-level function only operates on np.ndarray " + f"instances. To get a {kind} {name} instance, use a method " + f"like `inst_new = inst.copy().{alternative}(...)` " + "instead." + ) + _validate_type(x, (np.ndarray, list, tuple), f"Data to be {kind}") x = np.asanyarray(x) if x.dtype != np.float64: - raise ValueError('Data to be %s must be real floating, got %s' - % (kind, x.dtype,)) + raise ValueError( + "Data to be %s must be real floating, got %s" + % ( + kind, + x.dtype, + ) + ) return x @@ -1424,8 +1900,18 @@ def _resamp_ratio_len(up, down, n): @verbose -def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', - n_jobs=None, pad='reflect_limited', *, verbose=None): +def resample( + x, + up=1.0, + down=1.0, + npad=100, + axis=-1, + window="boxcar", + n_jobs=None, + pad="reflect_limited", + *, + verbose=None, +): """Resample an array. Operates along the last dimension of the array. @@ -1469,16 +1955,19 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', """ from scipy.signal import get_window from scipy.fft import ifftshift, fftfreq + # check explicitly for backwards compatibility if not isinstance(axis, int): - err = ("The axis parameter needs to be an integer (got %s). " - "The axis parameter was missing from this function for a " - "period of time, you might be intending to specify the " - "subsequent window parameter." % repr(axis)) + err = ( + "The axis parameter needs to be an integer (got %s). " + "The axis parameter was missing from this function for a " + "period of time, you might be intending to specify the " + "subsequent window parameter." % repr(axis) + ) raise TypeError(err) # make sure our arithmetic will work - x = _check_filterable(x, 'resampled', 'resample') + x = _check_filterable(x, "resampled", "resample") ratio, final_len = _resamp_ratio_len(up, down, x.shape[axis]) del up, down if axis < 0: @@ -1489,11 +1978,11 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', orig_shape = x.shape x_len = orig_shape[-1] if x_len == 0: - warn('x has zero length along last axis, returning a copy of x') + warn("x has zero length along last axis, returning a copy of x") return x.copy() bad_msg = 'npad must be "auto" or an integer' if isinstance(npad, str): - if npad != 'auto': + if npad != "auto": raise ValueError(bad_msg) # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 @@ -1520,14 +2009,13 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', if window is not None: if callable(window): W = window(fftfreq(orig_len)) - elif isinstance(window, np.ndarray) and \ - window.shape == (orig_len,): + elif isinstance(window, np.ndarray) and window.shape == (orig_len,): W = window else: W = ifftshift(get_window(window, orig_len)) else: W = np.ones(orig_len) - W *= (float(new_len) / float(orig_len)) + W *= float(new_len) / float(orig_len) # figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_resample(n_jobs, W, new_len) @@ -1538,11 +2026,11 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', if n_jobs == 1: y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x.dtype) for xi, x_ in enumerate(x_flat): - y[xi] = _fft_resample(x_, new_len, npads, to_removes, - cuda_dict, pad) + y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: - y = parallel(p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) - for x_ in x_flat) + y = parallel( + p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) for x_ in x_flat + ) y = np.array(y) # Restore the original array shape (modified for resampling) @@ -1588,8 +2076,7 @@ def _resample_stim_channels(stim_data, up, down): # out-of-bounds, which can happen (having one sample more than # expected) due to padding sample_picks = np.minimum( - (np.arange(resampled_n_samples) / ratio).astype(int), - n_samples - 1 + (np.arange(resampled_n_samples) / ratio).astype(int), n_samples - 1 ) # Create windows starting from sample_picks[i], ending at sample_picks[i+1] @@ -1598,7 +2085,7 @@ def _resample_stim_channels(stim_data, up, down): # Use the first non-zero value in each window for window_i, window in enumerate(windows): for stim_num, stim in enumerate(stim_data): - nonzero = stim[window[0]:window[1]].nonzero()[0] + nonzero = stim[window[0] : window[1]].nonzero()[0] if len(nonzero) > 0: val = stim[window[0] + nonzero[0]] else: @@ -1637,14 +2124,15 @@ def detrend(x, order=1, axis=-1): True """ from scipy.signal import detrend + if axis > len(x.shape): - raise ValueError('x does not have %d axes' % axis) + raise ValueError("x does not have %d axes" % axis) if order == 0: - fit = 'constant' + fit = "constant" elif order == 1: - fit = 'linear' + fit = "linear" else: - raise ValueError('order must be 0 or 1') + raise ValueError("order must be 0 or 1") y = detrend(x, axis=axis, type=fit) @@ -1659,31 +2147,33 @@ def detrend(x, order=1, axis=-1): # (Hamming) then δs = 10 ** (53 / -20.), which means that the passband # deviation should be 20 * np.log10(1 + 10 ** (53 / -20.)) == 0.0194. _fir_window_dict = { - 'hann': dict(name='Hann', ripple=0.0546, attenuation=44), - 'hamming': dict(name='Hamming', ripple=0.0194, attenuation=53), - 'blackman': dict(name='Blackman', ripple=0.0017, attenuation=74), + "hann": dict(name="Hann", ripple=0.0546, attenuation=44), + "hamming": dict(name="Hamming", ripple=0.0194, attenuation=53), + "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) -_known_phases_fir = ('linear', 'zero', 'zero-double', 'minimum') -_known_phases_iir = ('zero', 'zero-double', 'forward') -_known_fir_designs = ('firwin', 'firwin2') +_known_phases_fir = ("linear", "zero", "zero-double", "minimum") +_known_phases_iir = ("zero", "zero-double", "forward") +_known_fir_designs = ("firwin", "firwin2") _fir_design_dict = { - 'firwin': 'Windowed time-domain', - 'firwin2': 'Windowed frequency-domain', + "firwin": "Windowed time-domain", + "firwin2": "Windowed frequency-domain", } def _to_samples(filter_length, sfreq, phase, fir_design): - _validate_type(filter_length, (str, 'int-like'), 'filter_length') + _validate_type(filter_length, (str, "int-like"), "filter_length") if isinstance(filter_length, str): filter_length = filter_length.lower() - err_msg = ('filter_length, if a string, must be a ' - 'human-readable time, e.g. "10s", or "auto", not ' - '"%s"' % filter_length) - if filter_length.lower().endswith('ms'): + err_msg = ( + "filter_length, if a string, must be a " + 'human-readable time, e.g. "10s", or "auto", not ' + '"%s"' % filter_length + ) + if filter_length.lower().endswith("ms"): mult_fact = 1e-3 filter_length = filter_length[:-2] - elif filter_length[-1].lower() == 's': + elif filter_length[-1].lower() == "s": mult_fact = 1 filter_length = filter_length[:-1] else: @@ -1693,54 +2183,62 @@ def _to_samples(filter_length, sfreq, phase, fir_design): filter_length = float(filter_length) except ValueError: raise ValueError(err_msg) - filter_length = max(int(np.ceil(filter_length * mult_fact * - sfreq)), 1) - if fir_design == 'firwin': + filter_length = max(int(np.ceil(filter_length * mult_fact * sfreq)), 1) + if fir_design == "firwin": filter_length += (filter_length - 1) % 2 - filter_length = _ensure_int(filter_length, 'filter_length') + filter_length = _ensure_int(filter_length, "filter_length") return filter_length -def _triage_filter_params(x, sfreq, l_freq, h_freq, - l_trans_bandwidth, h_trans_bandwidth, - filter_length, method, phase, fir_window, - fir_design, bands='scalar', reverse=False): +def _triage_filter_params( + x, + sfreq, + l_freq, + h_freq, + l_trans_bandwidth, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + bands="scalar", + reverse=False, +): """Validate and automate filter parameter selection.""" - _validate_type(phase, 'str', 'phase') - if method == 'fir': - _check_option('phase', phase, _known_phases_fir, - extra='when FIR filtering') + _validate_type(phase, "str", "phase") + if method == "fir": + _check_option("phase", phase, _known_phases_fir, extra="when FIR filtering") else: - _check_option('phase', phase, _known_phases_iir, - extra='when IIR filtering') - _validate_type(fir_window, 'str', 'fir_window') - _check_option('fir_window', fir_window, _known_fir_windows) - _validate_type(fir_design, 'str', 'fir_design') - _check_option('fir_design', fir_design, _known_fir_designs) + _check_option("phase", phase, _known_phases_iir, extra="when IIR filtering") + _validate_type(fir_window, "str", "fir_window") + _check_option("fir_window", fir_window, _known_fir_windows) + _validate_type(fir_design, "str", "fir_design") + _check_option("fir_design", fir_design, _known_fir_designs) # Helpers for reporting - report_phase = 'non-linear phase' if phase == 'minimum' else 'zero-phase' - causality = 'causal' if phase == 'minimum' else 'non-causal' - if phase == 'zero-double': - report_pass = 'two-pass forward and reverse' + report_phase = "non-linear phase" if phase == "minimum" else "zero-phase" + causality = "causal" if phase == "minimum" else "non-causal" + if phase == "zero-double": + report_pass = "two-pass forward and reverse" else: - report_pass = 'one-pass' + report_pass = "one-pass" if l_freq is not None: if h_freq is not None: - kind = 'bandstop' if reverse else 'bandpass' + kind = "bandstop" if reverse else "bandpass" else: - kind = 'highpass' + kind = "highpass" assert not reverse elif h_freq is not None: - kind = 'lowpass' + kind = "lowpass" assert not reverse else: - kind = 'allpass' + kind = "allpass" def float_array(c): return np.array(c, float).ravel() - if bands == 'arr': + if bands == "arr": cast = float_array else: cast = float @@ -1748,164 +2246,193 @@ def float_array(c): if l_freq is not None: l_freq = cast(l_freq) if np.any(l_freq <= 0): - raise ValueError('highpass frequency %s must be greater than zero' - % (l_freq,)) + raise ValueError( + "highpass frequency %s must be greater than zero" % (l_freq,) + ) if h_freq is not None: h_freq = cast(h_freq) - if np.any(h_freq >= sfreq / 2.): - raise ValueError('lowpass frequency %s must be less than Nyquist ' - '(%s)' % (h_freq, sfreq / 2.)) + if np.any(h_freq >= sfreq / 2.0): + raise ValueError( + "lowpass frequency %s must be less than Nyquist " + "(%s)" % (h_freq, sfreq / 2.0) + ) dB_cutoff = False # meaning, don't try to compute or report - if bands == 'scalar' or (len(h_freq) == 1 and len(l_freq) == 1): - if phase == 'zero': - dB_cutoff = '-6 dB' - elif phase == 'zero-double': - dB_cutoff = '-12 dB' + if bands == "scalar" or (len(h_freq) == 1 and len(l_freq) == 1): + if phase == "zero": + dB_cutoff = "-6 dB" + elif phase == "zero-double": + dB_cutoff = "-12 dB" # we go to the next power of two when in FIR and zero-double mode - if method == 'iir': + if method == "iir": # Ignore these parameters, effectively l_stop, h_stop = l_freq, h_freq else: # method == 'fir' l_stop = h_stop = None - logger.info('') - logger.info('FIR filter parameters') - logger.info('---------------------') - logger.info('Designing a %s, %s, %s %s filter:' - % (report_pass, report_phase, causality, kind)) - logger.info('- %s design (%s) method' - % (_fir_design_dict[fir_design], fir_design)) + logger.info("") + logger.info("FIR filter parameters") + logger.info("---------------------") + logger.info( + "Designing a %s, %s, %s %s filter:" + % (report_pass, report_phase, causality, kind) + ) + logger.info( + "- %s design (%s) method" % (_fir_design_dict[fir_design], fir_design) + ) this_dict = _fir_window_dict[fir_window] - if fir_design == 'firwin': - logger.info('- {name:s} window with {ripple:0.4f} passband ripple ' - 'and {attenuation:d} dB stopband attenuation' - .format(**this_dict)) + if fir_design == "firwin": + logger.info( + "- {name:s} window with {ripple:0.4f} passband ripple " + "and {attenuation:d} dB stopband attenuation".format(**this_dict) + ) else: - logger.info('- {name:s} window'.format(**this_dict)) + logger.info("- {name:s} window".format(**this_dict)) if l_freq is not None: # high-pass component if isinstance(l_trans_bandwidth, str): - if l_trans_bandwidth != 'auto': - raise ValueError('l_trans_bandwidth must be "auto" if ' - 'string, got "%s"' % l_trans_bandwidth) - l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.), - l_freq) + if l_trans_bandwidth != "auto": + raise ValueError( + 'l_trans_bandwidth must be "auto" if ' + 'string, got "%s"' % l_trans_bandwidth + ) + l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.0), l_freq) l_trans_rep = np.array(l_trans_bandwidth, float) if l_trans_rep.size == 1: - l_trans_rep = f'{l_trans_rep.item():0.2f}' - with np.printoptions(precision=2, floatmode='fixed'): - msg = f'- Lower transition bandwidth: {l_trans_rep} Hz' + l_trans_rep = f"{l_trans_rep.item():0.2f}" + with np.printoptions(precision=2, floatmode="fixed"): + msg = f"- Lower transition bandwidth: {l_trans_rep} Hz" if dB_cutoff: l_freq_rep = np.array(l_freq, float) if l_freq_rep.size == 1: - l_freq_rep = f'{l_freq_rep.item():0.2f}' - cutoff_rep = np.array( - l_freq - l_trans_bandwidth / 2., float) + l_freq_rep = f"{l_freq_rep.item():0.2f}" + cutoff_rep = np.array(l_freq - l_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: - cutoff_rep = f'{cutoff_rep.item():0.2f}' + cutoff_rep = f"{cutoff_rep.item():0.2f}" # Could be an array - logger.info(f'- Lower passband edge: {l_freq_rep}') - msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' + logger.info(f"- Lower passband edge: {l_freq_rep}") + msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) l_trans_bandwidth = cast(l_trans_bandwidth) if np.any(l_trans_bandwidth <= 0): - raise ValueError('l_trans_bandwidth must be positive, got %s' - % (l_trans_bandwidth,)) + raise ValueError( + "l_trans_bandwidth must be positive, got %s" % (l_trans_bandwidth,) + ) l_stop = l_freq - l_trans_bandwidth if reverse: # band-stop style l_stop += l_trans_bandwidth l_freq += l_trans_bandwidth if np.any(l_stop < 0): - raise ValueError('Filter specification invalid: Lower stop ' - 'frequency negative (%0.2f Hz). Increase pass' - ' frequency or reduce the transition ' - 'bandwidth (l_trans_bandwidth)' % l_stop) + raise ValueError( + "Filter specification invalid: Lower stop " + "frequency negative (%0.2f Hz). Increase pass" + " frequency or reduce the transition " + "bandwidth (l_trans_bandwidth)" % l_stop + ) if h_freq is not None: # low-pass component if isinstance(h_trans_bandwidth, str): - if h_trans_bandwidth != 'auto': - raise ValueError('h_trans_bandwidth must be "auto" if ' - 'string, got "%s"' % h_trans_bandwidth) - h_trans_bandwidth = np.minimum(np.maximum(0.25 * h_freq, 2.), - sfreq / 2. - h_freq) + if h_trans_bandwidth != "auto": + raise ValueError( + 'h_trans_bandwidth must be "auto" if ' + 'string, got "%s"' % h_trans_bandwidth + ) + h_trans_bandwidth = np.minimum( + np.maximum(0.25 * h_freq, 2.0), sfreq / 2.0 - h_freq + ) h_trans_rep = np.array(h_trans_bandwidth, float) if h_trans_rep.size == 1: - h_trans_rep = f'{h_trans_rep.item():0.2f}' - with np.printoptions(precision=2, floatmode='fixed'): - msg = f'- Upper transition bandwidth: {h_trans_rep} Hz' + h_trans_rep = f"{h_trans_rep.item():0.2f}" + with np.printoptions(precision=2, floatmode="fixed"): + msg = f"- Upper transition bandwidth: {h_trans_rep} Hz" if dB_cutoff: h_freq_rep = np.array(h_freq, float) if h_freq_rep.size == 1: - h_freq_rep = f'{h_freq_rep.item():0.2f}' - cutoff_rep = np.array( - h_freq + h_trans_bandwidth / 2., float) + h_freq_rep = f"{h_freq_rep.item():0.2f}" + cutoff_rep = np.array(h_freq + h_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: - cutoff_rep = f'{cutoff_rep.item():0.2f}' - logger.info(f'- Upper passband edge: {h_freq_rep} Hz') - msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' + cutoff_rep = f"{cutoff_rep.item():0.2f}" + logger.info(f"- Upper passband edge: {h_freq_rep} Hz") + msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) h_trans_bandwidth = cast(h_trans_bandwidth) if np.any(h_trans_bandwidth <= 0): - raise ValueError('h_trans_bandwidth must be positive, got %s' - % (h_trans_bandwidth,)) + raise ValueError( + "h_trans_bandwidth must be positive, got %s" % (h_trans_bandwidth,) + ) h_stop = h_freq + h_trans_bandwidth if reverse: # band-stop style h_stop -= h_trans_bandwidth h_freq -= h_trans_bandwidth if np.any(h_stop > sfreq / 2): - raise ValueError('Effective band-stop frequency (%s) is too ' - 'high (maximum based on Nyquist is %s)' - % (h_stop, sfreq / 2.)) + raise ValueError( + "Effective band-stop frequency (%s) is too " + "high (maximum based on Nyquist is %s)" % (h_stop, sfreq / 2.0) + ) - if isinstance(filter_length, str) and filter_length.lower() == 'auto': + if isinstance(filter_length, str) and filter_length.lower() == "auto": filter_length = filter_length.lower() h_check = l_check = np.inf if h_freq is not None: h_check = min(np.atleast_1d(h_trans_bandwidth)) if l_freq is not None: l_check = min(np.atleast_1d(l_trans_bandwidth)) - mult_fact = 2. if fir_design == 'firwin2' else 1. - filter_length = '%ss' % (_length_factors[fir_window] * mult_fact / - float(min(h_check, l_check)),) + mult_fact = 2.0 if fir_design == "firwin2" else 1.0 + filter_length = "%ss" % ( + _length_factors[fir_window] * mult_fact / float(min(h_check, l_check)), + ) next_pow_2 = False # disable old behavior else: - next_pow_2 = ( - isinstance(filter_length, str) and phase == 'zero-double') + next_pow_2 = isinstance(filter_length, str) and phase == "zero-double" filter_length = _to_samples(filter_length, sfreq, phase, fir_design) # use correct type of filter (must be odd length for firwin and for # zero phase) - if fir_design == 'firwin' or phase == 'zero': + if fir_design == "firwin" or phase == "zero": filter_length += (filter_length - 1) % 2 - logger.info('- Filter length: %s samples (%0.3f s)' - % (filter_length, filter_length / sfreq)) - logger.info('') + logger.info( + "- Filter length: %s samples (%0.3f s)" + % (filter_length, filter_length / sfreq) + ) + logger.info("") if filter_length <= 0: - raise ValueError('filter_length must be positive, got %s' - % (filter_length,)) + raise ValueError( + "filter_length must be positive, got %s" % (filter_length,) + ) if next_pow_2: filter_length = 2 ** int(np.ceil(np.log2(filter_length))) - if fir_design == 'firwin': + if fir_design == "firwin": filter_length += (filter_length - 1) % 2 # If we have data supplied, do a sanity check if x is not None: x = _check_filterable(x) len_x = x.shape[-1] - if method != 'fir': + if method != "fir": filter_length = len_x if filter_length > len_x and not (l_freq is None and h_freq is None): - warn('filter_length (%s) is longer than the signal (%s), ' - 'distortion is likely. Reduce filter length or filter a ' - 'longer signal.' % (filter_length, len_x)) - - logger.debug('Using filter length: %s' % filter_length) - return (x, sfreq, l_freq, h_freq, l_stop, h_stop, filter_length, phase, - fir_window, fir_design) + warn( + "filter_length (%s) is longer than the signal (%s), " + "distortion is likely. Reduce filter length or filter a " + "longer signal." % (filter_length, len_x) + ) + + logger.debug("Using filter length: %s" % filter_length) + return ( + x, + sfreq, + l_freq, + h_freq, + l_stop, + h_stop, + filter_length, + phase, + fir_window, + fir_design, + ) class FilterMixin: @@ -1957,26 +2484,40 @@ def savgol_filter(self, h_freq, verbose=None): >>> evoked.plot() # doctest:+SKIP """ # noqa: E501 from scipy.signal import savgol_filter - _check_preload(self, 'inst.savgol_filter') + + _check_preload(self, "inst.savgol_filter") h_freq = float(h_freq) - if h_freq >= self.info['sfreq'] / 2.: - raise ValueError('h_freq must be less than half the sample rate') + if h_freq >= self.info["sfreq"] / 2.0: + raise ValueError("h_freq must be less than half the sample rate") # savitzky-golay filtering - window_length = (int(np.round(self.info['sfreq'] / - h_freq)) // 2) * 2 + 1 - logger.info('Using savgol length %d' % window_length) - self._data[:] = savgol_filter(self._data, axis=-1, polyorder=5, - window_length=window_length) + window_length = (int(np.round(self.info["sfreq"] / h_freq)) // 2) * 2 + 1 + logger.info("Using savgol length %d" % window_length) + self._data[:] = savgol_filter( + self._data, axis=-1, polyorder=5, window_length=window_length + ) return self @verbose - def filter(self, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', n_jobs=None, - method='fir', iir_params=None, phase='zero', - fir_window='hamming', fir_design='firwin', - skip_by_annotation=('edge', 'bad_acq_skip'), pad='edge', *, - verbose=None): + def filter( + self, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + phase="zero", + fir_window="hamming", + fir_design="firwin", + skip_by_annotation=("edge", "bad_acq_skip"), + pad="edge", + *, + verbose=None, + ): """Filter a subset of channels. Parameters @@ -2045,38 +2586,62 @@ def filter(self, l_freq, h_freq, picks=None, filter_length='auto', .. versionadded:: 0.15 """ from .io.base import BaseRaw - _check_preload(self, 'inst.filter') - if pad is None and method != 'iir': - pad = 'edge' - update_info, picks = _filt_check_picks(self.info, picks, - l_freq, h_freq) + + _check_preload(self, "inst.filter") + if pad is None and method != "iir": + pad = "edge" + update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq) if isinstance(self, BaseRaw): # Deal with annotations onsets, ends = _annotations_starts_stops( - self, skip_by_annotation, invert=True) - logger.info('Filtering raw data in %d contiguous segment%s' - % (len(onsets), _pl(onsets))) + self, skip_by_annotation, invert=True + ) + logger.info( + "Filtering raw data in %d contiguous segment%s" + % (len(onsets), _pl(onsets)) + ) else: onsets, ends = np.array([0]), np.array([self._data.shape[1]]) max_idx = (ends - onsets).argmax() for si, (start, stop) in enumerate(zip(onsets, ends)): # Only output filter params once (for info level), and only warn # once about the length criterion (longest segment is too short) - use_verbose = verbose if si == max_idx else 'error' + use_verbose = verbose if si == max_idx else "error" filter_data( - self._data[:, start:stop], self.info['sfreq'], l_freq, h_freq, - picks, filter_length, l_trans_bandwidth, h_trans_bandwidth, - n_jobs, method, iir_params, copy=False, phase=phase, - fir_window=fir_window, fir_design=fir_design, pad=pad, - verbose=use_verbose) + self._data[:, start:stop], + self.info["sfreq"], + l_freq, + h_freq, + picks, + filter_length, + l_trans_bandwidth, + h_trans_bandwidth, + n_jobs, + method, + iir_params, + copy=False, + phase=phase, + fir_window=fir_window, + fir_design=fir_design, + pad=pad, + verbose=use_verbose, + ) # update info if filter is applied to all data channels, # and it's not a band-stop filter _filt_update_info(self.info, update_info, l_freq, h_freq) return self @verbose - def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, - pad='edge', *, verbose=None): + def resample( + self, + sfreq, + npad="auto", + window="boxcar", + n_jobs=None, + pad="edge", + *, + verbose=None, + ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. @@ -2114,23 +2679,26 @@ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, """ from .epochs import BaseEpochs from .evoked import Evoked + # Should be guaranteed by our inheritance, and the fact that # mne.io.base.BaseRaw overrides this method assert isinstance(self, (BaseEpochs, Evoked)) - _check_preload(self, 'inst.resample') + _check_preload(self, "inst.resample") sfreq = float(sfreq) - o_sfreq = self.info['sfreq'] - self._data = resample(self._data, sfreq, o_sfreq, npad, window=window, - n_jobs=n_jobs, pad=pad) - lowpass = self.info.get('lowpass') + o_sfreq = self.info["sfreq"] + self._data = resample( + self._data, sfreq, o_sfreq, npad, window=window, n_jobs=n_jobs, pad=pad + ) + lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass with self.info._unlock(): - self.info['lowpass'] = min(lowpass, sfreq / 2.) - self.info['sfreq'] = float(sfreq) - new_times = (np.arange(self._data.shape[-1], dtype=np.float64) / - sfreq + self.times[0]) + self.info["lowpass"] = min(lowpass, sfreq / 2.0) + self.info["sfreq"] = float(sfreq) + new_times = ( + np.arange(self._data.shape[-1], dtype=np.float64) / sfreq + self.times[0] + ) # adjust indirectly affected variables self._set_times(new_times) self._raw_times = self.times @@ -2138,8 +2706,9 @@ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, return self @verbose - def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, - n_fft='auto', *, verbose=None): + def apply_hilbert( + self, picks=None, envelope=False, n_jobs=None, n_fft="auto", *, verbose=None + ): """Compute analytic signal or envelope for a subset of channels. Parameters @@ -2203,18 +2772,22 @@ def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, by computing the analytic signal in sensor space, applying the MNE inverse, and computing the envelope in source space. """ - _check_preload(self, 'inst.apply_hilbert') + _check_preload(self, "inst.apply_hilbert") if n_fft is None: n_fft = len(self.times) elif isinstance(n_fft, str): - if n_fft != 'auto': - raise ValueError('n_fft must be an integer, string, or None, ' - 'got %s' % (type(n_fft),)) + if n_fft != "auto": + raise ValueError( + "n_fft must be an integer, string, or None, " + "got %s" % (type(n_fft),) + ) n_fft = next_fast_len(len(self.times)) n_fft = int(n_fft) if n_fft < len(self.times): - raise ValueError("n_fft (%d) must be at least the number of time " - "points (%d)" % (n_fft, len(self.times))) + raise ValueError( + "n_fft (%d) must be at least the number of time " + "points (%d)" % (n_fft, len(self.times)) + ) dtype = None if envelope else np.complex128 picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) args, kwargs = (), dict(n_fft=n_fft, envelope=envelope) @@ -2228,12 +2801,13 @@ def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, # modify data inplace to save memory for idx in picks: self._data[..., idx, :] = _check_fun( - _my_hilbert, data_in[..., idx, :], *args, **kwargs) + _my_hilbert, data_in[..., idx, :], *args, **kwargs + ) else: # use parallel function data_picks_new = parallel( - p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) - for p in picks) + p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[..., p, :] = data_picks_new[pp] return self @@ -2244,10 +2818,11 @@ def _check_fun(fun, d, *args, **kwargs): want_shape = d.shape d = fun(d, *args, **kwargs) if not isinstance(d, np.ndarray): - raise TypeError('Return value must be an ndarray') + raise TypeError("Return value must be an ndarray") if d.shape != want_shape: - raise ValueError('Return data must have shape %s not %s' - % (want_shape, d.shape)) + raise ValueError( + "Return data must have shape %s not %s" % (want_shape, d.shape) + ) return d @@ -2271,6 +2846,7 @@ def _my_hilbert(x, n_fft=None, envelope=False): The hilbert transform of the signal, or the envelope. """ from scipy.signal import hilbert + n_x = x.shape[-1] out = hilbert(x, N=n_fft, axis=-1)[..., :n_x] if envelope: @@ -2279,9 +2855,14 @@ def _my_hilbert(x, n_fft=None, envelope=False): @verbose -def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., - l_trans_bandwidth=None, h_trans_bandwidth=5., - verbose=None): +def design_mne_c_filter( + sfreq, + l_freq=None, + h_freq=40.0, + l_trans_bandwidth=None, + h_trans_bandwidth=5.0, + verbose=None, +): """Create a FIR filter like that used by MNE-C. Parameters @@ -2315,39 +2896,39 @@ def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., and ones in the passband, with squared cosine ramps in between. """ from scipy.fft import irfft + n_freqs = (4096 + 2 * 2048) // 2 + 1 freq_resp = np.ones(n_freqs) l_freq = 0 if l_freq is None else float(l_freq) if l_trans_bandwidth is None: l_width = 3 else: - l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / - (0.5 * sfreq)) + 1) // 2 + l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 l_start = int(((n_freqs - 1) * l_freq) / (0.5 * sfreq)) - h_freq = sfreq / 2. if h_freq is None else float(h_freq) - h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / - (0.5 * sfreq)) + 1) // 2 + h_freq = sfreq / 2.0 if h_freq is None else float(h_freq) + h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) - logger.info('filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d ' - 'hpw : %d lpw : %d' % (l_freq, h_freq, l_start, h_start, - n_freqs, l_width, h_width)) + logger.info( + "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " + "hpw : %d lpw : %d" + % (l_freq, h_freq, l_start, h_start, n_freqs, l_width, h_width) + ) if l_freq > 0: start = l_start - l_width + 1 stop = start + 2 * l_width - 1 if start < 0 or stop >= n_freqs: - raise RuntimeError('l_freq too low or l_trans_bandwidth too large') - freq_resp[:start] = 0. - k = np.arange(-l_width + 1, l_width) / float(l_width) + 3. - freq_resp[start:stop] = np.cos(np.pi / 4. * k) ** 2 + raise RuntimeError("l_freq too low or l_trans_bandwidth too large") + freq_resp[:start] = 0.0 + k = np.arange(-l_width + 1, l_width) / float(l_width) + 3.0 + freq_resp[start:stop] = np.cos(np.pi / 4.0 * k) ** 2 - if h_freq < sfreq / 2.: + if h_freq < sfreq / 2.0: start = h_start - h_width + 1 stop = start + 2 * h_width - 1 if start < 0 or stop >= n_freqs: - raise RuntimeError('h_freq too high or h_trans_bandwidth too ' - 'large') - k = np.arange(-h_width + 1, h_width) / float(h_width) + 1. - freq_resp[start:stop] *= np.cos(np.pi / 4. * k) ** 2 + raise RuntimeError("h_freq too high or h_trans_bandwidth too " "large") + k = np.arange(-h_width + 1, h_width) / float(h_width) + 1.0 + freq_resp[start:stop] *= np.cos(np.pi / 4.0 * k) ** 2 freq_resp[stop:] = 0.0 # Get the time-domain version of this signal h = irfft(freq_resp, n=2 * len(freq_resp) - 1) @@ -2357,32 +2938,44 @@ def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., def _filt_check_picks(info, picks, h_freq, l_freq): from .io.pick import _picks_to_idx + update_info = False # This will pick *all* data channels - picks = _picks_to_idx(info, picks, 'data_or_ica', exclude=()) + picks = _picks_to_idx(info, picks, "data_or_ica", exclude=()) if h_freq is not None or l_freq is not None: - data_picks = _picks_to_idx(info, None, 'data_or_ica', exclude=(), - allow_empty=True) + data_picks = _picks_to_idx( + info, None, "data_or_ica", exclude=(), allow_empty=True + ) if len(data_picks) == 0: - logger.info('No data channels found. The highpass and ' - 'lowpass values in the measurement info will not ' - 'be updated.') + logger.info( + "No data channels found. The highpass and " + "lowpass values in the measurement info will not " + "be updated." + ) elif np.in1d(data_picks, picks).all(): update_info = True else: - logger.info('Filtering a subset of channels. The highpass and ' - 'lowpass values in the measurement info will not ' - 'be updated.') + logger.info( + "Filtering a subset of channels. The highpass and " + "lowpass values in the measurement info will not " + "be updated." + ) return update_info, picks def _filt_update_info(info, update_info, l_freq, h_freq): if update_info: - if h_freq is not None and (l_freq is None or l_freq < h_freq) and \ - (info["lowpass"] is None or h_freq < info['lowpass']): + if ( + h_freq is not None + and (l_freq is None or l_freq < h_freq) + and (info["lowpass"] is None or h_freq < info["lowpass"]) + ): with info._unlock(): - info['lowpass'] = float(h_freq) - if l_freq is not None and (h_freq is None or l_freq < h_freq) and \ - (info["highpass"] is None or l_freq > info['highpass']): + info["lowpass"] = float(h_freq) + if ( + l_freq is not None + and (h_freq is None or l_freq < h_freq) + and (info["highpass"] is None or l_freq > info["highpass"]) + ): with info._unlock(): - info['highpass'] = float(l_freq) + info["highpass"] = float(l_freq) diff --git a/mne/fixes.py b/mne/fixes.py index b59439b3b88..c05dfaec344 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -30,6 +30,7 @@ # from the standard library with the release of Python 3.12. For version # comparisons, we use setuptools's `parse_version` if available. + def _compare_version(version_a, operator, version_b): """Compare two version strings via a user-specified operator. @@ -49,14 +50,16 @@ def _compare_version(version_a, operator, version_b): The result of the version comparison. """ from packaging.version import parse + with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") return eval(f'parse("{version_a}") {operator} parse("{version_b}")') ############################################################################### # Misc + def _median_complex(data, axis): """Compute marginal median on complex data safely. @@ -65,8 +68,9 @@ def _median_complex(data, axis): """ # np.median must be passed real arrays for the desired result if np.iscomplexobj(data): - data = (np.median(np.real(data), axis=axis) - + 1j * np.median(np.imag(data), axis=axis)) + data = np.median(np.real(data), axis=axis) + 1j * np.median( + np.imag(data), axis=axis + ) else: data = np.median(data, axis=axis) return data @@ -79,19 +83,21 @@ def _safe_svd(A, **kwargs): # For SciPy 0.18 and up, we can work around it by using # lapack_driver='gesvd' instead. from scipy import linalg - if kwargs.get('overwrite_a', False): - raise ValueError('Cannot set overwrite_a=True with this function') + + if kwargs.get("overwrite_a", False): + raise ValueError("Cannot set overwrite_a=True with this function") try: return linalg.svd(A, **kwargs) except np.linalg.LinAlgError as exp: from .utils import warn - warn('SVD error (%s), attempting to use GESVD instead of GESDD' - % (exp,)) - return linalg.svd(A, lapack_driver='gesvd', **kwargs) + + warn("SVD error (%s), attempting to use GESVD instead of GESDD" % (exp,)) + return linalg.svd(A, lapack_driver="gesvd", **kwargs) def _csc_matrix_cast(x): from scipy.sparse import csc_matrix + return csc_matrix(x) @@ -102,25 +108,26 @@ def _csc_matrix_cast(x): def rng_uniform(rng): """Get the unform/randint from the rng.""" # prefer Generator.integers, fall back to RandomState.randint - return getattr(rng, 'integers', getattr(rng, 'randint', None)) + return getattr(rng, "integers", getattr(rng, "randint", None)) def _validate_sos(sos): """Helper to validate a SOS input""" sos = np.atleast_2d(sos) if sos.ndim != 2: - raise ValueError('sos array must be 2D') + raise ValueError("sos array must be 2D") n_sections, m = sos.shape if m != 6: - raise ValueError('sos array must be shape (n_sections, 6)') + raise ValueError("sos array must be shape (n_sections, 6)") if not (sos[:, 3] == 1).all(): - raise ValueError('sos[:, 3] should be all ones') + raise ValueError("sos[:, 3] should be all ones") return sos, n_sections ############################################################################### # Misc utilities + # get_fdata() requires knowing the dtype ahead of time, so let's triage on our # own instead def _get_img_fdata(img): @@ -134,22 +141,30 @@ def _read_volume_info(fobj): versions of nibabel (<=2.1.0) don't have it. """ volume_info = dict() - head = np.fromfile(fobj, '>i4', 1) + head = np.fromfile(fobj, ">i4", 1) if not np.array_equal(head, [20]): # Read two bytes more - head = np.concatenate([head, np.fromfile(fobj, '>i4', 2)]) + head = np.concatenate([head, np.fromfile(fobj, ">i4", 2)]) if not np.array_equal(head, [2, 0, 20]): warnings.warn("Unknown extension code.") return volume_info - volume_info['head'] = head - for key in ['valid', 'filename', 'volume', 'voxelsize', 'xras', 'yras', - 'zras', 'cras']: - pair = fobj.readline().decode('utf-8').split('=') + volume_info["head"] = head + for key in [ + "valid", + "filename", + "volume", + "voxelsize", + "xras", + "yras", + "zras", + "cras", + ]: + pair = fobj.readline().decode("utf-8").split("=") if pair[0].strip() != key or len(pair) != 2: - raise OSError('Error parsing volume info.') - if key in ('valid', 'filename'): + raise OSError("Error parsing volume info.") + if key in ("valid", "filename"): volume_info[key] = pair[1].strip() - elif key == 'volume': + elif key == "volume": volume_info[key] = np.array(pair[1].split()).astype(int) else: volume_info[key] = np.array(pair[1].split()).astype(float) @@ -194,24 +209,24 @@ def is_regressor(estimator): _DEFAULT_TAGS = { - 'non_deterministic': False, - 'requires_positive_X': False, - 'requires_positive_y': False, - 'X_types': ['2darray'], - 'poor_score': False, - 'no_validation': False, - 'multioutput': False, + "non_deterministic": False, + "requires_positive_X": False, + "requires_positive_y": False, + "X_types": ["2darray"], + "poor_score": False, + "no_validation": False, + "multioutput": False, "allow_nan": False, - 'stateless': False, - 'multilabel': False, - '_skip_test': False, - '_xfail_checks': False, - 'multioutput_only': False, - 'binary_only': False, - 'requires_fit': True, - 'preserves_dtype': [np.float64], - 'requires_y': False, - 'pairwise': False, + "stateless": False, + "multilabel": False, + "_skip_test": False, + "_xfail_checks": False, + "multioutput_only": False, + "binary_only": False, + "requires_fit": True, + "preserves_dtype": [np.float64], + "requires_y": False, + "pairwise": False, } @@ -230,7 +245,7 @@ def _get_param_names(cls): """Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + init = getattr(cls.__init__, "deprecated_original", cls.__init__) if init is object.__init__: # No explicit constructor to introspect return [] @@ -239,16 +254,20 @@ def _get_param_names(cls): # to represent init_signature = inspect.signature(init) # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] + parameters = [ + p + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] for p in parameters: if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) + raise RuntimeError( + "scikit-learn estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." % (cls, init_signature) + ) # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) @@ -283,9 +302,9 @@ def get_params(self, deep=True): warnings.filters.pop(0) # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): + if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) + out.update((key + "__" + k, val) for k, val in deep_items) out[key] = value return out @@ -312,24 +331,27 @@ def set_params(self, **params): return self valid_params = self.get_params(deep=True) for key, value in params.items(): - split = key.split('__', 1) + split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." % (name, self) + ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." + % (key, self.__class__.__name__) + ) setattr(self, key, value) return self @@ -338,7 +360,7 @@ def __repr__(self): pprint(self.get_params(deep=False), params) params.seek(0) class_name = self.__class__.__name__ - return '%s(%s)' % (class_name, params.read().strip()) + return "%s(%s)" % (class_name, params.read().strip()) # __getstate__ and __setstate__ are omitted because they only contain # conditionals that are not satisfied by our objects (e.g., @@ -350,7 +372,7 @@ def _more_tags(self): def _get_tags(self): collected_tags = {} for base_class in reversed(inspect.getmro(self.__class__)): - if hasattr(base_class, '_more_tags'): + if hasattr(base_class, "_more_tags"): # need the if because mixins might not have _more_tags # but might do redundant work in estimators # (i.e. calling more tags on BaseEstimator multiple times) @@ -391,21 +413,25 @@ def _check_fit_params(X, fit_params, indices=None): indexing. """ try: - from sklearn.utils.validation import \ - _check_fit_params as _sklearn_check_fit_params + from sklearn.utils.validation import ( + _check_fit_params as _sklearn_check_fit_params, + ) + return _sklearn_check_fit_params(X, fit_params, indices) except ImportError: from sklearn.model_selection import _validation - fit_params_validated = \ - {k: _validation._index_param_value(X, v, indices) - for k, v in fit_params.items()} + fit_params_validated = { + k: _validation._index_param_value(X, v, indices) + for k, v in fit_params.items() + } return fit_params_validated ############################################################################### # Copied from sklearn to simplify code paths + def empirical_covariance(X, assume_centered=False): """Computes the Maximum likelihood covariance estimator @@ -432,8 +458,9 @@ def empirical_covariance(X, assume_centered=False): X = np.reshape(X, (1, -1)) if X.shape[0] == 1: - warnings.warn("Only one sample available. " - "You may want to reshape your data array") + warnings.warn( + "Only one sample available. " "You may want to reshape your data array" + ) if assume_centered: covariance = np.dot(X.T, X) / X.shape[0] @@ -471,6 +498,7 @@ class EmpiricalCovariance(BaseEstimator): (stored only if store_precision is True) """ + def __init__(self, store_precision=True, assume_centered=False): self.store_precision = store_precision self.assume_centered = assume_centered @@ -489,6 +517,7 @@ def _set_covariance(self, covariance): """ from scipy import linalg + # covariance = check_array(covariance) # set covariance self.covariance_ = covariance @@ -508,6 +537,7 @@ def get_precision(self): """ from scipy import linalg + if self.store_precision: precision = self.precision_ else: @@ -535,8 +565,7 @@ def fit(self, X, y=None): self.location_ = np.zeros(X.shape[1]) else: self.location_ = X.mean(0) - covariance = empirical_covariance( - X, assume_centered=self.assume_centered) + covariance = empirical_covariance(X, assume_centered=self.assume_centered) self._set_covariance(covariance) return self @@ -563,15 +592,13 @@ def score(self, X_test, y=None): estimator of its covariance matrix. """ # compute empirical covariance of the test set - test_cov = empirical_covariance( - X_test - self.location_, assume_centered=True) + test_cov = empirical_covariance(X_test - self.location_, assume_centered=True) # compute log likelihood res = log_likelihood(test_cov, self.get_precision()) return res - def error_norm(self, comp_cov, norm='frobenius', scaling=True, - squared=True): + def error_norm(self, comp_cov, norm="frobenius", scaling=True, squared=True): """Computes the Mean Squared Error between two covariance estimators. Parameters @@ -597,16 +624,18 @@ def error_norm(self, comp_cov, norm='frobenius', scaling=True, `self` and `comp_cov` covariance estimators. """ from scipy import linalg + # compute the error error = comp_cov - self.covariance_ # compute the error norm if norm == "frobenius": - squared_norm = np.sum(error ** 2) + squared_norm = np.sum(error**2) elif norm == "spectral": squared_norm = np.amax(linalg.svdvals(np.dot(error.T, error))) else: raise NotImplementedError( - "Only spectral and frobenius norms are implemented") + "Only spectral and frobenius norms are implemented" + ) # optionally scale the error norm if scaling: squared_norm = squared_norm / error.shape[0] @@ -637,8 +666,7 @@ def mahalanobis(self, observations): precision = self.get_precision() # compute mahalanobis distances centered_obs = observations - self.location_ - mahalanobis_dist = np.sum( - np.dot(centered_obs, precision) * centered_obs, 1) + mahalanobis_dist = np.sum(np.dot(centered_obs, precision) * centered_obs, 1) return mahalanobis_dist @@ -663,17 +691,19 @@ def log_likelihood(emp_cov, precision): sample mean of the log-likelihood """ p = precision.shape[0] - log_likelihood_ = - np.sum(emp_cov * precision) + _logdet(precision) + log_likelihood_ = -np.sum(emp_cov * precision) + _logdet(precision) log_likelihood_ -= p * np.log(2 * np.pi) - log_likelihood_ /= 2. + log_likelihood_ /= 2.0 return log_likelihood_ # sklearn uses np.linalg for this, but ours is more robust to zero eigenvalues + def _logdet(A): """Compute the log det of a positive semidefinite matrix.""" from scipy import linalg + vals = linalg.eigvalsh(A) # avoid negative (numerical errors) or zero (semi-definite matrix) values tol = vals.max() * vals.size * np.finfo(np.float64).eps @@ -694,37 +724,37 @@ def _infer_dimension_(spectrum, n_samples, n_features): def _assess_dimension_(spectrum, rank, n_samples, n_features): from scipy.special import gammaln + if rank > len(spectrum): - raise ValueError("The tested rank cannot exceed the rank of the" - " dataset") + raise ValueError("The tested rank cannot exceed the rank of the" " dataset") - pu = -rank * log(2.) + pu = -rank * log(2.0) for i in range(rank): - pu += (gammaln((n_features - i) / 2.) - - log(np.pi) * (n_features - i) / 2.) + pu += gammaln((n_features - i) / 2.0) - log(np.pi) * (n_features - i) / 2.0 pl = np.sum(np.log(spectrum[:rank])) - pl = -pl * n_samples / 2. + pl = -pl * n_samples / 2.0 if rank == n_features: pv = 0 v = 1 else: v = np.sum(spectrum[rank:]) / (n_features - rank) - pv = -np.log(v) * n_samples * (n_features - rank) / 2. + pv = -np.log(v) * n_samples * (n_features - rank) / 2.0 - m = n_features * rank - rank * (rank + 1.) / 2. - pp = log(2. * np.pi) * (m + rank + 1.) / 2. + m = n_features * rank - rank * (rank + 1.0) / 2.0 + pp = log(2.0 * np.pi) * (m + rank + 1.0) / 2.0 - pa = 0. + pa = 0.0 spectrum_ = spectrum.copy() spectrum_[rank:n_features] = v for i in range(rank): for j in range(i + 1, len(spectrum)): - pa += log((spectrum[i] - spectrum[j]) * - (1. / spectrum_[j] - 1. / spectrum_[i])) + log(n_samples) + pa += log( + (spectrum[i] - spectrum[j]) * (1.0 / spectrum_[j] - 1.0 / spectrum_[i]) + ) + log(n_samples) - ll = pu + pl + pv + pp - pa / 2. - rank * log(n_samples) / 2. + ll = pu + pl + pv + pp - pa / 2.0 - rank * log(n_samples) / 2.0 return ll @@ -762,23 +792,30 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """ out = np.cumsum(arr, axis=axis, dtype=np.float64) expected = np.sum(arr, axis=axis, dtype=np.float64) - if not np.all(np.isclose(out.take(-1, axis=axis), expected, rtol=rtol, - atol=atol, equal_nan=True)): - warnings.warn('cumsum was found to be unstable: ' - 'its last element does not correspond to sum', - RuntimeWarning) + if not np.all( + np.isclose( + out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True + ) + ): + warnings.warn( + "cumsum was found to be unstable: " + "its last element does not correspond to sum", + RuntimeWarning, + ) return out ############################################################################### # From nilearn + def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): """ crop a colorbar to show from cbar_vmin to cbar_vmax Used when symmetric_cbar=False is used. """ import matplotlib + if (cbar_vmin is None) and (cbar_vmax is None): return cbar_tick_locs = cbar.locator.locs @@ -786,8 +823,7 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): cbar_vmax = cbar_tick_locs.max() if cbar_vmin is None: cbar_vmin = cbar_tick_locs.min() - new_tick_locs = np.linspace(cbar_vmin, cbar_vmax, - len(cbar_tick_locs)) + new_tick_locs = np.linspace(cbar_vmin, cbar_vmax, len(cbar_tick_locs)) cbar.ax.set_ylim(cbar_vmin, cbar_vmax) X = cbar._mesh()[0] @@ -797,9 +833,11 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): ii = [0, 1, N - 2, N - 1, 2 * N - 1, 2 * N - 2, N + 1, N, 0] x = X.T.reshape(-1)[ii] y = Y.T.reshape(-1)[ii] - xy = (np.column_stack([y, x]) - if cbar.orientation == 'horizontal' else - np.column_stack([x, y])) + xy = ( + np.column_stack([y, x]) + if cbar.orientation == "horizontal" + else np.column_stack([x, y]) + ) cbar.outline.set_xy(xy) cbar.set_ticks(new_tick_locs) @@ -812,29 +850,36 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): # Here we choose different defaults to speed things up by default try: import numba - if _compare_version(numba.__version__, '<', '0.53.1'): + + if _compare_version(numba.__version__, "<", "0.53.1"): raise ImportError prange = numba.prange - def jit(nopython=True, nogil=True, fastmath=True, cache=True, - **kwargs): # noqa - return numba.jit(nopython=nopython, nogil=nogil, fastmath=fastmath, - cache=cache, **kwargs) + + def jit(nopython=True, nogil=True, fastmath=True, cache=True, **kwargs): # noqa + return numba.jit( + nopython=nopython, nogil=nogil, fastmath=fastmath, cache=cache, **kwargs + ) + except Exception: # could be ImportError, SystemError, etc. has_numba = False else: - has_numba = (os.getenv('MNE_USE_NUMBA', 'true').lower() == 'true') + has_numba = os.getenv("MNE_USE_NUMBA", "true").lower() == "true" if not has_numba: + def jit(**kwargs): # noqa def _jit(func): return func + return _jit + prange = range bincount = np.bincount mean = np.mean else: + @jit() def bincount(x, weights, minlength): # noqa: D103 out = np.zeros(minlength) @@ -865,6 +910,7 @@ def mean(array, axis): ############################################################################### # Matplotlib + # workaround: plt.close() doesn't spawn close_event on Agg backend # https://github.com/matplotlib/matplotlib/issues/18609 # scheduled to be fixed by MPL 3.6 @@ -872,13 +918,15 @@ def _close_event(fig): """Force calling of the MPL figure close event.""" from .utils import logger from matplotlib import backend_bases + try: fig.canvas.callbacks.process( - 'close_event', backend_bases.CloseEvent( - name='close_event', canvas=fig.canvas)) - logger.debug(f'Called {fig!r}.canvas.close_event()') + "close_event", + backend_bases.CloseEvent(name="close_event", canvas=fig.canvas), + ) + logger.debug(f"Called {fig!r}.canvas.close_event()") except ValueError: # old mpl with Qt - logger.debug(f'Calling {fig!r}.canvas.close_event() failed') + logger.debug(f"Calling {fig!r}.canvas.close_event() failed") pass # pragma: no cover @@ -891,7 +939,7 @@ def _is_last_row(ax): def _sharex(ax1, ax2): - if hasattr(ax1.axes, 'sharex'): + if hasattr(ax1.axes, "sharex"): ax1.axes.sharex(ax2) else: ax1.get_shared_x_axes().join(ax1, ax2) @@ -900,6 +948,7 @@ def _sharex(ax1, ax2): ############################################################################### # SciPy deprecation of pinv + pinvh rcond (never worked properly anyway) in 1.7 + def pinvh(a, rtol=None): """Compute a pseudo-inverse of a Hermitian matrix.""" s, u = np.linalg.eigh(a) @@ -907,7 +956,7 @@ def pinvh(a, rtol=None): if rtol is None: rtol = s.size * np.finfo(s.dtype).eps maxS = np.max(np.abs(s)) - above_cutoff = (abs(s) > maxS * rtol) + above_cutoff = abs(s) > maxS * rtol psigma_diag = 1.0 / s[above_cutoff] u = u[:, above_cutoff] return (u * psigma_diag) @ u.conj().T @@ -929,9 +978,12 @@ def pinv(a, rtol=None): ############################################################################### # h5py uses np.product which is deprecated in NumPy 1.25 + @contextmanager def _numpy_h5py_dep(): # h5io uses np.product with warnings.catch_warnings(record=True): - warnings.filterwarnings('ignore', '`product` is deprecated.*', DeprecationWarning) + warnings.filterwarnings( + "ignore", "`product` is deprecated.*", DeprecationWarning + ) yield diff --git a/mne/forward/__init__.py b/mne/forward/__init__.py index 83788b8f706..c5fbeced9a4 100644 --- a/mne/forward/__init__.py +++ b/mne/forward/__init__.py @@ -1,22 +1,48 @@ """Forward modeling code.""" -from .forward import (Forward, read_forward_solution, write_forward_solution, - is_fixed_orient, _read_forward_meas_info, - _select_orient_forward, - compute_orient_prior, compute_depth_prior, - apply_forward, apply_forward_raw, - restrict_forward_to_stc, restrict_forward_to_label, - average_forward_solutions, _stc_src_sel, - _fill_measurement_info, _apply_forward, - _subject_from_forward, convert_forward_solution, - _merge_fwds, _do_forward_solution) -from ._make_forward import (make_forward_solution, _prepare_for_forward, - _prep_meg_channels, _prep_eeg_channels, - _to_forward_dict, _create_meg_coils, - _read_coil_defs, _transform_orig_meg_coils, - make_forward_dipole, use_coil_def) -from ._compute_forward import (_magnetic_dipole_field_vec, _compute_forwards, - _concatenate_coils) -from ._field_interpolation import (_make_surface_mapping, make_field_map, - _as_meg_type_inst, _map_meg_or_eeg_channels) +from .forward import ( + Forward, + read_forward_solution, + write_forward_solution, + is_fixed_orient, + _read_forward_meas_info, + _select_orient_forward, + compute_orient_prior, + compute_depth_prior, + apply_forward, + apply_forward_raw, + restrict_forward_to_stc, + restrict_forward_to_label, + average_forward_solutions, + _stc_src_sel, + _fill_measurement_info, + _apply_forward, + _subject_from_forward, + convert_forward_solution, + _merge_fwds, + _do_forward_solution, +) +from ._make_forward import ( + make_forward_solution, + _prepare_for_forward, + _prep_meg_channels, + _prep_eeg_channels, + _to_forward_dict, + _create_meg_coils, + _read_coil_defs, + _transform_orig_meg_coils, + make_forward_dipole, + use_coil_def, +) +from ._compute_forward import ( + _magnetic_dipole_field_vec, + _compute_forwards, + _concatenate_coils, +) +from ._field_interpolation import ( + _make_surface_mapping, + make_field_map, + _as_meg_type_inst, + _map_meg_or_eeg_channels, +) from . import _lead_dots # for testing purposes diff --git a/mne/forward/_compute_forward.py b/mne/forward/_compute_forward.py index 9b4ee7dba1c..5fe55906220 100644 --- a/mne/forward/_compute_forward.py +++ b/mne/forward/_compute_forward.py @@ -30,21 +30,22 @@ # ############################################################################# # COIL SPECIFICATION AND FIELD COMPUTATION MATRIX + def _dup_coil_set(coils, coord_frame, t): """Make a duplicate.""" - if t is not None and coord_frame != t['from']: - raise RuntimeError('transformation frame does not match the coil set') + if t is not None and coord_frame != t["from"]: + raise RuntimeError("transformation frame does not match the coil set") coils = deepcopy(coils) if t is not None: - coord_frame = t['to'] + coord_frame = t["to"] for coil in coils: - for key in ('ex', 'ey', 'ez'): + for key in ("ex", "ey", "ez"): if key in coil: - coil[key] = apply_trans(t['trans'], coil[key], False) - coil['r0'] = apply_trans(t['trans'], coil['r0']) - coil['rmag'] = apply_trans(t['trans'], coil['rmag']) - coil['cosmag'] = apply_trans(t['trans'], coil['cosmag'], False) - coil['coord_frame'] = t['to'] + coil[key] = apply_trans(t["trans"], coil[key], False) + coil["r0"] = apply_trans(t["trans"], coil["r0"]) + coil["rmag"] = apply_trans(t["trans"], coil["rmag"]) + coil["cosmag"] = apply_trans(t["trans"], coil["cosmag"], False) + coil["coord_frame"] = t["to"] return coils, coord_frame @@ -53,10 +54,9 @@ def _check_coil_frame(coils, coord_frame, bem): if coord_frame != FIFF.FIFFV_COORD_MRI: if coord_frame == FIFF.FIFFV_COORD_HEAD: # Make a transformed duplicate - coils, coord_Frame = _dup_coil_set(coils, coord_frame, - bem['head_mri_t']) + coils, coord_Frame = _dup_coil_set(coils, coord_frame, bem["head_mri_t"]) else: - raise RuntimeError('Bad coil coordinate frame %s' % coord_frame) + raise RuntimeError("Bad coil coordinate frame %s" % coord_frame) return coils, coord_frame @@ -88,12 +88,17 @@ def _lin_field_coeff(surf, mult, rmags, cosmags, ws, bins, n_jobs): (?) """ parallel, p_fun, n_jobs = parallel_func( - _do_lin_field_coeff, n_jobs, max_jobs=len(surf['tris'])) + _do_lin_field_coeff, n_jobs, max_jobs=len(surf["tris"]) + ) nas = np.array_split - coeffs = parallel(p_fun(surf['rr'], t, tn, ta, rmags, cosmags, ws, bins) - for t, tn, ta in zip(nas(surf['tris'], n_jobs), - nas(surf['tri_nn'], n_jobs), - nas(surf['tri_area'], n_jobs))) + coeffs = parallel( + p_fun(surf["rr"], t, tn, ta, rmags, cosmags, ws, bins) + for t, tn, ta in zip( + nas(surf["tris"], n_jobs), + nas(surf["tri_nn"], n_jobs), + nas(surf["tri_area"], n_jobs), + ) + ) return mult * np.sum(coeffs, axis=0) @@ -154,22 +159,21 @@ def _do_lin_field_coeff(bem_rr, tris, tn, ta, rmags, cosmags, ws, bins): for ti in range(3): x = np.sum(c[:, ti], axis=-1) x /= den[:, tri[ti]] / tri_area - coeff[:, tri[ti]] += \ - bincount(bins, weights=x, minlength=bins[-1] + 1) + coeff[:, tri[ti]] += bincount(bins, weights=x, minlength=bins[-1] + 1) return coeff def _concatenate_coils(coils): """Concatenate MEG coil parameters.""" - rmags = np.concatenate([coil['rmag'] for coil in coils]) - cosmags = np.concatenate([coil['cosmag'] for coil in coils]) - ws = np.concatenate([coil['w'] for coil in coils]) - n_int = np.array([len(coil['rmag']) for coil in coils]) + rmags = np.concatenate([coil["rmag"] for coil in coils]) + cosmags = np.concatenate([coil["cosmag"] for coil in coils]) + ws = np.concatenate([coil["w"] for coil in coils]) + n_int = np.array([len(coil["rmag"]) for coil in coils]) if n_int[-1] == 0: # We assume each sensor has at least one integration point, # which should be a safe assumption. But let's check it here, since # our code elsewhere relies on bins[-1] + 1 being the number of sensors - raise RuntimeError('not supported') + raise RuntimeError("not supported") bins = np.repeat(np.arange(len(n_int)), n_int) return rmags, cosmags, ws, bins @@ -208,8 +212,8 @@ def _bem_specify_coils(bem, coils, coord_frame, mults, n_jobs): # Process each of the surfaces rmags, cosmags, ws, bins = _triage_coils(coils) del coils - lens = np.cumsum(np.r_[0, [len(s['rr']) for s in bem['surfs']]]) - sol = np.zeros((bins[-1] + 1, bem['solution'].shape[1])) + lens = np.cumsum(np.r_[0, [len(s["rr"]) for s in bem["surfs"]]]) + sol = np.zeros((bins[-1] + 1, bem["solution"].shape[1])) lims = np.concatenate([np.arange(0, sol.shape[0], 100), [sol.shape[0]]]) # Put through the bem (in channel-based chunks to save memory) @@ -217,10 +221,11 @@ def _bem_specify_coils(bem, coils, coord_frame, mults, n_jobs): mask = np.logical_and(bins >= start, bins < stop) r, c, w, b = rmags[mask], cosmags[mask], ws[mask], bins[mask] - start # Compute coeffs for each surface, one at a time - for o1, o2, surf, mult in zip(lens[:-1], lens[1:], - bem['surfs'], bem['field_mult']): + for o1, o2, surf, mult in zip( + lens[:-1], lens[1:], bem["surfs"], bem["field_mult"] + ): coeff = _lin_field_coeff(surf, mult, r, c, w, b, n_jobs) - sol[start:stop] += np.dot(coeff, bem['solution'][o1:o2]) + sol[start:stop] += np.dot(coeff, bem["solution"][o1:o2]) sol *= mults return sol @@ -242,20 +247,22 @@ def _bem_specify_els(bem, els, mults): sol : ndarray, shape (n_EEG_sensors, n_BEM_vertices) EEG solution """ - sol = np.zeros((len(els), bem['solution'].shape[1])) - scalp = bem['surfs'][0] + sol = np.zeros((len(els), bem["solution"].shape[1])) + scalp = bem["surfs"][0] # Operate on all integration points for all electrodes (in MRI coords) - rrs = np.concatenate([apply_trans(bem['head_mri_t']['trans'], el['rmag']) - for el in els], axis=0) - ws = np.concatenate([el['w'] for el in els]) + rrs = np.concatenate( + [apply_trans(bem["head_mri_t"]["trans"], el["rmag"]) for el in els], axis=0 + ) + ws = np.concatenate([el["w"] for el in els]) tri_weights, tri_idx = _project_onto_surface(rrs, scalp) tri_weights *= ws[:, np.newaxis] - weights = np.matmul(tri_weights[:, np.newaxis], - bem['solution'][scalp['tris'][tri_idx]])[:, 0] + weights = np.matmul( + tri_weights[:, np.newaxis], bem["solution"][scalp["tris"][tri_idx]] + )[:, 0] # there are way more vertices than electrodes generally, so let's iterate # over the electrodes - edges = np.concatenate([[0], np.cumsum([len(el['w']) for el in els])]) + edges = np.concatenate([[0], np.cumsum([len(el["w"]) for el in els])]) for ii, (start, stop) in enumerate(zip(edges[:-1], edges[1:])): sol[ii] = weights[start:stop].sum(0) sol *= mults @@ -302,7 +309,7 @@ def _bem_inf_pots(mri_rr, bem_rr, mri_Q=None): this_diff = bem_rr - rr diff_norm = np.sum(this_diff * this_diff, axis=1) diff_norm *= np.sqrt(diff_norm) - diff_norm[diff_norm == 0] = 1. + diff_norm[diff_norm == 0] = 1.0 if mri_Q is not None: this_diff = np.dot(this_diff, mri_Q.T) this_diff /= diff_norm.reshape(-1, 1) @@ -310,6 +317,7 @@ def _bem_inf_pots(mri_rr, bem_rr, mri_Q=None): return diff + # This function has been refactored to process all points simultaneously # def _bem_inf_field(rd, Q, rp, d): # """Infinite-medium magnetic field. See (7) in Mosher, 1999""" @@ -370,8 +378,7 @@ def _bem_inf_fields(rr, rmag, cosmag): @fill_doc -def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, - coil_type): +def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, coil_type): """Calculate the magnetic field or electric potential forward solution. The code is very similar between EEG and MEG potentials, so combine them. @@ -404,22 +411,25 @@ def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, # Both MEG and EEG have the inifinite-medium potentials # This could be just vectorized, but eats too much memory, so instead we # reduce memory by chunking within _do_inf_pots and parallelize, too: - parallel, p_fun, n_jobs = parallel_func( - _do_inf_pots, n_jobs, max_jobs=len(rr)) + parallel, p_fun, n_jobs = parallel_func(_do_inf_pots, n_jobs, max_jobs=len(rr)) nas = np.array_split - B = np.sum(parallel(p_fun(mri_rr, sr.copy(), np.ascontiguousarray(mri_Q), - np.array(sol)) # copy and contig - for sr, sol in zip(nas(bem_rr, n_jobs), - nas(solution.T, n_jobs))), axis=0) + B = np.sum( + parallel( + p_fun( + mri_rr, sr.copy(), np.ascontiguousarray(mri_Q), np.array(sol) + ) # copy and contig + for sr, sol in zip(nas(bem_rr, n_jobs), nas(solution.T, n_jobs)) + ), + axis=0, + ) # The copy()s above should make it so the whole objects don't need to be # pickled... # Only MEG coils are sensitive to the primary current distribution. - if coil_type == 'meg': + if coil_type == "meg": # Primary current contribution (can be calc. in coil/dipole coords) parallel, p_fun, n_jobs = parallel_func(_do_prim_curr, n_jobs) - pcc = np.concatenate(parallel(p_fun(r, coils) - for r in nas(rr, n_jobs)), axis=0) + pcc = np.concatenate(parallel(p_fun(r, coils) for r in nas(rr, n_jobs)), axis=0) B += pcc B *= _MAG_FACTOR return B @@ -451,8 +461,9 @@ def _do_prim_curr(rr, coils): pp = _bem_inf_fields(rr[start:stop], rmags, cosmags) pp *= ws pp.shape = (3 * (stop - start), -1) - pc[3 * start:3 * stop] = [bincount(bins, this_pp, bins[-1] + 1) - for this_pp in pp] + pc[3 * start : 3 * stop] = [ + bincount(bins, this_pp, bins[-1] + 1) for this_pp in pp + ] return pc @@ -493,21 +504,21 @@ def _do_inf_pots(mri_rr, bem_rr, mri_Q, sol): # v0 in Hämäläinen et al., 1989 == v_inf in Mosher, et al., 1999 v0s = _bem_inf_pots(mri_rr[start:stop], bem_rr, mri_Q) v0s = v0s.reshape(-1, v0s.shape[2]) - B[3 * start:3 * stop] = np.dot(v0s, sol) + B[3 * start : 3 * stop] = np.dot(v0s, sol) return B # ############################################################################# # SPHERE COMPUTATION -def _sphere_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, - n_jobs, coil_type): + +def _sphere_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, coil_type): """Do potential or field for spherical model.""" - fun = _eeg_spherepot_coil if coil_type == 'eeg' else _sphere_field - parallel, p_fun, n_jobs = parallel_func( - fun, n_jobs, max_jobs=len(rr)) - B = np.concatenate(parallel(p_fun(r, coils, sphere=solution) - for r in np.array_split(rr, n_jobs))) + fun = _eeg_spherepot_coil if coil_type == "eeg" else _sphere_field + parallel, p_fun, n_jobs = parallel_func(fun, n_jobs, max_jobs=len(rr)) + B = np.concatenate( + parallel(p_fun(r, coils, sphere=solution) for r in np.array_split(rr, n_jobs)) + ) return B @@ -521,7 +532,7 @@ def _sphere_field(rrs, coils, sphere): by Matti Hämäläinen, February 1990 """ rmags, cosmags, ws, bins = _triage_coils(coils) - return _do_sphere_field(rrs, rmags, cosmags, ws, bins, sphere['r0']) + return _do_sphere_field(rrs, rmags, cosmags, ws, bins, sphere["r0"]) @jit() @@ -557,8 +568,9 @@ def _do_sphere_field(rrs, rmags, cosmags, ws, bins, r0): _jit_cross(v1, rr_, cosmags) v2 = np.empty((cosmags.shape[0], 3)) _jit_cross(v2, rr_, this_poss) - xx = ((good * ws).reshape(-1, 1) * - (v1 / F.reshape(-1, 1) + v2 * g.reshape(-1, 1))) + xx = (good * ws).reshape(-1, 1) * ( + v1 / F.reshape(-1, 1) + v2 * g.reshape(-1, 1) + ) for jj in range(3): zz = bincount(bins, xx[:, jj], n_coils) B[3 * ri + jj, :] = zz @@ -573,24 +585,24 @@ def _eeg_spherepot_coil(rrs, coils, sphere): del coils # Shift to the sphere model coordinates - rrs = rrs - sphere['r0'] + rrs = rrs - sphere["r0"] B = np.zeros((3 * len(rrs), n_coils)) for ri, rr in enumerate(rrs): # Only process dipoles inside the innermost sphere - if np.sqrt(np.dot(rr, rr)) >= sphere['layers'][0]['rad']: + if np.sqrt(np.dot(rr, rr)) >= sphere["layers"][0]["rad"]: continue # fwd_eeg_spherepot_vec vval_one = np.zeros((len(rmags), 3)) # Make a weighted sum over the equivalence parameters - for eq in range(sphere['nfit']): + for eq in range(sphere["nfit"]): # Scale the dipole position - rd = sphere['mu'][eq] * rr + rd = sphere["mu"][eq] * rr rd2 = np.sum(rd * rd) rd2_inv = 1.0 / rd2 # Go over all electrodes - this_pos = rmags - sphere['r0'] + this_pos = rmags - sphere["r0"] # Scale location onto the surface of the sphere (not used) # if sphere['scale_pos']: @@ -616,17 +628,19 @@ def _eeg_spherepot_coil(rrs, coils, sphere): c2 = a3 + (a + r) / (r * F) # Mix them together and scale by lambda/(rd*rd) - m1 = (c1 - c2 * rrd) + m1 = c1 - c2 * rrd m2 = c2 * rd2 - vval_one += (sphere['lambda'][eq] * rd2_inv * - (m1[:, np.newaxis] * rd + - m2[:, np.newaxis] * this_pos)) + vval_one += ( + sphere["lambda"][eq] + * rd2_inv + * (m1[:, np.newaxis] * rd + m2[:, np.newaxis] * this_pos) + ) # compute total result xx = vval_one * ws[:, np.newaxis] zz = np.array([bincount(bins, x, bins[-1] + 1) for x in xx.T]) - B[3 * ri:3 * ri + 3, :] = zz + B[3 * ri : 3 * ri + 3, :] = zz # finishing by scaling by 1/(4*M_PI) B *= 0.25 / np.pi return B @@ -642,14 +656,14 @@ def _triage_coils(coils): _MIN_DIST_LIMIT = 1e-5 -def _magnetic_dipole_field_vec(rrs, coils, too_close='raise'): +def _magnetic_dipole_field_vec(rrs, coils, too_close="raise"): rmags, cosmags, ws, bins = _triage_coils(coils) fwd, min_dist = _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close) if min_dist < _MIN_DIST_LIMIT: - msg = 'Coil too close (dist = %g mm)' % (min_dist * 1000,) - if too_close == 'raise': + msg = "Coil too close (dist = %g mm)" % (min_dist * 1000,) + if too_close == "raise": raise RuntimeError(msg) - func = warn if too_close == 'warning' else logger.info + func = warn if too_close == "warning" else logger.info func(msg) return fwd @@ -682,7 +696,7 @@ def _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close): dist2 = dist2_.reshape(-1, 1) dist = np.sqrt(dist2) min_dist = min(dist.min(), min_dist) - if min_dist < _MIN_DIST_LIMIT and too_close == 'raise': + if min_dist < _MIN_DIST_LIMIT and too_close == "raise": break t_ = np.sum(diff * cosmags, axis=1) t = t_.reshape(-1, 1) @@ -696,6 +710,7 @@ def _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close): # ############################################################################# # MAIN TRIAGING FUNCTION + @verbose def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): """Precompute and store some things that are used for both MEG and EEG. @@ -717,44 +732,47 @@ def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): %(verbose)s """ bem_rr = mults = mri_Q = head_mri_t = None - if not bem['is_sphere']: - if bem['bem_method'] != FIFF.FIFFV_BEM_APPROX_LINEAR: - raise RuntimeError('only linear collocation supported') + if not bem["is_sphere"]: + if bem["bem_method"] != FIFF.FIFFV_BEM_APPROX_LINEAR: + raise RuntimeError("only linear collocation supported") # Store (and apply soon) μ_0/(4π) factor before source computations - mults = np.repeat(bem['source_mult'] / (4.0 * np.pi), - [len(s['rr']) for s in bem['surfs']])[np.newaxis, :] + mults = np.repeat( + bem["source_mult"] / (4.0 * np.pi), [len(s["rr"]) for s in bem["surfs"]] + )[np.newaxis, :] # Get positions of BEM points for every surface - bem_rr = np.concatenate([s['rr'] for s in bem['surfs']]) + bem_rr = np.concatenate([s["rr"] for s in bem["surfs"]]) # The dipole location and orientation must be transformed - head_mri_t = bem['head_mri_t'] - mri_Q = bem['head_mri_t']['trans'][:3, :3].T + head_mri_t = bem["head_mri_t"] + mri_Q = bem["head_mri_t"]["trans"][:3, :3].T solutions = dict() for coil_type in sensors: - coils = sensors[coil_type]['defs'] - if not bem['is_sphere']: - if coil_type == 'meg': + coils = sensors[coil_type]["defs"] + if not bem["is_sphere"]: + if coil_type == "meg": # MEG field computation matrices for BEM - start = 'Composing the field computation matrix' - logger.info('\n' + start + '...') + start = "Composing the field computation matrix" + logger.info("\n" + start + "...") cf = FIFF.FIFFV_COORD_HEAD # multiply solution by "mults" here for simplicity solution = _bem_specify_coils(bem, coils, cf, mults, n_jobs) else: # Compute solution for EEG sensor - logger.info('Setting up for EEG...') + logger.info("Setting up for EEG...") solution = _bem_specify_els(bem, coils, mults) else: solution = bem - if coil_type == 'eeg': - logger.info('Using the equivalent source approach in the ' - 'homogeneous sphere for EEG') - sensors[coil_type]['defs'] = _triage_coils(coils) + if coil_type == "eeg": + logger.info( + "Using the equivalent source approach in the " + "homogeneous sphere for EEG" + ) + sensors[coil_type]["defs"] = _triage_coils(coils) solutions[coil_type] = solution # Get appropriate forward physics function depending on sphere or BEM model - fun = _sphere_pot_or_field if bem['is_sphere'] else _bem_pot_or_field + fun = _sphere_pot_or_field if bem["is_sphere"] else _bem_pot_or_field # Update fwd_data with # bem_rr (3D BEM vertex positions) @@ -764,8 +782,8 @@ def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): # solutions (len 2 list; [ndarray, shape (n_MEG_sens, n BEM vertices), # ndarray, shape (n_EEG_sens, n BEM vertices)] fwd_data = dict( - bem_rr=bem_rr, mri_Q=mri_Q, head_mri_t=head_mri_t, fun=fun, - solutions=solutions) + bem_rr=bem_rr, mri_Q=mri_Q, head_mri_t=head_mri_t, fun=fun, solutions=solutions + ) return fwd_data @@ -775,26 +793,34 @@ def _compute_forwards_meeg(rr, *, sensors, fwd_data, n_jobs, silent=False): Bs = dict() # The dipole location and orientation must be transformed to mri coords mri_rr = None - if fwd_data['head_mri_t'] is not None: - mri_rr = np.ascontiguousarray( - apply_trans(fwd_data['head_mri_t']['trans'], rr)) - mri_Q, bem_rr, fun = fwd_data['mri_Q'], fwd_data['bem_rr'], fwd_data['fun'] - solutions = fwd_data['solutions'] + if fwd_data["head_mri_t"] is not None: + mri_rr = np.ascontiguousarray(apply_trans(fwd_data["head_mri_t"]["trans"], rr)) + mri_Q, bem_rr, fun = fwd_data["mri_Q"], fwd_data["bem_rr"], fwd_data["fun"] + solutions = fwd_data["solutions"] del fwd_data for coil_type, sens in sensors.items(): - coils = sens['defs'] - compensator = sens.get('compensator', None) - post_picks = sens.get('post_picks', None) + coils = sens["defs"] + compensator = sens.get("compensator", None) + post_picks = sens.get("post_picks", None) solution = solutions.get(coil_type, None) # Do the actual forward calculation for a list MEG/EEG sensors if not silent: - logger.info('Computing %s at %d source location%s ' - '(free orientations)...' - % (coil_type.upper(), len(rr), _pl(rr))) + logger.info( + "Computing %s at %d source location%s " + "(free orientations)..." % (coil_type.upper(), len(rr), _pl(rr)) + ) # Calculate forward solution using spherical or BEM model - B = fun(rr, mri_rr, mri_Q, coils=coils, solution=solution, - bem_rr=bem_rr, n_jobs=n_jobs, coil_type=coil_type) + B = fun( + rr, + mri_rr, + mri_Q, + coils=coils, + solution=solution, + bem_rr=bem_rr, + n_jobs=n_jobs, + coil_type=coil_type, + ) # Compensate if needed (only done for MEG systems w/compensation) if compensator is not None: @@ -810,16 +836,16 @@ def _compute_forwards(rr, *, bem, sensors, n_jobs, verbose=None): """Compute the MEG and EEG forward solutions.""" # Split calculation into two steps to save (potentially) a lot of time # when e.g. dipole fitting - solver = bem.get('solver', 'mne') - _check_option('solver', solver, ('mne', 'openmeeg')) - if bem['is_sphere'] or solver == 'mne': - fwd_data = _prep_field_computation( - rr, sensors=sensors, bem=bem, n_jobs=n_jobs) + solver = bem.get("solver", "mne") + _check_option("solver", solver, ("mne", "openmeeg")) + if bem["is_sphere"] or solver == "mne": + fwd_data = _prep_field_computation(rr, sensors=sensors, bem=bem, n_jobs=n_jobs) Bs = _compute_forwards_meeg( - rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs) + rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs + ) else: Bs = _compute_forwards_openmeeg(rr, bem=bem, sensors=sensors) - n_sensors_want = sum(len(s['ch_names']) for s in sensors.values()) + n_sensors_want = sum(len(s["ch_names"]) for s in sensors.values()) n_sensors = sum(B.shape[1] for B in Bs.values()) n_sources = list(Bs.values())[0].shape[0] assert (n_sources, n_sensors) == (len(rr) * 3, n_sensors_want) @@ -830,30 +856,30 @@ def _compute_forwards_openmeeg(rr, *, bem, sensors): """Compute the MEG and EEG forward solutions for OpenMEEG.""" if len(bem["surfs"]) != 3: raise RuntimeError("Only 3-layer BEM is supported for OpenMEEG.") - om = _import_openmeeg('compute a forward solution using OpenMEEG') + om = _import_openmeeg("compute a forward solution using OpenMEEG") hminv = om.SymMatrix(bem["solution"]) - geom = _make_openmeeg_geometry(bem, invert_transform(bem['head_mri_t'])) + geom = _make_openmeeg_geometry(bem, invert_transform(bem["head_mri_t"])) # Make dipoles for all XYZ orientations dipoles = np.c_[ np.kron(rr.T, np.ones(3)[None, :]).T, - np.kron(np.ones(len(rr))[:, None], - np.eye(3)), + np.kron(np.ones(len(rr))[:, None], np.eye(3)), ] dipoles = np.asfortranarray(dipoles) dipoles = om.Matrix(dipoles) dsm = om.DipSourceMat(geom, dipoles, "Brain") Bs = dict() - if 'eeg' in sensors: - rmags, _, ws, bins = _concatenate_coils(sensors['eeg']['defs']) + if "eeg" in sensors: + rmags, _, ws, bins = _concatenate_coils(sensors["eeg"]["defs"]) rmags = np.asfortranarray(rmags.astype(np.float64)) eeg_sensors = om.Sensors(om.Matrix(np.asfortranarray(rmags)), geom) h2em = om.Head2EEGMat(geom, eeg_sensors) eeg_fwd_full = om.GainEEG(hminv, dsm, h2em).array() - Bs['eeg'] = np.array([bincount(bins, ws * x, bins[-1] + 1) - for x in eeg_fwd_full.T], float) - if 'meg' in sensors: - rmags, cosmags, ws, bins = _concatenate_coils(sensors['meg']['defs']) + Bs["eeg"] = np.array( + [bincount(bins, ws * x, bins[-1] + 1) for x in eeg_fwd_full.T], float + ) + if "meg" in sensors: + rmags, cosmags, ws, bins = _concatenate_coils(sensors["meg"]["defs"]) rmags = np.asfortranarray(rmags.astype(np.float64)) cosmags = np.asfortranarray(cosmags.astype(np.float64)) labels = [str(ii) for ii in range(len(rmags))] @@ -862,13 +888,14 @@ def _compute_forwards_openmeeg(rr, *, bem, sensors): h2mm = om.Head2MEGMat(geom, meg_sensors) ds2mm = om.DipSource2MEGMat(dipoles, meg_sensors) meg_fwd_full = om.GainMEG(hminv, dsm, h2mm, ds2mm).array() - B = np.array([bincount(bins, ws * x, bins[-1] + 1) - for x in meg_fwd_full.T], float) - compensator = sensors['meg'].get('compensator', None) - post_picks = sensors['meg'].get('post_picks', None) + B = np.array( + [bincount(bins, ws * x, bins[-1] + 1) for x in meg_fwd_full.T], float + ) + compensator = sensors["meg"].get("compensator", None) + post_picks = sensors["meg"].get("post_picks", None) if compensator is not None: B = B @ compensator.T if post_picks is not None: B = B[:, post_picks] - Bs['meg'] = B + Bs["meg"] = B return Bs diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index fdc21ab8e9c..acad17a7fca 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -18,11 +18,13 @@ from ..surface import get_head_surf, get_meg_helmet_surf from ..transforms import transform_surface_to, _find_trans, _get_trans from ._make_forward import _create_meg_coils, _create_eeg_els, _read_coil_defs -from ._lead_dots import (_do_self_dots, _do_surface_dots, _get_legen_table, - _do_cross_dots) -from ..utils import ( - logger, verbose, _check_option, _reg_pinv, _pl, _check_fname +from ._lead_dots import ( + _do_self_dots, + _do_surface_dots, + _get_legen_table, + _do_cross_dots, ) +from ..utils import logger, verbose, _check_option, _reg_pinv, _pl, _check_fname from ..epochs import EpochsArray, BaseEpochs from ..evoked import Evoked, EvokedArray @@ -30,9 +32,10 @@ def _setup_dots(mode, info, coils, ch_type): """Set up dot products.""" from scipy.interpolate import interp1d + int_rad = 0.06 noise = make_ad_hoc_cov(info, dict(mag=20e-15, grad=5e-13, eeg=1e-6)) - n_coeff, interp = (50, 'nearest') if mode == 'fast' else (100, 'linear') + n_coeff, interp = (50, "nearest") if mode == "fast" else (100, "linear") lut, n_fact = _get_legen_table(ch_type, False, n_coeff, verbose=False) lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, axis=0) return int_rad, noise, lut_fun, n_fact @@ -40,27 +43,27 @@ def _setup_dots(mode, info, coils, ch_type): def _compute_mapping_matrix(fmd, info): """Do the hairy computations.""" - logger.info(' Preparing the mapping matrix...') + logger.info(" Preparing the mapping matrix...") # assemble a projector and apply it to the data - ch_names = fmd['ch_names'] - projs = info.get('projs', list()) + ch_names = fmd["ch_names"] + projs = info.get("projs", list()) proj_op = make_projector(projs, ch_names)[0] - proj_dots = np.dot(proj_op.T, np.dot(fmd['self_dots'], proj_op)) + proj_dots = np.dot(proj_op.T, np.dot(fmd["self_dots"], proj_op)) - noise_cov = fmd['noise'] + noise_cov = fmd["noise"] # Whiten - if not noise_cov['diag']: + if not noise_cov["diag"]: raise NotImplementedError # this shouldn't happen - whitener = np.diag(1.0 / np.sqrt(noise_cov['data'].ravel())) + whitener = np.diag(1.0 / np.sqrt(noise_cov["data"].ravel())) whitened_dots = np.dot(whitener.T, np.dot(proj_dots, whitener)) # SVD is numerically better than the eigenvalue composition even if # mat is supposed to be symmetric and positive definite - if fmd.get('pinv_method', 'tsvd') == 'tsvd': - inv, fmd['nest'] = _pinv_trunc(whitened_dots, fmd['miss']) + if fmd.get("pinv_method", "tsvd") == "tsvd": + inv, fmd["nest"] = _pinv_trunc(whitened_dots, fmd["miss"]) else: - assert fmd['pinv_method'] == 'tikhonov', fmd['pinv_method'] - inv, fmd['nest'] = _pinv_tikhonov(whitened_dots, fmd['miss']) + assert fmd["pinv_method"] == "tikhonov", fmd["pinv_method"] + inv, fmd["nest"] = _pinv_tikhonov(whitened_dots, fmd["miss"]) # Sandwich with the whitener inv_whitened = np.dot(whitener.T, np.dot(inv, whitener)) @@ -71,13 +74,14 @@ def _compute_mapping_matrix(fmd, info): # Finally sandwich in the selection matrix # This one picks up the correct lead field projection - mapping_mat = np.dot(fmd['surface_dots'], inv_whitened_proj) + mapping_mat = np.dot(fmd["surface_dots"], inv_whitened_proj) # Optionally apply the average electrode reference to the final field map - if fmd['kind'] == 'eeg' and _has_eeg_average_ref_proj(info): + if fmd["kind"] == "eeg" and _has_eeg_average_ref_proj(info): logger.info( - ' The map has an average electrode reference ' - f'({mapping_mat.shape[0]} channels)') + " The map has an average electrode reference " + f"({mapping_mat.shape[0]} channels)" + ) mapping_mat -= np.mean(mapping_mat, axis=0) return mapping_mat @@ -85,15 +89,18 @@ def _compute_mapping_matrix(fmd, info): def _pinv_trunc(x, miss): """Compute pseudoinverse, truncating at most "miss" fraction of varexp.""" from scipy import linalg + u, s, v = linalg.svd(x, full_matrices=False) # Eigenvalue truncation varexp = np.cumsum(s) varexp /= varexp[-1] n = np.where(varexp >= (1.0 - miss))[0][0] + 1 - logger.info(' Truncating at %d/%d components to omit less than %g ' - '(%0.2g)' % (n, len(s), miss, 1. - varexp[n - 1])) - s = 1. / s[:n] + logger.info( + " Truncating at %d/%d components to omit less than %g " + "(%0.2g)" % (n, len(s), miss, 1.0 - varexp[n - 1]) + ) + s = 1.0 / s[:n] inv = ((u[:, :n] * s) @ v[:n]).T return inv, n @@ -101,8 +108,10 @@ def _pinv_trunc(x, miss): def _pinv_tikhonov(x, reg): # _reg_pinv requires square Hermitian, which we have here inv, _, n = _reg_pinv(x, reg=reg, rank=None) - logger.info(f' Truncating at {n}/{len(x)} components and regularizing ' - f'with α={reg:0.1e}') + logger.info( + f" Truncating at {n}/{len(x)} components and regularizing " + f"with α={reg:0.1e}" + ) return inv, n @@ -131,58 +140,76 @@ def _map_meg_or_eeg_channels(info_from, info_to, mode, origin, miss=None): """ # no need to apply trans because both from and to coils are in device # coordinates - info_kinds = set(ch['kind'] for ch in info_to['chs']) - info_kinds |= set(ch['kind'] for ch in info_from['chs']) + info_kinds = set(ch["kind"] for ch in info_to["chs"]) + info_kinds |= set(ch["kind"] for ch in info_from["chs"]) if FIFF.FIFFV_REF_MEG_CH in info_kinds: # refs same as MEG info_kinds |= set([FIFF.FIFFV_MEG_CH]) info_kinds -= set([FIFF.FIFFV_REF_MEG_CH]) info_kinds = sorted(info_kinds) # This should be guaranteed by the callers - assert (len(info_kinds) == 1 and info_kinds[0] in ( - FIFF.FIFFV_MEG_CH, FIFF.FIFFV_EEG_CH)) - kind = 'eeg' if info_kinds[0] == FIFF.FIFFV_EEG_CH else 'meg' + assert len(info_kinds) == 1 and info_kinds[0] in ( + FIFF.FIFFV_MEG_CH, + FIFF.FIFFV_EEG_CH, + ) + kind = "eeg" if info_kinds[0] == FIFF.FIFFV_EEG_CH else "meg" # # Step 1. Prepare the coil definitions # - if kind == 'meg': + if kind == "meg": templates = _read_coil_defs(verbose=False) - coils_from = _create_meg_coils(info_from['chs'], 'normal', - info_from['dev_head_t'], templates) - coils_to = _create_meg_coils(info_to['chs'], 'normal', - info_to['dev_head_t'], templates) - pinv_method = 'tsvd' + coils_from = _create_meg_coils( + info_from["chs"], "normal", info_from["dev_head_t"], templates + ) + coils_to = _create_meg_coils( + info_to["chs"], "normal", info_to["dev_head_t"], templates + ) + pinv_method = "tsvd" miss = 1e-4 else: - coils_from = _create_eeg_els(info_from['chs']) - coils_to = _create_eeg_els(info_to['chs']) - pinv_method = 'tikhonov' + coils_from = _create_eeg_els(info_from["chs"]) + coils_to = _create_eeg_els(info_to["chs"]) + pinv_method = "tikhonov" miss = 1e-1 - if _has_eeg_average_ref_proj(info_from) and \ - not _has_eeg_average_ref_proj(info_to): + if _has_eeg_average_ref_proj(info_from) and not _has_eeg_average_ref_proj( + info_to + ): raise RuntimeError( - 'info_to must have an average EEG reference projector if ' - 'info_from has one') + "info_to must have an average EEG reference projector if " + "info_from has one" + ) origin = _check_origin(origin, info_from) # # Step 2. Calculate the dot products # - int_rad, noise, lut_fun, n_fact = _setup_dots( - mode, info_from, coils_from, kind) - logger.info(f' Computing dot products for {len(coils_from)} ' - f'{kind.upper()} channel{_pl(coils_from)}...') - self_dots = _do_self_dots(int_rad, False, coils_from, origin, kind, - lut_fun, n_fact, n_jobs=None) - logger.info(f' Computing cross products for {len(coils_from)} → ' - f'{len(coils_to)} {kind.upper()} channel{_pl(coils_to)}...') - cross_dots = _do_cross_dots(int_rad, False, coils_from, coils_to, - origin, kind, lut_fun, n_fact).T - - ch_names = [c['ch_name'] for c in info_from['chs']] - fmd = dict(kind=kind, ch_names=ch_names, - origin=origin, noise=noise, self_dots=self_dots, - surface_dots=cross_dots, int_rad=int_rad, miss=miss, - pinv_method=pinv_method) + int_rad, noise, lut_fun, n_fact = _setup_dots(mode, info_from, coils_from, kind) + logger.info( + f" Computing dot products for {len(coils_from)} " + f"{kind.upper()} channel{_pl(coils_from)}..." + ) + self_dots = _do_self_dots( + int_rad, False, coils_from, origin, kind, lut_fun, n_fact, n_jobs=None + ) + logger.info( + f" Computing cross products for {len(coils_from)} → " + f"{len(coils_to)} {kind.upper()} channel{_pl(coils_to)}..." + ) + cross_dots = _do_cross_dots( + int_rad, False, coils_from, coils_to, origin, kind, lut_fun, n_fact + ).T + + ch_names = [c["ch_name"] for c in info_from["chs"]] + fmd = dict( + kind=kind, + ch_names=ch_names, + origin=origin, + noise=noise, + self_dots=self_dots, + surface_dots=cross_dots, + int_rad=int_rad, + miss=miss, + pinv_method=pinv_method, + ) # # Step 3. Compute the mapping matrix @@ -191,7 +218,7 @@ def _map_meg_or_eeg_channels(info_from, info_to, mode, origin, miss=None): return mapping -def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): +def _as_meg_type_inst(inst, ch_type="grad", mode="fast"): """Compute virtual evoked using interpolated fields in mag/grad channels. Parameters @@ -210,30 +237,31 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): inst : instance of mne.EvokedArray or mne.EpochsArray The transformed evoked object containing only virtual channels. """ - _check_option('ch_type', ch_type, ['mag', 'grad']) + _check_option("ch_type", ch_type, ["mag", "grad"]) # pick the original and destination channels - pick_from = pick_types(inst.info, meg=True, eeg=False, - ref_meg=False) - pick_to = pick_types(inst.info, meg=ch_type, eeg=False, - ref_meg=False) + pick_from = pick_types(inst.info, meg=True, eeg=False, ref_meg=False) + pick_to = pick_types(inst.info, meg=ch_type, eeg=False, ref_meg=False) if len(pick_to) == 0: - raise ValueError('No channels matching the destination channel type' - ' found in info. Please pass an evoked containing' - 'both the original and destination channels. Only the' - ' locations of the destination channels will be used' - ' for interpolation.') + raise ValueError( + "No channels matching the destination channel type" + " found in info. Please pass an evoked containing" + "both the original and destination channels. Only the" + " locations of the destination channels will be used" + " for interpolation." + ) info_from = pick_info(inst.info, pick_from) info_to = pick_info(inst.info, pick_to) # XXX someday we should probably expose the origin mapping = _map_meg_or_eeg_channels( - info_from, info_to, origin=(0., 0., 0.04), mode=mode) + info_from, info_to, origin=(0.0, 0.0, 0.04), mode=mode + ) # compute data by multiplying by the 'gain matrix' from # original sensors to virtual sensors - if hasattr(inst, 'get_data'): + if hasattr(inst, "get_data"): data = inst.get_data() else: data = inst.data @@ -242,8 +270,7 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): if ndim == 2: data = data[np.newaxis, :, :] - data_ = np.empty((data.shape[0], len(mapping), data.shape[2]), - dtype=data.dtype) + data_ = np.empty((data.shape[0], len(mapping), data.shape[2]), dtype=data.dtype) for d, d_ in zip(data, data_): d_[:] = np.dot(mapping, d[pick_from]) @@ -251,28 +278,41 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): info = pick_info(inst.info, sel=pick_to, copy=True) # change channel names to emphasize they contain interpolated data - for ch in info['chs']: - ch['ch_name'] += '_v' + for ch in info["chs"]: + ch["ch_name"] += "_v" info._update_redundant() info._check_consistency() if isinstance(inst, Evoked): assert ndim == 2 data_ = data_[0] # undo new axis - inst_ = EvokedArray(data_, info, tmin=inst.times[0], - comment=inst.comment, nave=inst.nave) + inst_ = EvokedArray( + data_, info, tmin=inst.times[0], comment=inst.comment, nave=inst.nave + ) else: assert isinstance(inst, BaseEpochs) - inst_ = EpochsArray(data_, info, tmin=inst.tmin, - events=inst.events, - event_id=inst.event_id, - metadata=inst.metadata) + inst_ = EpochsArray( + data_, + info, + tmin=inst.tmin, + events=inst.events, + event_id=inst.event_id, + metadata=inst.metadata, + ) return inst_ @verbose -def _make_surface_mapping(info, surf, ch_type='meg', trans=None, mode='fast', - n_jobs=None, origin=(0., 0., 0.04), verbose=None): +def _make_surface_mapping( + info, + surf, + ch_type="meg", + trans=None, + mode="fast", + n_jobs=None, + origin=(0.0, 0.0, 0.04), + verbose=None, +): """Re-map M/EEG data to a surface. Parameters @@ -303,88 +343,108 @@ def _make_surface_mapping(info, surf, ch_type='meg', trans=None, mode='fast', A n_vertices x n_sensors array that remaps the MEG or EEG data, as `new_data = np.dot(mapping, data)`. """ - if not all(key in surf for key in ['rr', 'nn']): + if not all(key in surf for key in ["rr", "nn"]): raise KeyError('surf must have both "rr" and "nn"') - if 'coord_frame' not in surf: - raise KeyError('The surface coordinate frame must be specified ' - 'in surf["coord_frame"]') - _check_option('mode', mode, ['accurate', 'fast']) + if "coord_frame" not in surf: + raise KeyError( + "The surface coordinate frame must be specified " 'in surf["coord_frame"]' + ) + _check_option("mode", mode, ["accurate", "fast"]) # deal with coordinate frames here -- always go to "head" (easiest) orig_surf = surf - surf = transform_surface_to(deepcopy(surf), 'head', trans) + surf = transform_surface_to(deepcopy(surf), "head", trans) origin = _check_origin(origin, info) # # Step 1. Prepare the coil definitions # Do the dot products, assume surf in head coords # - _check_option('ch_type', ch_type, ['meg', 'eeg']) - if ch_type == 'meg': + _check_option("ch_type", ch_type, ["meg", "eeg"]) + if ch_type == "meg": picks = pick_types(info, meg=True, eeg=False, ref_meg=False) - logger.info('Prepare MEG mapping...') + logger.info("Prepare MEG mapping...") else: picks = pick_types(info, meg=False, eeg=True, ref_meg=False) - logger.info('Prepare EEG mapping...') + logger.info("Prepare EEG mapping...") if len(picks) == 0: - raise RuntimeError('cannot map, no channels found') + raise RuntimeError("cannot map, no channels found") # XXX this code does not do any checking for compensation channels, # but it seems like this must be intentional from the ref_meg=False # (presumably from the C code) - dev_head_t = info['dev_head_t'] + dev_head_t = info["dev_head_t"] info = pick_info(_simplify_info(info), picks) - info['dev_head_t'] = dev_head_t + info["dev_head_t"] = dev_head_t # create coil defs in head coordinates - if ch_type == 'meg': + if ch_type == "meg": # Put them in head coordinates - coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t']) - type_str = 'coils' + coils = _create_meg_coils(info["chs"], "normal", info["dev_head_t"]) + type_str = "coils" miss = 1e-4 # Smoothing criterion for MEG else: # EEG - coils = _create_eeg_els(info['chs']) - type_str = 'electrodes' + coils = _create_eeg_els(info["chs"]) + type_str = "electrodes" miss = 1e-3 # Smoothing criterion for EEG # # Step 2. Calculate the dot products # int_rad, noise, lut_fun, n_fact = _setup_dots(mode, info, coils, ch_type) - logger.info('Computing dot products for %i %s...' % (len(coils), type_str)) - self_dots = _do_self_dots(int_rad, False, coils, origin, ch_type, - lut_fun, n_fact, n_jobs) - sel = np.arange(len(surf['rr'])) # eventually we should do sub-selection - logger.info('Computing dot products for %i surface locations...' - % len(sel)) - surface_dots = _do_surface_dots(int_rad, False, coils, surf, sel, - origin, ch_type, lut_fun, n_fact, - n_jobs) + logger.info("Computing dot products for %i %s..." % (len(coils), type_str)) + self_dots = _do_self_dots( + int_rad, False, coils, origin, ch_type, lut_fun, n_fact, n_jobs + ) + sel = np.arange(len(surf["rr"])) # eventually we should do sub-selection + logger.info("Computing dot products for %i surface locations..." % len(sel)) + surface_dots = _do_surface_dots( + int_rad, False, coils, surf, sel, origin, ch_type, lut_fun, n_fact, n_jobs + ) # # Step 4. Return the result # - fmd = dict(kind=ch_type, surf=surf, ch_names=info['ch_names'], coils=coils, - origin=origin, noise=noise, self_dots=self_dots, - surface_dots=surface_dots, int_rad=int_rad, miss=miss) - logger.info('Field mapping data ready') - - fmd['data'] = _compute_mapping_matrix(fmd, info) + fmd = dict( + kind=ch_type, + surf=surf, + ch_names=info["ch_names"], + coils=coils, + origin=origin, + noise=noise, + self_dots=self_dots, + surface_dots=surface_dots, + int_rad=int_rad, + miss=miss, + ) + logger.info("Field mapping data ready") + + fmd["data"] = _compute_mapping_matrix(fmd, info) # bring the original back, whatever coord frame it was in - fmd['surf'] = orig_surf + fmd["surf"] = orig_surf # Remove some unnecessary fields - del fmd['self_dots'] - del fmd['surface_dots'] - del fmd['int_rad'] - del fmd['miss'] + del fmd["self_dots"] + del fmd["surface_dots"] + del fmd["int_rad"] + del fmd["miss"] return fmd @verbose -def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, - ch_type=None, mode='fast', meg_surf='helmet', - origin=(0., 0., 0.04), n_jobs=None, *, - head_source=('bem', 'head'), verbose=None): +def make_field_map( + evoked, + trans="auto", + subject=None, + subjects_dir=None, + ch_type=None, + mode="fast", + meg_surf="helmet", + origin=(0.0, 0.0, 0.04), + n_jobs=None, + *, + head_source=("bem", "head"), + verbose=None, +): """Compute surface maps used for field display in 3D. Parameters @@ -433,9 +493,9 @@ def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, info = evoked.info if ch_type is None: - types = [t for t in ['eeg', 'meg'] if t in evoked] + types = [t for t in ["eeg", "meg"] if t in evoked] else: - _check_option('ch_type', ch_type, ['eeg', 'meg']) + _check_option("ch_type", ch_type, ["eeg", "meg"]) types = [ch_type] if subjects_dir is not None: @@ -446,35 +506,40 @@ def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, name="subjects_dir", need_dir=True, ) - if isinstance(trans, str) and trans == 'auto': + if isinstance(trans, str) and trans == "auto": # let's try to do this in MRI coordinates so they're easy to plot trans = _find_trans(subject, subjects_dir) - trans, trans_type = _get_trans(trans, fro='head', to='mri') + trans, trans_type = _get_trans(trans, fro="head", to="mri") - if 'eeg' in types and trans_type == 'identity': - logger.info('No trans file available. EEG data ignored.') - types.remove('eeg') + if "eeg" in types and trans_type == "identity": + logger.info("No trans file available. EEG data ignored.") + types.remove("eeg") if len(types) == 0: - raise RuntimeError('No data available for mapping.') + raise RuntimeError("No data available for mapping.") - _check_option('meg_surf', meg_surf, ['helmet', 'head']) + _check_option("meg_surf", meg_surf, ["helmet", "head"]) surfs = [] for this_type in types: - if this_type == 'meg' and meg_surf == 'helmet': + if this_type == "meg" and meg_surf == "helmet": surf = get_meg_helmet_surf(info, trans) else: - surf = get_head_surf( - subject, source=head_source, subjects_dir=subjects_dir) + surf = get_head_surf(subject, source=head_source, subjects_dir=subjects_dir) surfs.append(surf) surf_maps = list() for this_type, this_surf in zip(types, surfs): - this_map = _make_surface_mapping(evoked.info, this_surf, this_type, - trans, n_jobs=n_jobs, origin=origin, - mode=mode) + this_map = _make_surface_mapping( + evoked.info, + this_surf, + this_type, + trans, + n_jobs=n_jobs, + origin=origin, + mode=mode, + ) surf_maps.append(this_map) return surf_maps diff --git a/mne/forward/_lead_dots.py b/mne/forward/_lead_dots.py index a97bac9d660..3eda719ac59 100644 --- a/mne/forward/_lead_dots.py +++ b/mne/forward/_lead_dots.py @@ -20,6 +20,7 @@ ############################################################################## # FAST LEGENDRE (DERIVATIVE) POLYNOMIALS USING LOOKUP TABLE + def _next_legen_der(n, x, p0, p01, p0d, p0dd): """Compute the next Legendre polynomial and its derivatives.""" # only good for n > 1 ! @@ -46,50 +47,56 @@ def _get_legen_der(xx, n_coeff=100): p0dds[:2] = [0.0, 0.0] for n in range(2, n_coeff): p0s[n], p0ds[n], p0dds[n] = _next_legen_der( - n, x, p0s[n - 1], p0s[n - 2], p0ds[n - 1], p0dds[n - 1]) + n, x, p0s[n - 1], p0s[n - 2], p0ds[n - 1], p0dds[n - 1] + ) return coeffs @verbose -def _get_legen_table(ch_type, volume_integral=False, n_coeff=100, - n_interp=20000, force_calc=False, verbose=None): +def _get_legen_table( + ch_type, + volume_integral=False, + n_coeff=100, + n_interp=20000, + force_calc=False, + verbose=None, +): """Return a (generated) LUT of Legendre (derivative) polynomial coeffs.""" if n_interp % 2 != 0: - raise RuntimeError('n_interp must be even') - fname = op.join(_get_extra_data_path(), 'tables') + raise RuntimeError("n_interp must be even") + fname = op.join(_get_extra_data_path(), "tables") if not op.isdir(fname): # Updated due to API change (GH 1167) os.makedirs(fname) - if ch_type == 'meg': - fname = op.join(fname, 'legder_%s_%s.bin' % (n_coeff, n_interp)) + if ch_type == "meg": + fname = op.join(fname, "legder_%s_%s.bin" % (n_coeff, n_interp)) leg_fun = _get_legen_der - extra_str = ' derivative' + extra_str = " derivative" lut_shape = (n_interp + 1, n_coeff, 3) else: # 'eeg' - fname = op.join(fname, 'legval_%s_%s.bin' % (n_coeff, n_interp)) + fname = op.join(fname, "legval_%s_%s.bin" % (n_coeff, n_interp)) leg_fun = _get_legen - extra_str = '' + extra_str = "" lut_shape = (n_interp + 1, n_coeff) if not op.isfile(fname) or force_calc: - logger.info('Generating Legendre%s table...' % extra_str) + logger.info("Generating Legendre%s table..." % extra_str) x_interp = np.linspace(-1, 1, n_interp + 1) lut = leg_fun(x_interp, n_coeff).astype(np.float32) if not force_calc: - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: fid.write(lut.tobytes()) else: - logger.info('Reading Legendre%s table...' % extra_str) - with open(fname, 'rb', buffering=0) as fid: + logger.info("Reading Legendre%s table..." % extra_str) + with open(fname, "rb", buffering=0) as fid: lut = np.fromfile(fid, np.float32) lut.shape = lut_shape # we need this for the integration step n_fact = np.arange(1, n_coeff, dtype=float) - if ch_type == 'meg': + if ch_type == "meg": n_facts = list() # multn, then mult, then multn * (n + 1) if volume_integral: - n_facts.append(n_fact / ((2.0 * n_fact + 1.0) * - (2.0 * n_fact + 3.0))) + n_facts.append(n_fact / ((2.0 * n_fact + 1.0) * (2.0 * n_fact + 3.0))) else: n_facts.append(n_fact / (2.0 * n_fact + 1.0)) n_facts.append(n_facts[0] / (n_fact + 1.0)) @@ -167,8 +174,13 @@ def _comp_sums_meg(beta, ctheta, lut_fun, n_fact, volume_integral): bbeta = np.tile(beta[start:stop][np.newaxis], (n_fact.shape[0], 1)) bbeta[0] *= beta[start:stop] np.cumprod(bbeta, axis=0, out=bbeta) # run inplace - np.einsum('ji,jk,ijk->ki', bbeta, n_fact, lut_fun(ctheta[start:stop]), - out=sums[:, start:stop]) + np.einsum( + "ji,jk,ijk->ki", + bbeta, + n_fact, + lut_fun(ctheta[start:stop]), + out=sums[:, start:stop], + ) return sums @@ -179,8 +191,21 @@ def _comp_sums_meg(beta, ctheta, lut_fun, n_fact, volume_integral): _eeg_const = 1.0 / (4.0 * np.pi) -def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, - w1, w2s, volume_integral, lut, n_fact, ch_type): +def _fast_sphere_dot_r0( + r, + rr1_orig, + rr2s, + lr1, + lr2s, + cosmags1, + cosmags2s, + w1, + w2s, + volume_integral, + lut, + n_fact, + ch_type, +): """Lead field dot product computation for M/EEG in the sphere model. Parameters @@ -230,7 +255,7 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, cosmags2 = np.concatenate(cosmags2s) # outer product, sum over coords - ct = np.einsum('ik,jk->ij', rr1_orig, rr2) + ct = np.einsum("ik,jk->ij", rr1_orig, rr2) np.clip(ct, -1, 1, ct) # expand axes @@ -239,9 +264,10 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, lr1lr2 = lr1[:, np.newaxis] * lr2[np.newaxis, :] beta = (r * r) / lr1lr2 - if ch_type == 'meg': - sums = _comp_sums_meg(beta.flatten(), ct.flatten(), lut, n_fact, - volume_integral) + if ch_type == "meg": + sums = _comp_sums_meg( + beta.flatten(), ct.flatten(), lut, n_fact, volume_integral + ) sums.shape = (4,) + beta.shape # Accumulate the result, a little bit streamlined version @@ -252,21 +278,23 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, # n2c1 = np.sum(cosmags2 * rr1, axis=2) # n2c2 = np.sum(cosmags2 * rr2, axis=2) # n1n2 = np.sum(cosmags1 * cosmags2, axis=2) - n1c1 = np.einsum('ik,ijk->ij', cosmags1, rr1) - n1c2 = np.einsum('ik,ijk->ij', cosmags1, rr2) - n2c1 = np.einsum('jk,ijk->ij', cosmags2, rr1) - n2c2 = np.einsum('jk,ijk->ij', cosmags2, rr2) - n1n2 = np.einsum('ik,jk->ij', cosmags1, cosmags2) + n1c1 = np.einsum("ik,ijk->ij", cosmags1, rr1) + n1c2 = np.einsum("ik,ijk->ij", cosmags1, rr2) + n2c1 = np.einsum("jk,ijk->ij", cosmags2, rr1) + n2c2 = np.einsum("jk,ijk->ij", cosmags2, rr2) + n1n2 = np.einsum("ik,jk->ij", cosmags1, cosmags2) part1 = ct * n1c1 * n2c2 part2 = n1c1 * n2c1 + n1c2 * n2c2 - result = (n1c1 * n2c2 * sums[0] + - (2.0 * part1 - part2) * sums[1] + - (n1n2 + part1 - part2) * sums[2] + - (n1c2 - ct * n1c1) * (n2c1 - ct * n2c2) * sums[3]) + result = ( + n1c1 * n2c2 * sums[0] + + (2.0 * part1 - part2) * sums[1] + + (n1n2 + part1 - part2) * sums[2] + + (n1c2 - ct * n1c1) * (n2c1 - ct * n2c2) * sums[3] + ) # Give it a finishing touch! - result *= (_meg_const / lr1lr2) + result *= _meg_const / lr1lr2 if volume_integral: result *= r else: # 'eeg' @@ -281,7 +309,7 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, if w1 is not None: result *= w1[:, np.newaxis] for ii, w2 in enumerate(w2s): - out[ii] = np.sum(result[:, offset:offset + len(w2)], axis=sum_axis) + out[ii] = np.sum(result[:, offset : offset + len(w2)], axis=sum_axis) offset += len(w2) return out @@ -314,40 +342,52 @@ def _do_self_dots(intrad, volume, coils, r0, ch_type, lut, n_fact, n_jobs): products : array, shape (n_coils, n_coils) The integration products. """ - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 # convert to normalized distances from expansion center - rmags = [coil['rmag'] - r0[np.newaxis, :] for coil in coils] + rmags = [coil["rmag"] - r0[np.newaxis, :] for coil in coils] rlens = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags] rmags = [r / rl[:, np.newaxis] for r, rl in zip(rmags, rlens)] - cosmags = [coil['cosmag'] for coil in coils] - ws = [coil['w'] for coil in coils] + cosmags = [coil["cosmag"] for coil in coils] + ws = [coil["w"] for coil in coils] parallel, p_fun, n_jobs = parallel_func(_do_self_dots_subset, n_jobs) - prods = parallel(p_fun(intrad, rmags, rlens, cosmags, - ws, volume, lut, n_fact, ch_type, idx) - for idx in np.array_split(np.arange(len(rmags)), n_jobs)) + prods = parallel( + p_fun(intrad, rmags, rlens, cosmags, ws, volume, lut, n_fact, ch_type, idx) + for idx in np.array_split(np.arange(len(rmags)), n_jobs) + ) products = np.sum(prods, axis=0) return products -def _do_self_dots_subset(intrad, rmags, rlens, cosmags, ws, volume, lut, - n_fact, ch_type, idx): +def _do_self_dots_subset( + intrad, rmags, rlens, cosmags, ws, volume, lut, n_fact, ch_type, idx +): """Parallelize.""" # all possible combinations of two magnetometers products = np.zeros((len(rmags), len(rmags))) for ci1 in idx: ci2 = ci1 + 1 res = _fast_sphere_dot_r0( - intrad, rmags[ci1], rmags[:ci2], rlens[ci1], rlens[:ci2], - cosmags[ci1], cosmags[:ci2], ws[ci1], ws[:ci2], volume, lut, - n_fact, ch_type) + intrad, + rmags[ci1], + rmags[:ci2], + rlens[ci1], + rlens[:ci2], + cosmags[ci1], + cosmags[:ci2], + ws[ci1], + ws[:ci2], + volume, + lut, + n_fact, + ch_type, + ) products[ci1, :ci2] = res products[:ci2, ci1] = res return products -def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, - lut, n_fact): +def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, lut, n_fact): """Compute lead field dot product integrations between two coil sets. The code is a direct translation of MNE-C code found in @@ -378,10 +418,10 @@ def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, products : array, shape (n_coils, n_coils) The integration products. """ - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 - rmags1 = [coil['rmag'] - r0[np.newaxis, :] for coil in coils1] - rmags2 = [coil['rmag'] - r0[np.newaxis, :] for coil in coils2] + rmags1 = [coil["rmag"] - r0[np.newaxis, :] for coil in coils1] + rmags2 = [coil["rmag"] - r0[np.newaxis, :] for coil in coils2] rlens1 = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags1] rlens2 = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags2] @@ -389,24 +429,37 @@ def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, rmags1 = [r / rl[:, np.newaxis] for r, rl in zip(rmags1, rlens1)] rmags2 = [r / rl[:, np.newaxis] for r, rl in zip(rmags2, rlens2)] - ws1 = [coil['w'] for coil in coils1] - ws2 = [coil['w'] for coil in coils2] + ws1 = [coil["w"] for coil in coils1] + ws2 = [coil["w"] for coil in coils2] - cosmags1 = [coil['cosmag'] for coil in coils1] - cosmags2 = [coil['cosmag'] for coil in coils2] + cosmags1 = [coil["cosmag"] for coil in coils1] + cosmags2 = [coil["cosmag"] for coil in coils2] products = np.zeros((len(rmags1), len(rmags2))) for ci1 in range(len(coils1)): res = _fast_sphere_dot_r0( - intrad, rmags1[ci1], rmags2, rlens1[ci1], rlens2, cosmags1[ci1], - cosmags2, ws1[ci1], ws2, volume, lut, n_fact, ch_type) + intrad, + rmags1[ci1], + rmags2, + rlens1[ci1], + rlens2, + cosmags1[ci1], + cosmags2, + ws1[ci1], + ws2, + volume, + lut, + n_fact, + ch_type, + ) products[ci1, :] = res return products @fill_doc -def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, - lut, n_fact, n_jobs): +def _do_surface_dots( + intrad, volume, coils, surf, sel, r0, ch_type, lut, n_fact, n_jobs +): """Compute the map construction products. Parameters @@ -438,15 +491,15 @@ def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, The integration products. """ # convert to normalized distances from expansion center - rmags = [coil['rmag'] - r0[np.newaxis, :] for coil in coils] + rmags = [coil["rmag"] - r0[np.newaxis, :] for coil in coils] rlens = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags] rmags = [r / rl[:, np.newaxis] for r, rl in zip(rmags, rlens)] - cosmags = [coil['cosmag'] for coil in coils] - ws = [coil['w'] for coil in coils] + cosmags = [coil["cosmag"] for coil in coils] + ws = [coil["w"] for coil in coils] rref = None refl = None # virt_ref = False - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 # The virtual ref code is untested and unused, so it is # commented out for now @@ -455,24 +508,54 @@ def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, # refl = np.sqrt(np.sum(rref * rref, axis=1)) # rref /= refl[:, np.newaxis] - rsurf = surf['rr'][sel] - r0[np.newaxis, :] + rsurf = surf["rr"][sel] - r0[np.newaxis, :] lsurf = np.sqrt(np.sum(rsurf * rsurf, axis=1)) rsurf /= lsurf[:, np.newaxis] - this_nn = surf['nn'][sel] + this_nn = surf["nn"][sel] # loop over the coils parallel, p_fun, n_jobs = parallel_func(_do_surface_dots_subset, n_jobs) - prods = parallel(p_fun(intrad, rsurf, rmags, rref, refl, lsurf, rlens, - this_nn, cosmags, ws, volume, lut, n_fact, ch_type, - idx) - for idx in np.array_split(np.arange(len(rmags)), n_jobs)) + prods = parallel( + p_fun( + intrad, + rsurf, + rmags, + rref, + refl, + lsurf, + rlens, + this_nn, + cosmags, + ws, + volume, + lut, + n_fact, + ch_type, + idx, + ) + for idx in np.array_split(np.arange(len(rmags)), n_jobs) + ) products = np.sum(prods, axis=0) return products -def _do_surface_dots_subset(intrad, rsurf, rmags, rref, refl, lsurf, rlens, - this_nn, cosmags, ws, volume, lut, n_fact, ch_type, - idx): +def _do_surface_dots_subset( + intrad, + rsurf, + rmags, + rref, + refl, + lsurf, + rlens, + this_nn, + cosmags, + ws, + volume, + lut, + n_fact, + ch_type, + idx, +): """Parallelize. Parameters @@ -507,8 +590,20 @@ def _do_surface_dots_subset(intrad, rsurf, rmags, rref, refl, lsurf, rlens, The integration products. """ products = _fast_sphere_dot_r0( - intrad, rsurf, rmags, lsurf, rlens, this_nn, cosmags, None, ws, - volume, lut, n_fact, ch_type).T + intrad, + rsurf, + rmags, + lsurf, + rlens, + this_nn, + cosmags, + None, + ws, + volume, + lut, + n_fact, + ch_type, + ).T if rref is not None: raise NotImplementedError # we don't ever use this, isn't tested # vres = _fast_sphere_dot_r0( diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 34be8d023cd..44783393ab8 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -21,23 +21,35 @@ from ..io.compensator import get_current_comp, make_compensator from ..io.pick import _has_kit_refs, pick_types, pick_info from ..io.constants import FIFF, FWD -from ..transforms import (_ensure_trans, transform_surface_to, apply_trans, - _get_trans, _print_coord_trans, _coord_frame_name, - Transform, invert_transform) +from ..transforms import ( + _ensure_trans, + transform_surface_to, + apply_trans, + _get_trans, + _print_coord_trans, + _coord_frame_name, + Transform, + invert_transform, +) from ..utils import logger, verbose, warn, _pl, _validate_type, _check_fname -from ..source_space import (_ensure_src, _filter_source_spaces, - _make_discrete_source_space, _complete_vol_src) +from ..source_space import ( + _ensure_src, + _filter_source_spaces, + _make_discrete_source_space, + _complete_vol_src, +) from ..source_estimate import VolSourceEstimate from ..surface import _normalize_vectors, _CheckInside from ..bem import read_bem_solution, _bem_find_surface, ConductorModel -from .forward import (Forward, _merge_fwds, convert_forward_solution, - _FWD_ORDER) +from .forward import Forward, _merge_fwds, convert_forward_solution, _FWD_ORDER -_accuracy_dict = dict(point=FWD.COIL_ACCURACY_POINT, - normal=FWD.COIL_ACCURACY_NORMAL, - accurate=FWD.COIL_ACCURACY_ACCURATE) +_accuracy_dict = dict( + point=FWD.COIL_ACCURACY_POINT, + normal=FWD.COIL_ACCURACY_NORMAL, + accurate=FWD.COIL_ACCURACY_ACCURATE, +) _extra_coil_def_fname = None @@ -63,11 +75,11 @@ def _read_coil_defs(verbose=None): The global variable "_extra_coil_def_fname" can be used to prepend additional definitions. These are never added to the registry. """ - coil_dir = op.join(op.split(__file__)[0], '..', 'data') + coil_dir = op.join(op.split(__file__)[0], "..", "data") coils = list() if _extra_coil_def_fname is not None: coils += _read_coil_def_file(_extra_coil_def_fname, use_registry=False) - coils += _read_coil_def_file(op.join(coil_dir, 'coil_def.dat')) + coils += _read_coil_def_file(op.join(coil_dir, "coil_def.dat")) return coils @@ -81,23 +93,28 @@ def _read_coil_def_file(fname, use_registry=True): if not use_registry or fname not in _coil_registry: big_val = 0.5 coils = list() - with open(fname, 'r') as fid: + with open(fname, "r") as fid: lines = fid.readlines() lines = lines[::-1] while len(lines) > 0: line = lines.pop().strip() - if line[0] == '#' and len(line) > 0: + if line[0] == "#" and len(line) > 0: continue desc_start = line.find('"') desc_end = len(line) - 1 assert line.strip()[desc_end] == '"' desc = line[desc_start:desc_end] - vals = np.fromstring(line[:desc_start].strip(), - dtype=float, sep=' ') + vals = np.fromstring(line[:desc_start].strip(), dtype=float, sep=" ") assert len(vals) == 6 npts = int(vals[3]) - coil = dict(coil_type=vals[1], coil_class=vals[0], desc=desc, - accuracy=vals[2], size=vals[4], base=vals[5]) + coil = dict( + coil_type=vals[1], + coil_class=vals[0], + desc=desc, + accuracy=vals[2], + size=vals[4], + base=vals[5], + ) # get parameters of each component rmag = list() cosmag = list() @@ -105,13 +122,13 @@ def _read_coil_def_file(fname, use_registry=True): for p in range(npts): # get next non-comment line line = lines.pop() - while line[0] == '#': + while line[0] == "#": line = lines.pop() - vals = np.fromstring(line, sep=' ') + vals = np.fromstring(line, sep=" ") if len(vals) != 7: raise RuntimeError( - f'Could not interpret line {p + 1} as 7 points:\n' - f'{line}') + f"Could not interpret line {p + 1} as 7 points:\n" f"{line}" + ) # Read and verify data for each integration point w.append(vals[0]) rmag.append(vals[[1, 2, 3]]) @@ -119,11 +136,11 @@ def _read_coil_def_file(fname, use_registry=True): w = np.array(w) rmag = np.array(rmag) cosmag = np.array(cosmag) - size = np.sqrt(np.sum(cosmag ** 2, axis=1)) - if np.any(np.sqrt(np.sum(rmag ** 2, axis=1)) > big_val): - raise RuntimeError('Unreasonable integration point') + size = np.sqrt(np.sum(cosmag**2, axis=1)) + if np.any(np.sqrt(np.sum(rmag**2, axis=1)) > big_val): + raise RuntimeError("Unreasonable integration point") if np.any(size <= 0): - raise RuntimeError('Unreasonable normal') + raise RuntimeError("Unreasonable normal") cosmag /= size[:, np.newaxis] coil.update(dict(w=w, cosmag=cosmag, rmag=rmag)) coils.append(coil) @@ -131,70 +148,92 @@ def _read_coil_def_file(fname, use_registry=True): _coil_registry[fname] = coils if use_registry: coils = deepcopy(_coil_registry[fname]) - logger.info('%d coil definition%s read', len(coils), _pl(coils)) + logger.info("%d coil definition%s read", len(coils), _pl(coils)) return coils def _create_meg_coil(coilset, ch, acc, do_es): """Create a coil definition using templates, transform if necessary.""" # Also change the coordinate frame if so desired - if ch['kind'] not in [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH]: - raise RuntimeError('%s is not a MEG channel' % ch['ch_name']) + if ch["kind"] not in [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH]: + raise RuntimeError("%s is not a MEG channel" % ch["ch_name"]) # Simple linear search from the coil definitions for coil in coilset: - if coil['coil_type'] == (ch['coil_type'] & 0xFFFF) and \ - coil['accuracy'] == acc: + if coil["coil_type"] == (ch["coil_type"] & 0xFFFF) and coil["accuracy"] == acc: break else: - raise RuntimeError('Desired coil definition not found ' - '(type = %d acc = %d)' % (ch['coil_type'], acc)) + raise RuntimeError( + "Desired coil definition not found " + "(type = %d acc = %d)" % (ch["coil_type"], acc) + ) # Apply a coordinate transformation if so desired - coil_trans = _loc_to_coil_trans(ch['loc']) + coil_trans = _loc_to_coil_trans(ch["loc"]) # Create the result - res = dict(chname=ch['ch_name'], coil_class=coil['coil_class'], - accuracy=coil['accuracy'], base=coil['base'], size=coil['size'], - type=ch['coil_type'], w=coil['w'], desc=coil['desc'], - coord_frame=FIFF.FIFFV_COORD_DEVICE, rmag_orig=coil['rmag'], - cosmag_orig=coil['cosmag'], coil_trans_orig=coil_trans, - r0=coil_trans[:3, 3], - rmag=apply_trans(coil_trans, coil['rmag']), - cosmag=apply_trans(coil_trans, coil['cosmag'], False)) + res = dict( + chname=ch["ch_name"], + coil_class=coil["coil_class"], + accuracy=coil["accuracy"], + base=coil["base"], + size=coil["size"], + type=ch["coil_type"], + w=coil["w"], + desc=coil["desc"], + coord_frame=FIFF.FIFFV_COORD_DEVICE, + rmag_orig=coil["rmag"], + cosmag_orig=coil["cosmag"], + coil_trans_orig=coil_trans, + r0=coil_trans[:3, 3], + rmag=apply_trans(coil_trans, coil["rmag"]), + cosmag=apply_trans(coil_trans, coil["cosmag"], False), + ) if do_es: - r0_exey = (np.dot(coil['rmag'][:, :2], coil_trans[:3, :2].T) + - coil_trans[:3, 3]) - res.update(ex=coil_trans[:3, 0], ey=coil_trans[:3, 1], - ez=coil_trans[:3, 2], r0_exey=r0_exey) + r0_exey = np.dot(coil["rmag"][:, :2], coil_trans[:3, :2].T) + coil_trans[:3, 3] + res.update( + ex=coil_trans[:3, 0], + ey=coil_trans[:3, 1], + ez=coil_trans[:3, 2], + r0_exey=r0_exey, + ) return res def _create_eeg_el(ch, t=None): """Create an electrode definition, transform coords if necessary.""" - if ch['kind'] != FIFF.FIFFV_EEG_CH: - raise RuntimeError('%s is not an EEG channel. Cannot create an ' - 'electrode definition.' % ch['ch_name']) + if ch["kind"] != FIFF.FIFFV_EEG_CH: + raise RuntimeError( + "%s is not an EEG channel. Cannot create an " + "electrode definition." % ch["ch_name"] + ) if t is None: - t = Transform('head', 'head') # identity, no change - if t.from_str != 'head': - raise RuntimeError('Inappropriate coordinate transformation') + t = Transform("head", "head") # identity, no change + if t.from_str != "head": + raise RuntimeError("Inappropriate coordinate transformation") - r0ex = _loc_to_eeg_loc(ch['loc']) + r0ex = _loc_to_eeg_loc(ch["loc"]) if r0ex.shape[1] == 1: # no reference - w = np.array([1.]) + w = np.array([1.0]) else: # has reference - w = np.array([1., -1.]) + w = np.array([1.0, -1.0]) # Optional coordinate transformation - r0ex = apply_trans(t['trans'], r0ex.T) + r0ex = apply_trans(t["trans"], r0ex.T) # The electrode location cosmag = r0ex.copy() _normalize_vectors(cosmag) - res = dict(chname=ch['ch_name'], coil_class=FWD.COILC_EEG, w=w, - accuracy=_accuracy_dict['normal'], type=ch['coil_type'], - coord_frame=t['to'], rmag=r0ex, cosmag=cosmag) + res = dict( + chname=ch["ch_name"], + coil_class=FWD.COILC_EEG, + w=w, + accuracy=_accuracy_dict["normal"], + type=ch["coil_type"], + coord_frame=t["to"], + rmag=r0ex, + cosmag=cosmag, + ) return res @@ -212,16 +251,24 @@ def _transform_orig_meg_coils(coils, t, do_es=True): if t is None: return for coil in coils: - coil_trans = np.dot(t['trans'], coil['coil_trans_orig']) + coil_trans = np.dot(t["trans"], coil["coil_trans_orig"]) coil.update( - coord_frame=t['to'], r0=coil_trans[:3, 3], - rmag=apply_trans(coil_trans, coil['rmag_orig']), - cosmag=apply_trans(coil_trans, coil['cosmag_orig'], False)) + coord_frame=t["to"], + r0=coil_trans[:3, 3], + rmag=apply_trans(coil_trans, coil["rmag_orig"]), + cosmag=apply_trans(coil_trans, coil["cosmag_orig"], False), + ) if do_es: - r0_exey = (np.dot(coil['rmag_orig'][:, :2], - coil_trans[:3, :2].T) + coil_trans[:3, 3]) - coil.update(ex=coil_trans[:3, 0], ey=coil_trans[:3, 1], - ez=coil_trans[:3, 2], r0_exey=r0_exey) + r0_exey = ( + np.dot(coil["rmag_orig"][:, :2], coil_trans[:3, :2].T) + + coil_trans[:3, 3] + ) + coil.update( + ex=coil_trans[:3, 0], + ey=coil_trans[:3, 1], + ez=coil_trans[:3, 2], + r0_exey=r0_exey, + ) def _create_eeg_els(chs): @@ -230,47 +277,58 @@ def _create_eeg_els(chs): @verbose -def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, - verbose=None): +def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None): """Set up a BEM for forward computation, making a copy and modifying.""" if allow_none and bem is None: return None - logger.info('') - _validate_type(bem, ('path-like', ConductorModel), bem) + logger.info("") + _validate_type(bem, ("path-like", ConductorModel), bem) if not isinstance(bem, ConductorModel): - logger.info('Setting up the BEM model using %s...\n' % bem_extra) + logger.info("Setting up the BEM model using %s...\n" % bem_extra) bem = read_bem_solution(bem) else: bem = bem.copy() - if bem['is_sphere']: - logger.info('Using the sphere model.\n') - if len(bem['layers']) == 0 and neeg > 0: - raise RuntimeError('Spherical model has zero shells, cannot use ' - 'with EEG data') - if bem['coord_frame'] != FIFF.FIFFV_COORD_HEAD: - raise RuntimeError('Spherical model is not in head coordinates') + if bem["is_sphere"]: + logger.info("Using the sphere model.\n") + if len(bem["layers"]) == 0 and neeg > 0: + raise RuntimeError( + "Spherical model has zero shells, cannot use " "with EEG data" + ) + if bem["coord_frame"] != FIFF.FIFFV_COORD_HEAD: + raise RuntimeError("Spherical model is not in head coordinates") else: - if bem['surfs'][0]['coord_frame'] != FIFF.FIFFV_COORD_MRI: + if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: + raise RuntimeError( + "BEM is in %s coordinates, should be in MRI" + % (_coord_frame_name(bem["surfs"][0]["coord_frame"]),) + ) + if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( - 'BEM is in %s coordinates, should be in MRI' - % (_coord_frame_name(bem['surfs'][0]['coord_frame']),)) - if neeg > 0 and len(bem['surfs']) == 1: - raise RuntimeError('Cannot use a homogeneous (1-layer BEM) model ' - 'for EEG forward calculations, consider ' - 'using a 3-layer BEM instead') - logger.info('Employing the head->MRI coordinate transform with the ' - 'BEM model.') + "Cannot use a homogeneous (1-layer BEM) model " + "for EEG forward calculations, consider " + "using a 3-layer BEM instead" + ) + logger.info( + "Employing the head->MRI coordinate transform with the " "BEM model." + ) # fwd_bem_set_head_mri_t: Set the coordinate transformation - bem['head_mri_t'] = _ensure_trans(mri_head_t, 'head', 'mri') - logger.info('BEM model %s is now set up' % op.split(bem_extra)[1]) - logger.info('') + bem["head_mri_t"] = _ensure_trans(mri_head_t, "head", "mri") + logger.info("BEM model %s is now set up" % op.split(bem_extra)[1]) + logger.info("") return bem @verbose -def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, - ignore_ref=False, head_frame=True, do_es=False, - verbose=None): +def _prep_meg_channels( + info, + accuracy="accurate", + exclude=(), + *, + ignore_ref=False, + head_frame=True, + do_es=False, + verbose=None, +): """Prepare MEG coil definitions for forward calculation.""" # Find MEG channels ref_meg = True if not ignore_ref else False @@ -278,7 +336,7 @@ def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, # Make sure MEG coils exist if len(picks) <= 0: - raise RuntimeError('Could not find any MEG channels') + raise RuntimeError("Could not find any MEG channels") info_meg = pick_info(info, picks) del picks @@ -287,95 +345,110 @@ def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, # Get MEG compensation channels compensator = post_picks = None - ch_names = info_meg['ch_names'] + ch_names = info_meg["ch_names"] if not ignore_ref: ref_picks = pick_types(info, meg=False, ref_meg=True, exclude=exclude) ncomp = len(ref_picks) - if (ncomp > 0): - logger.info(f'Read {ncomp} MEG compensation channels from info') + if ncomp > 0: + logger.info(f"Read {ncomp} MEG compensation channels from info") # We need to check to make sure these are NOT KIT refs if _has_kit_refs(info, ref_picks): raise NotImplementedError( - 'Cannot create forward solution with KIT reference ' + "Cannot create forward solution with KIT reference " 'channels. Consider using "ignore_ref=True" in ' - 'calculation') - logger.info( - f'{len(info["comps"])} compensation data sets in info') + "calculation" + ) + logger.info(f'{len(info["comps"])} compensation data sets in info') # Compose a compensation data set if necessary # adapted from mne_make_ctf_comp() from mne_ctf_comp.c - logger.info('Setting up compensation data...') + logger.info("Setting up compensation data...") comp_num = get_current_comp(info) if comp_num is None or comp_num == 0: - logger.info(' No compensation set. Nothing more to do.') + logger.info(" No compensation set. Nothing more to do.") else: compensator = make_compensator( - info_meg, 0, comp_num, exclude_comp_chs=False) - logger.info( - f' Desired compensation data ({comp_num}) found.') - logger.info(' All compensation channels found.') - logger.info(' Preselector created.') - logger.info(' Compensation data matrix created.') - logger.info(' Postselector created.') - post_picks = pick_types( - info_meg, meg=True, ref_meg=False, exclude=exclude) + info_meg, 0, comp_num, exclude_comp_chs=False + ) + logger.info(f" Desired compensation data ({comp_num}) found.") + logger.info(" All compensation channels found.") + logger.info(" Preselector created.") + logger.info(" Compensation data matrix created.") + logger.info(" Postselector created.") + post_picks = pick_types(info_meg, meg=True, ref_meg=False, exclude=exclude) ch_names = [ch_names[pick] for pick in post_picks] # Create coil descriptions with transformation to head or device frame templates = _read_coil_defs() if head_frame: - _print_coord_trans(info['dev_head_t']) - transform = info['dev_head_t'] + _print_coord_trans(info["dev_head_t"]) + transform = info["dev_head_t"] else: transform = None megcoils = _create_meg_coils( - info_meg['chs'], accuracy, transform, templates, do_es=do_es) + info_meg["chs"], accuracy, transform, templates, do_es=do_es + ) # Check that coordinate frame is correct and log it if head_frame: - assert megcoils[0]['coord_frame'] == FIFF.FIFFV_COORD_HEAD - logger.info('MEG coil definitions created in head coordinates.') + assert megcoils[0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD + logger.info("MEG coil definitions created in head coordinates.") else: - assert megcoils[0]['coord_frame'] == FIFF.FIFFV_COORD_DEVICE - logger.info('MEG coil definitions created in device coordinate.') + assert megcoils[0]["coord_frame"] == FIFF.FIFFV_COORD_DEVICE + logger.info("MEG coil definitions created in device coordinate.") return dict( - defs=megcoils, ch_names=ch_names, compensator=compensator, - info=info_meg, post_picks=post_picks) + defs=megcoils, + ch_names=ch_names, + compensator=compensator, + info=info_meg, + post_picks=post_picks, + ) @verbose def _prep_eeg_channels(info, exclude=(), verbose=None): """Prepare EEG electrode definitions for forward calculation.""" - info_extra = 'info' + info_extra = "info" # Find EEG electrodes - picks = pick_types(info, meg=False, eeg=True, ref_meg=False, - exclude=exclude) + picks = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=exclude) # Make sure EEG electrodes exist neeg = len(picks) if neeg <= 0: - raise RuntimeError('Could not find any EEG channels') + raise RuntimeError("Could not find any EEG channels") # Get channel info and names for EEG channels - eegchs = pick_info(info, picks)['chs'] - eegnames = [info['ch_names'][p] for p in picks] - logger.info('Read %3d EEG channels from %s' % (len(picks), info_extra)) + eegchs = pick_info(info, picks)["chs"] + eegnames = [info["ch_names"][p] for p in picks] + logger.info("Read %3d EEG channels from %s" % (len(picks), info_extra)) # Create EEG electrode descriptions eegels = _create_eeg_els(eegchs) - logger.info('Head coordinate coil definitions created.') + logger.info("Head coordinate coil definitions created.") return dict(defs=eegels, ch_names=eegnames) @verbose -def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, - bem_extra='', trans='', info_extra='', - meg=True, eeg=True, ignore_ref=False, - allow_bem_none=False, verbose=None): +def _prepare_for_forward( + src, + mri_head_t, + info, + bem, + mindist, + n_jobs, + bem_extra="", + trans="", + info_extra="", + meg=True, + eeg=True, + ignore_ref=False, + allow_bem_none=False, + verbose=None, +): """Prepare for forward computation. The sensors dict contains keys for each sensor type, e.g. 'meg', 'eeg'. @@ -389,116 +462,157 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, compensator """ # Read the source locations - logger.info('') + logger.info("") # let's make a copy in case we modify something src = _ensure_src(src).copy() - nsource = sum(s['nuse'] for s in src) + nsource = sum(s["nuse"] for s in src) if nsource == 0: - raise RuntimeError('No sources are active in these source spaces. ' - '"do_all" option should be used.') - logger.info('Read %d source spaces a total of %d active source locations' - % (len(src), nsource)) + raise RuntimeError( + "No sources are active in these source spaces. " + '"do_all" option should be used.' + ) + logger.info( + "Read %d source spaces a total of %d active source locations" + % (len(src), nsource) + ) # Delete some keys to clean up the source space: - for key in ['working_dir', 'command_line']: + for key in ["working_dir", "command_line"]: if key in src.info: del src.info[key] # Read the MRI -> head coordinate transformation - logger.info('') + logger.info("") _print_coord_trans(mri_head_t) # make a new dict with the relevant information - arg_list = [info_extra, trans, src, bem_extra, meg, eeg, mindist, - n_jobs, verbose] - cmd = 'make_forward_solution(%s)' % (', '.join([str(a) for a in arg_list])) + arg_list = [info_extra, trans, src, bem_extra, meg, eeg, mindist, n_jobs, verbose] + cmd = "make_forward_solution(%s)" % (", ".join([str(a) for a in arg_list])) mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0) info_trans = str(trans) if isinstance(trans, Path) else trans - info = Info(chs=info['chs'], comps=info['comps'], - dev_head_t=info['dev_head_t'], mri_file=info_trans, - mri_id=mri_id, - meas_file=info_extra, meas_id=None, working_dir=os.getcwd(), - command_line=cmd, bads=info['bads'], mri_head_t=mri_head_t) + info = Info( + chs=info["chs"], + comps=info["comps"], + dev_head_t=info["dev_head_t"], + mri_file=info_trans, + mri_id=mri_id, + meas_file=info_extra, + meas_id=None, + working_dir=os.getcwd(), + command_line=cmd, + bads=info["bads"], + mri_head_t=mri_head_t, + ) info._update_redundant() info._check_consistency() - logger.info('') + logger.info("") sensors = dict() if meg and len(pick_types(info, meg=True, ref_meg=False, exclude=[])) > 0: - sensors['meg'] = _prep_meg_channels(info, ignore_ref=ignore_ref) + sensors["meg"] = _prep_meg_channels(info, ignore_ref=ignore_ref) if eeg and len(pick_types(info, eeg=True, exclude=[])) > 0: - sensors['eeg'] = _prep_eeg_channels(info) + sensors["eeg"] = _prep_eeg_channels(info) # Check that some channels were found if len(sensors) == 0: - raise RuntimeError('No MEG or EEG channels found.') + raise RuntimeError("No MEG or EEG channels found.") # pick out final info - info = pick_info(info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False, - exclude=[])) + info = pick_info( + info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False, exclude=[]) + ) # Transform the source spaces into the appropriate coordinates # (will either be HEAD or MRI) for s in src: - transform_surface_to(s, 'head', mri_head_t) - logger.info('Source spaces are now in %s coordinates.' - % _coord_frame_name(s['coord_frame'])) + transform_surface_to(s, "head", mri_head_t) + logger.info( + "Source spaces are now in %s coordinates." % _coord_frame_name(s["coord_frame"]) + ) # Prepare the BEM model - eegnames = sensors.get('eeg', dict()).get('ch_names', []) - bem = _setup_bem(bem, bem_extra, len(eegnames), mri_head_t, - allow_none=allow_bem_none) + eegnames = sensors.get("eeg", dict()).get("ch_names", []) + bem = _setup_bem( + bem, bem_extra, len(eegnames), mri_head_t, allow_none=allow_bem_none + ) del eegnames # Circumvent numerical problems by excluding points too close to the skull, # and check that sensors are not inside any BEM surface if bem is not None: - if not bem['is_sphere']: - check_surface = 'inner skull surface' - inner_skull = _bem_find_surface(bem, 'inner_skull') + if not bem["is_sphere"]: + check_surface = "inner skull surface" + inner_skull = _bem_find_surface(bem, "inner_skull") check_inside = _filter_source_spaces( - inner_skull, mindist, mri_head_t, src, n_jobs) - logger.info('') - if len(bem['surfs']) == 3: - check_surface = 'scalp surface' - check_inside = _CheckInside(_bem_find_surface(bem, 'head')) + inner_skull, mindist, mri_head_t, src, n_jobs + ) + logger.info("") + if len(bem["surfs"]) == 3: + check_surface = "scalp surface" + check_inside = _CheckInside(_bem_find_surface(bem, "head")) else: - check_surface = 'outermost sphere shell' - if len(bem['layers']) == 0: + check_surface = "outermost sphere shell" + if len(bem["layers"]) == 0: + def check_inside(x): return np.zeros(len(x), bool) + else: + def check_inside(x): - return (np.linalg.norm(x - bem['r0'], axis=1) < - bem['layers'][-1]['rad']) - if 'meg' in sensors: + return ( + np.linalg.norm(x - bem["r0"], axis=1) < bem["layers"][-1]["rad"] + ) + + if "meg" in sensors: meg_loc = apply_trans( invert_transform(mri_head_t), - np.array([coil['r0'] for coil in sensors['meg']['defs']])) + np.array([coil["r0"] for coil in sensors["meg"]["defs"]]), + ) n_inside = check_inside(meg_loc).sum() if n_inside: raise RuntimeError( - f'Found {n_inside} MEG sensor{_pl(n_inside)} inside the ' - f'{check_surface}, perhaps coordinate frames and/or ' - 'coregistration must be incorrect') + f"Found {n_inside} MEG sensor{_pl(n_inside)} inside the " + f"{check_surface}, perhaps coordinate frames and/or " + "coregistration must be incorrect" + ) - rr = np.concatenate([s['rr'][s['vertno']] for s in src]) + rr = np.concatenate([s["rr"][s["vertno"]] for s in src]) if len(rr) < 1: - raise RuntimeError('No points left in source space after excluding ' - 'points close to inner skull.') + raise RuntimeError( + "No points left in source space after excluding " + "points close to inner skull." + ) # deal with free orientations: source_nn = np.tile(np.eye(3), (len(rr), 1)) - update_kwargs = dict(nchan=len(info['ch_names']), nsource=len(rr), - info=info, src=src, source_nn=source_nn, - source_rr=rr, surf_ori=False, mri_head_t=mri_head_t) + update_kwargs = dict( + nchan=len(info["ch_names"]), + nsource=len(rr), + info=info, + src=src, + source_nn=source_nn, + source_rr=rr, + surf_ori=False, + mri_head_t=mri_head_t, + ) return sensors, rr, info, update_kwargs, bem @verbose -def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, *, - mindist=0.0, ignore_ref=False, n_jobs=None, - verbose=None): +def make_forward_solution( + info, + trans, + src, + bem, + meg=True, + eeg=True, + *, + mindist=0.0, + ignore_ref=False, + n_jobs=None, + verbose=None, +): """Calculate a forward solution for a subject. Parameters @@ -561,61 +675,72 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, *, # (could also be HEAD to MRI) mri_head_t, trans = _get_trans(trans) if isinstance(bem, ConductorModel): - bem_extra = 'instance of ConductorModel' + bem_extra = "instance of ConductorModel" else: bem_extra = bem - _validate_type(info, ('path-like', Info), 'info') + _validate_type(info, ("path-like", Info), "info") if not isinstance(info, Info): info_extra = op.split(info)[1] - info = _check_fname(info, must_exist=True, overwrite='read', - name='info') + info = _check_fname(info, must_exist=True, overwrite="read", name="info") info = read_info(info, verbose=False) else: - info_extra = 'instance of Info' + info_extra = "instance of Info" # Report the setup - logger.info('Source space : %s' % src) - logger.info('MRI -> head transform : %s' % trans) - logger.info('Measurement data : %s' % info_extra) - if isinstance(bem, ConductorModel) and bem['is_sphere']: - logger.info('Sphere model : origin at %s mm' - % (bem['r0'],)) - logger.info('Standard field computations') + logger.info("Source space : %s" % src) + logger.info("MRI -> head transform : %s" % trans) + logger.info("Measurement data : %s" % info_extra) + if isinstance(bem, ConductorModel) and bem["is_sphere"]: + logger.info("Sphere model : origin at %s mm" % (bem["r0"],)) + logger.info("Standard field computations") else: - logger.info('Conductor model : %s' % bem_extra) - logger.info('Accurate field computations') - logger.info('Do computations in %s coordinates', - _coord_frame_name(FIFF.FIFFV_COORD_HEAD)) - logger.info('Free source orientations') + logger.info("Conductor model : %s" % bem_extra) + logger.info("Accurate field computations") + logger.info( + "Do computations in %s coordinates", _coord_frame_name(FIFF.FIFFV_COORD_HEAD) + ) + logger.info("Free source orientations") # Create MEG coils and EEG electrodes in the head coordinate frame sensors, rr, info, update_kwargs, bem = _prepare_for_forward( - src, mri_head_t, info, bem, mindist, n_jobs, bem_extra, trans, - info_extra, meg, eeg, ignore_ref) - del (src, mri_head_t, trans, info_extra, bem_extra, mindist, - meg, eeg, ignore_ref) + src, + mri_head_t, + info, + bem, + mindist, + n_jobs, + bem_extra, + trans, + info_extra, + meg, + eeg, + ignore_ref, + ) + del (src, mri_head_t, trans, info_extra, bem_extra, mindist, meg, eeg, ignore_ref) # Time to do the heavy lifting: MEG first, then EEG fwds = _compute_forwards(rr, bem=bem, sensors=sensors, n_jobs=n_jobs) # merge forwards - fwds = {key: _to_forward_dict(fwds[key], sensors[key]['ch_names']) - for key in _FWD_ORDER if key in fwds} + fwds = { + key: _to_forward_dict(fwds[key], sensors[key]["ch_names"]) + for key in _FWD_ORDER + if key in fwds + } fwd = _merge_fwds(fwds, verbose=False) del fwds - logger.info('') + logger.info("") # Don't transform the source spaces back into MRI coordinates (which is # done in the C code) because mne-python assumes forward solution source # spaces are in head coords. fwd.update(**update_kwargs) - logger.info('Finished.') + logger.info("Finished.") return fwd @verbose -def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, - verbose=None): +def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, verbose=None): """Convert dipole object to source estimate and calculate forward operator. The instance of Dipole is converted to a discrete source space, @@ -662,6 +787,7 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, """ if isinstance(dipole, list): from ..dipole import _concatenate_dipoles # To avoid circular import + dipole = _concatenate_dipoles(dipole) # Make copies to avoid mangling original dipole @@ -674,31 +800,29 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, # NB information about dipole orientation enters here, then no more sources = dict(rr=pos, nn=ori) # Dipole objects must be in the head frame - src = _complete_vol_src( - [_make_discrete_source_space(sources, coord_frame='head')]) + src = _complete_vol_src([_make_discrete_source_space(sources, coord_frame="head")]) # Forward operator created for channels in info (use pick_info to restrict) # Use defaults for most params, including min_dist - fwd = make_forward_solution(info, trans, src, bem, n_jobs=n_jobs, - verbose=verbose) + fwd = make_forward_solution(info, trans, src, bem, n_jobs=n_jobs, verbose=verbose) # Convert from free orientations to fixed (in-place) - convert_forward_solution(fwd, surf_ori=False, force_fixed=True, - copy=False, use_cps=False, verbose=None) + convert_forward_solution( + fwd, surf_ori=False, force_fixed=True, copy=False, use_cps=False, verbose=None + ) # Check for omissions due to proximity to inner skull in # make_forward_solution, which will result in an exception - if fwd['src'][0]['nuse'] != len(pos): - inuse = fwd['src'][0]['inuse'].astype(bool) - head = ('The following dipoles are outside the inner skull boundary') - msg = len(head) * '#' + '\n' + head + '\n' - for (t, pos) in zip(times[np.logical_not(inuse)], - pos[np.logical_not(inuse)]): - msg += ' t={:.0f} ms, pos=({:.0f}, {:.0f}, {:.0f}) mm\n'.\ - format(t * 1000., pos[0] * 1000., - pos[1] * 1000., pos[2] * 1000.) - msg += len(head) * '#' + if fwd["src"][0]["nuse"] != len(pos): + inuse = fwd["src"][0]["inuse"].astype(bool) + head = "The following dipoles are outside the inner skull boundary" + msg = len(head) * "#" + "\n" + head + "\n" + for t, pos in zip(times[np.logical_not(inuse)], pos[np.logical_not(inuse)]): + msg += " t={:.0f} ms, pos=({:.0f}, {:.0f}, {:.0f}) mm\n".format( + t * 1000.0, pos[0] * 1000.0, pos[1] * 1000.0, pos[2] * 1000.0 + ) + msg += len(head) * "#" logger.error(msg) - raise ValueError('One or more dipoles outside the inner skull.') + raise ValueError("One or more dipoles outside the inner skull.") # multiple dipoles (rr and nn) per time instant allowed # uneven sampling in time returns list @@ -706,8 +830,10 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, if len(timepoints) > 1: tdiff = np.diff(timepoints) if not np.allclose(tdiff, tdiff[0]): - warn('Unique time points of dipoles unevenly spaced: returned ' - 'stc will be a list, one for each time point.') + warn( + "Unique time points of dipoles unevenly spaced: returned " + "stc will be a list, one for each time point." + ) tstep = -1.0 else: tstep = tdiff[0] @@ -722,39 +848,64 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, row = 0 for tpind, tp in enumerate(timepoints): amp = amplitude[np.in1d(times, tp)] - data[row:row + len(amp), tpind] = amp + data[row : row + len(amp), tpind] = amp row += len(amp) if tstep > 0: - stc = VolSourceEstimate(data, vertices=[fwd['src'][0]['vertno']], - tmin=timepoints[0], - tstep=tstep, subject=None) + stc = VolSourceEstimate( + data, + vertices=[fwd["src"][0]["vertno"]], + tmin=timepoints[0], + tstep=tstep, + subject=None, + ) else: # Must return a list of stc, one for each time point stc = [] for col, tp in enumerate(timepoints): - stc += [VolSourceEstimate(data[:, col][:, np.newaxis], - vertices=[fwd['src'][0]['vertno']], - tmin=tp, tstep=0.001, subject=None)] + stc += [ + VolSourceEstimate( + data[:, col][:, np.newaxis], + vertices=[fwd["src"][0]["vertno"]], + tmin=tp, + tstep=0.001, + subject=None, + ) + ] return fwd, stc -def _to_forward_dict(fwd, names, fwd_grad=None, - coord_frame=FIFF.FIFFV_COORD_HEAD, - source_ori=FIFF.FIFFV_MNE_FREE_ORI): +def _to_forward_dict( + fwd, + names, + fwd_grad=None, + coord_frame=FIFF.FIFFV_COORD_HEAD, + source_ori=FIFF.FIFFV_MNE_FREE_ORI, +): """Convert forward solution matrices to dicts.""" assert names is not None - sol = dict(data=fwd.T, nrow=fwd.shape[1], ncol=fwd.shape[0], - row_names=names, col_names=[]) - fwd = Forward(sol=sol, source_ori=source_ori, nsource=sol['ncol'], - coord_frame=coord_frame, sol_grad=None, - nchan=sol['nrow'], _orig_source_ori=source_ori, - _orig_sol=sol['data'].copy(), _orig_sol_grad=None) + sol = dict( + data=fwd.T, nrow=fwd.shape[1], ncol=fwd.shape[0], row_names=names, col_names=[] + ) + fwd = Forward( + sol=sol, + source_ori=source_ori, + nsource=sol["ncol"], + coord_frame=coord_frame, + sol_grad=None, + nchan=sol["nrow"], + _orig_source_ori=source_ori, + _orig_sol=sol["data"].copy(), + _orig_sol_grad=None, + ) if fwd_grad is not None: - sol_grad = dict(data=fwd_grad.T, nrow=fwd_grad.shape[1], - ncol=fwd_grad.shape[0], row_names=names, - col_names=[]) - fwd.update(dict(sol_grad=sol_grad), - _orig_sol_grad=sol_grad['data'].copy()) + sol_grad = dict( + data=fwd_grad.T, + nrow=fwd_grad.shape[1], + ncol=fwd_grad.shape[0], + row_names=names, + col_names=[], + ) + fwd.update(dict(sol_grad=sol_grad), _orig_sol_grad=sol_grad["data"].copy()) return fwd diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 17ed07f8ac4..2f1e1c0b89d 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -25,30 +25,57 @@ from ..io.open import fiff_open from ..io.tree import dir_tree_find from ..io.tag import find_tag, read_tag -from ..io.matrix import (_read_named_matrix, _transpose_named_matrix, - write_named_matrix) -from ..io.meas_info import (_read_bad_channels, write_info, _write_ch_infos, - _read_extended_ch_info, _make_ch_names_mapping, - _write_bad_channels) -from ..io.pick import (pick_channels_forward, pick_info, pick_channels, - pick_types) -from ..io.write import (write_int, start_block, end_block, write_coord_trans, - write_string, start_and_end_file, write_id) +from ..io.matrix import _read_named_matrix, _transpose_named_matrix, write_named_matrix +from ..io.meas_info import ( + _read_bad_channels, + write_info, + _write_ch_infos, + _read_extended_ch_info, + _make_ch_names_mapping, + _write_bad_channels, +) +from ..io.pick import pick_channels_forward, pick_info, pick_channels, pick_types +from ..io.write import ( + write_int, + start_block, + end_block, + write_coord_trans, + write_string, + start_and_end_file, + write_id, +) from ..io.base import BaseRaw from ..evoked import Evoked, EvokedArray from ..epochs import BaseEpochs -from ..source_space import (_read_source_spaces_from_tree, - find_source_space_hemi, _set_source_space_vertices, - _write_source_spaces_to_fid, _get_src_nn, - _src_kind_dict) +from ..source_space import ( + _read_source_spaces_from_tree, + find_source_space_hemi, + _set_source_space_vertices, + _write_source_spaces_to_fid, + _get_src_nn, + _src_kind_dict, +) from ..source_estimate import _BaseVectorSourceEstimate, _BaseSourceEstimate from ..surface import _normal_orth -from ..transforms import (transform_surface_to, invert_transform, - write_trans) -from ..utils import (_check_fname, get_subjects_dir, has_mne_c, warn, - run_subprocess, check_fname, logger, verbose, fill_doc, - _validate_type, _check_compensation_grade, _check_option, - _check_stc_units, _stamp_to_dt, _on_missing, repr_html) +from ..transforms import transform_surface_to, invert_transform, write_trans +from ..utils import ( + _check_fname, + get_subjects_dir, + has_mne_c, + warn, + run_subprocess, + check_fname, + logger, + verbose, + fill_doc, + _validate_type, + _check_compensation_grade, + _check_option, + _check_stc_units, + _stamp_to_dt, + _on_missing, + repr_html, +) from ..label import Label @@ -133,69 +160,73 @@ def copy(self): return Forward(deepcopy(self)) def _get_src_type_and_ori_for_repr(self): - src_types = np.array([src['type'] for src in self['src']]) - - if (src_types == 'surf').all(): - src_type = 'Surface with %d vertices' % self['nsource'] - elif (src_types == 'vol').all(): - src_type = 'Volume with %d grid points' % self['nsource'] - elif (src_types == 'discrete').all(): - src_type = 'Discrete with %d dipoles' % self['nsource'] + src_types = np.array([src["type"] for src in self["src"]]) + + if (src_types == "surf").all(): + src_type = "Surface with %d vertices" % self["nsource"] + elif (src_types == "vol").all(): + src_type = "Volume with %d grid points" % self["nsource"] + elif (src_types == "discrete").all(): + src_type = "Discrete with %d dipoles" % self["nsource"] else: - count_string = '' - if (src_types == 'surf').any(): - count_string += '%d surface, ' % (src_types == 'surf').sum() - if (src_types == 'vol').any(): - count_string += '%d volume, ' % (src_types == 'vol').sum() - if (src_types == 'discrete').any(): - count_string += '%d discrete, ' \ - % (src_types == 'discrete').sum() - count_string = count_string.rstrip(', ') - src_type = ('Mixed (%s) with %d vertices' - % (count_string, self['nsource'])) - - if self['source_ori'] == FIFF.FIFFV_MNE_UNKNOWN_ORI: - src_ori = 'Unknown' - elif self['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: - src_ori = 'Fixed' - elif self['source_ori'] == FIFF.FIFFV_MNE_FREE_ORI: - src_ori = 'Free' + count_string = "" + if (src_types == "surf").any(): + count_string += "%d surface, " % (src_types == "surf").sum() + if (src_types == "vol").any(): + count_string += "%d volume, " % (src_types == "vol").sum() + if (src_types == "discrete").any(): + count_string += "%d discrete, " % (src_types == "discrete").sum() + count_string = count_string.rstrip(", ") + src_type = "Mixed (%s) with %d vertices" % (count_string, self["nsource"]) + + if self["source_ori"] == FIFF.FIFFV_MNE_UNKNOWN_ORI: + src_ori = "Unknown" + elif self["source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: + src_ori = "Fixed" + elif self["source_ori"] == FIFF.FIFFV_MNE_FREE_ORI: + src_ori = "Free" return src_type, src_ori def __repr__(self): """Summarize forward info instead of printing all.""" - entr = ' 0: - raise ValueError('Width of matrix must be a multiple of n') + raise ValueError("Width of matrix must be a multiple of n") tmp = np.arange(ma * bdn, dtype=np.int64).reshape(bdn, ma) tmp = np.tile(tmp, (1, n)) @@ -279,7 +311,7 @@ def _get_tag_int(fid, node, name, id_): tag = find_tag(fid, node, id_) if tag is None: fid.close() - raise ValueError(name + ' tag not found') + raise ValueError(name + " tag not found") return int(tag.data.item()) @@ -290,42 +322,44 @@ def _read_one(fid, node): return None one = Forward() - one['source_ori'] = _get_tag_int(fid, node, 'Source orientation', - FIFF.FIFF_MNE_SOURCE_ORIENTATION) - one['coord_frame'] = _get_tag_int(fid, node, 'Coordinate frame', - FIFF.FIFF_MNE_COORD_FRAME) - one['nsource'] = _get_tag_int(fid, node, 'Number of sources', - FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS) - one['nchan'] = _get_tag_int(fid, node, 'Number of channels', - FIFF.FIFF_NCHAN) + one["source_ori"] = _get_tag_int( + fid, node, "Source orientation", FIFF.FIFF_MNE_SOURCE_ORIENTATION + ) + one["coord_frame"] = _get_tag_int( + fid, node, "Coordinate frame", FIFF.FIFF_MNE_COORD_FRAME + ) + one["nsource"] = _get_tag_int( + fid, node, "Number of sources", FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS + ) + one["nchan"] = _get_tag_int(fid, node, "Number of channels", FIFF.FIFF_NCHAN) try: - one['sol'] = _read_named_matrix(fid, node, - FIFF.FIFF_MNE_FORWARD_SOLUTION, - transpose=True) - one['_orig_sol'] = one['sol']['data'].copy() + one["sol"] = _read_named_matrix( + fid, node, FIFF.FIFF_MNE_FORWARD_SOLUTION, transpose=True + ) + one["_orig_sol"] = one["sol"]["data"].copy() except Exception: - logger.error('Forward solution data not found') + logger.error("Forward solution data not found") raise try: fwd_type = FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD - one['sol_grad'] = _read_named_matrix(fid, node, fwd_type, - transpose=True) - one['_orig_sol_grad'] = one['sol_grad']['data'].copy() + one["sol_grad"] = _read_named_matrix(fid, node, fwd_type, transpose=True) + one["_orig_sol_grad"] = one["sol_grad"]["data"].copy() except Exception: - one['sol_grad'] = None + one["sol_grad"] = None - if one['sol']['data'].shape[0] != one['nchan'] or \ - (one['sol']['data'].shape[1] != one['nsource'] and - one['sol']['data'].shape[1] != 3 * one['nsource']): - raise ValueError('Forward solution matrix has wrong dimensions') + if one["sol"]["data"].shape[0] != one["nchan"] or ( + one["sol"]["data"].shape[1] != one["nsource"] + and one["sol"]["data"].shape[1] != 3 * one["nsource"] + ): + raise ValueError("Forward solution matrix has wrong dimensions") - if one['sol_grad'] is not None: - if one['sol_grad']['data'].shape[0] != one['nchan'] or \ - (one['sol_grad']['data'].shape[1] != 3 * one['nsource'] and - one['sol_grad']['data'].shape[1] != 3 * 3 * one['nsource']): - raise ValueError('Forward solution gradient matrix has ' - 'wrong dimensions') + if one["sol_grad"] is not None: + if one["sol_grad"]["data"].shape[0] != one["nchan"] or ( + one["sol_grad"]["data"].shape[1] != 3 * one["nsource"] + and one["sol_grad"]["data"].shape[1] != 3 * 3 * one["nsource"] + ): + raise ValueError("Forward solution gradient matrix has " "wrong dimensions") return one @@ -352,30 +386,30 @@ def _read_forward_meas_info(tree, fid): # Information from the MRI file parent_mri = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MRI_FILE) if len(parent_mri) == 0: - raise ValueError('No parent MEG information found in operator') + raise ValueError("No parent MEG information found in operator") parent_mri = parent_mri[0] tag = find_tag(fid, parent_mri, FIFF.FIFF_MNE_FILE_NAME) - info['mri_file'] = tag.data if tag is not None else None + info["mri_file"] = tag.data if tag is not None else None tag = find_tag(fid, parent_mri, FIFF.FIFF_PARENT_FILE_ID) - info['mri_id'] = tag.data if tag is not None else None + info["mri_id"] = tag.data if tag is not None else None # Information from the MEG file parent_meg = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) if len(parent_meg) == 0: - raise ValueError('No parent MEG information found in operator') + raise ValueError("No parent MEG information found in operator") parent_meg = parent_meg[0] tag = find_tag(fid, parent_meg, FIFF.FIFF_MNE_FILE_NAME) - info['meas_file'] = tag.data if tag is not None else None + info["meas_file"] = tag.data if tag is not None else None tag = find_tag(fid, parent_meg, FIFF.FIFF_PARENT_FILE_ID) - info['meas_id'] = tag.data if tag is not None else None + info["meas_id"] = tag.data if tag is not None else None # Add channel information - info['chs'] = chs = list() - for k in range(parent_meg['nent']): - kind = parent_meg['directory'][k].kind - pos = parent_meg['directory'][k].pos + info["chs"] = chs = list() + for k in range(parent_meg["nent"]): + kind = parent_meg["directory"][k].kind + pos = parent_meg["directory"][k].pos if kind == FIFF.FIFF_CH_INFO: tag = read_tag(fid, pos) chs.append(tag.data) @@ -389,51 +423,50 @@ def _read_forward_meas_info(tree, fid): coord_device = FIFF.FIFFV_COORD_DEVICE coord_ctf_head = FIFF.FIFFV_MNE_COORD_CTF_HEAD if tag is None: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") cand = tag.data - if cand['from'] == coord_mri and cand['to'] == coord_head: - info['mri_head_t'] = cand + if cand["from"] == coord_mri and cand["to"] == coord_head: + info["mri_head_t"] = cand else: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") # Get the MEG device <-> head coordinate transformation tag = find_tag(fid, parent_meg, FIFF.FIFF_COORD_TRANS) if tag is None: - raise ValueError('MEG/head coordinate transformation not found') + raise ValueError("MEG/head coordinate transformation not found") cand = tag.data - if cand['from'] == coord_device and cand['to'] == coord_head: - info['dev_head_t'] = cand - elif cand['from'] == coord_ctf_head and cand['to'] == coord_head: - info['ctf_head_t'] = cand + if cand["from"] == coord_device and cand["to"] == coord_head: + info["dev_head_t"] = cand + elif cand["from"] == coord_ctf_head and cand["to"] == coord_head: + info["ctf_head_t"] = cand else: - raise ValueError('MEG/head coordinate transformation not found') + raise ValueError("MEG/head coordinate transformation not found") - info['bads'] = _read_bad_channels( - fid, parent_meg, ch_names_mapping=ch_names_mapping) + info["bads"] = _read_bad_channels( + fid, parent_meg, ch_names_mapping=ch_names_mapping + ) # clean up our bad list, old versions could have non-existent bads - info['bads'] = [bad for bad in info['bads'] if bad in info['ch_names']] + info["bads"] = [bad for bad in info["bads"] if bad in info["ch_names"]] # Check if a custom reference has been applied tag = find_tag(fid, parent_mri, FIFF.FIFF_MNE_CUSTOM_REF) if tag is None: tag = find_tag(fid, parent_mri, 236) # Constant 236 used before v0.11 - info['custom_ref_applied'] = ( - int(tag.data.item()) if tag is not None else False - ) + info["custom_ref_applied"] = int(tag.data.item()) if tag is not None else False info._unlocked = False return info def _subject_from_forward(forward): """Get subject id from inverse operator.""" - return forward['src']._subject + return forward["src"]._subject # This sets the forward solution order (and gives human-readable names) _FWD_ORDER = dict( - meg='MEG', - eeg='EEG', + meg="MEG", + eeg="EEG", ) @@ -455,28 +488,30 @@ def _merge_fwds(fwds, *, verbose=None): b = fwds[key] a_kind, b_kind = _FWD_ORDER[first_key], _FWD_ORDER[key] combined.append(b_kind) - if (a['sol']['data'].shape[1] != b['sol']['data'].shape[1] or - a['source_ori'] != b['source_ori'] or - a['nsource'] != b['nsource'] or - a['coord_frame'] != b['coord_frame']): + if ( + a["sol"]["data"].shape[1] != b["sol"]["data"].shape[1] + or a["source_ori"] != b["source_ori"] + or a["nsource"] != b["nsource"] + or a["coord_frame"] != b["coord_frame"] + ): raise ValueError( - f'The {a_kind} and {b_kind} forward solutions do not match') - for k in ('sol', 'sol_grad'): + f"The {a_kind} and {b_kind} forward solutions do not match" + ) + for k in ("sol", "sol_grad"): if a[k] is None: continue - a[k]['data'] = np.r_[a[k]['data'], b[k]['data']] - a[f'_orig_{k}'] = np.r_[a[f'_orig_{k}'], b[f'_orig_{k}']] - a[k]['nrow'] = a[k]['nrow'] + b[k]['nrow'] - a[k]['row_names'] = a[k]['row_names'] + b[k]['row_names'] - a['nchan'] = a['nchan'] + b['nchan'] + a[k]["data"] = np.r_[a[k]["data"], b[k]["data"]] + a[f"_orig_{k}"] = np.r_[a[f"_orig_{k}"], b[f"_orig_{k}"]] + a[k]["nrow"] = a[k]["nrow"] + b[k]["nrow"] + a[k]["row_names"] = a[k]["row_names"] + b[k]["row_names"] + a["nchan"] = a["nchan"] + b["nchan"] if len(fwds) > 1: logger.info(f' Forward solutions combined: {", ".join(combined)}') return fwd @verbose -def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, - verbose=None): +def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbose=None): """Read a forward solution a.k.a. lead field. Parameters @@ -515,27 +550,28 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, surface-based, fixed orientation cannot be reverted after loading the forward solution with :func:`read_forward_solution`. """ - check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', - '_fwd.fif', '_fwd.fif.gz')) - fname = _check_fname(fname=fname, must_exist=True, overwrite='read') + check_fname( + fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + ) + fname = _check_fname(fname=fname, must_exist=True, overwrite="read") # Open the file, create directory - logger.info('Reading forward solution from %s...' % fname) + logger.info("Reading forward solution from %s..." % fname) f, tree, _ = fiff_open(fname) with f as fid: # Find all forward solutions fwds = dir_tree_find(tree, FIFF.FIFFB_MNE_FORWARD_SOLUTION) if len(fwds) == 0: - raise ValueError('No forward solutions in %s' % fname) + raise ValueError("No forward solutions in %s" % fname) # Parent MRI data parent_mri = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MRI_FILE) if len(parent_mri) == 0: - raise ValueError('No parent MRI information in %s' % fname) + raise ValueError("No parent MRI information in %s" % fname) parent_mri = parent_mri[0] src = _read_source_spaces_from_tree(fid, tree, patch_stats=False) for s in src: - s['id'] = find_source_space_hemi(s) + s["id"] = find_source_space_hemi(s) fwd = None @@ -545,8 +581,9 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, for k in range(len(fwds)): tag = find_tag(fid, fwds[k], FIFF.FIFF_MNE_INCLUDED_METHODS) if tag is None: - raise ValueError('Methods not listed for one of the forward ' - 'solutions') + raise ValueError( + "Methods not listed for one of the forward " "solutions" + ) if tag.data == FIFF.FIFFV_MNE_MEG: megnode = fwds[k] @@ -556,26 +593,30 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, fwds = dict() megfwd = _read_one(fid, megnode) if megfwd is not None: - fwds['meg'] = megfwd + fwds["meg"] = megfwd if is_fixed_orient(megfwd): - ori = 'fixed' + ori = "fixed" else: - ori = 'free' - logger.info(' Read MEG forward solution (%d sources, ' - '%d channels, %s orientations)' - % (megfwd['nsource'], megfwd['nchan'], ori)) + ori = "free" + logger.info( + " Read MEG forward solution (%d sources, " + "%d channels, %s orientations)" + % (megfwd["nsource"], megfwd["nchan"], ori) + ) del megfwd eegfwd = _read_one(fid, eegnode) if eegfwd is not None: - fwds['eeg'] = eegfwd + fwds["eeg"] = eegfwd if is_fixed_orient(eegfwd): - ori = 'fixed' + ori = "fixed" else: - ori = 'free' - logger.info(' Read EEG forward solution (%d sources, ' - '%d channels, %s orientations)' - % (eegfwd['nsource'], eegfwd['nchan'], ori)) + ori = "free" + logger.info( + " Read EEG forward solution (%d sources, " + "%d channels, %s orientations)" + % (eegfwd["nsource"], eegfwd["nchan"], ori) + ) del eegfwd fwd = _merge_fwds(fwds) @@ -584,22 +625,25 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, # Get the MRI <-> head coordinate transformation tag = find_tag(fid, parent_mri, FIFF.FIFF_COORD_TRANS) if tag is None: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") mri_head_t = tag.data - if (mri_head_t['from'] != FIFF.FIFFV_COORD_MRI or - mri_head_t['to'] != FIFF.FIFFV_COORD_HEAD): + if ( + mri_head_t["from"] != FIFF.FIFFV_COORD_MRI + or mri_head_t["to"] != FIFF.FIFFV_COORD_HEAD + ): mri_head_t = invert_transform(mri_head_t) - if (mri_head_t['from'] != FIFF.FIFFV_COORD_MRI or - mri_head_t['to'] != FIFF.FIFFV_COORD_HEAD): + if ( + mri_head_t["from"] != FIFF.FIFFV_COORD_MRI + or mri_head_t["to"] != FIFF.FIFFV_COORD_HEAD + ): fid.close() - raise ValueError('MRI/head coordinate transformation not ' - 'found') - fwd['mri_head_t'] = mri_head_t + raise ValueError("MRI/head coordinate transformation not " "found") + fwd["mri_head_t"] = mri_head_t # # get parent MEG info # - fwd['info'] = _read_forward_meas_info(tree, fid) + fwd["info"] = _read_forward_meas_info(tree, fid) # MNE environment parent_env = dir_tree_find(tree, FIFF.FIFFB_MNE_ENV) @@ -607,20 +651,22 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, parent_env = parent_env[0] tag = find_tag(fid, parent_env, FIFF.FIFF_MNE_ENV_WORKING_DIR) if tag is not None: - with fwd['info']._unlock(): - fwd['info']['working_dir'] = tag.data + with fwd["info"]._unlock(): + fwd["info"]["working_dir"] = tag.data tag = find_tag(fid, parent_env, FIFF.FIFF_MNE_ENV_COMMAND_LINE) if tag is not None: - with fwd['info']._unlock(): - fwd['info']['command_line'] = tag.data + with fwd["info"]._unlock(): + fwd["info"]["command_line"] = tag.data # Transform the source spaces to the correct coordinate frame # if necessary # Make sure forward solution is in either the MRI or HEAD coordinate frame - if fwd['coord_frame'] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): - raise ValueError('Only forward solutions computed in MRI or head ' - 'coordinates are acceptable') + if fwd["coord_frame"] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): + raise ValueError( + "Only forward solutions computed in MRI or head " + "coordinates are acceptable" + ) # Transform each source space to the HEAD or MRI coordinate frame, # depending on the coordinate frame of the forward solution @@ -629,45 +675,47 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, nuse = 0 for s in src: try: - s = transform_surface_to(s, fwd['coord_frame'], mri_head_t) + s = transform_surface_to(s, fwd["coord_frame"], mri_head_t) except Exception as inst: - raise ValueError('Could not transform source space (%s)' % inst) + raise ValueError("Could not transform source space (%s)" % inst) - nuse += s['nuse'] + nuse += s["nuse"] # Make sure the number of sources match after transformation - if nuse != fwd['nsource']: - raise ValueError('Source spaces do not match the forward solution.') + if nuse != fwd["nsource"]: + raise ValueError("Source spaces do not match the forward solution.") - logger.info(' Source spaces transformed to the forward solution ' - 'coordinate frame') - fwd['src'] = src + logger.info( + " Source spaces transformed to the forward solution " "coordinate frame" + ) + fwd["src"] = src # Handle the source locations and orientations - fwd['source_rr'] = np.concatenate([ss['rr'][ss['vertno'], :] - for ss in src], axis=0) + fwd["source_rr"] = np.concatenate([ss["rr"][ss["vertno"], :] for ss in src], axis=0) # Store original source orientations - fwd['_orig_source_ori'] = fwd['source_ori'] + fwd["_orig_source_ori"] = fwd["source_ori"] # Deal with include and exclude pick_channels_forward(fwd, include=include, exclude=exclude, copy=False) if is_fixed_orient(fwd, orig=True): - fwd['source_nn'] = np.concatenate([_src['nn'][_src['vertno'], :] - for _src in fwd['src']], axis=0) - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["source_nn"] = np.concatenate( + [_src["nn"][_src["vertno"], :] for _src in fwd["src"]], axis=0 + ) + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True else: - fwd['source_nn'] = np.kron(np.ones((fwd['nsource'], 1)), np.eye(3)) - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = False + fwd["source_nn"] = np.kron(np.ones((fwd["nsource"], 1)), np.eye(3)) + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = False return Forward(fwd) @verbose -def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, - copy=True, use_cps=True, *, verbose=None): +def convert_forward_solution( + fwd, surf_ori=False, force_fixed=False, copy=True, use_cps=True, *, verbose=None +): """Convert forward solution between different source orientations. Parameters @@ -690,28 +738,34 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, The modified forward solution. """ from scipy import sparse + fwd = fwd.copy() if copy else fwd if force_fixed is True: surf_ori = True - if any([src['type'] == 'vol' for src in fwd['src']]) and force_fixed: + if any([src["type"] == "vol" for src in fwd["src"]]) and force_fixed: raise ValueError( - 'Forward operator was generated with sources from a ' - 'volume source space. Conversion to fixed orientation is not ' - 'possible. Consider using a discrete source space if you have ' - 'meaningful normal orientations.') + "Forward operator was generated with sources from a " + "volume source space. Conversion to fixed orientation is not " + "possible. Consider using a discrete source space if you have " + "meaningful normal orientations." + ) if surf_ori and use_cps: - if any(s.get('patch_inds') is not None for s in fwd['src']): - logger.info(' Average patch normals will be employed in ' - 'the rotation to the local surface coordinates..' - '..') + if any(s.get("patch_inds") is not None for s in fwd["src"]): + logger.info( + " Average patch normals will be employed in " + "the rotation to the local surface coordinates.." + ".." + ) else: use_cps = False - logger.info(' No patch info available. The standard source ' - 'space normals will be employed in the rotation ' - 'to the local surface coordinates....') + logger.info( + " No patch info available. The standard source " + "space normals will be employed in the rotation " + "to the local surface coordinates...." + ) # We need to change these entries (only): # 1. source_nn @@ -723,78 +777,79 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, if is_fixed_orient(fwd, orig=True) or (force_fixed and not use_cps): # Fixed - fwd['source_nn'] = np.concatenate([_get_src_nn(s, use_cps) - for s in fwd['src']], axis=0) + fwd["source_nn"] = np.concatenate( + [_get_src_nn(s, use_cps) for s in fwd["src"]], axis=0 + ) if not is_fixed_orient(fwd, orig=True): - logger.info(' Changing to fixed-orientation forward ' - 'solution with surface-based source orientations...') - fix_rot = _block_diag(fwd['source_nn'].T, 1) + logger.info( + " Changing to fixed-orientation forward " + "solution with surface-based source orientations..." + ) + fix_rot = _block_diag(fwd["source_nn"].T, 1) # newer versions of numpy require explicit casting here, so *= no # longer works - fwd['sol']['data'] = (fwd['_orig_sol'] * - fix_rot).astype('float32') - fwd['sol']['ncol'] = fwd['nsource'] - if fwd['sol_grad'] is not None: + fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32") + fwd["sol"]["ncol"] = fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([fix_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 3 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 3 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True elif surf_ori: # Free, surf-oriented # Rotate the local source coordinate systems - fwd['source_nn'] = np.kron(np.ones((fwd['nsource'], 1)), np.eye(3)) - logger.info(' Converting to surface-based source orientations...') + fwd["source_nn"] = np.kron(np.ones((fwd["nsource"], 1)), np.eye(3)) + logger.info(" Converting to surface-based source orientations...") # Actually determine the source orientations pp = 0 - for s in fwd['src']: - if s['type'] in ['surf', 'discrete']: + for s in fwd["src"]: + if s["type"] in ["surf", "discrete"]: nn = _get_src_nn(s, use_cps) - stop = pp + 3 * s['nuse'] - fwd['source_nn'][pp:stop] = _normal_orth(nn).reshape(-1, 3) + stop = pp + 3 * s["nuse"] + fwd["source_nn"][pp:stop] = _normal_orth(nn).reshape(-1, 3) pp = stop del nn else: - pp += 3 * s['nuse'] + pp += 3 * s["nuse"] # Rotate the solution components as well if force_fixed: - fwd['source_nn'] = fwd['source_nn'][2::3, :] - fix_rot = _block_diag(fwd['source_nn'].T, 1) + fwd["source_nn"] = fwd["source_nn"][2::3, :] + fix_rot = _block_diag(fwd["source_nn"].T, 1) # newer versions of numpy require explicit casting here, so *= no # longer works - fwd['sol']['data'] = (fwd['_orig_sol'] * - fix_rot).astype('float32') - fwd['sol']['ncol'] = fwd['nsource'] - if fwd['sol_grad'] is not None: + fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32") + fwd["sol"]["ncol"] = fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([fix_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 3 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 3 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True else: - surf_rot = _block_diag(fwd['source_nn'].T, 3) - fwd['sol']['data'] = fwd['_orig_sol'] * surf_rot - fwd['sol']['ncol'] = 3 * fwd['nsource'] - if fwd['sol_grad'] is not None: + surf_rot = _block_diag(fwd["source_nn"].T, 3) + fwd["sol"]["data"] = fwd["_orig_sol"] * surf_rot + fwd["sol"]["ncol"] = 3 * fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([surf_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 9 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 9 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = True else: # Free, cartesian - logger.info(' Cartesian source orientations...') - fwd['source_nn'] = np.tile(np.eye(3), (fwd['nsource'], 1)) - fwd['sol']['data'] = fwd['_orig_sol'].copy() - fwd['sol']['ncol'] = 3 * fwd['nsource'] - if fwd['sol_grad'] is not None: - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'].copy() - fwd['sol_grad']['ncol'] = 9 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = False - - logger.info(' [done]') + logger.info(" Cartesian source orientations...") + fwd["source_nn"] = np.tile(np.eye(3), (fwd["nsource"], 1)) + fwd["sol"]["data"] = fwd["_orig_sol"].copy() + fwd["sol"]["ncol"] = 3 * fwd["nsource"] + if fwd["sol_grad"] is not None: + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"].copy() + fwd["sol_grad"]["ncol"] = 9 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = False + + logger.info(" [done]") return fwd @@ -832,8 +887,9 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): surface-based, fixed orientation cannot be reverted after loading the forward solution with :func:`read_forward_solution`. """ - check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', - '_fwd.fif', '_fwd.fif.gz')) + check_fname( + fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + ) # check for file existence and expand `~` if present fname = _check_fname(fname, overwrite) @@ -849,10 +905,10 @@ def _write_forward_solution(fid, fwd): # start_block(fid, FIFF.FIFFB_MNE_ENV) write_id(fid, FIFF.FIFF_BLOCK_ID) - data = fwd['info'].get('working_dir', None) + data = fwd["info"].get("working_dir", None) if data is not None: write_string(fid, FIFF.FIFF_MNE_ENV_WORKING_DIR, data) - data = fwd['info'].get('command_line', None) + data = fwd["info"].get("command_line", None) if data is not None: write_string(fid, FIFF.FIFF_MNE_ENV_COMMAND_LINE, data) end_block(fid, FIFF.FIFFB_MNE_ENV) @@ -861,118 +917,138 @@ def _write_forward_solution(fid, fwd): # Information from the MRI file # start_block(fid, FIFF.FIFFB_MNE_PARENT_MRI_FILE) - write_string(fid, FIFF.FIFF_MNE_FILE_NAME, fwd['info']['mri_file']) - if fwd['info']['mri_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_FILE_ID, fwd['info']['mri_id']) + write_string(fid, FIFF.FIFF_MNE_FILE_NAME, fwd["info"]["mri_file"]) + if fwd["info"]["mri_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_FILE_ID, fwd["info"]["mri_id"]) # store the MRI to HEAD transform in MRI file - write_coord_trans(fid, fwd['info']['mri_head_t']) + write_coord_trans(fid, fwd["info"]["mri_head_t"]) end_block(fid, FIFF.FIFFB_MNE_PARENT_MRI_FILE) # write measurement info - write_forward_meas_info(fid, fwd['info']) + write_forward_meas_info(fid, fwd["info"]) # invert our original source space transform src = list() - for s in fwd['src']: + for s in fwd["src"]: s = deepcopy(s) try: # returns source space to original coordinate frame # usually MRI - s = transform_surface_to(s, fwd['mri_head_t']['from'], - fwd['mri_head_t']) + s = transform_surface_to(s, fwd["mri_head_t"]["from"], fwd["mri_head_t"]) except Exception as inst: - raise ValueError('Could not transform source space (%s)' % inst) + raise ValueError("Could not transform source space (%s)" % inst) src.append(s) # # Write the source spaces (again) # _write_source_spaces_to_fid(fid, src) - n_vert = sum([ss['nuse'] for ss in src]) - if fwd['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: + n_vert = sum([ss["nuse"] for ss in src]) + if fwd["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: n_col = n_vert else: n_col = 3 * n_vert # Undo transformations - sol = fwd['_orig_sol'].copy() - if fwd['sol_grad'] is not None: - sol_grad = fwd['_orig_sol_grad'].copy() + sol = fwd["_orig_sol"].copy() + if fwd["sol_grad"] is not None: + sol_grad = fwd["_orig_sol_grad"].copy() else: sol_grad = None - if fwd['surf_ori'] is True: - if fwd['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: - warn('The forward solution, which is stored on disk now, is based ' - 'on a forward solution with fixed orientation. Please note ' - 'that the transformation to surface-based, fixed orientation ' - 'cannot be reverted after loading the forward solution with ' - 'read_forward_solution.', RuntimeWarning) + if fwd["surf_ori"] is True: + if fwd["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: + warn( + "The forward solution, which is stored on disk now, is based " + "on a forward solution with fixed orientation. Please note " + "that the transformation to surface-based, fixed orientation " + "cannot be reverted after loading the forward solution with " + "read_forward_solution.", + RuntimeWarning, + ) else: - warn('This forward solution is based on a forward solution with ' - 'free orientation. The original forward solution is stored ' - 'on disk in X/Y/Z RAS coordinates. Any transformation ' - '(surface orientation or fixed orientation) will be ' - 'reverted. To reapply any transformation to the forward ' - 'operator please apply convert_forward_solution after ' - 'reading the forward solution with read_forward_solution.', - RuntimeWarning) + warn( + "This forward solution is based on a forward solution with " + "free orientation. The original forward solution is stored " + "on disk in X/Y/Z RAS coordinates. Any transformation " + "(surface orientation or fixed orientation) will be " + "reverted. To reapply any transformation to the forward " + "operator please apply convert_forward_solution after " + "reading the forward solution with read_forward_solution.", + RuntimeWarning, + ) # # MEG forward solution # - picks_meg = pick_types(fwd['info'], meg=True, eeg=False, ref_meg=False, - exclude=[]) - picks_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False, - exclude=[]) + picks_meg = pick_types(fwd["info"], meg=True, eeg=False, ref_meg=False, exclude=[]) + picks_eeg = pick_types(fwd["info"], meg=False, eeg=True, ref_meg=False, exclude=[]) n_meg = len(picks_meg) n_eeg = len(picks_eeg) - row_names_meg = [fwd['sol']['row_names'][p] for p in picks_meg] - row_names_eeg = [fwd['sol']['row_names'][p] for p in picks_eeg] + row_names_meg = [fwd["sol"]["row_names"][p] for p in picks_meg] + row_names_eeg = [fwd["sol"]["row_names"][p] for p in picks_eeg] if n_meg > 0: - meg_solution = dict(data=sol[picks_meg], nrow=n_meg, ncol=n_col, - row_names=row_names_meg, col_names=[]) + meg_solution = dict( + data=sol[picks_meg], + nrow=n_meg, + ncol=n_col, + row_names=row_names_meg, + col_names=[], + ) _transpose_named_matrix(meg_solution) start_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) write_int(fid, FIFF.FIFF_MNE_INCLUDED_METHODS, FIFF.FIFFV_MNE_MEG) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd['coord_frame']) - write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, - fwd['_orig_source_ori']) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd["coord_frame"]) + write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, fwd["_orig_source_ori"]) write_int(fid, FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS, n_vert) write_int(fid, FIFF.FIFF_NCHAN, n_meg) write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION, meg_solution) if sol_grad is not None: - meg_solution_grad = dict(data=sol_grad[picks_meg], - nrow=n_meg, ncol=n_col * 3, - row_names=row_names_meg, col_names=[]) + meg_solution_grad = dict( + data=sol_grad[picks_meg], + nrow=n_meg, + ncol=n_col * 3, + row_names=row_names_meg, + col_names=[], + ) _transpose_named_matrix(meg_solution_grad) - write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, - meg_solution_grad) + write_named_matrix( + fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, meg_solution_grad + ) end_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) # # EEG forward solution # if n_eeg > 0: - eeg_solution = dict(data=sol[picks_eeg], nrow=n_eeg, ncol=n_col, - row_names=row_names_eeg, col_names=[]) + eeg_solution = dict( + data=sol[picks_eeg], + nrow=n_eeg, + ncol=n_col, + row_names=row_names_eeg, + col_names=[], + ) _transpose_named_matrix(eeg_solution) start_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) write_int(fid, FIFF.FIFF_MNE_INCLUDED_METHODS, FIFF.FIFFV_MNE_EEG) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd['coord_frame']) - write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, - fwd['_orig_source_ori']) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd["coord_frame"]) + write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, fwd["_orig_source_ori"]) write_int(fid, FIFF.FIFF_NCHAN, n_eeg) write_int(fid, FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS, n_vert) write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION, eeg_solution) if sol_grad is not None: - eeg_solution_grad = dict(data=sol_grad[picks_eeg], - nrow=n_eeg, ncol=n_col * 3, - row_names=row_names_eeg, col_names=[]) + eeg_solution_grad = dict( + data=sol_grad[picks_eeg], + nrow=n_eeg, + ncol=n_col * 3, + row_names=row_names_eeg, + col_names=[], + ) _transpose_named_matrix(eeg_solution_grad) - write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, - eeg_solution_grad) + write_named_matrix( + fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, eeg_solution_grad + ) end_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) end_block(fid, FIFF.FIFFB_MNE) @@ -995,9 +1071,9 @@ def is_fixed_orient(forward, orig=False): Whether or not it is fixed orientation. """ if orig: # if we want to know about the original version - fixed_ori = (forward['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI) + fixed_ori = forward["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI else: # most of the time we want to know about the current version - fixed_ori = (forward['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI) + fixed_ori = forward["source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI return fixed_ori @@ -1016,25 +1092,25 @@ def write_forward_meas_info(fid, info): # Information from the MEG file # start_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) - write_string(fid, FIFF.FIFF_MNE_FILE_NAME, info['meas_file']) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + write_string(fid, FIFF.FIFF_MNE_FILE_NAME, info["meas_file"]) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) # get transformation from CTF and DEVICE to HEAD coordinate frame - meg_head_t = info.get('dev_head_t', info.get('ctf_head_t')) + meg_head_t = info.get("dev_head_t", info.get("ctf_head_t")) if meg_head_t is None: fid.close() - raise ValueError('Head<-->sensor transform not found') + raise ValueError("Head<-->sensor transform not found") write_coord_trans(fid, meg_head_t) ch_names_mapping = dict() - if 'chs' in info: + if "chs" in info: # Channel information - ch_names_mapping = _make_ch_names_mapping(info['chs']) - write_int(fid, FIFF.FIFF_NCHAN, len(info['chs'])) - _write_ch_infos(fid, info['chs'], False, ch_names_mapping) - if 'bads' in info and len(info['bads']) > 0: + ch_names_mapping = _make_ch_names_mapping(info["chs"]) + write_int(fid, FIFF.FIFF_NCHAN, len(info["chs"])) + _write_ch_infos(fid, info["chs"], False, ch_names_mapping) + if "bads" in info and len(info["bads"]) > 0: # Bad channels - _write_bad_channels(fid, info['bads'], ch_names_mapping) + _write_bad_channels(fid, info["bads"], ch_names_mapping) end_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) @@ -1042,82 +1118,89 @@ def write_forward_meas_info(fid, info): def _select_orient_forward(forward, info, noise_cov=None, copy=True): """Prepare forward solution for inverse solvers.""" # fwd['sol']['row_names'] may be different order from fwd['info']['chs'] - fwd_sol_ch_names = forward['sol']['row_names'] + fwd_sol_ch_names = forward["sol"]["row_names"] all_ch_names = set(fwd_sol_ch_names) - all_bads = set(info['bads']) + all_bads = set(info["bads"]) if noise_cov is not None: - all_ch_names &= set(noise_cov['names']) - all_bads |= set(noise_cov['bads']) + all_ch_names &= set(noise_cov["names"]) + all_bads |= set(noise_cov["bads"]) else: - noise_cov = dict(bads=info['bads']) - ch_names = [c['ch_name'] for c in info['chs'] - if c['ch_name'] not in all_bads and - c['ch_name'] in all_ch_names] - - if not len(info['bads']) == len(noise_cov['bads']) or \ - not all(b in noise_cov['bads'] for b in info['bads']): - logger.info('info["bads"] and noise_cov["bads"] do not match, ' - 'excluding bad channels from both') + noise_cov = dict(bads=info["bads"]) + ch_names = [ + c["ch_name"] + for c in info["chs"] + if c["ch_name"] not in all_bads and c["ch_name"] in all_ch_names + ] + + if not len(info["bads"]) == len(noise_cov["bads"]) or not all( + b in noise_cov["bads"] for b in info["bads"] + ): + logger.info( + 'info["bads"] and noise_cov["bads"] do not match, ' + "excluding bad channels from both" + ) # check the compensation grade - _check_compensation_grade(forward['info'], info, 'forward') + _check_compensation_grade(forward["info"], info, "forward") n_chan = len(ch_names) logger.info("Computing inverse operator with %d channels." % n_chan) - forward = pick_channels_forward(forward, ch_names, ordered=True, - copy=copy) - info_idx = [info['ch_names'].index(name) for name in ch_names] + forward = pick_channels_forward(forward, ch_names, ordered=True, copy=copy) + info_idx = [info["ch_names"].index(name) for name in ch_names] info_picked = pick_info(info, info_idx) - forward['info']._check_consistency() + forward["info"]._check_consistency() info_picked._check_consistency() return forward, info_picked -def _triage_loose(src, loose, fixed='auto'): - _validate_type(loose, (str, dict, 'numeric'), 'loose') - _validate_type(fixed, (str, bool), 'fixed') +def _triage_loose(src, loose, fixed="auto"): + _validate_type(loose, (str, dict, "numeric"), "loose") + _validate_type(fixed, (str, bool), "fixed") orig_loose = loose if isinstance(loose, str): - _check_option('loose', loose, ('auto',)) + _check_option("loose", loose, ("auto",)) if fixed is True: - loose = 0. + loose = 0.0 else: # False or auto - loose = 0.2 if src.kind == 'surface' else 1. - src_types = set(_src_kind_dict[s['type']] for s in src) + loose = 0.2 if src.kind == "surface" else 1.0 + src_types = set(_src_kind_dict[s["type"]] for s in src) if not isinstance(loose, dict): loose = float(loose) loose = {key: loose for key in src_types} loose_keys = set(loose.keys()) if loose_keys != src_types: raise ValueError( - f'loose, if dict, must have keys {sorted(src_types)} to match the ' - f'source space, got {sorted(loose_keys)}') + f"loose, if dict, must have keys {sorted(src_types)} to match the " + f"source space, got {sorted(loose_keys)}" + ) # if fixed is auto it can be ignored, if it's False it can be ignored, # only really need to care about fixed=True if fixed is True: - if not all(v == 0. for v in loose.values()): + if not all(v == 0.0 for v in loose.values()): raise ValueError( 'When using fixed=True, loose must be 0. or "auto", ' - f'got {orig_loose}') + f"got {orig_loose}" + ) elif fixed is False: - if any(v == 0. for v in loose.values()): + if any(v == 0.0 for v in loose.values()): raise ValueError( - 'If loose==0., then fixed must be True or "auto", got False') + 'If loose==0., then fixed must be True or "auto", got False' + ) del fixed for key, this_loose in loose.items(): - if key not in ('surface', 'discrete') and this_loose != 1: + if key not in ("surface", "discrete") and this_loose != 1: raise ValueError( 'loose parameter has to be 1 or "auto" for non-surface/' - f'discrete source spaces, got loose["{key}"] = {this_loose}') + f'discrete source spaces, got loose["{key}"] = {this_loose}' + ) if not 0 <= this_loose <= 1: - raise ValueError( - f'loose ({key}) must be between 0 and 1, got {this_loose}') + raise ValueError(f"loose ({key}) must be between 0 and 1, got {this_loose}") return loose @verbose -def compute_orient_prior(forward, loose='auto', verbose=None): +def compute_orient_prior(forward, loose="auto", verbose=None): """Compute orientation prior. Parameters @@ -1136,40 +1219,46 @@ def compute_orient_prior(forward, loose='auto', verbose=None): -------- compute_depth_prior """ - _validate_type(forward, Forward, 'forward') - n_sources = forward['sol']['data'].shape[1] + _validate_type(forward, Forward, "forward") + n_sources = forward["sol"]["data"].shape[1] - loose = _triage_loose(forward['src'], loose) + loose = _triage_loose(forward["src"], loose) orient_prior = np.ones(n_sources, dtype=np.float64) if is_fixed_orient(forward): - if any(v > 0. for v in loose.values()): - raise ValueError('loose must be 0. with forward operator ' - 'with fixed orientation, got %s' % (loose,)) + if any(v > 0.0 for v in loose.values()): + raise ValueError( + "loose must be 0. with forward operator " + "with fixed orientation, got %s" % (loose,) + ) return orient_prior - if all(v == 1. for v in loose.values()): + if all(v == 1.0 for v in loose.values()): return orient_prior # We actually need non-unity prior, compute it for each source space # separately - if not forward['surf_ori']: - raise ValueError('Forward operator is not oriented in surface ' - 'coordinates. loose parameter should be 1. ' - 'not %s.' % (loose,)) + if not forward["surf_ori"]: + raise ValueError( + "Forward operator is not oriented in surface " + "coordinates. loose parameter should be 1. " + "not %s." % (loose,) + ) start = 0 logged = dict() - for s in forward['src']: - this_type = _src_kind_dict[s['type']] + for s in forward["src"]: + this_type = _src_kind_dict[s["type"]] use_loose = loose[this_type] if not logged.get(this_type): - if use_loose == 1.: - name = 'free' + if use_loose == 1.0: + name = "free" else: - name = 'fixed' if use_loose == 0. else 'loose' - logger.info(f'Applying {name.ljust(5)} dipole orientations to ' - f'{this_type.ljust(7)} source spaces: {use_loose}') + name = "fixed" if use_loose == 0.0 else "loose" + logger.info( + f"Applying {name.ljust(5)} dipole orientations to " + f"{this_type.ljust(7)} source spaces: {use_loose}" + ) logged[this_type] = True - stop = start + 3 * s['nuse'] + stop = start + 3 * s["nuse"] orient_prior[start:stop:3] *= use_loose - orient_prior[start + 1:stop:3] *= use_loose + orient_prior[start + 1 : stop : 3] *= use_loose start = stop return orient_prior @@ -1177,27 +1266,38 @@ def compute_orient_prior(forward, loose='auto', verbose=None): def _restrict_gain_matrix(G, info): """Restrict gain matrix entries for optimal depth weighting.""" # Figure out which ones have been used - if len(info['chs']) != G.shape[0]: - raise ValueError('G.shape[0] (%d) and length of info["chs"] (%d) ' - 'do not match' % (G.shape[0], len(info['chs']))) + if len(info["chs"]) != G.shape[0]: + raise ValueError( + 'G.shape[0] (%d) and length of info["chs"] (%d) ' + "do not match" % (G.shape[0], len(info["chs"])) + ) for meg, eeg, kind in ( - ('grad', False, 'planar'), - ('mag', False, 'magnetometer or axial gradiometer'), - (False, True, 'EEG')): + ("grad", False, "planar"), + ("mag", False, "magnetometer or axial gradiometer"), + (False, True, "EEG"), + ): sel = pick_types(info, meg=meg, eeg=eeg, ref_meg=False, exclude=[]) if len(sel) > 0: - logger.info(' %d %s channels' % (len(sel), kind)) + logger.info(" %d %s channels" % (len(sel), kind)) break else: - warn('Could not find MEG or EEG channels to limit depth channels') + warn("Could not find MEG or EEG channels to limit depth channels") sel = slice(None) return G[sel] @verbose -def compute_depth_prior(forward, info, exp=0.8, limit=10.0, - limit_depth_chs=False, combine_xyz='spectral', - noise_cov=None, rank=None, verbose=None): +def compute_depth_prior( + forward, + info, + exp=0.8, + limit=10.0, + limit_depth_chs=False, + combine_xyz="spectral", + noise_cov=None, + rank=None, + verbose=None, +): """Compute depth prior for depth weighting. Parameters @@ -1278,39 +1378,44 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, SI units (such as EEG being orders of magnitude larger than MEG). """ from ..cov import Covariance, compute_whitener - _validate_type(forward, Forward, 'forward') - patch_areas = forward.get('patch_areas', None) + + _validate_type(forward, Forward, "forward") + patch_areas = forward.get("patch_areas", None) is_fixed_ori = is_fixed_orient(forward) - G = forward['sol']['data'] - logger.info('Creating the depth weighting matrix...') - _validate_type(noise_cov, (Covariance, None), 'noise_cov', - 'Covariance or None') - _validate_type(limit_depth_chs, (str, bool), 'limit_depth_chs') + G = forward["sol"]["data"] + logger.info("Creating the depth weighting matrix...") + _validate_type(noise_cov, (Covariance, None), "noise_cov", "Covariance or None") + _validate_type(limit_depth_chs, (str, bool), "limit_depth_chs") if isinstance(limit_depth_chs, str): - if limit_depth_chs != 'whiten': - raise ValueError('limit_depth_chs, if str, must be "whiten", got ' - '%s' % (limit_depth_chs,)) + if limit_depth_chs != "whiten": + raise ValueError( + 'limit_depth_chs, if str, must be "whiten", got ' + "%s" % (limit_depth_chs,) + ) if not isinstance(noise_cov, Covariance): - raise ValueError('With limit_depth_chs="whiten", noise_cov must be' - ' a Covariance, got %s' % (type(noise_cov),)) + raise ValueError( + 'With limit_depth_chs="whiten", noise_cov must be' + " a Covariance, got %s" % (type(noise_cov),) + ) if combine_xyz is not False: # private / expert option - _check_option('combine_xyz', combine_xyz, ('fro', 'spectral')) + _check_option("combine_xyz", combine_xyz, ("fro", "spectral")) # If possible, pick best depth-weighting channels if limit_depth_chs is True: G = _restrict_gain_matrix(G, info) - elif limit_depth_chs == 'whiten': - whitener, _ = compute_whitener(noise_cov, info, pca=True, rank=rank, - verbose=False) + elif limit_depth_chs == "whiten": + whitener, _ = compute_whitener( + noise_cov, info, pca=True, rank=rank, verbose=False + ) G = np.dot(whitener, G) # Compute the gain matrix - if is_fixed_ori or combine_xyz in ('fro', False): - d = np.sum(G ** 2, axis=0) + if is_fixed_ori or combine_xyz in ("fro", False): + d = np.sum(G**2, axis=0) if not (is_fixed_ori or combine_xyz is False): d = d.reshape(-1, 3).sum(axis=1) # Spherical leadfield can be zero at the center - d[d == 0.] = np.min(d[d != 0.]) + d[d == 0.0] = np.min(d[d != 0.0]) else: # 'spectral' # n_pos = G.shape[1] // 3 # The following is equivalent to this, but 4-10x faster @@ -1320,22 +1425,22 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, # x = np.dot(Gk.T, Gk) # d[k] = linalg.svdvals(x)[0] G.shape = (G.shape[0], -1, 3) - d = np.linalg.norm(np.einsum('svj,svk->vjk', G, G), # vector dot prods - ord=2, axis=(1, 2)) # ord=2 spectral (largest s.v.) + d = np.linalg.norm( + np.einsum("svj,svk->vjk", G, G), ord=2, axis=(1, 2) # vector dot prods + ) # ord=2 spectral (largest s.v.) G.shape = (G.shape[0], -1) # XXX Currently the fwd solns never have "patch_areas" defined if patch_areas is not None: if not is_fixed_ori and combine_xyz is False: patch_areas = np.repeat(patch_areas, 3) - d /= patch_areas ** 2 - logger.info(' Patch areas taken into account in the depth ' - 'weighting') + d /= patch_areas**2 + logger.info(" Patch areas taken into account in the depth " "weighting") w = 1.0 / d if limit is not None: ws = np.sort(w) - weight_limit = limit ** 2 + weight_limit = limit**2 if limit_depth_chs is False: # match old mne-python behavior # we used to do ind = np.argmin(ws), but this is 0 by sort above @@ -1350,13 +1455,13 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, limit = ws[ind] n_limit = ind - logger.info(' limit = %d/%d = %f' - % (n_limit + 1, len(d), - np.sqrt(limit / ws[0]))) + logger.info( + " limit = %d/%d = %f" % (n_limit + 1, len(d), np.sqrt(limit / ws[0])) + ) scale = 1.0 / limit - logger.info(' scale = %g exp = %g' % (scale, exp)) + logger.info(" scale = %g exp = %g" % (scale, exp)) w = np.minimum(w / limit, 1) - depth_prior = w ** exp + depth_prior = w**exp if not (is_fixed_ori or combine_xyz is False): depth_prior = np.repeat(depth_prior, 3) @@ -1364,8 +1469,9 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, return depth_prior -def _stc_src_sel(src, stc, on_missing='raise', - extra=', likely due to forward calculations'): +def _stc_src_sel( + src, stc, on_missing="raise", extra=", likely due to forward calculations" +): """Select the vertex indices of a source space using a source estimate.""" if isinstance(stc, list): vertices = stc @@ -1374,14 +1480,16 @@ def _stc_src_sel(src, stc, on_missing='raise', vertices = stc.vertices del stc if not len(src) == len(vertices): - raise RuntimeError('Mismatch between number of source spaces (%s) and ' - 'STC vertices (%s)' % (len(src), len(vertices))) + raise RuntimeError( + "Mismatch between number of source spaces (%s) and " + "STC vertices (%s)" % (len(src), len(vertices)) + ) src_sels, stc_sels, out_vertices = [], [], [] src_offset = stc_offset = 0 for s, v in zip(src, vertices): - joint_sel = np.intersect1d(s['vertno'], v) - src_sels.append(np.searchsorted(s['vertno'], joint_sel) + src_offset) - src_offset += len(s['vertno']) + joint_sel = np.intersect1d(s["vertno"], v) + src_sels.append(np.searchsorted(s["vertno"], joint_sel) + src_offset) + src_offset += len(s["vertno"]) idx = np.searchsorted(v, joint_sel) stc_sels.append(idx + stc_offset) stc_offset += len(v) @@ -1393,20 +1501,21 @@ def _stc_src_sel(src, stc, on_missing='raise', n_stc = sum(len(v) for v in vertices) n_joint = len(src_sel) if n_joint != n_stc: - msg = ('Only %i of %i SourceEstimate %s found in ' - 'source space%s' - % (n_joint, n_stc, 'vertex' if n_stc == 1 else 'vertices', - extra)) + msg = "Only %i of %i SourceEstimate %s found in " "source space%s" % ( + n_joint, + n_stc, + "vertex" if n_stc == 1 else "vertices", + extra, + ) _on_missing(on_missing, msg) return src_sel, stc_sel, out_vertices def _fill_measurement_info(info, fwd, sfreq, data): """Fill the measurement info of a Raw or Evoked object.""" - sel = pick_channels( - info['ch_names'], fwd['sol']['row_names'], ordered=False) + sel = pick_channels(info["ch_names"], fwd["sol"]["row_names"], ordered=False) info = pick_info(info, sel) - info['bads'] = [] + info["bads"] = [] now = time() sec = np.floor(now) @@ -1414,41 +1523,49 @@ def _fill_measurement_info(info, fwd, sfreq, data): # this is probably correct based on what's done in meas_info.py... with info._unlock(check_after=True): - info.update(meas_id=fwd['info']['meas_id'], file_id=info['meas_id'], - meas_date=_stamp_to_dt((int(sec), int(usec))), - highpass=0., lowpass=sfreq / 2., sfreq=sfreq, projs=[]) + info.update( + meas_id=fwd["info"]["meas_id"], + file_id=info["meas_id"], + meas_date=_stamp_to_dt((int(sec), int(usec))), + highpass=0.0, + lowpass=sfreq / 2.0, + sfreq=sfreq, + projs=[], + ) # reorder data (which is in fwd order) to match that of info - order = [fwd['sol']['row_names'].index(name) for name in info['ch_names']] + order = [fwd["sol"]["row_names"].index(name) for name in info["ch_names"]] data = data[order] return info, data @verbose -def _apply_forward(fwd, stc, start=None, stop=None, on_missing='raise', - use_cps=True, verbose=None): +def _apply_forward( + fwd, stc, start=None, stop=None, on_missing="raise", use_cps=True, verbose=None +): """Apply forward model and return data, times, ch_names.""" - _validate_type(stc, _BaseSourceEstimate, 'stc', 'SourceEstimate') - _validate_type(fwd, Forward, 'fwd') + _validate_type(stc, _BaseSourceEstimate, "stc", "SourceEstimate") + _validate_type(fwd, Forward, "fwd") if isinstance(stc, _BaseVectorSourceEstimate): vector = True fwd = convert_forward_solution(fwd, force_fixed=False, surf_ori=False) else: vector = False if not is_fixed_orient(fwd): - fwd = convert_forward_solution(fwd, force_fixed=True, - use_cps=use_cps) + fwd = convert_forward_solution(fwd, force_fixed=True, use_cps=use_cps) if np.all(stc.data > 0): - warn('Source estimate only contains currents with positive values. ' - 'Use pick_ori="normal" when computing the inverse to compute ' - 'currents not current magnitudes.') + warn( + "Source estimate only contains currents with positive values. " + 'Use pick_ori="normal" when computing the inverse to compute ' + "currents not current magnitudes." + ) _check_stc_units(stc) - src_sel, stc_sel, _ = _stc_src_sel(fwd['src'], stc, on_missing=on_missing) - gain = fwd['sol']['data'] + src_sel, stc_sel, _ = _stc_src_sel(fwd["src"], stc, on_missing=on_missing) + gain = fwd["sol"]["data"] stc_sel = slice(None) if len(stc_sel) == len(stc.data) else stc_sel times = stc.times[start:stop].copy() stc_data = stc.data[stc_sel, ..., start:stop].reshape(-1, len(times)) @@ -1458,15 +1575,23 @@ def _apply_forward(fwd, stc, start=None, stop=None, on_missing='raise', gain = gain[:, src_sel].reshape(len(gain), -1) # save some memory if possible - logger.info('Projecting source estimate to sensor space...') + logger.info("Projecting source estimate to sensor space...") data = np.dot(gain, stc_data) - logger.info('[done]') + logger.info("[done]") return data, times @verbose -def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, - on_missing='raise', verbose=None): +def apply_forward( + fwd, + stc, + info, + start=None, + stop=None, + use_cps=True, + on_missing="raise", + verbose=None, +): """Project source space currents to sensor space using a forward operator. The sensor space data is computed for all channels present in fwd. Use @@ -1507,19 +1632,22 @@ def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, -------- apply_forward_raw: Compute sensor space data and return a Raw object. """ - _validate_type(info, Info, 'info') - _validate_type(fwd, Forward, 'forward') + _validate_type(info, Info, "info") + _validate_type(fwd, Forward, "forward") info._check_consistency() # make sure evoked_template contains all channels in fwd - for ch_name in fwd['sol']['row_names']: - if ch_name not in info['ch_names']: - raise ValueError('Channel %s of forward operator not present in ' - 'evoked_template.' % ch_name) + for ch_name in fwd["sol"]["row_names"]: + if ch_name not in info["ch_names"]: + raise ValueError( + "Channel %s of forward operator not present in " + "evoked_template." % ch_name + ) # project the source estimate to the sensor space - data, times = _apply_forward(fwd, stc, start, stop, on_missing=on_missing, - use_cps=use_cps) + data, times = _apply_forward( + fwd, stc, start, stop, on_missing=on_missing, use_cps=use_cps + ) # fill the measurement info sfreq = float(1.0 / stc.tstep) @@ -1534,8 +1662,16 @@ def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, @verbose -def apply_forward_raw(fwd, stc, info, start=None, stop=None, - on_missing='raise', use_cps=True, verbose=None): +def apply_forward_raw( + fwd, + stc, + info, + start=None, + stop=None, + on_missing="raise", + use_cps=True, + verbose=None, +): """Project source space currents to sensor space using a forward operator. The sensor space data is computed for all channels present in fwd. Use @@ -1577,19 +1713,21 @@ def apply_forward_raw(fwd, stc, info, start=None, stop=None, apply_forward: Compute sensor space data and return an Evoked object. """ # make sure info contains all channels in fwd - for ch_name in fwd['sol']['row_names']: - if ch_name not in info['ch_names']: - raise ValueError('Channel %s of forward operator not present in ' - 'info.' % ch_name) + for ch_name in fwd["sol"]["row_names"]: + if ch_name not in info["ch_names"]: + raise ValueError( + "Channel %s of forward operator not present in " "info." % ch_name + ) # project the source estimate to the sensor space - data, times = _apply_forward(fwd, stc, start, stop, on_missing=on_missing, - use_cps=use_cps) + data, times = _apply_forward( + fwd, stc, start, stop, on_missing=on_missing, use_cps=use_cps + ) sfreq = 1.0 / stc.tstep info, data = _fill_measurement_info(info, fwd, sfreq, data) with info._unlock(): - info['projs'] = [] + info["projs"] = [] # store sensor data in Raw object using the info raw = RawArray(data, info, first_samp=int(np.round(times[0] * sfreq))) raw._projector = None @@ -1597,7 +1735,7 @@ def apply_forward_raw(fwd, stc, info, start=None, stop=None, @fill_doc -def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): +def restrict_forward_to_stc(fwd, stc, on_missing="ignore"): """Restrict forward operator to active sources in a source estimate. Parameters @@ -1620,9 +1758,9 @@ def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): -------- restrict_forward_to_label """ - _validate_type(on_missing, str, 'on_missing') - _check_option('on_missing', on_missing, ('ignore', 'warn', 'raise')) - src_sel, _, vertices = _stc_src_sel(fwd['src'], stc, on_missing=on_missing) + _validate_type(on_missing, str, "on_missing") + _check_option("on_missing", on_missing, ("ignore", "warn", "raise")) + src_sel, _, vertices = _stc_src_sel(fwd["src"], stc, on_missing=on_missing) del stc return _restrict_forward_to_src_sel(fwd, src_sel) @@ -1630,46 +1768,47 @@ def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): def _restrict_forward_to_src_sel(fwd, src_sel): fwd_out = deepcopy(fwd) # figure out the vertno we are keeping - idx_sel = np.concatenate([[[si] * len(s['vertno']), s['vertno']] - for si, s in enumerate(fwd['src'])], axis=-1) + idx_sel = np.concatenate( + [[[si] * len(s["vertno"]), s["vertno"]] for si, s in enumerate(fwd["src"])], + axis=-1, + ) assert idx_sel.ndim == 2 and idx_sel.shape[0] == 2 - assert idx_sel.shape[1] == fwd['nsource'] + assert idx_sel.shape[1] == fwd["nsource"] idx_sel = idx_sel[:, src_sel] - fwd_out['source_rr'] = fwd['source_rr'][src_sel] - fwd_out['nsource'] = len(src_sel) + fwd_out["source_rr"] = fwd["source_rr"][src_sel] + fwd_out["nsource"] = len(src_sel) if is_fixed_orient(fwd): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['source_nn'] = fwd['source_nn'][idx] - fwd_out['sol']['data'] = fwd['sol']['data'][:, idx] - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = fwd['sol_grad']['data'][:, idx_grad] - fwd_out['sol']['ncol'] = len(idx) + fwd_out["source_nn"] = fwd["source_nn"][idx] + fwd_out["sol"]["data"] = fwd["sol"]["data"][:, idx] + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = fwd["sol_grad"]["data"][:, idx_grad] + fwd_out["sol"]["ncol"] = len(idx) if is_fixed_orient(fwd, orig=True): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['_orig_sol'] = fwd['_orig_sol'][:, idx] - if fwd['sol_grad'] is not None: - fwd_out['_orig_sol_grad'] = fwd['_orig_sol_grad'][:, idx_grad] + fwd_out["_orig_sol"] = fwd["_orig_sol"][:, idx] + if fwd["sol_grad"] is not None: + fwd_out["_orig_sol_grad"] = fwd["_orig_sol_grad"][:, idx_grad] - vertices = [idx_sel[1][idx_sel[0] == si] - for si in range(len(fwd_out['src']))] - _set_source_space_vertices(fwd_out['src'], vertices) + vertices = [idx_sel[1][idx_sel[0] == si] for si in range(len(fwd_out["src"]))] + _set_source_space_vertices(fwd_out["src"], vertices) return fwd_out @@ -1701,92 +1840,106 @@ def restrict_forward_to_label(fwd, labels): # Get vertices separately of each hemisphere from all label for label in labels: _validate_type(label, Label, "label", "Label or list") - i = 0 if label.hemi == 'lh' else 1 + i = 0 if label.hemi == "lh" else 1 vertices[i] = np.append(vertices[i], label.vertices) # Remove duplicates and sort vertices = [np.unique(vert_hemi) for vert_hemi in vertices] fwd_out = deepcopy(fwd) - fwd_out['source_rr'] = np.zeros((0, 3)) - fwd_out['nsource'] = 0 - fwd_out['source_nn'] = np.zeros((0, 3)) - fwd_out['sol']['data'] = np.zeros((fwd['sol']['data'].shape[0], 0)) - fwd_out['_orig_sol'] = np.zeros((fwd['_orig_sol'].shape[0], 0)) - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = np.zeros( - (fwd['sol_grad']['data'].shape[0], 0)) - fwd_out['_orig_sol_grad'] = np.zeros( - (fwd['_orig_sol_grad'].shape[0], 0)) - fwd_out['sol']['ncol'] = 0 - nuse_lh = fwd['src'][0]['nuse'] + fwd_out["source_rr"] = np.zeros((0, 3)) + fwd_out["nsource"] = 0 + fwd_out["source_nn"] = np.zeros((0, 3)) + fwd_out["sol"]["data"] = np.zeros((fwd["sol"]["data"].shape[0], 0)) + fwd_out["_orig_sol"] = np.zeros((fwd["_orig_sol"].shape[0], 0)) + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = np.zeros((fwd["sol_grad"]["data"].shape[0], 0)) + fwd_out["_orig_sol_grad"] = np.zeros((fwd["_orig_sol_grad"].shape[0], 0)) + fwd_out["sol"]["ncol"] = 0 + nuse_lh = fwd["src"][0]["nuse"] for i in range(2): - fwd_out['src'][i]['vertno'] = np.array([], int) - fwd_out['src'][i]['nuse'] = 0 - fwd_out['src'][i]['inuse'] = fwd['src'][i]['inuse'].copy() - fwd_out['src'][i]['inuse'].fill(0) - fwd_out['src'][i]['use_tris'] = np.array([[]], int) - fwd_out['src'][i]['nuse_tri'] = np.array([0]) + fwd_out["src"][i]["vertno"] = np.array([], int) + fwd_out["src"][i]["nuse"] = 0 + fwd_out["src"][i]["inuse"] = fwd["src"][i]["inuse"].copy() + fwd_out["src"][i]["inuse"].fill(0) + fwd_out["src"][i]["use_tris"] = np.array([[]], int) + fwd_out["src"][i]["nuse_tri"] = np.array([0]) # src_sel is idx to cols in fwd that are in any label per hemi - src_sel = np.intersect1d(fwd['src'][i]['vertno'], vertices[i]) - src_sel = np.searchsorted(fwd['src'][i]['vertno'], src_sel) + src_sel = np.intersect1d(fwd["src"][i]["vertno"], vertices[i]) + src_sel = np.searchsorted(fwd["src"][i]["vertno"], src_sel) # Reconstruct each src - vertno = fwd['src'][i]['vertno'][src_sel] - fwd_out['src'][i]['inuse'][vertno] = 1 - fwd_out['src'][i]['nuse'] += len(vertno) - fwd_out['src'][i]['vertno'] = np.where(fwd_out['src'][i]['inuse'])[0] + vertno = fwd["src"][i]["vertno"][src_sel] + fwd_out["src"][i]["inuse"][vertno] = 1 + fwd_out["src"][i]["nuse"] += len(vertno) + fwd_out["src"][i]["vertno"] = np.where(fwd_out["src"][i]["inuse"])[0] # Reconstruct part of fwd that is not sol data src_sel += i * nuse_lh # Add column shift to right hemi - fwd_out['source_rr'] = np.vstack([fwd_out['source_rr'], - fwd['source_rr'][src_sel]]) - fwd_out['nsource'] += len(src_sel) + fwd_out["source_rr"] = np.vstack( + [fwd_out["source_rr"], fwd["source_rr"][src_sel]] + ) + fwd_out["nsource"] += len(src_sel) if is_fixed_orient(fwd): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['source_nn'] = np.vstack( - [fwd_out['source_nn'], fwd['source_nn'][idx]]) - fwd_out['sol']['data'] = np.hstack( - [fwd_out['sol']['data'], fwd['sol']['data'][:, idx]]) - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = np.hstack( - [fwd_out['sol_grad']['data'], - fwd['sol_rad']['data'][:, idx_grad]]) - fwd_out['sol']['ncol'] += len(idx) + fwd_out["source_nn"] = np.vstack([fwd_out["source_nn"], fwd["source_nn"][idx]]) + fwd_out["sol"]["data"] = np.hstack( + [fwd_out["sol"]["data"], fwd["sol"]["data"][:, idx]] + ) + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = np.hstack( + [fwd_out["sol_grad"]["data"], fwd["sol_rad"]["data"][:, idx_grad]] + ) + fwd_out["sol"]["ncol"] += len(idx) if is_fixed_orient(fwd, orig=True): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['_orig_sol'] = np.hstack( - [fwd_out['_orig_sol'], fwd['_orig_sol'][:, idx]]) - if fwd['sol_grad'] is not None: - fwd_out['_orig_sol_grad'] = np.hstack( - [fwd_out['_orig_sol_grad'], - fwd['_orig_sol_grad'][:, idx_grad]]) + fwd_out["_orig_sol"] = np.hstack( + [fwd_out["_orig_sol"], fwd["_orig_sol"][:, idx]] + ) + if fwd["sol_grad"] is not None: + fwd_out["_orig_sol_grad"] = np.hstack( + [fwd_out["_orig_sol_grad"], fwd["_orig_sol_grad"][:, idx_grad]] + ) return fwd_out -def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, - mindist=None, bem=None, mri=None, trans=None, - eeg=True, meg=True, fixed=False, grad=False, - mricoord=False, overwrite=False, subjects_dir=None, - verbose=None): +def _do_forward_solution( + subject, + meas, + fname=None, + src=None, + spacing=None, + mindist=None, + bem=None, + mri=None, + trans=None, + eeg=True, + meg=True, + fixed=False, + grad=False, + mricoord=False, + overwrite=False, + subjects_dir=None, + verbose=None, +): """Calculate a forward solution for a subject using MNE-C routines. This is kept around for testing purposes. @@ -1852,7 +2005,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, The generated forward solution. """ if not has_mne_c(): - raise RuntimeError('mne command line tools could not be found') + raise RuntimeError("mne command line tools could not be found") # check for file existence temp_dir = Path(tempfile.mkdtemp()) @@ -1862,9 +2015,9 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, _validate_type(subject, "str", "subject") # check for meas to exist as string, or try to make evoked - _validate_type(meas, ('path-like', BaseRaw, BaseEpochs, Evoked), 'meas') + _validate_type(meas, ("path-like", BaseRaw, BaseEpochs, Evoked), "meas") if isinstance(meas, (BaseRaw, BaseEpochs, Evoked)): - meas_file = op.join(temp_dir, 'info.fif') + meas_file = op.join(temp_dir, "info.fif") write_info(meas_file, meas.info) meas = meas_file else: @@ -1872,11 +2025,11 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, # deal with trans/mri if mri is not None and trans is not None: - raise ValueError('trans and mri cannot both be specified') + raise ValueError("trans and mri cannot both be specified") if mri is None and trans is None: # MNE allows this to default to a trans/mri in the subject's dir, # but let's be safe here and force the user to pass us a trans/mri - raise ValueError('Either trans or mri must be specified') + raise ValueError("Either trans or mri must be specified") if trans is not None: if isinstance(trans, dict): @@ -1885,8 +2038,10 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(trans, trans_data) except Exception: - raise OSError('trans was a dict, but could not be ' - 'written to disk as a transform file') + raise OSError( + "trans was a dict, but could not be " + "written to disk as a transform file" + ) elif isinstance(trans, (str, Path, PathLike)): _check_fname(trans, "read", must_exist=True, name="trans") trans = Path(trans) @@ -1899,8 +2054,10 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(mri, mri_data) except Exception: - raise OSError('mri was a dict, but could not be ' - 'written to disk as a transform file') + raise OSError( + "mri was a dict, but could not be " + "written to disk as a transform file" + ) elif isinstance(mri, (str, Path, PathLike)): _check_fname(mri, "read", must_exist=True, name="mri") mri = Path(mri) @@ -1909,37 +2066,45 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, # deal with meg/eeg if not meg and not eeg: - raise ValueError('meg or eeg (or both) must be True') + raise ValueError("meg or eeg (or both) must be True") if not fname.suffix == ".fif": - raise ValueError('Forward name does not end with .fif') + raise ValueError("Forward name does not end with .fif") path = fname.parent.absolute() fname = fname.name # deal with mindist if mindist is not None: if isinstance(mindist, str): - if not mindist.lower() == 'all': + if not mindist.lower() == "all": raise ValueError('mindist, if string, must be "all"') - mindist = ['--all'] + mindist = ["--all"] else: - mindist = ['--mindist', '%g' % mindist] + mindist = ["--mindist", "%g" % mindist] # src, spacing, bem - for element, name, kind in zip((src, spacing, bem), - ("src", "spacing", "bem"), - ('path-like', 'str', 'path-like')): + for element, name, kind in zip( + (src, spacing, bem), + ("src", "spacing", "bem"), + ("path-like", "str", "path-like"), + ): if element is not None: _validate_type(element, kind, name, "%s or None" % kind) # put together the actual call - cmd = ['mne_do_forward_solution', - '--subject', subject, - '--meas', meas, - '--fwd', fname, - '--destdir', str(path)] + cmd = [ + "mne_do_forward_solution", + "--subject", + subject, + "--meas", + meas, + "--fwd", + fname, + "--destdir", + str(path), + ] if src is not None: - cmd += ['--src', src] + cmd += ["--src", src] if spacing is not None: if spacing.isdigit(): pass # spacing in mm @@ -1948,36 +2113,38 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, match = re.match(r"(oct|ico)-?(\d+)$", spacing) if match is None: raise ValueError("Invalid spacing parameter: %r" % spacing) - spacing = '-'.join(match.groups()) - cmd += ['--spacing', spacing] + spacing = "-".join(match.groups()) + cmd += ["--spacing", spacing] if mindist is not None: cmd += mindist if bem is not None: - cmd += ['--bem', bem] + cmd += ["--bem", bem] if mri is not None: - cmd += ['--mri', '%s' % str(mri.absolute())] + cmd += ["--mri", "%s" % str(mri.absolute())] if trans is not None: - cmd += ['--trans', '%s' % str(trans.absolute())] + cmd += ["--trans", "%s" % str(trans.absolute())] if not meg: - cmd.append('--eegonly') + cmd.append("--eegonly") if not eeg: - cmd.append('--megonly') + cmd.append("--megonly") if fixed: - cmd.append('--fixed') + cmd.append("--fixed") if grad: - cmd.append('--grad') + cmd.append("--grad") if mricoord: - cmd.append('--mricoord') + cmd.append("--mricoord") if overwrite: - cmd.append('--overwrite') + cmd.append("--overwrite") env = os.environ.copy() subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - env['SUBJECTS_DIR'] = subjects_dir + env["SUBJECTS_DIR"] = subjects_dir try: - logger.info('Running forward solution generation command with ' - 'subjects_dir %s' % subjects_dir) + logger.info( + "Running forward solution generation command with " + "subjects_dir %s" % subjects_dir + ) run_subprocess(cmd, env=env) except Exception: raise @@ -2011,19 +2178,19 @@ def average_forward_solutions(fwds, weights=None, verbose=None): # check for fwds being a list _validate_type(fwds, list, "fwds") if not len(fwds) > 0: - raise ValueError('fwds must not be empty') + raise ValueError("fwds must not be empty") # check weights if weights is None: weights = np.ones(len(fwds)) weights = np.asanyarray(weights) # in case it's a list, convert it if not np.all(weights >= 0): - raise ValueError('weights must be non-negative') + raise ValueError("weights must be non-negative") if not len(weights) == len(fwds): - raise ValueError('weights must be None or the same length as fwds') + raise ValueError("weights must be None or the same length as fwds") w_sum = np.sum(weights) if not w_sum > 0: - raise ValueError('weights cannot all be zero') + raise ValueError("weights cannot all be zero") weights /= w_sum # check our forward solutions @@ -2031,32 +2198,49 @@ def average_forward_solutions(fwds, weights=None, verbose=None): # check to make sure it's a forward solution _validate_type(fwd, dict, "each entry in fwds", "dict") # check to make sure the dict is actually a fwd - check_keys = ['info', 'sol_grad', 'nchan', 'src', 'source_nn', 'sol', - 'source_rr', 'source_ori', 'surf_ori', 'coord_frame', - 'mri_head_t', 'nsource'] + check_keys = [ + "info", + "sol_grad", + "nchan", + "src", + "source_nn", + "sol", + "source_rr", + "source_ori", + "surf_ori", + "coord_frame", + "mri_head_t", + "nsource", + ] if not all(key in fwd for key in check_keys): - raise KeyError('forward solution dict does not have all standard ' - 'entries, cannot compute average.') + raise KeyError( + "forward solution dict does not have all standard " + "entries, cannot compute average." + ) # check forward solution compatibility - if any(fwd['sol'][k] != fwds[0]['sol'][k] - for fwd in fwds[1:] for k in ['nrow', 'ncol']): - raise ValueError('Forward solutions have incompatible dimensions') - if any(fwd[k] != fwds[0][k] for fwd in fwds[1:] - for k in ['source_ori', 'surf_ori', 'coord_frame']): - raise ValueError('Forward solutions have incompatible orientations') + if any( + fwd["sol"][k] != fwds[0]["sol"][k] for fwd in fwds[1:] for k in ["nrow", "ncol"] + ): + raise ValueError("Forward solutions have incompatible dimensions") + if any( + fwd[k] != fwds[0][k] + for fwd in fwds[1:] + for k in ["source_ori", "surf_ori", "coord_frame"] + ): + raise ValueError("Forward solutions have incompatible orientations") # actually average them (solutions and gradients) fwd_ave = deepcopy(fwds[0]) - fwd_ave['sol']['data'] *= weights[0] - fwd_ave['_orig_sol'] *= weights[0] + fwd_ave["sol"]["data"] *= weights[0] + fwd_ave["_orig_sol"] *= weights[0] for fwd, w in zip(fwds[1:], weights[1:]): - fwd_ave['sol']['data'] += w * fwd['sol']['data'] - fwd_ave['_orig_sol'] += w * fwd['_orig_sol'] - if fwd_ave['sol_grad'] is not None: - fwd_ave['sol_grad']['data'] *= weights[0] - fwd_ave['_orig_sol_grad'] *= weights[0] + fwd_ave["sol"]["data"] += w * fwd["sol"]["data"] + fwd_ave["_orig_sol"] += w * fwd["_orig_sol"] + if fwd_ave["sol_grad"] is not None: + fwd_ave["sol_grad"]["data"] *= weights[0] + fwd_ave["_orig_sol_grad"] *= weights[0] for fwd, w in zip(fwds[1:], weights[1:]): - fwd_ave['sol_grad']['data'] += w * fwd['sol_grad']['data'] - fwd_ave['_orig_sol_grad'] += w * fwd['_orig_sol_grad'] + fwd_ave["sol_grad"]["data"] += w * fwd["sol_grad"]["data"] + fwd_ave["_orig_sol_grad"] += w * fwd["_orig_sol_grad"] return fwd_ave diff --git a/mne/forward/tests/test_field_interpolation.py b/mne/forward/tests/test_field_interpolation.py index 9adf3915870..036a4a58af9 100644 --- a/mne/forward/tests/test_field_interpolation.py +++ b/mne/forward/tests/test_field_interpolation.py @@ -3,16 +3,24 @@ import numpy as np from numpy.polynomial import legendre -from numpy.testing import (assert_allclose, assert_array_equal, assert_equal, - assert_array_almost_equal) +from numpy.testing import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_array_almost_equal, +) from scipy.interpolate import interp1d import pytest import mne from mne.forward import _make_surface_mapping, make_field_map -from mne.forward._lead_dots import (_comp_sum_eeg, _comp_sums_meg, - _get_legen_table, _do_cross_dots) +from mne.forward._lead_dots import ( + _comp_sum_eeg, + _comp_sums_meg, + _get_legen_table, + _do_cross_dots, +) from mne.forward._make_forward import _create_meg_coils from mne.forward._field_interpolation import _setup_dots from mne.surface import get_meg_helmet_surf, get_head_surf @@ -21,15 +29,14 @@ from mne.io import read_raw_fif -base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') -raw_fname = op.join(base_dir, 'test_raw.fif') -evoked_fname = op.join(base_dir, 'test-ave.fif') -raw_ctf_fname = op.join(base_dir, 'test_ctf_raw.fif') +base_dir = op.join(op.dirname(__file__), "..", "..", "io", "tests", "data") +raw_fname = op.join(base_dir, "test_raw.fif") +evoked_fname = op.join(base_dir, "test-ave.fif") +raw_ctf_fname = op.join(base_dir, "test_ctf_raw.fif") data_path = testing.data_path(download=False) -trans_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_trunc-trans.fif') -subjects_dir = op.join(data_path, 'subjects') +trans_fname = op.join(data_path, "MEG", "sample", "sample_audvis_trunc-trans.fif") +subjects_dir = op.join(data_path, "subjects") @testing.requires_testing_data @@ -41,29 +48,30 @@ def test_field_map_ctf(): evoked = Epochs(raw, events).average() evoked.pick_channels(evoked.ch_names[:50]) # crappy mapping but faster # smoke test - passing trans_fname as pathlib.Path as additional check - make_field_map(evoked, trans=Path(trans_fname), subject='sample', - subjects_dir=subjects_dir) + make_field_map( + evoked, trans=Path(trans_fname), subject="sample", subjects_dir=subjects_dir + ) def test_legendre_val(): """Test Legendre polynomial (derivative) equivalence.""" rng = np.random.RandomState(0) # check table equiv - xs = np.linspace(-1., 1., 1000) + xs = np.linspace(-1.0, 1.0, 1000) n_terms = 100 # True, numpy vals_np = legendre.legvander(xs, n_terms - 1) # Table approximation - for nc, interp in zip([100, 50], ['nearest', 'linear']): - lut, n_fact = _get_legen_table('eeg', n_coeff=nc, force_calc=True) - lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, - axis=0) + for nc, interp in zip([100, 50], ["nearest", "linear"]): + lut, n_fact = _get_legen_table("eeg", n_coeff=nc, force_calc=True) + lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, axis=0) vals_i = lut_fun(xs) # Need a "1:" here because we omit the first coefficient in our table! - assert_allclose(vals_np[:, 1:vals_i.shape[1] + 1], vals_i, - rtol=1e-2, atol=5e-3) + assert_allclose( + vals_np[:, 1 : vals_i.shape[1] + 1], vals_i, rtol=1e-2, atol=5e-3 + ) # Now let's look at our sums ctheta = rng.rand(20, 30) * 2.0 - 1.0 @@ -74,24 +82,27 @@ def test_legendre_val(): # compare to numpy n = np.arange(1, n_terms, dtype=float)[:, np.newaxis, np.newaxis] coeffs = np.zeros((n_terms,) + beta.shape) - coeffs[1:] = (np.cumprod([beta] * (n_terms - 1), axis=0) * - (2.0 * n + 1.0) * (2.0 * n + 1.0) / n) + coeffs[1:] = ( + np.cumprod([beta] * (n_terms - 1), axis=0) + * (2.0 * n + 1.0) + * (2.0 * n + 1.0) + / n + ) # can't use tensor=False here b/c it isn't in old numpy c2 = np.empty((20, 30)) for ci1 in range(20): for ci2 in range(30): - c2[ci1, ci2] = legendre.legval(ctheta[ci1, ci2], - coeffs[:, ci1, ci2]) + c2[ci1, ci2] = legendre.legval(ctheta[ci1, ci2], coeffs[:, ci1, ci2]) assert_allclose(c1, c2, 1e-2, 1e-3) # close enough... # compare fast and slow for MEG ctheta = rng.rand(20 * 30) * 2.0 - 1.0 beta = rng.rand(20 * 30) * 0.8 - lut, n_fact = _get_legen_table('meg', n_coeff=10, force_calc=True) - fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, 'nearest', axis=0) + lut, n_fact = _get_legen_table("meg", n_coeff=10, force_calc=True) + fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, "nearest", axis=0) coeffs = _comp_sums_meg(beta, ctheta, fun, n_fact, False) - lut, n_fact = _get_legen_table('meg', n_coeff=20, force_calc=True) - fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, 'linear', axis=0) + lut, n_fact = _get_legen_table("meg", n_coeff=20, force_calc=True) + fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, "linear", axis=0) coeffs = _comp_sums_meg(beta, ctheta, fun, n_fact, False) @@ -99,10 +110,10 @@ def test_legendre_table(): """Test Legendre table calculation.""" # double-check our table generation n = 10 - for ch_type in ['eeg', 'meg']: + for ch_type in ["eeg", "meg"]: lut1, n_fact1 = _get_legen_table(ch_type, n_coeff=25, force_calc=True) - lut1 = lut1[:, :n - 1].copy() - n_fact1 = n_fact1[:n - 1].copy() + lut1 = lut1[:, : n - 1].copy() + n_fact1 = n_fact1[: n - 1].copy() lut2, n_fact2 = _get_legen_table(ch_type, n_coeff=n, force_calc=True) assert_allclose(lut1, lut2) assert_allclose(n_fact1, n_fact2) @@ -111,77 +122,93 @@ def test_legendre_table(): @testing.requires_testing_data def test_make_field_map_eeg(): """Test interpolation of EEG field onto head.""" - evoked = read_evokeds(evoked_fname, condition='Left Auditory') - evoked.info['bads'] = ['MEG 2443', 'EEG 053'] # add some bads - surf = get_head_surf('sample', subjects_dir=subjects_dir) + evoked = read_evokeds(evoked_fname, condition="Left Auditory") + evoked.info["bads"] = ["MEG 2443", "EEG 053"] # add some bads + surf = get_head_surf("sample", subjects_dir=subjects_dir) # we must have trans if surface is in MRI coords - pytest.raises(ValueError, _make_surface_mapping, evoked.info, surf, 'eeg') + pytest.raises(ValueError, _make_surface_mapping, evoked.info, surf, "eeg") evoked.pick_types(meg=False, eeg=True) - fmd = make_field_map(evoked, trans_fname, - subject='sample', subjects_dir=subjects_dir) + fmd = make_field_map( + evoked, trans_fname, subject="sample", subjects_dir=subjects_dir + ) # trans is necessary for EEG only - pytest.raises(RuntimeError, make_field_map, evoked, None, - subject='sample', subjects_dir=subjects_dir) - - fmd = make_field_map(evoked, trans_fname, - subject='sample', subjects_dir=subjects_dir) + pytest.raises( + RuntimeError, + make_field_map, + evoked, + None, + subject="sample", + subjects_dir=subjects_dir, + ) + + fmd = make_field_map( + evoked, trans_fname, subject="sample", subjects_dir=subjects_dir + ) assert len(fmd) == 1 - assert_array_equal(fmd[0]['data'].shape, (642, 59)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 59 + assert_array_equal(fmd[0]["data"].shape, (642, 59)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 59 @testing.requires_testing_data @pytest.mark.slowtest def test_make_field_map_meg(): """Test interpolation of MEG field onto helmet | head.""" - evoked = read_evokeds(evoked_fname, condition='Left Auditory') + evoked = read_evokeds(evoked_fname, condition="Left Auditory") info = evoked.info surf = get_meg_helmet_surf(info) # let's reduce the number of channels by a bunch to speed it up - info['bads'] = info['ch_names'][:200] + info["bads"] = info["ch_names"][:200] # bad ch_type - pytest.raises(ValueError, _make_surface_mapping, info, surf, 'foo') + pytest.raises(ValueError, _make_surface_mapping, info, surf, "foo") # bad mode - pytest.raises(ValueError, _make_surface_mapping, info, surf, 'meg', - mode='foo') + pytest.raises(ValueError, _make_surface_mapping, info, surf, "meg", mode="foo") # no picks evoked_eeg = evoked.copy().pick_types(meg=False, eeg=True) - pytest.raises(RuntimeError, _make_surface_mapping, evoked_eeg.info, - surf, 'meg') + pytest.raises(RuntimeError, _make_surface_mapping, evoked_eeg.info, surf, "meg") # bad surface def - nn = surf['nn'] - del surf['nn'] - pytest.raises(KeyError, _make_surface_mapping, info, surf, 'meg') - surf['nn'] = nn - cf = surf['coord_frame'] - del surf['coord_frame'] - pytest.raises(KeyError, _make_surface_mapping, info, surf, 'meg') - surf['coord_frame'] = cf + nn = surf["nn"] + del surf["nn"] + pytest.raises(KeyError, _make_surface_mapping, info, surf, "meg") + surf["nn"] = nn + cf = surf["coord_frame"] + del surf["coord_frame"] + pytest.raises(KeyError, _make_surface_mapping, info, surf, "meg") + surf["coord_frame"] = cf # now do it with make_field_map evoked.pick_types(meg=True, eeg=False) evoked.info.normalize_proj() # avoid projection warnings - fmd = make_field_map(evoked, None, - subject='sample', subjects_dir=subjects_dir) - assert (len(fmd) == 1) - assert_array_equal(fmd[0]['data'].shape, (304, 106)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 106 + fmd = make_field_map(evoked, None, subject="sample", subjects_dir=subjects_dir) + assert len(fmd) == 1 + assert_array_equal(fmd[0]["data"].shape, (304, 106)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 106 - pytest.raises(ValueError, make_field_map, evoked, ch_type='foobar') + pytest.raises(ValueError, make_field_map, evoked, ch_type="foobar") # now test the make_field_map on head surf for MEG evoked.pick_types(meg=True, eeg=False) evoked.info.normalize_proj() - fmd = make_field_map(evoked, trans_fname, meg_surf='head', - subject='sample', subjects_dir=subjects_dir) + fmd = make_field_map( + evoked, + trans_fname, + meg_surf="head", + subject="sample", + subjects_dir=subjects_dir, + ) assert len(fmd) == 1 - assert_array_equal(fmd[0]['data'].shape, (642, 106)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 106 + assert_array_equal(fmd[0]["data"].shape, (642, 106)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 106 - pytest.raises(ValueError, make_field_map, evoked, meg_surf='foobar', - subjects_dir=subjects_dir, trans=trans_fname) + pytest.raises( + ValueError, + make_field_map, + evoked, + meg_surf="foobar", + subjects_dir=subjects_dir, + trans=trans_fname, + ) @testing.requires_testing_data @@ -192,31 +219,45 @@ def test_make_field_map_meeg(): picks = picks[::10] evoked.pick_channels([evoked.ch_names[p] for p in picks]) evoked.info.normalize_proj() - maps = make_field_map(evoked, trans_fname, subject='sample', - subjects_dir=subjects_dir, verbose='debug') - assert_equal(maps[0]['data'].shape, (642, 6)) # EEG->Head - assert_equal(maps[1]['data'].shape, (304, 31)) # MEG->Helmet + maps = make_field_map( + evoked, + trans_fname, + subject="sample", + subjects_dir=subjects_dir, + verbose="debug", + ) + assert_equal(maps[0]["data"].shape, (642, 6)) # EEG->Head + assert_equal(maps[1]["data"].shape, (304, 31)) # MEG->Helmet # reasonable ranges maxs = (1.2, 2.0) # before #4418, was (1.1, 2.0) mins = (-0.8, -1.3) # before #4418, was (-0.6, -1.2) assert_equal(len(maxs), len(maps)) for map_, max_, min_ in zip(maps, maxs, mins): - assert_allclose(map_['data'].max(), max_, rtol=5e-2) - assert_allclose(map_['data'].min(), min_, rtol=5e-2) + assert_allclose(map_["data"].max(), max_, rtol=5e-2) + assert_allclose(map_["data"].min(), min_, rtol=5e-2) # calculated from correct looking mapping on 2015/12/26 - assert_allclose(np.sqrt(np.sum(maps[0]['data'] ** 2)), 19.0903, # 16.6088, - atol=1e-3, rtol=1e-3) - assert_allclose(np.sqrt(np.sum(maps[1]['data'] ** 2)), 19.4748, # 20.1245, - atol=1e-3, rtol=1e-3) + assert_allclose( + np.sqrt(np.sum(maps[0]["data"] ** 2)), 19.0903, atol=1e-3, rtol=1e-3 # 16.6088, + ) + assert_allclose( + np.sqrt(np.sum(maps[1]["data"] ** 2)), 19.4748, atol=1e-3, rtol=1e-3 # 20.1245, + ) def _setup_args(info): """Configure args for test_as_meg_type_evoked.""" - coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t']) - int_rad, _, lut_fun, n_fact = _setup_dots('fast', info, coils, 'meg') - my_origin = np.array([0., 0., 0.04]) - args_dict = dict(intrad=int_rad, volume=False, coils1=coils, r0=my_origin, - ch_type='meg', lut=lut_fun, n_fact=n_fact) + coils = _create_meg_coils(info["chs"], "normal", info["dev_head_t"]) + int_rad, _, lut_fun, n_fact = _setup_dots("fast", info, coils, "meg") + my_origin = np.array([0.0, 0.0, 0.04]) + args_dict = dict( + intrad=int_rad, + volume=False, + coils1=coils, + r0=my_origin, + ch_type="meg", + lut=lut_fun, + n_fact=n_fact, + ) return args_dict @@ -226,23 +267,30 @@ def test_as_meg_type_evoked(): # validation tests raw = read_raw_fif(raw_fname) events = mne.find_events(raw) - picks = pick_types(raw.info, meg=True, eeg=True, stim=True, - ecg=True, eog=True, include=['STI 014'], - exclude='bads') + picks = pick_types( + raw.info, + meg=True, + eeg=True, + stim=True, + ecg=True, + eog=True, + include=["STI 014"], + exclude="bads", + ) epochs = mne.Epochs(raw, events, picks=picks) evoked = epochs.average() with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"): - evoked.as_type('meg') + evoked.as_type("meg") with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"): - evoked.copy().pick_types(meg='grad').as_type('meg') + evoked.copy().pick_types(meg="grad").as_type("meg") # channel names - ch_names = evoked.info['ch_names'] + ch_names = evoked.info["ch_names"] virt_evoked = evoked.copy().pick_channels(ch_names=ch_names[:10:1]) virt_evoked.info.normalize_proj() - virt_evoked = virt_evoked.as_type('mag') - assert (all(ch.endswith('_v') for ch in virt_evoked.info['ch_names'])) + virt_evoked = virt_evoked.as_type("mag") + assert all(ch.endswith("_v") for ch in virt_evoked.info["ch_names"]) # pick from and to channels evoked_from = evoked.copy().pick_channels(ch_names=ch_names[2:10:3]) @@ -252,8 +300,8 @@ def test_as_meg_type_evoked(): # set up things args1, args2 = _setup_args(info_from), _setup_args(info_to) - args1.update(coils2=args2['coils1']) - args2.update(coils2=args1['coils1']) + args1.update(coils2=args2["coils1"]) + args2.update(coils2=args1["coils1"]) # test cross dots cross_dots1 = _do_cross_dots(**args1) @@ -263,14 +311,13 @@ def test_as_meg_type_evoked(): # correlation test evoked = evoked.pick_channels(ch_names=ch_names[:10:]).copy() - data1 = evoked.pick_types(meg='grad').data.ravel() - data2 = evoked.as_type('grad').data.ravel() - assert (np.corrcoef(data1, data2)[0, 1] > 0.95) + data1 = evoked.pick_types(meg="grad").data.ravel() + data2 = evoked.as_type("grad").data.ravel() + assert np.corrcoef(data1, data2)[0, 1] > 0.95 # Do it with epochs - virt_epochs = \ - epochs.copy().load_data().pick_channels(ch_names=ch_names[:10:1]) + virt_epochs = epochs.copy().load_data().pick_channels(ch_names=ch_names[:10:1]) virt_epochs.info.normalize_proj() - virt_epochs = virt_epochs.as_type('mag') - assert (all(ch.endswith('_v') for ch in virt_epochs.info['ch_names'])) + virt_epochs = virt_epochs.as_type("mag") + assert all(ch.endswith("_v") for ch in virt_epochs.info["ch_names"]) assert_allclose(virt_epochs.get_data().mean(0), virt_evoked.data) diff --git a/mne/forward/tests/test_forward.py b/mne/forward/tests/test_forward.py index ff244d9e0bf..59f53b349c0 100644 --- a/mne/forward/tests/test_forward.py +++ b/mne/forward/tests/test_forward.py @@ -3,57 +3,65 @@ import pytest import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_equal, - assert_array_equal, assert_allclose) +from numpy.testing import ( + assert_array_almost_equal, + assert_equal, + assert_array_equal, + assert_allclose, +) from mne.datasets import testing -from mne import (read_forward_solution, apply_forward, apply_forward_raw, - average_forward_solutions, write_forward_solution, - convert_forward_solution, SourceEstimate, pick_types_forward, - read_evokeds, VectorSourceEstimate) +from mne import ( + read_forward_solution, + apply_forward, + apply_forward_raw, + average_forward_solutions, + write_forward_solution, + convert_forward_solution, + SourceEstimate, + pick_types_forward, + read_evokeds, + VectorSourceEstimate, +) from mne.io import read_info from mne.label import read_label from mne.utils import requires_mne, run_subprocess -from mne.forward import (restrict_forward_to_stc, restrict_forward_to_label, - Forward, is_fixed_orient, compute_orient_prior, - compute_depth_prior) +from mne.forward import ( + restrict_forward_to_stc, + restrict_forward_to_label, + Forward, + is_fixed_orient, + compute_orient_prior, + compute_depth_prior, +) from mne.channels import equalize_channels data_path = testing.data_path(download=False) -fname_meeg = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_meeg_grad = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif" ) fname_evoked = ( - Path(__file__).parent.parent.parent - / "io" - / "tests" - / "data" - / "test-ave.fif" + Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif" ) def assert_forward_allclose(f1, f2, rtol=1e-7): """Compare two potentially converted forward solutions.""" - assert_allclose(f1['sol']['data'], f2['sol']['data'], rtol=rtol) - assert f1['sol']['ncol'] == f2['sol']['ncol'] - assert f1['sol']['ncol'] == f1['sol']['data'].shape[1] - assert_allclose(f1['source_nn'], f2['source_nn'], rtol=rtol) - if f1['sol_grad'] is not None: - assert (f2['sol_grad'] is not None) - assert_allclose(f1['sol_grad']['data'], f2['sol_grad']['data']) - assert f1['sol_grad']['ncol'] == f2['sol_grad']['ncol'] - assert f1['sol_grad']['ncol'] == f1['sol_grad']['data'].shape[1] + assert_allclose(f1["sol"]["data"], f2["sol"]["data"], rtol=rtol) + assert f1["sol"]["ncol"] == f2["sol"]["ncol"] + assert f1["sol"]["ncol"] == f1["sol"]["data"].shape[1] + assert_allclose(f1["source_nn"], f2["source_nn"], rtol=rtol) + if f1["sol_grad"] is not None: + assert f2["sol_grad"] is not None + assert_allclose(f1["sol_grad"]["data"], f2["sol_grad"]["data"]) + assert f1["sol_grad"]["ncol"] == f2["sol_grad"]["ncol"] + assert f1["sol_grad"]["ncol"] == f1["sol_grad"]["data"].shape[1] else: - assert (f2['sol_grad'] is None) - assert f1['source_ori'] == f2['source_ori'] - assert f1['surf_ori'] == f2['surf_ori'] - assert f1['src'][0]['coord_frame'] == f1['src'][0]['coord_frame'] + assert f2["sol_grad"] is None + assert f1["source_ori"] == f2["source_ori"] + assert f1["surf_ori"] == f2["surf_ori"] + assert f1["src"][0]["coord_frame"] == f1["src"][0]["coord_frame"] @testing.requires_testing_data @@ -61,33 +69,33 @@ def test_convert_forward(): """Test converting forward solution between different representations.""" fwd = read_forward_solution(fname_meeg_grad) fwd_repr = repr(fwd) - assert ('306' in fwd_repr) - assert ('60' in fwd_repr) - assert (fwd_repr) - assert (isinstance(fwd, Forward)) + assert "306" in fwd_repr + assert "60" in fwd_repr + assert fwd_repr + assert isinstance(fwd, Forward) # look at surface orientation fwd_surf = convert_forward_solution(fwd, surf_ori=True) # go back fwd_new = convert_forward_solution(fwd_surf, surf_ori=False) - assert (repr(fwd_new)) - assert (isinstance(fwd_new, Forward)) + assert repr(fwd_new) + assert isinstance(fwd_new, Forward) assert_forward_allclose(fwd, fwd_new) del fwd_new gc.collect() # now go to fixed - fwd_fixed = convert_forward_solution(fwd_surf, surf_ori=True, - force_fixed=True, use_cps=False) + fwd_fixed = convert_forward_solution( + fwd_surf, surf_ori=True, force_fixed=True, use_cps=False + ) del fwd_surf gc.collect() - assert (repr(fwd_fixed)) - assert (isinstance(fwd_fixed, Forward)) - assert (is_fixed_orient(fwd_fixed)) + assert repr(fwd_fixed) + assert isinstance(fwd_fixed, Forward) + assert is_fixed_orient(fwd_fixed) # now go back to cartesian (original condition) - fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False, - force_fixed=False) - assert (repr(fwd_new)) - assert (isinstance(fwd_new, Forward)) + fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False, force_fixed=False) + assert repr(fwd_new) + assert isinstance(fwd_new, Forward) assert_forward_allclose(fwd, fwd_new) del fwd, fwd_new, fwd_fixed gc.collect() @@ -100,86 +108,86 @@ def test_io_forward(tmp_path): # do extensive tests with MEEG + grad n_channels, n_src = 366, 108 fwd = read_forward_solution(fname_meeg_grad) - assert (isinstance(fwd, Forward)) + assert isinstance(fwd, Forward) fwd = read_forward_solution(fname_meeg_grad) fwd = convert_forward_solution(fwd, surf_ori=True) - leadfield = fwd['sol']['data'] + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - fname_temp = tmp_path / 'test-fwd.fif' - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + fname_temp = tmp_path / "test-fwd.fif" + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd = read_forward_solution(fname_meeg_grad) fwd = convert_forward_solution(fwd, surf_ori=True) fwd_read = read_forward_solution(fname_temp) fwd_read = convert_forward_solution(fwd_read, surf_ori=True) - leadfield = fwd_read['sol']['data'] + leadfield = fwd_read["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src)) - assert_equal(len(fwd_read['sol']['row_names']), n_channels) - assert_equal(len(fwd_read['info']['chs']), n_channels) - assert ('dev_head_t' in fwd_read['info']) - assert ('mri_head_t' in fwd_read) - assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data']) + assert_equal(len(fwd_read["sol"]["row_names"]), n_channels) + assert_equal(len(fwd_read["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd_read["info"] + assert "mri_head_t" in fwd_read + assert_array_almost_equal(fwd["sol"]["data"], fwd_read["sol"]["data"]) fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=False) - with pytest.warns(RuntimeWarning, match='stored on disk'): + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=False) + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=False) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=False + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) - leadfield = fwd['sol']['data'] + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, 1494 / 3)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - assert_equal(len(fwd['info']['chs']), n_channels) - assert ('dev_head_t' in fwd['info']) - assert ('mri_head_t' in fwd) - assert (fwd['surf_ori']) - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + assert_equal(len(fwd["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd["info"] + assert "mri_head_t" in fwd + assert fwd["surf_ori"] + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=True) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=True + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) fwd = read_forward_solution(fname_meeg_grad) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) - leadfield = fwd['sol']['data'] + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src / 3)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - assert_equal(len(fwd['info']['chs']), n_channels) - assert ('dev_head_t' in fwd['info']) - assert ('mri_head_t' in fwd) - assert (fwd['surf_ori']) - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + assert_equal(len(fwd["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd["info"] + assert "mri_head_t" in fwd + assert fwd["surf_ori"] + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=True) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=True + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) # test warnings on bad filenames fwd = read_forward_solution(fname_meeg_grad) - fwd_badname = tmp_path / 'test-bad-name.fif.gz' - with pytest.warns(RuntimeWarning, match='end with'): + fwd_badname = tmp_path / "test-bad-name.fif.gz" + with pytest.warns(RuntimeWarning, match="end with"): write_forward_solution(fwd_badname, fwd) - with pytest.warns(RuntimeWarning, match='end with'): + with pytest.warns(RuntimeWarning, match="end with"): read_forward_solution(fwd_badname) fwd = read_forward_solution(fname_meeg) @@ -198,53 +206,55 @@ def test_apply_forward(): t_start = 0.123 fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) assert isinstance(fwd, Forward) - vertno = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']] + vertno = [fwd["src"][0]["vertno"], fwd["src"][1]["vertno"]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) - gain_sum = np.sum(fwd['sol']['data'], axis=1) + gain_sum = np.sum(fwd["sol"]["data"], axis=1) # Evoked evoked = read_evokeds(fname_evoked, condition=0) evoked.pick_types(meg=True) - with pytest.warns(RuntimeWarning, match='only .* positive values'): + with pytest.warns(RuntimeWarning, match="only .* positive values"): evoked = apply_forward(fwd, stc, evoked.info, start=start, stop=stop) data = evoked.data times = evoked.times # do some tests - assert_array_almost_equal(evoked.info['sfreq'], sfreq) + assert_array_almost_equal(evoked.info["sfreq"], sfreq) assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum) assert_array_almost_equal(times[0], t_start) assert_array_almost_equal(times[-1], t_start + (n_times - 1) / sfreq) # vector stc_vec = VectorSourceEstimate( - fwd['source_nn'][:, :, np.newaxis] * stc.data[:, np.newaxis], - stc.vertices, stc.tmin, stc.tstep) - with pytest.warns(RuntimeWarning, match='very large'): + fwd["source_nn"][:, :, np.newaxis] * stc.data[:, np.newaxis], + stc.vertices, + stc.tmin, + stc.tstep, + ) + with pytest.warns(RuntimeWarning, match="very large"): evoked_2 = apply_forward(fwd, stc_vec, evoked.info) assert np.abs(evoked_2.data).mean() > 1e-5 assert_allclose(evoked.data, evoked_2.data, atol=1e-10) # Raw - with pytest.warns(RuntimeWarning, match='only .* positive values'): - raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, - stop=stop) + with pytest.warns(RuntimeWarning, match="only .* positive values"): + raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, stop=stop) data, times = raw_proj[:, :] # do some tests - assert_array_almost_equal(raw_proj.info['sfreq'], sfreq) + assert_array_almost_equal(raw_proj.info["sfreq"], sfreq) assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum) - atol = 1. / sfreq + atol = 1.0 / sfreq assert_allclose(raw_proj.first_samp / sfreq, t_start, atol=atol) - assert_allclose(raw_proj.last_samp / sfreq, - t_start + (n_times - 1) / sfreq, atol=atol) + assert_allclose( + raw_proj.last_samp / sfreq, t_start + (n_times - 1) / sfreq, atol=atol + ) @testing.requires_testing_data @@ -257,47 +267,47 @@ def test_restrict_forward_to_stc(tmp_path): t_start = 0.123 fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) - vertno = [fwd['src'][0]['vertno'][0:15], fwd['src'][1]['vertno'][0:5]] + vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) fwd_out = restrict_forward_to_stc(fwd, stc) - assert (isinstance(fwd_out, Forward)) + assert isinstance(fwd_out, Forward) - assert_equal(fwd_out['sol']['ncol'], 20) - assert_equal(fwd_out['src'][0]['nuse'], 15) - assert_equal(fwd_out['src'][1]['nuse'], 5) - assert_equal(fwd_out['src'][0]['vertno'], fwd['src'][0]['vertno'][0:15]) - assert_equal(fwd_out['src'][1]['vertno'], fwd['src'][1]['vertno'][0:5]) + assert_equal(fwd_out["sol"]["ncol"], 20) + assert_equal(fwd_out["src"][0]["nuse"], 15) + assert_equal(fwd_out["src"][1]["nuse"], 5) + assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15]) + assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5]) fwd = read_forward_solution(fname_meeg) fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=False) fwd = pick_types_forward(fwd, meg=True) - vertno = [fwd['src'][0]['vertno'][0:15], fwd['src'][1]['vertno'][0:5]] + vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) fwd_out = restrict_forward_to_stc(fwd, stc) - assert_equal(fwd_out['sol']['ncol'], 60) - assert_equal(fwd_out['src'][0]['nuse'], 15) - assert_equal(fwd_out['src'][1]['nuse'], 5) - assert_equal(fwd_out['src'][0]['vertno'], fwd['src'][0]['vertno'][0:15]) - assert_equal(fwd_out['src'][1]['vertno'], fwd['src'][1]['vertno'][0:5]) + assert_equal(fwd_out["sol"]["ncol"], 60) + assert_equal(fwd_out["src"][0]["nuse"], 15) + assert_equal(fwd_out["src"][1]["nuse"], 5) + assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15]) + assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5]) # Test saving the restricted forward object. This only works if all fields # are properly accounted for. - fname_copy = tmp_path / 'copy-fwd.fif' - with pytest.warns(RuntimeWarning, match='stored on disk'): + fname_copy = tmp_path / "copy-fwd.fif" + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_copy, fwd_out, overwrite=True) fwd_out_read = read_forward_solution(fname_copy) - fwd_out_read = convert_forward_solution(fwd_out_read, surf_ori=True, - force_fixed=False) + fwd_out_read = convert_forward_solution( + fwd_out_read, surf_ori=True, force_fixed=False + ) assert_forward_allclose(fwd_out, fwd_out_read) @@ -305,63 +315,61 @@ def test_restrict_forward_to_stc(tmp_path): def test_restrict_forward_to_label(tmp_path): """Test restriction of source space to label.""" fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) label_path = data_path / "MEG" / "sample" / "labels" - labels = ['Aud-lh', 'Vis-rh'] + labels = ["Aud-lh", "Vis-rh"] label_lh = read_label(label_path / (labels[0] + ".label")) label_rh = read_label(label_path / (labels[1] + ".label")) fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh]) - src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label_lh.vertices) - src_sel_lh = np.searchsorted(fwd['src'][0]['vertno'], src_sel_lh) - vertno_lh = fwd['src'][0]['vertno'][src_sel_lh] + src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices) + src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh) + vertno_lh = fwd["src"][0]["vertno"][src_sel_lh] - nuse_lh = fwd['src'][0]['nuse'] - src_sel_rh = np.intersect1d(fwd['src'][1]['vertno'], label_rh.vertices) - src_sel_rh = np.searchsorted(fwd['src'][1]['vertno'], src_sel_rh) - vertno_rh = fwd['src'][1]['vertno'][src_sel_rh] + nuse_lh = fwd["src"][0]["nuse"] + src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices) + src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh) + vertno_rh = fwd["src"][1]["vertno"][src_sel_rh] src_sel_rh += nuse_lh - assert_equal(fwd_out['sol']['ncol'], len(src_sel_lh) + len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['nuse'], len(src_sel_lh)) - assert_equal(fwd_out['src'][1]['nuse'], len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['vertno'], vertno_lh) - assert_equal(fwd_out['src'][1]['vertno'], vertno_rh) + assert_equal(fwd_out["sol"]["ncol"], len(src_sel_lh) + len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh)) + assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["vertno"], vertno_lh) + assert_equal(fwd_out["src"][1]["vertno"], vertno_rh) fwd = read_forward_solution(fname_meeg) fwd = pick_types_forward(fwd, meg=True) label_path = data_path / "MEG" / "sample" / "labels" - labels = ['Aud-lh', 'Vis-rh'] + labels = ["Aud-lh", "Vis-rh"] label_lh = read_label(label_path / (labels[0] + ".label")) label_rh = read_label(label_path / (labels[1] + ".label")) fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh]) - src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label_lh.vertices) - src_sel_lh = np.searchsorted(fwd['src'][0]['vertno'], src_sel_lh) - vertno_lh = fwd['src'][0]['vertno'][src_sel_lh] + src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices) + src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh) + vertno_lh = fwd["src"][0]["vertno"][src_sel_lh] - nuse_lh = fwd['src'][0]['nuse'] - src_sel_rh = np.intersect1d(fwd['src'][1]['vertno'], label_rh.vertices) - src_sel_rh = np.searchsorted(fwd['src'][1]['vertno'], src_sel_rh) - vertno_rh = fwd['src'][1]['vertno'][src_sel_rh] + nuse_lh = fwd["src"][0]["nuse"] + src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices) + src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh) + vertno_rh = fwd["src"][1]["vertno"][src_sel_rh] src_sel_rh += nuse_lh - assert_equal(fwd_out['sol']['ncol'], - 3 * (len(src_sel_lh) + len(src_sel_rh))) - assert_equal(fwd_out['src'][0]['nuse'], len(src_sel_lh)) - assert_equal(fwd_out['src'][1]['nuse'], len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['vertno'], vertno_lh) - assert_equal(fwd_out['src'][1]['vertno'], vertno_rh) + assert_equal(fwd_out["sol"]["ncol"], 3 * (len(src_sel_lh) + len(src_sel_rh))) + assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh)) + assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["vertno"], vertno_lh) + assert_equal(fwd_out["src"][1]["vertno"], vertno_rh) # Test saving the restricted forward object. This only works if all fields # are properly accounted for. - fname_copy = tmp_path / 'copy-fwd.fif' + fname_copy = tmp_path / "copy-fwd.fif" write_forward_solution(fname_copy, fwd_out, overwrite=True) fwd_out_read = read_forward_solution(fname_copy) assert_forward_allclose(fwd_out, fwd_out_read) @@ -387,20 +395,27 @@ def test_average_forward_solution(tmp_path): # try an easy case fwd_copy = average_forward_solutions([fwd]) - assert (isinstance(fwd_copy, Forward)) - assert_array_equal(fwd['sol']['data'], fwd_copy['sol']['data']) + assert isinstance(fwd_copy, Forward) + assert_array_equal(fwd["sol"]["data"], fwd_copy["sol"]["data"]) # modify a fwd solution, save it, use MNE to average with old one - fwd_copy['sol']['data'] *= 0.5 - fname_copy = str(tmp_path / 'copy-fwd.fif') + fwd_copy["sol"]["data"] *= 0.5 + fname_copy = str(tmp_path / "copy-fwd.fif") write_forward_solution(fname_copy, fwd_copy, overwrite=True) - cmd = ('mne_average_forward_solutions', '--fwd', fname_meeg, '--fwd', - fname_copy, '--out', fname_copy) + cmd = ( + "mne_average_forward_solutions", + "--fwd", + fname_meeg, + "--fwd", + fname_copy, + "--out", + fname_copy, + ) run_subprocess(cmd) # now let's actually do it, with one filename and one fwd fwd_ave = average_forward_solutions([fwd, fwd_copy]) - assert_array_equal(0.75 * fwd['sol']['data'], fwd_ave['sol']['data']) + assert_array_equal(0.75 * fwd["sol"]["data"], fwd_ave["sol"]["data"]) # fwd_ave_mne = read_forward_solution(fname_copy) # assert_array_equal(fwd_ave_mne['sol']['data'], fwd_ave['sol']['data']) @@ -416,32 +431,32 @@ def test_priors(): # Depth prior fwd = read_forward_solution(fname_meeg) assert not is_fixed_orient(fwd) - n_sources = fwd['nsource'] + n_sources = fwd["nsource"] info = read_info(fname_evoked) depth_prior = compute_depth_prior(fwd, info, exp=0.8) assert depth_prior.shape == (3 * n_sources,) - depth_prior = compute_depth_prior(fwd, info, exp=0.) - assert_array_equal(depth_prior, 1.) + depth_prior = compute_depth_prior(fwd, info, exp=0.0) + assert_array_equal(depth_prior, 1.0) with pytest.raises(ValueError, match='must be "whiten"'): - compute_depth_prior(fwd, info, limit_depth_chs='foo') - with pytest.raises(ValueError, match='noise_cov must be a Covariance'): - compute_depth_prior(fwd, info, limit_depth_chs='whiten') + compute_depth_prior(fwd, info, limit_depth_chs="foo") + with pytest.raises(ValueError, match="noise_cov must be a Covariance"): + compute_depth_prior(fwd, info, limit_depth_chs="whiten") fwd_fixed = convert_forward_solution(fwd, force_fixed=True) depth_prior = compute_depth_prior(fwd_fixed, info=info) assert depth_prior.shape == (n_sources,) # Orientation prior - orient_prior = compute_orient_prior(fwd, 1.) - assert_array_equal(orient_prior, 1.) - orient_prior = compute_orient_prior(fwd_fixed, 0.) - assert_array_equal(orient_prior, 1.) - with pytest.raises(ValueError, match='oriented in surface coordinates'): + orient_prior = compute_orient_prior(fwd, 1.0) + assert_array_equal(orient_prior, 1.0) + orient_prior = compute_orient_prior(fwd_fixed, 0.0) + assert_array_equal(orient_prior, 1.0) + with pytest.raises(ValueError, match="oriented in surface coordinates"): compute_orient_prior(fwd, 0.5) fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True) orient_prior = compute_orient_prior(fwd_surf_ori, 0.5) - assert all(np.in1d(orient_prior, (0.5, 1.))) - with pytest.raises(ValueError, match='between 0 and 1'): + assert all(np.in1d(orient_prior, (0.5, 1.0))) + with pytest.raises(ValueError, match="between 0 and 1"): compute_orient_prior(fwd_surf_ori, -0.5) - with pytest.raises(ValueError, match='with fixed orientation'): + with pytest.raises(ValueError, match="with fixed orientation"): compute_orient_prior(fwd_fixed, 0.5) @@ -449,8 +464,8 @@ def test_priors(): def test_equalize_channels(): """Test equalization of channels for instances of Forward.""" fwd1 = read_forward_solution(fname_meeg) - fwd1.pick_channels(['EEG 001', 'EEG 002', 'EEG 003']) - fwd2 = fwd1.copy().pick_channels(['EEG 002', 'EEG 001'], ordered=True) + fwd1.pick_channels(["EEG 001", "EEG 002", "EEG 003"]) + fwd2 = fwd1.copy().pick_channels(["EEG 002", "EEG 001"], ordered=True) fwd1, fwd2 = equalize_channels([fwd1, fwd2]) - assert fwd1.ch_names == ['EEG 001', 'EEG 002'] - assert fwd2.ch_names == ['EEG 001', 'EEG 002'] + assert fwd1.ch_names == ["EEG 001", "EEG 002"] + assert fwd2.ch_names == ["EEG 001", "EEG 002"] diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index b23a1ec2f6e..627822aca95 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -11,37 +11,50 @@ from mne.datasets import testing from mne.io import read_raw_fif, read_raw_kit, read_raw_bti, read_info from mne.io.constants import FIFF -from mne import (read_forward_solution, write_forward_solution, - make_forward_solution, convert_forward_solution, - setup_volume_source_space, read_source_spaces, create_info, - make_sphere_model, pick_types_forward, pick_info, pick_types, - read_evokeds, read_cov, read_dipole, - get_volume_labels_from_aseg) +from mne import ( + read_forward_solution, + write_forward_solution, + make_forward_solution, + convert_forward_solution, + setup_volume_source_space, + read_source_spaces, + create_info, + make_sphere_model, + pick_types_forward, + pick_info, + pick_types, + read_evokeds, + read_cov, + read_dipole, + get_volume_labels_from_aseg, +) from mne.surface import _get_ico_surface from mne.transforms import Transform -from mne.utils import (requires_mne, run_subprocess, catch_logging, - requires_mne_mark, requires_openmeeg_mark) +from mne.utils import ( + requires_mne, + run_subprocess, + catch_logging, + requires_mne_mark, + requires_openmeeg_mark, +) from mne.forward._make_forward import _create_meg_coils, make_forward_dipole from mne.forward._compute_forward import _magnetic_dipole_field_vec from mne.forward import Forward, _do_forward_solution, use_coil_def from mne.dipole import Dipole, fit_dipole from mne.simulation import simulate_evoked from mne.source_estimate import VolSourceEstimate -from mne.source_space import (write_source_spaces, _compare_source_spaces, - setup_source_space) +from mne.source_space import ( + write_source_spaces, + _compare_source_spaces, + setup_source_space, +) from mne.forward.tests.test_forward import assert_forward_allclose data_path = testing.data_path(download=False) -fname_meeg = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_raw = ( - Path(__file__).parent.parent.parent - / "io" - / "tests" - / "data" - / "test_raw.fif" + Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" ) fname_evo = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" @@ -49,9 +62,7 @@ fname_trans = data_path / "MEG" / "sample" / "sample_audvis_trunc-trans.fif" subjects_dir = data_path / "subjects" fname_src = subjects_dir / "sample" / "bem" / "sample-oct-4-src.fif" -fname_bem = ( - subjects_dir / "sample" / "bem" / "sample-1280-1280-1280-bem-sol.fif" -) +fname_bem = subjects_dir / "sample" / "bem" / "sample-1280-1280-1280-bem-sol.fif" fname_aseg = subjects_dir / "sample" / "mri" / "aseg.mgz" fname_bem_meg = subjects_dir / "sample" / "bem" / "sample-1280-bem-sol.fif" @@ -70,9 +81,9 @@ def _col_corrs(a, b): a_std = np.sqrt((a * a).mean(0)) b_std = np.sqrt((b * b).mean(0)) all_zero = (a_std == 0) & (b_std == 0) - num[all_zero] = 1. - a_std[all_zero] = 1. - b_std[all_zero] = 1. + num[all_zero] = 1.0 + a_std[all_zero] = 1.0 + b_std[all_zero] = 1.0 return num / (a_std * b_std) @@ -81,67 +92,94 @@ def _rdm(a, b): a_norm = np.linalg.norm(a, axis=0) b_norm = np.linalg.norm(b, axis=0) all_zero = (a_norm == 0) & (b_norm == 0) - a_norm[all_zero] = 1. - b_norm[all_zero] = 1. + a_norm[all_zero] = 1.0 + b_norm[all_zero] = 1.0 return a_norm / b_norm -def _compare_forwards(fwd, fwd_py, n_sensors, n_src, - meg_rtol=1e-4, meg_atol=1e-9, - meg_corr_tol=0.99, meg_rdm_tol=0.01, - eeg_rtol=1e-3, eeg_atol=1e-3, - eeg_corr_tol=0.99, eeg_rdm_tol=0.01): +def _compare_forwards( + fwd, + fwd_py, + n_sensors, + n_src, + meg_rtol=1e-4, + meg_atol=1e-9, + meg_corr_tol=0.99, + meg_rdm_tol=0.01, + eeg_rtol=1e-3, + eeg_atol=1e-3, + eeg_corr_tol=0.99, + eeg_rdm_tol=0.01, +): """Test forwards.""" # check source spaces - assert len(fwd['src']) == len(fwd_py['src']) - _compare_source_spaces(fwd['src'], fwd_py['src'], mode='approx') + assert len(fwd["src"]) == len(fwd_py["src"]) + _compare_source_spaces(fwd["src"], fwd_py["src"], mode="approx") for surf_ori, force_fixed in product([False, True], [False, True]): # use copy here to leave our originals unmodified - fwd = convert_forward_solution(fwd, surf_ori, force_fixed, copy=True, - use_cps=True) - fwd_py = convert_forward_solution(fwd_py, surf_ori, force_fixed, - copy=True, use_cps=True) + fwd = convert_forward_solution( + fwd, surf_ori, force_fixed, copy=True, use_cps=True + ) + fwd_py = convert_forward_solution( + fwd_py, surf_ori, force_fixed, copy=True, use_cps=True + ) check_src = n_src // 3 if force_fixed else n_src - for key in ('nchan', 'source_rr', 'source_ori', - 'surf_ori', 'coord_frame', 'nsource'): - assert_allclose(fwd_py[key], fwd[key], rtol=1e-4, atol=1e-7, - err_msg=key) + for key in ( + "nchan", + "source_rr", + "source_ori", + "surf_ori", + "coord_frame", + "nsource", + ): + assert_allclose(fwd_py[key], fwd[key], rtol=1e-4, atol=1e-7, err_msg=key) # In surf_ori=True only Z matters for source_nn if surf_ori and not force_fixed: ori_sl = slice(2, None, 3) else: ori_sl = slice(None) - assert_allclose(fwd_py['source_nn'][ori_sl], fwd['source_nn'][ori_sl], - rtol=1e-4, atol=1e-6) - assert_allclose(fwd_py['mri_head_t']['trans'], - fwd['mri_head_t']['trans'], rtol=1e-5, atol=1e-8) - - assert fwd_py['sol']['data'].shape == (n_sensors, check_src) - assert len(fwd['sol']['row_names']) == n_sensors - assert len(fwd_py['sol']['row_names']) == n_sensors + assert_allclose( + fwd_py["source_nn"][ori_sl], fwd["source_nn"][ori_sl], rtol=1e-4, atol=1e-6 + ) + assert_allclose( + fwd_py["mri_head_t"]["trans"], + fwd["mri_head_t"]["trans"], + rtol=1e-5, + atol=1e-8, + ) + + assert fwd_py["sol"]["data"].shape == (n_sensors, check_src) + assert len(fwd["sol"]["row_names"]) == n_sensors + assert len(fwd_py["sol"]["row_names"]) == n_sensors # check MEG - fwd_meg = fwd['sol']['data'][:306, ori_sl] - fwd_meg_py = fwd_py['sol']['data'][:306, ori_sl] - assert_allclose(fwd_meg, fwd_meg_py, rtol=meg_rtol, atol=meg_atol, - err_msg='MEG mismatch') + fwd_meg = fwd["sol"]["data"][:306, ori_sl] + fwd_meg_py = fwd_py["sol"]["data"][:306, ori_sl] + assert_allclose( + fwd_meg, fwd_meg_py, rtol=meg_rtol, atol=meg_atol, err_msg="MEG mismatch" + ) meg_corrs = _col_corrs(fwd_meg, fwd_meg_py) - assert_array_less(meg_corr_tol, meg_corrs, err_msg='MEG corr/MAG') + assert_array_less(meg_corr_tol, meg_corrs, err_msg="MEG corr/MAG") meg_rdm = _rdm(fwd_meg, fwd_meg_py) - assert_allclose(meg_rdm, 1, atol=meg_rdm_tol, err_msg='MEG RDM') + assert_allclose(meg_rdm, 1, atol=meg_rdm_tol, err_msg="MEG RDM") # check EEG - if fwd['sol']['data'].shape[0] > 306: - fwd_eeg = fwd['sol']['data'][306:, ori_sl] - fwd_eeg_py = fwd['sol']['data'][306:, ori_sl] - assert_allclose(fwd_eeg, fwd_eeg_py, rtol=eeg_rtol, atol=eeg_atol, - err_msg='EEG mismatch') + if fwd["sol"]["data"].shape[0] > 306: + fwd_eeg = fwd["sol"]["data"][306:, ori_sl] + fwd_eeg_py = fwd["sol"]["data"][306:, ori_sl] + assert_allclose( + fwd_eeg, + fwd_eeg_py, + rtol=eeg_rtol, + atol=eeg_atol, + err_msg="EEG mismatch", + ) # To test so-called MAG we use correlation (related to cosine # similarity) and also RDM to test the amplitude mismatch eeg_corrs = _col_corrs(fwd_eeg, fwd_eeg_py) - assert_array_less(eeg_corr_tol, eeg_corrs, err_msg='EEG corr/MAG') + assert_array_less(eeg_corr_tol, eeg_corrs, err_msg="EEG corr/MAG") eeg_rdm = _rdm(fwd_eeg, fwd_eeg_py) - assert_allclose(eeg_rdm, 1, atol=eeg_rdm_tol, err_msg='EEG RDM') + assert_allclose(eeg_rdm, 1, atol=eeg_rdm_tol, err_msg="EEG RDM") def test_magnetic_dipole(): @@ -149,24 +187,24 @@ def test_magnetic_dipole(): info = read_info(fname_raw) picks = pick_types(info, meg=True, eeg=False, exclude=[]) info = pick_info(info, picks[:12]) - coils = _create_meg_coils(info['chs'], 'normal', None) + coils = _create_meg_coils(info["chs"], "normal", None) # magnetic dipole far (meters!) from device origin - r0 = np.array([0., 13., -6.]) - for ch, coil in zip(info['chs'], coils): - rr = (ch['loc'][:3] + r0) / 2. # get halfway closer + r0 = np.array([0.0, 13.0, -6.0]) + for ch, coil in zip(info["chs"], coils): + rr = (ch["loc"][:3] + r0) / 2.0 # get halfway closer far_fwd = _magnetic_dipole_field_vec(r0[np.newaxis, :], [coil]) near_fwd = _magnetic_dipole_field_vec(rr[np.newaxis, :], [coil]) - ratio = 8. if ch['ch_name'][-1] == '1' else 16. # grad vs mag + ratio = 8.0 if ch["ch_name"][-1] == "1" else 16.0 # grad vs mag assert_allclose(np.median(near_fwd / far_fwd), ratio, atol=1e-1) # degenerate case - r0 = coils[0]['rmag'][[0]] - with pytest.raises(RuntimeError, match='Coil too close'): + r0 = coils[0]["rmag"][[0]] + with pytest.raises(RuntimeError, match="Coil too close"): _magnetic_dipole_field_vec(r0, coils[:1]) - with pytest.warns(RuntimeWarning, match='Coil too close'): - fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close='warning') + with pytest.warns(RuntimeWarning, match="Coil too close"): + fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close="warning") assert not np.isfinite(fwd).any() - with np.errstate(invalid='ignore'): - fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close='info') + with np.errstate(invalid="ignore"): + fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close="info") assert not np.isfinite(fwd).any() @@ -181,116 +219,183 @@ def test_make_forward_solution_kit(tmp_path, fname_src_small): fname_kit_raw = kit_dir / "test_bin_raw.fif" # first use mne-C: convert file, make forward solution - fwd = _do_forward_solution('sample', fname_kit_raw, src=fname_src_small, - bem=fname_bem_meg, mri=trans_path, - eeg=False, meg=True, subjects_dir=subjects_dir) - assert (isinstance(fwd, Forward)) + fwd = _do_forward_solution( + "sample", + fname_kit_raw, + src=fname_src_small, + bem=fname_bem_meg, + mri=trans_path, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + ) + assert isinstance(fwd, Forward) # now let's use python with the same raw file src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(fname_kit_raw, trans_path, src, - fname_bem_meg, eeg=False, meg=True) + fwd_py = make_forward_solution( + fname_kit_raw, trans_path, src, fname_bem_meg, eeg=False, meg=True + ) _compare_forwards(fwd, fwd_py, 157, n_src_small) - assert (isinstance(fwd_py, Forward)) + assert isinstance(fwd_py, Forward) # now let's use mne-python all the way raw_py = read_raw_kit(sqd_path, mrk_path, elp_path, hsp_path) # without ignore_ref=True, this should throw an error: - with pytest.raises(NotImplementedError, match='Cannot.*KIT reference'): - make_forward_solution(raw_py.info, src=src, eeg=False, meg=True, - bem=fname_bem_meg, trans=trans_path) + with pytest.raises(NotImplementedError, match="Cannot.*KIT reference"): + make_forward_solution( + raw_py.info, + src=src, + eeg=False, + meg=True, + bem=fname_bem_meg, + trans=trans_path, + ) # check that asking for eeg channels (even if they don't exist) is handled - meg_only_info = pick_info(raw_py.info, pick_types(raw_py.info, meg=True, - eeg=False)) - fwd_py = make_forward_solution(meg_only_info, src=src, meg=True, eeg=True, - bem=fname_bem_meg, trans=trans_path, - ignore_ref=True) - _compare_forwards(fwd, fwd_py, 157, n_src_small, - meg_rtol=1e-3, meg_atol=1e-7) + meg_only_info = pick_info(raw_py.info, pick_types(raw_py.info, meg=True, eeg=False)) + fwd_py = make_forward_solution( + meg_only_info, + src=src, + meg=True, + eeg=True, + bem=fname_bem_meg, + trans=trans_path, + ignore_ref=True, + ) + _compare_forwards(fwd, fwd_py, 157, n_src_small, meg_rtol=1e-3, meg_atol=1e-7) @requires_mne def test_make_forward_solution_bti(fname_src_small): """Test BTI end-to-end versus C.""" - bti_pdf = bti_dir / 'test_pdf_linux' - bti_config = bti_dir / 'test_config_linux' - bti_hs = bti_dir / 'test_hs_linux' - fname_bti_raw = bti_dir / 'exported4D_linux_raw.fif' + bti_pdf = bti_dir / "test_pdf_linux" + bti_config = bti_dir / "test_config_linux" + bti_hs = bti_dir / "test_hs_linux" + fname_bti_raw = bti_dir / "exported4D_linux_raw.fif" raw_py = read_raw_bti(bti_pdf, bti_config, bti_hs, preload=False) src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(raw_py.info, src=src, eeg=False, meg=True, - bem=fname_bem_meg, trans=trans_path) - fwd = _do_forward_solution('sample', fname_bti_raw, src=fname_src_small, - bem=fname_bem_meg, mri=trans_path, - eeg=False, meg=True, subjects_dir=subjects_dir) + fwd_py = make_forward_solution( + raw_py.info, src=src, eeg=False, meg=True, bem=fname_bem_meg, trans=trans_path + ) + fwd = _do_forward_solution( + "sample", + fname_bti_raw, + src=fname_src_small, + bem=fname_bem_meg, + mri=trans_path, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + ) _compare_forwards(fwd, fwd_py, 248, n_src_small) -@pytest.mark.parametrize('other', [ - pytest.param('MNE-C', marks=requires_mne_mark()), - pytest.param('openmeeg', marks=requires_openmeeg_mark()), -]) +@pytest.mark.parametrize( + "other", + [ + pytest.param("MNE-C", marks=requires_mne_mark()), + pytest.param("openmeeg", marks=requires_openmeeg_mark()), + ], +) def test_make_forward_solution_ctf(tmp_path, fname_src_small, other): """Test CTF w/compensation against MNE-C or OpenMEEG.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") src = read_source_spaces(fname_src_small) raw = read_raw_fif(fname_ctf_raw) assert raw.compensation_grade == 3 - if other == 'openmeeg': - mindist = 20. + if other == "openmeeg": + mindist = 20.0 n_src_want = 51 else: - assert other == 'MNE-C' - mindist = 0. + assert other == "MNE-C" + mindist = 0.0 n_src_want = n_src_small assert n_src_want == 108 - mindist = 20. if other == 'openmeeg' else 0. + mindist = 20.0 if other == "openmeeg" else 0.0 fwd_py = make_forward_solution( - fname_ctf_raw, fname_trans, src, fname_bem_meg, eeg=False, - mindist=mindist, verbose=True) - - if other == 'openmeeg': + fname_ctf_raw, + fname_trans, + src, + fname_bem_meg, + eeg=False, + mindist=mindist, + verbose=True, + ) + + if other == "openmeeg": # TODO: This should be a 1-layer, but it's broken # (some correlations become negative!)... bem_surfaces = read_bem_surfaces(fname_bem) # fname_bem_meg - bem = make_bem_solution(bem_surfaces, solver='openmeeg') + bem = make_bem_solution(bem_surfaces, solver="openmeeg") # TODO: These tolerances are bad tol_kwargs = dict(meg_atol=1, meg_corr_tol=0.65, meg_rdm_tol=0.6) fwd = make_forward_solution( - fname_ctf_raw, fname_trans, src, bem, eeg=False, mindist=mindist, - verbose=True) + fname_ctf_raw, + fname_trans, + src, + bem, + eeg=False, + mindist=mindist, + verbose=True, + ) else: - assert other == 'MNE-C' + assert other == "MNE-C" bem = None tol_kwargs = dict() fwd = _do_forward_solution( - 'sample', fname_ctf_raw, mri=fname_trans, src=fname_src_small, - bem=fname_bem_meg, eeg=False, meg=True, subjects_dir=subjects_dir, - mindist=mindist) + "sample", + fname_ctf_raw, + mri=fname_trans, + src=fname_src_small, + bem=fname_bem_meg, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + mindist=mindist, + ) _compare_forwards(fwd, fwd_py, 274, n_src_want, **tol_kwargs) # CTF with compensation changed in python ctf_raw = read_raw_fif(fname_ctf_raw) - ctf_raw.info['bads'] = ['MRO24-2908'] # test that it works with some bads + ctf_raw.info["bads"] = ["MRO24-2908"] # test that it works with some bads ctf_raw.apply_gradient_compensation(2) fwd_py = make_forward_solution( - ctf_raw.info, fname_trans, src, fname_bem_meg, eeg=False, meg=True, - mindist=mindist) - if other == 'openmeeg': + ctf_raw.info, + fname_trans, + src, + fname_bem_meg, + eeg=False, + meg=True, + mindist=mindist, + ) + if other == "openmeeg": assert bem is not None fwd = make_forward_solution( - ctf_raw.info, fname_trans, src, bem, eeg=False, mindist=mindist, - verbose=True) + ctf_raw.info, + fname_trans, + src, + bem, + eeg=False, + mindist=mindist, + verbose=True, + ) else: fwd = _do_forward_solution( - 'sample', ctf_raw, mri=fname_trans, src=fname_src_small, - bem=fname_bem_meg, eeg=False, meg=True, subjects_dir=subjects_dir, - mindist=mindist) + "sample", + ctf_raw, + mri=fname_trans, + src=fname_src_small, + bem=fname_bem_meg, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + mindist=mindist, + ) _compare_forwards(fwd, fwd_py, 274, n_src_want, **tol_kwargs) - fname_temp = tmp_path / 'test-ctf-fwd.fif' + fname_temp = tmp_path / "test-ctf-fwd.fif" write_forward_solution(fname_temp, fwd_py) fwd_py2 = read_forward_solution(fname_temp) _compare_forwards(fwd_py, fwd_py2, 274, n_src_want, **tol_kwargs) @@ -303,25 +408,32 @@ def test_make_forward_solution_basic(): with catch_logging() as log: # make sure everything can be path-like (gh #10872) fwd_py = make_forward_solution( - Path(fname_raw), Path(fname_trans), Path(fname_src), - Path(fname_bem), mindist=5., verbose=True) + Path(fname_raw), + Path(fname_trans), + Path(fname_src), + Path(fname_bem), + mindist=5.0, + verbose=True, + ) log = log.getvalue() - assert 'Total 258/258 points inside the surface' in log - assert (isinstance(fwd_py, Forward)) + assert "Total 258/258 points inside the surface" in log + assert isinstance(fwd_py, Forward) fwd = read_forward_solution(fname_meeg) - assert (isinstance(fwd, Forward)) + assert isinstance(fwd, Forward) _compare_forwards(fwd, fwd_py, 366, 1494, meg_rtol=1e-3) # Homogeneous model - with pytest.raises(RuntimeError, match='homogeneous.*1-layer.*EEG'): - make_forward_solution(fname_raw, fname_trans, fname_src, - fname_bem_meg) + with pytest.raises(RuntimeError, match="homogeneous.*1-layer.*EEG"): + make_forward_solution(fname_raw, fname_trans, fname_src, fname_bem_meg) @requires_openmeeg_mark() -@pytest.mark.parametrize("n_layers", [ - 3, - pytest.param(1, marks=pytest.mark.xfail(raises=RuntimeError)), -]) +@pytest.mark.parametrize( + "n_layers", + [ + 3, + pytest.param(1, marks=pytest.mark.xfail(raises=RuntimeError)), + ], +) @testing.requires_testing_data def test_make_forward_solution_openmeeg(n_layers): """Test making M-EEG forward solution from OpenMEEG.""" @@ -329,33 +441,45 @@ def test_make_forward_solution_openmeeg(n_layers): bem_surfaces = read_bem_surfaces(fname_bem) raw = read_raw_fif(fname_raw) n_sensors = 366 - ch_types = ['eeg', 'meg'] + ch_types = ["eeg", "meg"] if n_layers == 1: - ch_types = ['meg'] + ch_types = ["meg"] bem_surfaces = bem_surfaces[-1:] - assert bem_surfaces[0]['id'] == FIFF.FIFFV_BEM_SURF_ID_BRAIN + assert bem_surfaces[0]["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN n_sensors = 306 raw.pick(ch_types) n_sources_kept = 501 // 3 fwds = dict() for solver in ["openmeeg", "mne"]: bem = make_bem_solution(bem_surfaces, solver=solver) - assert bem['solver'] == solver + assert bem["solver"] == solver with catch_logging() as log: # make sure everything can be path-like (gh #10872) fwd = make_forward_solution( - raw.info, Path(fname_trans), Path(fname_src), - bem, mindist=20., verbose=True) + raw.info, + Path(fname_trans), + Path(fname_src), + bem, + mindist=20.0, + verbose=True, + ) log = log.getvalue() - assert 'Total 258/258 points inside the surface' in log - assert (isinstance(fwd, Forward)) + assert "Total 258/258 points inside the surface" in log + assert isinstance(fwd, Forward) fwds[solver] = fwd del fwd - _compare_forwards(fwds["openmeeg"], - fwds["mne"], n_sensors, n_sources_kept * 3, - meg_atol=1, eeg_atol=100, - meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, eeg_rdm_tol=0.2) + _compare_forwards( + fwds["openmeeg"], + fwds["mne"], + n_sensors, + n_sources_kept * 3, + meg_atol=1, + eeg_atol=100, + meg_corr_tol=0.98, + eeg_corr_tol=0.98, + meg_rdm_tol=0.1, + eeg_rdm_tol=0.2, + ) def test_make_forward_solution_discrete(tmp_path, small_surf_src): @@ -363,31 +487,36 @@ def test_make_forward_solution_discrete(tmp_path, small_surf_src): # smoke test for depth weighting and discrete source spaces src = small_surf_src src = src + setup_volume_source_space( - pos=dict(rr=src[0]['rr'][src[0]['vertno'][:3]].copy(), - nn=src[0]['nn'][src[0]['vertno'][:3]].copy())) + pos=dict( + rr=src[0]["rr"][src[0]["vertno"][:3]].copy(), + nn=src[0]["nn"][src[0]["vertno"][:3]].copy(), + ) + ) sphere = make_sphere_model() - fwd = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + fwd = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) convert_forward_solution(fwd, surf_ori=True) n_src_small = 108 # this is the resulting # of verts in fwd -@pytest.fixture(scope='module', params=[testing._pytest_param()]) +@pytest.fixture(scope="module", params=[testing._pytest_param()]) def small_surf_src(): """Create a small surface source space.""" - pytest.importorskip('nibabel') - src = setup_source_space('sample', 'oct2', subjects_dir=subjects_dir, - add_dist=False) - assert sum(s['nuse'] for s in src) * 3 == n_src_small + pytest.importorskip("nibabel") + src = setup_source_space( + "sample", "oct2", subjects_dir=subjects_dir, add_dist=False + ) + assert sum(s["nuse"] for s in src) * 3 == n_src_small return src @pytest.fixture() def fname_src_small(tmp_path, small_surf_src): """Create a small source space.""" - fname_src_small = tmp_path / 'sample-oct-2-src.fif' + fname_src_small = tmp_path / "sample-oct-2-src.fif" write_source_spaces(fname_src_small, small_surf_src) return fname_src_small @@ -396,39 +525,65 @@ def fname_src_small(tmp_path, small_surf_src): @pytest.mark.timeout(90) # can take longer than 60 s on Travis def test_make_forward_solution_sphere(tmp_path, fname_src_small): """Test making a forward solution with a sphere model.""" - out_name = tmp_path / 'tmp-fwd.fif' - run_subprocess(['mne_forward_solution', '--meg', '--eeg', - '--meas', fname_raw, '--src', fname_src_small, - '--mri', fname_trans, '--fwd', out_name]) + out_name = tmp_path / "tmp-fwd.fif" + run_subprocess( + [ + "mne_forward_solution", + "--meg", + "--eeg", + "--meas", + fname_raw, + "--src", + fname_src_small, + "--mri", + fname_trans, + "--fwd", + out_name, + ] + ) fwd = read_forward_solution(out_name) sphere = make_sphere_model(verbose=True) src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=True, verbose=True) - _compare_forwards(fwd, fwd_py, 366, 108, - meg_rtol=5e-1, meg_atol=1e-6, - eeg_rtol=5e-1, eeg_atol=5e-1) + fwd_py = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=True, verbose=True + ) + _compare_forwards( + fwd, + fwd_py, + 366, + 108, + meg_rtol=5e-1, + meg_atol=1e-6, + eeg_rtol=5e-1, + eeg_atol=5e-1, + ) # Since the above is pretty lax, let's check a different way for meg, eeg in zip([True, False], [False, True]): fwd_ = pick_types_forward(fwd, meg=meg, eeg=eeg) fwd_py_ = pick_types_forward(fwd, meg=meg, eeg=eeg) - assert_allclose(np.corrcoef(fwd_['sol']['data'].ravel(), - fwd_py_['sol']['data'].ravel())[0, 1], - 1.0, rtol=1e-3) + assert_allclose( + np.corrcoef(fwd_["sol"]["data"].ravel(), fwd_py_["sol"]["data"].ravel())[ + 0, 1 + ], + 1.0, + rtol=1e-3, + ) # Number of layers in the sphere model doesn't matter for MEG # (as long as no sources are omitted due to distance) - assert len(sphere['layers']) == 4 - fwd = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + assert len(sphere["layers"]) == 4 + fwd = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) sphere_1 = make_sphere_model(head_radius=None) - assert len(sphere_1['layers']) == 0 - assert_array_equal(sphere['r0'], sphere_1['r0']) - fwd_1 = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + assert len(sphere_1["layers"]) == 0 + assert_array_equal(sphere["r0"], sphere_1["r0"]) + fwd_1 = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) _compare_forwards(fwd, fwd_1, 306, 108, meg_rtol=1e-12, meg_atol=1e-12) # Homogeneous model sphere = make_sphere_model(head_radius=None) - with pytest.raises(RuntimeError, match='zero shells.*EEG'): + with pytest.raises(RuntimeError, match="zero shells.*EEG"): make_forward_solution(fname_raw, fname_trans, src, sphere) @@ -436,7 +591,7 @@ def test_make_forward_solution_sphere(tmp_path, fname_src_small): @testing.requires_testing_data def test_forward_mixed_source_space(tmp_path): """Test making the forward solution for a mixed source space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") # get the surface source space rng = np.random.RandomState(0) surf = read_source_spaces(fname_src) @@ -444,42 +599,49 @@ def test_forward_mixed_source_space(tmp_path): # setup two volume source spaces label_names = get_volume_labels_from_aseg(fname_aseg) vol_labels = rng.choice(label_names, 2) - with pytest.warns(RuntimeWarning, match='Found no usable.*CC_Mid_Ant.*'): - vol1 = setup_volume_source_space('sample', pos=20., mri=fname_aseg, - volume_label=vol_labels[0], - add_interpolator=False) - vol2 = setup_volume_source_space('sample', pos=20., mri=fname_aseg, - volume_label=vol_labels[1], - add_interpolator=False) + with pytest.warns(RuntimeWarning, match="Found no usable.*CC_Mid_Ant.*"): + vol1 = setup_volume_source_space( + "sample", + pos=20.0, + mri=fname_aseg, + volume_label=vol_labels[0], + add_interpolator=False, + ) + vol2 = setup_volume_source_space( + "sample", + pos=20.0, + mri=fname_aseg, + volume_label=vol_labels[1], + add_interpolator=False, + ) # merge surfaces and volume src = surf + vol1 + vol2 # calculate forward solution fwd = make_forward_solution(fname_raw, fname_trans, src, fname_bem) - assert (repr(fwd)) + assert repr(fwd) # extract source spaces - src_from_fwd = fwd['src'] + src_from_fwd = fwd["src"] # get the coordinate frame of each source space - coord_frames = np.array([s['coord_frame'] for s in src_from_fwd]) + coord_frames = np.array([s["coord_frame"] for s in src_from_fwd]) # assert that all source spaces are in head coordinates - assert ((coord_frames == FIFF.FIFFV_COORD_HEAD).all()) + assert (coord_frames == FIFF.FIFFV_COORD_HEAD).all() # run tests for SourceSpaces.export_volume - fname_img = tmp_path / 'temp-image.mgz' + fname_img = tmp_path / "temp-image.mgz" # head coordinates and mri_resolution, but trans file - with pytest.raises(ValueError, match='trans containing mri to head'): + with pytest.raises(ValueError, match="trans containing mri to head"): src_from_fwd.export_volume(fname_img, mri_resolution=True, trans=None) # head coordinates and mri_resolution, but wrong trans file - vox_mri_t = vol1[0]['vox_mri_t'] - with pytest.raises(ValueError, match='head<->mri, got mri_voxel->mri'): - src_from_fwd.export_volume(fname_img, mri_resolution=True, - trans=vox_mri_t) + vox_mri_t = vol1[0]["vox_mri_t"] + with pytest.raises(ValueError, match="head<->mri, got mri_voxel->mri"): + src_from_fwd.export_volume(fname_img, mri_resolution=True, trans=vox_mri_t) @pytest.mark.slowtest @@ -490,11 +652,11 @@ def test_make_forward_dipole(tmp_path): evoked = read_evokeds(fname_evo)[0] cov = read_cov(fname_cov) - cov['projs'] = [] # avoid proj warning + cov["projs"] = [] # avoid proj warning dip_c = read_dipole(fname_dip) # Only use magnetometers for speed! - picks = pick_types(evoked.info, meg='mag', eeg=False)[::8] + picks = pick_types(evoked.info, meg="mag", eeg=False)[::8] evoked.pick_channels([evoked.ch_names[p] for p in picks]) evoked.info.normalize_proj() info = evoked.info @@ -503,18 +665,19 @@ def test_make_forward_dipole(tmp_path): # in the test dataset. n_test_dipoles = 3 # minimum 3 needed to get uneven sampling in time dipsel = np.sort(rng.permutation(np.arange(len(dip_c)))[:n_test_dipoles]) - dip_test = Dipole(times=dip_c.times[dipsel], - pos=dip_c.pos[dipsel], - amplitude=dip_c.amplitude[dipsel], - ori=dip_c.ori[dipsel], - gof=dip_c.gof[dipsel]) + dip_test = Dipole( + times=dip_c.times[dipsel], + pos=dip_c.pos[dipsel], + amplitude=dip_c.amplitude[dipsel], + ori=dip_c.ori[dipsel], + gof=dip_c.gof[dipsel], + ) sphere = make_sphere_model(head_radius=0.1) # Warning emitted due to uneven sampling in time - with pytest.warns(RuntimeWarning, match='unevenly spaced'): - fwd, stc = make_forward_dipole(dip_test, sphere, info, - trans=fname_trans) + with pytest.warns(RuntimeWarning, match="unevenly spaced"): + fwd, stc = make_forward_dipole(dip_test, sphere, info, trans=fname_trans) # stc is list of VolSourceEstimate's assert isinstance(stc, list) @@ -526,8 +689,7 @@ def test_make_forward_dipole(tmp_path): times, pos, amplitude, ori, gof = [], [], [], [], [] nave = 400 # add a tiny amount of noise to the simulated evokeds for s in stc: - evo_test = simulate_evoked(fwd, s, info, cov, - nave=nave, random_state=rng) + evo_test = simulate_evoked(fwd, s, info, cov, nave=nave, random_state=rng) # evo_test.add_proj(make_eeg_average_ref_proj(evo_test.info)) dfit, resid = fit_dipole(evo_test, cov, sphere, None) times += dfit.times.tolist() @@ -544,14 +706,16 @@ def test_make_forward_dipole(tmp_path): diff = dip_test.pos - dip_fit.pos corr = np.corrcoef(dip_test.pos.ravel(), dip_fit.pos.ravel())[0, 1] dist = np.sqrt(np.mean(np.sum(diff * diff, axis=1))) - gc_dist = 180 / np.pi * \ - np.mean(np.arccos(np.sum(dip_test.ori * dip_fit.ori, axis=1))) + gc_dist = ( + 180 / np.pi * np.mean(np.arccos(np.sum(dip_test.ori * dip_fit.ori, axis=1))) + ) amp_err = np.sqrt(np.mean((dip_test.amplitude - dip_fit.amplitude) ** 2)) # Make sure each coordinate is close to reference # NB tolerance should be set relative to snr of simulated evoked! - assert_allclose(dip_fit.pos, dip_test.pos, rtol=0, atol=1e-2, - err_msg='position mismatch') + assert_allclose( + dip_fit.pos, dip_test.pos, rtol=0, atol=1e-2, err_msg="position mismatch" + ) assert dist < 1e-2 # within 1 cm assert corr > 0.985 assert gc_dist < 20 # less than 20 degrees @@ -560,20 +724,22 @@ def test_make_forward_dipole(tmp_path): # Make sure rejection works with BEM: one dipole at z=1m # NB _make_forward.py:_prepare_for_forward will raise a RuntimeError # if no points are left after min_dist exclusions, hence 2 dips here! - dip_outside = Dipole(times=[0., 0.001], - pos=[[0., 0., 1.0], [0., 0., 0.040]], - amplitude=[100e-9, 100e-9], - ori=[[1., 0., 0.], [1., 0., 0.]], gof=1) - with pytest.raises(ValueError, match='outside the inner skull'): + dip_outside = Dipole( + times=[0.0, 0.001], + pos=[[0.0, 0.0, 1.0], [0.0, 0.0, 0.040]], + amplitude=[100e-9, 100e-9], + ori=[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + gof=1, + ) + with pytest.raises(ValueError, match="outside the inner skull"): make_forward_dipole(dip_outside, fname_bem, info, fname_trans) # if we get this far, can safely assume the code works with BEMs too # -> use sphere again below for speed # Now make an evenly sampled set of dipoles, some simultaneous, # should return a VolSourceEstimate regardless - times = [0., 0., 0., 0.001, 0.001, 0.002] - pos = np.random.rand(6, 3) * 0.020 + \ - np.array([0., 0., 0.040])[np.newaxis, :] + times = [0.0, 0.0, 0.0, 0.001, 0.001, 0.002] + pos = np.random.rand(6, 3) * 0.020 + np.array([0.0, 0.0, 0.040])[np.newaxis, :] amplitude = np.random.rand(6) * 100e-9 ori = np.eye(6, 3) + np.eye(6, 3, -3) gof = np.arange(len(times)) / len(times) # arbitrary @@ -581,61 +747,63 @@ def test_make_forward_dipole(tmp_path): dip_even_samp = Dipole(times, pos, amplitude, ori, gof) # I/O round-trip - fname = str(tmp_path / 'test-fwd.fif') - with pytest.warns(RuntimeWarning, match='free orientation'): + fname = str(tmp_path / "test-fwd.fif") + with pytest.warns(RuntimeWarning, match="free orientation"): write_forward_solution(fname, fwd) - fwd_read = convert_forward_solution( - read_forward_solution(fname), force_fixed=True) + fwd_read = convert_forward_solution(read_forward_solution(fname), force_fixed=True) assert_forward_allclose(fwd, fwd_read, rtol=1e-6) - fwd, stc = make_forward_dipole(dip_even_samp, sphere, info, - trans=fname_trans) + fwd, stc = make_forward_dipole(dip_even_samp, sphere, info, trans=fname_trans) assert isinstance(stc, VolSourceEstimate) - assert_allclose(stc.times, np.arange(0., 0.003, 0.001)) + assert_allclose(stc.times, np.arange(0.0, 0.003, 0.001)) # Test passing a list of Dipoles instead of a single Dipole object - fwd2, stc2 = make_forward_dipole([dip_even_samp[0], dip_even_samp[1:]], - sphere, info, trans=fname_trans) - assert_array_equal(fwd['sol']['data'], fwd2['sol']['data']) + fwd2, stc2 = make_forward_dipole( + [dip_even_samp[0], dip_even_samp[1:]], sphere, info, trans=fname_trans + ) + assert_array_equal(fwd["sol"]["data"], fwd2["sol"]["data"]) assert_array_equal(stc.data, stc2.data) @testing.requires_testing_data def test_make_forward_no_meg(tmp_path): """Test that we can make and I/O forward solution with no MEG channels.""" - pos = dict(rr=[[0.05, 0, 0]], nn=[[0, 0, 1.]]) + pos = dict(rr=[[0.05, 0, 0]], nn=[[0, 0, 1.0]]) src = setup_volume_source_space(pos=pos) bem = make_sphere_model() trans = None - montage = make_standard_montage('standard_1020') - info = create_info(['Cz'], 1000., 'eeg').set_montage(montage) + montage = make_standard_montage("standard_1020") + info = create_info(["Cz"], 1000.0, "eeg").set_montage(montage) fwd = make_forward_solution(info, trans, src, bem) - fname = tmp_path / 'test-fwd.fif' + fname = tmp_path / "test-fwd.fif" write_forward_solution(fname, fwd) fwd_read = read_forward_solution(fname) - assert_allclose(fwd['sol']['data'], fwd_read['sol']['data']) + assert_allclose(fwd["sol"]["data"], fwd_read["sol"]["data"]) def test_use_coil_def(tmp_path): """Test use_coil_def.""" - info = create_info(1, 1000., 'mag') - info['chs'][0]['coil_type'] = 9999 - info['chs'][0]['loc'][:] = [0, 0, 0.02, 1, 0, 0, 0, 1, 0, 0, 0, 1] - sphere = make_sphere_model((0., 0., 0.), 0.01) + info = create_info(1, 1000.0, "mag") + info["chs"][0]["coil_type"] = 9999 + info["chs"][0]["loc"][:] = [0, 0, 0.02, 1, 0, 0, 0, 1, 0, 0, 0, 1] + sphere = make_sphere_model((0.0, 0.0, 0.0), 0.01) src = setup_volume_source_space(pos=5, sphere=sphere) - trans = Transform('head', 'mri', None) - with pytest.raises(RuntimeError, match='coil definition not found'): + trans = Transform("head", "mri", None) + with pytest.raises(RuntimeError, match="coil definition not found"): make_forward_solution(info, trans, src, sphere) - coil_fname = tmp_path / 'coil_def.dat' - with open(coil_fname, 'w') as fid: - fid.write("""# custom cube coil def + coil_fname = tmp_path / "coil_def.dat" + with open(coil_fname, "w") as fid: + fid.write( + """# custom cube coil def 1 9999 2 8 3e-03 0.000e+00 "Test" - 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000""") - with pytest.raises(RuntimeError, match='Could not interpret'): + 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000""" + ) + with pytest.raises(RuntimeError, match="Could not interpret"): with use_coil_def(coil_fname): make_forward_solution(info, trans, src, sphere) - with open(coil_fname, 'w') as fid: - fid.write("""# custom cube coil def + with open(coil_fname, "w") as fid: + fid.write( + """# custom cube coil def 1 9999 2 8 3e-03 0.000e+00 "Test" 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000 1.000 0.1250 -0.750e-03 0.750e-03 -0.750e-03 0.000 0.000 1.000 @@ -644,7 +812,8 @@ def test_use_coil_def(tmp_path): 0.1250 -0.750e-03 -0.750e-03 0.750e-03 0.000 0.000 1.000 0.1250 -0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000 0.1250 0.750e-03 -0.750e-03 0.750e-03 0.000 0.000 1.000 - 0.1250 0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000""") + 0.1250 0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000""" + ) with use_coil_def(coil_fname): make_forward_solution(info, trans, src, sphere) @@ -653,27 +822,27 @@ def test_use_coil_def(tmp_path): @testing.requires_testing_data def test_sensors_inside_bem(): """Test that sensors inside the BEM are problematic.""" - rr = _get_ico_surface(1)['rr'] + rr = _get_ico_surface(1)["rr"] rr /= np.linalg.norm(rr, axis=1, keepdims=True) rr *= 0.1 assert len(rr) == 42 - info = create_info(len(rr), 1000., 'mag') - info['dev_head_t'] = Transform('meg', 'head', np.eye(4)) - for ii, ch in enumerate(info['chs']): - ch['loc'][:] = np.concatenate((rr[ii], np.eye(3).ravel())) - trans = Transform('head', 'mri', np.eye(4)) - trans['trans'][2, 3] = 0.03 - sphere_noshell = make_sphere_model((0., 0., 0.), None) - sphere = make_sphere_model((0., 0., 0.), 1.01) - with pytest.raises(RuntimeError, match='.* 15 MEG.*inside the scalp.*'): + info = create_info(len(rr), 1000.0, "mag") + info["dev_head_t"] = Transform("meg", "head", np.eye(4)) + for ii, ch in enumerate(info["chs"]): + ch["loc"][:] = np.concatenate((rr[ii], np.eye(3).ravel())) + trans = Transform("head", "mri", np.eye(4)) + trans["trans"][2, 3] = 0.03 + sphere_noshell = make_sphere_model((0.0, 0.0, 0.0), None) + sphere = make_sphere_model((0.0, 0.0, 0.0), 1.01) + with pytest.raises(RuntimeError, match=".* 15 MEG.*inside the scalp.*"): make_forward_solution(info, trans, fname_src, fname_bem) make_forward_solution(info, trans, fname_src, fname_bem_meg) # okay make_forward_solution(info, trans, fname_src, sphere_noshell) # okay - with pytest.raises(RuntimeError, match='.* 42 MEG.*outermost sphere sh.*'): + with pytest.raises(RuntimeError, match=".* 42 MEG.*outermost sphere sh.*"): make_forward_solution(info, trans, fname_src, sphere) - sphere = make_sphere_model((0., 0., 2.0), 1.01) # weird, but okay + sphere = make_sphere_model((0.0, 0.0, 2.0), 1.01) # weird, but okay make_forward_solution(info, trans, fname_src, sphere) - for ch in info['chs']: - ch['loc'][:3] *= 0.1 - with pytest.raises(RuntimeError, match='.* 42 MEG.*the inner skull.*'): + for ch in info["chs"]: + ch["loc"][:3] *= 0.1 + with pytest.raises(RuntimeError, match=".* 42 MEG.*the inner skull.*"): make_forward_solution(info, trans, fname_src, fname_bem_meg) diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index 7dffe749732..0bd08f62ad5 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -8,14 +8,32 @@ @verbose -def coregistration(tabbed=False, split=True, width=None, inst=None, - subject=None, subjects_dir=None, guess_mri_subject=None, - height=None, head_opacity=None, head_high_res=None, - trans=None, scrollable=True, *, - orient_to_surface=True, scale_by_distance=True, - mark_inside=True, interaction=None, scale=None, - advanced_rendering=None, head_inside=True, - fullscreen=None, show=True, block=False, verbose=None): +def coregistration( + tabbed=False, + split=True, + width=None, + inst=None, + subject=None, + subjects_dir=None, + guess_mri_subject=None, + height=None, + head_opacity=None, + head_high_res=None, + trans=None, + scrollable=True, + *, + orient_to_surface=True, + scale_by_distance=True, + mark_inside=True, + interaction=None, + scale=None, + advanced_rendering=None, + head_inside=True, + fullscreen=None, + show=True, + block=False, + verbose=None, +): """Coregister an MRI with a subject's head shape. The GUI can be launched through the command line interface: @@ -127,13 +145,13 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, .. youtube:: ALV5qqMHLlQ """ unsupported_params = { - 'tabbed': (tabbed, False), - 'split': (split, True), - 'scrollable': (scrollable, True), - 'head_inside': (head_inside, True), - 'guess_mri_subject': guess_mri_subject, - 'scale': scale, - 'advanced_rendering': advanced_rendering, + "tabbed": (tabbed, False), + "split": (split, True), + "scrollable": (scrollable, True), + "head_inside": (head_inside, True), + "guess_mri_subject": guess_mri_subject, + "scale": scale, + "advanced_rendering": advanced_rendering, } for key, val in unsupported_params.items(): if isinstance(val, tuple): @@ -141,45 +159,44 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, else: to_raise = val is not None if to_raise: - warn(f"The parameter {key} is not supported with" - " the pyvistaqt 3d backend. It will be ignored.") + warn( + f"The parameter {key} is not supported with" + " the pyvistaqt 3d backend. It will be ignored." + ) config = get_config() if guess_mri_subject is None: - guess_mri_subject = config.get( - 'MNE_COREG_GUESS_MRI_SUBJECT', 'true') == 'true' + guess_mri_subject = config.get("MNE_COREG_GUESS_MRI_SUBJECT", "true") == "true" if head_high_res is None: - head_high_res = config.get('MNE_COREG_HEAD_HIGH_RES', 'true') == 'true' + head_high_res = config.get("MNE_COREG_HEAD_HIGH_RES", "true") == "true" if advanced_rendering is None: - advanced_rendering = \ - config.get('MNE_COREG_ADVANCED_RENDERING', 'true') == 'true' + advanced_rendering = ( + config.get("MNE_COREG_ADVANCED_RENDERING", "true") == "true" + ) if head_opacity is None: - head_opacity = config.get('MNE_COREG_HEAD_OPACITY', 0.8) + head_opacity = config.get("MNE_COREG_HEAD_OPACITY", 0.8) if head_inside is None: - head_inside = \ - config.get('MNE_COREG_HEAD_INSIDE', 'true').lower() == 'true' + head_inside = config.get("MNE_COREG_HEAD_INSIDE", "true").lower() == "true" if width is None: - width = config.get('MNE_COREG_WINDOW_WIDTH', 800) + width = config.get("MNE_COREG_WINDOW_WIDTH", 800) if height is None: - height = config.get('MNE_COREG_WINDOW_HEIGHT', 600) + height = config.get("MNE_COREG_WINDOW_HEIGHT", 600) if subjects_dir is None: - if 'SUBJECTS_DIR' in config: - subjects_dir = config['SUBJECTS_DIR'] - elif 'MNE_COREG_SUBJECTS_DIR' in config: - subjects_dir = config['MNE_COREG_SUBJECTS_DIR'] + if "SUBJECTS_DIR" in config: + subjects_dir = config["SUBJECTS_DIR"] + elif "MNE_COREG_SUBJECTS_DIR" in config: + subjects_dir = config["MNE_COREG_SUBJECTS_DIR"] if orient_to_surface is None: - orient_to_surface = (config.get('MNE_COREG_ORIENT_TO_SURFACE', '') == - 'true') + orient_to_surface = config.get("MNE_COREG_ORIENT_TO_SURFACE", "") == "true" if scale_by_distance is None: - scale_by_distance = (config.get('MNE_COREG_SCALE_BY_DISTANCE', '') == - 'true') + scale_by_distance = config.get("MNE_COREG_SCALE_BY_DISTANCE", "") == "true" if interaction is None: - interaction = config.get('MNE_COREG_INTERACTION', 'terrain') + interaction = config.get("MNE_COREG_INTERACTION", "terrain") if mark_inside is None: - mark_inside = config.get('MNE_COREG_MARK_INSIDE', '') == 'true' + mark_inside = config.get("MNE_COREG_MARK_INSIDE", "") == "true" if scale is None: - scale = config.get('MNE_COREG_SCENE_SCALE', 0.16) + scale = config.get("MNE_COREG_SCENE_SCALE", 0.16) if fullscreen is None: - fullscreen = config.get('MNE_COREG_FULLSCREEN', '') == 'true' + fullscreen = config.get("MNE_COREG_FULLSCREEN", "") == "true" head_opacity = float(head_opacity) head_inside = bool(head_inside) width = int(width) @@ -188,23 +205,44 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING from ._coreg import CoregistrationUI + if MNE_3D_BACKEND_TESTING: show = block = False return CoregistrationUI( - info_file=inst, subject=subject, subjects_dir=subjects_dir, - head_resolution=head_high_res, head_opacity=head_opacity, - orient_glyphs=orient_to_surface, scale_by_distance=scale_by_distance, - mark_inside=mark_inside, trans=trans, size=(width, height), show=show, - block=block, interaction=interaction, fullscreen=fullscreen, - verbose=verbose + info_file=inst, + subject=subject, + subjects_dir=subjects_dir, + head_resolution=head_high_res, + head_opacity=head_opacity, + orient_glyphs=orient_to_surface, + scale_by_distance=scale_by_distance, + mark_inside=mark_inside, + trans=trans, + size=(width, height), + show=show, + block=block, + interaction=interaction, + fullscreen=fullscreen, + verbose=verbose, ) -@deprecated('Use the :mod:`mne-gui-addons:mne_gui_addons` package instead, ' - 'will be removed in version 1.5.0') +@deprecated( + "Use the :mod:`mne-gui-addons:mne_gui_addons` package instead, " + "will be removed in version 1.5.0" +) @verbose -def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, - groups=None, show=True, block=False, verbose=None): +def locate_ieeg( + info, + trans, + base_image, + subject=None, + subjects_dir=None, + groups=None, + show=True, + block=False, + verbose=None, +): """Locate intracranial electrode contacts. Parameters @@ -242,55 +280,79 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, from ..viz.backends._utils import _qt_app_exec from ._ieeg_locate import IntracranialElectrodeLocator from qtpy.QtWidgets import QApplication + mne_gui = None # get application app = QApplication.instance() if app is None: app = QApplication(["Intracranial Electrode Locator"]) gui = IntracranialElectrodeLocator( - info, trans, base_image, subject=subject, subjects_dir=subjects_dir, - groups=groups, show=show, verbose=verbose) + info, + trans, + base_image, + subject=subject, + subjects_dir=subjects_dir, + groups=groups, + show=show, + verbose=verbose, + ) if block: _qt_app_exec(app) - return mne_gui.locate_ieeg( - info=info, trans=trans, base_image=base_image, - subject=subject, subjects_dir=subjects_dir, - groups=groups, show=show, block=block) if mne_gui else gui + return ( + mne_gui.locate_ieeg( + info=info, + trans=trans, + base_image=base_image, + subject=subject, + subjects_dir=subjects_dir, + groups=groups, + show=show, + block=block, + ) + if mne_gui + else gui + ) class _GUIScraper: """Scrape GUI outputs.""" def __repr__(self): - return '' + return "" def __call__(self, block, block_vars, gallery_conf): from ._ieeg_locate import IntracranialElectrodeLocator from ._coreg import CoregistrationUI + gui_classes = ( IntracranialElectrodeLocator, CoregistrationUI, ) try: - from mne_gui_addons._ieeg_locate import IntracranialElectrodeLocator # noqa: E501 + from mne_gui_addons._ieeg_locate import ( + IntracranialElectrodeLocator, + ) # noqa: E501 except Exception: pass else: gui_classes = gui_classes + (IntracranialElectrodeLocator,) from sphinx_gallery.scrapers import figure_rst from qtpy import QtGui - for gui in block_vars['example_globals'].values(): - if (isinstance(gui, gui_classes) and - not getattr(gui, '_scraped', False) and - gallery_conf['builder_name'] == 'html'): + + for gui in block_vars["example_globals"].values(): + if ( + isinstance(gui, gui_classes) + and not getattr(gui, "_scraped", False) + and gallery_conf["builder_name"] == "html" + ): gui._scraped = True # monkey-patch but it's easy enough - img_fname = next(block_vars['image_path_iterator']) + img_fname = next(block_vars["image_path_iterator"]) # TODO fix in window refactor - window = gui if hasattr(gui, 'grab') else gui._renderer._window + window = gui if hasattr(gui, "grab") else gui._renderer._window # window is QWindow # https://doc.qt.io/qt-5/qwidget.html#grab pixmap = window.grab() - if hasattr(gui, '_renderer'): # if no renderer, no need + if hasattr(gui, "_renderer"): # if no renderer, no need # Now the tricky part: we need to get the 3D renderer, # extract the image from it, and put it in the correct # place in the pixmap. The easiest way to do this is @@ -302,8 +364,8 @@ def __call__(self, block, block_vars, gallery_conf): # https://doc.qt.io/qt-5/qwidget.html#mapTo # https://doc.qt.io/qt-5/qpainter.html#drawPixmap-1 QtGui.QPainter(pixmap).drawPixmap( - plotter.mapTo(window, plotter.rect().topLeft()), - sub_pixmap) + plotter.mapTo(window, plotter.rect().topLeft()), sub_pixmap + ) # https://doc.qt.io/qt-5/qpixmap.html#save pixmap.save(img_fname) try: # for compatibility with both GUIs, will be refactored @@ -311,6 +373,5 @@ def __call__(self, block, block_vars, gallery_conf): except Exception: pass gui.close() - return figure_rst( - [img_fname], gallery_conf['src_dir'], 'GUI') - return '' + return figure_rst([img_fname], gallery_conf["src_dir"], "GUI") + return "" diff --git a/mne/gui/_core.py b/mne/gui/_core.py index b40f16621b3..03ba89b79c4 100644 --- a/mne/gui/_core.py +++ b/mne/gui/_core.py @@ -11,9 +11,16 @@ from qtpy import QtCore from qtpy.QtCore import Slot, Qt -from qtpy.QtWidgets import (QMainWindow, QGridLayout, - QVBoxLayout, QHBoxLayout, QLabel, - QMessageBox, QWidget, QLineEdit) +from qtpy.QtWidgets import ( + QMainWindow, + QGridLayout, + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QLineEdit, +) from matplotlib import patheffects from matplotlib.backends.backend_qt5agg import FigureCanvas @@ -24,28 +31,35 @@ from ..viz.utils import safe_event from ..surface import _read_mri_surface, _marching_cubes from ..transforms import apply_trans, _frame_to_str -from ..utils import (logger, _check_fname, verbose, warn, get_subjects_dir, - _import_nibabel) +from ..utils import ( + logger, + _check_fname, + verbose, + warn, + get_subjects_dir, + _import_nibabel, +) from ..viz.backends._utils import _qt_safe_window -_IMG_LABELS = [['I', 'P'], ['I', 'L'], ['P', 'L']] +_IMG_LABELS = [["I", "P"], ["I", "L"], ["P", "L"]] _ZOOM_STEP_SIZE = 5 @verbose def _load_image(img, verbose=None): """Load data from a 3D image file (e.g. CT, MR).""" - nib = _import_nibabel('use GUI') + nib = _import_nibabel("use GUI") if not isinstance(img, nib.spatialimages.SpatialImage): - logger.debug(f'Loading {img}') - _check_fname(img, overwrite='read', must_exist=True) + logger.debug(f"Loading {img}") + _check_fname(img, overwrite="read", must_exist=True) img = nib.load(img) # get data orig_data = np.array(img.dataobj).astype(np.float32) # reorient data to RAS ornt = nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(img.affine)).astype(int) - ras_ornt = nib.orientations.axcodes2ornt('RAS') + nib.orientations.aff2axcodes(img.affine) + ).astype(int) + ras_ornt = nib.orientations.axcodes2ornt("RAS") ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) orig_mgh = nib.MGHImage(orig_data, img.affine) @@ -55,14 +69,20 @@ def _load_image(img, verbose=None): return img_data, vox_ras_t, vox_scan_ras_t -def _make_mpl_plot(width=4, height=4, dpi=300, tight=True, hide_axes=True, - facecolor='black', invert=True): +def _make_mpl_plot( + width=4, + height=4, + dpi=300, + tight=True, + hide_axes=True, + facecolor="black", + invert=True, +): fig = Figure(figsize=(width, height), dpi=dpi) canvas = FigureCanvas(fig) ax = fig.subplots() if tight: - fig.subplots_adjust(bottom=0, left=0, right=1, top=1, - wspace=0, hspace=0) + fig.subplots_adjust(bottom=0, left=0, right=1, top=1, wspace=0, hspace=0) ax.set_facecolor(facecolor) # clean up excess plot text, invert if invert: @@ -82,9 +102,8 @@ class SliceBrowser(QMainWindow): (0, 1), ) - @_qt_safe_window(splash='_renderer.figure.splash', window='') - def __init__(self, base_image=None, subject=None, subjects_dir=None, - verbose=None): + @_qt_safe_window(splash="_renderer.figure.splash", window="") + def __init__(self, base_image=None, subject=None, subjects_dir=None, verbose=None): """GUI for browsing slices of anatomical images.""" # initialize QMainWindow class super(SliceBrowser, self).__init__() @@ -92,10 +111,11 @@ def __init__(self, base_image=None, subject=None, subjects_dir=None, self._verbose = verbose # if bad/None subject, will raise an informative error when loading MRI - subject = os.environ.get('SUBJECT') if subject is None else subject + subject = os.environ.get("SUBJECT") if subject is None else subject subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=False)) - self._subject_dir = op.join(subjects_dir, subject) \ - if subject and subjects_dir else None + self._subject_dir = ( + op.join(subjects_dir, subject) if subject and subjects_dir else None + ) self._load_image_data(base_image=base_image) # GUI design @@ -108,10 +128,11 @@ def __init__(self, base_image=None, subject=None, subjects_dir=None, self._plt_grid.addWidget(canvas, i // 2, i % 2) self._figs.append(fig) self._renderer = _get_renderer( - name='Slice Browser', size=(400, 400), bgcolor='w') + name="Slice Browser", size=(400, 400), bgcolor="w" + ) self._plt_grid.addWidget(self._renderer.plotter, 1, 1) - self._set_ras([0., 0., 0.], update_plots=False) + self._set_ras([0.0, 0.0, 0.0], update_plots=False) self._plot_images() @@ -141,10 +162,14 @@ def _load_image_data(self, base_image=None): self._head = None self._lh = self._rh = None else: - mri_img = 'brain' if op.isfile(op.join( - self._subject_dir, 'mri', 'brain.mgz')) else 'T1' + mri_img = ( + "brain" + if op.isfile(op.join(self._subject_dir, "mri", "brain.mgz")) + else "T1" + ) self._mri_data, vox_ras_t, vox_scan_ras_t = _load_image( - op.join(self._subject_dir, 'mri', f'{mri_img}.mgz')) + op.join(self._subject_dir, "mri", f"{mri_img}.mgz") + ) # ready alternate base image if provided, otherwise use brain/T1 if base_image is None: @@ -153,19 +178,22 @@ def _load_image_data(self, base_image=None): self._vox_ras_t = vox_ras_t self._vox_scan_ras_t = vox_scan_ras_t else: - self._base_data, self._vox_ras_t, self._vox_scan_ras_t = \ - _load_image(base_image) + self._base_data, self._vox_ras_t, self._vox_scan_ras_t = _load_image( + base_image + ) if self._mri_data is not None: - if self._mri_data.shape != self._base_data.shape or \ - not np.allclose(self._vox_ras_t, vox_ras_t, rtol=1e-6): + if self._mri_data.shape != self._base_data.shape or not np.allclose( + self._vox_ras_t, vox_ras_t, rtol=1e-6 + ): raise ValueError( - 'Base image is not aligned to MRI, got ' - f'Base shape={self._base_data.shape}, ' - f'MRI shape={self._mri_data.shape}, ' - f'Base affine={vox_ras_t} and ' - f'MRI affine={self._vox_ras_t}, ' - 'please provide an aligned image or do not use the ' - '``subject`` and ``subjects_dir`` arguments') + "Base image is not aligned to MRI, got " + f"Base shape={self._base_data.shape}, " + f"MRI shape={self._mri_data.shape}, " + f"Base affine={vox_ras_t} and " + f"MRI affine={self._vox_ras_t}, " + "please provide an aligned image or do not use the " + "``subject`` and ``subjects_dir`` arguments" + ) self._ras_vox_t = np.linalg.inv(self._vox_ras_t) self._scan_ras_vox_t = np.linalg.inv(self._vox_scan_ras_t) @@ -176,113 +204,171 @@ def _load_image_data(self, base_image=None): # number. This code assumes 1mm isotropic... img_delta = 0.5 self._img_extents = list( - [-img_delta, self._voxel_sizes[idx[0]] - img_delta, - -img_delta, self._voxel_sizes[idx[1]] - img_delta] - for idx in self._xy_idx) + [ + -img_delta, + self._voxel_sizes[idx[0]] - img_delta, + -img_delta, + self._voxel_sizes[idx[1]] - img_delta, + ] + for idx in self._xy_idx + ) if self._subject_dir is not None: - if op.exists(op.join(self._subject_dir, 'surf', 'lh.seghead')): + if op.exists(op.join(self._subject_dir, "surf", "lh.seghead")): self._head = _read_mri_surface( - op.join(self._subject_dir, 'surf', 'lh.seghead')) - assert _frame_to_str[self._head['coord_frame']] == 'mri' + op.join(self._subject_dir, "surf", "lh.seghead") + ) + assert _frame_to_str[self._head["coord_frame"]] == "mri" else: - warn('`seghead` not found, using marching cubes on base image ' - 'for head plot, use :ref:`mne.bem.make_scalp_surfaces` ' - 'to add the scalp surface instead') + warn( + "`seghead` not found, using marching cubes on base image " + "for head plot, use :ref:`mne.bem.make_scalp_surfaces` " + "to add the scalp surface instead" + ) self._head = None if self._subject_dir is not None: # allow ?h.pial.T1 if ?h.pial doesn't exist # end with '' for better file not found error - for img in ('', '.T1', '.T2', ''): + for img in ("", ".T1", ".T2", ""): surf_fname = op.join( - self._subject_dir, 'surf', '{hemi}' + f'.pial{img}') - if op.isfile(surf_fname.format(hemi='lh')): + self._subject_dir, "surf", "{hemi}" + f".pial{img}" + ) + if op.isfile(surf_fname.format(hemi="lh")): break - if op.exists(surf_fname.format(hemi='lh')): - self._lh = _read_mri_surface(surf_fname.format(hemi='lh')) - assert _frame_to_str[self._lh['coord_frame']] == 'mri' - self._rh = _read_mri_surface(surf_fname.format(hemi='rh')) - assert _frame_to_str[self._rh['coord_frame']] == 'mri' + if op.exists(surf_fname.format(hemi="lh")): + self._lh = _read_mri_surface(surf_fname.format(hemi="lh")) + assert _frame_to_str[self._lh["coord_frame"]] == "mri" + self._rh = _read_mri_surface(surf_fname.format(hemi="rh")) + assert _frame_to_str[self._rh["coord_frame"]] == "mri" else: - warn('`pial` surface not found, skipping adding to 3D ' - 'plot. This indicates the Freesurfer recon-all ' - 'has not finished or has been modified and ' - 'these files have been deleted.') + warn( + "`pial` surface not found, skipping adding to 3D " + "plot. This indicates the Freesurfer recon-all " + "has not finished or has been modified and " + "these files have been deleted." + ) self._lh = self._rh = None def _plot_images(self): """Use the MRI or CT to make plots.""" # Plot sagittal (0), coronal (1) or axial (2) view - self._images = dict(base=list(), cursor_v=list(), cursor_h=list(), - bounds=list()) + self._images = dict( + base=list(), cursor_v=list(), cursor_h=list(), bounds=list() + ) img_min = np.nanmin(self._base_data) img_max = np.nanmax(self._base_data) - text_kwargs = dict(fontsize='medium', weight='bold', color='#66CCEE', - family='monospace', ha='center', va='center', - path_effects=[patheffects.withStroke( - linewidth=4, foreground="k", alpha=0.75)]) + text_kwargs = dict( + fontsize="medium", + weight="bold", + color="#66CCEE", + family="monospace", + ha="center", + va="center", + path_effects=[ + patheffects.withStroke(linewidth=4, foreground="k", alpha=0.75) + ], + ) xyz = apply_trans(self._ras_vox_t, self._ras) for axis in range(3): plot_x_idx, plot_y_idx = self._xy_idx[axis] fig = self._figs[axis] ax = fig.axes[0] - img_data = np.take(self._base_data, self._current_slice[axis], - axis=axis).T - self._images['base'].append(ax.imshow( - img_data, cmap='gray', aspect='auto', zorder=1, - vmin=img_min, vmax=img_max)) + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"].append( + ax.imshow( + img_data, + cmap="gray", + aspect="auto", + zorder=1, + vmin=img_min, + vmax=img_max, + ) + ) img_extent = self._img_extents[axis] # x0, x1, y0, y1 w, h = np.diff(np.array(img_extent).reshape(2, 2), axis=1)[:, 0] - self._images['bounds'].append(Rectangle( - img_extent[::2], w, h, edgecolor='w', facecolor='none', - alpha=0.25, lw=0.5, zorder=1.5)) - ax.add_patch(self._images['bounds'][-1]) + self._images["bounds"].append( + Rectangle( + img_extent[::2], + w, + h, + edgecolor="w", + facecolor="none", + alpha=0.25, + lw=0.5, + zorder=1.5, + ) + ) + ax.add_patch(self._images["bounds"][-1]) v_x = (xyz[plot_x_idx],) * 2 v_y = img_extent[2:4] - self._images['cursor_v'].append(ax.plot( - v_x, v_y, color='lime', linewidth=0.5, alpha=0.5, zorder=8)[0]) + self._images["cursor_v"].append( + ax.plot(v_x, v_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) h_y = (xyz[plot_y_idx],) * 2 h_x = img_extent[0:2] - self._images['cursor_h'].append(ax.plot( - h_x, h_y, color='lime', linewidth=0.5, alpha=0.5, zorder=8)[0]) + self._images["cursor_h"].append( + ax.plot(h_x, h_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) # label axes - self._figs[axis].text(0.5, 0.075, _IMG_LABELS[axis][0], - **text_kwargs) - self._figs[axis].text(0.075, 0.5, _IMG_LABELS[axis][1], - **text_kwargs) + self._figs[axis].text(0.5, 0.075, _IMG_LABELS[axis][0], **text_kwargs) + self._figs[axis].text(0.075, 0.5, _IMG_LABELS[axis][1], **text_kwargs) self._figs[axis].axes[0].axis(img_extent) + self._figs[axis].canvas.mpl_connect("scroll_event", self._on_scroll) self._figs[axis].canvas.mpl_connect( - 'scroll_event', self._on_scroll) - self._figs[axis].canvas.mpl_connect( - 'button_release_event', partial(self._on_click, axis=axis)) + "button_release_event", partial(self._on_click, axis=axis) + ) # add head and brain in mm (convert from m) if self._head is None: - logger.debug('Using marching cubes on the base image for the ' - '3D visualization panel') + logger.debug( + "Using marching cubes on the base image for the " + "3D visualization panel" + ) # in this case, leave in voxel coordinates - rr, tris = _marching_cubes(np.where( - self._base_data < np.quantile(self._base_data, 0.95), 0, 1), - [1])[0] + rr, tris = _marching_cubes( + np.where(self._base_data < np.quantile(self._base_data, 0.95), 0, 1), + [1], + )[0] # marching cubes transposes dimensions so flip rr = apply_trans(self._vox_ras_t, rr[:, ::-1]) self._renderer.mesh( - *rr.T, triangles=tris, color='gray', opacity=0.2, - reset_camera=False, render=False) + *rr.T, + triangles=tris, + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) self._renderer.set_camera(focalpoint=rr.mean(axis=0)) else: self._renderer.mesh( - *self._head['rr'].T * 1000, triangles=self._head['tris'], - color='gray', opacity=0.2, reset_camera=False, render=False) + *self._head["rr"].T * 1000, + triangles=self._head["tris"], + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) if self._lh is not None and self._rh is not None: self._renderer.mesh( - *self._lh['rr'].T * 1000, triangles=self._lh['tris'], - color='white', opacity=0.2, reset_camera=False, render=False) + *self._lh["rr"].T * 1000, + triangles=self._lh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) self._renderer.mesh( - *self._rh['rr'].T * 1000, triangles=self._rh['tris'], - color='white', opacity=0.2, reset_camera=False, render=False) - self._renderer.set_camera(azimuth=90, elevation=90, distance=300, - focalpoint=tuple(self._ras)) + *self._rh["rr"].T * 1000, + triangles=self._rh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) + self._renderer.set_camera( + azimuth=90, elevation=90, distance=300, focalpoint=tuple(self._ras) + ) # update plots self._draw() self._renderer._update() @@ -291,19 +377,19 @@ def _configure_status_bar(self, hbox=None): """Make a bar at the bottom with information in it.""" hbox = QHBoxLayout() if hbox is None else hbox - self._intensity_label = QLabel('') # update later + self._intensity_label = QLabel("") # update later hbox.addWidget(self._intensity_label) - VOX_label = QLabel('VOX =') - self._VOX_textbox = QLineEdit('') # update later + VOX_label = QLabel("VOX =") + self._VOX_textbox = QLineEdit("") # update later self._VOX_textbox.setMaximumHeight(25) self._VOX_textbox.setMinimumWidth(75) self._VOX_textbox.focusOutEvent = self._update_VOX hbox.addWidget(VOX_label) hbox.addWidget(self._VOX_textbox) - RAS_label = QLabel('RAS =') - self._RAS_textbox = QLineEdit('') # update later + RAS_label = QLabel("RAS =") + self._RAS_textbox = QLineEdit("") # update later self._RAS_textbox.setMaximumHeight(25) self._RAS_textbox.setMinimumWidth(150) self._RAS_textbox.focusOutEvent = self._update_RAS @@ -318,7 +404,8 @@ def _update_camera(self, render=False): # needs fix, distance moves when focal point updates distance=self._renderer.plotter.camera.distance * 0.9, focalpoint=tuple(self._ras), - reset_camera=False) + reset_camera=False, + ) def _on_scroll(self, event): """Process mouse scroll wheel event to zoom.""" @@ -328,8 +415,8 @@ def _zoom(self, sign=1, draw=False): """Zoom in on the image.""" delta = _ZOOM_STEP_SIZE * sign for axis, fig in enumerate(self._figs): - xcur = self._images['cursor_v'][axis].get_xdata()[0] - ycur = self._images['cursor_h'][axis].get_ydata()[0] + xcur = self._images["cursor_v"][axis].get_xdata()[0] + ycur = self._images["cursor_h"][axis].get_ydata()[0] rx, ry = [self._voxel_ratios[idx] for idx in self._xy_idx[axis]] xmin, xmax = fig.axes[0].get_xlim() ymin, ymax = fig.axes[0].get_ylim() @@ -352,37 +439,38 @@ def _zoom(self, sign=1, draw=False): @Slot() def _update_RAS(self, event): """Interpret user input to the RAS textbox.""" - ras = self._convert_text(self._RAS_textbox.text(), 'ras') + ras = self._convert_text(self._RAS_textbox.text(), "ras") if ras is not None: self._set_ras(ras) @Slot() def _update_VOX(self, event): """Interpret user input to the RAS textbox.""" - ras = self._convert_text(self._VOX_textbox.text(), 'vox') + ras = self._convert_text(self._VOX_textbox.text(), "vox") if ras is not None: self._set_ras(ras) def _convert_text(self, text, text_kind): - text = text.replace('\n', '') - vals = text.split(',') + text = text.replace("\n", "") + vals = text.split(",") if len(vals) != 3: - vals = text.split(' ') # spaces also okay as in freesurfer + vals = text.split(" ") # spaces also okay as in freesurfer vals = [var.lstrip().rstrip() for var in vals] try: vals = np.array([float(var) for var in vals]).reshape(3) except Exception: self._update_moved() # resets RAS label return - if text_kind == 'vox': + if text_kind == "vox": vox = vals ras = apply_trans(self._vox_ras_t, vox) else: - assert text_kind == 'ras' + assert text_kind == "ras" ras = vals vox = apply_trans(self._ras_vox_t, ras) - wrong_size = any(var < 0 or var > n - 1 for var, n in - zip(vox, self._voxel_sizes)) + wrong_size = any( + var < 0 or var > n - 1 for var, n in zip(vox, self._voxel_sizes) + ) if wrong_size: self._update_moved() # resets RAS label return @@ -405,18 +493,18 @@ def set_RAS(self, ras): def _set_ras(self, ras, update_plots=True): ras = np.asarray(ras, dtype=float) assert ras.shape == (3,) - msg = ', '.join(f'{x:0.2f}' for x in ras) - logger.debug(f'Trying RAS: ({msg}) mm') + msg = ", ".join(f"{x:0.2f}" for x in ras) + logger.debug(f"Trying RAS: ({msg}) mm") # clip to valid vox = apply_trans(self._ras_vox_t, ras) - vox = np.array([ - np.clip(d, 0, self._voxel_sizes[ii] - 1) - for ii, d in enumerate(vox)]) + vox = np.array( + [np.clip(d, 0, self._voxel_sizes[ii] - 1) for ii, d in enumerate(vox)] + ) # transform back, make write-only self._ras_safe = apply_trans(self._vox_ras_t, vox) - self._ras_safe.flags['WRITEABLE'] = False - msg = ', '.join(f'{x:0.2f}' for x in self._ras_safe) - logger.debug(f'Setting RAS: ({msg}) mm') + self._ras_safe.flags["WRITEABLE"] = False + msg = ", ".join(f"{x:0.2f}" for x in self._ras_safe) + logger.debug(f"Setting RAS: ({msg}) mm") if update_plots: self._move_cursors_to_pos() @@ -440,15 +528,14 @@ def _current_slice(self): def _draw(self, axis=None): """Update the figures with a draw call.""" - for axis in (range(3) if axis is None else [axis]): + for axis in range(3) if axis is None else [axis]: self._figs[axis].canvas.draw() def _update_base_images(self, axis=None, draw=False): """Update the base images.""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], - axis=axis).T - self._images['base'][axis].set_data(img_data) + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"][axis].set_data(img_data) if draw: self._draw(axis) @@ -462,23 +549,25 @@ def _move_cursors_to_pos(self): """Move the cursors to a position.""" for axis in range(3): x, y = self._vox[list(self._xy_idx[axis])] - self._images['cursor_v'][axis].set_xdata([x, x]) - self._images['cursor_h'][axis].set_ydata([y, y]) + self._images["cursor_v"][axis].set_xdata([x, x]) + self._images["cursor_h"][axis].set_ydata([y, y]) self._update_images(draw=True) self._update_moved() def _show_help(self): """Show the help menu.""" QMessageBox.information( - self, 'Help', + self, + "Help", "Help:\n" "'+'/'-': zoom\nleft/right arrow: left/right\n" "up/down arrow: superior/inferior\n" - "left angle bracket/right angle bracket: anterior/posterior") + "left angle bracket/right angle bracket: anterior/posterior", + ) def keyPressEvent(self, event): """Execute functions when the user presses a key.""" - if event.key() == 'escape': + if event.key() == "escape": self.close() elif event.key() == QtCore.Qt.Key_Return: @@ -487,25 +576,37 @@ def keyPressEvent(self, event): widget.clearFocus() self.setFocus() # removing focus calls focus out event - elif event.text() == 'h': + elif event.text() == "h": self._show_help() - elif event.text() in ('=', '+', '-'): - self._zoom(sign=-2 * (event.text() == '-') + 1, draw=True) + elif event.text() in ("=", "+", "-"): + self._zoom(sign=-2 * (event.text() == "-") + 1, draw=True) # Changing slices - elif event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down, - QtCore.Qt.Key_Left, QtCore.Qt.Key_Right, - QtCore.Qt.Key_Comma, QtCore.Qt.Key_Period, - QtCore.Qt.Key_PageUp, QtCore.Qt.Key_PageDown): + elif event.key() in ( + QtCore.Qt.Key_Up, + QtCore.Qt.Key_Down, + QtCore.Qt.Key_Left, + QtCore.Qt.Key_Right, + QtCore.Qt.Key_Comma, + QtCore.Qt.Key_Period, + QtCore.Qt.Key_PageUp, + QtCore.Qt.Key_PageDown, + ): ras = np.array(self._ras) if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down): ras[2] += 2 * (event.key() == QtCore.Qt.Key_Up) - 1 elif event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_Right): ras[0] += 2 * (event.key() == QtCore.Qt.Key_Right) - 1 else: - ras[1] += 2 * (event.key() == QtCore.Qt.Key_PageUp or - event.key() == QtCore.Qt.Key_Period) - 1 + ras[1] += ( + 2 + * ( + event.key() == QtCore.Qt.Key_PageUp + or event.key() == QtCore.Qt.Key_Period + ) + - 1 + ) self._set_ras(ras) def _on_click(self, event, axis): @@ -516,18 +617,17 @@ def _on_click(self, event, axis): logger.debug(f'Clicked {"XYZ"[axis]} ({axis}) axis at pos {pos}') xyz = self._vox xyz[list(self._xy_idx[axis])] = pos - logger.debug(f'Using voxel {list(xyz)}') + logger.debug(f"Using voxel {list(xyz)}") ras = apply_trans(self._vox_ras_t, xyz) self._set_ras(ras) def _update_moved(self): """Update when cursor position changes.""" - self._RAS_textbox.setText('{:.2f}, {:.2f}, {:.2f}'.format( - *self._ras)) - self._VOX_textbox.setText('{:3d}, {:3d}, {:3d}'.format( - *self._current_slice)) - self._intensity_label.setText('intensity = {:.2f}'.format( - self._base_data[tuple(self._current_slice)])) + self._RAS_textbox.setText("{:.2f}, {:.2f}, {:.2f}".format(*self._ras)) + self._VOX_textbox.setText("{:3d}, {:3d}, {:3d}".format(*self._current_slice)) + self._intensity_label.setText( + "intensity = {:.2f}".format(self._base_data[tuple(self._current_slice)]) + ) @safe_event def closeEvent(self, event): diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index d5126ba2fee..a9f26038107 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -21,23 +21,49 @@ from ..io.meas_info import _empty_info from ..io._read_raw import supported as raw_supported_types from ..bem import make_bem_solution, write_bem_solution -from ..coreg import (Coregistration, _is_mri_subject, scale_mri, bem_fname, - _mri_subject_has_bem, fid_fname, _map_fid_name_to_idx, - _find_head_bem) -from ..viz._3d import (_plot_head_surface, _plot_head_fiducials, - _plot_head_shape_points, _plot_mri_fiducials, - _plot_hpi_coils, _plot_sensors, _plot_helmet) +from ..coreg import ( + Coregistration, + _is_mri_subject, + scale_mri, + bem_fname, + _mri_subject_has_bem, + fid_fname, + _map_fid_name_to_idx, + _find_head_bem, +) +from ..viz._3d import ( + _plot_head_surface, + _plot_head_fiducials, + _plot_head_shape_points, + _plot_mri_fiducials, + _plot_hpi_coils, + _plot_sensors, + _plot_helmet, +) from ..viz.backends._utils import _qt_app_exec, _qt_safe_window from ..viz.utils import safe_event -from ..transforms import (read_trans, write_trans, _ensure_trans, _get_trans, - rotation_angles, _get_transforms_to_coord_frame) -from ..utils import (get_subjects_dir, check_fname, _check_fname, fill_doc, - verbose, logger, _validate_type) +from ..transforms import ( + read_trans, + write_trans, + _ensure_trans, + _get_trans, + rotation_angles, + _get_transforms_to_coord_frame, +) +from ..utils import ( + get_subjects_dir, + check_fname, + _check_fname, + fill_doc, + verbose, + logger, + _validate_type, +) from ..surface import _DistanceQuery, _CheckInside from ..channels import read_dig_fif -class _WorkerData(): +class _WorkerData: def __init__(self, name, params=None): self._name = name self._params = params @@ -50,9 +76,9 @@ def _get_subjects(sdir): dir_content = os.listdir(sdir) subjects = [s for s in dir_content if _is_mri_subject(s, sdir)] if len(subjects) == 0: - subjects.append('') + subjects.append("") else: - subjects = [''] + subjects = [""] return sorted(subjects) @@ -136,28 +162,47 @@ class CoregistrationUI(HasTraits): _scale_mode = Unicode() _icp_fid_match = Unicode() - @_qt_safe_window(splash='_renderer.figure.splash', - window='_renderer.figure.plotter') + @_qt_safe_window( + splash="_renderer.figure.splash", window="_renderer.figure.plotter" + ) @verbose - def __init__(self, info_file, *, subject=None, subjects_dir=None, - fiducials='auto', head_resolution=None, - head_opacity=None, hpi_coils=None, - head_shape_points=None, eeg_channels=None, orient_glyphs=None, - scale_by_distance=None, mark_inside=None, - sensor_opacity=None, trans=None, size=None, bgcolor=None, - show=True, block=False, fullscreen=False, - interaction='http://github.com/terrain', verbose=None): + def __init__( + self, + info_file, + *, + subject=None, + subjects_dir=None, + fiducials="auto", + head_resolution=None, + head_opacity=None, + hpi_coils=None, + head_shape_points=None, + eeg_channels=None, + orient_glyphs=None, + scale_by_distance=None, + mark_inside=None, + sensor_opacity=None, + trans=None, + size=None, + bgcolor=None, + show=True, + block=False, + fullscreen=False, + interaction="http://github.com/terrain", + verbose=None, + ): from ..viz.backends.renderer import _get_renderer def _get_default(var, val): return var if var is not None else val + self._actors = dict() self._surfaces = dict() self._widgets = dict() self._verbose = verbose self._plot_locked = False self._params_locked = False - self._refresh_rate_ms = max(int(round(1000. / 60.)), 1) + self._refresh_rate_ms = max(int(round(1000.0 / 60.0)), 1) self._redraws_pending = set() self._parameter_mutex = threading.Lock() self._redraw_mutex = threading.Lock() @@ -176,8 +221,8 @@ def _get_default(var, val): self._mri_scale_modified = False self._accept_close_event = True self._fid_colors = tuple( - DEFAULTS['coreg'][f'{key}_color'] for key in - ('lpa', 'nasion', 'rpa')) + DEFAULTS["coreg"][f"{key}_color"] for key in ("lpa", "nasion", "rpa") + ) self._defaults = dict( size=_get_default(size, (800, 600)), bgcolor=_get_default(bgcolor, "grey"), @@ -198,8 +243,8 @@ def _get_default(var, val): subject_to="", scale_modes=["None", "uniform", "3-axis"], scale_mode="None", - icp_fid_matches=('nearest', 'matched'), - icp_fid_match='matched', + icp_fid_matches=("nearest", "matched"), + icp_fid_match="matched", icp_n_iterations=20, omit_hsp_distance=10.0, lock_head_opacity=self._head_opacity < 1.0, @@ -221,7 +266,7 @@ def _get_default(var, val): subject = _get_default(subject, _get_subjects(subjects_dir)[0]) # setup the window - splash = 'Initializing coregistration GUI...' if show else False + splash = "Initializing coregistration GUI..." if show else False self._renderer = _get_renderer( size=self._defaults["size"], bgcolor=self._defaults["bgcolor"], @@ -233,13 +278,15 @@ def _get_default(var, val): self._renderer.set_interaction(interaction) # coregistration model setup - self._immediate_redraw = (self._renderer._kind != 'qt') + self._immediate_redraw = self._renderer._kind != "qt" self._info = info self._fiducials = fiducials self.coreg = Coregistration( - info=self._info, subject=subject, subjects_dir=subjects_dir, + info=self._info, + subject=subject, + subjects_dir=subjects_dir, fiducials=fiducials, - on_defects='ignore' # safe due to interactive visual inspection + on_defects="ignore", # safe due to interactive visual inspection ) fid_accurate = self.coreg._fid_accurate for fid in self._defaults["weights"].keys(): @@ -286,8 +333,8 @@ def _get_default(var, val): # internally self._set_fiducials_file(self.coreg._fid_filename) else: - self._set_head_resolution('high') - self._forward_widget_command('high_res_head', "set_value", True) + self._set_head_resolution("high") + self._forward_widget_command("high_res_head", "set_value", True) self._set_lock_fids(True) # hack to make the dig disappear self._update_fiducials_label() self._update_fiducials() @@ -301,20 +348,21 @@ def _get_default(var, val): if show: self._renderer.show() # update the view once shown - views = {True: dict(azimuth=90, elevation=90), # front - False: dict(azimuth=180, elevation=90)} # left + views = { + True: dict(azimuth=90, elevation=90), # front + False: dict(azimuth=180, elevation=90), + } # left self._renderer.set_camera(distance=None, **views[self._lock_fids]) self._redraw() # XXX: internal plotter/renderer should not be exposed if not self._immediate_redraw: - self._renderer.plotter.add_callback( - self._redraw, self._refresh_rate_ms) + self._renderer.plotter.add_callback(self._redraw, self._refresh_rate_ms) self._renderer.plotter.show_axes() # initialization does not count as modification by the user self._trans_modified = False self._mri_fids_modified = False self._mri_scale_modified = False - if block and self._renderer._kind != 'notebook': + if block and self._renderer._kind != "notebook": _qt_app_exec(self._renderer.figure.store["app"]) def _set_subjects_dir(self, subjects_dir): @@ -330,10 +378,8 @@ def _set_subjects_dir(self, subjects_dir): ) ) subjects = _get_subjects(subjects_dir) - low_res_path = _find_head_bem( - subjects[0], subjects_dir, high_res=False) - high_res_path = _find_head_bem( - subjects[0], subjects_dir, high_res=True) + low_res_path = _find_head_bem(subjects[0], subjects_dir, high_res=False) + high_res_path = _find_head_bem(subjects[0], subjects_dir, high_res=True) valid = low_res_path is not None or high_res_path is not None except Exception: valid = False @@ -352,7 +398,7 @@ def _set_lock_fids(self, state): def _set_fiducials_file(self, fname): if fname is None: - fids = 'auto' + fids = "auto" else: fname = str( _check_fname( @@ -373,17 +419,11 @@ def _set_fiducials_file(self, fname): if fname is None: self._set_lock_fids(False) - self._forward_widget_command( - 'reload_mri_fids', 'set_enabled', False - ) + self._forward_widget_command("reload_mri_fids", "set_enabled", False) else: self._set_lock_fids(True) - self._forward_widget_command( - 'reload_mri_fids', 'set_enabled', True - ) - self._display_message( - f"Loading MRI fiducials from {fname}... Done!" - ) + self._forward_widget_command("reload_mri_fids", "set_enabled", True) + self._display_message(f"Loading MRI fiducials from {fname}... Done!") def _set_current_fiducial(self, fid): self._current_fiducial = fid.lower() @@ -394,17 +434,23 @@ def _set_info_file(self, fname): # info file can be anything supported by read_raw try: - check_fname(fname, 'info', tuple(raw_supported_types.keys()), - endings_err=tuple(raw_supported_types.keys())) + check_fname( + fname, + "info", + tuple(raw_supported_types.keys()), + endings_err=tuple(raw_supported_types.keys()), + ) fname = str(_check_fname(fname, overwrite="read")) # cast to str # ctf ds `files` are actually directories - if fname.endswith(('.ds',)): + if fname.endswith((".ds",)): info_file = _check_fname( - fname, overwrite='read', must_exist=True, need_dir=True) + fname, overwrite="read", must_exist=True, need_dir=True + ) else: info_file = _check_fname( - fname, overwrite='read', must_exist=True, need_dir=False) + fname, overwrite="read", must_exist=True, need_dir=False + ) valid = True except OSError: valid = False @@ -450,14 +496,12 @@ def _set_grow_hair(self, value): def _set_subject_to(self, value): self._subject_to = value - self._forward_widget_command( - "save_subject", "set_enabled", len(value) > 0) + self._forward_widget_command("save_subject", "set_enabled", len(value) > 0) if self._check_subject_exists(): style = dict(border="2px solid #ff0000") else: style = dict(border="initial") - self._forward_widget_command( - "subject_to", "set_style", style) + self._forward_widget_command("subject_to", "set_style", style) def _set_scale_mode(self, mode): self._scale_mode = mode @@ -470,7 +514,7 @@ def _set_fiducial(self, value, coord): coords = ["X", "Y", "Z"] coord_idx = coords.index(coord) - self.coreg.fiducials.dig[fid_idx]['r'][coord_idx] = value / 1e3 + self.coreg.fiducials.dig[fid_idx]["r"][coord_idx] = value / 1e3 self._update_plot("mri_fids") def _set_parameter(self, value, mode_name, coord, plot_locked=False): @@ -482,10 +526,9 @@ def _set_parameter(self, value, mode_name, coord, plot_locked=False): return if mode_name == "scale" and self._scale_mode == "uniform": with self._lock(params=True): - self._forward_widget_command( - ["sY", "sZ"], "set_value", value) + self._forward_widget_command(["sY", "sZ"], "set_value", value) with self._parameter_mutex: - self. _set_parameter_safe(value, mode_name, coord) + self._set_parameter_safe(value, mode_name, coord) if not plot_locked: self._update_plot("sensors") @@ -521,9 +564,9 @@ def _set_icp_fid_match(self, method): def _set_point_weight(self, weight, point): funcs = { - 'hpi': '_set_hpi_coils', - 'hsp': '_set_head_shape_points', - 'eeg': '_set_eeg_channels', + "hpi": "_set_hpi_coils", + "hsp": "_set_head_shape_points", + "eeg": "_set_eeg_channels", } if point in funcs.keys(): getattr(self, funcs[point])(weight > 0) @@ -567,70 +610,90 @@ def _lock_fids_changed(self, change=None): # MRI fiducials "save_mri_fids", # View options - "helmet", "head_opacity", "high_res_head", + "helmet", + "head_opacity", + "high_res_head", # Digitization source - "info_file", "grow_hair", "omit_distance", "omit", "reset_omit", + "info_file", + "grow_hair", + "omit_distance", + "omit", + "reset_omit", # Scaling - "scaling_mode", "sX", "sY", "sZ", + "scaling_mode", + "sX", + "sY", + "sZ", # Transformation - "tX", "tY", "tZ", - "rX", "rY", "rZ", + "tX", + "tY", + "tZ", + "rX", + "rY", + "rZ", # Fitting buttons - "fit_fiducials", "fit_icp", + "fit_fiducials", + "fit_icp", # Transformation I/O - "save_trans", "load_trans", + "save_trans", + "load_trans", "reset_trans", # ICP - "icp_n_iterations", "icp_fid_match", "reset_fitting_options", + "icp_n_iterations", + "icp_fid_match", + "reset_fitting_options", # Weights - "hsp_weight", "eeg_weight", "hpi_weight", - "lpa_weight", "nasion_weight", "rpa_weight", + "hsp_weight", + "eeg_weight", + "hpi_weight", + "lpa_weight", + "nasion_weight", + "rpa_weight", ] fits_widgets = ["fits_fiducials", "fits_icp"] fid_widgets = ["fid_X", "fid_Y", "fid_Z", "fids_file", "fids"] if self._lock_fids: self._forward_widget_command(locked_widgets, "set_enabled", True) self._forward_widget_command( - 'head_opacity', 'set_value', self._old_head_opacity + "head_opacity", "set_value", self._old_head_opacity ) self._scale_mode_changed() self._display_message() self._update_distance_estimation() else: self._old_head_opacity = self._head_opacity - self._forward_widget_command( - 'head_opacity', 'set_value', 1.0 - ) + self._forward_widget_command("head_opacity", "set_value", 1.0) self._forward_widget_command(locked_widgets, "set_enabled", False) self._forward_widget_command(fits_widgets, "set_enabled", False) - self._display_message("Placing MRI fiducials - " - f"{self._current_fiducial.upper()}") + self._display_message( + "Placing MRI fiducials - " f"{self._current_fiducial.upper()}" + ) self._set_sensors_visibility(self._lock_fids) self._forward_widget_command("lock_fids", "set_value", self._lock_fids) - self._forward_widget_command(fid_widgets, "set_enabled", - not self._lock_fids) + self._forward_widget_command(fid_widgets, "set_enabled", not self._lock_fids) @observe("_current_fiducial") def _current_fiducial_changed(self, change=None): self._update_fiducials() self._follow_fiducial_view() if not self._lock_fids: - self._display_message("Placing MRI fiducials - " - f"{self._current_fiducial.upper()}") + self._display_message( + "Placing MRI fiducials - " f"{self._current_fiducial.upper()}" + ) @observe("_info_file") def _info_file_changed(self, change=None): if not self._info_file: return - elif self._info_file.endswith(('.fif', '.fif.gz')): + elif self._info_file.endswith((".fif", ".fif.gz")): fid, tree, _ = fiff_open(self._info_file) fid.close() if len(dir_tree_find(tree, FIFF.FIFFB_MEAS_INFO)) > 0: self._info = read_info(self._info_file, verbose=False) elif len(dir_tree_find(tree, FIFF.FIFFB_ISOTRAK)) > 0: self._info = _empty_info(1) - self._info['dig'] = read_dig_fif(fname=self._info_file).dig + self._info["dig"] = read_dig_fif(fname=self._info_file).dig self._info._unlocked = False else: self._info = read_raw(self._info_file).info @@ -689,10 +752,12 @@ def _scale_mode_changed(self, change=None): mode = None if self._scale_mode == "None" else self._scale_mode self.coreg.set_scale_mode(mode) if self._lock_fids: - self._forward_widget_command(locked_widgets, "set_enabled", - mode is not None) - self._forward_widget_command("fits_fiducials", "set_enabled", - mode not in (None, "3-axis")) + self._forward_widget_command( + locked_widgets, "set_enabled", mode is not None + ) + self._forward_widget_command( + "fits_fiducials", "set_enabled", mode not in (None, "3-axis") + ) if self._scale_mode == "uniform": self._forward_widget_command(["sY", "sZ"], "set_enabled", False) @@ -712,13 +777,15 @@ def _run_worker(self, queue, jobs): def _configure_dialogs(self): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + for name, buttons in zip( - ["overwrite_subject", "overwrite_subject_exit"], - [["Yes", "No"], ["Yes", "Discard", "Cancel"]]): + ["overwrite_subject", "overwrite_subject_exit"], + [["Yes", "No"], ["Yes", "Discard", "Cancel"]], + ): self._widgets[name] = self._renderer._dialog_create( title="CoregistrationUI", text="The name of the output subject used to " - "save the scaled anatomy already exists.", + "save the scaled anatomy already exists.", info_text="Do you want to overwrite?", callback=self._overwrite_subject_callback, buttons=buttons, @@ -731,11 +798,13 @@ def _configure_worker(self): "_parameter_queue": dict(set_parameter=self._set_parameter), } for queue_name, jobs in work_plan.items(): - t = threading.Thread(target=partial( - self._run_worker, - queue=getattr(self, queue_name), - jobs=jobs, - )) + t = threading.Thread( + target=partial( + self._run_worker, + queue=getattr(self, queue_name), + jobs=jobs, + ) + ) t.daemon = True t.start() @@ -744,14 +813,15 @@ def _configure_picking(self): self._on_mouse_move, self._on_button_press, self._on_button_release, - self._on_pick + self._on_pick, ) def _configure_legend(self): - colors = \ - [np.array(DEFAULTS['coreg'][f"{fid.lower()}_color"]).astype(float) - for fid in self._defaults['fiducials']] - labels = list(zip(self._defaults['fiducials'], colors)) + colors = [ + np.array(DEFAULTS["coreg"][f"{fid.lower()}_color"]).astype(float) + for fid in self._defaults["fiducials"] + ] + labels = list(zip(self._defaults["fiducials"], colors)) mri_fids_legend_actor = self._renderer.legend(labels=labels) self._update_actor("mri_fids_legend", mri_fids_legend_actor) @@ -772,16 +842,16 @@ def _redraw(self, *, verbose=None): # We need at least "head" before "hsp", because the grow_hair param # for head sets the rr that are used for inside/outside hsp redraws_ordered = sorted( - self._redraws_pending, - key=lambda key: list(draw_map).index(key)) - logger.debug(f'Redrawing {redraws_ordered}') + self._redraws_pending, key=lambda key: list(draw_map).index(key) + ) + logger.debug(f"Redrawing {redraws_ordered}") for ki, key in enumerate(redraws_ordered): - logger.debug(f'{ki}. Drawing {repr(key)}') + logger.debug(f"{ki}. Drawing {repr(key)}") draw_map[key]() self._redraws_pending.clear() self._renderer._update() # necessary for MacOS - if platform.system() == 'Darwin': + if platform.system() == "Darwin": self._renderer._process_events() def _on_mouse_move(self, vtk_picker, event): @@ -813,8 +883,10 @@ def _on_pick(self, vtk_picker, event): return pos = np.array(vtk_picker.GetPickPosition()) vtk_cell = mesh.GetCell(cell_id) - cell = [vtk_cell.GetPointId(point_id) for point_id - in range(vtk_cell.GetNumberOfPoints())] + cell = [ + vtk_cell.GetPointId(point_id) + for point_id in range(vtk_cell.GetNumberOfPoints()) + ] vertices = mesh.points[cell] idx = np.argmin(abs(vertices - pos), axis=0) vertex_id = cell[idx[0]] @@ -828,14 +900,16 @@ def _on_pick(self, vtk_picker, event): self._update_plot("mri_fids") def _reset_fitting_parameters(self): - self._forward_widget_command("icp_n_iterations", "set_value", - self._defaults["icp_n_iterations"]) - self._forward_widget_command("icp_fid_match", "set_value", - self._defaults["icp_fid_match"]) - weights_widgets = [f"{w}_weight" - for w in self._defaults["weights"].keys()] - self._forward_widget_command(weights_widgets, "set_value", - list(self._defaults["weights"].values())) + self._forward_widget_command( + "icp_n_iterations", "set_value", self._defaults["icp_n_iterations"] + ) + self._forward_widget_command( + "icp_fid_match", "set_value", self._defaults["icp_fid_match"] + ) + weights_widgets = [f"{w}_weight" for w in self._defaults["weights"].keys()] + self._forward_widget_command( + weights_widgets, "set_value", list(self._defaults["weights"].values()) + ) def _reset_fiducials(self): self._set_current_fiducial(self._defaults["fiducial"]) @@ -843,21 +917,22 @@ def _reset_fiducials(self): def _omit_hsp(self): self.coreg.omit_head_shape_points(self._omit_hsp_distance / 1e3) n_omitted = np.sum(~self.coreg._extra_points_filter) - n_remaining = len(self.coreg._dig_dict['hsp']) - n_omitted + n_remaining = len(self.coreg._dig_dict["hsp"]) - n_omitted self._update_plot("hsp") self._update_distance_estimation() self._display_message( - f"{n_omitted} head shape points omitted, " - f"{n_remaining} remaining.") + f"{n_omitted} head shape points omitted, " f"{n_remaining} remaining." + ) def _reset_omit_hsp_filter(self): self.coreg._extra_points_filter = None self.coreg._update_params(force_update=True) self._update_plot("hsp") self._update_distance_estimation() - n_total = len(self.coreg._dig_dict['hsp']) + n_total = len(self.coreg._dig_dict["hsp"]) self._display_message( - f"No head shape point is omitted, the total is {n_total}.") + f"No head shape point is omitted, the total is {n_total}." + ) @verbose def _update_plot(self, changes="all", verbose=None): @@ -866,9 +941,8 @@ def _update_plot(self, changes="all", verbose=None): try: fun_name = inspect.currentframe().f_back.f_back.f_code.co_name except Exception: # just in case one of these attrs is missing - fun_name = 'unknown' - logger.debug( - f'Updating plots based on {fun_name}: {repr(changes)}') + fun_name = "unknown" + logger.debug(f"Updating plots based on {fun_name}: {repr(changes)}") if self._plot_locked: return if self._info is None: @@ -876,15 +950,20 @@ def _update_plot(self, changes="all", verbose=None): self._to_cf_t = dict(mri=dict(trans=np.eye(4)), head=None) else: self._to_cf_t = _get_transforms_to_coord_frame( - self._info, self.coreg.trans, coord_frame=self._coord_frame) + self._info, self.coreg.trans, coord_frame=self._coord_frame + ) all_keys = ( - 'head', 'mri_fids', # MRI first - 'hsp', 'hpi', 'eeg', 'head_fids', # then dig - 'helmet', - ) - if changes == 'all': + "head", + "mri_fids", # MRI first + "hsp", + "hpi", + "eeg", + "head_fids", # then dig + "helmet", + ) + if changes == "all": changes = list(all_keys) - elif changes == 'sensors': + elif changes == "sensors": changes = all_keys[2:] # omit MRI ones elif isinstance(changes, str): changes = [changes] @@ -894,7 +973,7 @@ def _update_plot(self, changes="all", verbose=None): # it would reduce "jerkiness" of the updates, but this should at least # work okay bad = changes.difference(set(all_keys)) - assert len(bad) == 0, f'Unknown changes: {bad}' + assert len(bad) == 0, f"Unknown changes: {bad}" self._redraws_pending.update(changes) if self._immediate_redraw: self._redraw() @@ -913,15 +992,24 @@ def _lock(self, plot=False, params=False, scale_mode=False, fitting=False): self.coreg._scale_mode = None if fitting: widgets = [ - "sX", "sY", "sZ", - "tX", "tY", "tZ", - "rX", "rY", "rZ", - "fit_icp", "fit_fiducials", "fits_icp", "fits_fiducials" + "sX", + "sY", + "sZ", + "tX", + "tY", + "tZ", + "rX", + "rY", + "rZ", + "fit_icp", + "fit_fiducials", + "fits_icp", + "fits_fiducials", ] states = [ self._forward_widget_command( - w, "is_enabled", None, - input_value=False, output_value=True) + w, "is_enabled", None, input_value=False, output_value=True + ) for w in widgets ] self._forward_widget_command(widgets, "set_enabled", False) @@ -939,21 +1027,19 @@ def _lock(self, plot=False, params=False, scale_mode=False, fitting=False): self._forward_widget_command(w, "set_enabled", states[idx]) def _display_message(self, msg=""): - self._forward_widget_command('status_message', 'set_value', msg) + self._forward_widget_command("status_message", "set_value", msg) + self._forward_widget_command("status_message", "show", None, input_value=False) self._forward_widget_command( - 'status_message', 'show', None, input_value=False - ) - self._forward_widget_command( - 'status_message', 'update', None, input_value=False + "status_message", "update", None, input_value=False ) if msg: logger.info(msg) def _follow_fiducial_view(self): fid = self._current_fiducial.lower() - view = dict(lpa='left', rpa='right', nasion='front') - kwargs = dict(front=(90., 90.), left=(180, 90), right=(0., 90)) - kwargs = dict(zip(('azimuth', 'elevation'), kwargs[view[fid]])) + view = dict(lpa="left", rpa="right", nasion="front") + kwargs = dict(front=(90.0, 90.0), left=(180, 90), right=(0.0, 90)) + kwargs = dict(zip(("azimuth", "elevation"), kwargs[view[fid]])) if not self._lock_fids: self._renderer.set_camera(distance=None, **kwargs) @@ -963,35 +1049,39 @@ def _update_fiducials(self): return idx = _map_fid_name_to_idx(name=fid) - val = self.coreg.fiducials.dig[idx]['r'] * 1e3 + val = self.coreg.fiducials.dig[idx]["r"] * 1e3 with self._lock(plot=True): - self._forward_widget_command( - ["fid_X", "fid_Y", "fid_Z"], "set_value", val) + self._forward_widget_command(["fid_X", "fid_Y", "fid_Z"], "set_value", val) def _update_distance_estimation(self): - value = self.coreg._get_fiducials_distance_str() + '\n' + \ - self.coreg._get_point_distance_str() + value = ( + self.coreg._get_fiducials_distance_str() + + "\n" + + self.coreg._get_point_distance_str() + ) dists = self.coreg.compute_dig_mri_distances() * 1e3 if self._hsp_weight > 0: - value += "\nHSP <-> MRI (mean/min/max): "\ - f"{np.mean(dists):.2f} "\ + value += ( + "\nHSP <-> MRI (mean/min/max): " + f"{np.mean(dists):.2f} " f"/ {np.min(dists):.2f} / {np.max(dists):.2f} mm" + ) self._forward_widget_command("fit_label", "set_value", value) def _update_parameters(self): with self._lock(plot=True, params=True): # rotation deg = np.rad2deg(self.coreg._rotation) - logger.debug(f' Rotation: {deg}') + logger.debug(f" Rotation: {deg}") self._forward_widget_command(["rX", "rY", "rZ"], "set_value", deg) # translation mm = self.coreg._translation * 1e3 - logger.debug(f' Translation: {mm}') + logger.debug(f" Translation: {mm}") self._forward_widget_command(["tX", "tY", "tZ"], "set_value", mm) # scale sc = self.coreg._scale * 1e2 - logger.debug(f' Scale: {sc}') + logger.debug(f" Scale: {sc}") self._forward_widget_command(["sX", "sY", "sZ"], "set_value", sc) def _reset(self, keep_trans=False): @@ -1011,8 +1101,9 @@ def _reset(self, keep_trans=False): self._update_parameters() self._update_distance_estimation() - def _forward_widget_command(self, names, command, value, - input_value=True, output_value=False): + def _forward_widget_command( + self, names, command, value, input_value=True, output_value=False + ): """Invoke a method of one or more widgets if the widgets exist. Parameters @@ -1035,11 +1126,7 @@ def _forward_widget_command(self, names, command, value, ``None`` if ``output_value`` is ``False``, and the return value of ``command`` otherwise. """ - _validate_type( - item=names, - types=(str, list), - item_name='names' - ) + _validate_type(item=names, types=(str, list), item_name="names") if isinstance(names, str): names = [names] @@ -1058,8 +1145,7 @@ def _forward_widget_command(self, names, command, value, return ret def _set_sensors_visibility(self, state): - sensors = ["head_fiducials", "hpi_coils", "head_shape_points", - "eeg_channels"] + sensors = ["head_fiducials", "hpi_coils", "head_shape_points", "eeg_channels"] for sensor in sensors: if sensor in self._actors and self._actors[sensor] is not None: actors = self._actors[sensor] @@ -1070,14 +1156,18 @@ def _set_sensors_visibility(self, state): def _update_actor(self, actor_name, actor): # XXX: internal plotter/renderer should not be exposed - self._renderer.plotter.remove_actor(self._actors.get(actor_name), - render=False) + self._renderer.plotter.remove_actor(self._actors.get(actor_name), render=False) self._actors[actor_name] = actor def _add_mri_fiducials(self): mri_fids_actors = _plot_mri_fiducials( - self._renderer, self.coreg._fid_points, self._subjects_dir, - self._subject, self._to_cf_t, self._fid_colors) + self._renderer, + self.coreg._fid_points, + self._subjects_dir, + self._subject, + self._to_cf_t, + self._fid_colors, + ) # disable picking on the markers for actor in mri_fids_actors: actor.SetPickable(False) @@ -1085,19 +1175,24 @@ def _add_mri_fiducials(self): def _add_head_fiducials(self): head_fids_actors = _plot_head_fiducials( - self._renderer, self._info, self._to_cf_t, self._fid_colors) + self._renderer, self._info, self._to_cf_t, self._fid_colors + ) self._update_actor("head_fiducials", head_fids_actors) def _add_hpi_coils(self): if self._hpi_coils: hpi_actors = _plot_hpi_coils( - self._renderer, self._info, self._to_cf_t, + self._renderer, + self._info, + self._to_cf_t, opacity=self._defaults["sensor_opacity"], scale=DEFAULTS["coreg"]["extra_scale"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - surf=self._head_geo, check_inside=self._check_inside, - nearest=self._nearest) + surf=self._head_geo, + check_inside=self._check_inside, + nearest=self._nearest, + ) else: hpi_actors = None self._update_actor("hpi_coils", hpi_actors) @@ -1105,13 +1200,18 @@ def _add_hpi_coils(self): def _add_head_shape_points(self): if self._head_shape_points: hsp_actors = _plot_head_shape_points( - self._renderer, self._info, self._to_cf_t, + self._renderer, + self._info, + self._to_cf_t, opacity=self._defaults["sensor_opacity"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - mark_inside=self._mark_inside, surf=self._head_geo, + mark_inside=self._mark_inside, + surf=self._head_geo, mask=self.coreg._extra_points_filter, - check_inside=self._check_inside, nearest=self._nearest) + check_inside=self._check_inside, + nearest=self._nearest, + ) else: hsp_actors = None self._update_actor("head_shape_points", hsp_actors) @@ -1122,14 +1222,23 @@ def _add_eeg_channels(self): picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True) if len(picks) > 0: actors = _plot_sensors( - self._renderer, self._info, self._to_cf_t, picks, - meg=False, eeg=eeg, fnirs=["sources", "detectors"], - warn_meg=False, head_surf=self._head_geo, units='m', + self._renderer, + self._info, + self._to_cf_t, + picks, + meg=False, + eeg=eeg, + fnirs=["sources", "detectors"], + warn_meg=False, + head_surf=self._head_geo, + units="m", sensor_opacity=self._defaults["sensor_opacity"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - surf=self._head_geo, check_inside=self._check_inside, - nearest=self._nearest) + surf=self._head_geo, + check_inside=self._check_inside, + nearest=self._nearest, + ) sens_actors = actors["eeg"] sens_actors.extend(actors["fnirs"]) else: @@ -1141,22 +1250,34 @@ def _add_eeg_channels(self): def _add_head_surface(self): bem = None if self._head_resolution: - surface = 'head-dense' - key = 'high' + surface = "head-dense" + key = "high" else: - surface = 'head' - key = 'low' + surface = "head" + key = "low" try: head_actor, head_surf, _ = _plot_head_surface( - self._renderer, surface, self._subject, - self._subjects_dir, bem, self._coord_frame, self._to_cf_t, - alpha=self._head_opacity) + self._renderer, + surface, + self._subject, + self._subjects_dir, + bem, + self._coord_frame, + self._to_cf_t, + alpha=self._head_opacity, + ) except OSError: head_actor, head_surf, _ = _plot_head_surface( - self._renderer, "head", self._subject, self._subjects_dir, - bem, self._coord_frame, self._to_cf_t, - alpha=self._head_opacity) - key = 'low' + self._renderer, + "head", + self._subject, + self._subjects_dir, + bem, + self._coord_frame, + self._to_cf_t, + alpha=self._head_opacity, + ) + key = "low" self._update_actor("head", head_actor) # mark head surface mesh to restrict picking head_surf._picking_target = True @@ -1170,16 +1291,16 @@ def _add_head_surface(self): nn = self._surfaces["head"].point_normals assert nn.shape == (len(rr), 3), nn.shape self._head_geo = dict(rr=rr, tris=tris, nn=nn) - self._check_inside = _CheckInside(head_surf, mode='pyvista') + self._check_inside = _CheckInside(head_surf, mode="pyvista") self._nearest = _DistanceQuery(rr) def _add_helmet(self): if self._helmet: - logger.debug('Drawing helmet') - head_mri_t = _get_trans(self.coreg.trans, 'head', 'mri')[0] + logger.debug("Drawing helmet") + head_mri_t = _get_trans(self.coreg.trans, "head", "mri")[0] helmet_actor, _, _ = _plot_helmet( - self._renderer, self._info, self._to_cf_t, head_mri_t, - self._coord_frame) + self._renderer, self._info, self._to_cf_t, head_mri_t, self._coord_frame + ) else: helmet_actor = None self._update_actor("helmet", helmet_actor) @@ -1199,7 +1320,8 @@ def _fits_fiducials(self): ) end = time.time() self._display_message( - f"Fitting fiducials finished in {end - start:.2f} seconds.") + f"Fitting fiducials finished in {end - start:.2f} seconds." + ) self._update_plot("sensors") self._update_parameters() self._update_distance_estimation() @@ -1214,13 +1336,12 @@ def _fits_icp(self): def _fit_icp_real(self, *, update_head): with self._lock(params=True, fitting=True): self._current_icp_iterations = 0 - updates = ['hsp', 'hpi', 'eeg', 'head_fids', 'helmet'] + updates = ["hsp", "hpi", "eeg", "head_fids", "helmet"] if update_head: - updates.insert(0, 'head') + updates.insert(0, "head") def callback(iteration, n_iterations): - self._display_message( - f"Fitting ICP - iteration {iteration + 1}") + self._display_message(f"Fitting ICP - iteration {iteration + 1}") self._update_plot(updates) self._current_icp_iterations += 1 self._update_distance_estimation() @@ -1240,11 +1361,13 @@ def callback(iteration, n_iterations): self._display_message() self._display_message( f"Fitting ICP finished in {end - start:.2f} seconds and " - f"{self._current_icp_iterations} iterations.") + f"{self._current_icp_iterations} iterations." + ) del self._current_icp_iterations def _task_save_subject(self): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + if MNE_3D_BACKEND_TESTING: self._save_subject() else: @@ -1252,12 +1375,21 @@ def _task_save_subject(self): def _task_set_parameter(self, value, mode_name, coord): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + if MNE_3D_BACKEND_TESTING: self._set_parameter(value, mode_name, coord, self._plot_locked) else: - self._parameter_queue.put(_WorkerData("set_parameter", dict( - value=value, mode_name=mode_name, coord=coord, - plot_locked=self._plot_locked))) + self._parameter_queue.put( + _WorkerData( + "set_parameter", + dict( + value=value, + mode_name=mode_name, + coord=coord, + plot_locked=self._plot_locked, + ), + ) + ) def _overwrite_subject_callback(self, button_name): if button_name == "Yes": @@ -1270,9 +1402,10 @@ def _overwrite_subject_callback(self, button_name): def _check_subject_exists(self): if not self._subject_to: return False - subject_dirname = os.path.join('{subjects_dir}', '{subject}') - dest = subject_dirname.format(subject=self._subject_to, - subjects_dir=self._subjects_dir) + subject_dirname = os.path.join("{subjects_dir}", "{subject}") + dest = subject_dirname.format( + subject=self._subject_to, subjects_dir=self._subjects_dir + ) return os.path.exists(dest) def _save_subject(self, exit_mode=False): @@ -1286,19 +1419,19 @@ def _save_subject_callback(self, overwrite=False): self._display_message(f"Saving {self._subject_to}...") default_cursor = self._renderer._window_get_cursor() self._renderer._window_set_cursor( - self._renderer._window_new_cursor("WaitCursor")) + self._renderer._window_new_cursor("WaitCursor") + ) # prepare bem bem_names = [] if self._scale_mode != "None": - can_prepare_bem = _mri_subject_has_bem( - self._subject, self._subjects_dir) + can_prepare_bem = _mri_subject_has_bem(self._subject, self._subjects_dir) else: can_prepare_bem = False if can_prepare_bem: - pattern = bem_fname.format(subjects_dir=self._subjects_dir, - subject=self._subject, - name='(.+-bem)') + pattern = bem_fname.format( + subjects_dir=self._subjects_dir, subject=self._subject, name="(.+-bem)" + ) bem_dir, pattern = os.path.split(pattern) for filename in os.listdir(bem_dir): match = re.match(pattern, filename) @@ -1309,10 +1442,15 @@ def _save_subject_callback(self, overwrite=False): try: self._display_message(f"Scaling {self._subject_to}...") scale_mri( - subject_from=self._subject, subject_to=self._subject_to, - scale=self.coreg._scale, overwrite=overwrite, - subjects_dir=self._subjects_dir, skip_fiducials=True, - labels=True, annot=True, on_defects='ignore' + subject_from=self._subject, + subject_to=self._subject_to, + scale=self.coreg._scale, + overwrite=overwrite, + subjects_dir=self._subjects_dir, + skip_fiducials=True, + labels=True, + annot=True, + on_defects="ignore", ) except Exception: logger.error(f"Error scaling {self._subject_to}") @@ -1324,16 +1462,17 @@ def _save_subject_callback(self, overwrite=False): for bem_name in bem_names: try: self._display_message(f"Computing {bem_name} solution...") - bem_file = bem_fname.format(subjects_dir=self._subjects_dir, - subject=self._subject_to, - name=bem_name) + bem_file = bem_fname.format( + subjects_dir=self._subjects_dir, + subject=self._subject_to, + name=bem_name, + ) bemsol = make_bem_solution(bem_file) - write_bem_solution(bem_file[:-4] + '-sol.fif', bemsol) + write_bem_solution(bem_file[:-4] + "-sol.fif", bemsol) except Exception: logger.error(f"Error computing {bem_name} solution") else: - self._display_message(f"Computing {bem_name} solution..." - " Done!") + self._display_message(f"Computing {bem_name} solution..." " Done!") self._display_message(f"Saving {self._subject_to}... Done!") self._renderer._window_set_cursor(default_cursor) self._mri_scale_modified = False @@ -1342,7 +1481,7 @@ def _save_mri_fiducials(self, fname): self._display_message(f"Saving {fname}...") dig_montage = self.coreg.fiducials write_fiducials( - fname=fname, pts=dig_montage.dig, coord_frame='mri', overwrite=True + fname=fname, pts=dig_montage.dig, coord_frame="mri", overwrite=True ) self._set_fiducials_file(fname) self._display_message(f"Saving {fname}... Done!") @@ -1350,13 +1489,13 @@ def _save_mri_fiducials(self, fname): def _save_trans(self, fname): write_trans(fname, self.coreg.trans, overwrite=True) - self._display_message( - f"{fname} transform file is saved.") + self._display_message(f"{fname} transform file is saved.") self._trans_modified = False def _load_trans(self, fname): - mri_head_t = _ensure_trans(read_trans(fname, return_all=True), - 'mri', 'head')['trans'] + mri_head_t = _ensure_trans(read_trans(fname, return_all=True), "mri", "head")[ + "trans" + ] rot_x, rot_y, rot_z = rotation_angles(mri_head_t) x, y, z = mri_head_t[:3, 3] self.coreg._update_params( @@ -1366,17 +1505,16 @@ def _load_trans(self, fname): self._update_parameters() self._update_distance_estimation() self._update_plot() - self._display_message( - f"{fname} transform file is loaded.") + self._display_message(f"{fname} transform file is loaded.") def _update_fiducials_label(self): if self._fiducials_file is None: text = ( - '

No custom MRI fiducials loaded!

' - '

MRI fiducials could not be found in the standard ' - 'location. The displayed initial MRI fiducial locations ' - '(diamonds) were derived from fsaverage. Place, lock, and ' - 'save fiducials to discard this message.

' + "

No custom MRI fiducials loaded!

" + "

MRI fiducials could not be found in the standard " + "location. The displayed initial MRI fiducial locations " + "(diamonds) were derived from fsaverage. Place, lock, and " + "save fiducials to discard this message.

" ) else: assert self._fiducials_file == fid_fname.format( @@ -1384,30 +1522,24 @@ def _update_fiducials_label(self): ) assert self.coreg._fid_accurate is True text = ( - f'

MRI fiducials (diamonds) loaded from ' - f'standard location:

' - f'

{self._fiducials_file}

' + f"

MRI fiducials (diamonds) loaded from " + f"standard location:

" + f"

{self._fiducials_file}

" ) - self._forward_widget_command( - 'mri_fiducials_label', 'set_value', text - ) + self._forward_widget_command("mri_fiducials_label", "set_value", text) def _configure_dock(self): - if self._renderer._kind == 'notebook': + if self._renderer._kind == "notebook": collapse = True # collapsible and collapsed else: collapse = None # not collapsible - self._renderer._dock_initialize( - name="Input", area="left", max_width="350px" - ) + self._renderer._dock_initialize(name="Input", area="left", max_width="350px") mri_subject_layout = self._renderer._dock_add_group_box( name="MRI Subject", collapse=collapse, ) - subjects_dir_layout = self._renderer._dock_add_layout( - vertical=False - ) + subjects_dir_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["subjects_dir_field"] = self._renderer._dock_add_text( name="subjects_dir_field", value=self._subjects_dir, @@ -1422,7 +1554,7 @@ def _configure_dock(self): is_directory=True, icon=True, tooltip="Load the path to the directory containing the " - "FreeSurfer subjects", + "FreeSurfer subjects", layout=subjects_dir_layout, ) self._renderer._layout_add_widget( @@ -1444,38 +1576,33 @@ def _configure_dock(self): collapse=collapse, ) # Add MRI fiducials I/O widgets - self._widgets['mri_fiducials_label'] = self._renderer._dock_add_label( - value='', # Will be filled via _update_fiducials_label() + self._widgets["mri_fiducials_label"] = self._renderer._dock_add_label( + value="", # Will be filled via _update_fiducials_label() layout=mri_fiducials_layout, - selectable=True + selectable=True, ) # Reload & Save buttons go into their own layout widget - mri_fiducials_button_layout = self._renderer._dock_add_layout( - vertical=False - ) + mri_fiducials_button_layout = self._renderer._dock_add_layout(vertical=False) self._renderer._layout_add_widget( - layout=mri_fiducials_layout, - widget=mri_fiducials_button_layout + layout=mri_fiducials_layout, widget=mri_fiducials_button_layout ) self._widgets["reload_mri_fids"] = self._renderer._dock_add_button( - name='Reload MRI Fid.', + name="Reload MRI Fid.", callback=lambda: self._set_fiducials_file(self._fiducials_file), tooltip="Reload MRI fiducials from the standard location", layout=mri_fiducials_button_layout, ) # Disable reload button until we've actually loaded a fiducial file # (happens in _set_fiducials_file method) - self._forward_widget_command('reload_mri_fids', 'set_enabled', False) + self._forward_widget_command("reload_mri_fids", "set_enabled", False) self._widgets["save_mri_fids"] = self._renderer._dock_add_button( name="Save MRI Fid.", callback=lambda: self._save_mri_fiducials( - fid_fname.format( - subjects_dir=self._subjects_dir, subject=self._subject - ) + fid_fname.format(subjects_dir=self._subjects_dir, subject=self._subject) ), tooltip="Save MRI fiducials to the standard location. Fiducials " - "must be locked first!", + "must be locked first!", layout=mri_fiducials_button_layout, ) self._widgets["lock_fids"] = self._renderer._dock_add_check_box( @@ -1497,7 +1624,7 @@ def _configure_dock(self): name = f"fid_{coord}" self._widgets[name] = self._renderer._dock_add_spin_box( name=coord, - value=0., + value=0.0, rng=[-1e3, 1e3], callback=partial( self._set_fiducial, @@ -1509,16 +1636,13 @@ def _configure_dock(self): tooltip=f"Set the {coord} fiducial coordinate", layout=fiducial_coords_layout, ) - self._renderer._layout_add_widget( - mri_fiducials_layout, fiducial_coords_layout) + self._renderer._layout_add_widget(mri_fiducials_layout, fiducial_coords_layout) dig_source_layout = self._renderer._dock_add_group_box( name="Info source with digitization", collapse=collapse, ) - info_file_layout = self._renderer._dock_add_layout( - vertical=False - ) + info_file_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["info_file_field"] = self._renderer._dock_add_text( name="info_file_field", value=self._info_file, @@ -1531,8 +1655,7 @@ def _configure_dock(self): desc="Load", func=self._set_info_file, icon=True, - tooltip="Load the FIFF file with digitization data for " - "coregistration", + tooltip="Load the FIFF file with digitization data for " "coregistration", layout=info_file_layout, ) self._renderer._layout_add_widget( @@ -1561,7 +1684,7 @@ def _configure_dock(self): name="Omit", callback=self._omit_hsp, tooltip="Exclude the head shape points that are far away from " - "the MRI head", + "the MRI head", layout=omit_hsp_layout_2, ) self._widgets["reset_omit"] = self._renderer._dock_add_button( @@ -1629,7 +1752,7 @@ def _configure_dock(self): self._widgets[name] = self._renderer._dock_add_spin_box( name=name, value=attr[coords.index(coord)] * 1e2, - rng=[1., 10000.], # percent + rng=[1.0, 10000.0], # percent callback=partial( self._set_parameter, mode_name="scale", @@ -1647,18 +1770,17 @@ def _configure_dock(self): name="Fit fiducials with scaling", callback=self._fits_fiducials, tooltip="Find MRI scaling, rotation, and translation to fit all " - "3 fiducials", + "3 fiducials", layout=fit_scale_layout, ) self._widgets["fits_icp"] = self._renderer._dock_add_button( name="Fit ICP with scaling", callback=self._fits_icp, tooltip="Find MRI scaling, rotation, and translation to match the " - "head shape points", + "head shape points", layout=fit_scale_layout, ) - self._renderer._layout_add_widget( - scale_params_layout, fit_scale_layout) + self._renderer._layout_add_widget(scale_params_layout, fit_scale_layout) subject_to_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["subject_to"] = self._renderer._dock_add_text( name="subject-to", @@ -1673,8 +1795,7 @@ def _configure_dock(self): tooltip="Save scaled anatomy", layout=subject_to_layout, ) - self._renderer._layout_add_widget( - mri_scaling_layout, subject_to_layout) + self._renderer._layout_add_widget(mri_scaling_layout, subject_to_layout) param_layout = self._renderer._dock_add_group_box( name="Translation (t) and Rotation (r)", collapse=collapse, @@ -1699,8 +1820,8 @@ def _configure_dock(self): double=True, step=1, tooltip=f"Set the {coord} {mode_name.lower()}" - f" parameter (in {unit})", - layout=coord_layout + f" parameter (in {unit})", + layout=coord_layout, ) self._renderer._layout_add_widget(param_layout, coord_layout) @@ -1714,8 +1835,7 @@ def _configure_dock(self): self._widgets["fit_icp"] = self._renderer._dock_add_button( name="Fit ICP", callback=self._fit_icp, - tooltip="Find rotation and translation to match the " - "head shape points", + tooltip="Find rotation and translation to match the " "head shape points", layout=fit_layout, ) self._renderer._layout_add_widget(param_layout, fit_layout) @@ -1731,7 +1851,7 @@ def _configure_dock(self): func=self._save_trans, tooltip="Save the transform file to disk", layout=save_trans_layout, - filter='Head->MRI transformation (*-trans.fif *_trans.fif)', + filter="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._widgets["load_trans"] = self._renderer._dock_add_file_button( @@ -1740,7 +1860,7 @@ def _configure_dock(self): func=self._load_trans, tooltip="Load the transform file from disk", layout=save_trans_layout, - filter='Head->MRI transformation (*-trans.fif *_trans.fif)', + filter="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._renderer._layout_add_widget(trans_layout, save_trans_layout) @@ -1782,15 +1902,14 @@ def _configure_dock(self): name="Weights", layout=fitting_options_layout, ) - for point, fid in zip(("HSP", "EEG", "HPI"), - self._defaults["fiducials"]): + for point, fid in zip(("HSP", "EEG", "HPI"), self._defaults["fiducials"]): weight_layout = self._renderer._dock_add_layout(vertical=False) point_lower = point.lower() name = f"{point_lower}_weight" self._widgets[name] = self._renderer._dock_add_spin_box( name=point, value=getattr(self, f"_{point_lower}_weight"), - rng=[0., 100.], + rng=[0.0, 100.0], callback=partial(self._set_point_weight, point=point_lower), compact=True, double=True, @@ -1803,7 +1922,7 @@ def _configure_dock(self): self._widgets[name] = self._renderer._dock_add_spin_box( name=fid, value=getattr(self, f"_{fid_lower}_weight"), - rng=[0., 100.], + rng=[0.0, 100.0], callback=partial(self._set_point_weight, point=fid_lower), compact=True, double=True, @@ -1811,23 +1930,21 @@ def _configure_dock(self): layout=weight_layout, ) self._renderer._layout_add_widget(weights_layout, weight_layout) - self._widgets['reset_fitting_options'] = ( - self._renderer._dock_add_button( - name="Reset Fitting Options", - callback=self._reset_fitting_parameters, - tooltip="Reset all the fitting parameters to default value", - layout=fitting_options_layout, - ) + self._widgets["reset_fitting_options"] = self._renderer._dock_add_button( + name="Reset Fitting Options", + callback=self._reset_fitting_parameters, + tooltip="Reset all the fitting parameters to default value", + layout=fitting_options_layout, ) self._renderer._dock_add_stretch() def _configure_status_bar(self): self._renderer._status_bar_initialize() - self._widgets['status_message'] = self._renderer._status_bar_add_label( + self._widgets["status_message"] = self._renderer._status_bar_add_label( "", stretch=1 ) self._forward_widget_command( - 'status_message', 'hide', value=None, input_value=False + "status_message", "hide", value=None, input_value=False ) def _clean(self): @@ -1851,17 +1968,16 @@ def close(self): def _close_dialog_callback(self, button_name): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + self._accept_close_event = True if button_name == "Save": if self._trans_modified: - self._forward_widget_command( - "save_trans", "set_value", None) + self._forward_widget_command("save_trans", "set_value", None) # cancel means _save_trans is not called if self._trans_modified: self._accept_close_event = False if self._mri_fids_modified: - self._forward_widget_command( - "save_mri_fids", "set_value", None) + self._forward_widget_command("save_mri_fids", "set_value", None) if self._mri_scale_modified: if self._subject_to: self._save_subject(exit_mode=True) @@ -1869,7 +1985,7 @@ def _close_dialog_callback(self, button_name): dialog = self._renderer._dialog_create( title="CoregistrationUI", text="The name of the output subject used to " - "save the scaled anatomy is not set.", + "save the scaled anatomy is not set.", info_text="Please set a subject name", callback=lambda x: None, buttons=["Ok"], @@ -1883,9 +1999,9 @@ def _close_dialog_callback(self, button_name): assert button_name == "Discard" def _close_callback(self): - if self._trans_modified or self._mri_fids_modified or \ - self._mri_scale_modified: + if self._trans_modified or self._mri_fids_modified or self._mri_scale_modified: from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + # prepare the dialog's text text = "The following is/are not saved:" text += "