Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api/events.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Events

Annotations
AcqParserFIF
HEDAnnotations
concatenate_events
count_events
find_events
Expand Down
2 changes: 2 additions & 0 deletions mne/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __all__ = [
"Evoked",
"EvokedArray",
"Forward",
"HEDAnnotations",
"Info",
"Label",
"MixedSourceEstimate",
Expand Down Expand Up @@ -260,6 +261,7 @@ from ._freesurfer import (
)
from .annotations import (
Annotations,
HEDAnnotations,
annotations_from_events,
count_annotations,
events_from_annotations,
Expand Down
249 changes: 247 additions & 2 deletions mne/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
write_name_list_sanitized,
write_string,
)
from .fixes import _compare_version
from .utils import (
_check_dict_keys,
_check_dt,
Expand All @@ -52,6 +53,7 @@
verbose,
warn,
)
from .utils.check import _soft_import

# For testing windows_like_datetime, we monkeypatch "datetime" in this module.
# Keep the true datetime object around for _validate_type use.
Expand Down Expand Up @@ -151,6 +153,7 @@ class Annotations:
--------
mne.annotations_from_events
mne.events_from_annotations
mne.HEDAnnotations

Notes
-----
Expand Down Expand Up @@ -288,7 +291,7 @@ def orig_time(self):

def __eq__(self, other):
"""Compare to another Annotations instance."""
if not isinstance(other, Annotations):
if not isinstance(other, type(self)):
return False
return (
np.array_equal(self.onset, other.onset)
Expand Down Expand Up @@ -567,6 +570,12 @@ def _sort(self):
self.duration = self.duration[order]
self.description = self.description[order]
self.ch_names = self.ch_names[order]
if hasattr(self, "hed_string"):
self.hed_string._objs = [self.hed_string._objs[i] for i in order]
for i in order:
self.hed_string.__setitem__(
i, self.hed_string._objs[i].get_original_hed_string()
)

@verbose
def crop(
Expand Down Expand Up @@ -758,6 +767,241 @@ def rename(self, mapping, verbose=None):
return self


class _HEDStrings(list):
"""Subclass of list that will validate before __setitem__."""

def __init__(self, *args, hed_version, **kwargs):
self._hed = _soft_import("hed", "validation of HED tags in annotations")
self._schema = self._hed.load_schema_version(hed_version)
super().__init__(*args, **kwargs)
self._objs = [self._validate_hed_string(item, self._schema) for item in self]

def __setitem__(self, key, value):
"""Validate value first, before assigning."""
hs = self._validate_hed_string(value, self._schema)
super().__setitem__(key, hs.get_original_hed_string())
self._objs[key] = hs

def _validate_hed_string(self, value, schema):
# create HedString object and validate it
hs = self._hed.HedString(value, schema)
# handle any errors
error_handler = self._hed.errors.ErrorHandler(check_for_warnings=False)
issues = hs.validate(allow_placeholders=False, error_handler=error_handler)
error_string = self._hed.get_printable_issue_string(issues)
if len(error_string):
raise ValueError(f"A HED string failed to validate:\n {error_string}")
hs.sort()
return hs

def append(self, item):
"""Append an item to the end of the HEDString list."""
hs = self._validate_hed_string(item, self._schema)
super().append(hs.get_original_hed_string())
self._objs.append(hs)


@fill_doc
class HEDAnnotations(Annotations):
"""Annotations object for annotating segments of raw data with HED tags.

Parameters
----------
onset : array of float, shape (n_annotations,)
The starting time of annotations in seconds after ``orig_time``.
duration : array of float, shape (n_annotations,) | float
Durations of the annotations in seconds. If a float, all the
annotations are given the same duration.
description : array of str, shape (n_annotations,) | str
Array of strings containing description for each annotation. If a
string, all the annotations are given the same description. To reject
epochs, use description starting with keyword 'bad'. See example above.
hed_string : array of str, shape (n_annotations,) | str
Sequence of strings containing a HED tag (or comma-separated list of HED tags)
for each annotation. If a single string is provided, all annotations are
assigned the same HED string.
hed_version : str
The HED schema version against which to validate the HED strings.
orig_time : float | str | datetime | tuple of int | None
A POSIX Timestamp, datetime or a tuple containing the timestamp as the
first element and microseconds as the second element. Determines the
starting time of annotation acquisition. If None (default),
starting time is determined from beginning of raw data acquisition.
In general, ``raw.info['meas_date']`` (or None) can be used for syncing
the annotations with raw data if their acquisition is started at the
same time. If it is a string, it should conform to the ISO8601 format.
More precisely to this '%%Y-%%m-%%d %%H:%%M:%%S.%%f' particular case of
the ISO8601 format where the delimiter between date and time is ' '.
%(ch_names_annot)s

See Also
--------
mne.Annotations

Notes
-----

.. versionadded:: 1.10
"""

def __init__(
self,
onset,
duration,
description,
hed_string,
hed_version="8.3.0",
orig_time=None,
ch_names=None,
):
self._hed_version = hed_version
self.hed_string = _HEDStrings(hed_string, hed_version=self._hed_version)
super().__init__(
onset=onset,
duration=duration,
description=description,
orig_time=orig_time,
ch_names=ch_names,
)

def __eq__(self, other):
"""Compare to another HEDAnnotations instance."""
_slf = self.hed_string
_oth = other.hed_string

if _compare_version(self._hed_version, "<", other._hed_version):
_slf = [_slf._validate_hed_string(v, _oth._schema) for v in _slf._objs]
elif _compare_version(self._hed_version, ">", other._hed_version):
_oth = [_oth._validate_hed_string(v, _slf._schema) for v in _oth._objs]
return super().__eq__(other) and _slf == _oth

def __repr__(self):
"""Show a textual summary of the object."""
counter = Counter([hs.get_as_short() for hs in self.hed_string._objs])

# textwrap.shorten won't work: we remove all spaces and shouldn't split on `-`
def _shorten(text, width=74, placeholder=" ..."):
parts = text.split(",")
out = parts[0]
for part in parts[1:]:
# +1 for the comma ↓↓↓
if width < len(out) + 1 + len(part) + len(placeholder):
break
out = f"{out},{part}"
return out + placeholder

kinds = [
f"{_shorten(k, width=74):<74} ({v})" for k, v in sorted(counter.items())
]
if len(kinds) > 5:
kinds = [*kinds[:5], f"... and {len(kinds) - 5} more"]
kinds = "\n ".join(kinds)
if len(kinds):
kinds = f":\n {kinds}\n"
ch_specific = ", channel-specific" if self._any_ch_names() else ""
s = (
f"HEDAnnotations | {len(self.onset)} segment"
f"{_pl(len(self.onset))}{ch_specific}{kinds}"
)
return f"<{s}>"

def __getitem__(self, key, *, with_ch_names=None):
"""Propagate indexing and slicing to the underlying structure."""
result = super().__getitem__(key, with_ch_names=with_ch_names)
if isinstance(result, OrderedDict):
result["hed_string"] = self.hed_string[key]
return result
else:
key = list(key) if isinstance(key, tuple) else key
hed_string = [self.hed_string[key]]
return HEDAnnotations(
result.onset,
result.duration,
result.description,
hed_string=hed_string,
hed_version=self._hed_version,
orig_time=self.orig_time,
ch_names=result.ch_names,
)

def __getstate__(self):
"""Make serialization work, by removing module reference."""
return dict(
_orig_time=self._orig_time,
onset=self.onset,
duration=self.duration,
description=self.description,
ch_names=self.ch_names,
hed_string=list(self.hed_string),
_hed_version=self._hed_version,
)

def __setstate__(self, state):
"""Unpack from serialized format."""
self._orig_time = state["_orig_time"]
self.onset = state["onset"]
self.duration = state["duration"]
self.description = state["description"]
self.ch_names = state["ch_names"]
self._hed_version = state["_hed_version"]
self.hed_string = _HEDStrings(
state["hed_string"], hed_version=self._hed_version
)

@fill_doc
def append(self, *, onset, duration, description, hed_string, ch_names=None):
"""Add an annotated segment. Operates inplace.

Parameters
----------
onset : float | array-like
Annotation time onset from the beginning of the recording in
seconds.
duration : float | array-like
Duration of the annotation in seconds.
description : str | array-like
Description for the annotation. To reject epochs, use description
starting with keyword 'bad'.
hed_string : array of str, shape (n_annotations,) | str
Sequence of strings containing a HED tag (or comma-separated list of HED
tags) for each annotation. If a single string is provided, all annotations
are assigned the same HED string.
%(ch_names_annot)s

Returns
-------
self : mne.HEDAnnotations
The modified HEDAnnotations object.
"""
self.hed_string.append(hed_string)
super().append(
onset=onset, duration=duration, description=description, ch_names=ch_names
)

def crop(
self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None
):
"""TODO."""
pass

def delete(self, idx):
"""Remove an annotation. Operates inplace.

Parameters
----------
idx : int | array-like of int
Index of the annotation to remove. Can be array-like to remove multiple
indices.
"""
_ = self.hed_string._objs.pop(idx)
_ = self.hed_string.pop(idx)
super().delete(idx)

def to_data_frame(self, time_format="datetime"):
"""TODO."""
pass


class EpochAnnotationsMixin:
"""Mixin class for Annotations in Epochs."""

Expand Down Expand Up @@ -1732,5 +1976,6 @@ def count_annotations(annotations):
>>> count_annotations(annotations)
{'T0': 2, 'T1': 1}
"""
types, counts = np.unique(annotations.description, return_counts=True)
field = "hed_string" if isinstance(annotations, HEDAnnotations) else "description"
types, counts = np.unique(getattr(annotations, field), return_counts=True)
return {str(t): int(count) for t, count in zip(types, counts)}
Loading
Loading