Skip to content

Commit 436e99b

Browse files
committed
chore: Update for random input
1 parent 7bbb07a commit 436e99b

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,20 @@ def example_tensor(
352352
)
353353
else:
354354
if isinstance(self.shape, tuple):
355-
return torch.rand(self.shape).to(
356-
dtype=self.dtype.to(torch.dtype, use_default=True)
357-
)
355+
if self.dtype in [dtype.u8, dtype.i8, dtype.i32, dtype.i64]:
356+
type = self.dtype.to(torch.dtype, use_default=True)
357+
return torch.randint(
358+
torch.iinfo(type).min,
359+
torch.iinfo(type).max,
360+
self.shape,
361+
dtype=type,
362+
)
363+
elif self.dtype == dtype.b:
364+
return torch.rand(self.shape) < 0.5
365+
else:
366+
return torch.rand(self.shape).to(
367+
dtype=self.dtype.to(torch.dtype, use_default=True)
368+
)
358369
else:
359370
RuntimeError(
360371
f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})"
@@ -380,6 +391,8 @@ def example_tensor(
380391
self.shape[optimization_profile_field],
381392
dtype=type,
382393
)
394+
elif self.dtype == dtype.b:
395+
return torch.rand(self.shape[optimization_profile_field]) < 0.5
383396
else:
384397
return torch.rand(self.shape[optimization_profile_field]).to(
385398
dtype=self.dtype.to(torch.dtype, use_default=True)

tests/py/dynamo/conversion/harness.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def run_test_compare_tensor_attributes_only(
319319
expected_ops,
320320
comparators: List[Tuple[Callable, List]],
321321
precision=torch.float,
322-
output_dtypes=None,
322+
check_dtype=True,
323323
use_dynamo_tracer=False,
324324
enable_passes=False,
325325
):
@@ -337,6 +337,15 @@ def run_test_compare_tensor_attributes_only(
337337
debug=True,
338338
)
339339

340+
output_dtypes = None
341+
if check_dtype:
342+
output_dtypes = infer_module_output_dtypes(
343+
mod,
344+
[Input.from_tensor(i) for i in inputs],
345+
compilation_settings.device,
346+
truncate_double=compilation_settings.truncate_double,
347+
)
348+
340349
interp = TRTInterpreter(
341350
mod,
342351
Input.from_tensors(inputs),
@@ -393,4 +402,12 @@ def run_test_with_dynamic_shape(
393402
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
394403
if not use_example_tensors:
395404
inputs_max = [spec.torch_tensor for spec in input_specs]
396-
super().run_test(mod, inputs_max, interp, rtol, atol, check_dtype=check_dtype, pyt_inputs=pyt_inputs)
405+
super().run_test(
406+
mod,
407+
inputs_max,
408+
interp,
409+
rtol,
410+
atol,
411+
check_dtype=check_dtype,
412+
pyt_inputs=pyt_inputs,
413+
)

0 commit comments

Comments
 (0)