Skip to content

Commit 628fab7

Browse files
committed
aten::select
1 parent 6cad83d commit 628fab7

File tree

1 file changed

+73
-0
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+73
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,76 @@ def index_select(
390390
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
391391

392392
return gather_layer.get_output(0)
393+
394+
395+
def scatter_value(
396+
ctx: ConversionContext,
397+
target: Target,
398+
source_ir: Optional[SourceIR],
399+
name: str,
400+
input: TRTTensor,
401+
dim: Shape,
402+
index: Shape,
403+
value: TRTTensor,
404+
) -> TRTTensor:
405+
if not isinstance(input, TRTTensor):
406+
raise RuntimeError(
407+
f"scatter_tensor received input {input} that is not part "
408+
"of the TensorRT region!"
409+
)
410+
411+
ranks = len(input.shape)
412+
dim = get_positive_dim(cast(int, dim), ranks)
413+
dynamic_shape = has_dynamic_shape(input.shape)
414+
if dynamic_shape:
415+
# Check whether slice target dim is dynamic shape dim
416+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
417+
418+
input_dims = len(input.shape)
419+
for i in range(0, input_dims):
420+
if index[i] >= input.shape[i]:
421+
raise RuntimeError(
422+
f"cannot have index greater than the dimension length! {input.shape[dim]}"
423+
)
424+
value_tensor = value * torch.ones(index.shape)
425+
scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT)
426+
scatter_layer.set_axis(dim)
427+
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
428+
out = scatter_layer.get_output(0)
429+
return out
430+
431+
432+
def scatter_src(
433+
ctx: ConversionContext,
434+
target: Target,
435+
source_ir: Optional[SourceIR],
436+
name: str,
437+
input: TRTTensor,
438+
dim: Shape,
439+
index: Shape,
440+
src: float,
441+
) -> TRTTensor:
442+
if not isinstance(input, TRTTensor):
443+
raise RuntimeError(
444+
f"scatter_tensor received input {input} that is not part "
445+
"of the TensorRT region!"
446+
)
447+
448+
ranks = len(input.shape)
449+
dim = get_positive_dim(cast(int, dim), ranks)
450+
dynamic_shape = has_dynamic_shape(input.shape)
451+
if dynamic_shape:
452+
# Check whether slice target dim is dynamic shape dim
453+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
454+
455+
input_dims = len(input.shape)
456+
for i in range(0, input_dims):
457+
if index[i] >= input.shape[i]:
458+
raise RuntimeError(
459+
f"cannot have index greater than the dimension length! {input.shape[dim]}"
460+
)
461+
scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT)
462+
scatter_layer.set_axis(dim)
463+
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
464+
out = scatter_layer.get_output(0)
465+
return out

0 commit comments

Comments
 (0)