diff --git a/nn_pruning/modules/quantization.py b/nn_pruning/modules/quantization.py index f4b92184..2b198a80 100644 --- a/nn_pruning/modules/quantization.py +++ b/nn_pruning/modules/quantization.py @@ -12,8 +12,7 @@ prepare_fx, prepare_qat_fx, ) -from transformers.modeling_fx_utils import symbolic_trace - +from transformers.utils.fx import symbolic_trace from .quantization_config import create_qconfig @@ -113,11 +112,10 @@ def _prepare( else: model.eval() - traced = symbolic_trace( - model, input_names=input_names, batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices - ) + traced = symbolic_trace(model, input_names=input_names) change_attention_mask_value(traced) + # traced=model prepare_custom_config_dict = {"preserved_attributes": ["config", "dummy_inputs"]} prepared_model = torch_prepare_fn(traced, qconfig_dict, prepare_custom_config_dict) diff --git a/setup.py b/setup.py index 1690cfda..a2a69bbb 100644 --- a/setup.py +++ b/setup.py @@ -5,14 +5,23 @@ def readme(): with open("README.md") as f: return f.read() + extras = { - "tests": ["pytest"], - "examples": ["numpy>=1.2.0", "datasets>=1.4.1", "ipywidgets>=7.6.3", "matplotlib>=3.3.4", "pandas>=1.2.3"], + "tests": ["pytest", "transformers==4.15.0", "torch==1.9"], + "examples": [ + "numpy>=1.2.0", + "datasets>=1.4.1", + "ipywidgets>=7.6.3", + "matplotlib>=3.3.4", + "pandas>=1.2.3", + ], } + def combine_requirements(base_keys): return list(set(k for v in base_keys for k in extras[v])) + extras["dev"] = combine_requirements([k for k in extras if k != "examples"]) @@ -33,7 +42,7 @@ def combine_requirements(base_keys): author_email="", license="MIT", packages=["nn_pruning", "nn_pruning.modules"], - install_requires=["click", "transformers>=4.3.0", "torch>=1.6", "scikit-learn>=0.24"], + install_requires=["click", "transformers>=4.3", "torch>=1.6", "scikit-learn>=0.24"], extras_require=extras, test_suite="nose.collector", tests_require=["nose", "nose-cover3"],