Skip to content

Commit b3fa90b

Browse files
committed
triggering event is based on the set value of num scm per minutes not just set threshold
1 parent 7ec0ba6 commit b3fa90b

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

realtime_decoder/stimulation.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(self, comm, rank, config, trodes_client):
195195
self._arm_2_posterior = [0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.26,0.27]
196196
self._task_state_2_start_time = None
197197
self._timepoints_per_sec = self._config['sampling_rate']['spikes']
198-
198+
self._elapsed_minutes = 0
199199

200200
def handle_message(self, msg, mpi_status):
201201
"""Process a (non neural data) received MPI message"""
@@ -223,14 +223,14 @@ def _update_gui_params(self, gui_msg):
223223

224224
# manual control of the replay target arm is only allowed
225225
# for a non-instructive task
226+
'''
226227
if not self._config['stimulation']['instructive']:
227228
arm = gui_msg.replay_target_arm
228229
self.p_replay['target_arm'] = arm
229230
self.class_log.info(
230231
"Non instructive task: Updated replay target arm to {arm}"
231232
)
232-
233-
self.p_replay['primary_arm_threshold'] = gui_msg.posterior_threshold
233+
'''
234234
self.p['max_center_well_dist'] = gui_msg.max_center_well_distance
235235
self.p_ripples['num_above_thresh'] = gui_msg.num_above_threshold
236236
self.p_head['min_duration'] = gui_msg.min_duration
@@ -244,7 +244,16 @@ def _update_gui_params(self, gui_msg):
244244
#NOTE(DS): I am using this as a starting spatial bin for target location
245245
self._well_angle_range = gui_msg.well_angle_range
246246
self._within_angle_range = gui_msg.within_angle_range
247-
self.p_replay['secondary_arm_threshold'] = gui_msg.min_duration
247+
248+
self.p_replay['primary_arm_threshold'] = gui_msg.posterior_threshold
249+
self.p_replay['secondary_arm_threshold'] = gui_msg.min_duration #NOTE(DS): previously secondary_arm_threshold was about two decoder, but now i am using as arm 2 threshold
250+
251+
#NOTE(DS): I am using this variable to get scm_rate per arm-- default will be 1?
252+
print(f"gui_msg.replay_target_arm is {gui_msg.replay_target_arm} with type: {type(gui_msg.replay_target_arm)}")
253+
print(f"current: {self._num_scm_each_arm_per_minute} and new : {float(gui_msg.replay_target_arm)}")
254+
if self._num_scm_each_arm_per_minute != float(gui_msg.replay_target_arm):
255+
self._num_scm_each_arm_per_minute = float(gui_msg.replay_target_arm)
256+
self._update_scm_threshold()
248257

249258

250259
def _update_ripples(self, msg):
@@ -936,6 +945,14 @@ def _handle_replay(self, arm, msg):
936945
trodes_of_spike = self._enc_ci_buff[self._enc_ci_buff != 0]
937946
#avg_arm_ps = np.mean(self._region_ps_buff[ind],axis = 0) #NOTE(DS): target arm + whole center
938947

948+
curent_time = msg[0]['bin_timestamp_l']
949+
if self._task_state == 2:
950+
# NOTE(DS): This computes the threshold to match the scm rate
951+
if self._task_state_2_start_time == None:
952+
self._task_state_2_start_time = msg[0]['bin_timestamp_l']
953+
954+
self._elapsed_minutes = (curent_time - self._task_state_2_start_time)/self._timepoints_per_sec/60 #minutes
955+
939956
if num_unique < self.p_replay['min_unique_trodes']:
940957
print(f"Replay arm {arm} detected less than min unique trodes in ts {self._task_state}")
941958
else:
@@ -962,7 +979,7 @@ def _handle_replay(self, arm, msg):
962979

963980
send_shortcut = self._check_send_shortcut(
964981
self.p_replay['enabled']
965-
) and above_threshold
982+
) and (above_threshold or num_spikes_in_event >= 6) # NOTE(DS): num_spikes_in_event >6 is to detect SWR
966983

967984
if num_unique >= self.p_replay['min_unique_trodes']:
968985

@@ -987,23 +1004,7 @@ def _handle_replay(self, arm, msg):
9871004
print(f" ")
9881005

9891006
if (np.sum(self._num_rewards[1:]) in [5,10,20,40]) and self._automatic_threshold_update:
990-
991-
curent_time = msg[0]['bin_timestamp_l']
992-
elapsed_minutes = (curent_time - self._task_state_2_start_time)/self._timepoints_per_sec/60 #minutes
993-
994-
desired_number_of_scm = np.ceil(self._num_scm_each_arm_per_minute * elapsed_minutes)
995-
index_for_desired_number_of_scm = -int(desired_number_of_scm + 1)
996-
997-
self.p_replay['primary_arm_threshold'] = np.sort(self._arm_1_posterior)[index_for_desired_number_of_scm]
998-
self.p_replay['secondary_arm_threshold'] = np.sort(self._arm_2_posterior)[index_for_desired_number_of_scm]
999-
1000-
print(f"new arm 1 thresh: {self.p_replay['primary_arm_threshold']} and arm 2 thresh: {self.p_replay['secondary_arm_threshold']}")
1001-
1002-
# NOTE(DS): This computes the threshold to match the scm rate
1003-
if self._task_state == 2 and self._task_state_2_start_time == None:
1004-
self._task_state_2_start_time = msg[0]['bin_timestamp_l']
1005-
1006-
1007+
self._update_scm_threshold()
10071008

10081009
self.write_record(
10091010
binary_record.RecordIDs.STIM_MESSAGE,
@@ -1031,7 +1032,14 @@ def _handle_replay(self, arm, msg):
10311032
*self._arm_ps_buff.mean(axis=1).flatten()
10321033
)
10331034

1035+
def _update_scm_threshold(self):
1036+
desired_number_of_scm = np.ceil(self._num_scm_each_arm_per_minute * self._elapsed_minutes)
1037+
index_for_desired_number_of_scm = -int(desired_number_of_scm + 1)
1038+
1039+
self.p_replay['primary_arm_threshold'] = np.sort(self._arm_1_posterior)[index_for_desired_number_of_scm]
1040+
self.p_replay['secondary_arm_threshold'] = np.sort(self._arm_2_posterior)[index_for_desired_number_of_scm]
10341041

1042+
print(f"new arm 1 thresh: {self.p_replay['primary_arm_threshold']} and arm 2 thresh: {self.p_replay['secondary_arm_threshold']}")
10351043

10361044

10371045
def _find_replay_instructive(self, msg):

0 commit comments

Comments
 (0)