Skip to content

Commit fc6499e

Browse files
committed
chore: Better random bool values for example_tensor
1 parent 00d1d85 commit fc6499e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,13 @@ def example_tensor(
364364
)
365365

366366
if isinstance(self.shape, dict):
367-
return torch.rand(self.shape[optimization_profile_field]).to(
368-
dtype=self.dtype.to(torch.dtype, use_default=True)
369-
)
367+
if self.dtype == dtype.b:
368+
return torch.rand(self.shape[optimization_profile_field]) < 0.5
369+
else:
370+
return torch.rand(self.shape[optimization_profile_field]).to(
371+
dtype=self.dtype.to(torch.dtype, use_default=True)
372+
)
373+
370374
else:
371375
raise RuntimeError(
372376
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"

0 commit comments

Comments
 (0)