Skip to content

Commit 75d645a

Browse files
roll back tests for model forward
1 parent 8616b99 commit 75d645a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onmt/tests/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1):
134134

135135
test_src, test_tgt, test_length = self.get_batch(source_l=source_l,
136136
bsize=bsize)
137-
outputs, attn, _, _ = model(test_src, test_tgt, test_length)
137+
outputs, attn = model(test_src, test_tgt, test_length)
138138
outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size)
139139
# Make sure that output has the correct size and type
140140
self.assertEqual(outputs.size(), outputsize.size())
@@ -168,7 +168,7 @@ def imagemodel_forward(self, opt, tgt_l=2, bsize=1, h=15, w=17):
168168
h=h, w=w,
169169
bsize=bsize,
170170
tgt_l=tgt_l)
171-
outputs, attn, _, _ = model(test_src, test_tgt, test_length)
171+
outputs, attn = model(test_src, test_tgt, test_length)
172172
outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size)
173173
# Make sure that output has the correct size and type
174174
self.assertEqual(outputs.size(), outputsize.size())
@@ -206,7 +206,7 @@ def audiomodel_forward(self, opt, tgt_l=7, bsize=3, t=37):
206206
sample_rate=opt.sample_rate,
207207
window_size=opt.window_size,
208208
t=t, tgt_l=tgt_l)
209-
outputs, attn, _, _ = model(test_src, test_tgt, test_length)
209+
outputs, attn = model(test_src, test_tgt, test_length)
210210
outputsize = torch.zeros(tgt_l - 1, bsize, opt.dec_rnn_size)
211211
# Make sure that output has the correct size and type
212212
self.assertEqual(outputs.size(), outputsize.size())

0 commit comments

Comments
 (0)