@@ -195,7 +195,7 @@ def __init__(self, comm, rank, config, trodes_client):
195
195
self ._arm_2_posterior = [0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.25 ,0.26 ,0.27 ]
196
196
self ._task_state_2_start_time = None
197
197
self ._timepoints_per_sec = self ._config ['sampling_rate' ]['spikes' ]
198
-
198
+ self . _elapsed_minutes = 0
199
199
200
200
def handle_message (self , msg , mpi_status ):
201
201
"""Process a (non neural data) received MPI message"""
@@ -223,14 +223,14 @@ def _update_gui_params(self, gui_msg):
223
223
224
224
# manual control of the replay target arm is only allowed
225
225
# for a non-instructive task
226
+ '''
226
227
if not self._config['stimulation']['instructive']:
227
228
arm = gui_msg.replay_target_arm
228
229
self.p_replay['target_arm'] = arm
229
230
self.class_log.info(
230
231
"Non instructive task: Updated replay target arm to {arm}"
231
232
)
232
-
233
- self .p_replay ['primary_arm_threshold' ] = gui_msg .posterior_threshold
233
+ '''
234
234
self .p ['max_center_well_dist' ] = gui_msg .max_center_well_distance
235
235
self .p_ripples ['num_above_thresh' ] = gui_msg .num_above_threshold
236
236
self .p_head ['min_duration' ] = gui_msg .min_duration
@@ -244,7 +244,16 @@ def _update_gui_params(self, gui_msg):
244
244
#NOTE(DS): I am using this as a starting spatial bin for target location
245
245
self ._well_angle_range = gui_msg .well_angle_range
246
246
self ._within_angle_range = gui_msg .within_angle_range
247
- self .p_replay ['secondary_arm_threshold' ] = gui_msg .min_duration
247
+
248
+ self .p_replay ['primary_arm_threshold' ] = gui_msg .posterior_threshold
249
+ self .p_replay ['secondary_arm_threshold' ] = gui_msg .min_duration #NOTE(DS): previously secondary_arm_threshold was about two decoder, but now i am using as arm 2 threshold
250
+
251
+ #NOTE(DS): I am using this variable to get scm_rate per arm-- default will be 1?
252
+ print (f"gui_msg.replay_target_arm is { gui_msg .replay_target_arm } with type: { type (gui_msg .replay_target_arm )} " )
253
+ print (f"current: { self ._num_scm_each_arm_per_minute } and new : { float (gui_msg .replay_target_arm )} " )
254
+ if self ._num_scm_each_arm_per_minute != float (gui_msg .replay_target_arm ):
255
+ self ._num_scm_each_arm_per_minute = float (gui_msg .replay_target_arm )
256
+ self ._update_scm_threshold ()
248
257
249
258
250
259
def _update_ripples (self , msg ):
@@ -936,6 +945,14 @@ def _handle_replay(self, arm, msg):
936
945
trodes_of_spike = self ._enc_ci_buff [self ._enc_ci_buff != 0 ]
937
946
#avg_arm_ps = np.mean(self._region_ps_buff[ind],axis = 0) #NOTE(DS): target arm + whole center
938
947
948
+ curent_time = msg [0 ]['bin_timestamp_l' ]
949
+ if self ._task_state == 2 :
950
+ # NOTE(DS): This computes the threshold to match the scm rate
951
+ if self ._task_state_2_start_time == None :
952
+ self ._task_state_2_start_time = msg [0 ]['bin_timestamp_l' ]
953
+
954
+ self ._elapsed_minutes = (curent_time - self ._task_state_2_start_time )/ self ._timepoints_per_sec / 60 #minutes
955
+
939
956
if num_unique < self .p_replay ['min_unique_trodes' ]:
940
957
print (f"Replay arm { arm } detected less than min unique trodes in ts { self ._task_state } " )
941
958
else :
@@ -962,7 +979,7 @@ def _handle_replay(self, arm, msg):
962
979
963
980
send_shortcut = self ._check_send_shortcut (
964
981
self .p_replay ['enabled' ]
965
- ) and above_threshold
982
+ ) and ( above_threshold or num_spikes_in_event >= 6 ) # NOTE(DS): num_spikes_in_event >6 is to detect SWR
966
983
967
984
if num_unique >= self .p_replay ['min_unique_trodes' ]:
968
985
@@ -987,23 +1004,7 @@ def _handle_replay(self, arm, msg):
987
1004
print (f" " )
988
1005
989
1006
if (np .sum (self ._num_rewards [1 :]) in [5 ,10 ,20 ,40 ]) and self ._automatic_threshold_update :
990
-
991
- curent_time = msg [0 ]['bin_timestamp_l' ]
992
- elapsed_minutes = (curent_time - self ._task_state_2_start_time )/ self ._timepoints_per_sec / 60 #minutes
993
-
994
- desired_number_of_scm = np .ceil (self ._num_scm_each_arm_per_minute * elapsed_minutes )
995
- index_for_desired_number_of_scm = - int (desired_number_of_scm + 1 )
996
-
997
- self .p_replay ['primary_arm_threshold' ] = np .sort (self ._arm_1_posterior )[index_for_desired_number_of_scm ]
998
- self .p_replay ['secondary_arm_threshold' ] = np .sort (self ._arm_2_posterior )[index_for_desired_number_of_scm ]
999
-
1000
- print (f"new arm 1 thresh: { self .p_replay ['primary_arm_threshold' ]} and arm 2 thresh: { self .p_replay ['secondary_arm_threshold' ]} " )
1001
-
1002
- # NOTE(DS): This computes the threshold to match the scm rate
1003
- if self ._task_state == 2 and self ._task_state_2_start_time == None :
1004
- self ._task_state_2_start_time = msg [0 ]['bin_timestamp_l' ]
1005
-
1006
-
1007
+ self ._update_scm_threshold ()
1007
1008
1008
1009
self .write_record (
1009
1010
binary_record .RecordIDs .STIM_MESSAGE ,
@@ -1031,7 +1032,14 @@ def _handle_replay(self, arm, msg):
1031
1032
* self ._arm_ps_buff .mean (axis = 1 ).flatten ()
1032
1033
)
1033
1034
1035
+ def _update_scm_threshold (self ):
1036
+ desired_number_of_scm = np .ceil (self ._num_scm_each_arm_per_minute * self ._elapsed_minutes )
1037
+ index_for_desired_number_of_scm = - int (desired_number_of_scm + 1 )
1038
+
1039
+ self .p_replay ['primary_arm_threshold' ] = np .sort (self ._arm_1_posterior )[index_for_desired_number_of_scm ]
1040
+ self .p_replay ['secondary_arm_threshold' ] = np .sort (self ._arm_2_posterior )[index_for_desired_number_of_scm ]
1034
1041
1042
+ print (f"new arm 1 thresh: { self .p_replay ['primary_arm_threshold' ]} and arm 2 thresh: { self .p_replay ['secondary_arm_threshold' ]} " )
1035
1043
1036
1044
1037
1045
def _find_replay_instructive (self , msg ):
0 commit comments