Skip to content

[SYCL-MLIR] Generate i1 for scalar Boolean values #8010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 23, 2023
Merged
9 changes: 7 additions & 2 deletions polygeist/tools/cgeist/Lib/CGDecl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using namespace mlir;

ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *Decl) {
Decl = Decl->getCanonicalDecl();
mlir::Type SubType = Glob.getTypes().getMLIRType(Decl->getType());
mlir::Type SubType = Glob.getTypes().getMLIRTypeForMem(Decl->getType());
const unsigned MemType = Decl->hasAttr<clang::CUDASharedAttr>() ? 5 : 0;
bool LLVMABI = false, IsArray = false;

Expand Down Expand Up @@ -72,7 +72,12 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *Decl) {
Init->dump();
assert(false);
}
SubType = InitExpr.val.getType();
const auto InitType = InitExpr.val.getType();
const auto IsNotBoolean = !InitType.isInteger(1);
assert((IsNotBoolean || SubType.isInteger(8)) &&
"Wrong Boolean initialization");
if (IsNotBoolean)
SubType = InitType;
}
}
} else if (auto *Ava = Decl->getAttr<clang::AlignValueAttr>()) {
Expand Down
31 changes: 23 additions & 8 deletions polygeist/tools/cgeist/Lib/CGExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,14 @@ mlir::Attribute MLIRScanner::InitializeValueByInitListExpr(mlir::Value ToInit,
ValueCategory(ToInit, /*isReference*/ true).store(Builder, Sub, IsArray);
if (!Sub.isReference)
if (auto MT = ToInit.getType().dyn_cast<MemRefType>()) {
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantIntOp>())
return DenseElementsAttr::get(
RankedTensorType::get(std::vector<int64_t>({1}),
MT.getElementType()),
Cop.getValue());
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantIntOp>()) {
const auto C = Cop.getValue();
const auto CT = C.getType();
const auto ET = MT.getElementType();
assert((CT == ET || (CT.isInteger(1) && ET.isInteger(8))) &&
"Expecting same width but for boolean values");
return DenseElementsAttr::get(RankedTensorType::get(1, CT), C);
}
if (auto Cop = Sub.val.getDefiningOp<arith::ConstantFloatOp>())
return DenseElementsAttr::get(
RankedTensorType::get(std::vector<int64_t>({1}),
Expand Down Expand Up @@ -648,12 +651,24 @@ ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *Expr) {
auto Val = Result.val;

if (auto MT = Val.getType().dyn_cast<MemRefType>()) {
auto ET = MT.getElementType();
if (ET.isInteger(1)) {
ET = Builder.getIntegerType(8);
const auto Zero = getConstantIndex(0);
const auto Scalar =
ValueCategory(Builder.create<memref::LoadOp>(Loc, Val, Zero),
/*IsReference*/ false)
.IntCast(Builder, Loc, ET, /*IsSigned*/ false);
Val = Builder.create<memref::AllocaOp>(
Loc, MemRefType::get(1, ET, MT.getLayout(), MT.getMemorySpace()));
Builder.create<memref::StoreOp>(Loc, Scalar.val, Val, Zero);
}
auto Shape = std::vector<int64_t>(MT.getShape());
Shape[0] = ShapedType::kDynamic;
Val = Builder.create<memref::CastOp>(
Loc,
MemRefType::get(Shape, MT.getElementType(),
MemRefLayoutAttrInterface(), MT.getMemorySpace()),
MemRefType::get(Shape, ET, MemRefLayoutAttrInterface(),
MT.getMemorySpace()),
Val);
}

Expand Down Expand Up @@ -2630,7 +2645,7 @@ ValueCategory MLIRScanner::EmitPointerArithmetic(const BinOpInfo &Info) {
return Result.BitCast(Builder, Loc, Pointer.val.getType());
}

auto ElemTy = Glob.getTypes().getMLIRType(ElementType);
auto ElemTy = Glob.getTypes().getMLIRTypeForMem(ElementType);
if (CGM.getLangOpts().isSignedOverflowDefined()) {
if (Optional<Value> NewIndex =
castSubIndexOpIndex(Builder, Loc, Pointer, Index.val, IsSigned))
Expand Down
45 changes: 30 additions & 15 deletions polygeist/tools/cgeist/Lib/CodeGenTypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,23 @@ void CodeGenTypes::constructAttributeList(
AttrList.addAttrs(FuncAttrsBuilder, RetAttrsBuilder, ArgAttrs);
}

mlir::Type CodeGenTypes::getMLIRTypeForMem(clang::QualType QT,
bool *ImplicitRef, bool AllowMerge) {
assert(!QT->isConstantMatrixType() && "Unsupported type");

const auto R = getMLIRType(QT, ImplicitRef, AllowMerge);

// TODO: Check for the boolean vector case.

// If this is a bool type map this integer to the target-specified size.
if (!QT->isBitIntType() && R.isInteger(1))
return mlir::IntegerType::get(TheModule->getContext(),
Context.getTypeSize(QT));

// Else, don't map it.
return R;
}

mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
bool AllowMerge) {
if (const auto *ET = dyn_cast<clang::ElaboratedType>(QT))
Expand Down Expand Up @@ -1274,7 +1291,8 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,

if (const auto *DT = dyn_cast<clang::DecayedType>(QT)) {
bool AssumeRef = false;
auto MLIRTy = getMLIRType(DT->getOriginalType(), &AssumeRef, AllowMerge);
clang::QualType OrigTy = DT->getOriginalType();
auto MLIRTy = getMLIRType(OrigTy, &AssumeRef, AllowMerge);
if (MemRefABI && AssumeRef) {
// Constant array types like `int A[30][20]` will be converted to LLVM
// type `[20 x i32]* %0`, which has the outermost dimension size erased,
Expand All @@ -1283,14 +1301,11 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
// specifically handle this case by unwrapping the clang-adjusted
// type, to get the corresponding ConstantArrayType with the full
// dimensions.
if (MemRefFullRank) {
clang::QualType OrigTy = DT->getOriginalType();
if (OrigTy->isConstantArrayType()) {
SmallVector<int64_t, 4> Shape;
clang::QualType ElemTy;
getConstantArrayShapeAndElemType(OrigTy, Shape, ElemTy);
return mlir::MemRefType::get(Shape, getMLIRType(ElemTy));
}
if (MemRefFullRank && OrigTy->isConstantArrayType()) {
SmallVector<int64_t, 4> Shape;
clang::QualType ElemTy;
getConstantArrayShapeAndElemType(OrigTy, Shape, ElemTy);
return mlir::MemRefType::get(Shape, getMLIRTypeForMem(ElemTy));
}

// If -memref-fullrank is unset or it cannot be fulfilled.
Expand Down Expand Up @@ -1400,7 +1415,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
if (CXRD) {
for (auto F : CXRD->bases()) {
bool SubRef = false;
auto Ty = getMLIRType(F.getType(), &SubRef, /*AllowMerge*/ false);
auto Ty = getMLIRTypeForMem(F.getType(), &SubRef, /*AllowMerge*/ false);
assert(!SubRef);
InnerLLVM |= Ty.isa<LLVM::LLVMPointerType, LLVM::LLVMStructType,
LLVM::LLVMArrayType>();
Expand All @@ -1410,7 +1425,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,

for (auto *F : RT->getDecl()->fields()) {
bool SubRef = false;
auto Ty = getMLIRType(F->getType(), &SubRef, /*AllowMerge*/ false);
auto Ty = getMLIRTypeForMem(F->getType(), &SubRef, /*AllowMerge*/ false);
assert(!SubRef);
InnerLLVM |= Ty.isa<LLVM::LLVMPointerType, LLVM::LLVMStructType,
LLVM::LLVMArrayType>();
Expand Down Expand Up @@ -1441,7 +1456,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
}

bool SubRef = false;
auto ET = getMLIRType(AT->getElementType(), &SubRef, AllowMerge);
auto ET = getMLIRTypeForMem(AT->getElementType(), &SubRef, AllowMerge);
int64_t Size = ShapedType::kDynamic;
if (const auto *CAT = dyn_cast<clang::ConstantArrayType>(AT))
Size = CAT->getSize().getZExtValue();
Expand Down Expand Up @@ -1508,7 +1523,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
}

bool SubRef = false;
auto SubType = getMLIRType(PointeeType, &SubRef, /*AllowMerge*/ true);
auto SubType = getMLIRTypeForMem(PointeeType, &SubRef, /*AllowMerge*/ true);

if (!MemRefABI ||
SubType.isa<LLVM::LLVMArrayType, LLVM::LLVMStructType,
Expand Down Expand Up @@ -1599,8 +1614,8 @@ mlir::Type CodeGenTypes::getMLIRType(const clang::BuiltinType *BT) const {
return Builder.getIntegerType(8);

case BuiltinType::Bool:
// TODO: boolean types should be represented as i1 rather than i8.
return Builder.getIntegerType(8);
// Note that we always return bool as i1 for use as a scalar type.
return Builder.getIntegerType(1);

case BuiltinType::Char_S:
case BuiltinType::Char_U:
Expand Down
8 changes: 8 additions & 0 deletions polygeist/tools/cgeist/Lib/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ class CodeGenTypes {
unsigned &CallingConv, bool AttrOnCallSite,
bool IsThunk);

/// Convert type T into an mlir::Type.
///
/// This differs from getMLIRType in that it is used to convert to the memory
/// representation for a type. For example, the scalar representation for
/// _Bool is i1, but the memory representation is usually i8 or i32, depending
/// on the target.
mlir::Type getMLIRTypeForMem(clang::QualType QT, bool *ImplicitRef = nullptr,
bool AllowMerge = true);
// TODO: Possibly create a SYCLTypeCache
mlir::Type getMLIRType(clang::QualType QT, bool *ImplicitRef = nullptr,
bool AllowMerge = true);
Expand Down
34 changes: 17 additions & 17 deletions polygeist/tools/cgeist/Lib/TypeUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
llvm::SmallVector<mlir::Type, 4> Body;

for (const auto *Field : RD->fields())
Body.push_back(CGT.getMLIRType(Field->getType()));
Body.push_back(CGT.getMLIRTypeForMem(Field->getType()));

if (const auto *CTS =
llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(RD)) {
Expand All @@ -193,7 +193,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::Accessor: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const auto Dim =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
const auto MemAccessMode = static_cast<mlir::sycl::MemoryAccessMode>(
Expand All @@ -206,7 +206,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
// TODO: we should push the non-empty base classes in a more general way.
if (MemTargetMode == mlir::sycl::MemoryTargetMode::Local) {
assert(Body.empty());
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
}

return mlir::sycl::AccessorType::get(CGT.getModule()->getContext(), Type,
Expand All @@ -227,7 +227,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::Atomic: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const int AddrSpace =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
return mlir::sycl::AtomicType::get(
Expand All @@ -236,12 +236,12 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::GetOp: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
return mlir::sycl::GetOpType::get(CGT.getModule()->getContext(), Type);
}
case TypeEnum::GetScalarOp: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
return mlir::sycl::GetScalarOpType::get(CGT.getModule()->getContext(),
Type, Body);
}
Expand All @@ -260,7 +260,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
case TypeEnum::ID: {
const auto Dim =
CTS->getTemplateArgs().get(0).getAsIntegral().getExtValue();
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
return mlir::sycl::IDType::get(CGT.getModule()->getContext(), Dim, Body);
}
case TypeEnum::ItemBase: {
Expand All @@ -287,7 +287,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::LocalAccessorBase: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const auto Dim =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
const auto MemAccessMode = static_cast<mlir::sycl::MemoryAccessMode>(
Expand All @@ -297,10 +297,10 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::LocalAccessor: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const auto Dim =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
return mlir::sycl::LocalAccessorType::get(CGT.getModule()->getContext(),
Type, Dim, Body);
}
Expand All @@ -316,7 +316,7 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
}
case TypeEnum::MultiPtr: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const int AddrSpace =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
const int DecAccess =
Expand All @@ -343,36 +343,36 @@ mlir::Type getSYCLType(const clang::RecordType *RT,
case TypeEnum::Range: {
const auto Dim =
CTS->getTemplateArgs().get(0).getAsIntegral().getExtValue();
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
return mlir::sycl::RangeType::get(CGT.getModule()->getContext(), Dim,
Body);
}
case TypeEnum::TupleCopyAssignableValueHolder: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const auto IsTriviallyCopyAssignable =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
Body.push_back(CGT.getMLIRType(CTS->bases_begin()->getType()));
Body.push_back(CGT.getMLIRTypeForMem(CTS->bases_begin()->getType()));
return mlir::sycl::TupleCopyAssignableValueHolderType::get(
CGT.getModule()->getContext(), Type, IsTriviallyCopyAssignable, Body);
}
case TypeEnum::TupleValueHolder: {
const auto Type =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
return mlir::sycl::TupleValueHolderType::get(
CGT.getModule()->getContext(), Type, Body);
}
case TypeEnum::Vec: {
const auto ElemType =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType());
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType());
const auto NumElems =
CTS->getTemplateArgs().get(1).getAsIntegral().getExtValue();
return mlir::sycl::VecType::get(CGT.getModule()->getContext(), ElemType,
NumElems, Body);
}
case TypeEnum::SwizzleOp: {
const auto VecType =
CGT.getMLIRType(CTS->getTemplateArgs().get(0).getAsType())
CGT.getMLIRTypeForMem(CTS->getTemplateArgs().get(0).getAsType())
.cast<mlir::sycl::VecType>();
const auto IndexesArgs = CTS->getTemplateArgs().get(4).getPackAsArray();
SmallVector<int> Indexes;
Expand Down
14 changes: 11 additions & 3 deletions polygeist/tools/cgeist/Lib/clang-mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ MLIRASTConsumer::getOrCreateGlobal(const clang::ValueDecl &VD,
return Globals[Name];

const bool IsArray = isa<clang::ArrayType>(VD.getType());
const Type MLIRType = getTypes().getMLIRType(VD.getType());
const Type MLIRType = getTypes().getMLIRTypeForMem(VD.getType());
const clang::VarDecl *Var = cast<clang::VarDecl>(VD).getCanonicalDecl();
const unsigned MemSpace =
CGM.getContext().getTargetAddressSpace(CGM.GetGlobalVarAddressSpace(Var));
Expand Down Expand Up @@ -1840,10 +1840,18 @@ MLIRASTConsumer::getOrCreateGlobal(const clang::ValueDecl &VD,

auto Op = VC.val.getDefiningOp<arith::ConstantOp>();
assert(Op && "Could not find the initializer constant expression");
const auto IT = Op.getType();
const auto ET = VarTy.getElementType();
if (IT != ET) {
assert(IT.isInteger(1) && ET.isInteger(8) &&
"Expecting same width but for boolean values");
Op = VC.IntCast(Builder, Op.getLoc(), ET, false)
.val.getDefiningOp<arith::ConstantOp>();
assert(Op && "Folding failed");
}

auto InitialVal = SplatElementsAttr::get(
RankedTensorType::get(VarTy.getShape(), VarTy.getElementType()),
Op.getValue());
RankedTensorType::get(VarTy.getShape(), ET), Op.getValue());
GlobalOp.setInitialValueAttr(InitialVal);
}
}
Expand Down
Loading