We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2f569c3 commit 8554782Copy full SHA for 8554782
py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -392,7 +392,19 @@ def aten_ops_sigmoid(
392
)
393
394
395
-@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
+def index_dtype_validator(node: Node) -> bool:
396
+ index = node.args[1]
397
+ for ind in index:
398
+ if ind is not None:
399
+ val = ind.meta.get("val")
400
+ if val is not None and val.dtype != torch.int32:
401
+ return False
402
+ return True
403
+
404
405
+@dynamo_tensorrt_converter(
406
+ torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
407
+)
408
@enforce_tensor_types(
409
{
410
0: (TRTTensor,),
0 commit comments