Skip to content

Commit abcbbe7

Browse files
authored
[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (llvm#93361)
LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics. For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.
1 parent 19b43e1 commit abcbbe7

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
523523
llvmType);
524524
}
525525

526-
/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
526+
/// Creates a value with the 1-D vector shape provided in `llvmType`.
527527
/// This is used as effective vector length by some intrinsics supporting
528528
/// dynamic vector lengths at runtime.
529529
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
532532
auto vShape = vType.getShape();
533533
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
534534

535-
return rewriter.create<LLVM::ConstantOp>(
535+
Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
536536
loc, rewriter.getI32Type(),
537537
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
538+
539+
if (!vType.getScalableDims()[0])
540+
return baseVecLength;
541+
542+
// For a scalable vector type, create and return `vScale * baseVecLength`.
543+
Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
544+
vScale =
545+
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
546+
Value scalableVecLength =
547+
rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
548+
return scalableVecLength;
538549
}
539550

540551
/// Helper method to lower a `vector.reduction` op that performs an arithmetic

mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
7979
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
8080

8181

82+
// -----
83+
84+
func.func @masked_reduce_add_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
85+
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
86+
return %0 : f32
87+
}
88+
89+
// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
90+
// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
91+
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
92+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
93+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
94+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
95+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
96+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
97+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
98+
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
99+
100+
82101
// -----
83102

84103
func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
@@ -110,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
110129

111130
// -----
112131

132+
func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
133+
%0 = vector.mask %mask { vector.reduction <minnumf>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
134+
return %0 : f32
135+
}
136+
137+
// CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable(
138+
// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
139+
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
140+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
141+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
142+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
143+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
144+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
145+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
146+
// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
147+
148+
// -----
149+
113150
func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
114151
%0 = vector.mask %mask { vector.reduction <maxnumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
115152
return %0 : f32
@@ -167,6 +204,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
167204
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
168205

169206

207+
// -----
208+
209+
func.func @masked_reduce_add_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
210+
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
211+
return %0 : i8
212+
}
213+
214+
// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
215+
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
216+
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
217+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
218+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
219+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
220+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
221+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
222+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
223+
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
224+
225+
170226
// -----
171227

172228
func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
@@ -197,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
197253

198254
// -----
199255

256+
func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
257+
%0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
258+
return %0 : i8
259+
}
260+
261+
// CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable(
262+
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
263+
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
264+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
265+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
266+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
267+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
268+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
269+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
270+
// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
271+
272+
// -----
273+
200274
func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
201275
%0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
202276
return %0 : i8
@@ -239,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
239313

240314
// -----
241315

316+
func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
317+
%0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
318+
return %0 : i8
319+
}
320+
321+
// CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable(
322+
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
323+
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
324+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
325+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
326+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
327+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
328+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
329+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
330+
// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
331+
332+
// -----
333+
242334
func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
243335
%0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
244336
return %0 : i8
@@ -280,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
280372
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
281373
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
282374

375+
// -----
376+
377+
func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
378+
%0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
379+
return %0 : i8
380+
}
381+
382+
// CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable(
383+
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
384+
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
385+
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
386+
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
387+
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
388+
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
389+
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
390+
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
391+
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
392+
283393

0 commit comments

Comments
 (0)