6
6
import os
7
7
import inspect
8
8
import logging
9
+ import asyncio
9
10
from datetime import datetime
10
11
from azure .ai .evaluation ._common ._experimental import experimental
11
- from typing import Any , Callable , Dict , List , Optional , Union , cast
12
+ from typing import Any , Callable , Dict , List , Optional , Union , cast , Coroutine , TypeVar , Awaitable
12
13
from azure .ai .evaluation ._common .math import list_mean_nan_safe
13
14
from azure .ai .evaluation ._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT
14
15
from azure .ai .evaluation ._evaluators import (
@@ -192,10 +193,17 @@ async def callback(
192
193
context = latest_message .get ("context" , None )
193
194
latest_context = None
194
195
try :
196
+ is_async = self ._is_async_function (target )
195
197
if self ._check_target_returns_context (target ):
196
- response , latest_context = target (query = application_input )
198
+ if is_async :
199
+ response , latest_context = await target (query = application_input )
200
+ else :
201
+ response , latest_context = target (query = application_input )
197
202
else :
198
- response = target (query = application_input )
203
+ if is_async :
204
+ response = await target (query = application_input )
205
+ else :
206
+ response = target (query = application_input )
199
207
except Exception as e :
200
208
response = f"Something went wrong { e !s} "
201
209
@@ -465,7 +473,7 @@ def _get_evaluators(
465
473
blame = ErrorBlame .USER_ERROR ,
466
474
)
467
475
return evaluators_dict
468
-
476
+
469
477
@staticmethod
470
478
def _check_target_returns_context (target : Callable ) -> bool :
471
479
"""
@@ -478,6 +486,15 @@ def _check_target_returns_context(target: Callable) -> bool:
478
486
ret_type = sig .return_annotation
479
487
if ret_type == inspect .Signature .empty :
480
488
return False
489
+
490
+ # Check for Coroutine/Awaitable return types for async functions
491
+ origin = getattr (ret_type , "__origin__" , None )
492
+ if origin is not None and (origin is Coroutine or origin is Awaitable ):
493
+ args = getattr (ret_type , "__args__" , None )
494
+ if args and len (args ) > 0 :
495
+ # For async functions, check the actual return type inside the Coroutine
496
+ ret_type = args [- 1 ]
497
+
481
498
if ret_type is tuple :
482
499
return True
483
500
return False
@@ -494,13 +511,33 @@ def _check_target_returns_str(target: Callable) -> bool:
494
511
ret_type = sig .return_annotation
495
512
if ret_type == inspect .Signature .empty :
496
513
return False
514
+
515
+ # Check for Coroutine/Awaitable return types for async functions
516
+ origin = getattr (ret_type , "__origin__" , None )
517
+ if origin is not None and (origin is Coroutine or origin is Awaitable ):
518
+ args = getattr (ret_type , "__args__" , None )
519
+ if args and len (args ) > 0 :
520
+ # For async functions, check the actual return type inside the Coroutine
521
+ ret_type = args [- 1 ]
522
+
497
523
if ret_type is str :
498
524
return True
499
525
return False
500
526
501
-
502
527
@staticmethod
503
- def _check_target_is_callback (target :Callable ) -> bool :
528
+ def _is_async_function (target : Callable ) -> bool :
529
+ """
530
+ Checks if the target function is an async function.
531
+
532
+ :param target: The target function to check.
533
+ :type target: Callable
534
+ :return: True if the target function is async, False otherwise.
535
+ :rtype: bool
536
+ """
537
+ return asyncio .iscoroutinefunction (target )
538
+
539
+ @staticmethod
540
+ def _check_target_is_callback (target : Callable ) -> bool :
504
541
sig = inspect .signature (target )
505
542
param_names = list (sig .parameters .keys ())
506
543
return 'messages' in param_names and 'stream' in param_names and 'session_state' in param_names and 'context' in param_names
@@ -630,7 +667,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult:
630
667
631
668
async def __call__ (
632
669
self ,
633
- target : Union [Callable , AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
670
+ target : Union [Callable , Awaitable [ Any ], AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
634
671
evaluators : List [_SafetyEvaluator ] = [],
635
672
evaluation_name : Optional [str ] = None ,
636
673
num_turns : int = 1 ,
@@ -644,12 +681,12 @@ async def __call__(
644
681
jailbreak_data_path : Optional [Union [str , os .PathLike ]] = None ,
645
682
output_path : Optional [Union [str , os .PathLike ]] = None ,
646
683
data_paths : Optional [Union [Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]] = None
647
- ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
684
+ ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
648
685
'''
649
686
Evaluates the target function based on the provided parameters.
650
687
651
- :param target: The target function to call during the evaluation.
652
- :type target: Callable
688
+ :param target: The target function to call during the evaluation. This can be a synchronous or asynchronous function.
689
+ :type target: Union[ Callable, Awaitable[Any], AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
653
690
:param evaluators: A list of SafetyEvaluator.
654
691
:type evaluators: List[_SafetyEvaluator]
655
692
:param evaluation_name: The display name name of the evaluation.
0 commit comments