Skip to content

Commit 89d4a98

Browse files
committed
chore: Test update for int type dynamic shape input
1 parent 610057c commit 89d4a98

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,25 @@ def example_tensor(
364364
)
365365

366366
if isinstance(self.shape, dict):
367-
return torch.rand(self.shape[optimization_profile_field]).to(
368-
dtype=self.dtype.to(torch.dtype, use_default=True)
369-
)
367+
if (
368+
self.dtype == dtype.u8
369+
or self.dtype == dtype.i8
370+
or self.dtype == dtype.i32
371+
or self.dtype == dtype.i64
372+
):
373+
type = self.dtype.to(torch.dtype, use_default=True)
374+
min_value = torch.iinfo(type).min
375+
max_value = torch.iinfo(type).max
376+
return torch.randint(
377+
min_value,
378+
max_value,
379+
self.shape[optimization_profile_field],
380+
dtype=type,
381+
)
382+
else:
383+
return torch.rand(self.shape[optimization_profile_field]).to(
384+
dtype=self.dtype.to(torch.dtype, use_default=True)
385+
)
370386
else:
371387
raise RuntimeError(
372388
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"

tests/py/dynamo/conversion/harness.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def run_test_with_dynamic_shape(
338338
input_specs,
339339
rtol=1e-03,
340340
atol=1e-03,
341-
output_dtypes=None,
341+
check_dtype=True,
342342
use_dynamo_tracer=False,
343343
enable_passes=False,
344344
):
@@ -355,13 +355,23 @@ def run_test_with_dynamic_shape(
355355
# We replicate this behavior here
356356
compilation_settings = CompilationSettings(truncate_double=True)
357357

358+
output_dtypes = None
359+
if check_dtype:
360+
output_dtypes = infer_module_output_dtypes(
361+
mod,
362+
input_specs,
363+
compilation_settings.device,
364+
truncate_double=compilation_settings.truncate_double,
365+
)
366+
358367
interp = TRTInterpreter(
359368
mod,
360369
input_specs,
361370
output_dtypes=output_dtypes,
362371
compilation_settings=compilation_settings,
363372
)
373+
364374
# Since the lowering is based on optimal shape. We need to test with
365375
# different shape(for ex. max shape) for testing dynamic shape
366376
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
367-
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
377+
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol, check_dtype)

0 commit comments

Comments
 (0)