Skip to content

Commit e4516a2

Browse files
committed
Exposing select layer
1 parent 78756c4 commit e4516a2

File tree

1 file changed

+12
-2
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/condition

1 file changed

+12
-2
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,18 @@ def where(
9696
ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape
9797
)
9898

99-
select_layer = ctx.net.add_select(condition_val, x_val, y_val)
99+
return select(ctx, target, source_ir, name, x_val, y_val, condition)
100100

101-
set_layer_name(select_layer, target, f"{name}_select")
102101

102+
def select(
103+
ctx: ConversionContext,
104+
target: Target,
105+
source_ir: Optional[SourceIR],
106+
name: str,
107+
input: Union[TRTTensor, np.ndarray, torch.Tensor],
108+
other: Union[TRTTensor, np.ndarray, torch.Tensor],
109+
condition: Union[TRTTensor, np.ndarray, torch.Tensor],
110+
) -> TRTTensor:
111+
select_layer = ctx.net.add_select(condition, input, other)
112+
set_layer_name(select_layer, target, name + "_select", source_ir)
103113
return select_layer.get_output(0)

0 commit comments

Comments
 (0)