Skip to content

Commit 5e73e8d

Browse files
committed
doc
1 parent b11db3c commit 5e73e8d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

onnx_array_api/reference/evaluator_yield.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,22 @@ class YieldEvaluator:
139139
140140
:param onnx_model: model to run
141141
:param recursive: dig into subgraph and functions as well
142+
:param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator`
142143
"""
143144

144145
def __init__(
145146
self,
146147
onnx_model: ModelProto,
147148
recursive: bool = False,
148-
cls=ExtendedReferenceEvaluator,
149+
cls: Optional[type[ExtendedReferenceEvaluator]] = None,
149150
):
150151
assert not recursive, "recursive=True is not yet implemented"
151152
self.onnx_model = onnx_model
152-
self.evaluator = cls(onnx_model) if cls is not None else None
153+
self.evaluator = (
154+
cls(onnx_model)
155+
if cls is not None
156+
else ExtendedReferenceEvaluator(onnx_model)
157+
)
153158

154159
def enumerate_results(
155160
self,

0 commit comments

Comments
 (0)