@@ -102,7 +102,12 @@ def update(self, **kwargs) -> None:
102102 if ((self .min is not None ) or (self .max is not None )) and not isinstance (
103103 self , NoOp
104104 ):
105- self .feature_value .clamp_ (self .min , self .max )
105+ if self .feature_value .is_sparse :
106+ self .feature_value = (
107+ self .feature_value .to_dense ().clamp_ (self .min , self .max ).to_sparse ()
108+ )
109+ else :
110+ self .feature_value .clamp_ (self .min , self .max )
106111
107112 @abstractmethod
108113 def reset_state_variables (self ) -> None :
@@ -247,10 +252,15 @@ def _connection_update(self, **kwargs) -> None:
247252 torch .mean (self .average_buffer_pre , dim = 0 ) * self .connection .dt
248253 )
249254 else :
250- self .feature_value -= (
251- self .reduction (torch .bmm (source_s , target_x ), dim = 0 )
252- * self .connection .dt
253- )
255+ if self .feature_value .is_sparse :
256+ self .feature_value -= (
257+ torch .bmm (source_s , target_x ) * self .connection .dt
258+ ).to_sparse ()
259+ else :
260+ self .feature_value -= (
261+ self .reduction (torch .bmm (source_s , target_x ), dim = 0 )
262+ * self .connection .dt
263+ )
254264 del source_s , target_x
255265
256266 # Post-synaptic update.
@@ -278,10 +288,15 @@ def _connection_update(self, **kwargs) -> None:
278288 torch .mean (self .average_buffer_post , dim = 0 ) * self .connection .dt
279289 )
280290 else :
281- self .feature_value += (
282- self .reduction (torch .bmm (source_x , target_s ), dim = 0 )
283- * self .connection .dt
284- )
291+ if self .feature_value .is_sparse :
292+ self .feature_value += (
293+ torch .bmm (source_x , target_s ) * self .connection .dt
294+ ).to_sparse ()
295+ else :
296+ self .feature_value += (
297+ self .reduction (torch .bmm (source_x , target_s ), dim = 0 )
298+ * self .connection .dt
299+ )
285300 del source_x , target_s
286301
287302 super ().update ()
@@ -508,16 +523,16 @@ def _connection_update(self, **kwargs) -> None:
508523 self .average_buffer_index + 1
509524 ) % self .average_update
510525
511- if self .continues_update :
512- self .feature_value += self .nu [0 ] * torch .mean (
513- self .average_buffer , dim = 0
514- )
515- elif self .average_buffer_index == 0 :
516- self .feature_value += self .nu [0 ] * torch .mean (
517- self .average_buffer , dim = 0
518- )
526+ if self .continues_update or self .average_buffer_index == 0 :
527+ update = self .nu [0 ] * torch .mean (self .average_buffer , dim = 0 )
528+ if self .feature_value .is_sparse :
529+ update = update .to_sparse ()
530+ self .feature_value += update
519531 else :
520- self .feature_value += self .nu [0 ] * self .reduction (update , dim = 0 )
532+ update = self .nu [0 ] * self .reduction (update , dim = 0 )
533+ if self .feature_value .is_sparse :
534+ update = update .to_sparse ()
535+ self .feature_value += update
521536
522537 # Update P^+ and P^- values.
523538 self .p_plus *= torch .exp (- self .connection .dt / self .tc_plus )
@@ -686,14 +701,16 @@ def _connection_update(self, **kwargs) -> None:
686701 self .average_buffer_index + 1
687702 ) % self .average_update
688703
689- if self .continues_update :
690- self .feature_value += torch .mean (self .average_buffer , dim = 0 )
691- elif self .average_buffer_index == 0 :
692- self .feature_value += torch .mean (self .average_buffer , dim = 0 )
704+ if self .continues_update or self .average_buffer_index == 0 :
705+ update = torch .mean (self .average_buffer , dim = 0 )
706+ if self .feature_value .is_sparse :
707+ update = update .to_sparse ()
708+ self .feature_value += update
693709 else :
694- self .feature_value += (
695- self .nu [0 ] * self .connection .dt * reward * self .eligibility_trace
696- )
710+ update = self .nu [0 ] * self .connection .dt * reward * self .eligibility_trace
711+ if self .feature_value .is_sparse :
712+ update = update .to_sparse ()
713+ self .feature_value += update
697714
698715 # Update P^+ and P^- values.
699716 self .p_plus *= torch .exp (- self .connection .dt / self .tc_plus ) # Decay
0 commit comments