12
12
from EduCDM import CDM
13
13
14
14
15
+ class PosLinear (nn .Linear ):
16
+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
17
+ weight = 2 * F .relu (1 * torch .neg (self .weight )) + self .weight
18
+ return F .linear (input , weight , self .bias )
19
+
20
+
15
21
class Net (nn .Module ):
16
22
17
23
def __init__ (self , knowledge_n , exer_n , student_n ):
@@ -28,11 +34,11 @@ def __init__(self, knowledge_n, exer_n, student_n):
28
34
self .student_emb = nn .Embedding (self .emb_num , self .stu_dim )
29
35
self .k_difficulty = nn .Embedding (self .exer_n , self .knowledge_dim )
30
36
self .e_difficulty = nn .Embedding (self .exer_n , 1 )
31
- self .prednet_full1 = nn . Linear (self .prednet_input_len , self .prednet_len1 )
37
+ self .prednet_full1 = PosLinear (self .prednet_input_len , self .prednet_len1 )
32
38
self .drop_1 = nn .Dropout (p = 0.5 )
33
- self .prednet_full2 = nn . Linear (self .prednet_len1 , self .prednet_len2 )
39
+ self .prednet_full2 = PosLinear (self .prednet_len1 , self .prednet_len2 )
34
40
self .drop_2 = nn .Dropout (p = 0.5 )
35
- self .prednet_full3 = nn . Linear (self .prednet_len2 , 1 )
41
+ self .prednet_full3 = PosLinear (self .prednet_len2 , 1 )
36
42
37
43
# initialize
38
44
for name , param in self .named_parameters ():
@@ -53,22 +59,6 @@ def forward(self, stu_id, input_exercise, input_knowledge_point):
53
59
54
60
return output_1 .view (- 1 )
55
61
56
- def apply_clipper (self ):
57
- clipper = NoneNegClipper ()
58
- self .prednet_full1 .apply (clipper )
59
- self .prednet_full2 .apply (clipper )
60
- self .prednet_full3 .apply (clipper )
61
-
62
-
63
- class NoneNegClipper (object ):
64
- def __init__ (self ):
65
- super (NoneNegClipper , self ).__init__ ()
66
-
67
- def __call__ (self , module ):
68
- if hasattr (module , 'weight' ):
69
- w = module .weight .data
70
- module .weight .data = torch .clamp (w , min = 0. ).detach ()
71
-
72
62
73
63
class NCDM (CDM ):
74
64
'''Neural Cognitive Diagnosis Model'''
@@ -98,7 +88,6 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
98
88
optimizer .zero_grad ()
99
89
loss .backward ()
100
90
optimizer .step ()
101
- self .ncdm_net .apply_clipper ()
102
91
103
92
epoch_losses .append (loss .mean ().item ())
104
93
0 commit comments