Skip to content

Commit 495c981

Browse files
committed
add comments for readability
1 parent 3eff8e6 commit 495c981

File tree

1 file changed

+6
-0
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+6
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,24 @@ def adaptive_avg_pool1d(
117117
output_size: Union[int, Sequence[int]],
118118
) -> TRTTensor:
119119
def start_index(idx: int, out_dim: int, in_dim: int) -> int:
120+
"""Calculate the start index of each pooling window"""
120121
return math.floor((float(idx) * float(in_dim)) / out_dim)
121122

122123
def end_index(idx: int, out_dim: int, in_dim: int) -> int:
124+
"""Calculate the end index of each pooling window"""
123125
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)
124126

125127
in_dim = input.shape[-1]
126128
out_dim = output_size if isinstance(output_size, int) else output_size[0]
127129
output_list = []
128130

131+
# iterate over each output dimension
129132
for i in range(out_dim):
133+
# calculate the start and end index of each pooling window
130134
start = start_index(i, out_dim, in_dim)
131135
end = end_index(i, out_dim, in_dim)
132136

137+
# slice the input tensor from start to end index, the result of which is the window waiting for pooling
133138
slices = []
134139
for j in range(start, end):
135140
slice = impl.select.select(
@@ -148,6 +153,7 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
148153
slices = impl.cat.cat(
149154
ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1
150155
)
156+
# calculate the mean of the slices (average pooling output) and append to the output list
151157
output_list.append(
152158
impl.reduce.mean(
153159
ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True

0 commit comments

Comments
 (0)