Skip to content

Commit 04ed07b

Browse files
committed
[mlir] StandardToLLVM: clean up conversion patterns for vector operations
Summary: Provide a public VectorConvertToLLVMPattern utility class to implement conversions with automatic unrolling of operation on multidimensional vectors to lists of operations on single-dimensional vectors when lowering to the LLVM dialect. Drop the template-based check on the number of operands since the actual implementation does not depend on the operand number anymore. This check only creates spurious concepts (UnaryOpLowering, BinaryOpLowering, etc). Differential Revision: https://reviews.llvm.org/D76865
1 parent 987fbae commit 04ed07b

File tree

2 files changed

+105
-93
lines changed

2 files changed

+105
-93
lines changed

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
416416
ValueRange operands,
417417
LLVMTypeConverter &typeConverter,
418418
ConversionPatternRewriter &rewriter);
419+
420+
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
421+
ValueRange operands,
422+
LLVMTypeConverter &typeConverter,
423+
ConversionPatternRewriter &rewriter);
419424
} // namespace detail
420425
} // namespace LLVM
421426

@@ -441,6 +446,29 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
441446
}
442447
};
443448

449+
/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops
450+
/// with one result. This supports higher-dimensional vector types.
451+
template <typename SourceOp, typename TargetOp>
452+
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
453+
public:
454+
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
455+
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
456+
457+
LogicalResult
458+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
459+
ConversionPatternRewriter &rewriter) const override {
460+
static_assert(
461+
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
462+
"expected single result op");
463+
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
464+
SourceOp>::value,
465+
"expected same operands and result type");
466+
return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(),
467+
operands, this->typeConverter,
468+
rewriter);
469+
}
470+
};
471+
444472
/// Derived class that automatically populates legalization information for
445473
/// different LLVM ops.
446474
class LLVMConversionTarget : public ConversionTarget {

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 77 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,9 +1148,10 @@ template <typename SourceOp, unsigned OpCount>
11481148
void ValidateOpCount() {
11491149
OpCountValidator<SourceOp, OpCount>();
11501150
}
1151+
} // namespace
11511152

1152-
static LogicalResult HandleMultidimensionalVectors(
1153-
Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
1153+
static LogicalResult handleMultidimensionalVectors(
1154+
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
11541155
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
11551156
ConversionPatternRewriter &rewriter) {
11561157
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
@@ -1179,139 +1180,125 @@ static LogicalResult HandleMultidimensionalVectors(
11791180
return success();
11801181
}
11811182

1182-
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
1183-
// Ops for N-ary ops with one result. This supports higher-dimensional vector
1184-
// types.
1185-
template <typename SourceOp, typename TargetOp, unsigned OpCount>
1186-
struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
1187-
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
1188-
using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
1189-
1190-
// Convert the type of the result to an LLVM type, pass operands as is,
1191-
// preserve attributes.
1192-
LogicalResult
1193-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1194-
ConversionPatternRewriter &rewriter) const override {
1195-
ValidateOpCount<SourceOp, OpCount>();
1196-
static_assert(
1197-
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
1198-
"expected single result op");
1199-
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
1200-
SourceOp>::value,
1201-
"expected same operands and result type");
1202-
1203-
// Cannot convert ops if their operands are not of LLVM type.
1204-
for (Value operand : operands) {
1205-
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
1206-
return failure();
1207-
}
1183+
LogicalResult LLVM::detail::vectorOneToOneRewrite(
1184+
Operation *op, StringRef targetOp, ValueRange operands,
1185+
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1186+
assert(!operands.empty());
12081187

1209-
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1188+
// Cannot convert ops if their operands are not of LLVM type.
1189+
if (!llvm::all_of(operands.getTypes(),
1190+
[](Type t) { return t.isa<LLVM::LLVMType>(); }))
1191+
return failure();
12101192

1211-
if (!llvmArrayTy.isArrayTy()) {
1212-
auto newOp = rewriter.create<TargetOp>(
1213-
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
1214-
rewriter.replaceOp(op, newOp.getResult());
1215-
return success();
1216-
}
1193+
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1194+
if (!llvmArrayTy.isArrayTy())
1195+
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
12171196

1218-
if (succeeded(HandleMultidimensionalVectors(
1219-
op, operands, this->typeConverter,
1220-
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1221-
return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
1222-
operands, op->getAttrs());
1223-
},
1224-
rewriter)))
1225-
return success();
1226-
return failure();
1227-
}
1228-
};
1197+
auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
1198+
ValueRange operands) {
1199+
OperationState state(op->getLoc(), targetOp);
1200+
state.addTypes(llvmVectorTy);
1201+
state.addOperands(operands);
1202+
state.addAttributes(op->getAttrs());
1203+
return rewriter.createOperation(state)->getResult(0);
1204+
};
12291205

1230-
template <typename SourceOp, typename TargetOp>
1231-
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
1232-
template <typename SourceOp, typename TargetOp>
1233-
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
1206+
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
1207+
rewriter);
1208+
}
12341209

1210+
namespace {
12351211
// Specific lowerings.
12361212
// FIXME: this should be tablegen'ed.
1237-
struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> {
1213+
struct AbsFOpLowering
1214+
: public VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp> {
12381215
using Super::Super;
12391216
};
1240-
struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> {
1217+
struct CeilFOpLowering
1218+
: public VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp> {
12411219
using Super::Super;
12421220
};
1243-
struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> {
1221+
struct CosOpLowering : public VectorConvertToLLVMPattern<CosOp, LLVM::CosOp> {
12441222
using Super::Super;
12451223
};
1246-
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
1224+
struct ExpOpLowering : public VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp> {
12471225
using Super::Super;
12481226
};
1249-
struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> {
1227+
struct LogOpLowering : public VectorConvertToLLVMPattern<LogOp, LLVM::LogOp> {
12501228
using Super::Super;
12511229
};
1252-
struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> {
1230+
struct Log10OpLowering
1231+
: public VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op> {
12531232
using Super::Super;
12541233
};
1255-
struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> {
1234+
struct Log2OpLowering
1235+
: public VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op> {
12561236
using Super::Super;
12571237
};
1258-
struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> {
1238+
struct NegFOpLowering
1239+
: public VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp> {
12591240
using Super::Super;
12601241
};
1261-
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
1242+
struct AddIOpLowering : public VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp> {
12621243
using Super::Super;
12631244
};
1264-
struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
1245+
struct SubIOpLowering : public VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp> {
12651246
using Super::Super;
12661247
};
1267-
struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
1248+
struct MulIOpLowering : public VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp> {
12681249
using Super::Super;
12691250
};
12701251
struct SignedDivIOpLowering
1271-
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
1252+
: public VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp> {
12721253
using Super::Super;
12731254
};
1274-
struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
1255+
struct SqrtOpLowering
1256+
: public VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp> {
12751257
using Super::Super;
12761258
};
12771259
struct UnsignedDivIOpLowering
1278-
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
1260+
: public VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp> {
12791261
using Super::Super;
12801262
};
12811263
struct SignedRemIOpLowering
1282-
: public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
1264+
: public VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp> {
12831265
using Super::Super;
12841266
};
12851267
struct UnsignedRemIOpLowering
1286-
: public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
1268+
: public VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp> {
12871269
using Super::Super;
12881270
};
1289-
struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
1271+
struct AndOpLowering : public VectorConvertToLLVMPattern<AndOp, LLVM::AndOp> {
12901272
using Super::Super;
12911273
};
1292-
struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
1274+
struct OrOpLowering : public VectorConvertToLLVMPattern<OrOp, LLVM::OrOp> {
12931275
using Super::Super;
12941276
};
1295-
struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
1277+
struct XOrOpLowering : public VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp> {
12961278
using Super::Super;
12971279
};
1298-
struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
1280+
struct AddFOpLowering
1281+
: public VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp> {
12991282
using Super::Super;
13001283
};
1301-
struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
1284+
struct SubFOpLowering
1285+
: public VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp> {
13021286
using Super::Super;
13031287
};
1304-
struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
1288+
struct MulFOpLowering
1289+
: public VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp> {
13051290
using Super::Super;
13061291
};
1307-
struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
1292+
struct DivFOpLowering
1293+
: public VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp> {
13081294
using Super::Super;
13091295
};
1310-
struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
1296+
struct RemFOpLowering
1297+
: public VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp> {
13111298
using Super::Super;
13121299
};
13131300
struct CopySignOpLowering
1314-
: public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> {
1301+
: public VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp> {
13151302
using Super::Super;
13161303
};
13171304
struct SelectOpLowering
@@ -1695,24 +1682,21 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
16951682
if (!vectorType)
16961683
return failure();
16971684

1698-
if (succeeded(HandleMultidimensionalVectors(
1699-
op, operands, typeConverter,
1700-
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1701-
auto splatAttr = SplatElementsAttr::get(
1702-
mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
1703-
->getVectorNumElements()},
1704-
floatType),
1705-
floatOne);
1706-
auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
1707-
splatAttr);
1708-
auto sqrt =
1709-
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
1710-
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
1711-
sqrt);
1712-
},
1713-
rewriter)))
1714-
return success();
1715-
return failure();
1685+
return handleMultidimensionalVectors(
1686+
op, operands, typeConverter,
1687+
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
1688+
auto splatAttr = SplatElementsAttr::get(
1689+
mlir::VectorType::get(
1690+
{llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
1691+
floatType),
1692+
floatOne);
1693+
auto one =
1694+
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
1695+
auto sqrt =
1696+
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
1697+
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
1698+
},
1699+
rewriter);
17161700
}
17171701
};
17181702

0 commit comments

Comments
 (0)