Skip to content

Commit 657ec73

Browse files
authored
[mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt (#95531)
This produces better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: #81840.
1 parent 995835f commit 657ec73

File tree

4 files changed

+90
-2
lines changed

4 files changed

+90
-2
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
552552
Arg<AnyScalableVector, "v3">:$v3,
553553
Arg<AnyScalableVector, "v3">:$v4)>;
554554

555+
def WhileLTIntrOp :
556+
ArmSVE_IntrOp<"whilelt",
557+
[TypeIs<"res", SVEPredicate>, Pure],
558+
/*overloadedOperands=*/[0],
559+
/*overloadedResults=*/[0]>,
560+
Arguments<(ins I64:$base, I64:$n)>;
561+
555562
#endif // ARMSVE_OPS

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
140140
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
141141
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
142142

143+
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
144+
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
145+
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
146+
/// for more details. Note that we can't use (the more general) active.lane.mask
147+
/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
148+
/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
149+
/// `n` is zero (whereas `create_mask` just returns an all-false mask).
150+
struct CreateMaskOpLowering
151+
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
152+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
153+
154+
LogicalResult
155+
matchAndRewrite(vector::CreateMaskOp createMaskOp,
156+
vector::CreateMaskOp::Adaptor adaptor,
157+
ConversionPatternRewriter &rewriter) const override {
158+
auto maskType = createMaskOp.getVectorType();
159+
if (maskType.getRank() != 1 || !maskType.isScalable())
160+
return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
161+
162+
// TODO: Support masks which are multiples of SVE predicates.
163+
auto maskBaseSize = maskType.getDimSize(0);
164+
if (maskBaseSize < 2 || maskBaseSize > 16 ||
165+
!llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
166+
return rewriter.notifyMatchFailure(createMaskOp,
167+
"not SVE predicate-sized");
168+
169+
auto loc = createMaskOp.getLoc();
170+
auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
171+
rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
172+
adaptor.getOperands()[0]);
173+
return success();
174+
}
175+
};
176+
143177
} // namespace
144178

145179
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
169203
ConvertFromSvboolOpLowering,
170204
ZipX2OpLowering,
171205
ZipX4OpLowering>(converter);
206+
// Add vector.create_mask conversion with a high benefit as it produces much
207+
// nicer code than the generic lowering.
208+
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
172209
// clang-format on
173210
}
174211

@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
191228
ConvertToSvboolIntrOp,
192229
ConvertFromSvboolIntrOp,
193230
ZipX2IntrOp,
194-
ZipX4IntrOp>();
231+
ZipX4IntrOp,
232+
WhileLTIntrOp>();
195233
target.addIllegalOp<SdotOp,
196234
SmmlaOp,
197235
UdotOp,

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
22

33
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
44
%b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
211211
%0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
212212
return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
213213
}
214+
215+
// -----
216+
217+
// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
218+
// CHECK-SAME: %[[INDEX:.*]]: i64
219+
func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
220+
// CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
221+
// CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
222+
%0 = vector.create_mask %index : vector<[2]xi1>
223+
// CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
224+
%1 = vector.create_mask %index : vector<[4]xi1>
225+
// CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
226+
%2 = vector.create_mask %index : vector<[8]xi1>
227+
// CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
228+
%3 = vector.create_mask %index : vector<[16]xi1>
229+
return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
230+
}
231+
232+
// -----
233+
234+
// CHECK-LABEL: @arm_sve_unsupported_create_masks
235+
func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>) {
236+
// CHECK-NOT: arm_sve.intr.whilelt
237+
%0 = vector.create_mask %index : vector<[1]xi1>
238+
%1 = vector.create_mask %index : vector<[7]xi1>
239+
%2 = vector.create_mask %index : vector<[32]xi1>
240+
return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
241+
}

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
356356
-> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
357357
llvm.return
358358
}
359+
360+
// CHECK-LABEL: arm_sve_whilelt(
361+
// CHECK-SAME: i64 %[[BASE:[0-9]+]],
362+
// CHECK-SAME: i64 %[[N:[0-9]+]]
363+
llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
364+
// call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
365+
%1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
366+
// call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
367+
%2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
368+
// call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
369+
%3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
370+
// call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
371+
%4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
372+
llvm.return
373+
}

0 commit comments

Comments
 (0)