@@ -47,6 +47,50 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
47
47
return success ();
48
48
}
49
49
};
50
+
51
+ // / Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
52
+ template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
53
+ class FComparePattern : public SPIRVToLLVMConversion <SPIRVOp> {
54
+ public:
55
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
56
+
57
+ LogicalResult
58
+ matchAndRewrite (SPIRVOp operation, ArrayRef<Value> operands,
59
+ ConversionPatternRewriter &rewriter) const override {
60
+
61
+ auto dstType = this ->typeConverter .convertType (operation.getType ());
62
+ if (!dstType)
63
+ return failure ();
64
+
65
+ rewriter.template replaceOpWithNewOp <LLVM::FCmpOp>(
66
+ operation, dstType,
67
+ rewriter.getI64IntegerAttr (static_cast <int64_t >(predicate)),
68
+ operation.operand1 (), operation.operand2 ());
69
+ return success ();
70
+ }
71
+ };
72
+
73
+ // / Converts SPIR-V integer comparisons to llvm.icmp "predicate"
74
+ template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
75
+ class IComparePattern : public SPIRVToLLVMConversion <SPIRVOp> {
76
+ public:
77
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
78
+
79
+ LogicalResult
80
+ matchAndRewrite (SPIRVOp operation, ArrayRef<Value> operands,
81
+ ConversionPatternRewriter &rewriter) const override {
82
+
83
+ auto dstType = this ->typeConverter .convertType (operation.getType ());
84
+ if (!dstType)
85
+ return failure ();
86
+
87
+ rewriter.template replaceOpWithNewOp <LLVM::ICmpOp>(
88
+ operation, dstType,
89
+ rewriter.getI64IntegerAttr (static_cast <int64_t >(predicate)),
90
+ operation.operand1 (), operation.operand2 ());
91
+ return success ();
92
+ }
93
+ };
50
94
} // namespace
51
95
52
96
// ===----------------------------------------------------------------------===//
@@ -56,19 +100,48 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
56
100
void mlir::populateSPIRVToLLVMConversionPatterns (
57
101
MLIRContext *context, LLVMTypeConverter &typeConverter,
58
102
OwningRewritePatternList &patterns) {
59
- patterns.insert <DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
60
- DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
61
- DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
62
- DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
63
- DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
64
- DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
65
- DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
66
- DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
67
- DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
68
- DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
69
- DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
70
- DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
71
- DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
72
- DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>>(
103
+ patterns.insert <
104
+ // Arithmetic ops
105
+ DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
106
+ DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
107
+ DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
108
+ DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
109
+ DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
110
+ DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
111
+ DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
112
+ DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
113
+ DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
114
+ DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
115
+ DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
116
+
117
+ // Bitwise ops
118
+ DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
119
+ DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
120
+ DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
121
+
122
+ // Comparison ops
123
+ IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
124
+ IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
125
+ FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
126
+ FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
127
+ FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
128
+ FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
129
+ FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
130
+ FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
131
+ FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
132
+ FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
133
+ FComparePattern<spirv::FUnordGreaterThanEqualOp,
134
+ LLVM::FCmpPredicate::uge>,
135
+ FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
136
+ FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
137
+ FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
138
+ IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
139
+ IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
140
+ IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
141
+ IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
142
+ IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
143
+ IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
144
+ IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
145
+ IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>>(
73
146
context, typeConverter);
74
147
}
0 commit comments