Skip to content

Commit d20dad1

Browse files
authored
Port of bare minimum of batch engine code
A port of the bare minimum batch engine code from the Promptflow repo. The focus was on expediency, rather than elegance and some items are still stubs or don't do much currently. A follow-up PR will integrate this new code so that it can be invoked as part of an evaluate call. Things that still need to be done: - Handle errors properly for OpenAI requests - Handle "target" functions being passed - Capture token usage from OpenAI requests (though this does not currently work right now) - Handle logging in a cleaner way - Determine what tracing (if any) is needed and either implement that code or remove it entirely - Determine if we still need the saving to a local folder structure, whether this should be changed to save to a single file (optionally?), or remove that code outright
1 parent 3516c3c commit d20dad1

25 files changed

+1945
-19
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@
33
# ---------------------------------------------------------
44
from .eval_run_context import EvalRunContext
55
from .code_client import CodeClient
6-
from .proxy_client import ProxyClient
6+
from .proxy_client import ProxyClient, ProxyRun
7+
from ._run_submitter_client import RunSubmitterClient
78
from .target_run_context import TargetRunContext
8-
from .proxy_client import ProxyRun
99

10-
__all__ = ["CodeClient", "ProxyClient", "EvalRunContext", "TargetRunContext", "ProxyRun"]
10+
__all__ = [
11+
"CodeClient",
12+
"ProxyClient",
13+
"EvalRunContext",
14+
"TargetRunContext",
15+
"ProxyRun",
16+
"RunSubmitterClient",
17+
]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
import logging
6+
import pandas as pd
7+
import sys
8+
from collections import defaultdict
9+
from concurrent.futures import Future, ThreadPoolExecutor
10+
from os import PathLike
11+
from typing import Any, Callable, Dict, Final, List, Mapping, Optional, Sequence, Union, cast
12+
13+
from .batch_clients import BatchClientRun, HasAsyncCallable
14+
from ..._legacy._batch_engine._run_submitter import RunSubmitter
15+
from ..._legacy._batch_engine._config import BatchEngineConfig
16+
from ..._legacy._batch_engine._run import Run
17+
18+
19+
LOGGER = logging.getLogger(__name__)
20+
21+
22+
class RunSubmitterClient:
23+
def __init__(self, config: Optional[BatchEngineConfig] = None) -> None:
24+
self._config = config or BatchEngineConfig(LOGGER, use_async=True)
25+
self._thread_pool = ThreadPoolExecutor(thread_name_prefix="evaluators_thread")
26+
27+
def run(
28+
self,
29+
flow: Callable,
30+
data: Union[str, PathLike, pd.DataFrame],
31+
column_mapping: Optional[Dict[str, str]] = None,
32+
evaluator_name: Optional[str] = None,
33+
**kwargs: Any,
34+
) -> BatchClientRun:
35+
if not isinstance(data, pd.DataFrame):
36+
# Should never get here
37+
raise ValueError("Data must be a pandas DataFrame")
38+
if not column_mapping:
39+
raise ValueError("Column mapping must be provided")
40+
41+
# The column mappings are index by data to indicate they come from the data
42+
# input. Update the inputs so that each entry is a dictionary with a data key
43+
# that contains the original input data.
44+
inputs = [{"data": input_data} for input_data in data.to_dict(orient="records")]
45+
46+
# always uses async behind the scenes
47+
if isinstance(flow, HasAsyncCallable):
48+
flow = flow._to_async() # pylint: disable=protected-access
49+
50+
run_submitter = RunSubmitter(self._config)
51+
run_future = self._thread_pool.submit(
52+
run_submitter.submit,
53+
dynamic_callable=flow,
54+
inputs=inputs,
55+
column_mapping=column_mapping,
56+
name_prefix=evaluator_name,
57+
created_on=kwargs.pop("created_on", None),
58+
storage_creator=kwargs.pop("storage_creator", None),
59+
**kwargs,
60+
)
61+
62+
return run_future
63+
64+
def get_details(self, client_run: BatchClientRun, all_results: bool = False) -> pd.DataFrame:
65+
run = self._get_run(client_run)
66+
67+
data: Dict[str, List[Any]] = defaultdict(list)
68+
stop_at: Final[int] = self._config.default_num_results if not all_results else sys.maxsize
69+
70+
def _update(prefix: str, items: Sequence[Mapping[str, Any]]) -> None:
71+
for i, line in enumerate(items):
72+
if i >= stop_at:
73+
break
74+
for k, value in line.items():
75+
key = f"{prefix}.{k}"
76+
data[key].append(value)
77+
78+
_update("inputs", run.inputs)
79+
_update("outputs", run.outputs)
80+
81+
df = pd.DataFrame(data).reindex(columns=[k for k in data.keys()])
82+
return df
83+
84+
def get_metrics(self, client_run: BatchClientRun) -> Dict[str, Any]:
85+
run = self._get_run(client_run)
86+
return dict(run.metrics)
87+
88+
def get_run_summary(self, client_run: BatchClientRun) -> Dict[str, Any]:
89+
run = self._get_run(client_run)
90+
91+
total_lines = run.result.total_lines if run.result else 0
92+
failed_lines = run.result.failed_lines if run.result else 0
93+
94+
return {
95+
"status": run.status.value,
96+
"duration": str(run.duration),
97+
"completed_lines": total_lines - failed_lines,
98+
"failed_lines": failed_lines,
99+
# "log_path": "",
100+
}
101+
102+
@staticmethod
103+
def _get_run(run: BatchClientRun) -> Run:
104+
return cast(Future[Run], run).result()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
import pandas
6+
from os import PathLike
7+
from typing import Any, Awaitable, Callable, Dict, Optional, Protocol, Union, runtime_checkable
8+
9+
10+
class BatchClientRun(Protocol):
11+
"""The protocol for the batch client run."""
12+
13+
pass
14+
15+
16+
@runtime_checkable
17+
class HasAsyncCallable(Protocol):
18+
"""The protocol for an object that has an async callable."""
19+
20+
def _to_async(self) -> Callable[[Any, Any], Awaitable[Any]]: ...
21+
22+
23+
class BatchClient(Protocol):
24+
"""The protocol for the batch client. This allows for running a flow on a data source
25+
and getting the details of the run."""
26+
27+
def run(
28+
self,
29+
flow: Callable,
30+
data: Union[str, PathLike, pandas.DataFrame],
31+
column_mapping: Optional[Dict[str, str]] = None,
32+
evaluator_name: Optional[str] = None,
33+
**kwargs: Any,
34+
) -> BatchClientRun:
35+
"""Run the given flow on the data with the given column mapping.
36+
37+
:param flow: The flow to run.
38+
:type flow: Union[Callable, HasAsyncCallable]
39+
:param data: The JSONL file containing the data to run the flow on,
40+
or the loaded data
41+
:type data: Union[str, PathLike]
42+
:param column_mapping: The column mapping to use.
43+
:type column_mapping: Mapping[str, str]
44+
:param name: The name of the run.
45+
:type name: Optional[str]
46+
:param kwargs: Additional keyword arguments to pass to the flow.
47+
:return: The result of the batch client run.
48+
:rtype: BatchClientRun
49+
"""
50+
...
51+
52+
def get_details(self, client_run: BatchClientRun, all_results: bool = False) -> pandas.DataFrame:
53+
"""Get the details of the run.
54+
55+
:param client_run: The run to get the details of.
56+
:type client_run: BatchClientRun
57+
:param all_results: Whether to get all results.
58+
:type all_results: bool
59+
:return: The details of the run.
60+
:rtype: pandas.DataFrame
61+
"""
62+
...
63+
64+
def get_metrics(self, client_run: BatchClientRun) -> Dict[str, Any]:
65+
"""Get the metrics of the run.
66+
67+
:param client_run: The run to get the metrics of.
68+
:type client_run: BatchClientRun
69+
:return: The metrics of the run.
70+
:rtype: Mapping[str, Any]
71+
"""
72+
...
73+
74+
def get_run_summary(self, client_run: BatchClientRun) -> Dict[str, Any]:
75+
"""Get the summary of the run.
76+
77+
:param client_run: The run to get the summary of.
78+
:type client_run: BatchClientRun
79+
:return: The summary of the run.
80+
:rtype: Mapping[str, Any]
81+
"""
82+
...
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
# NOTE: This is a direct port of the bare minimum needed for BatchEngine functionality from
6+
# the original Promptflow code. The goal here is expediency, not elegance. As such
7+
# parts of this code may be a little "quirky", seem incomplete in places, or contain
8+
# longer TODOs comments than usual. In a future code update, large swaths of this code
9+
# will be refactored or deleted outright.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from dataclasses import dataclass
6+
from logging import Logger
7+
8+
from ..._constants import PF_BATCH_TIMEOUT_SEC_DEFAULT
9+
10+
11+
@dataclass
12+
class BatchEngineConfig:
13+
"""Context for a batch of evaluations. This will contain the configuration,
14+
logging, and other needed information."""
15+
16+
logger: Logger
17+
"""The logger to use for logging messages."""
18+
19+
batch_timeout_seconds: int = PF_BATCH_TIMEOUT_SEC_DEFAULT
20+
"""The maximum amount of time to wait for all evaluations in the batch to complete."""
21+
22+
run_timeout_seconds: int = 600
23+
"""The maximum amount of time to wait for an evaluation to run against a single entry
24+
in the data input to complete."""
25+
26+
max_concurrency: int = 10
27+
"""The maximum number of evaluations to run concurrently."""
28+
29+
use_async: bool = True
30+
"""Whether to use asynchronous evaluation."""
31+
32+
default_num_results: int = 100
33+
"""The default number of results to return if you don't ask for all results."""
34+
35+
def __post_init__(self):
36+
if self.logger is None:
37+
raise ValueError("logger cannot be None")
38+
if self.batch_timeout_seconds <= 0:
39+
raise ValueError("batch_timeout_seconds must be greater than 0")
40+
if self.run_timeout_seconds <= 0:
41+
raise ValueError("run_timeout_seconds must be greater than 0")
42+
if self.max_concurrency <= 0:
43+
raise ValueError("max_concurrency must be greater than 0")
44+
if self.default_num_results <= 0:
45+
raise ValueError("default_num_results must be greater than 0")

0 commit comments

Comments
 (0)