Skip to content

Commit 23149d5

Browse files
committed
[mlir] Added ctlz and cttz to math dialect and LLVM dialect
Count leading/trailing zeros are an existing LLVM intrinsic. Added LLVM support for the intrinsics with lowerings from the math dialect to LLVM dialect. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D115206
1 parent 7d62b68 commit 23149d5

File tree

5 files changed

+159
-0
lines changed

5 files changed

+159
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,12 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<OpTrait> traits = []> :
14011401
let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c);
14021402
}
14031403

1404+
class LLVM_CountZerosIntrinsicOp<string func, list<OpTrait> traits = []> :
1405+
LLVM_OneResultIntrOp<func, [], [0],
1406+
!listconcat([NoSideEffect], traits)> {
1407+
let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined);
1408+
}
1409+
14041410
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
14051411
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
14061412
def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
@@ -1421,6 +1427,8 @@ def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
14211427
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
14221428
def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
14231429
def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
1430+
def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">;
1431+
def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">;
14241432
def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
14251433
def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">;
14261434
def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">;

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,54 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
297297
}];
298298
}
299299

300+
//===----------------------------------------------------------------------===//
301+
// CountLeadingZerosOp
302+
//===----------------------------------------------------------------------===//
303+
304+
def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
305+
let summary = "counts the leading zeros an integer value";
306+
let description = [{
307+
The `ctlz` operation computes the number of leading zeros of an integer value.
308+
309+
Example:
310+
311+
```mlir
312+
// Scalar ctlz function value.
313+
%a = math.ctlz %b : i32
314+
315+
// SIMD vector element-wise ctlz function value.
316+
%f = math.ctlz %g : vector<4xi16>
317+
318+
// Tensor element-wise ctlz function value.
319+
%x = math.ctlz %y : tensor<4x?xi8>
320+
```
321+
}];
322+
}
323+
324+
//===----------------------------------------------------------------------===//
325+
// CountTrailingZerosOp
326+
//===----------------------------------------------------------------------===//
327+
328+
def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
329+
let summary = "counts the trailing zeros an integer value";
330+
let description = [{
331+
The `cttz` operation computes the number of trailing zeros of an integer value.
332+
333+
Example:
334+
335+
```mlir
336+
// Scalar cttz function value.
337+
%a = math.cttz %b : i32
338+
339+
// SIMD vector element-wise cttz function value.
340+
%f = math.cttz %g : vector<4xi16>
341+
342+
// Tensor element-wise cttz function value.
343+
%x = math.cttz %y : tensor<4x?xi8>
344+
```
345+
}];
346+
}
347+
300348
//===----------------------------------------------------------------------===//
301349
// CtPopOp
302350
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,54 @@ using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3838
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3939
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
4040

41+
// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
42+
template <typename MathOp, typename LLVMOp>
43+
struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
44+
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
45+
using Super = CountOpLowering<MathOp, LLVMOp>;
46+
47+
LogicalResult
48+
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
49+
ConversionPatternRewriter &rewriter) const override {
50+
auto operandType = adaptor.getOperand().getType();
51+
52+
if (!operandType || !LLVM::isCompatibleType(operandType))
53+
return failure();
54+
55+
auto loc = op.getLoc();
56+
auto resultType = op.getResult().getType();
57+
auto boolType = rewriter.getIntegerType(1);
58+
auto boolZero = rewriter.getIntegerAttr(boolType, 0);
59+
60+
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
61+
LLVM::ConstantOp zero =
62+
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
63+
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
64+
zero);
65+
return success();
66+
}
67+
68+
auto vectorType = resultType.template dyn_cast<VectorType>();
69+
if (!vectorType)
70+
return failure();
71+
72+
return LLVM::detail::handleMultidimensionalVectors(
73+
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
74+
[&](Type llvm1DVectorTy, ValueRange operands) {
75+
LLVM::ConstantOp zero =
76+
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
77+
return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
78+
operands[0], zero);
79+
},
80+
rewriter);
81+
}
82+
};
83+
84+
using CountLeadingZerosOpLowering =
85+
CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
86+
using CountTrailingZerosOpLowering =
87+
CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
88+
4189
// A `expm1` is converted into `exp - 1`.
4290
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
4391
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -222,6 +270,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
222270
CeilOpLowering,
223271
CopySignOpLowering,
224272
CosOpLowering,
273+
CountLeadingZerosOpLowering,
274+
CountTrailingZerosOpLowering,
225275
CtPopFOpLowering,
226276
ExpOpLowering,
227277
Exp2OpLowering,

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,39 @@ func @sine(%arg0 : f32) {
7474

7575
// -----
7676

77+
// CHECK-LABEL: func @ctlz(
78+
// CHECK-SAME: i32
79+
func @ctlz(%arg0 : i32) {
80+
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
81+
// CHECK: "llvm.intr.ctlz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
82+
%0 = math.ctlz %arg0 : i32
83+
std.return
84+
}
85+
86+
// -----
87+
88+
// CHECK-LABEL: func @cttz(
89+
// CHECK-SAME: i32
90+
func @cttz(%arg0 : i32) {
91+
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
92+
// CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
93+
%0 = math.cttz %arg0 : i32
94+
std.return
95+
}
96+
97+
// -----
98+
99+
// CHECK-LABEL: func @cttz_vec(
100+
// CHECK-SAME: i32
101+
func @cttz_vec(%arg0 : vector<4xi32>) {
102+
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
103+
// CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (vector<4xi32>, i1) -> vector<4xi32>
104+
%0 = math.cttz %arg0 : vector<4xi32>
105+
std.return
106+
}
107+
108+
// -----
109+
77110
// CHECK-LABEL: func @ctpop(
78111
// CHECK-SAME: i32
79112
func @ctpop(%arg0 : i32) {

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,26 @@ llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
135135
llvm.return
136136
}
137137

138+
// CHECK-LABEL: @ctlz_test
139+
llvm.func @ctlz_test(%arg0: i32, %arg1: vector<8xi32>) {
140+
%i1 = llvm.mlir.constant(false) : i1
141+
// CHECK: call i32 @llvm.ctlz.i32
142+
"llvm.intr.ctlz"(%arg0, %i1) : (i32, i1) -> i32
143+
// CHECK: call <8 x i32> @llvm.ctlz.v8i32
144+
"llvm.intr.ctlz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
145+
llvm.return
146+
}
147+
148+
// CHECK-LABEL: @cttz_test
149+
llvm.func @cttz_test(%arg0: i32, %arg1: vector<8xi32>) {
150+
%i1 = llvm.mlir.constant(false) : i1
151+
// CHECK: call i32 @llvm.cttz.i32
152+
"llvm.intr.cttz"(%arg0, %i1) : (i32, i1) -> i32
153+
// CHECK: call <8 x i32> @llvm.cttz.v8i32
154+
"llvm.intr.cttz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
155+
llvm.return
156+
}
157+
138158
// CHECK-LABEL: @ctpop_test
139159
llvm.func @ctpop_test(%arg0: i32, %arg1: vector<8xi32>) {
140160
// CHECK: call i32 @llvm.ctpop.i32

0 commit comments

Comments
 (0)