@@ -696,7 +696,7 @@ def good_loss_bound(model):
696696a = second_layer_attention (matrices , attn_1 )
697697loss = 1 - (torch .nansum (a * weights_1 ) / (weights_1 .sum ()))
698698print (a [~ torch .isnan (a )].min ())
699- print (a [ ~ torch .isnan ( a )]. mean ( ))
699+ print (torch .nansum ( a * weights_1 ) / ( weights_1 . sum () ))
700700print (a [~ torch .isnan (a )].max ())
701701while loss > 0.5 :
702702 # torch.autograd.set_detect_anomaly(True)
@@ -724,38 +724,37 @@ def good_loss_bound(model):
724724 .bool ()
725725 .to (device )
726726)
727+ weights_2 = ein .array (
728+ lambda i , j , k : where (k > 0 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 )
729+ * ((d_voc - 1 ) * ((d_voc - 1 ) ** (j - 2 ))),
730+ sizes = [d_voc , n_ctx , n_ctx ],
731+ ).to (device )
727732# %%
728733optimiser = torch .optim .AdamW (
729- model_1 .parameters (), lr = 1 , betas = (0.9 , 0.999 ), weight_decay = 0
734+ model_1 .parameters (), lr = 1e-2 , betas = (0.9 , 0.999 ), weight_decay = 1. 0
730735)
731736# %%
732- optimiser = torch .optim .SGD (model_1 .parameters (), lr = 100 )
737+ # optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
733738# %%
734- a = loss_bound (model_1 , 3 )[4 ]
735- loss = 1 - a [valid ].mean ()
736- print (a [valid ].min ())
737- print (a [valid ].mean ())
738- print (a [valid ].max ())
739- for i in range (1 ):
740- print (i + 1 )
741-
739+ bound = loss_bound (model_1 )[1 ]
740+ loss = 1 - (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
741+ print (bound [valid ].min ())
742+ print (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
743+ print (bound [valid ].max ())
744+ while loss > 0.5 :
745+ # torch.autograd.set_detect_anomaly(True)
742746 loss .backward ()
747+ # torch.nn.utils.clip_grad_norm_(model_1.parameters(), max_norm=1.0)
743748 optimiser .step ()
744- for param in model_1 .parameters ():
745- if param .requires_grad :
746- print (param .grad .norm ()) # Check gradient norms
747-
748749 optimiser .zero_grad ()
749- a = loss_bound (model_1 , 3 )[4 ]
750- loss = 1 - a [valid ].mean ()
751- print (a [valid ].min ())
752- print (a [valid ].mean ())
753- print (a [valid ].max ())
754- if i % 10 == 1 :
755- r = loss_bound (model_1 , 4 )[5 ]
756- print (r [valid ].min ())
757- print (r [valid ].mean ())
758- print (r [valid ].max ())
750+ bound = loss_bound (model_1 )[1 ]
751+ loss = 1 - (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
752+ counter += 1
753+ print (counter )
754+ print (bound [valid ].min ())
755+ print (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
756+ print (bound [valid ].max ())
757+
759758
760759# %%
761760ModelMatrixLoggingOptions .all (
0 commit comments