@@ -117,19 +117,24 @@ def adaptive_avg_pool1d(
117
117
output_size : Union [int , Sequence [int ]],
118
118
) -> TRTTensor :
119
119
def start_index (idx : int , out_dim : int , in_dim : int ) -> int :
120
+ """Calculate the start index of each pooling window"""
120
121
return math .floor ((float (idx ) * float (in_dim )) / out_dim )
121
122
122
123
def end_index (idx : int , out_dim : int , in_dim : int ) -> int :
124
+ """Calculate the end index of each pooling window"""
123
125
return math .ceil ((float (idx + 1 ) * float (in_dim )) / out_dim )
124
126
125
127
in_dim = input .shape [- 1 ]
126
128
out_dim = output_size if isinstance (output_size , int ) else output_size [0 ]
127
129
output_list = []
128
130
131
+ # iterate over each output dimension
129
132
for i in range (out_dim ):
133
+ # calculate the start and end index of each pooling window
130
134
start = start_index (i , out_dim , in_dim )
131
135
end = end_index (i , out_dim , in_dim )
132
136
137
+ # slice the input tensor from start to end index, the result of which is the window waiting for pooling
133
138
slices = []
134
139
for j in range (start , end ):
135
140
slice = impl .select .select (
@@ -148,6 +153,7 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
148
153
slices = impl .cat .cat (
149
154
ctx , target , source_ir , f"{ name } _slices_cat_{ i } " , slices , dim = - 1
150
155
)
156
+ # calculate the mean of the slices (average pooling output) and append to the output list
151
157
output_list .append (
152
158
impl .reduce .mean (
153
159
ctx , target , source_ir , f"{ name } _sum_{ i } " , slices , dim = - 1 , keepdim = True
0 commit comments