diff --git a/tests/integrations/huggingface_hub/test_huggingface_hub.py b/tests/integrations/huggingface_hub/test_huggingface_hub.py index b9ab4df5bf..ffeb6acbb5 100644 --- a/tests/integrations/huggingface_hub/test_huggingface_hub.py +++ b/tests/integrations/huggingface_hub/test_huggingface_hub.py @@ -34,6 +34,15 @@ ) +def get_hf_provider_inference_client(): + # The provider parameter was added in version 0.28.0 of huggingface_hub + return ( + InferenceClient(model="test-model", provider="hf-inference") + if HF_VERSION >= (0, 28, 0) + else InferenceClient(model="test-model") + ) + + def _add_mock_response( httpx_mock, rsps, method, url, json=None, status=200, body=None, headers=None ): @@ -616,7 +625,7 @@ def test_chat_completion( ) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() with sentry_sdk.start_transaction(name="test"): client.chat_completion( @@ -688,7 +697,7 @@ def test_chat_completion_streaming( ) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() with sentry_sdk.start_transaction(name="test"): _ = list( @@ -752,7 +761,7 @@ def test_chat_completion_api_error( sentry_init(traces_sample_rate=1.0) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() with sentry_sdk.start_transaction(name="test"): with pytest.raises(HfHubHTTPError): @@ -804,7 +813,7 @@ def test_span_status_error(sentry_init, capture_events, mock_hf_api_with_errors) sentry_init(traces_sample_rate=1.0) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() with sentry_sdk.start_transaction(name="test"): with pytest.raises(HfHubHTTPError): @@ -849,7 +858,7 @@ def test_chat_completion_with_tools( ) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() tools = [ { @@ -938,7 +947,7 @@ def test_chat_completion_streaming_with_tools( ) events = capture_events() - client = InferenceClient(model="test-model") + client = get_hf_provider_inference_client() tools = [ {