diff --git a/doc/changes/devel/13092.newfeature.rst b/doc/changes/devel/13092.newfeature.rst new file mode 100644 index 00000000000..90f8c60b4f7 --- /dev/null +++ b/doc/changes/devel/13092.newfeature.rst @@ -0,0 +1 @@ +Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 0d5ee6a5c73..d208ab06388 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -86,6 +86,7 @@ .. _Evgeny Goldstein: https://github.com/evgenygoldstein .. _Ezequiel Mikulan: https://github.com/ezemikulan .. _Ezequiel Mikulan: https://github.com/ezemikulan +.. _Fabrice Guibert: https://github.com/Shrecki .. _Fahimeh Mamashli: https://github.com/fmamashli .. _Farzin Negahbani: https://github.com/Farzin-Negahbani .. _Federico Raimondo: https://github.com/fraimondo diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 15cef38dec7..b2444c94b90 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1383,6 +1383,54 @@ def _read_isotrak_elp_points(fname): } +def _read_isotrak_pos_points(fname): + """Read Polhemus Isotrak digitizer data from a ``.pos`` file. + + Parameters + ---------- + fname : path-like + The filepath of .pos Polhemus Isotrak file. + + Returns + ------- + out : dict of arrays + The dictionary containing locations for 'nasion', 'lpa', 'rpa' + and 'points'. + """ + with open(fname) as fid: + file_str = fid.read() + + # Get all lines which are points + int_pat = r"[+-]?\d+" + float_pat = r"[+-]?(?:\d+\.\d*|\d*\.\d+)(?:[eE][+-]?\d+)?" + pattern_points = re.compile( + rf"^\s*({int_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})", + re.MULTILINE, + ) + points = pattern_points.findall(file_str) + + # Get nasion, left and right + label_pat = r"[A-Za-z]+" + pattern_labels = re.compile( + rf"^\s*({label_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})", + re.MULTILINE, + ) + labels = pattern_labels.findall(file_str) + + return { + "nasion": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "nasion"][0] + ), + "lpa": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "left"][0] + ), + "rpa": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "right"][0] + ), + "points": np.array([tuple(map(float, x[1:])) for x in points]), + } + + def _read_isotrak_hsp_points(fname): """Read Polhemus Isotrak digitizer data from a ``.hsp`` file. @@ -1459,7 +1507,7 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): read_dig_fif read_dig_localite """ - VALID_FILE_EXT = (".hsp", ".elp", ".eeg") + VALID_FILE_EXT = (".hsp", ".elp", ".eeg", ".pos") fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) @@ -1468,6 +1516,8 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): if ext == ".elp": data = _read_isotrak_elp_points(fname) + elif ext == ".pos": + data = _read_isotrak_pos_points(fname) else: # Default case we read points as hsp since is the most likely scenario data = _read_isotrak_hsp_points(fname) diff --git a/mne/label.py b/mne/label.py index 02bf9dc09c0..dc9c9c83251 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1460,22 +1460,44 @@ def label_sign_flip(label, src): flip : array Sign flip vector (contains 1 or -1). """ - if len(src) != 2: - raise ValueError("Only source spaces with 2 hemisphers are accepted") + if len(src) > 2 or len(src) == 0: + raise ValueError( + "Only source spaces with between one and two " + + "hemispheres are accepted, was {len(src)}" + ) + + if len(src) == 1 and label.hemi == "both": + raise ValueError( + 'Cannot use hemisphere label "both" when source' + + "space contains a single hemisphere." + ) - lh_vertno = src[0]["vertno"] - rh_vertno = src[1]["vertno"] + hemis = {} + + # Build hemisphere info dictionary + if label.hemi == "both": + hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]} + hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]} + elif label.hemi in ("lh", "rh"): + # If two sources available, the hemisphere's ID must be looked up. + # If only a single source, the ID is zero. + index_ = ("lh", "rh").index(label.hemi) if len(src) == 2 else 0 + hemis[label.hemi] = {"id": index_, "vertno": src[index_]["vertno"]} + else: + raise Exception(f'Unknown hemisphere type "{label.hemi}"') # get source orientations ori = list() - if label.hemi in ("lh", "both"): - vertices = label.vertices if label.hemi == "lh" else label.lh.vertices - vertno_sel = np.intersect1d(lh_vertno, vertices) - ori.append(src[0]["nn"][vertno_sel]) - if label.hemi in ("rh", "both"): - vertices = label.vertices if label.hemi == "rh" else label.rh.vertices - vertno_sel = np.intersect1d(rh_vertno, vertices) - ori.append(src[1]["nn"][vertno_sel]) + for hemi, hemi_infos in hemis.items(): + # When the label is lh or rh, get vertices directly + if label.hemi == hemi: + vertices = label.vertices + # In the case where label is "both", get label.hemi.vertices + # (so either label.lh.vertices or label.rh.vertices) + else: + vertices = getattr(label, hemi).vertices + vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices) + ori.append(src[hemi_infos["id"]]["nn"][vertno_sel]) if len(ori) == 0: raise Exception(f'Unknown hemisphere type "{label.hemi}"') ori = np.concatenate(ori, axis=0) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index deeb3a43ede..5717178d3de 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -19,7 +19,7 @@ from .cov import Covariance from .evoked import _get_peak from .filter import FilterMixin, _check_fun, resample -from .fixes import _eye_array, _safe_svd +from .fixes import _eye_array from .parallel import parallel_func from .source_space._source_space import ( SourceSpaces, @@ -3375,13 +3375,29 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): return ico -def _pca_flip(flip, data): - U, s, V = _safe_svd(data, full_matrices=False) - # determine sign-flip - sign = np.sign(np.dot(U[:, 0], flip)) - # use average power in label for scaling - scale = np.linalg.norm(s) / np.sqrt(len(data)) - return sign * scale * V[0] +def _compute_pca_quantities(U, s, V, flip): + if flip is None: # Case of volumetric data: flip is meaningless + flip = 1 + if isinstance(flip, int): + sign = np.sign((flip * U[:, 0]).sum()) + else: + sign = np.sign(np.dot(U[:, 0], flip)) + scale = np.linalg.norm(s) / np.sqrt(len(U)) + result = sign * scale * V[0] + return result + + +def _pca_flip(flip, data, max_rank): + result = None + if data.shape[0] < 2: + result = data.mean(axis=0) # Trivial accumulator + else: + from sklearn.utils.extmath import randomized_svd + + U, s, V = randomized_svd(data, n_components=max_rank) + # determine sign-flip. + result = _compute_pca_quantities(U, s, V, flip) + return result _label_funcs = { @@ -3433,6 +3449,8 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # only computes vertex indices and label_flip will be list of None. from .label import BiHemiLabel, Label, label_sign_flip + logger.debug(f"Selected mode: {mode}") + # if source estimate provided in stc, get vertices from source space and # check that they are the same as in the stcs _check_stc_src(stc, src) @@ -3644,8 +3662,10 @@ def _get_default_label_modes(): def _get_allowed_label_modes(stc): - if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate): + if isinstance(stc, _BaseVectorSourceEstimate): return ("mean", "max", "auto") + elif isinstance(stc, _BaseVolSourceEstimate): + return ("mean", "pca_flip", "max", "auto") else: return _get_default_label_modes() @@ -3659,6 +3679,7 @@ def _gen_extract_label_time_course( allow_empty=False, mri_resolution=True, verbose=None, + max_channels=400, ): # loop through source estimates and extract time series if src is None and mode in ["mean", "max"]: @@ -3722,18 +3743,48 @@ def _gen_extract_label_time_course( else: # For other modes, initialize the label_tc array label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype) + + pca_volume = mode == "pca_flip" and kind == "volume" + if pca_volume: + from sklearn.utils.extmath import randomized_svd + + logger.debug("First SVD for PCA volume on stc data") + u_b, s_b, vh_b = randomized_svd(stc.data, max_channels) for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)): if vertidx is not None: - if isinstance(vertidx, sparse.csr_array): - assert mri_resolution - assert vertidx.shape[1] == stc.data.shape[0] - this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) - this_data = vertidx @ this_data - this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + if pca_volume: + # Use a trick for efficiency: + # stc = Ub Sb VhB + # full_data = vertidx @ stc + # = vertidx @ Ub @ Sb @ Vhb + # Consider U_f, s_f, Vh_f = SVD(vertidx @ Ub @ Sb) + # Then U,S,V = svd(full_data) is such that + # U_f = U, S = s_f and V = Vh_f @ Vhb + # This trick is more efficient, because: + # - We compute a first SVD once on stc, restricted to + # only first max_channels singular vals/vecs (quite fast) + # - We project vertidx to be from Nvertex x Nsources + # to Nvertex x rank. + # - We compute SVD on Nvertex x rank + # As rank << Nsources, we end up saving a lot of computations. + tmp_array = vertidx @ u_b @ np.diag(s_b) + U, S, v_tmp = np.linalg.svd(tmp_array, full_matrices=False) + V = v_tmp @ vh_b + label_tc[i] = _compute_pca_quantities(U, S, V, flip) else: - this_data = stc.data[vertidx] - label_tc[i] = func(flip, this_data) - + if isinstance(vertidx, sparse.csr_array): + assert mri_resolution + assert vertidx.shape[1] == stc.data.shape[0] + this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) + this_data = vertidx @ this_data + this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + else: + this_data = stc.data[vertidx] + if mode == "pca_flip": + label_tc[i] = func(flip, this_data, max_channels) + else: + label_tc[i] = func(flip, this_data) + logger.debug(f"Done with label {i}") if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 for i, nv in enumerate(nvert[2:]): diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index d64989961cf..90741d62bb6 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -642,6 +642,7 @@ def export_volume( # Get shape, inuse array and interpolation matrix from volume sources src = src_types["volume"][0] + src["mri_file"] = src["mri_volume_name"] aseg_data = None if mri_resolution: # read the mri file used to generate volumes diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index e4fa5a36b25..b8af56be21c 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -73,7 +73,14 @@ read_inverse_operator, ) from mne.morph_map import _make_morph_map_hemi -from mne.source_estimate import _get_vol_mask, _make_stc, grade_to_tris +from mne.source_estimate import ( + _get_vol_mask, + _make_stc, + _pca_flip, + _prepare_label_extraction, + _volume_labels, + grade_to_tris, +) from mne.source_space._source_space import _get_src_nn from mne.transforms import apply_trans, invert_transform, transform_surface_to from mne.utils import ( @@ -679,6 +686,147 @@ def test_center_of_mass(): assert_equal(np.round(t, 2), 0.12) +@testing.requires_testing_data +@pytest.mark.parametrize( + "label_type, mri_res, test_label, cf, call", + [ + (str, False, False, "head", "meth"), # head frame + (str, False, str, "mri", "func"), # fastest, default for testing + (str, True, str, "mri", "func"), # fastest, default for testing + (str, True, False, "mri", "func"), # mri_resolution + (list, True, False, "mri", "func"), # volume label as list + (dict, True, False, "mri", "func"), # volume label as dict + ], +) +def test_extract_label_time_course_volume_pca_flip( + src_volume_labels, label_type, mri_res, test_label, cf, call +): + """Test extraction of label timecourses on VolumetricSourceEstimate with PCA.""" + # Setup of data + src_labels, volume_labels, lut = src_volume_labels + n_tot = 46 + assert n_tot == len(src_labels) + inv = read_inverse_operator(fname_inv_vol) + if cf == "head": + src = inv["src"] + else: + src = read_source_spaces(fname_src_vol) + klass = VolVectorSourceEstimate._scalar_class + vertices = [src[0]["vertno"]] + n_verts = len(src[0]["vertno"]) + n_times = 50 + data = np.arange(1, n_verts + 1) + end_shape = (n_times,) + data = np.repeat(data[..., np.newaxis], n_times, -1) + stcs = [klass(data.astype(float), vertices, 0, 1)] + + def eltc(*args, **kwargs): + if call == "func": + return extract_label_time_course(stcs, *args, **kwargs) + else: + return [stcs[0].extract_label_time_course(*args, **kwargs)] + + # triage "labels" argument + if mri_res: + # All should be there + missing = [] + else: + # Nearest misses these + missing = [ + "Left-vessel", + "Right-vessel", + "5th-Ventricle", + "non-WM-hypointensities", + ] + n_want = len(src_labels) + if label_type is str: + labels = fname_aseg + elif label_type is list: + labels = (fname_aseg, volume_labels) + else: + assert label_type is dict + labels = (fname_aseg, {k: lut[k] for k in volume_labels}) + assert mri_res + assert len(missing) == 0 + # we're going to add one that won't exist + missing = ["intentionally_bad"] + labels[1][missing[0]] = 10000 + n_want += 1 + n_tot += 1 + n_want -= len(missing) + + # actually do the testing + labels_expanded = _volume_labels(src, labels, mri_res) + _, src_flip = _prepare_label_extraction( + stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res) + ) + + mode = "pca_flip" + with catch_logging() as log: + label_tc = eltc( + labels, + src, + mode=mode, + allow_empty="ignore", + mri_resolution=mri_res, + verbose=True, + ) + log = log.getvalue() + assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None + if len(missing): + # assert that the missing ones get logged + assert "does not contain" in log + assert repr(missing) in log + else: + assert "does not contain" not in log + assert f"\n{n_want}/{n_tot} atlas regions had at least" in log + assert len(label_tc) == 1 + label_tc = label_tc[0] + assert label_tc.shape == (n_tot,) + end_shape + assert label_tc.shape == (n_tot, n_times) + # let's test some actual values by trusting the masks provided by + # setup_volume_source_space. mri_resolution=True does some + # interpolation so we should not expect equivalence, False does + # nearest so we should. + if mri_res: + rtol = 0.8 # max much more sensitive + else: + rtol = 0.0 + for si, s in enumerate(src_labels): + func = _pca_flip + these = data[np.isin(src[0]["vertno"], s["vertno"])] + print(these.shape) + assert len(these) == s["nuse"] + if si == 0 and s["seg_name"] == "Unknown": + continue # unknown is crappy + if s["nuse"] == 0: + want = 0.0 + if mri_res: + # this one is totally due to interpolation, so no easy + # test here + continue + else: + if src_flip[si] is None: + want = None + else: + want = func(src_flip[si], these) + if want is not None: + assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol) + # compare with in_label, only on every fourth for speed + if test_label is not False and si % 4 == 0: + label = s["seg_name"] + if test_label is int: + label = lut[label] + in_label = stcs[0].in_label(label, fname_aseg, src).data + assert in_label.shape == (s["nuse"],) + end_shape + if np.all(want == 0): + assert in_label.shape[0] == 0 + else: + if src_flip[si] is not None: + in_label = func(src_flip[si], in_label) + assert_allclose(in_label, want, atol=1e-6, rtol=rtol) + + @testing.requires_testing_data @pytest.mark.parametrize("kind", ("surface", "mixed")) @pytest.mark.parametrize("vector", (False, True))