Skip to content

Commit 3489e38

Browse files
authored
[SYCL-MLIR] Fix compound assignment operators (#7285)
This commit fixes the following aspects of compound assignment operators: 1. Evaluation order: 1. RHS; 2. LHS; 3. Result. 2. In case of LHS and RHS type mismatch: cast to RHS type, operate, cast back to LHS type and store. 3. Perform _Float16 expansion before operating. 4. Reuse code from the operators being invoked. Signed-off-by: Victor Perez <[email protected]>
1 parent 46b2ba3 commit 3489e38

File tree

10 files changed

+665
-255
lines changed

10 files changed

+665
-255
lines changed

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 293 additions & 21 deletions
Large diffs are not rendered by default.

polygeist/tools/cgeist/Lib/ValueCategory.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,102 @@ template <typename OpTy> inline void warnUnconstrainedCast() {
281281

282282
ValueCategory ValueCategory::FPTrunc(OpBuilder &Builder,
283283
Type PromotionType) const {
284+
assert(val.getType().isa<FloatType>() &&
285+
"Expecting floating point source type");
286+
assert(PromotionType.isa<FloatType>() &&
287+
"Expecting floating point promotion type");
288+
assert(val.getType().getIntOrFloatBitWidth() >=
289+
PromotionType.getIntOrFloatBitWidth() &&
290+
"Source type must be wider than promotion type");
291+
284292
warnUnconstrainedCast<arith::TruncFOp>();
285293
return Cast<arith::TruncFOp>(Builder, PromotionType);
286294
}
287295

288296
ValueCategory ValueCategory::FPExt(OpBuilder &Builder,
289297
Type PromotionType) const {
298+
assert(val.getType().isa<FloatType>() &&
299+
"Expecting floating point source type");
300+
assert(PromotionType.isa<FloatType>() &&
301+
"Expecting floating point promotion type");
302+
assert(val.getType().getIntOrFloatBitWidth() <=
303+
PromotionType.getIntOrFloatBitWidth() &&
304+
"Source type must be narrower than promotion type");
305+
290306
warnUnconstrainedCast<arith::ExtFOp>();
291307
return Cast<arith::ExtFOp>(Builder, PromotionType);
292308
}
309+
310+
ValueCategory ValueCategory::SIToFP(OpBuilder &Builder,
311+
Type PromotionType) const {
312+
assert(val.getType().isa<IntegerType>() && "Expecting int source type");
313+
assert(PromotionType.isa<FloatType>() &&
314+
"Expecting floating point promotion type");
315+
316+
warnUnconstrainedCast<arith::SIToFPOp>();
317+
return Cast<arith::SIToFPOp>(Builder, PromotionType);
318+
}
319+
320+
ValueCategory ValueCategory::UIToFP(OpBuilder &Builder,
321+
Type PromotionType) const {
322+
assert(val.getType().isa<IntegerType>() && "Expecting int source type");
323+
assert(PromotionType.isa<FloatType>() &&
324+
"Expecting floating point promotion type");
325+
326+
warnUnconstrainedCast<arith::UIToFPOp>();
327+
return Cast<arith::UIToFPOp>(Builder, PromotionType);
328+
}
329+
330+
ValueCategory ValueCategory::FPToUI(OpBuilder &Builder,
331+
Type PromotionType) const {
332+
assert(val.getType().isa<FloatType>() &&
333+
"Expecting floating point source type");
334+
assert(PromotionType.isa<IntegerType>() &&
335+
"Expecting integer promotion type");
336+
337+
warnUnconstrainedCast<arith::FPToUIOp>();
338+
return Cast<arith::FPToUIOp>(Builder, PromotionType);
339+
}
340+
341+
ValueCategory ValueCategory::FPToSI(OpBuilder &Builder,
342+
Type PromotionType) const {
343+
assert(val.getType().isa<FloatType>() &&
344+
"Expecting floating point source type");
345+
assert(PromotionType.isa<IntegerType>() &&
346+
"Expecting integer promotion type");
347+
348+
warnUnconstrainedCast<arith::FPToSIOp>();
349+
return Cast<arith::FPToSIOp>(Builder, PromotionType);
350+
}
351+
352+
ValueCategory ValueCategory::IntCast(OpBuilder &Builder, Type PromotionType,
353+
bool IsSigned) const {
354+
assert(val.getType().isa<IntegerType>() && "Expecting integer source type");
355+
assert(PromotionType.isa<IntegerType>() &&
356+
"Expecting integer promotion type");
357+
358+
if (val.getType() == PromotionType)
359+
return *this;
360+
361+
auto SrcIntTy = val.getType().cast<IntegerType>();
362+
auto DstIntTy = PromotionType.cast<IntegerType>();
363+
364+
const unsigned SrcBits = SrcIntTy.getWidth();
365+
const unsigned DstBits = DstIntTy.getWidth();
366+
367+
auto Res = [&]() -> Value {
368+
if (SrcBits == DstBits)
369+
return Builder.createOrFold<arith::BitcastOp>(Builder.getUnknownLoc(),
370+
PromotionType, val);
371+
if (SrcBits > DstBits)
372+
return Builder.createOrFold<arith::TruncIOp>(Builder.getUnknownLoc(),
373+
PromotionType, val);
374+
if (IsSigned)
375+
return Builder.createOrFold<arith::ExtSIOp>(Builder.getUnknownLoc(),
376+
PromotionType, val);
377+
return Builder.createOrFold<arith::ExtUIOp>(Builder.getUnknownLoc(),
378+
PromotionType, val);
379+
}();
380+
381+
return {Res, /*IsReference*/ false};
382+
}

polygeist/tools/cgeist/Lib/ValueCategory.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
class ValueCategory {
1818
private:
1919
template <typename OpTy>
20-
ValueCategory Cast(mlir::OpBuilder &builder, mlir::Type PromotionType) const {
20+
ValueCategory Cast(mlir::OpBuilder &Builder, mlir::Type PromotionType) const {
2121
if (val.getType() == PromotionType)
2222
return *this;
2323
return {
24-
builder.createOrFold<OpTy>(builder.getUnknownLoc(), PromotionType, val),
24+
Builder.createOrFold<OpTy>(Builder.getUnknownLoc(), PromotionType, val),
2525
false};
2626
}
2727

@@ -35,17 +35,27 @@ class ValueCategory {
3535
ValueCategory(mlir::Value val, bool isReference);
3636

3737
// TODO: rename to 'loadVariable'? getValue seems to generic.
38-
mlir::Value getValue(mlir::OpBuilder &builder) const;
39-
void store(mlir::OpBuilder &builder, ValueCategory toStore,
38+
mlir::Value getValue(mlir::OpBuilder &Builder) const;
39+
void store(mlir::OpBuilder &Builder, ValueCategory toStore,
4040
bool isArray) const;
4141
// TODO: rename to storeVariable?
42-
void store(mlir::OpBuilder &builder, mlir::Value toStore) const;
43-
ValueCategory dereference(mlir::OpBuilder &builder) const;
42+
void store(mlir::OpBuilder &Builder, mlir::Value toStore) const;
43+
ValueCategory dereference(mlir::OpBuilder &Builder) const;
4444

45-
ValueCategory FPTrunc(mlir::OpBuilder &builder,
45+
ValueCategory FPTrunc(mlir::OpBuilder &Builder,
4646
mlir::Type PromotionType) const;
4747

48-
ValueCategory FPExt(mlir::OpBuilder &builder, mlir::Type PromotionType) const;
48+
ValueCategory FPExt(mlir::OpBuilder &Builder, mlir::Type PromotionType) const;
49+
ValueCategory IntCast(mlir::OpBuilder &Builder, mlir::Type PromotionType,
50+
bool IsSigned) const;
51+
ValueCategory SIToFP(mlir::OpBuilder &Builder,
52+
mlir::Type PromotionType) const;
53+
ValueCategory UIToFP(mlir::OpBuilder &Builder,
54+
mlir::Type PromotionType) const;
55+
ValueCategory FPToUI(mlir::OpBuilder &Builder,
56+
mlir::Type PromotionType) const;
57+
ValueCategory FPToSI(mlir::OpBuilder &Builder,
58+
mlir::Type PromotionType) const;
4959
};
5060

5161
#endif /* CLANG_MLIR_VALUE_CATEGORY */

polygeist/tools/cgeist/Lib/clang-mlir.cc

Lines changed: 0 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,211 +1472,6 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
14721472
return rhs;
14731473
}
14741474

1475-
case clang::BinaryOperator::Opcode::BO_AddAssign: {
1476-
assert(lhs.isReference);
1477-
auto prev = lhs.getValue(builder);
1478-
1479-
mlir::Value result;
1480-
if (auto postTy = prev.getType().dyn_cast<mlir::FloatType>()) {
1481-
mlir::Value rhsV = rhs.getValue(builder);
1482-
auto prevTy = rhsV.getType().cast<mlir::FloatType>();
1483-
if (prevTy == postTy) {
1484-
} else if (prevTy.getWidth() < postTy.getWidth()) {
1485-
rhsV = builder.create<mlir::arith::ExtFOp>(loc, postTy, rhsV);
1486-
} else {
1487-
rhsV = builder.create<mlir::arith::TruncFOp>(loc, postTy, rhsV);
1488-
}
1489-
assert(rhsV.getType() == prev.getType());
1490-
result = builder.create<AddFOp>(loc, prev, rhsV);
1491-
} else if (auto pt =
1492-
prev.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
1493-
result = builder.create<LLVM::GEPOp>(
1494-
loc, pt, prev, std::vector<mlir::Value>({rhs.getValue(builder)}));
1495-
} else if (auto postTy = prev.getType().dyn_cast<mlir::IntegerType>()) {
1496-
mlir::Value rhsV = rhs.getValue(builder);
1497-
auto prevTy = rhsV.getType().cast<mlir::IntegerType>();
1498-
if (prevTy == postTy) {
1499-
} else if (prevTy.getWidth() < postTy.getWidth()) {
1500-
if (signedType) {
1501-
rhsV = builder.create<arith::ExtSIOp>(loc, postTy, rhsV);
1502-
} else {
1503-
rhsV = builder.create<arith::ExtUIOp>(loc, postTy, rhsV);
1504-
}
1505-
} else {
1506-
rhsV = builder.create<arith::TruncIOp>(loc, postTy, rhsV);
1507-
}
1508-
assert(rhsV.getType() == prev.getType());
1509-
result = builder.create<AddIOp>(loc, prev, rhsV);
1510-
} else if (auto postTy = prev.getType().dyn_cast<mlir::MemRefType>()) {
1511-
mlir::Value rhsV = rhs.getValue(builder);
1512-
auto shape = std::vector<int64_t>(postTy.getShape());
1513-
shape[0] = -1;
1514-
postTy = mlir::MemRefType::get(shape, postTy.getElementType(),
1515-
MemRefLayoutAttrInterface(),
1516-
postTy.getMemorySpace());
1517-
auto ptradd = rhsV;
1518-
ptradd = castToIndex(loc, ptradd);
1519-
result = builder.create<polygeist::SubIndexOp>(loc, postTy, prev, ptradd);
1520-
} else {
1521-
assert(false && "Unsupported add assign type");
1522-
}
1523-
lhs.store(builder, result);
1524-
return lhs;
1525-
}
1526-
case clang::BinaryOperator::Opcode::BO_SubAssign: {
1527-
assert(lhs.isReference);
1528-
auto prev = lhs.getValue(builder);
1529-
1530-
mlir::Value result;
1531-
if (prev.getType().isa<mlir::FloatType>()) {
1532-
auto right = rhs.getValue(builder);
1533-
if (right.getType() != prev.getType()) {
1534-
auto prevTy = right.getType().cast<mlir::FloatType>();
1535-
auto postTy =
1536-
Glob.getTypes().getMLIRType(BO->getType()).cast<mlir::FloatType>();
1537-
1538-
if (prevTy.getWidth() < postTy.getWidth()) {
1539-
right = builder.create<arith::ExtFOp>(loc, postTy, right);
1540-
} else {
1541-
right = builder.create<arith::TruncFOp>(loc, postTy, right);
1542-
}
1543-
}
1544-
if (right.getType() != prev.getType()) {
1545-
BO->dump();
1546-
llvm::errs() << " p:" << prev << " r:" << right << "\n";
1547-
}
1548-
assert(right.getType() == prev.getType());
1549-
result = builder.create<SubFOp>(loc, prev, right);
1550-
} else {
1551-
result = builder.create<SubIOp>(loc, prev, rhs.getValue(builder));
1552-
}
1553-
lhs.store(builder, result);
1554-
return lhs;
1555-
}
1556-
case clang::BinaryOperator::Opcode::BO_MulAssign: {
1557-
assert(lhs.isReference);
1558-
auto prev = lhs.getValue(builder);
1559-
1560-
mlir::Value result;
1561-
if (prev.getType().isa<mlir::FloatType>()) {
1562-
auto right = rhs.getValue(builder);
1563-
if (right.getType() != prev.getType()) {
1564-
auto prevTy = right.getType().cast<mlir::FloatType>();
1565-
auto postTy =
1566-
Glob.getTypes().getMLIRType(BO->getType()).cast<mlir::FloatType>();
1567-
1568-
if (prevTy.getWidth() < postTy.getWidth()) {
1569-
right = builder.create<arith::ExtFOp>(loc, postTy, right);
1570-
} else {
1571-
right = builder.create<arith::TruncFOp>(loc, postTy, right);
1572-
}
1573-
}
1574-
if (right.getType() != prev.getType()) {
1575-
BO->dump();
1576-
llvm::errs() << " p:" << prev << " r:" << right << "\n";
1577-
}
1578-
assert(right.getType() == prev.getType());
1579-
result = builder.create<MulFOp>(loc, prev, right);
1580-
} else {
1581-
result = builder.create<MulIOp>(loc, prev, rhs.getValue(builder));
1582-
}
1583-
lhs.store(builder, result);
1584-
return lhs;
1585-
}
1586-
case clang::BinaryOperator::Opcode::BO_DivAssign: {
1587-
assert(lhs.isReference);
1588-
auto prev = lhs.getValue(builder);
1589-
1590-
mlir::Value result;
1591-
if (prev.getType().isa<mlir::FloatType>()) {
1592-
mlir::Value val = rhs.getValue(builder);
1593-
auto prevTy = val.getType().cast<mlir::FloatType>();
1594-
auto postTy = prev.getType().cast<mlir::FloatType>();
1595-
1596-
if (prevTy.getWidth() < postTy.getWidth()) {
1597-
val = builder.create<arith::ExtFOp>(loc, postTy, val);
1598-
} else if (prevTy.getWidth() > postTy.getWidth()) {
1599-
val = builder.create<arith::TruncFOp>(loc, postTy, val);
1600-
}
1601-
result = builder.create<arith::DivFOp>(loc, prev, val);
1602-
} else {
1603-
if (signedType)
1604-
result =
1605-
builder.create<arith::DivSIOp>(loc, prev, rhs.getValue(builder));
1606-
else
1607-
result =
1608-
builder.create<arith::DivUIOp>(loc, prev, rhs.getValue(builder));
1609-
}
1610-
lhs.store(builder, result);
1611-
return lhs;
1612-
}
1613-
case clang::BinaryOperator::Opcode::BO_ShrAssign: {
1614-
assert(lhs.isReference);
1615-
auto prev = lhs.getValue(builder);
1616-
1617-
mlir::Value result;
1618-
1619-
if (signedType)
1620-
result = builder.create<ShRSIOp>(loc, prev, rhs.getValue(builder));
1621-
else
1622-
result = builder.create<ShRUIOp>(loc, prev, rhs.getValue(builder));
1623-
lhs.store(builder, result);
1624-
return lhs;
1625-
}
1626-
case clang::BinaryOperator::Opcode::BO_ShlAssign: {
1627-
assert(lhs.isReference);
1628-
auto prev = lhs.getValue(builder);
1629-
1630-
mlir::Value result =
1631-
builder.create<ShLIOp>(loc, prev, rhs.getValue(builder));
1632-
lhs.store(builder, result);
1633-
return lhs;
1634-
}
1635-
case clang::BinaryOperator::Opcode::BO_RemAssign: {
1636-
assert(lhs.isReference);
1637-
auto prev = lhs.getValue(builder);
1638-
1639-
mlir::Value result;
1640-
1641-
if (prev.getType().isa<mlir::FloatType>()) {
1642-
result = builder.create<RemFOp>(loc, prev, rhs.getValue(builder));
1643-
} else {
1644-
if (signedType)
1645-
result = builder.create<RemSIOp>(loc, prev, rhs.getValue(builder));
1646-
else
1647-
result = builder.create<RemUIOp>(loc, prev, rhs.getValue(builder));
1648-
}
1649-
lhs.store(builder, result);
1650-
return lhs;
1651-
}
1652-
case clang::BinaryOperator::Opcode::BO_AndAssign: {
1653-
assert(lhs.isReference);
1654-
auto prev = lhs.getValue(builder);
1655-
1656-
mlir::Value result =
1657-
builder.create<AndIOp>(loc, prev, rhs.getValue(builder));
1658-
lhs.store(builder, result);
1659-
return lhs;
1660-
}
1661-
case clang::BinaryOperator::Opcode::BO_OrAssign: {
1662-
assert(lhs.isReference);
1663-
auto prev = lhs.getValue(builder);
1664-
1665-
mlir::Value result =
1666-
builder.create<OrIOp>(loc, prev, rhs.getValue(builder));
1667-
lhs.store(builder, result);
1668-
return lhs;
1669-
}
1670-
case clang::BinaryOperator::Opcode::BO_XorAssign: {
1671-
assert(lhs.isReference);
1672-
auto prev = lhs.getValue(builder);
1673-
1674-
mlir::Value result =
1675-
builder.create<XOrIOp>(loc, prev, rhs.getValue(builder));
1676-
lhs.store(builder, result);
1677-
return lhs;
1678-
}
1679-
16801475
default: {
16811476
BO->dump();
16821477
assert(0 && "unhandled opcode");

0 commit comments

Comments
 (0)