We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 87fd100 + c31416e commit e2c2d83Copy full SHA for e2c2d83
model.py
@@ -42,8 +42,7 @@ def __repr__(self):
42
class InferenceBatchSoftmax(nn.Module):
43
def forward(self, input_):
44
if not self.training:
45
- batch_size = input_.size()[0]
46
- return torch.stack([F.softmax(input_[i], dim=1) for i in range(batch_size)], 0)
+ return F.softmax(input_, dim=-1)
47
else:
48
return input_
49
0 commit comments