Skip to content

Commit ac82540

Browse files
committed
aten::select
1 parent 4b608f0 commit ac82540

File tree

1 file changed

+73
-0
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+73
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,76 @@ def index(
370370
reshape_output = reshape_layer.get_output(0)
371371

372372
return reshape_output
373+
374+
375+
def scatter_value(
376+
ctx: ConversionContext,
377+
target: Target,
378+
source_ir: Optional[SourceIR],
379+
name: str,
380+
input: TRTTensor,
381+
dim: Shape,
382+
index: Shape,
383+
value: TRTTensor,
384+
) -> TRTTensor:
385+
if not isinstance(input, TRTTensor):
386+
raise RuntimeError(
387+
f"scatter_tensor received input {input} that is not part "
388+
"of the TensorRT region!"
389+
)
390+
391+
ranks = len(input.shape)
392+
dim = get_positive_dim(cast(int, dim), ranks)
393+
dynamic_shape = has_dynamic_shape(input.shape)
394+
if dynamic_shape:
395+
# Check whether slice target dim is dynamic shape dim
396+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
397+
398+
input_dims = len(input.shape)
399+
for i in range(0, input_dims):
400+
if index[i] >= input.shape[i]:
401+
raise RuntimeError(
402+
f"cannot have index greater than the dimension length! {input.shape[dim]}"
403+
)
404+
value_tensor = value * torch.ones(index.shape)
405+
scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT)
406+
scatter_layer.set_axis(dim)
407+
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
408+
out = scatter_layer.get_output(0)
409+
return out
410+
411+
412+
def scatter_src(
413+
ctx: ConversionContext,
414+
target: Target,
415+
source_ir: Optional[SourceIR],
416+
name: str,
417+
input: TRTTensor,
418+
dim: Shape,
419+
index: Shape,
420+
src: float,
421+
) -> TRTTensor:
422+
if not isinstance(input, TRTTensor):
423+
raise RuntimeError(
424+
f"scatter_tensor received input {input} that is not part "
425+
"of the TensorRT region!"
426+
)
427+
428+
ranks = len(input.shape)
429+
dim = get_positive_dim(cast(int, dim), ranks)
430+
dynamic_shape = has_dynamic_shape(input.shape)
431+
if dynamic_shape:
432+
# Check whether slice target dim is dynamic shape dim
433+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
434+
435+
input_dims = len(input.shape)
436+
for i in range(0, input_dims):
437+
if index[i] >= input.shape[i]:
438+
raise RuntimeError(
439+
f"cannot have index greater than the dimension length! {input.shape[dim]}"
440+
)
441+
scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT)
442+
scatter_layer.set_axis(dim)
443+
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
444+
out = scatter_layer.get_output(0)
445+
return out

0 commit comments

Comments
 (0)