@@ -1148,9 +1148,10 @@ template <typename SourceOp, unsigned OpCount>
1148
1148
void ValidateOpCount () {
1149
1149
OpCountValidator<SourceOp, OpCount>();
1150
1150
}
1151
+ } // namespace
1151
1152
1152
- static LogicalResult HandleMultidimensionalVectors (
1153
- Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
1153
+ static LogicalResult handleMultidimensionalVectors (
1154
+ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
1154
1155
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
1155
1156
ConversionPatternRewriter &rewriter) {
1156
1157
auto vectorType = op->getResult (0 ).getType ().dyn_cast <VectorType>();
@@ -1179,139 +1180,125 @@ static LogicalResult HandleMultidimensionalVectors(
1179
1180
return success ();
1180
1181
}
1181
1182
1182
- // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
1183
- // Ops for N-ary ops with one result. This supports higher-dimensional vector
1184
- // types.
1185
- template <typename SourceOp, typename TargetOp, unsigned OpCount>
1186
- struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern <SourceOp> {
1187
- using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
1188
- using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
1189
-
1190
- // Convert the type of the result to an LLVM type, pass operands as is,
1191
- // preserve attributes.
1192
- LogicalResult
1193
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
1194
- ConversionPatternRewriter &rewriter) const override {
1195
- ValidateOpCount<SourceOp, OpCount>();
1196
- static_assert (
1197
- std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
1198
- " expected single result op" );
1199
- static_assert (std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
1200
- SourceOp>::value,
1201
- " expected same operands and result type" );
1202
-
1203
- // Cannot convert ops if their operands are not of LLVM type.
1204
- for (Value operand : operands) {
1205
- if (!operand || !operand.getType ().isa <LLVM::LLVMType>())
1206
- return failure ();
1207
- }
1183
+ LogicalResult LLVM::detail::vectorOneToOneRewrite (
1184
+ Operation *op, StringRef targetOp, ValueRange operands,
1185
+ LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1186
+ assert (!operands.empty ());
1208
1187
1209
- auto llvmArrayTy = operands[0 ].getType ().cast <LLVM::LLVMType>();
1188
+ // Cannot convert ops if their operands are not of LLVM type.
1189
+ if (!llvm::all_of (operands.getTypes (),
1190
+ [](Type t) { return t.isa <LLVM::LLVMType>(); }))
1191
+ return failure ();
1210
1192
1211
- if (!llvmArrayTy.isArrayTy ()) {
1212
- auto newOp = rewriter.create <TargetOp>(
1213
- op->getLoc (), operands[0 ].getType (), operands, op->getAttrs ());
1214
- rewriter.replaceOp (op, newOp.getResult ());
1215
- return success ();
1216
- }
1193
+ auto llvmArrayTy = operands[0 ].getType ().cast <LLVM::LLVMType>();
1194
+ if (!llvmArrayTy.isArrayTy ())
1195
+ return oneToOneRewrite (op, targetOp, operands, typeConverter, rewriter);
1217
1196
1218
- if (succeeded (HandleMultidimensionalVectors (
1219
- op, operands, this ->typeConverter ,
1220
- [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1221
- return rewriter.create <TargetOp>(op->getLoc (), llvmVectorTy,
1222
- operands, op->getAttrs ());
1223
- },
1224
- rewriter)))
1225
- return success ();
1226
- return failure ();
1227
- }
1228
- };
1197
+ auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
1198
+ ValueRange operands) {
1199
+ OperationState state (op->getLoc (), targetOp);
1200
+ state.addTypes (llvmVectorTy);
1201
+ state.addOperands (operands);
1202
+ state.addAttributes (op->getAttrs ());
1203
+ return rewriter.createOperation (state)->getResult (0 );
1204
+ };
1229
1205
1230
- template <typename SourceOp, typename TargetOp>
1231
- using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1 >;
1232
- template <typename SourceOp, typename TargetOp>
1233
- using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2 >;
1206
+ return handleMultidimensionalVectors (op, operands, typeConverter, callback,
1207
+ rewriter);
1208
+ }
1234
1209
1210
+ namespace {
1235
1211
// Specific lowerings.
1236
1212
// FIXME: this should be tablegen'ed.
1237
- struct AbsFOpLowering : public UnaryOpLLVMOpLowering <AbsFOp, LLVM::FAbsOp> {
1213
+ struct AbsFOpLowering
1214
+ : public VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp> {
1238
1215
using Super::Super;
1239
1216
};
1240
- struct CeilFOpLowering : public UnaryOpLLVMOpLowering <CeilFOp, LLVM::FCeilOp> {
1217
+ struct CeilFOpLowering
1218
+ : public VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp> {
1241
1219
using Super::Super;
1242
1220
};
1243
- struct CosOpLowering : public UnaryOpLLVMOpLowering <CosOp, LLVM::CosOp> {
1221
+ struct CosOpLowering : public VectorConvertToLLVMPattern <CosOp, LLVM::CosOp> {
1244
1222
using Super::Super;
1245
1223
};
1246
- struct ExpOpLowering : public UnaryOpLLVMOpLowering <ExpOp, LLVM::ExpOp> {
1224
+ struct ExpOpLowering : public VectorConvertToLLVMPattern <ExpOp, LLVM::ExpOp> {
1247
1225
using Super::Super;
1248
1226
};
1249
- struct LogOpLowering : public UnaryOpLLVMOpLowering <LogOp, LLVM::LogOp> {
1227
+ struct LogOpLowering : public VectorConvertToLLVMPattern <LogOp, LLVM::LogOp> {
1250
1228
using Super::Super;
1251
1229
};
1252
- struct Log10OpLowering : public UnaryOpLLVMOpLowering <Log10Op, LLVM::Log10Op> {
1230
+ struct Log10OpLowering
1231
+ : public VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op> {
1253
1232
using Super::Super;
1254
1233
};
1255
- struct Log2OpLowering : public UnaryOpLLVMOpLowering <Log2Op, LLVM::Log2Op> {
1234
+ struct Log2OpLowering
1235
+ : public VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op> {
1256
1236
using Super::Super;
1257
1237
};
1258
- struct NegFOpLowering : public UnaryOpLLVMOpLowering <NegFOp, LLVM::FNegOp> {
1238
+ struct NegFOpLowering
1239
+ : public VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp> {
1259
1240
using Super::Super;
1260
1241
};
1261
- struct AddIOpLowering : public BinaryOpLLVMOpLowering <AddIOp, LLVM::AddOp> {
1242
+ struct AddIOpLowering : public VectorConvertToLLVMPattern <AddIOp, LLVM::AddOp> {
1262
1243
using Super::Super;
1263
1244
};
1264
- struct SubIOpLowering : public BinaryOpLLVMOpLowering <SubIOp, LLVM::SubOp> {
1245
+ struct SubIOpLowering : public VectorConvertToLLVMPattern <SubIOp, LLVM::SubOp> {
1265
1246
using Super::Super;
1266
1247
};
1267
- struct MulIOpLowering : public BinaryOpLLVMOpLowering <MulIOp, LLVM::MulOp> {
1248
+ struct MulIOpLowering : public VectorConvertToLLVMPattern <MulIOp, LLVM::MulOp> {
1268
1249
using Super::Super;
1269
1250
};
1270
1251
struct SignedDivIOpLowering
1271
- : public BinaryOpLLVMOpLowering <SignedDivIOp, LLVM::SDivOp> {
1252
+ : public VectorConvertToLLVMPattern <SignedDivIOp, LLVM::SDivOp> {
1272
1253
using Super::Super;
1273
1254
};
1274
- struct SqrtOpLowering : public UnaryOpLLVMOpLowering <SqrtOp, LLVM::SqrtOp> {
1255
+ struct SqrtOpLowering
1256
+ : public VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp> {
1275
1257
using Super::Super;
1276
1258
};
1277
1259
struct UnsignedDivIOpLowering
1278
- : public BinaryOpLLVMOpLowering <UnsignedDivIOp, LLVM::UDivOp> {
1260
+ : public VectorConvertToLLVMPattern <UnsignedDivIOp, LLVM::UDivOp> {
1279
1261
using Super::Super;
1280
1262
};
1281
1263
struct SignedRemIOpLowering
1282
- : public BinaryOpLLVMOpLowering <SignedRemIOp, LLVM::SRemOp> {
1264
+ : public VectorConvertToLLVMPattern <SignedRemIOp, LLVM::SRemOp> {
1283
1265
using Super::Super;
1284
1266
};
1285
1267
struct UnsignedRemIOpLowering
1286
- : public BinaryOpLLVMOpLowering <UnsignedRemIOp, LLVM::URemOp> {
1268
+ : public VectorConvertToLLVMPattern <UnsignedRemIOp, LLVM::URemOp> {
1287
1269
using Super::Super;
1288
1270
};
1289
- struct AndOpLowering : public BinaryOpLLVMOpLowering <AndOp, LLVM::AndOp> {
1271
+ struct AndOpLowering : public VectorConvertToLLVMPattern <AndOp, LLVM::AndOp> {
1290
1272
using Super::Super;
1291
1273
};
1292
- struct OrOpLowering : public BinaryOpLLVMOpLowering <OrOp, LLVM::OrOp> {
1274
+ struct OrOpLowering : public VectorConvertToLLVMPattern <OrOp, LLVM::OrOp> {
1293
1275
using Super::Super;
1294
1276
};
1295
- struct XOrOpLowering : public BinaryOpLLVMOpLowering <XOrOp, LLVM::XOrOp> {
1277
+ struct XOrOpLowering : public VectorConvertToLLVMPattern <XOrOp, LLVM::XOrOp> {
1296
1278
using Super::Super;
1297
1279
};
1298
- struct AddFOpLowering : public BinaryOpLLVMOpLowering <AddFOp, LLVM::FAddOp> {
1280
+ struct AddFOpLowering
1281
+ : public VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp> {
1299
1282
using Super::Super;
1300
1283
};
1301
- struct SubFOpLowering : public BinaryOpLLVMOpLowering <SubFOp, LLVM::FSubOp> {
1284
+ struct SubFOpLowering
1285
+ : public VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp> {
1302
1286
using Super::Super;
1303
1287
};
1304
- struct MulFOpLowering : public BinaryOpLLVMOpLowering <MulFOp, LLVM::FMulOp> {
1288
+ struct MulFOpLowering
1289
+ : public VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp> {
1305
1290
using Super::Super;
1306
1291
};
1307
- struct DivFOpLowering : public BinaryOpLLVMOpLowering <DivFOp, LLVM::FDivOp> {
1292
+ struct DivFOpLowering
1293
+ : public VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp> {
1308
1294
using Super::Super;
1309
1295
};
1310
- struct RemFOpLowering : public BinaryOpLLVMOpLowering <RemFOp, LLVM::FRemOp> {
1296
+ struct RemFOpLowering
1297
+ : public VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp> {
1311
1298
using Super::Super;
1312
1299
};
1313
1300
struct CopySignOpLowering
1314
- : public BinaryOpLLVMOpLowering <CopySignOp, LLVM::CopySignOp> {
1301
+ : public VectorConvertToLLVMPattern <CopySignOp, LLVM::CopySignOp> {
1315
1302
using Super::Super;
1316
1303
};
1317
1304
struct SelectOpLowering
@@ -1695,24 +1682,21 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
1695
1682
if (!vectorType)
1696
1683
return failure ();
1697
1684
1698
- if (succeeded (HandleMultidimensionalVectors (
1699
- op, operands, typeConverter,
1700
- [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1701
- auto splatAttr = SplatElementsAttr::get (
1702
- mlir::VectorType::get ({llvmVectorTy.getUnderlyingType ()
1703
- ->getVectorNumElements ()},
1704
- floatType),
1705
- floatOne);
1706
- auto one = rewriter.create <LLVM::ConstantOp>(loc, llvmVectorTy,
1707
- splatAttr);
1708
- auto sqrt =
1709
- rewriter.create <LLVM::SqrtOp>(loc, llvmVectorTy, operands[0 ]);
1710
- return rewriter.create <LLVM::FDivOp>(loc, llvmVectorTy, one,
1711
- sqrt);
1712
- },
1713
- rewriter)))
1714
- return success ();
1715
- return failure ();
1685
+ return handleMultidimensionalVectors (
1686
+ op, operands, typeConverter,
1687
+ [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1688
+ auto splatAttr = SplatElementsAttr::get (
1689
+ mlir::VectorType::get (
1690
+ {llvmVectorTy.getUnderlyingType ()->getVectorNumElements ()},
1691
+ floatType),
1692
+ floatOne);
1693
+ auto one =
1694
+ rewriter.create <LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
1695
+ auto sqrt =
1696
+ rewriter.create <LLVM::SqrtOp>(loc, llvmVectorTy, operands[0 ]);
1697
+ return rewriter.create <LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
1698
+ },
1699
+ rewriter);
1716
1700
}
1717
1701
};
1718
1702
0 commit comments