We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a0f22cd commit 5e6f3bdCopy full SHA for 5e6f3bd
py/torch_tensorrt/_Input.py
@@ -364,9 +364,13 @@ def example_tensor(
364
)
365
366
if isinstance(self.shape, dict):
367
- return torch.rand(self.shape[optimization_profile_field]).to(
368
- dtype=self.dtype.to(torch.dtype, use_default=True)
369
- )
+ if self.dtype == dtype.b:
+ return torch.rand(self.shape[optimization_profile_field]) < 0.5
+ else:
370
+ return torch.rand(self.shape[optimization_profile_field]).to(
371
+ dtype=self.dtype.to(torch.dtype, use_default=True)
372
+ )
373
+
374
else:
375
raise RuntimeError(
376
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
0 commit comments