Skip to content

Commit 8dfba80

Browse files
committed
add output_dtypes in test
1 parent 9a054ce commit 8dfba80

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/py/dynamo/converters/harness.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def run_test(
266266
precision=torch.float,
267267
check_dtype=True,
268268
disable_passes=False,
269+
output_dtypes=None,
269270
):
270271
mod.eval()
271272
mod = self.generate_graph(
@@ -284,6 +285,7 @@ def run_test(
284285
interp = TRTInterpreter(
285286
mod,
286287
Input.from_tensors(inputs),
288+
output_dtypes=output_dtypes,
287289
)
288290
super().run_test(
289291
mod,
@@ -306,6 +308,7 @@ def run_test_with_dynamic_shape(
306308
rtol=1e-03,
307309
atol=1e-03,
308310
disable_passes=False,
311+
output_dtypes=None,
309312
):
310313
mod.eval()
311314
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
@@ -321,6 +324,7 @@ def run_test_with_dynamic_shape(
321324
interp = TRTInterpreter(
322325
mod,
323326
input_specs,
327+
output_dtypes=output_dtypes,
324328
)
325329
# Since the lowering is based on optimal shape. We need to test with
326330
# different shape(for ex. max shape) for testing dynamic shape

0 commit comments

Comments
 (0)