diff --git a/ane_transformers/huggingface/test_distilbert.py b/ane_transformers/huggingface/test_distilbert.py index b14f14d..f5e909d 100644 --- a/ane_transformers/huggingface/test_distilbert.py +++ b/ane_transformers/huggingface/test_distilbert.py @@ -3,6 +3,7 @@ # Copyright (C) 2022 Apple Inc. All Rights Reserved. # +import einops from ane_transformers import testing_utils import collections import coremltools as ct @@ -62,7 +63,12 @@ def setUpClass(cls): cls.models[ 'test'] = ane_transformers.DistilBertForSequenceClassification( cls.models['ref'].config).eval() - cls.models['test'].load_state_dict(cls.models['ref'].state_dict()) + ref_model_state = cls.models['ref'].state_dict() + ref_model_state['pre_classifier.weight'] = einops.rearrange( + ref_model_state['pre_classifier.weight'], 'd n -> d n 1 1') + ref_model_state['classifier.weight'] = einops.rearrange( + ref_model_state['classifier.weight'], 'n d -> n d 1 1') + cls.models['test'].load_state_dict(ref_model_state) logger.info("Initialized and restored test model") # Cache tokenized inputs and forward pass results on both the reference and test networks diff --git a/requirements.txt b/requirements.txt index 30438b7..538bb8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.10.0,<=1.11.0 +torch>=2.0.0 transformers>=4.18.0 coremltools>=5.2.0 yapf diff --git a/setup.py b/setup.py index 5d659ae..39ef969 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ long_description_content_type='text/markdown', author='Apple Inc.', install_requires=[ - "torch>=1.10.0,<=1.11.0", + "torch>=2.0.0", "coremltools>=5.2.0", "transformers>=4.18.0", "protobuf>=3.1.0,<=3.20.1",