Skip to content

Commit 370bab5

Browse files
committed
[MLIR][Linalg] Scalable Vectorization of Reduction
Allow scalable vectorization of linalg::reduce and linalg::generic with reduction iterator. For now, only reduction on the trailing dimension is supported.
1 parent b298e2d commit 370bab5

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,12 @@ static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
582582
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
583583
}
584584

585+
static bool isLinalgReduction(LinalgOp &op) {
586+
return isa<linalg::ReduceOp>(op) ||
587+
(isa<linalg::GenericOp>(op) &&
588+
llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
589+
}
590+
585591
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
586592
/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
587593
/// currently being vectorized. If `dest` has null rank, build an memref.store.
@@ -1773,6 +1779,9 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17731779
if (isa<ConvolutionOpInterface>(op.getOperation()))
17741780
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
17751781

1782+
if (isLinalgReduction(op))
1783+
return reductionPreconditions(op);
1784+
17761785
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
17771786
// linalg.copy ops and ops that implement ContractionOpInterface for now.
17781787
if (!isElementwise(op) &&
@@ -1942,13 +1951,30 @@ vectorizeScalableVectorPrecondition(Operation *op,
19421951
if (inputVectorSizes.empty())
19431952
return success();
19441953

1954+
auto linalgOp = dyn_cast<LinalgOp>(op);
1955+
if (linalgOp && isLinalgReduction(linalgOp)) {
1956+
LDBG("Checking reduce op dims for scalable vectorization\n");
1957+
auto iteratorTypes = linalgOp.getIteratorTypesArray();
1958+
assert(iteratorTypes.size() == inputScalableVecDims.size() &&
1959+
"Number of iterator types and input scalable dims mismatch");
1960+
// For now, only support scalable vectorization of a reduction on the
1961+
// trailing dim.
1962+
for (size_t i = 0; i < inputScalableVecDims.size() - 1; ++i) {
1963+
if (inputScalableVecDims[i] && isReductionIterator(iteratorTypes[i])) {
1964+
LDBG("Non-trailing reduction dim requested for scalable "
1965+
"vectorization\n");
1966+
return failure();
1967+
}
1968+
}
1969+
return success();
1970+
}
1971+
19451972
bool isScalable = inputScalableVecDims.back();
19461973
if (!isScalable)
19471974
return success();
19481975

19491976
// Only element-wise and 1d depthwise conv ops supported in the presence of
19501977
// scalable dims.
1951-
auto linalgOp = dyn_cast<LinalgOp>(op);
19521978
return success(linalgOp && (isElementwise(linalgOp) ||
19531979
isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
19541980
}

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,83 @@ module attributes {transform.with_named_sequence} {
142142
}
143143
}
144144

145+
// -----
146+
147+
func.func @vectorize_dynamic_reduction_1d(%arg0: tensor<?xf32>,
148+
%arg1: tensor<f32>) -> tensor<f32> {
149+
150+
%0 = linalg.reduce ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) dimensions = [0]
151+
(%in: f32, %init: f32) {
152+
%0 = arith.addf %in, %init : f32
153+
linalg.yield %0 : f32
154+
}
155+
return %0 : tensor<f32>
156+
}
157+
158+
// CHECK-LABEL: func.func @vectorize_dynamic_reduction_1d(
159+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<f32>) -> tensor<f32> {
160+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
161+
// CHECK: %[[VAL_1:.*]] = tensor.dim %[[ARG_0]], %[[VAL_0]] : tensor<?xf32>
162+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
163+
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
164+
// CHECK: %[[VAL_4:.*]] = vector.create_mask %[[VAL_1]] : vector<[4]xi1>
165+
// CHECK: %[[VAL_5:.*]] = vector.mask %[[VAL_4]] { vector.transfer_read %[[ARG_0]][%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
166+
// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
167+
// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[ARG_1]][], %[[VAL_6]] : tensor<f32>, vector<f32>
168+
// CHECK: %[[VAL_8:.*]] = vector.extractelement %[[VAL_7]][] : vector<f32>
169+
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_4]] { vector.multi_reduction <add>, %[[VAL_5]], %[[VAL_8]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
170+
// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : f32 to vector<f32>
171+
// CHECK: %[[VAL_11:.*]] = vector.transfer_write %[[VAL_10]], %[[ARG_1]][] : vector<f32>, tensor<f32>
172+
// CHECK: return %[[VAL_11]] : tensor<f32>
173+
// CHECK: }
174+
175+
module attributes {transform.with_named_sequence} {
176+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177+
%0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
178+
transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
179+
transform.yield
180+
}
181+
}
182+
183+
// -----
184+
185+
func.func @vectorize_dynamic_reduction_2d(%arg0: tensor<?x?xf32>,
186+
%arg1: tensor<?xf32>) -> tensor<?xf32> {
187+
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
188+
affine_map<(d0, d1) -> (d0)>],
189+
iterator_types = ["parallel", "reduction"] }
190+
ins(%arg0 : tensor<?x?xf32>)
191+
outs(%arg1 : tensor<?xf32>) {
192+
^bb(%in: f32, %out: f32) :
193+
%0 = arith.addf %in, %out : f32
194+
linalg.yield %0 : f32
195+
} -> tensor<?xf32>
196+
return %0 : tensor<?xf32>
197+
}
198+
199+
// CHECK-LABEL: func.func @vectorize_dynamic_reduction_2d(
200+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
201+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
202+
// CHECK: %[[VAL_1:.*]] = tensor.dim %[[ARG_0]], %[[VAL_0]] : tensor<?x?xf32>
203+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
204+
// CHECK: %[[VAL_3:.*]] = tensor.dim %[[ARG_0]], %[[VAL_2]] : tensor<?x?xf32>
205+
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
206+
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
207+
// CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_3]] : vector<1x[4]xi1>
208+
// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_6]] { vector.transfer_read %[[ARG_0]][%[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
209+
// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
210+
// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_1]] : vector<1xi1>
211+
// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %[[ARG_1]][%[[VAL_4]]], %[[VAL_8]] {in_bounds = [true]} : tensor<?xf32>, vector<1xf32> } : vector<1xi1> -> vector<1xf32>
212+
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_6]] { vector.multi_reduction <add>, %[[VAL_7]], %[[VAL_10]] [1] : vector<1x[4]xf32> to vector<1xf32> } : vector<1x[4]xi1> -> vector<1xf32>
213+
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
214+
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_9]] { vector.transfer_write %[[VAL_11]], %[[ARG_1]][%[[VAL_12]]] {in_bounds = [true]} : vector<1xf32>, tensor<?xf32> } : vector<1xi1> -> tensor<?xf32>
215+
// CHECK: return %[[VAL_13]] : tensor<?xf32>
216+
// CHECK: }
217+
218+
module attributes {transform.with_named_sequence} {
219+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
220+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
221+
transform.structured.vectorize %0 vector_sizes [1, [4]] : !transform.any_op
222+
transform.yield
223+
}
224+
}

mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,30 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
298298
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
299299
// CHECK: return %[[VAL_4]] : f32
300300

301+
func.func @scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
302+
%0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
303+
return %0 : vector<2xf32>
304+
}
305+
306+
// CHECK-LABEL: func.func @scalable_dim_2d(
307+
// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>,
308+
// CHECK-SAME: %[[ARG_1:.*]]: vector<2xf32>,
309+
// CHECK-SAME: %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
310+
// CHECK-DAG: %[[CON_0:.*]] = arith.constant 1 : index
311+
// CHECK-DAG: %[[CON_1:.*]] = arith.constant 0 : index
312+
// CHECK-DAG: %[[CON_2:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
313+
// CHECK: %[[VAL_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
314+
// CHECK: %[[VAL_1:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
315+
// CHECK: %[[VAL_2:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
316+
// CHECK: %[[VAL_3:.*]] = vector.mask %[[VAL_2]] { vector.reduction <add>, %[[VAL_0]], %[[VAL_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
317+
// CHECK: %[[VAL_4:.*]] = vector.insertelement %[[VAL_3]], %[[CON_2]][%[[CON_1]] : index] : vector<2xf32>
318+
// CHECK: %[[VAL_5:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
319+
// CHECK: %[[VAL_6:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
320+
// CHECK: %[[VAL_7:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
321+
// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.reduction <add>, %[[VAL_5]], %[[VAL_6]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
322+
// CHECK: %[[VAL_9:.*]] = vector.insertelement %[[VAL_8]], %[[VAL_4]][%[[CON_0]] : index] : vector<2xf32>
323+
// CHECK: return %[[VAL_9]] : vector<2xf32>
324+
301325
module attributes {transform.with_named_sequence} {
302326
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
303327
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

0 commit comments

Comments
 (0)