Skip to content

Commit d22034e

Browse files
authored
Merge pull request #140 from Xilinx/liangta.mul
add PDLL mul/div/mod/sub/exp2/log2 operators
2 parents b1eb021 + 93b750b commit d22034e

File tree

10 files changed

+1431
-101
lines changed

10 files changed

+1431
-101
lines changed

mlir/include/mlir/Dialect/PDL/IR/Builtins.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,19 @@ namespace pdl {
2626
void registerBuiltins(PDLPatternModule &pdlPattern);
2727

2828
namespace builtin {
29+
enum class BinaryOpKind {
30+
add,
31+
sub,
32+
mul,
33+
div,
34+
mod,
35+
};
36+
37+
enum class UnaryOpKind {
38+
log2,
39+
exp2,
40+
};
41+
2942
LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
3043
PDLResultList &results,
3144
ArrayRef<PDLValue> args);
@@ -35,8 +48,27 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
3548
Attribute createArrayAttr(PatternRewriter &rewriter);
3649
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
3750
Attribute element);
51+
template <BinaryOpKind T>
52+
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
53+
llvm::ArrayRef<PDLValue> args);
54+
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
55+
llvm::ArrayRef<PDLValue> args);
56+
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
57+
llvm::ArrayRef<PDLValue> args);
58+
LogicalResult mod(PatternRewriter &rewriter, PDLResultList &results,
59+
llvm::ArrayRef<PDLValue> args);
3860
LogicalResult add(PatternRewriter &rewriter, PDLResultList &results,
3961
llvm::ArrayRef<PDLValue> args);
62+
LogicalResult sub(PatternRewriter &rewriter, PDLResultList &results,
63+
llvm::ArrayRef<PDLValue> args);
64+
LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
65+
llvm::ArrayRef<PDLValue> args);
66+
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
67+
llvm::ArrayRef<PDLValue> args);
68+
69+
template <BinaryOpKind T>
70+
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
71+
llvm::ArrayRef<PDLValue> args);
4072
} // namespace builtin
4173
} // namespace pdl
4274
} // namespace mlir

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 228 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
#include <cstdint>
33
#include <llvm/ADT/APFloat.h>
44
#include <llvm/ADT/APInt.h>
5+
#include <llvm/ADT/APSInt.h>
56
#include <llvm/ADT/ArrayRef.h>
7+
#include <llvm/ADT/TypeSwitch.h>
68
#include <llvm/Support/Casting.h>
79
#include <llvm/Support/ErrorHandling.h>
810
#include <mlir/Dialect/PDL/IR/Builtins.h>
@@ -56,52 +58,200 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
5658
return rewriter.getArrayAttr(values);
5759
}
5860

59-
LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
60-
llvm::ArrayRef<mlir::PDLValue> args) {
61-
assert(args.size() == 2 && "Expected 2 arguments");
61+
template <UnaryOpKind T>
62+
LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
63+
ArrayRef<PDLValue> args) {
64+
assert(args.size() == 1 && "Expected one operand for unary operation");
65+
auto operandAttr = args[0].cast<Attribute>();
66+
67+
if (auto operandIntAttr = dyn_cast_or_null<IntegerAttr>(operandAttr)) {
68+
auto integerType = cast<IntegerType>(operandIntAttr.getType());
69+
auto bitWidth = integerType.getIntOrFloatBitWidth();
70+
71+
if constexpr (T == UnaryOpKind::exp2) {
72+
uint64_t resultVal =
73+
integerType.isUnsigned() || integerType.isSignless()
74+
? std::pow(2, operandIntAttr.getValue().getZExtValue())
75+
: std::pow(2, operandIntAttr.getValue().getSExtValue());
76+
77+
APInt resultInt(bitWidth, resultVal, integerType.isSigned());
78+
79+
bool isOverflow = integerType.isSigned()
80+
? resultInt.slt(operandIntAttr.getValue())
81+
: resultInt.ult(operandIntAttr.getValue());
82+
83+
if (isOverflow)
84+
return failure();
85+
86+
results.push_back(rewriter.getIntegerAttr(integerType, resultInt));
87+
} else if constexpr (T == UnaryOpKind::log2) {
88+
auto getIntegerAsAttr = [&](const APSInt &value) {
89+
int32_t log2Value = value.exactLogBase2();
90+
assert(log2Value >= 0 &&
91+
"log2 of an integer is expected to return an exact integer.");
92+
return rewriter.getIntegerAttr(
93+
integerType,
94+
APSInt(APInt(bitWidth, log2Value), integerType.isUnsigned()));
95+
};
96+
// for log2 we treat signless integer as signed
97+
if (integerType.isSignless())
98+
results.push_back(
99+
getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false)));
100+
else
101+
results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt()));
102+
} else {
103+
llvm::llvm_unreachable_internal(
104+
"encountered an unsupported unary operator");
105+
return failure();
106+
}
107+
return success();
108+
}
109+
110+
if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
111+
// auto floatType = operandFloatAttr.getType();
112+
113+
if constexpr (T == UnaryOpKind::exp2) {
114+
// auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
115+
// auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
116+
117+
auto type = operandFloatAttr.getType();
118+
119+
return TypeSwitch<Type, LogicalResult>(type)
120+
.template Case<Float64Type>([&results, &rewriter,
121+
&operandFloatAttr](auto floatType) {
122+
APFloat resultAPFloat(
123+
std::exp2(operandFloatAttr.getValue().convertToDouble()));
124+
125+
// check overflow
126+
if (!resultAPFloat.isNormal())
127+
return failure();
128+
129+
results.push_back(rewriter.getFloatAttr(floatType, resultAPFloat));
130+
return success();
131+
})
132+
.template Case<Float32Type, Float16Type, BFloat16Type>(
133+
[&results, &rewriter, &operandFloatAttr](auto floatType) {
134+
APFloat resultAPFloat(
135+
std::exp2(operandFloatAttr.getValue().convertToFloat()));
136+
137+
// check overflow and underflow
138+
// If overflow happens, resultAPFloat is inf
139+
// If underflow happens, resultAPFloat is 0
140+
if (!resultAPFloat.isNormal())
141+
return failure();
142+
143+
results.push_back(
144+
rewriter.getFloatAttr(floatType, resultAPFloat));
145+
return success();
146+
})
147+
.Default([](Type /*type*/) { return failure(); });
148+
} else if constexpr (T == UnaryOpKind::log2) {
149+
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());
150+
151+
APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
152+
results.push_back(
153+
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
154+
}
155+
return success();
156+
}
157+
return failure();
158+
}
159+
160+
template <BinaryOpKind T>
161+
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
162+
llvm::ArrayRef<PDLValue> args) {
163+
assert(args.size() == 2 && "Expected two operands for binary operation");
62164
auto lhsAttr = args[0].cast<Attribute>();
63165
auto rhsAttr = args[1].cast<Attribute>();
64166

65-
// Integer
66167
if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
67168
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
68-
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType())
169+
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType()) {
69170
return failure();
171+
}
70172

71173
auto integerType = lhsIntAttr.getType();
72-
73-
bool isOverflow;
74-
llvm::APInt resultAPInt;
75-
if (integerType.isUnsignedInteger() || integerType.isSignlessInteger()) {
76-
resultAPInt =
77-
lhsIntAttr.getValue().uadd_ov(rhsIntAttr.getValue(), isOverflow);
174+
APInt resultAPInt;
175+
bool isOverflow = false;
176+
if constexpr (T == BinaryOpKind::add) {
177+
if (integerType.isSignlessInteger() || integerType.isUnsignedInteger()) {
178+
resultAPInt =
179+
lhsIntAttr.getValue().uadd_ov(rhsIntAttr.getValue(), isOverflow);
180+
} else {
181+
resultAPInt =
182+
lhsIntAttr.getValue().sadd_ov(rhsIntAttr.getValue(), isOverflow);
183+
}
184+
} else if constexpr (T == BinaryOpKind::sub) {
185+
if (integerType.isSignlessInteger() || integerType.isUnsignedInteger()) {
186+
resultAPInt =
187+
lhsIntAttr.getValue().usub_ov(rhsIntAttr.getValue(), isOverflow);
188+
} else {
189+
resultAPInt =
190+
lhsIntAttr.getValue().ssub_ov(rhsIntAttr.getValue(), isOverflow);
191+
}
192+
} else if constexpr (T == BinaryOpKind::mul) {
193+
if (integerType.isSignlessInteger() || integerType.isUnsignedInteger()) {
194+
resultAPInt =
195+
lhsIntAttr.getValue().umul_ov(rhsIntAttr.getValue(), isOverflow);
196+
} else {
197+
resultAPInt =
198+
lhsIntAttr.getValue().smul_ov(rhsIntAttr.getValue(), isOverflow);
199+
}
200+
} else if constexpr (T == BinaryOpKind::div) {
201+
if (integerType.isSignlessInteger() || integerType.isUnsignedInteger()) {
202+
resultAPInt = lhsIntAttr.getValue().udiv(rhsIntAttr.getValue());
203+
} else {
204+
resultAPInt =
205+
lhsIntAttr.getValue().sdiv_ov(rhsIntAttr.getValue(), isOverflow);
206+
}
207+
} else if constexpr (T == BinaryOpKind::mod) {
208+
if (integerType.isSignlessInteger() || integerType.isUnsignedInteger()) {
209+
resultAPInt = lhsIntAttr.getValue().urem(rhsIntAttr.getValue());
210+
} else {
211+
resultAPInt = lhsIntAttr.getValue().srem(rhsIntAttr.getValue());
212+
}
78213
} else {
79-
resultAPInt =
80-
lhsIntAttr.getValue().sadd_ov(rhsIntAttr.getValue(), isOverflow);
214+
assert(false && "Unsupported binary operator");
81215
}
82216

83-
if (isOverflow) {
217+
if (isOverflow)
84218
return failure();
85-
}
86219

87220
results.push_back(rewriter.getIntegerAttr(integerType, resultAPInt));
88221
return success();
89222
}
90223

91-
// Float
92224
if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
93225
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
94-
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType())
226+
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType()) {
95227
return failure();
228+
}
96229

97230
APFloat lhsVal = lhsFloatAttr.getValue();
98231
APFloat rhsVal = rhsFloatAttr.getValue();
99232
APFloat resultVal(lhsVal);
100233
auto floatType = lhsFloatAttr.getType();
101234

102-
bool isOverflow =
103-
resultVal.add(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
104-
if (isOverflow) {
235+
APFloat::opStatus operationStatus;
236+
if constexpr (T == BinaryOpKind::add) {
237+
operationStatus =
238+
resultVal.add(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
239+
} else if constexpr (T == BinaryOpKind::sub) {
240+
operationStatus =
241+
resultVal.subtract(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
242+
} else if constexpr (T == BinaryOpKind::mul) {
243+
operationStatus =
244+
resultVal.multiply(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
245+
} else if constexpr (T == BinaryOpKind::div) {
246+
operationStatus =
247+
resultVal.divide(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
248+
} else if constexpr (T == BinaryOpKind::mod) {
249+
operationStatus = resultVal.mod(rhsVal);
250+
} else {
251+
assert(false && "Unsupported binary operator");
252+
}
253+
254+
if (operationStatus != APFloat::opOK) {
105255
return failure();
106256
}
107257

@@ -110,6 +260,41 @@ LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
110260
}
111261
return failure();
112262
}
263+
264+
LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
265+
llvm::ArrayRef<mlir::PDLValue> args) {
266+
return binaryOp<BinaryOpKind::add>(rewriter, results, args);
267+
}
268+
269+
LogicalResult sub(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
270+
llvm::ArrayRef<mlir::PDLValue> args) {
271+
return binaryOp<BinaryOpKind::sub>(rewriter, results, args);
272+
}
273+
274+
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
275+
llvm::ArrayRef<PDLValue> args) {
276+
return binaryOp<BinaryOpKind::mul>(rewriter, results, args);
277+
}
278+
279+
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
280+
llvm::ArrayRef<PDLValue> args) {
281+
return binaryOp<BinaryOpKind::div>(rewriter, results, args);
282+
}
283+
284+
LogicalResult mod(PatternRewriter &rewriter, PDLResultList &results,
285+
ArrayRef<PDLValue> args) {
286+
return binaryOp<BinaryOpKind::mod>(rewriter, results, args);
287+
}
288+
289+
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
290+
llvm::ArrayRef<PDLValue> args) {
291+
return unaryOp<UnaryOpKind::exp2>(rewriter, results, args);
292+
}
293+
294+
LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
295+
llvm::ArrayRef<PDLValue> args) {
296+
return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
297+
}
113298
} // namespace builtin
114299

115300
void registerBuiltins(PDLPatternModule &pdlPattern) {
@@ -128,6 +313,27 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
128313
pdlPattern.registerConstraintFunctionWithResults(
129314
"__builtin_addEntryToDictionaryAttr_constraint",
130315
addEntryToDictionaryAttr);
131-
pdlPattern.registerConstraintFunctionWithResults("__builtin_add", add);
316+
pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul);
317+
pdlPattern.registerRewriteFunction("__builtin_divRewrite", div);
318+
pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod);
319+
pdlPattern.registerRewriteFunction("__builtin_addRewrite", add);
320+
pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub);
321+
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
322+
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);
323+
324+
pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint",
325+
mul);
326+
pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint",
327+
div);
328+
pdlPattern.registerConstraintFunctionWithResults("__builtin_modConstraint",
329+
mod);
330+
pdlPattern.registerConstraintFunctionWithResults("__builtin_addConstraint",
331+
add);
332+
pdlPattern.registerConstraintFunctionWithResults("__builtin_subConstraint",
333+
sub);
334+
pdlPattern.registerConstraintFunctionWithResults("__builtin_log2Constraint",
335+
log2);
336+
pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint",
337+
exp2);
132338
}
133-
} // namespace mlir::pdl
339+
} // namespace mlir::pdl

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ int Lexer::getNextChar() {
139139
return static_cast<unsigned char>(curChar);
140140
case 0: {
141141
// A nul character in the stream is either the end of the current buffer
142-
// or a random nul in the file. Disambiguate that here.
142+
// or a random nul in the file. Disambiguate that here.r
143143
if (curPtr - 1 != curBuffer.end())
144144
return 0;
145145

@@ -197,7 +197,7 @@ Token Lexer::lexToken() {
197197
++curPtr;
198198
return formToken(Token::arrow, tokStart);
199199
}
200-
return emitError(tokStart, "unexpected character");
200+
return formToken(Token::sub, tokStart);
201201
case ':':
202202
return formToken(Token::colon, tokStart);
203203
case ',':
@@ -220,6 +220,10 @@ Token Lexer::lexToken() {
220220
return formToken(Token::l_square, tokStart);
221221
case ']':
222222
return formToken(Token::r_square, tokStart);
223+
case '*':
224+
return formToken(Token::mul, tokStart);
225+
case '%':
226+
return formToken(Token::mod, tokStart);
223227
case '+':
224228
return formToken(Token::add, tokStart);
225229
case '<':
@@ -243,7 +247,7 @@ Token Lexer::lexToken() {
243247
return emitError(tokStart, "unterminated comment, expected '*/'");
244248
continue;
245249
}
246-
return emitError(tokStart, "unexpected character");
250+
return formToken(Token::div, tokStart);
247251

248252
// Ignore whitespace characters.
249253
case 0:
@@ -371,6 +375,8 @@ Token Lexer::lexIdentifier(const char *tokStart) {
371375
.Case("ValueRange", Token::kw_ValueRange)
372376
.Case("with", Token::kw_with)
373377
.Case("_", Token::underscore)
378+
.Case("log2", Token::log2)
379+
.Case("exp2", Token::exp2)
374380
.Default(Token::identifier);
375381
return Token(kind, str);
376382
}

0 commit comments

Comments
 (0)