Skip to content

Commit d356268

Browse files
authored
[SYCL-MLIR] Generate i1 for scalar Boolean values (#8010)
Keep using i8 for memory Boolean values. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 37eb288 commit d356268

File tree

14 files changed

+402
-86
lines changed

14 files changed

+402
-86
lines changed

polygeist/tools/cgeist/Lib/CGDecl.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using namespace mlir;
2626

2727
ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *Decl) {
2828
Decl = Decl->getCanonicalDecl();
29-
mlir::Type SubType = Glob.getTypes().getMLIRType(Decl->getType());
29+
mlir::Type SubType = Glob.getTypes().getMLIRTypeForMem(Decl->getType());
3030
const unsigned MemType = Decl->hasAttr<clang::CUDASharedAttr>() ? 5 : 0;
3131
bool LLVMABI = false, IsArray = false;
3232

@@ -72,7 +72,12 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *Decl) {
7272
Init->dump();
7373
assert(false);
7474
}
75-
SubType = InitExpr.val.getType();
75+
const auto InitType = InitExpr.val.getType();
76+
const auto IsNotBoolean = !InitType.isInteger(1);
77+
assert((IsNotBoolean || SubType.isInteger(8)) &&
78+
"Wrong Boolean initialization");
79+
if (IsNotBoolean)
80+
SubType = InitType;
7681
}
7782
}
7883
} else if (auto *Ava = Decl->getAttr<clang::AlignValueAttr>()) {

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,14 @@ mlir::Attribute MLIRScanner::InitializeValueByInitListExpr(mlir::Value ToInit,
338338
ValueCategory(ToInit, /*isReference*/ true).store(Builder, Sub, IsArray);
339339
if (!Sub.isReference)
340340
if (auto MT = ToInit.getType().dyn_cast<MemRefType>()) {
341-
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantIntOp>())
342-
return DenseElementsAttr::get(
343-
RankedTensorType::get(std::vector<int64_t>({1}),
344-
MT.getElementType()),
345-
Cop.getValue());
341+
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantIntOp>()) {
342+
const auto C = Cop.getValue();
343+
const auto CT = C.getType();
344+
const auto ET = MT.getElementType();
345+
assert((CT == ET || (CT.isInteger(1) && ET.isInteger(8))) &&
346+
"Expecting same width but for boolean values");
347+
return DenseElementsAttr::get(RankedTensorType::get(1, CT), C);
348+
}
346349
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantFloatOp>())
347350
return DenseElementsAttr::get(
348351
RankedTensorType::get(std::vector<int64_t>({1}),
@@ -648,12 +651,24 @@ ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *Expr) {
648651
auto Val = Result.val;
649652

650653
if (auto MT = Val.getType().dyn_cast<MemRefType>()) {
654+
auto ET = MT.getElementType();
655+
if (ET.isInteger(1)) {
656+
ET = Builder.getIntegerType(8);
657+
const auto Zero = getConstantIndex(0);
658+
const auto Scalar =
659+
ValueCategory(Builder.create<memref::LoadOp>(Loc, Val, Zero),
660+
/*IsReference*/ false)
661+
.IntCast(Builder, Loc, ET, /*IsSigned*/ false);
662+
Val = Builder.create<memref::AllocaOp>(
663+
Loc, MemRefType::get(1, ET, MT.getLayout(), MT.getMemorySpace()));
664+
Builder.create<memref::StoreOp>(Loc, Scalar.val, Val, Zero);
665+
}
651666
auto Shape = std::vector<int64_t>(MT.getShape());
652667
Shape[0] = ShapedType::kDynamic;
653668
Val = Builder.create<memref::CastOp>(
654669
Loc,
655-
MemRefType::get(Shape, MT.getElementType(),
656-
MemRefLayoutAttrInterface(), MT.getMemorySpace()),
670+
MemRefType::get(Shape, ET, MemRefLayoutAttrInterface(),
671+
MT.getMemorySpace()),
657672
Val);
658673
}
659674

@@ -2630,7 +2645,7 @@ ValueCategory MLIRScanner::EmitPointerArithmetic(const BinOpInfo &Info) {
26302645
return Result.BitCast(Builder, Loc, Pointer.val.getType());
26312646
}
26322647

2633-
auto ElemTy = Glob.getTypes().getMLIRType(ElementType);
2648+
auto ElemTy = Glob.getTypes().getMLIRTypeForMem(ElementType);
26342649
if (CGM.getLangOpts().isSignedOverflowDefined()) {
26352650
if (Optional<Value> NewIndex =
26362651
castSubIndexOpIndex(Builder, Loc, Pointer, Index.val, IsSigned))

polygeist/tools/cgeist/Lib/CodeGenTypes.cc

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,23 @@ void CodeGenTypes::constructAttributeList(
12461246
AttrList.addAttrs(FuncAttrsBuilder, RetAttrsBuilder, ArgAttrs);
12471247
}
12481248

1249+
mlir::Type CodeGenTypes::getMLIRTypeForMem(clang::QualType QT,
1250+
bool *ImplicitRef, bool AllowMerge) {
1251+
assert(!QT->isConstantMatrixType() && "Unsupported type");
1252+
1253+
const auto R = getMLIRType(QT, ImplicitRef, AllowMerge);
1254+
1255+
// TODO: Check for the boolean vector case.
1256+
1257+
// If this is a bool type map this integer to the target-specified size.
1258+
if (!QT->isBitIntType() && R.isInteger(1))
1259+
return mlir::IntegerType::get(TheModule->getContext(),
1260+
Context.getTypeSize(QT));
1261+
1262+
// Else, don't map it.
1263+
return R;
1264+
}
1265+
12491266
mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
12501267
bool AllowMerge) {
12511268
if (const auto *ET = dyn_cast<clang::ElaboratedType>(QT))
@@ -1274,7 +1291,8 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
12741291

12751292
if (const auto *DT = dyn_cast<clang::DecayedType>(QT)) {
12761293
bool AssumeRef = false;
1277-
auto MLIRTy = getMLIRType(DT->getOriginalType(), &AssumeRef, AllowMerge);
1294+
clang::QualType OrigTy = DT->getOriginalType();
1295+
auto MLIRTy = getMLIRType(OrigTy, &AssumeRef, AllowMerge);
12781296
if (MemRefABI && AssumeRef) {
12791297
// Constant array types like `int A[30][20]` will be converted to LLVM
12801298
// type `[20 x i32]* %0`, which has the outermost dimension size erased,
@@ -1283,14 +1301,11 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
12831301
// specifically handle this case by unwrapping the clang-adjusted
12841302
// type, to get the corresponding ConstantArrayType with the full
12851303
// dimensions.
1286-
if (MemRefFullRank) {
1287-
clang::QualType OrigTy = DT->getOriginalType();
1288-
if (OrigTy->isConstantArrayType()) {
1289-
SmallVector<int64_t, 4> Shape;
1290-
clang::QualType ElemTy;
1291-
getConstantArrayShapeAndElemType(OrigTy, Shape, ElemTy);
1292-
return mlir::MemRefType::get(Shape, getMLIRType(ElemTy));
1293-
}
1304+
if (MemRefFullRank && OrigTy->isConstantArrayType()) {
1305+
SmallVector<int64_t, 4> Shape;
1306+
clang::QualType ElemTy;
1307+
getConstantArrayShapeAndElemType(OrigTy, Shape, ElemTy);
1308+
return mlir::MemRefType::get(Shape, getMLIRTypeForMem(ElemTy));
12941309
}
12951310

12961311
// If -memref-fullrank is unset or it cannot be fulfilled.
@@ -1400,7 +1415,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14001415
if (CXRD) {
14011416
for (auto F : CXRD->bases()) {
14021417
bool SubRef = false;
1403-
auto Ty = getMLIRType(F.getType(), &SubRef, /*AllowMerge*/ false);
1418+
auto Ty = getMLIRTypeForMem(F.getType(), &SubRef, /*AllowMerge*/ false);
14041419
assert(!SubRef);
14051420
InnerLLVM |= Ty.isa<LLVM::LLVMPointerType, LLVM::LLVMStructType,
14061421
LLVM::LLVMArrayType>();
@@ -1410,7 +1425,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14101425

14111426
for (auto *F : RT->getDecl()->fields()) {
14121427
bool SubRef = false;
1413-
auto Ty = getMLIRType(F->getType(), &SubRef, /*AllowMerge*/ false);
1428+
auto Ty = getMLIRTypeForMem(F->getType(), &SubRef, /*AllowMerge*/ false);
14141429
assert(!SubRef);
14151430
InnerLLVM |= Ty.isa<LLVM::LLVMPointerType, LLVM::LLVMStructType,
14161431
LLVM::LLVMArrayType>();
@@ -1441,7 +1456,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14411456
}
14421457

14431458
bool SubRef = false;
1444-
auto ET = getMLIRType(AT->getElementType(), &SubRef, AllowMerge);
1459+
auto ET = getMLIRTypeForMem(AT->getElementType(), &SubRef, AllowMerge);
14451460
int64_t Size = ShapedType::kDynamic;
14461461
if (const auto *CAT = dyn_cast<clang::ConstantArrayType>(AT))
14471462
Size = CAT->getSize().getZExtValue();
@@ -1508,7 +1523,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
15081523
}
15091524

15101525
bool SubRef = false;
1511-
auto SubType = getMLIRType(PointeeType, &SubRef, /*AllowMerge*/ true);
1526+
auto SubType = getMLIRTypeForMem(PointeeType, &SubRef, /*AllowMerge*/ true);
15121527

15131528
if (!MemRefABI ||
15141529
SubType.isa<LLVM::LLVMArrayType, LLVM::LLVMStructType,
@@ -1599,8 +1614,8 @@ mlir::Type CodeGenTypes::getMLIRType(const clang::BuiltinType *BT) const {
15991614
return Builder.getIntegerType(8);
16001615

16011616
case BuiltinType::Bool:
1602-
// TODO: boolean types should be represented as i1 rather than i8.
1603-
return Builder.getIntegerType(8);
1617+
// Note that we always return bool as i1 for use as a scalar type.
1618+
return Builder.getIntegerType(1);
16041619

16051620
case BuiltinType::Char_S:
16061621
case BuiltinType::Char_U:

polygeist/tools/cgeist/Lib/CodeGenTypes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ class CodeGenTypes {
8383
unsigned &CallingConv, bool AttrOnCallSite,
8484
bool IsThunk);
8585

86+
/// Convert type T into an mlir::Type.
87+
///
88+
/// This differs from getMLIRType in that it is used to convert to the memory
89+
/// representation for a type. For example, the scalar representation for
90+
/// _Bool is i1, but the memory representation is usually i8 or i32, depending
91+
/// on the target.
92+
mlir::Type getMLIRTypeForMem(clang::QualType QT, bool *ImplicitRef = nullptr,
93+
bool AllowMerge = true);
8694
// TODO: Possibly create a SYCLTypeCache
8795
mlir::Type getMLIRType(clang::QualType QT, bool *ImplicitRef = nullptr,
8896
bool AllowMerge = true);

polygeist/tools/cgeist/Lib/TypeUtils.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
176176
llvm::SmallVector<mlir::Type, 4> Body;
177177

178178
for (const auto *Field : RD->fields())
179-
Body.push_back(CGT.getMLIRType(Field->getType()));
179+
Body.push_back(CGT.getMLIRTypeForMem(Field->getType()));
180180

181181
if (const auto *CTS =
182182
llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(RD)) {
@@ -193,7 +193,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
193193
}
194194
case TypeEnum::Accessor: {
195195
const auto Type =
196-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
196+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
197197
const auto Dim =
198198
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
199199
const auto MemAccessMode = static_cast<mlir::sycl::MemoryAccessMode>(
@@ -206,7 +206,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
206206
// TODO: we should push the non-empty base classes in a more general way.
207207
if (MemTargetMode == mlir::sycl::MemoryTargetMode::Local) {
208208
assert(Body.empty());
209-
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
209+
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
210210
}
211211

212212
return mlir::sycl::AccessorType::get(CGT.getModule()->getContext(), Type,
@@ -227,7 +227,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
227227
}
228228
case TypeEnum::Atomic: {
229229
const auto Type =
230-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
230+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
231231
const int AddrSpace =
232232
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
233233
return mlir::sycl::AtomicType::get(
@@ -236,12 +236,12 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
236236
}
237237
case TypeEnum::GetOp: {
238238
const auto Type =
239-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
239+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
240240
return mlir::sycl::GetOpType::get(CGT.getModule()->getContext(), Type);
241241
}
242242
case TypeEnum::GetScalarOp: {
243243
const auto Type =
244-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
244+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
245245
return mlir::sycl::GetScalarOpType::get(CGT.getModule()->getContext(),
246246
Type, Body);
247247
}
@@ -260,7 +260,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
260260
case TypeEnum::ID: {
261261
const auto Dim =
262262
CTS->getTemplateArgs().get(0).getAsIntegral().getExtValue();
263-
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
263+
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
264264
return mlir::sycl::IDType::get(CGT.getModule()->getContext(), Dim, Body);
265265
}
266266
case TypeEnum::ItemBase: {
@@ -287,7 +287,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
287287
}
288288
case TypeEnum::LocalAccessorBase: {
289289
const auto Type =
290-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
290+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
291291
const auto Dim =
292292
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
293293
const auto MemAccessMode = static_cast<mlir::sycl::MemoryAccessMode>(
@@ -297,10 +297,10 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
297297
}
298298
case TypeEnum::LocalAccessor: {
299299
const auto Type =
300-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
300+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
301301
const auto Dim =
302302
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
303-
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
303+
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
304304
return mlir::sycl::LocalAccessorType::get(CGT.getModule()->getContext(),
305305
Type, Dim, Body);
306306
}
@@ -316,7 +316,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
316316
}
317317
case TypeEnum::MultiPtr: {
318318
const auto Type =
319-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
319+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
320320
const int AddrSpace =
321321
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
322322
const int DecAccess =
@@ -343,36 +343,36 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
343343
case TypeEnum::Range: {
344344
const auto Dim =
345345
CTS->getTemplateArgs().get(0).getAsIntegral().getExtValue();
346-
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
346+
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
347347
return mlir::sycl::RangeType::get(CGT.getModule()->getContext(), Dim,
348348
Body);
349349
}
350350
case TypeEnum::TupleCopyAssignableValueHolder: {
351351
const auto Type =
352-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
352+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
353353
const auto IsTriviallyCopyAssignable =
354354
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
355-
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
355+
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
356356
return mlir::sycl::TupleCopyAssignableValueHolderType::get(
357357
CGT.getModule()->getContext(), Type, IsTriviallyCopyAssignable, Body);
358358
}
359359
case TypeEnum::TupleValueHolder: {
360360
const auto Type =
361-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
361+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
362362
return mlir::sycl::TupleValueHolderType::get(
363363
CGT.getModule()->getContext(), Type, Body);
364364
}
365365
case TypeEnum::Vec: {
366366
const auto ElemType =
367-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
367+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
368368
const auto NumElems =
369369
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
370370
return mlir::sycl::VecType::get(CGT.getModule()->getContext(), ElemType,
371371
NumElems, Body);
372372
}
373373
case TypeEnum::SwizzleOp: {
374374
const auto VecType =
375-
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType())
375+
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType())
376376
.cast<mlir::sycl::VecType>();
377377
const auto IndexesArgs = CTS->getTemplateArgs().get(4).getPackAsArray();
378378
SmallVector<int> Indexes;

polygeist/tools/cgeist/Lib/clang-mlir.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ MLIRASTConsumer::getOrCreateGlobal(const clang::ValueDecl &VD,
17401740
return Globals[Name];
17411741

17421742
const bool IsArray = isa<clang::ArrayType>(VD.getType());
1743-
const Type MLIRType = getTypes().getMLIRType(VD.getType());
1743+
const Type MLIRType = getTypes().getMLIRTypeForMem(VD.getType());
17441744
const clang::VarDecl *Var = cast<clang::VarDecl>(VD).getCanonicalDecl();
17451745
const unsigned MemSpace =
17461746
CGM.getContext().getTargetAddressSpace(CGM.GetGlobalVarAddressSpace(Var));
@@ -1840,10 +1840,18 @@ MLIRASTConsumer::getOrCreateGlobal(const clang::ValueDecl &VD,
18401840

18411841
auto Op = VC.val.getDefiningOp<arith::ConstantOp>();
18421842
assert(Op && "Could not find the initializer constant expression");
1843+
const auto IT = Op.getType();
1844+
const auto ET = VarTy.getElementType();
1845+
if (IT != ET) {
1846+
assert(IT.isInteger(1) && ET.isInteger(8) &&
1847+
"Expecting same width but for boolean values");
1848+
Op = VC.IntCast(Builder, Op.getLoc(), ET, false)
1849+
.val.getDefiningOp<arith::ConstantOp>();
1850+
assert(Op && "Folding failed");
1851+
}
18431852

18441853
auto InitialVal = SplatElementsAttr::get(
1845-
RankedTensorType::get(VarTy.getShape(), VarTy.getElementType()),
1846-
Op.getValue());
1854+
RankedTensorType::get(VarTy.getShape(), ET), Op.getValue());
18471855
GlobalOp.setInitialValueAttr(InitialVal);
18481856
}
18491857
}

0 commit comments

Comments
 (0)