Skip to content

Commit af6e3c0

Browse files
authored
[mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (#141020)
`vector<t>` types are not compatible with the LLVM type system – with the current approach employed within `LLVMTypeConverter`, they must be explicitly converted into `vector<1xt>` when lowering. Employ this rule within the conversion patterns for intrinsics that are handled directly within `MathToLLVM`: `math.ctlz` `.cttz`, `.absi`, `.expm1`, `.log1p`, `.rsqrt`, `.isnan`, `.isfinite`. This change does not cover/test patterns that are based off `VectorConvertToLLVMPattern` template from `LLVMCommon/VectorPattern.h`. --------- Signed-off-by: Artem Gindinson <[email protected]>
1 parent 0ada5c7 commit af6e3c0

File tree

2 files changed

+163
-54
lines changed

2 files changed

+163
-54
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 69 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
7373
using ATan2OpLowering =
7474
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
7575
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76+
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77+
// may be better to separate the patterns.
7678
template <typename MathOp, typename LLVMOp>
7779
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
7880
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,29 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
8183
LogicalResult
8284
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
8385
ConversionPatternRewriter &rewriter) const override {
86+
const auto &typeConverter = *this->getTypeConverter();
8487
auto operandType = adaptor.getOperand().getType();
85-
86-
if (!operandType || !LLVM::isCompatibleType(operandType))
88+
auto llvmOperandType = typeConverter.convertType(operandType);
89+
if (!llvmOperandType)
8790
return failure();
8891

8992
auto loc = op.getLoc();
9093
auto resultType = op.getResult().getType();
94+
auto llvmResultType = typeConverter.convertType(resultType);
95+
if (!llvmResultType)
96+
return failure();
9197

92-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
93-
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
94-
false);
98+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99+
rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100+
adaptor.getOperand(), false);
95101
return success();
96102
}
97103

98-
auto vectorType = dyn_cast<VectorType>(resultType);
99-
if (!vectorType)
104+
if (!isa<VectorType>(llvmResultType))
100105
return failure();
101106

102107
return LLVM::detail::handleMultidimensionalVectors(
103-
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
108+
op.getOperation(), adaptor.getOperands(), typeConverter,
104109
[&](Type llvm1DVectorTy, ValueRange operands) {
105110
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
106111
false);
@@ -123,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
123128
LogicalResult
124129
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
125130
ConversionPatternRewriter &rewriter) const override {
131+
const auto &typeConverter = *this->getTypeConverter();
126132
auto operandType = adaptor.getOperand().getType();
127-
128-
if (!operandType || !LLVM::isCompatibleType(operandType))
133+
auto llvmOperandType = typeConverter.convertType(operandType);
134+
if (!llvmOperandType)
129135
return failure();
130136

131137
auto loc = op.getLoc();
132138
auto resultType = op.getResult().getType();
133-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
139+
auto floatType = cast<FloatType>(
140+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
134141
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
135142
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
136143
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
137144

138-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
145+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
139146
LLVM::ConstantOp one;
140-
if (LLVM::isCompatibleVectorType(operandType)) {
147+
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
141148
one = rewriter.create<LLVM::ConstantOp>(
142-
loc, operandType,
143-
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
149+
loc, llvmOperandType,
150+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
151+
floatOne));
144152
} else {
145-
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
153+
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
146154
}
147155
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
148156
expAttrs.getAttrs());
149157
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
150-
op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
158+
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
151159
return success();
152160
}
153161

154-
auto vectorType = dyn_cast<VectorType>(resultType);
155-
if (!vectorType)
162+
if (!isa<VectorType>(resultType))
156163
return rewriter.notifyMatchFailure(op, "expected vector result type");
157164

158165
return LLVM::detail::handleMultidimensionalVectors(
159-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
166+
op.getOperation(), adaptor.getOperands(), typeConverter,
160167
[&](Type llvm1DVectorTy, ValueRange operands) {
161168
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
162169
auto splatAttr = SplatElementsAttr::get(
@@ -181,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
181188
LogicalResult
182189
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
183190
ConversionPatternRewriter &rewriter) const override {
191+
const auto &typeConverter = *this->getTypeConverter();
184192
auto operandType = adaptor.getOperand().getType();
185-
186-
if (!operandType || !LLVM::isCompatibleType(operandType))
193+
auto llvmOperandType = typeConverter.convertType(operandType);
194+
if (!llvmOperandType)
187195
return rewriter.notifyMatchFailure(op, "unsupported operand type");
188196

189197
auto loc = op.getLoc();
190198
auto resultType = op.getResult().getType();
191-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
199+
auto floatType = cast<FloatType>(
200+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
192201
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
193202
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
194203
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
195204

196-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
205+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
197206
LLVM::ConstantOp one =
198-
LLVM::isCompatibleVectorType(operandType)
207+
isa<VectorType>(llvmOperandType)
199208
? rewriter.create<LLVM::ConstantOp>(
200-
loc, operandType,
201-
SplatElementsAttr::get(cast<ShapedType>(resultType),
209+
loc, llvmOperandType,
210+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
202211
floatOne))
203-
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
212+
: rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
213+
floatOne);
204214

205215
auto add = rewriter.create<LLVM::FAddOp>(
206-
loc, operandType, ValueRange{one, adaptor.getOperand()},
216+
loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
207217
addAttrs.getAttrs());
208-
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
209-
logAttrs.getAttrs());
218+
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
219+
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
210220
return success();
211221
}
212222

213-
auto vectorType = dyn_cast<VectorType>(resultType);
214-
if (!vectorType)
223+
if (!isa<VectorType>(resultType))
215224
return rewriter.notifyMatchFailure(op, "expected vector result type");
216225

217226
return LLVM::detail::handleMultidimensionalVectors(
218-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
227+
op.getOperation(), adaptor.getOperands(), typeConverter,
219228
[&](Type llvm1DVectorTy, ValueRange operands) {
220229
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
221230
auto splatAttr = SplatElementsAttr::get(
@@ -241,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
241250
LogicalResult
242251
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
243252
ConversionPatternRewriter &rewriter) const override {
253+
const auto &typeConverter = *this->getTypeConverter();
244254
auto operandType = adaptor.getOperand().getType();
245-
246-
if (!operandType || !LLVM::isCompatibleType(operandType))
255+
auto llvmOperandType = typeConverter.convertType(operandType);
256+
if (!llvmOperandType)
247257
return failure();
248258

249259
auto loc = op.getLoc();
250260
auto resultType = op.getResult().getType();
251-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
261+
auto floatType = cast<FloatType>(
262+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
252263
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
253264
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
254265
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
255266

256-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
267+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
257268
LLVM::ConstantOp one;
258-
if (LLVM::isCompatibleVectorType(operandType)) {
269+
if (isa<VectorType>(llvmOperandType)) {
259270
one = rewriter.create<LLVM::ConstantOp>(
260-
loc, operandType,
261-
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
271+
loc, llvmOperandType,
272+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
273+
floatOne));
262274
} else {
263-
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
275+
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
264276
}
265277
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
266278
sqrtAttrs.getAttrs());
267279
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
268-
op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
280+
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
269281
return success();
270282
}
271283

272-
auto vectorType = dyn_cast<VectorType>(resultType);
273-
if (!vectorType)
284+
if (!isa<VectorType>(resultType))
274285
return failure();
275286

276287
return LLVM::detail::handleMultidimensionalVectors(
277-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
288+
op.getOperation(), adaptor.getOperands(), typeConverter,
278289
[&](Type llvm1DVectorTy, ValueRange operands) {
279290
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
280291
auto splatAttr = SplatElementsAttr::get(
@@ -298,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
298309
LogicalResult
299310
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
300311
ConversionPatternRewriter &rewriter) const override {
301-
auto operandType = adaptor.getOperand().getType();
302-
303-
if (!operandType || !LLVM::isCompatibleType(operandType))
312+
const auto &typeConverter = *this->getTypeConverter();
313+
auto operandType =
314+
typeConverter.convertType(adaptor.getOperand().getType());
315+
auto resultType = typeConverter.convertType(op.getResult().getType());
316+
if (!operandType || !resultType)
304317
return failure();
305318

306319
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
307-
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
320+
op, resultType, adaptor.getOperand(), llvm::fcNan);
308321
return success();
309322
}
310323
};
@@ -315,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
315328
LogicalResult
316329
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
317330
ConversionPatternRewriter &rewriter) const override {
318-
auto operandType = adaptor.getOperand().getType();
319-
320-
if (!operandType || !LLVM::isCompatibleType(operandType))
331+
const auto &typeConverter = *this->getTypeConverter();
332+
auto operandType =
333+
typeConverter.convertType(adaptor.getOperand().getType());
334+
auto resultType = typeConverter.convertType(op.getResult().getType());
335+
if (!operandType || !resultType)
321336
return failure();
322337

323338
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
324-
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
339+
op, resultType, adaptor.getOperand(), llvm::fcFinite);
325340
return success();
326341
}
327342
};

0 commit comments

Comments
 (0)