Skip to content

Commit fc148a4

Browse files
georgemitenkovantiagainst
authored andcommitted
[MLIR][SPIRVToLLVM] Added conversion for SPIR-V comparison ops
Implemented `FComparePattern` and `IComparePattern` classes that provide conversion of SPIR-V comparison ops (such as `spv.FOrdGreaterThanEqual` and others) to LLVM dialect. Also added tests in `comparison-ops-to-llvm.mlir`. Differential Revision: https://reviews.llvm.org/D81487
1 parent 1022b5e commit fc148a4

File tree

2 files changed

+440
-14
lines changed

2 files changed

+440
-14
lines changed

mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,50 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
4747
return success();
4848
}
4949
};
50+
51+
/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
52+
template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
53+
class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
54+
public:
55+
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
56+
57+
LogicalResult
58+
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
59+
ConversionPatternRewriter &rewriter) const override {
60+
61+
auto dstType = this->typeConverter.convertType(operation.getType());
62+
if (!dstType)
63+
return failure();
64+
65+
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
66+
operation, dstType,
67+
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
68+
operation.operand1(), operation.operand2());
69+
return success();
70+
}
71+
};
72+
73+
/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
74+
template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
75+
class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
76+
public:
77+
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
78+
79+
LogicalResult
80+
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
81+
ConversionPatternRewriter &rewriter) const override {
82+
83+
auto dstType = this->typeConverter.convertType(operation.getType());
84+
if (!dstType)
85+
return failure();
86+
87+
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
88+
operation, dstType,
89+
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
90+
operation.operand1(), operation.operand2());
91+
return success();
92+
}
93+
};
5094
} // namespace
5195

5296
//===----------------------------------------------------------------------===//
@@ -56,19 +100,48 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
56100
void mlir::populateSPIRVToLLVMConversionPatterns(
57101
MLIRContext *context, LLVMTypeConverter &typeConverter,
58102
OwningRewritePatternList &patterns) {
59-
patterns.insert<DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
60-
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
61-
DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
62-
DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
63-
DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
64-
DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
65-
DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
66-
DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
67-
DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
68-
DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
69-
DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
70-
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
71-
DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
72-
DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>>(
103+
patterns.insert<
104+
// Arithmetic ops
105+
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
106+
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
107+
DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
108+
DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
109+
DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
110+
DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
111+
DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
112+
DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
113+
DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
114+
DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
115+
DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
116+
117+
// Bitwise ops
118+
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
119+
DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
120+
DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
121+
122+
// Comparison ops
123+
IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
124+
IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
125+
FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
126+
FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
127+
FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
128+
FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
129+
FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
130+
FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
131+
FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
132+
FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
133+
FComparePattern<spirv::FUnordGreaterThanEqualOp,
134+
LLVM::FCmpPredicate::uge>,
135+
FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
136+
FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
137+
FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
138+
IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
139+
IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
140+
IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
141+
IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
142+
IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
143+
IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
144+
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
145+
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>>(
73146
context, typeConverter);
74147
}

0 commit comments

Comments
 (0)