Skip to content

Commit 689c0b5

Browse files
tomaarsenecharlaix
andauthored
Propagate library_name parameter in from_pretrained to export (#2328)
* Propagate library_name parameter in from_pretrained to export Required to avoid automatic inferring of the library_name * Use class attribute for ORTModel instead Under modeling_diffusion it looks like ORTModel isn't used * Add test case * Update optimum/onnxruntime/modeling_ort.py --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent 53f39a6 commit 689c0b5

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

optimum/onnxruntime/modeling_ort.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class ORTModel(ORTSessionMixin, OptimizedModel):
148148

149149
model_type = "onnx_model"
150150
auto_model_class = AutoModel
151+
_library_name: Optional[str] = None
151152

152153
def __init__(
153154
self,
@@ -431,6 +432,7 @@ def _export(
431432
local_files_only=local_files_only,
432433
force_download=force_download,
433434
trust_remote_code=trust_remote_code,
435+
library_name=cls._library_name,
434436
)
435437
maybe_save_preprocessors(model_id, model_save_path, src_subfolder=subfolder)
436438

@@ -628,6 +630,7 @@ class ORTModelForFeatureExtraction(ORTModel):
628630
"""
629631

630632
auto_model_class = AutoModel
633+
_library_name: Optional[str] = "transformers"
631634

632635
@add_start_docstrings_to_model_forward(
633636
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -753,6 +756,7 @@ class ORTModelForMaskedLM(ORTModel):
753756
"""
754757

755758
auto_model_class = AutoModelForMaskedLM
759+
_library_name: Optional[str] = "transformers"
756760

757761
@add_start_docstrings_to_model_forward(
758762
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -855,6 +859,7 @@ class ORTModelForQuestionAnswering(ORTModel):
855859
"""
856860

857861
auto_model_class = AutoModelForQuestionAnswering
862+
_library_name: Optional[str] = "transformers"
858863

859864
@add_start_docstrings_to_model_forward(
860865
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -974,6 +979,7 @@ class ORTModelForSequenceClassification(ORTModel):
974979
"""
975980

976981
auto_model_class = AutoModelForSequenceClassification
982+
_library_name: Optional[str] = "transformers"
977983

978984
@add_start_docstrings_to_model_forward(
979985
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1077,6 +1083,7 @@ class ORTModelForTokenClassification(ORTModel):
10771083
"""
10781084

10791085
auto_model_class = AutoModelForTokenClassification
1086+
_library_name: Optional[str] = "transformers"
10801087

10811088
@add_start_docstrings_to_model_forward(
10821089
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1173,6 +1180,7 @@ class ORTModelForMultipleChoice(ORTModel):
11731180
"""
11741181

11751182
auto_model_class = AutoModelForMultipleChoice
1183+
_library_name: Optional[str] = "transformers"
11761184

11771185
@add_start_docstrings_to_model_forward(
11781186
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1376,6 +1384,7 @@ class ORTModelForSemanticSegmentation(ORTModel):
13761384
"""
13771385

13781386
auto_model_class = AutoModelForSemanticSegmentation
1387+
_library_name: Optional[str] = "transformers"
13791388

13801389
@add_start_docstrings_to_model_forward(
13811390
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
@@ -1479,6 +1488,7 @@ class ORTModelForAudioClassification(ORTModel):
14791488
"""
14801489

14811490
auto_model_class = AutoModelForAudioClassification
1491+
_library_name: Optional[str] = "transformers"
14821492

14831493
@add_start_docstrings_to_model_forward(
14841494
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1577,6 +1587,7 @@ class ORTModelForCTC(ORTModel):
15771587
"""
15781588

15791589
auto_model_class = AutoModelForCTC
1590+
_library_name: Optional[str] = "transformers"
15801591

15811592
@add_start_docstrings_to_model_forward(
15821593
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1681,6 +1692,7 @@ class ORTModelForAudioXVector(ORTModel):
16811692
"""
16821693

16831694
auto_model_class = AutoModelForAudioXVector
1695+
_library_name: Optional[str] = "transformers"
16841696

16851697
@add_start_docstrings_to_model_forward(
16861698
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1770,6 +1782,7 @@ class ORTModelForAudioFrameClassification(ORTModel):
17701782
"""
17711783

17721784
auto_model_class = AutoModelForAudioFrameClassification
1785+
_library_name: Optional[str] = "transformers"
17731786

17741787
@add_start_docstrings_to_model_forward(
17751788
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
@@ -1850,6 +1863,7 @@ class ORTModelForImageToImage(ORTModel):
18501863
"""
18511864

18521865
auto_model_class = AutoModelForImageToImage
1866+
_library_name: Optional[str] = "transformers"
18531867

18541868
@add_start_docstrings_to_model_forward(
18551869
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")

optimum/onnxruntime/modeling_seq2seq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ def _export(
12801280
local_files_only=local_files_only,
12811281
force_download=force_download,
12821282
trust_remote_code=trust_remote_code,
1283+
library_name=cls._library_name,
12831284
)
12841285
maybe_save_preprocessors(model_id, model_save_path, src_subfolder=subfolder)
12851286

tests/onnxruntime/test_modeling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,21 @@ def test_compare_to_io_binding(self, model_arch):
13401340

13411341
gc.collect()
13421342

1343+
def test_load_sentence_transformers_model_as_fill_mask(self):
1344+
model_id = "sparse-encoder-testing/splade-bert-tiny-nq"
1345+
onnx_model = ORTModelForMaskedLM.from_pretrained(model_id)
1346+
tokenizer = get_preprocessor(model_id)
1347+
MASK_TOKEN = tokenizer.mask_token
1348+
pipe = pipeline("fill-mask", model=onnx_model, tokenizer=tokenizer, device=0)
1349+
text = f"The capital of France is {MASK_TOKEN}."
1350+
outputs = pipe(text)
1351+
1352+
self.assertEqual(pipe.device, onnx_model.device)
1353+
self.assertGreaterEqual(outputs[0]["score"], 0.0)
1354+
self.assertIsInstance(outputs[0]["token_str"], str)
1355+
1356+
gc.collect()
1357+
13431358

13441359
class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
13451360
SUPPORTED_ARCHITECTURES = [

0 commit comments

Comments
 (0)