|
| 1 | +# coding: utf-8 |
| 2 | +# 2021/7/1 @ tongshiwei |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from torch import nn |
| 8 | +from EduCDM import GDDINA |
| 9 | +from .loss import PairSCELoss, HarmonicLoss, loss_mask |
| 10 | +from tqdm import tqdm |
| 11 | +from longling.ML.metrics import ranking_report |
| 12 | + |
| 13 | + |
| 14 | +class DINA(GDDINA): |
| 15 | + def __init__(self, user_num, item_num, knowledge_num, ste=False, zeta=0.5): |
| 16 | + super(DINA, self).__init__(user_num, item_num, knowledge_num, ste) |
| 17 | + self.zeta = zeta |
| 18 | + |
| 19 | + def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: |
| 20 | + point_loss_function = nn.BCELoss() |
| 21 | + pair_loss_function = PairSCELoss() |
| 22 | + loss_function = HarmonicLoss(self.zeta) |
| 23 | + |
| 24 | + trainer = torch.optim.Adam(self.dina_net.parameters(), lr, weight_decay=1e-4) |
| 25 | + |
| 26 | + for e in range(epoch): |
| 27 | + point_losses = [] |
| 28 | + pair_losses = [] |
| 29 | + losses = [] |
| 30 | + for batch_data in tqdm(train_data, "Epoch %s" % e): |
| 31 | + user_id, item_id, knowledge, score, n_samples, *neg_users = batch_data |
| 32 | + user_id: torch.Tensor = user_id.to(device) |
| 33 | + item_id: torch.Tensor = item_id.to(device) |
| 34 | + knowledge: torch.Tensor = knowledge.to(device) |
| 35 | + predicted_pos_score: torch.Tensor = self.dina_net(user_id, item_id, knowledge) |
| 36 | + score: torch.Tensor = score.to(device) |
| 37 | + neg_score = 1 - score |
| 38 | + |
| 39 | + point_loss = point_loss_function(predicted_pos_score, score) |
| 40 | + predicted_neg_scores = [] |
| 41 | + if neg_users: |
| 42 | + for neg_user in neg_users: |
| 43 | + predicted_neg_score = self.dina_net(neg_user, item_id, knowledge) |
| 44 | + predicted_neg_scores.append(predicted_neg_score) |
| 45 | + |
| 46 | + # prediction loss |
| 47 | + pair_pred_loss_list = [] |
| 48 | + for i, predicted_neg_score in enumerate(predicted_neg_scores): |
| 49 | + pair_pred_loss_list.append( |
| 50 | + pair_loss_function( |
| 51 | + predicted_pos_score, |
| 52 | + predicted_neg_score, |
| 53 | + score - neg_score |
| 54 | + ) |
| 55 | + ) |
| 56 | + |
| 57 | + pair_loss = sum(loss_mask(pair_pred_loss_list, n_samples)) |
| 58 | + else: |
| 59 | + pair_loss = 0 |
| 60 | + |
| 61 | + loss = loss_function(point_loss, pair_loss) |
| 62 | + |
| 63 | + # back propagation |
| 64 | + trainer.zero_grad() |
| 65 | + loss.backward() |
| 66 | + trainer.step() |
| 67 | + |
| 68 | + point_losses.append(point_loss.mean().item()) |
| 69 | + pair_losses.append(pair_loss.mean().item() if not isinstance(pair_loss, int) else pair_loss) |
| 70 | + losses.append(loss.item()) |
| 71 | + print( |
| 72 | + "[Epoch %d] Loss: %.6f, PointLoss: %.6f, PairLoss: %.6f" % ( |
| 73 | + e, float(np.mean(losses)), float(np.mean(point_losses)), float(np.mean(pair_losses)) |
| 74 | + ) |
| 75 | + ) |
| 76 | + |
| 77 | + if test_data is not None: |
| 78 | + eval_data = self.eval(test_data) |
| 79 | + print("[Epoch %d]\n%s" % (e, eval_data)) |
| 80 | + |
| 81 | + def eval(self, test_data, device="cpu"): |
| 82 | + self.dina_net.eval() |
| 83 | + y_pred = [] |
| 84 | + y_true = [] |
| 85 | + items = [] |
| 86 | + for batch_data in tqdm(test_data, "evaluating"): |
| 87 | + user_id, item_id, knowledge, response = batch_data |
| 88 | + user_id: torch.Tensor = user_id.to(device) |
| 89 | + item_id: torch.Tensor = item_id.to(device) |
| 90 | + pred: torch.Tensor = self.dina_net(user_id, item_id, knowledge) |
| 91 | + y_pred.extend(pred.tolist()) |
| 92 | + y_true.extend(response.tolist()) |
| 93 | + items.extend(item_id.tolist()) |
| 94 | + |
| 95 | + df = pd.DataFrame({ |
| 96 | + "item_id": items, |
| 97 | + "score": y_true, |
| 98 | + "pred": y_pred, |
| 99 | + }) |
| 100 | + |
| 101 | + ground_truth = [] |
| 102 | + prediction = [] |
| 103 | + |
| 104 | + for _, group_df in tqdm(df.groupby("item_id"), "formatting item df"): |
| 105 | + ground_truth.append(group_df["score"].values) |
| 106 | + prediction.append(group_df["pred"].values) |
| 107 | + |
| 108 | + self.dina_net.train() |
| 109 | + |
| 110 | + return ranking_report( |
| 111 | + ground_truth, |
| 112 | + y_pred=prediction, |
| 113 | + coerce="padding" |
| 114 | + ) |
0 commit comments