Skip to content

Commit dbcc454

Browse files
authored
[MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim (#89978)
To correctly lower multi_reduction of 1-dim scalable vector, e.g., <[4]xf32>
1 parent 45b59cb commit dbcc454

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,10 @@ struct OneDimMultiReductionToTwoDim
437437
auto loc = multiReductionOp.getLoc();
438438
auto srcVectorType = multiReductionOp.getSourceVectorType();
439439
auto srcShape = srcVectorType.getShape();
440-
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
441-
srcVectorType.getElementType());
440+
auto castedType = VectorType::get(
441+
ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
442+
ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
443+
442444
auto accType =
443445
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
444446
assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
455457
loc, accType, multiReductionOp.getAcc());
456458
Value castMask;
457459
if (maskableOp.isMasked()) {
458-
auto maskType = llvm::cast<ShapedType>(mask.getType());
459-
auto castMaskType =
460-
VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
461-
maskType.getElementType());
460+
auto maskType = llvm::cast<VectorType>(mask.getType());
461+
auto castMaskType = VectorType::get(
462+
ArrayRef<int64_t>{1, maskType.getShape().back()},
463+
maskType.getElementType(),
464+
ArrayRef<bool>{false, maskType.getScalableDims().back()});
462465
castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
463466
}
464467

mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
281281
// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
282282
// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
283283

284+
// Check that OneDimMultiReductionToTwoDim handles scalable dim
285+
func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
286+
%0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
287+
return %0 : f32
288+
}
289+
290+
// CHECK-LABEL: func.func @scalable_dim_1d(
291+
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
292+
// CHECK-SAME: %[[ARG_1:.*]]: f32,
293+
// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
294+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0 : index
295+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
296+
// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
297+
// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
298+
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
299+
// CHECK: return %[[VAL_4]] : f32
300+
284301
module attributes {transform.with_named_sequence} {
285302
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
286303
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

0 commit comments

Comments
 (0)