Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions neo/rawio/monkeylogicrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import struct
from typing import Union

from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
_spike_channel_dtype, _event_channel_dtype)
Expand All @@ -22,7 +23,8 @@ class MLBlock(dict):
'integers': (8, 'Q'),
'uint8': (1, 'B'),
'single': (4, 'f'),
'double': (8, 'd')}
'double': (8, 'd'),
'uint64': (8, 'L')}

@staticmethod
def generate_block(f):
Expand All @@ -41,6 +43,9 @@ def generate_block(f):
# print(var_name)

LT = f.read(8)
# print(len(LT))
if len(LT) == 0:
return None
LT = struct.unpack('Q', LT)[0]
# print(f'LT: {LT}')
var_type = f.read(LT)
Expand Down Expand Up @@ -76,6 +81,11 @@ def __repr__(self):
if self.data is None:
shape = 0
dt = ''

elif not hasattr(self.data, '__len__'):
shape = (1, )
dt = f' dtype: {self.var_type}'

else:
shape = getattr(self.data, 'shape', len(self.data))
dt = f' dtype: {self.var_type}'
Expand Down Expand Up @@ -132,9 +142,12 @@ def read_data(self, f, recursive=False):

for field in range(n_fields * np.prod(self.var_size)):
bl = MLBlock.generate_block(f)
if recursive:
self[bl.var_name] = bl
bl.read_data(f, recursive=recursive)
if bl is None:
pass
else:
if recursive:
self[bl.var_name] = bl
bl.read_data(f, recursive=recursive)

elif self.var_type == 'cell':
# cells are always 2D
Expand All @@ -144,11 +157,18 @@ def read_data(self, f, recursive=False):
bl = MLBlock.generate_block(f)
if recursive:
data[field] = bl

if bl is None:
pass
else:
bl.read_data(f, recursive=recursive)

bl.read_data(f, recursive=recursive)
data = data.reshape(self.var_size)
self.data = data

elif self.var_type == 'function_handle':
pass

else:
raise ValueError(f'unknown variable type {self.var_type}')

Expand Down Expand Up @@ -216,6 +236,9 @@ def _parse_header(self):

exclude_signals = ['SampleInterval']

print(self.ml_blocks.keys())
print(self.ml_blocks.values())

# rawio configuration
signal_streams = []
signal_channels = []
Expand All @@ -230,14 +253,40 @@ def _parse_header(self):
def _register_signal(sig_block, name):
nonlocal stream_id
nonlocal chan_id
if not isinstance(sig_data, dict) and any(sig_data.shape):
if not isinstance(sig_data, dict) and any(np.shape(sig_data)):
signal_streams.append((name, stream_id))

sr = 1 # TODO: Where to get the sampling rate info?
# ML/Trial1/AnalogData/SampleInterval
# ML/MLConfig/HighFrequencyDAQ/SampleRate
# ML/MLConfig/VoiceRecording/SampleRate
# ML/MLConfig/AISampleRate
# ML/TrialRecord/LastTrialAnalogData/SampleInterval

dtype = type(sig_data)
units = '' # TODO: Where to find the unit info?
# Degree of visual angle is default coordinate system used by ML, see here:
# https://monkeylogic.nimh.nih.gov/docs_CoordinateConversion.html

# /ML/MLConfig/DiagonalSize
# /ML/MLConfig/ViewingDistance

# ML/MLConfig/Screen has details about screen, specifically:
# /ML/MLConfig/Screen/Xsize
# /ML/MLConfig/Screen/Ysize
# /ML/MLConfig/Screen/PixelsPerDegree
# /ML/MLConfig/Screen/RefreshRate
# /ML/MLConfig/Screen/FrameLength
# /ML/MLConfig/Screen/VBlankLength

gain = 1 # TODO: Where to find the gain info?

# ML/MLConfig/Webcam/1/Property/Gain

offset = 0 # TODO: Can signals have an offset in ML?

# /ML/MLConfig/EyeTransform/2/offset if it exists

stream_id = 0 # all analog data belong to same stream

if sig_block.shape[1] == 1:
Expand All @@ -254,9 +303,14 @@ def _register_signal(sig_block, name):
for sig_name, sig_data in ana_block.items():
if sig_name in exclude_signals:
continue

print(sig_name)
print(sig_data)
# print(sig_sub_name)
# print(sig_sub_data)

# 1st level signals ('Trial1'/'AnalogData'/<signal>')
if not isinstance(sig_data, dict) and any(sig_data.shape):
if not isinstance(sig_data, dict) and any(np.shape(sig_data)):
_register_signal(sig_data, name=sig_name)

# 2nd level signals ('Trial1'/'AnalogData'/<signal_group>/<signal>')
Expand Down Expand Up @@ -317,17 +371,42 @@ def _filter_keys(full_dict, ignore_keys, simplify=True):

ml_ann = {k: v for k, v in self.ml_blocks.items() if k in ['MLConfig', 'TrialRecord']}
ml_ann = _filter_keys(ml_ann, ignore_annotations)
# normalize annotation values, convert arrays to lists

def recursively_replace_arrays(container: Union[dict, list]) -> None:
"""
Replace numpy arrays in nested dictionary and list structures by lists
"""

if isinstance(container, dict):
iterator = container.items()
elif isinstance(container, list):
iterator = enumerate(container)

for k, v in iterator:
if isinstance(v, dict):
recursively_replace_arrays(v)
elif isinstance(v, np.ndarray):
container[k] = list(v)
recursively_replace_arrays(container[k])
elif isinstance(v, list):
recursively_replace_arrays(container[k])

recursively_replace_arrays(ml_ann)

bl_ann = self.raw_annotations['blocks'][0]
bl_ann.update(ml_ann)

for trial_id in self.trial_ids:
ml_trial = self.ml_blocks[f'Trial{trial_id}']
assert ml_trial['Trial'] == trial_id

recursively_replace_arrays(ml_trial)

seg_ann = self.raw_annotations['blocks'][0]['segments'][trial_id-1]
seg_ann.update(_filter_keys(ml_trial, ignore_annotations))

event_ann = seg_ann['events'][0] # 0 is event
# event_ann = seg_ann['events'][0] # 0 is event
# epoch_ann = seg_ann['events'][1] # 1 is epoch

def _segment_t_start(self, block_index, seg_index):
Expand Down