Skip to content

Commit be6acbe

Browse files
committed
chore: Test update for int type dynamic shape input
1 parent 0f7977e commit be6acbe

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,25 @@ def example_tensor(
372372
)
373373

374374
if isinstance(self.shape, dict):
375-
return torch.rand(self.shape[optimization_profile_field]).to(
376-
dtype=self.dtype.to(torch.dtype, use_default=True)
377-
)
375+
if (
376+
self.dtype == dtype.u8
377+
or self.dtype == dtype.i8
378+
or self.dtype == dtype.i32
379+
or self.dtype == dtype.i64
380+
):
381+
type = self.dtype.to(torch.dtype, use_default=True)
382+
min_value = torch.iinfo(type).min
383+
max_value = torch.iinfo(type).max
384+
return torch.randint(
385+
min_value,
386+
max_value,
387+
self.shape[optimization_profile_field],
388+
dtype=type,
389+
)
390+
else:
391+
return torch.rand(self.shape[optimization_profile_field]).to(
392+
dtype=self.dtype.to(torch.dtype, use_default=True)
393+
)
378394
else:
379395
raise RuntimeError(
380396
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"

tests/py/dynamo/conversion/harness.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def run_test_with_dynamic_shape(
353353
input_specs,
354354
rtol=1e-03,
355355
atol=1e-03,
356-
output_dtypes=None,
356+
check_dtype=True,
357357
use_dynamo_tracer=False,
358358
enable_passes=False,
359359
use_example_tensors=True,
@@ -372,15 +372,25 @@ def run_test_with_dynamic_shape(
372372
# We replicate this behavior here
373373
compilation_settings = CompilationSettings(truncate_double=True)
374374

375+
output_dtypes = None
376+
if check_dtype:
377+
output_dtypes = infer_module_output_dtypes(
378+
mod,
379+
input_specs,
380+
compilation_settings.device,
381+
truncate_double=compilation_settings.truncate_double,
382+
)
383+
375384
interp = TRTInterpreter(
376385
mod,
377386
input_specs,
378387
output_dtypes=output_dtypes,
379388
compilation_settings=compilation_settings,
380389
)
390+
381391
# Since the lowering is based on optimal shape. We need to test with
382392
# different shape(for ex. max shape) for testing dynamic shape
383393
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
384394
if not use_example_tensors:
385395
inputs_max = [spec.torch_tensor for spec in input_specs]
386-
super().run_test(mod, inputs_max, interp, rtol, atol, pyt_inputs=pyt_inputs)
396+
super().run_test(mod, inputs_max, interp, rtol, atol, check_dtype=check_dtype, pyt_inputs=pyt_inputs)

0 commit comments

Comments
 (0)