@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73
73
using ATan2OpLowering =
74
74
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
75
75
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76
+ // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77
+ // may be better to separate the patterns.
76
78
template <typename MathOp, typename LLVMOp>
77
79
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern <MathOp> {
78
80
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,29 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
81
83
LogicalResult
82
84
matchAndRewrite (MathOp op, typename MathOp::Adaptor adaptor,
83
85
ConversionPatternRewriter &rewriter) const override {
86
+ const auto &typeConverter = *this ->getTypeConverter ();
84
87
auto operandType = adaptor.getOperand ().getType ();
85
-
86
- if (!operandType || ! LLVM::isCompatibleType (operandType) )
88
+ auto llvmOperandType = typeConverter. convertType (operandType);
89
+ if (!llvmOperandType )
87
90
return failure ();
88
91
89
92
auto loc = op.getLoc ();
90
93
auto resultType = op.getResult ().getType ();
94
+ auto llvmResultType = typeConverter.convertType (resultType);
95
+ if (!llvmResultType)
96
+ return failure ();
91
97
92
- if (!isa<LLVM::LLVMArrayType>(operandType )) {
93
- rewriter.replaceOpWithNewOp <LLVMOp>(op, resultType, adaptor. getOperand () ,
94
- false );
98
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
99
+ rewriter.replaceOpWithNewOp <LLVMOp>(op, llvmResultType ,
100
+ adaptor. getOperand (), false );
95
101
return success ();
96
102
}
97
103
98
- auto vectorType = dyn_cast<VectorType>(resultType);
99
- if (!vectorType)
104
+ if (!isa<VectorType>(llvmResultType))
100
105
return failure ();
101
106
102
107
return LLVM::detail::handleMultidimensionalVectors (
103
- op.getOperation (), adaptor.getOperands (), * this -> getTypeConverter () ,
108
+ op.getOperation (), adaptor.getOperands (), typeConverter ,
104
109
[&](Type llvm1DVectorTy, ValueRange operands) {
105
110
return rewriter.create <LLVMOp>(loc, llvm1DVectorTy, operands[0 ],
106
111
false );
@@ -123,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
123
128
LogicalResult
124
129
matchAndRewrite (math::ExpM1Op op, OpAdaptor adaptor,
125
130
ConversionPatternRewriter &rewriter) const override {
131
+ const auto &typeConverter = *this ->getTypeConverter ();
126
132
auto operandType = adaptor.getOperand ().getType ();
127
-
128
- if (!operandType || ! LLVM::isCompatibleType (operandType) )
133
+ auto llvmOperandType = typeConverter. convertType (operandType);
134
+ if (!llvmOperandType )
129
135
return failure ();
130
136
131
137
auto loc = op.getLoc ();
132
138
auto resultType = op.getResult ().getType ();
133
- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
139
+ auto floatType = cast<FloatType>(
140
+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
134
141
auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
135
142
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs (op);
136
143
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs (op);
137
144
138
- if (!isa<LLVM::LLVMArrayType>(operandType )) {
145
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
139
146
LLVM::ConstantOp one;
140
- if (LLVM::isCompatibleVectorType (operandType )) {
147
+ if (LLVM::isCompatibleVectorType (llvmOperandType )) {
141
148
one = rewriter.create <LLVM::ConstantOp>(
142
- loc, operandType,
143
- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
149
+ loc, llvmOperandType,
150
+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
151
+ floatOne));
144
152
} else {
145
- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
153
+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
146
154
}
147
155
auto exp = rewriter.create <LLVM::ExpOp>(loc, adaptor.getOperand (),
148
156
expAttrs.getAttrs ());
149
157
rewriter.replaceOpWithNewOp <LLVM::FSubOp>(
150
- op, operandType , ValueRange{exp, one}, subAttrs.getAttrs ());
158
+ op, llvmOperandType , ValueRange{exp, one}, subAttrs.getAttrs ());
151
159
return success ();
152
160
}
153
161
154
- auto vectorType = dyn_cast<VectorType>(resultType);
155
- if (!vectorType)
162
+ if (!isa<VectorType>(resultType))
156
163
return rewriter.notifyMatchFailure (op, " expected vector result type" );
157
164
158
165
return LLVM::detail::handleMultidimensionalVectors (
159
- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
166
+ op.getOperation (), adaptor.getOperands (), typeConverter ,
160
167
[&](Type llvm1DVectorTy, ValueRange operands) {
161
168
auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
162
169
auto splatAttr = SplatElementsAttr::get (
@@ -181,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
181
188
LogicalResult
182
189
matchAndRewrite (math::Log1pOp op, OpAdaptor adaptor,
183
190
ConversionPatternRewriter &rewriter) const override {
191
+ const auto &typeConverter = *this ->getTypeConverter ();
184
192
auto operandType = adaptor.getOperand ().getType ();
185
-
186
- if (!operandType || ! LLVM::isCompatibleType (operandType) )
193
+ auto llvmOperandType = typeConverter. convertType (operandType);
194
+ if (!llvmOperandType )
187
195
return rewriter.notifyMatchFailure (op, " unsupported operand type" );
188
196
189
197
auto loc = op.getLoc ();
190
198
auto resultType = op.getResult ().getType ();
191
- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
199
+ auto floatType = cast<FloatType>(
200
+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
192
201
auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
193
202
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs (op);
194
203
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs (op);
195
204
196
- if (!isa<LLVM::LLVMArrayType>(operandType )) {
205
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
197
206
LLVM::ConstantOp one =
198
- LLVM::isCompatibleVectorType (operandType )
207
+ isa<VectorType>(llvmOperandType )
199
208
? rewriter.create <LLVM::ConstantOp>(
200
- loc, operandType ,
201
- SplatElementsAttr::get (cast<ShapedType>(resultType ),
209
+ loc, llvmOperandType ,
210
+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType ),
202
211
floatOne))
203
- : rewriter.create <LLVM::ConstantOp>(loc, operandType, floatOne);
212
+ : rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType,
213
+ floatOne);
204
214
205
215
auto add = rewriter.create <LLVM::FAddOp>(
206
- loc, operandType , ValueRange{one, adaptor.getOperand ()},
216
+ loc, llvmOperandType , ValueRange{one, adaptor.getOperand ()},
207
217
addAttrs.getAttrs ());
208
- rewriter.replaceOpWithNewOp <LLVM::LogOp>(op, operandType, ValueRange{add},
209
- logAttrs.getAttrs ());
218
+ rewriter.replaceOpWithNewOp <LLVM::LogOp>(
219
+ op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs ());
210
220
return success ();
211
221
}
212
222
213
- auto vectorType = dyn_cast<VectorType>(resultType);
214
- if (!vectorType)
223
+ if (!isa<VectorType>(resultType))
215
224
return rewriter.notifyMatchFailure (op, " expected vector result type" );
216
225
217
226
return LLVM::detail::handleMultidimensionalVectors (
218
- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
227
+ op.getOperation (), adaptor.getOperands (), typeConverter ,
219
228
[&](Type llvm1DVectorTy, ValueRange operands) {
220
229
auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
221
230
auto splatAttr = SplatElementsAttr::get (
@@ -241,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
241
250
LogicalResult
242
251
matchAndRewrite (math::RsqrtOp op, OpAdaptor adaptor,
243
252
ConversionPatternRewriter &rewriter) const override {
253
+ const auto &typeConverter = *this ->getTypeConverter ();
244
254
auto operandType = adaptor.getOperand ().getType ();
245
-
246
- if (!operandType || ! LLVM::isCompatibleType (operandType) )
255
+ auto llvmOperandType = typeConverter. convertType (operandType);
256
+ if (!llvmOperandType )
247
257
return failure ();
248
258
249
259
auto loc = op.getLoc ();
250
260
auto resultType = op.getResult ().getType ();
251
- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
261
+ auto floatType = cast<FloatType>(
262
+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
252
263
auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
253
264
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs (op);
254
265
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs (op);
255
266
256
- if (!isa<LLVM::LLVMArrayType>(operandType )) {
267
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
257
268
LLVM::ConstantOp one;
258
- if (LLVM::isCompatibleVectorType (operandType )) {
269
+ if (isa<VectorType>(llvmOperandType )) {
259
270
one = rewriter.create <LLVM::ConstantOp>(
260
- loc, operandType,
261
- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
271
+ loc, llvmOperandType,
272
+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
273
+ floatOne));
262
274
} else {
263
- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
275
+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
264
276
}
265
277
auto sqrt = rewriter.create <LLVM::SqrtOp>(loc, adaptor.getOperand (),
266
278
sqrtAttrs.getAttrs ());
267
279
rewriter.replaceOpWithNewOp <LLVM::FDivOp>(
268
- op, operandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
280
+ op, llvmOperandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
269
281
return success ();
270
282
}
271
283
272
- auto vectorType = dyn_cast<VectorType>(resultType);
273
- if (!vectorType)
284
+ if (!isa<VectorType>(resultType))
274
285
return failure ();
275
286
276
287
return LLVM::detail::handleMultidimensionalVectors (
277
- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
288
+ op.getOperation (), adaptor.getOperands (), typeConverter ,
278
289
[&](Type llvm1DVectorTy, ValueRange operands) {
279
290
auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
280
291
auto splatAttr = SplatElementsAttr::get (
@@ -298,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
298
309
LogicalResult
299
310
matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
300
311
ConversionPatternRewriter &rewriter) const override {
301
- auto operandType = adaptor.getOperand ().getType ();
302
-
303
- if (!operandType || !LLVM::isCompatibleType (operandType))
312
+ const auto &typeConverter = *this ->getTypeConverter ();
313
+ auto operandType =
314
+ typeConverter.convertType (adaptor.getOperand ().getType ());
315
+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
316
+ if (!operandType || !resultType)
304
317
return failure ();
305
318
306
319
rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
307
- op, op. getType () , adaptor.getOperand (), llvm::fcNan);
320
+ op, resultType , adaptor.getOperand (), llvm::fcNan);
308
321
return success ();
309
322
}
310
323
};
@@ -315,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
315
328
LogicalResult
316
329
matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
317
330
ConversionPatternRewriter &rewriter) const override {
318
- auto operandType = adaptor.getOperand ().getType ();
319
-
320
- if (!operandType || !LLVM::isCompatibleType (operandType))
331
+ const auto &typeConverter = *this ->getTypeConverter ();
332
+ auto operandType =
333
+ typeConverter.convertType (adaptor.getOperand ().getType ());
334
+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
335
+ if (!operandType || !resultType)
321
336
return failure ();
322
337
323
338
rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
324
- op, op. getType () , adaptor.getOperand (), llvm::fcFinite);
339
+ op, resultType , adaptor.getOperand (), llvm::fcFinite);
325
340
return success ();
326
341
}
327
342
};
0 commit comments