Skip to content

Commit a12754b

Browse files
committed
chore: Cast condition if type is not bool
1 parent 0814b4b commit a12754b

File tree

1 file changed

+8
-1
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/condition

1 file changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from typing import Optional, Union
22

33
import numpy as np
4+
import tensorrt as trt
45
import torch
56
from torch.fx.node import Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
89
from torch_tensorrt.dynamo.conversion.converter_utils import (
910
broadcastable,
11+
cast_trt_tensor,
1012
get_trt_tensor,
1113
prepend_ones,
1214
set_layer_name,
1315
)
16+
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
1417
from torch_tensorrt.fx.types import TRTTensor
1518

1619

@@ -32,8 +35,12 @@ def where(
3235
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))
3336

3437
if not isinstance(condition, TRTTensor):
35-
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
3638
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
39+
40+
if condition.dtype != trt.bool:
41+
condition = cast_trt_tensor(ctx, condition, trt.float32, f"{name}_cast")
42+
condition = ne(ctx, target, source_ir, f"{name}_cond_zero", condition, 0)
43+
3744
diff = max_shape_len - len(condition_shape)
3845
if diff > 0:
3946
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)

0 commit comments

Comments
 (0)