Skip to content

Commit 4c894b7

Browse files
Merge pull request #703 from BindsNET/sparse_connection
SparseConnection support
2 parents 731206e + 6b6e4b5 commit 4c894b7

File tree

11 files changed

+495
-188
lines changed

11 files changed

+495
-188
lines changed

bindsnet/evaluation/evaluation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ def assign_labels(
4444
indices = torch.nonzero(labels == i).view(-1)
4545

4646
# Compute average firing rates for this label.
47+
selected_spikes = torch.index_select(
48+
spikes, dim=0, index=torch.tensor(indices)
49+
)
4750
rates[:, i] = alpha * rates[:, i] + (
48-
torch.sum(spikes[indices], 0) / n_labeled
51+
torch.sum(selected_spikes, 0) / n_labeled
4952
)
5053

5154
# Compute proportions of spike activity per class.
@@ -111,6 +114,8 @@ def all_activity(
111114

112115
# Sum over time dimension (spike ordering doesn't matter).
113116
spikes = spikes.sum(1)
117+
if spikes.is_sparse:
118+
spikes = spikes.to_dense()
114119

115120
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
116121
for i in range(n_labels):
@@ -152,6 +157,8 @@ def proportion_weighting(
152157

153158
# Sum over time dimension (spike ordering doesn't matter).
154159
spikes = spikes.sum(1)
160+
if spikes.is_sparse:
161+
spikes = spikes.to_dense()
155162

156163
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
157164
for i in range(n_labels):

bindsnet/learning/MCC_learning.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def update(self, **kwargs) -> None:
102102
if ((self.min is not None) or (self.max is not None)) and not isinstance(
103103
self, NoOp
104104
):
105-
self.feature_value.clamp_(self.min, self.max)
105+
if self.feature_value.is_sparse:
106+
self.feature_value = (
107+
self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
108+
)
109+
else:
110+
self.feature_value.clamp_(self.min, self.max)
106111

107112
@abstractmethod
108113
def reset_state_variables(self) -> None:
@@ -247,10 +252,15 @@ def _connection_update(self, **kwargs) -> None:
247252
torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
248253
)
249254
else:
250-
self.feature_value -= (
251-
self.reduction(torch.bmm(source_s, target_x), dim=0)
252-
* self.connection.dt
253-
)
255+
if self.feature_value.is_sparse:
256+
self.feature_value -= (
257+
torch.bmm(source_s, target_x) * self.connection.dt
258+
).to_sparse()
259+
else:
260+
self.feature_value -= (
261+
self.reduction(torch.bmm(source_s, target_x), dim=0)
262+
* self.connection.dt
263+
)
254264
del source_s, target_x
255265

256266
# Post-synaptic update.
@@ -278,10 +288,15 @@ def _connection_update(self, **kwargs) -> None:
278288
torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
279289
)
280290
else:
281-
self.feature_value += (
282-
self.reduction(torch.bmm(source_x, target_s), dim=0)
283-
* self.connection.dt
284-
)
291+
if self.feature_value.is_sparse:
292+
self.feature_value += (
293+
torch.bmm(source_x, target_s) * self.connection.dt
294+
).to_sparse()
295+
else:
296+
self.feature_value += (
297+
self.reduction(torch.bmm(source_x, target_s), dim=0)
298+
* self.connection.dt
299+
)
285300
del source_x, target_s
286301

287302
super().update()
@@ -508,16 +523,16 @@ def _connection_update(self, **kwargs) -> None:
508523
self.average_buffer_index + 1
509524
) % self.average_update
510525

511-
if self.continues_update:
512-
self.feature_value += self.nu[0] * torch.mean(
513-
self.average_buffer, dim=0
514-
)
515-
elif self.average_buffer_index == 0:
516-
self.feature_value += self.nu[0] * torch.mean(
517-
self.average_buffer, dim=0
518-
)
526+
if self.continues_update or self.average_buffer_index == 0:
527+
update = self.nu[0] * torch.mean(self.average_buffer, dim=0)
528+
if self.feature_value.is_sparse:
529+
update = update.to_sparse()
530+
self.feature_value += update
519531
else:
520-
self.feature_value += self.nu[0] * self.reduction(update, dim=0)
532+
update = self.nu[0] * self.reduction(update, dim=0)
533+
if self.feature_value.is_sparse:
534+
update = update.to_sparse()
535+
self.feature_value += update
521536

522537
# Update P^+ and P^- values.
523538
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -686,14 +701,16 @@ def _connection_update(self, **kwargs) -> None:
686701
self.average_buffer_index + 1
687702
) % self.average_update
688703

689-
if self.continues_update:
690-
self.feature_value += torch.mean(self.average_buffer, dim=0)
691-
elif self.average_buffer_index == 0:
692-
self.feature_value += torch.mean(self.average_buffer, dim=0)
704+
if self.continues_update or self.average_buffer_index == 0:
705+
update = torch.mean(self.average_buffer, dim=0)
706+
if self.feature_value.is_sparse:
707+
update = update.to_sparse()
708+
self.feature_value += update
693709
else:
694-
self.feature_value += (
695-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
696-
)
710+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
711+
if self.feature_value.is_sparse:
712+
update = update.to_sparse()
713+
self.feature_value += update
697714

698715
# Update P^+ and P^- values.
699716
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay

bindsnet/learning/learning.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def update(self) -> None:
9898
(self.connection.wmin != -np.inf).any()
9999
or (self.connection.wmax != np.inf).any()
100100
) and not isinstance(self, NoOp):
101-
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
101+
if self.connection.w.is_sparse:
102+
raise Exception("SparseConnection isn't supported for wmin\\wmax")
103+
else:
104+
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
102105

103106

104107
class NoOp(LearningRule):
@@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
396399
if self.nu[0].any():
397400
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
398401
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
399-
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
402+
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
403+
if self.connection.w.is_sparse:
404+
update = update.to_sparse()
405+
self.connection.w -= update
400406
del source_s, target_x
401407

402408
# Post-synaptic update.
@@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
405411
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
406412
)
407413
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
408-
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
414+
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
415+
if self.connection.w.is_sparse:
416+
update = update.to_sparse()
417+
self.connection.w += update
409418
del source_x, target_s
410419

411420
super().update()
@@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:
11131122

11141123
# Pre-synaptic update.
11151124
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
1125+
if self.connection.w.is_sparse:
1126+
update = update.to_sparse()
11161127
self.connection.w += self.nu[0] * update
11171128

11181129
# Post-synaptic update.
11191130
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
1131+
if self.connection.w.is_sparse:
1132+
update = update.to_sparse()
11201133
self.connection.w += self.nu[1] * update
11211134

11221135
super().update()
@@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
15421555
a_minus = torch.tensor(a_minus, device=self.connection.w.device)
15431556

15441557
# Compute weight update based on the eligibility value of the past timestep.
1545-
update = reward * self.eligibility
1546-
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
1558+
update = self.reduction(reward * self.eligibility, dim=0)
1559+
if self.connection.w.is_sparse:
1560+
update = update.to_sparse()
1561+
self.connection.w += self.nu[0] * update
15471562

15481563
# Update P^+ and P^- values.
15491564
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
22142229
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
22152230
self.eligibility_trace += self.eligibility / self.tc_e_trace
22162231

2232+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2233+
if self.connection.w.is_sparse:
2234+
update = update.to_sparse()
22172235
# Compute weight update.
2218-
self.connection.w += (
2219-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2220-
)
2236+
self.connection.w += update
22212237

22222238
# Update P^+ and P^- values.
22232239
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
29362952
) * source_x[:, None]
29372953

29382954
# Compute weight update.
2939-
self.connection.w += self.nu[0] * reward * self.eligibility_trace
2955+
update = self.nu[0] * reward * self.eligibility_trace
2956+
if self.connection.w.is_sparse:
2957+
update = update.to_sparse()
2958+
self.connection.w += update
29402959

29412960
super().update()

bindsnet/models/models.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
import torch
55
from scipy.spatial.distance import euclidean
66
from torch.nn.modules.utils import _pair
7+
from torch import device
78

89
from bindsnet.learning import PostPre
10+
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
911
from bindsnet.network import Network
1012
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
11-
from bindsnet.network.topology import Connection, LocalConnection
13+
from bindsnet.network.topology import (
14+
Connection,
15+
LocalConnection,
16+
MulticompartmentConnection,
17+
)
18+
from bindsnet.network.topology_features import Weight
1219

1320

1421
class TwoLayerNetwork(Network):
@@ -94,6 +101,9 @@ class DiehlAndCook2015(Network):
94101
def __init__(
95102
self,
96103
n_inpt: int,
104+
device: str = "cpu",
105+
batch_size: int = None,
106+
sparse: bool = False,
97107
n_neurons: int = 100,
98108
exc: float = 22.5,
99109
inh: float = 17.5,
@@ -170,27 +180,44 @@ def __init__(
170180

171181
# Connections
172182
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
173-
input_exc_conn = Connection(
183+
input_exc_conn = MulticompartmentConnection(
174184
source=input_layer,
175185
target=exc_layer,
176-
w=w,
177-
update_rule=PostPre,
178-
nu=nu,
179-
reduction=reduction,
180-
wmin=wmin,
181-
wmax=wmax,
182-
norm=norm,
186+
device=device,
187+
pipeline=[
188+
Weight(
189+
"weight",
190+
w,
191+
range=[wmin, wmax],
192+
norm=norm,
193+
reduction=reduction,
194+
nu=nu,
195+
learning_rule=MMCPostPre,
196+
sparse=sparse,
197+
batch_size=batch_size,
198+
)
199+
],
183200
)
184201
w = self.exc * torch.diag(torch.ones(self.n_neurons))
185-
exc_inh_conn = Connection(
186-
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
202+
if sparse:
203+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
204+
exc_inh_conn = MulticompartmentConnection(
205+
source=exc_layer,
206+
target=inh_layer,
207+
device=device,
208+
pipeline=[Weight("weight", w, range=[0, self.exc], sparse=sparse)],
187209
)
188210
w = -self.inh * (
189211
torch.ones(self.n_neurons, self.n_neurons)
190212
- torch.diag(torch.ones(self.n_neurons))
191213
)
192-
inh_exc_conn = Connection(
193-
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
214+
if sparse:
215+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
216+
inh_exc_conn = MulticompartmentConnection(
217+
source=inh_layer,
218+
target=exc_layer,
219+
device=device,
220+
pipeline=[Weight("weight", w, range=[-self.inh, 0], sparse=sparse)],
194221
)
195222

196223
# Add to network

bindsnet/network/monitors.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
time: Optional[int] = None,
4646
batch_size: int = 1,
4747
device: str = "cpu",
48+
sparse: Optional[bool] = False,
4849
):
4950
# language=rst
5051
"""
@@ -62,6 +63,7 @@ def __init__(
6263
self.time = time
6364
self.batch_size = batch_size
6465
self.device = device
66+
self.sparse = sparse
6567

6668
# if time is not specified the monitor variable accumulate the logs
6769
if self.time is None:
@@ -98,11 +100,12 @@ def record(self) -> None:
98100
for v in self.state_vars:
99101
data = getattr(self.obj, v).unsqueeze(0)
100102
# self.recording[v].append(data.detach().clone().to(self.device))
101-
self.recording[v].append(
102-
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
103-
data, non_blocking=True
104-
)
105-
)
103+
record = torch.empty_like(
104+
data, device=self.device, requires_grad=False
105+
).copy_(data, non_blocking=True)
106+
if self.sparse:
107+
record = record.to_sparse()
108+
self.recording[v].append(record)
106109
# remove the oldest element (first in the list)
107110
if self.time is not None:
108111
self.recording[v].pop(0)

0 commit comments

Comments
 (0)