Skip to content

Commit f32a6c2

Browse files
authored
Add async support for _SafetyEvaluation (#40623)
1 parent a55603d commit f32a6c2

File tree

2 files changed

+102
-10
lines changed

2 files changed

+102
-10
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import os
77
import inspect
88
import logging
9+
import asyncio
910
from datetime import datetime
1011
from azure.ai.evaluation._common._experimental import experimental
11-
from typing import Any, Callable, Dict, List, Optional, Union, cast
12+
from typing import Any, Callable, Dict, List, Optional, Union, cast, Coroutine, TypeVar, Awaitable
1213
from azure.ai.evaluation._common.math import list_mean_nan_safe
1314
from azure.ai.evaluation._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT
1415
from azure.ai.evaluation._evaluators import (
@@ -192,10 +193,17 @@ async def callback(
192193
context = latest_message.get("context", None)
193194
latest_context = None
194195
try:
196+
is_async = self._is_async_function(target)
195197
if self._check_target_returns_context(target):
196-
response, latest_context = target(query=application_input)
198+
if is_async:
199+
response, latest_context = await target(query=application_input)
200+
else:
201+
response, latest_context = target(query=application_input)
197202
else:
198-
response = target(query=application_input)
203+
if is_async:
204+
response = await target(query=application_input)
205+
else:
206+
response = target(query=application_input)
199207
except Exception as e:
200208
response = f"Something went wrong {e!s}"
201209

@@ -465,7 +473,7 @@ def _get_evaluators(
465473
blame=ErrorBlame.USER_ERROR,
466474
)
467475
return evaluators_dict
468-
476+
469477
@staticmethod
470478
def _check_target_returns_context(target: Callable) -> bool:
471479
"""
@@ -478,6 +486,15 @@ def _check_target_returns_context(target: Callable) -> bool:
478486
ret_type = sig.return_annotation
479487
if ret_type == inspect.Signature.empty:
480488
return False
489+
490+
# Check for Coroutine/Awaitable return types for async functions
491+
origin = getattr(ret_type, "__origin__", None)
492+
if origin is not None and (origin is Coroutine or origin is Awaitable):
493+
args = getattr(ret_type, "__args__", None)
494+
if args and len(args) > 0:
495+
# For async functions, check the actual return type inside the Coroutine
496+
ret_type = args[-1]
497+
481498
if ret_type is tuple:
482499
return True
483500
return False
@@ -494,13 +511,33 @@ def _check_target_returns_str(target: Callable) -> bool:
494511
ret_type = sig.return_annotation
495512
if ret_type == inspect.Signature.empty:
496513
return False
514+
515+
# Check for Coroutine/Awaitable return types for async functions
516+
origin = getattr(ret_type, "__origin__", None)
517+
if origin is not None and (origin is Coroutine or origin is Awaitable):
518+
args = getattr(ret_type, "__args__", None)
519+
if args and len(args) > 0:
520+
# For async functions, check the actual return type inside the Coroutine
521+
ret_type = args[-1]
522+
497523
if ret_type is str:
498524
return True
499525
return False
500526

501-
502527
@staticmethod
503-
def _check_target_is_callback(target:Callable) -> bool:
528+
def _is_async_function(target: Callable) -> bool:
529+
"""
530+
Checks if the target function is an async function.
531+
532+
:param target: The target function to check.
533+
:type target: Callable
534+
:return: True if the target function is async, False otherwise.
535+
:rtype: bool
536+
"""
537+
return asyncio.iscoroutinefunction(target)
538+
539+
@staticmethod
540+
def _check_target_is_callback(target: Callable) -> bool:
504541
sig = inspect.signature(target)
505542
param_names = list(sig.parameters.keys())
506543
return 'messages' in param_names and 'stream' in param_names and 'session_state' in param_names and 'context' in param_names
@@ -630,7 +667,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult:
630667

631668
async def __call__(
632669
self,
633-
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
670+
target: Union[Callable, Awaitable[Any], AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
634671
evaluators: List[_SafetyEvaluator] = [],
635672
evaluation_name: Optional[str] = None,
636673
num_turns : int = 1,
@@ -644,12 +681,12 @@ async def __call__(
644681
jailbreak_data_path: Optional[Union[str, os.PathLike]] = None,
645682
output_path: Optional[Union[str, os.PathLike]] = None,
646683
data_paths: Optional[Union[Dict[str, str], Dict[str, Union[str,os.PathLike]]]] = None
647-
) -> Union[Dict[str, EvaluationResult], Dict[str, str], Dict[str, Union[str,os.PathLike]]]:
684+
) -> Union[Dict[str, EvaluationResult], Dict[str, str], Dict[str, Union[str,os.PathLike]]]:
648685
'''
649686
Evaluates the target function based on the provided parameters.
650687
651-
:param target: The target function to call during the evaluation.
652-
:type target: Callable
688+
:param target: The target function to call during the evaluation. This can be a synchronous or asynchronous function.
689+
:type target: Union[Callable, Awaitable[Any], AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
653690
:param evaluators: A list of SafetyEvaluator.
654691
:type evaluators: List[_SafetyEvaluator]
655692
:param evaluation_name: The display name name of the evaluation.

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_safety_evaluation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ def mock_target_with_context_fn() -> tuple:
4545
return mock_target_with_context_fn
4646

4747

48+
@pytest.fixture
49+
def mock_async_target():
50+
async def mock_async_target_fn(query: str) -> str:
51+
return "mock async response"
52+
53+
return mock_async_target_fn
54+
55+
56+
@pytest.fixture
57+
def mock_async_target_with_context():
58+
async def mock_async_target_with_context_fn(query: str) -> tuple:
59+
return ("mock async response", "mock async context")
60+
61+
return mock_async_target_with_context_fn
62+
63+
4864
@pytest.fixture
4965
def mock_eval_result_dict():
5066
jailbreak = {
@@ -122,6 +138,16 @@ def test_check_target_returns_context_false(self, safety_eval, mock_target):
122138
def test_check_target_returns_context_true(self, safety_eval, mock_target_with_context):
123139
assert safety_eval._check_target_returns_context(mock_target_with_context)
124140

141+
def test_check_target_returns_context_async(self, safety_eval, mock_async_target, mock_async_target_with_context):
142+
# Test that async function without context returns False
143+
assert not safety_eval._check_target_returns_context(mock_async_target)
144+
# Test that async function with context returns True
145+
assert safety_eval._check_target_returns_context(mock_async_target_with_context)
146+
147+
def test_check_target_returns_str_async(self, safety_eval, mock_async_target):
148+
# Test that async function returning string returns True
149+
assert safety_eval._check_target_returns_str(mock_async_target)
150+
125151
def test_validate_inputs_groundedness_no_source(self, safety_eval, mock_target):
126152
with pytest.raises(EvaluationException) as exc_info:
127153
safety_eval._validate_inputs(
@@ -243,3 +269,32 @@ async def test_simulate_no_results(self, mock_call, mock_init, safety_eval, mock
243269
target=mock_target, adversarial_scenario=AdversarialScenario.ADVERSARIAL_QA
244270
)
245271
assert "outputs generated by the simulator" in str(exc_info.value)
272+
273+
def test_is_async_function(self, safety_eval, mock_target, mock_async_target):
274+
# Test that sync function returns False
275+
assert not safety_eval._is_async_function(mock_target)
276+
# Test that async function returns True
277+
assert safety_eval._is_async_function(mock_async_target)
278+
279+
@pytest.mark.asyncio
280+
@patch("azure.ai.evaluation._safety_evaluation._safety_evaluation._SafetyEvaluation._simulate")
281+
@patch("azure.ai.evaluation._evaluate._evaluate.evaluate")
282+
async def test_call_with_async_target(self, mock_evaluate, mock_simulate, safety_eval, mock_async_target):
283+
# Setup mocks
284+
mock_simulate.return_value = {"MockSimulator": "MockSimulator_Data.jsonl"}
285+
mock_evaluate.return_value = {
286+
"metrics": {},
287+
"rows": [],
288+
"studio_url": "test_url"
289+
}
290+
291+
# Call the __call__ method with an async target
292+
result = await safety_eval(target=mock_async_target)
293+
294+
# Verify the results
295+
assert isinstance(result, dict)
296+
assert "MockSimulator" in result
297+
298+
# Verify that _simulate was called with the async target
299+
mock_simulate.assert_called_once()
300+
assert mock_simulate.call_args[1]["target"] == mock_async_target

0 commit comments

Comments
 (0)