Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test
  • Loading branch information
IlyasMoutawwakil committed Sep 2, 2025
commit daabb4e9939c34a18f9a93f1fabfc3ca56a73c30
18 changes: 13 additions & 5 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
import unittest
from typing import Any, Dict

import numpy as np
from huggingface_hub.constants import default_cache_path
from PIL import Image
from transformers import AutoTokenizer
from transformers.pipelines import Pipeline

from optimum.onnxruntime import ORTModelForFeatureExtraction
from optimum.pipelines import pipeline


Expand Down Expand Up @@ -236,6 +233,8 @@ def test_audio_classification_pipeline(self):

def test_pipeline_with_ort_model(self):
"""Test ORT pipeline with a model already in ONNX format"""
from optimum.onnxruntime import ORTModelForFeatureExtraction

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
model = ORTModelForFeatureExtraction.from_pretrained("distilbert-base-cased", export=True)
pipe = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, accelerator="ort")
Expand All @@ -257,8 +256,17 @@ def test_pipeline_with_custom_model_id(self):
self.assertIsInstance(result, list)
self.assertIsInstance(result[0], list)

def tearDown(self):
shutil.rmtree(default_cache_path, ignore_errors=True)
def test_pipeline_with_invalid_task(self):
"""Test ORT pipeline with an unsupported task"""
with self.assertRaises(KeyError) as context:
_ = pipeline(task="invalid-task", accelerator="ort")
self.assertIn("Unknown task invalid-task", str(context.exception))

def test_pipeline_with_invalid_accelerator(self):
"""Test ORT pipeline with an unsupported accelerator"""
with self.assertRaises(ValueError) as context:
_ = pipeline(task="text-classification", accelerator="invalid-accelerator")
self.assertIn("Accelerator invalid-accelerator not recognized", str(context.exception))


if __name__ == "__main__":
Expand Down
Loading