Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bcc07f1
example description update.
Shrecki Oct 15, 2024
99b593a
pca_flip allowed for volumetric
Shrecki Oct 16, 2024
00e0487
[FEAT] Volumetric PCA flip implementation
Shrecki Jan 30, 2025
3b3b754
Merge remote-tracking branch 'upstream/main'
Shrecki Jan 30, 2025
d88591f
Merge branch 'main' into pca_flip_volume
Shrecki Jan 30, 2025
2e56dcb
Removed source_space_custom_atlas example - should be object of separ…
Shrecki Jan 31, 2025
a677854
[MISC] Changelog update
Shrecki Jan 31, 2025
8a74ffe
[MISC] Fixed changelog
Shrecki Jan 31, 2025
dd15522
[FIX] Removed erroneous path from test case
Shrecki Jan 31, 2025
6839d73
[autofix.ci] apply automated fixes
autofix-ci[bot] Jan 31, 2025
379614e
[FEAT] Simplify label code and remove cruft code
Shrecki Mar 13, 2025
d3523b7
Merge branch 'pca_flip_volume' of github.com:Shrecki/mne-python into …
Shrecki Mar 13, 2025
33f911d
[FIX] Removed trivial branch
Shrecki Mar 13, 2025
256331b
Merge remote-tracking branch 'upstream/main'
Shrecki Mar 13, 2025
1626e5f
Merge branch 'main' into pca_flip_volume
Shrecki Mar 13, 2025
ec82986
[FIX] label_sign_flip incorrectly handled hemispheres
Shrecki Mar 14, 2025
0ca43cf
Imports moved up top
Shrecki Mar 25, 2025
ee4a174
Updating mri_name to save volumetric source
Shrecki May 15, 2025
3ce6b38
Fix of PCA flip in volume: returned constant 0 as flips meaningless i…
Shrecki May 15, 2025
7ae37e5
Fixed pca flip branch
Shrecki May 15, 2025
d9580da
Handling of flip being an int
Shrecki May 15, 2025
fd71779
Using numpy svd instead of scipy
Shrecki May 15, 2025
69937d2
PCA flip for volumetric is now using randomized SVD to manage to run …
Shrecki May 16, 2025
8888c13
Simplification of PCA flip
Shrecki May 16, 2025
775ec80
Logging
Shrecki May 16, 2025
0fd4be5
Found a trick to make everything much faster with only two svds
Shrecki May 16, 2025
1a89099
Flip handling in _compute_pca_quantitites
Shrecki May 16, 2025
b02c14c
Feat: montage now supports .pos information file
Shrecki May 23, 2025
8a96fec
Float convert in digitization
Shrecki May 23, 2025
1074c1a
Convert dig points to numpy array
Shrecki May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13092.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. _Evgeny Goldstein: https://github.com/evgenygoldstein
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Fabrice Guibert: https://github.com/Shrecki
.. _Fahimeh Mamashli: https://github.com/fmamashli
.. _Farzin Negahbani: https://github.com/Farzin-Negahbani
.. _Federico Raimondo: https://github.com/fraimondo
Expand Down
52 changes: 51 additions & 1 deletion mne/channels/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,54 @@ def _read_isotrak_elp_points(fname):
}


def _read_isotrak_pos_points(fname):
"""Read Polhemus Isotrak digitizer data from a ``.pos`` file.

Parameters
----------
fname : path-like
The filepath of .pos Polhemus Isotrak file.

Returns
-------
out : dict of arrays
The dictionary containing locations for 'nasion', 'lpa', 'rpa'
and 'points'.
"""
with open(fname) as fid:
file_str = fid.read()

# Get all lines which are points
int_pat = r"[+-]?\d+"
float_pat = r"[+-]?(?:\d+\.\d*|\d*\.\d+)(?:[eE][+-]?\d+)?"
pattern_points = re.compile(
rf"^\s*({int_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})",
re.MULTILINE,
)
points = pattern_points.findall(file_str)

# Get nasion, left and right
label_pat = r"[A-Za-z]+"
pattern_labels = re.compile(
rf"^\s*({label_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})",
re.MULTILINE,
)
labels = pattern_labels.findall(file_str)

return {
"nasion": np.array(
[tuple(map(float, x[1:])) for x in labels if x[0] == "nasion"][0]
),
"lpa": np.array(
[tuple(map(float, x[1:])) for x in labels if x[0] == "left"][0]
),
"rpa": np.array(
[tuple(map(float, x[1:])) for x in labels if x[0] == "right"][0]
),
"points": np.array([tuple(map(float, x[1:])) for x in points]),
}


def _read_isotrak_hsp_points(fname):
"""Read Polhemus Isotrak digitizer data from a ``.hsp`` file.

Expand Down Expand Up @@ -1459,7 +1507,7 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"):
read_dig_fif
read_dig_localite
"""
VALID_FILE_EXT = (".hsp", ".elp", ".eeg")
VALID_FILE_EXT = (".hsp", ".elp", ".eeg", ".pos")
fname = str(_check_fname(fname, overwrite="read", must_exist=True))
_scale = _check_unit_and_get_scaling(unit)

Expand All @@ -1468,6 +1516,8 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"):

if ext == ".elp":
data = _read_isotrak_elp_points(fname)
elif ext == ".pos":
data = _read_isotrak_pos_points(fname)
else:
# Default case we read points as hsp since is the most likely scenario
data = _read_isotrak_hsp_points(fname)
Expand Down
46 changes: 34 additions & 12 deletions mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,22 +1460,44 @@ def label_sign_flip(label, src):
flip : array
Sign flip vector (contains 1 or -1).
"""
if len(src) != 2:
raise ValueError("Only source spaces with 2 hemisphers are accepted")
if len(src) > 2 or len(src) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better / more modern check would be something like:

_validate_type(src, SourceSpaces, "src")
_check_option("source space kind", src.kind, ("volume", "surface"))
if src.kind == "volume" and len(src) != 1:
    raise ValueError("Only single-segment volumes, are supported, got labelized volume source space")

And incidentally I think eventually we could add support for segmented volume source spaces, as well as mixed source spaces (once surface + volume are fully supported, mixed isn't too bad after that). But probably not as part of this PR!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably! To support this, the (clarified) code in label_sign_flip with the hemis dictionary might be recycled:

hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
    hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
    hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
    hemis[label.hemi] = {"id": 0, "vertno": src[0]["vertno"]}
else:
    raise Exception(f'Unknown hemisphere type "{label.hemi}"')

raise ValueError(
"Only source spaces with between one and two "
+ "hemispheres are accepted, was {len(src)}"
)

if len(src) == 1 and label.hemi == "both":
raise ValueError(
'Cannot use hemisphere label "both" when source'
+ "space contains a single hemisphere."
)

lh_vertno = src[0]["vertno"]
rh_vertno = src[1]["vertno"]
hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
# If two sources available, the hemisphere's ID must be looked up.
# If only a single source, the ID is zero.
index_ = ("lh", "rh").index(label.hemi) if len(src) == 2 else 0
hemis[label.hemi] = {"id": index_, "vertno": src[index_]["vertno"]}
else:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')

# get source orientations
ori = list()
if label.hemi in ("lh", "both"):
vertices = label.vertices if label.hemi == "lh" else label.lh.vertices
vertno_sel = np.intersect1d(lh_vertno, vertices)
ori.append(src[0]["nn"][vertno_sel])
if label.hemi in ("rh", "both"):
vertices = label.vertices if label.hemi == "rh" else label.rh.vertices
vertno_sel = np.intersect1d(rh_vertno, vertices)
ori.append(src[1]["nn"][vertno_sel])
for hemi, hemi_infos in hemis.items():
# When the label is lh or rh, get vertices directly
if label.hemi == hemi:
vertices = label.vertices
# In the case where label is "both", get label.hemi.vertices
# (so either label.lh.vertices or label.rh.vertices)
else:
vertices = getattr(label, hemi).vertices
vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices)
ori.append(src[hemi_infos["id"]]["nn"][vertno_sel])
if len(ori) == 0:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')
ori = np.concatenate(ori, axis=0)
Expand Down
87 changes: 69 additions & 18 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .cov import Covariance
from .evoked import _get_peak
from .filter import FilterMixin, _check_fun, resample
from .fixes import _eye_array, _safe_svd
from .fixes import _eye_array
from .parallel import parallel_func
from .source_space._source_space import (
SourceSpaces,
Expand Down Expand Up @@ -3375,13 +3375,29 @@ def _get_ico_tris(grade, verbose=None, return_surf=False):
return ico


def _pca_flip(flip, data):
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
return sign * scale * V[0]
def _compute_pca_quantities(U, s, V, flip):
if flip is None: # Case of volumetric data: flip is meaningless
flip = 1
if isinstance(flip, int):
sign = np.sign((flip * U[:, 0]).sum())
else:
sign = np.sign(np.dot(U[:, 0], flip))
scale = np.linalg.norm(s) / np.sqrt(len(U))
result = sign * scale * V[0]
return result


def _pca_flip(flip, data, max_rank):
result = None
if data.shape[0] < 2:
result = data.mean(axis=0) # Trivial accumulator
else:
from sklearn.utils.extmath import randomized_svd

U, s, V = randomized_svd(data, n_components=max_rank)
# determine sign-flip.
result = _compute_pca_quantities(U, s, V, flip)
return result


_label_funcs = {
Expand Down Expand Up @@ -3433,6 +3449,8 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
# only computes vertex indices and label_flip will be list of None.
from .label import BiHemiLabel, Label, label_sign_flip

logger.debug(f"Selected mode: {mode}")

# if source estimate provided in stc, get vertices from source space and
# check that they are the same as in the stcs
_check_stc_src(stc, src)
Expand Down Expand Up @@ -3644,8 +3662,10 @@ def _get_default_label_modes():


def _get_allowed_label_modes(stc):
if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate):
if isinstance(stc, _BaseVectorSourceEstimate):
return ("mean", "max", "auto")
elif isinstance(stc, _BaseVolSourceEstimate):
return ("mean", "pca_flip", "max", "auto")
else:
return _get_default_label_modes()

Expand All @@ -3659,6 +3679,7 @@ def _gen_extract_label_time_course(
allow_empty=False,
mri_resolution=True,
verbose=None,
max_channels=400,
):
# loop through source estimates and extract time series
if src is None and mode in ["mean", "max"]:
Expand Down Expand Up @@ -3722,18 +3743,48 @@ def _gen_extract_label_time_course(
else:
# For other modes, initialize the label_tc array
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)

pca_volume = mode == "pca_flip" and kind == "volume"
if pca_volume:
from sklearn.utils.extmath import randomized_svd

logger.debug("First SVD for PCA volume on stc data")
u_b, s_b, vh_b = randomized_svd(stc.data, max_channels)
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
if vertidx is not None:
if isinstance(vertidx, sparse.csr_array):
assert mri_resolution
assert vertidx.shape[1] == stc.data.shape[0]
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
this_data = vertidx @ this_data
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
if pca_volume:
# Use a trick for efficiency:
# stc = Ub Sb VhB
# full_data = vertidx @ stc
# = vertidx @ Ub @ Sb @ Vhb
# Consider U_f, s_f, Vh_f = SVD(vertidx @ Ub @ Sb)
# Then U,S,V = svd(full_data) is such that
# U_f = U, S = s_f and V = Vh_f @ Vhb
# This trick is more efficient, because:
# - We compute a first SVD once on stc, restricted to
# only first max_channels singular vals/vecs (quite fast)
# - We project vertidx to be from Nvertex x Nsources
# to Nvertex x rank.
# - We compute SVD on Nvertex x rank
# As rank << Nsources, we end up saving a lot of computations.
tmp_array = vertidx @ u_b @ np.diag(s_b)
U, S, v_tmp = np.linalg.svd(tmp_array, full_matrices=False)
V = v_tmp @ vh_b
label_tc[i] = _compute_pca_quantities(U, S, V, flip)
else:
this_data = stc.data[vertidx]
label_tc[i] = func(flip, this_data)

if isinstance(vertidx, sparse.csr_array):
assert mri_resolution
assert vertidx.shape[1] == stc.data.shape[0]
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
this_data = vertidx @ this_data
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
else:
this_data = stc.data[vertidx]
if mode == "pca_flip":
label_tc[i] = func(flip, this_data, max_channels)
else:
label_tc[i] = func(flip, this_data)
logger.debug(f"Done with label {i}")
if mode is not None:
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
Expand Down
1 change: 1 addition & 0 deletions mne/source_space/_source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def export_volume(

# Get shape, inuse array and interpolation matrix from volume sources
src = src_types["volume"][0]
src["mri_file"] = src["mri_volume_name"]
aseg_data = None
if mri_resolution:
# read the mri file used to generate volumes
Expand Down
Loading
Loading