Skip to content

Commit 6cb0320

Browse files
authored
Better sampling and relax dependency (#2082)
1 parent 54c777a commit 6cb0320

File tree

4 files changed

+12
-18
lines changed

4 files changed

+12
-18
lines changed

.github/workflows/push.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- name: Install dependencies
2020
run: |
2121
python -m pip install --upgrade pip
22-
pip install --upgrade setuptools==50.3.0
22+
pip install --upgrade setuptools
2323
pip install -e .
2424
pip install -r requirements.opt.txt
2525
pip install flake8

onmt/translate/greedy_search.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,8 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp):
8383
logits = sample_topp(logits, keep_topp)
8484
if keep_topk > 0:
8585
logits = sample_topk(logits, keep_topk)
86-
dist = torch.distributions.Multinomial(
87-
logits=logits, total_count=1)
88-
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
86+
dist = torch.distributions.Categorical(logits=logits)
87+
topk_ids = dist.sample().view(-1, 1)
8988
topk_scores = logits.gather(dim=1, index=topk_ids)
9089
return topk_ids, topk_scores
9190

requirements.opt.txt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
cffi==1.14.3
2-
joblib==0.17.0
3-
numba==0.43.0
4-
llvmlite==0.32.1
5-
pyrouge==0.1.3
1+
pyrouge
62
git+git://github.com/NVIDIA/apex.git@700d6825e205732c1d6be511306ca4e595297070
7-
sentencepiece==0.1.94
8-
subword-nmt==0.3.7
3+
sentencepiece>=0.1.94
4+
subword-nmt>=0.3.7

setup.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
},
2222
python_requires=">=3.5",
2323
install_requires=[
24-
"tqdm>=4.51,<5",
25-
"torch==1.6.0",
24+
"torch>=1.6.0",
2625
"torchtext==0.5.0",
27-
"configargparse>=1.2.3,<2",
28-
"tensorboard>=2.3,<3",
29-
"flask==1.1.2",
30-
"waitress==1.4.4",
26+
"configargparse",
27+
"tensorboard>=2.3",
28+
"flask",
29+
"waitress",
3130
"pyonmttok>=1.23,<2;platform_system=='Linux' or platform_system=='Darwin'",
32-
"pyyaml==5.4",
31+
"pyyaml",
3332
],
3433
entry_points={
3534
"console_scripts": [

0 commit comments

Comments
 (0)