Skip to content

Commit 166e094

Browse files
committed
[feat] use PosLinear to replace clipper operation
1 parent ee5ddfe commit 166e094

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

EduCDM/NCDM/NCDM.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from EduCDM import CDM
1313

1414

15+
class PosLinear(nn.Linear):
16+
def forward(self, input: torch.Tensor) -> torch.Tensor:
17+
weight = 2 * F.relu(1 * torch.neg(self.weight)) + self.weight
18+
return F.linear(input, weight, self.bias)
19+
20+
1521
class Net(nn.Module):
1622

1723
def __init__(self, knowledge_n, exer_n, student_n):
@@ -28,11 +34,11 @@ def __init__(self, knowledge_n, exer_n, student_n):
2834
self.student_emb = nn.Embedding(self.emb_num, self.stu_dim)
2935
self.k_difficulty = nn.Embedding(self.exer_n, self.knowledge_dim)
3036
self.e_difficulty = nn.Embedding(self.exer_n, 1)
31-
self.prednet_full1 = nn.Linear(self.prednet_input_len, self.prednet_len1)
37+
self.prednet_full1 = PosLinear(self.prednet_input_len, self.prednet_len1)
3238
self.drop_1 = nn.Dropout(p=0.5)
33-
self.prednet_full2 = nn.Linear(self.prednet_len1, self.prednet_len2)
39+
self.prednet_full2 = PosLinear(self.prednet_len1, self.prednet_len2)
3440
self.drop_2 = nn.Dropout(p=0.5)
35-
self.prednet_full3 = nn.Linear(self.prednet_len2, 1)
41+
self.prednet_full3 = PosLinear(self.prednet_len2, 1)
3642

3743
# initialize
3844
for name, param in self.named_parameters():
@@ -53,22 +59,6 @@ def forward(self, stu_id, input_exercise, input_knowledge_point):
5359

5460
return output_1.view(-1)
5561

56-
def apply_clipper(self):
57-
clipper = NoneNegClipper()
58-
self.prednet_full1.apply(clipper)
59-
self.prednet_full2.apply(clipper)
60-
self.prednet_full3.apply(clipper)
61-
62-
63-
class NoneNegClipper(object):
64-
def __init__(self):
65-
super(NoneNegClipper, self).__init__()
66-
67-
def __call__(self, module):
68-
if hasattr(module, 'weight'):
69-
w = module.weight.data
70-
module.weight.data = torch.clamp(w, min=0.).detach()
71-
7262

7363
class NCDM(CDM):
7464
'''Neural Cognitive Diagnosis Model'''
@@ -98,7 +88,6 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
9888
optimizer.zero_grad()
9989
loss.backward()
10090
optimizer.step()
101-
self.ncdm_net.apply_clipper()
10291

10392
epoch_losses.append(loss.mean().item())
10493

0 commit comments

Comments
 (0)