Skip to content

Commit c31e837

Browse files
committed
add sketch of HEDAnnotations [ci skip]
1 parent 6f6ccdc commit c31e837

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

mne/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __all__ = [
1111
"Evoked",
1212
"EvokedArray",
1313
"Forward",
14+
"HEDAnnotations",
1415
"Info",
1516
"Label",
1617
"MixedSourceEstimate",
@@ -260,6 +261,7 @@ from ._freesurfer import (
260261
)
261262
from .annotations import (
262263
Annotations,
264+
HEDAnnotations,
263265
annotations_from_events,
264266
count_annotations,
265267
events_from_annotations,

mne/annotations.py

Lines changed: 137 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
verbose,
5353
warn,
5454
)
55+
from .utils.check import _soft_import
5556

5657
# For testing windows_like_datetime, we monkeypatch "datetime" in this module.
5758
# Keep the true datetime object around for _validate_type use.
@@ -151,6 +152,7 @@ class Annotations:
151152
--------
152153
mne.annotations_from_events
153154
mne.events_from_annotations
155+
mne.HEDAnnotations
154156
155157
Notes
156158
-----
@@ -288,7 +290,7 @@ def orig_time(self):
288290

289291
def __eq__(self, other):
290292
"""Compare to another Annotations instance."""
291-
if not isinstance(other, Annotations):
293+
if not isinstance(other, type(self)):
292294
return False
293295
return (
294296
np.array_equal(self.onset, other.onset)
@@ -567,6 +569,8 @@ def _sort(self):
567569
self.duration = self.duration[order]
568570
self.description = self.description[order]
569571
self.ch_names = self.ch_names[order]
572+
if hasattr(self, "hed_tags"):
573+
self.hed_tags = self.hed_tags[order]
570574

571575
@verbose
572576
def crop(
@@ -758,7 +762,138 @@ def rename(self, mapping, verbose=None):
758762
return self
759763

760764

761-
# TODO: Add support for HED annotations for use in epoching.
765+
class HEDAnnotations(Annotations):
766+
"""Annotations object for annotating segments of raw data with HED tags.
767+
768+
Parameters
769+
----------
770+
onset : array of float, shape (n_annotations,)
771+
The starting time of annotations in seconds after ``orig_time``.
772+
duration : array of float, shape (n_annotations,) | float
773+
Durations of the annotations in seconds. If a float, all the
774+
annotations are given the same duration.
775+
description : array of str, shape (n_annotations,) | str
776+
Array of strings containing description for each annotation. If a
777+
string, all the annotations are given the same description. To reject
778+
epochs, use description starting with keyword 'bad'. See example above.
779+
hed_tags : array of str, shape (n_annotations,) | str
780+
Array of strings containing a HED tag for each annotation. If a single string
781+
is provided, all annotations are given the same HED tag.
782+
hed_version : str
783+
The HED schema version against which to validate the HED tags.
784+
orig_time : float | str | datetime | tuple of int | None
785+
A POSIX Timestamp, datetime or a tuple containing the timestamp as the
786+
first element and microseconds as the second element. Determines the
787+
starting time of annotation acquisition. If None (default),
788+
starting time is determined from beginning of raw data acquisition.
789+
In general, ``raw.info['meas_date']`` (or None) can be used for syncing
790+
the annotations with raw data if their acquisition is started at the
791+
same time. If it is a string, it should conform to the ISO8601 format.
792+
More precisely to this '%%Y-%%m-%%d %%H:%%M:%%S.%%f' particular case of
793+
the ISO8601 format where the delimiter between date and time is ' '.
794+
%(ch_names_annot)s
795+
796+
See Also
797+
--------
798+
mne.Annotations
799+
800+
Notes
801+
-----
802+
803+
.. versionadded:: 1.10
804+
"""
805+
806+
def __init__(
807+
self,
808+
onset,
809+
duration,
810+
description,
811+
hed_tags,
812+
hed_version="latest", # TODO @VisLab what is a sensible default here?
813+
orig_time=None,
814+
ch_names=None,
815+
):
816+
hed = _soft_import("hed", "validation of HED tags in annotations") # noqa
817+
# TODO is some sort of initialization of the HED cache directory necessary?
818+
super().__init__(
819+
onset=onset,
820+
duration=duration,
821+
description=description,
822+
orig_time=orig_time,
823+
ch_names=ch_names,
824+
)
825+
# TODO validate the HED version the user claims to be using.
826+
self.hed_version = hed_version
827+
self._update_hed_tags(hed_tags=hed_tags)
828+
829+
def _update_hed_tags(self, hed_tags):
830+
if len(hed_tags) != len(self):
831+
raise ValueError(
832+
f"Number of HED tags ({len(hed_tags)}) must match the number of "
833+
f"annotations ({len(self)})."
834+
)
835+
# TODO insert validation of HED tags here
836+
self.hed_tags = hed_tags
837+
838+
def __eq__(self, other):
839+
"""Compare to another HEDAnnotations instance."""
840+
return (
841+
super().__eq__(self, other)
842+
and np.array_equal(self.hed_tags, other.hed_tags)
843+
and self.hed_version == other.hed_version
844+
)
845+
846+
def __repr__(self):
847+
"""Show a textual summary of the object."""
848+
counter = Counter(self.hed_tags)
849+
kinds = ", ".join(["{} ({})".format(*k) for k in sorted(counter.items())])
850+
kinds = (": " if len(kinds) > 0 else "") + kinds
851+
ch_specific = ", channel-specific" if self._any_ch_names() else ""
852+
s = (
853+
f"HEDAnnotations | {len(self.onset)} segment"
854+
f"{_pl(len(self.onset))}{ch_specific}{kinds}"
855+
)
856+
return "<" + shorten(s, width=77, placeholder=" ...") + ">"
857+
858+
def __getitem__(self, key, *, with_ch_names=None):
859+
"""Propagate indexing and slicing to the underlying numpy structure."""
860+
result = super().__getitem__(self, key, with_ch_names=with_ch_names)
861+
if isinstance(result, OrderedDict):
862+
result["hed_tags"] = self.hed_tags[key]
863+
else:
864+
key = list(key) if isinstance(key, tuple) else key
865+
hed_tags = self.hed_tags[key]
866+
return HEDAnnotations(
867+
result.onset,
868+
result.duration,
869+
result.description,
870+
hed_tags,
871+
hed_version=self.hed_version,
872+
orig_time=self.orig_time,
873+
ch_names=result.ch_names,
874+
)
875+
876+
def append(self, onset, duration, description, ch_names=None):
877+
"""TODO."""
878+
pass
879+
880+
def count(self):
881+
"""TODO. Unlike Annotations.count, keys should be HED tags not descriptions."""
882+
pass
883+
884+
def crop(
885+
self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None
886+
):
887+
"""TODO."""
888+
pass
889+
890+
def delete(self, idx):
891+
"""TODO."""
892+
pass
893+
894+
def to_data_frame(self, time_format="datetime"):
895+
"""TODO."""
896+
pass
762897

763898

764899
class EpochAnnotationsMixin:

0 commit comments

Comments
 (0)