Skip to content

Commit 01e1256

Browse files
authored
[SYCL-MLIR]: Implement 'getMLIRType' for clang::BuiltinType (#7975)
The current implementation in `getMLIRType` for clang builtin types is incomplete. This PR adds support for most clang::BuiltinType and separate the implementation into a private function. In particular support for OpenCL builtin types is added. This PR also fixes one additional test in the LLVM tests suite (/SYCL/Basic/sampler/sampler.cpp). --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 6e98388 commit 01e1256

File tree

4 files changed

+246
-61
lines changed

4 files changed

+246
-61
lines changed

polygeist/tools/cgeist/Lib/CodeGenTypes.cc

Lines changed: 180 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "TypeUtils.h"
1515
#include "utils.h"
1616

17+
#include "clang/../../lib/CodeGen/CGOpenCLRuntime.h"
1718
#include "clang/../../lib/CodeGen/CodeGenModule.h"
1819
#include "clang/AST/ASTContext.h"
1920
#include "clang/CodeGen/CGFunctionInfo.h"
@@ -1430,11 +1431,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14301431
return LLVM::LLVMStructType::getLiteral(TheModule->getContext(), Types);
14311432
}
14321433

1433-
const auto *T = QT->getUnqualifiedDesugaredType();
1434-
if (T->isVoidType()) {
1435-
mlir::OpBuilder Builder(TheModule->getContext());
1436-
return Builder.getNoneType();
1437-
}
1434+
const clang::Type *T = QT->getUnqualifiedDesugaredType();
14381435

14391436
if (const auto *AT = dyn_cast<clang::ArrayType>(T)) {
14401437
const auto *PTT = AT->getElementType()->getUnqualifiedDesugaredType();
@@ -1477,29 +1474,7 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14771474
bool SubRef = false;
14781475
auto ET = getMLIRType(AT->getElementType(), &SubRef, AllowMerge);
14791476
int64_t Size = AT->getNumElements();
1480-
if (isa<clang::ExtVectorType>(T))
1481-
return mlir::VectorType::get(Size, ET);
1482-
1483-
if (MemRefABI && SubRef) {
1484-
auto MT = ET.cast<MemRefType>();
1485-
auto Shape2 = std::vector<int64_t>(MT.getShape());
1486-
Shape2.insert(Shape2.begin(), Size);
1487-
if (ImplicitRef)
1488-
*ImplicitRef = true;
1489-
return mlir::MemRefType::get(Shape2, MT.getElementType(),
1490-
MemRefLayoutAttrInterface(),
1491-
MT.getMemorySpace());
1492-
}
1493-
1494-
if (!MemRefABI || !AllowMerge ||
1495-
ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
1496-
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
1497-
return LLVM::LLVMFixedVectorType::get(ET, Size);
1498-
1499-
if (ImplicitRef)
1500-
*ImplicitRef = true;
1501-
1502-
return mlir::MemRefType::get({Size}, ET);
1477+
return mlir::VectorType::get(Size, ET);
15031478
}
15041479

15051480
if (const auto *FT = dyn_cast<clang::FunctionProtoType>(T)) {
@@ -1595,43 +1570,190 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
15951570
CGM.getContext().getTargetAddressSpace(PointeeType.getAddressSpace()));
15961571
}
15971572

1598-
if (T->isBuiltinType() || isa<clang::EnumType>(T)) {
1599-
if (T->isBooleanType()) {
1600-
OpBuilder Builder(TheModule->getContext());
1601-
return Builder.getIntegerType(8);
1602-
}
1603-
1604-
llvm::Type *Ty = CGM.getTypes().ConvertType(QualType(T, 0));
1573+
if (isa<clang::EnumType>(T)) {
16051574
mlir::OpBuilder Builder(TheModule->getContext());
1606-
if (Ty->isVoidTy())
1607-
return Builder.getNoneType();
1608-
if (Ty->isFloatTy())
1609-
return Builder.getF32Type();
1610-
if (Ty->isDoubleTy())
1611-
return Builder.getF64Type();
1612-
if (Ty->isX86_FP80Ty())
1613-
return Builder.getF80Type();
1614-
if (Ty->isFP128Ty())
1615-
return Builder.getF128Type();
1616-
if (Ty->is16bitFPTy()) {
1617-
CGEIST_WARNING({
1618-
if (CGM.getTarget().shouldEmitFloat16WithExcessPrecision()) {
1619-
llvm::WithColor::warning() << "Experimental usage of _Float16. Code "
1620-
"generated will be illegal "
1621-
"for this target. Use with caution.\n";
1622-
}
1623-
});
1624-
return Builder.getF16Type();
1625-
}
1626-
1627-
if (auto *IT = dyn_cast<llvm::IntegerType>(Ty))
1628-
return Builder.getIntegerType(IT->getBitWidth());
1575+
llvm::Type *Ty = CGM.getTypes().ConvertType(QualType(T, 0));
1576+
return Builder.getIntegerType(cast<llvm::IntegerType>(Ty)->getBitWidth());
16291577
}
16301578

1579+
if (T->isBuiltinType())
1580+
return getMLIRType(cast<clang::BuiltinType>(T));
1581+
16311582
LLVM_DEBUG(llvm::dbgs() << "QT: "; QT->dump(); llvm::dbgs() << "\n");
16321583
llvm_unreachable("unhandled type");
16331584
}
16341585

1586+
mlir::Type CodeGenTypes::getMLIRType(const clang::BuiltinType *BT) const {
1587+
assert(BT && "Expecting valid pointer");
1588+
1589+
mlir::OpBuilder Builder(TheModule->getContext());
1590+
mlir::LLVM::TypeFromLLVMIRTranslator TypeTranslator(*TheModule->getContext());
1591+
1592+
switch (BT->getKind()) {
1593+
case BuiltinType::Void:
1594+
return Builder.getNoneType();
1595+
1596+
case BuiltinType::ObjCId:
1597+
case BuiltinType::ObjCClass:
1598+
case BuiltinType::ObjCSel:
1599+
return Builder.getIntegerType(8);
1600+
1601+
case BuiltinType::Bool:
1602+
// TODO: boolean types should be represented as i1 rather than i8.
1603+
return Builder.getIntegerType(8);
1604+
1605+
case BuiltinType::Char_S:
1606+
case BuiltinType::Char_U:
1607+
case BuiltinType::SChar:
1608+
case BuiltinType::UChar:
1609+
case BuiltinType::Short:
1610+
case BuiltinType::UShort:
1611+
case BuiltinType::Int:
1612+
case BuiltinType::UInt:
1613+
case BuiltinType::Long:
1614+
case BuiltinType::ULong:
1615+
case BuiltinType::LongLong:
1616+
case BuiltinType::ULongLong:
1617+
case BuiltinType::WChar_S:
1618+
case BuiltinType::WChar_U:
1619+
case BuiltinType::Char8:
1620+
case BuiltinType::Char16:
1621+
case BuiltinType::Char32:
1622+
case BuiltinType::ShortAccum:
1623+
case BuiltinType::Accum:
1624+
case BuiltinType::LongAccum:
1625+
case BuiltinType::UShortAccum:
1626+
case BuiltinType::UAccum:
1627+
case BuiltinType::ULongAccum:
1628+
case BuiltinType::ShortFract:
1629+
case BuiltinType::Fract:
1630+
case BuiltinType::LongFract:
1631+
case BuiltinType::UShortFract:
1632+
case BuiltinType::UFract:
1633+
case BuiltinType::ULongFract:
1634+
case BuiltinType::SatShortAccum:
1635+
case BuiltinType::SatAccum:
1636+
case BuiltinType::SatLongAccum:
1637+
case BuiltinType::SatUShortAccum:
1638+
case BuiltinType::SatUAccum:
1639+
case BuiltinType::SatULongAccum:
1640+
case BuiltinType::SatShortFract:
1641+
case BuiltinType::SatFract:
1642+
case BuiltinType::SatLongFract:
1643+
case BuiltinType::SatUShortFract:
1644+
case BuiltinType::SatUFract:
1645+
case BuiltinType::SatULongFract:
1646+
return Builder.getIntegerType(Context.getTypeSize(BT));
1647+
1648+
case BuiltinType::Float16:
1649+
case BuiltinType::Half:
1650+
case BuiltinType::BFloat16:
1651+
return Builder.getF16Type();
1652+
1653+
case BuiltinType::Float:
1654+
return Builder.getF32Type();
1655+
1656+
case BuiltinType::Double:
1657+
return Builder.getF64Type();
1658+
1659+
case BuiltinType::LongDouble:
1660+
case BuiltinType::Float128:
1661+
case BuiltinType::Ibm128:
1662+
return Builder.getF128Type();
1663+
1664+
case BuiltinType::NullPtr:
1665+
// Model std::nullptr_t as i8*
1666+
return getPointerOrMemRefType(Builder.getIntegerType(8), 0);
1667+
1668+
case BuiltinType::UInt128:
1669+
case BuiltinType::Int128:
1670+
return Builder.getIntegerType(128);
1671+
1672+
#define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) \
1673+
case BuiltinType::Id:
1674+
#include "clang/Basic/OpenCLImageTypes.def"
1675+
#define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) \
1676+
case BuiltinType::Sampled##Id:
1677+
#define IMAGE_WRITE_TYPE(Type, Id, Ext)
1678+
#define IMAGE_READ_WRITE_TYPE(Type, Id, Ext)
1679+
#include "clang/Basic/OpenCLImageTypes.def"
1680+
#define EXT_OPAQUE_TYPE(ExtType, Id, Ext) case BuiltinType::Id:
1681+
#include "clang/Basic/OpenCLExtensionTypes.def"
1682+
case BuiltinType::OCLSampler:
1683+
case BuiltinType::OCLEvent:
1684+
case BuiltinType::OCLClkEvent:
1685+
case BuiltinType::OCLQueue:
1686+
case BuiltinType::OCLReserveID:
1687+
return TypeTranslator.translateType(
1688+
CGM.getOpenCLRuntime().convertOpenCLSpecificType(BT));
1689+
1690+
case BuiltinType::SveInt8:
1691+
case BuiltinType::SveUint8:
1692+
case BuiltinType::SveInt8x2:
1693+
case BuiltinType::SveUint8x2:
1694+
case BuiltinType::SveInt8x3:
1695+
case BuiltinType::SveUint8x3:
1696+
case BuiltinType::SveInt8x4:
1697+
case BuiltinType::SveUint8x4:
1698+
case BuiltinType::SveInt16:
1699+
case BuiltinType::SveUint16:
1700+
case BuiltinType::SveInt16x2:
1701+
case BuiltinType::SveUint16x2:
1702+
case BuiltinType::SveInt16x3:
1703+
case BuiltinType::SveUint16x3:
1704+
case BuiltinType::SveInt16x4:
1705+
case BuiltinType::SveUint16x4:
1706+
case BuiltinType::SveInt32:
1707+
case BuiltinType::SveUint32:
1708+
case BuiltinType::SveInt32x2:
1709+
case BuiltinType::SveUint32x2:
1710+
case BuiltinType::SveInt32x3:
1711+
case BuiltinType::SveUint32x3:
1712+
case BuiltinType::SveInt32x4:
1713+
case BuiltinType::SveUint32x4:
1714+
case BuiltinType::SveInt64:
1715+
case BuiltinType::SveUint64:
1716+
case BuiltinType::SveInt64x2:
1717+
case BuiltinType::SveUint64x2:
1718+
case BuiltinType::SveInt64x3:
1719+
case BuiltinType::SveUint64x3:
1720+
case BuiltinType::SveInt64x4:
1721+
case BuiltinType::SveUint64x4:
1722+
case BuiltinType::SveBool:
1723+
case BuiltinType::SveFloat16:
1724+
case BuiltinType::SveFloat16x2:
1725+
case BuiltinType::SveFloat16x3:
1726+
case BuiltinType::SveFloat16x4:
1727+
case BuiltinType::SveFloat32:
1728+
case BuiltinType::SveFloat32x2:
1729+
case BuiltinType::SveFloat32x3:
1730+
case BuiltinType::SveFloat32x4:
1731+
case BuiltinType::SveFloat64:
1732+
case BuiltinType::SveFloat64x2:
1733+
case BuiltinType::SveFloat64x3:
1734+
case BuiltinType::SveFloat64x4:
1735+
case BuiltinType::SveBFloat16:
1736+
case BuiltinType::SveBFloat16x2:
1737+
case BuiltinType::SveBFloat16x3:
1738+
case BuiltinType::SveBFloat16x4:
1739+
llvm_unreachable("Unexpected ARM type");
1740+
1741+
#define PPC_VECTOR_TYPE(Name, Id, Size) case BuiltinType::Id:
1742+
#include "clang/Basic/PPCTypes.def"
1743+
#define RVV_TYPE(Name, Id, SingletonId) case BuiltinType::Id:
1744+
#include "clang/Basic/RISCVVTypes.def"
1745+
llvm_unreachable("Unexpected PPC type");
1746+
1747+
case BuiltinType::Dependent:
1748+
#define BUILTIN_TYPE(Id, SingletonId)
1749+
#define PLACEHOLDER_TYPE(Id, SingletonId) case BuiltinType::Id:
1750+
#include "clang/AST/BuiltinTypes.def"
1751+
llvm_unreachable("Unexpected placeholder builtin type!");
1752+
}
1753+
1754+
llvm_unreachable("Unexpected builtin type!");
1755+
}
1756+
16351757
// Note: In principle we should always create a memref here because we want to
16361758
// avoid lowering the abstraction level at this point in the compilation flow.
16371759
// However, cgeist treats type inconsistently, it expects memref for SYCL

polygeist/tools/cgeist/Lib/CodeGenTypes.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ namespace clang {
2222
class ASTContext;
2323
class CodeGenOptions;
2424
class FunctionDecl;
25+
class GlobalDecl;
26+
class RecordDecl;
27+
class BuiltinType;
2528
class QualType;
2629
class RecordType;
27-
class RecordDecl;
2830
class Type;
29-
class GlobalDecl;
3031

3132
namespace CodeGen {
3233
class ABIInfo;
@@ -104,6 +105,8 @@ class CodeGenTypes {
104105

105106
bool getCPUAndFeaturesAttributes(clang::GlobalDecl GD,
106107
AttrBuilder &Attrs) const;
108+
109+
mlir::Type getMLIRType(const clang::BuiltinType *BT) const;
107110
};
108111

109112
} // namespace CodeGen

polygeist/tools/cgeist/Test/Verification/float16.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
// COM: sapphirerapids supports _Float16 natively
55

6-
// CHECK-EXTEND: warning: Experimental usage of _Float16.
76
// CHECK-NATIVE-NOT: warning: Experimental usage of _Float16.
87

98
// CHECK-EXTEND-LABEL: func.func @type(%arg0: f16) -> f16
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: clang++ -fsycl -fsycl-device-only -O0 -w -emit-mlir -o - %s | FileCheck %s
2+
3+
#include <sycl/sycl.hpp>
4+
5+
using namespace sycl;
6+
7+
// CHECK-LABEL: func.func @_Z13opencl_float2Dv2_f(
8+
// CHECK: %arg0: vector<2xf32> {llvm.noundef})
9+
SYCL_EXTERNAL void opencl_float2(__cl_float2 var) {}
10+
11+
// CHECK-LABEL: func.func @_Z13opencl_float4Dv4_f(
12+
// CHECK: %arg0: vector<4xf32> {llvm.noundef})
13+
SYCL_EXTERNAL void opencl_float4(__cl_float4 var) {}
14+
15+
// CHECK-LABEL: func.func @_Z15scalable_vec2_tN4sycl3_V13vecIfLi2EEE(
16+
// CHECK: %arg0: memref<?x!sycl_vec_f32_2_> {llvm.align = 8 : i64, llvm.byval = !sycl_vec_f32_2_, llvm.noundef})
17+
SYCL_EXTERNAL void scalable_vec2_t(sycl::vec<float, 2> var) {}
18+
19+
// CHECK-LABEL: func.func @_Z15scalable_vec4_tN4sycl3_V13vecIfLi4EEE(
20+
// CHECK: %arg0: memref<?x!sycl_vec_f32_4_> {llvm.align = 16 : i64, llvm.byval = !sycl_vec_f32_4_, llvm.noundef})
21+
SYCL_EXTERNAL void scalable_vec4_t(sycl::vec<float, 4> var) {}
22+
23+
// CHECK-LABEL: func.func @_Z8float128g(
24+
// CHECK: %arg0: f128 {llvm.noundef})
25+
SYCL_EXTERNAL void float128(__float128 var) {}
26+
27+
// CHECK-LABEL: func.func @_Z6int128n(
28+
// CHECK: %arg0: i128 {llvm.noundef})
29+
SYCL_EXTERNAL void int128(__int128 var) {}
30+
31+
// CHECK-LABEL: func.func @_Z19opencl_image1d_ro_t14ocl_image1d_ro(
32+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.image1d_ro_t", opaque>, 1>)
33+
SYCL_EXTERNAL void opencl_image1d_ro_t(detail::opencl_image_type<1, access::mode::read, access::target::image>::type var) {}
34+
35+
// CHECK-LABEL: func.func @_Z19opencl_image1d_wo_t14ocl_image1d_wo(
36+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.image1d_wo_t", opaque>, 1>)
37+
SYCL_EXTERNAL void opencl_image1d_wo_t(detail::opencl_image_type<1, access::mode::write, access::target::image>::type var) {}
38+
39+
// CHECK-LABEL: func.func @_Z25opencl_image1d_array_ro_t20ocl_image1d_array_ro(
40+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.image1d_array_ro_t", opaque>, 1>)
41+
SYCL_EXTERNAL void opencl_image1d_array_ro_t(detail::opencl_image_type<1, access::mode::read, access::target::image_array>::type var) {}
42+
43+
// CHECK-LABEL: func.func @_Z25opencl_image1d_array_wo_t20ocl_image1d_array_wo(
44+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.image1d_array_wo_t", opaque>, 1>)
45+
SYCL_EXTERNAL void opencl_image1d_array_wo_t(detail::opencl_image_type<1, access::mode::write, access::target::image_array>::type var) {}
46+
47+
// CHECK-LABEL: func.func @_Z16opencl_sampler_t11ocl_sampler(
48+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.sampler_t", opaque>, 2>)
49+
SYCL_EXTERNAL void opencl_sampler_t(__ocl_sampler_t var) {}
50+
51+
// CHECK-LABEL: func.func @_Z33opencl_sampled_image_array1d_ro_t38__spirv_SampledImage__image1d_array_ro(
52+
// CHECK: %arg0: !llvm.ptr<struct<"spirv.SampledImage.image1d_array_ro_t.1", opaque>, 1>)
53+
SYCL_EXTERNAL void opencl_sampled_image_array1d_ro_t(__ocl_sampled_image1d_array_ro_t var) {}
54+
55+
// CHECK-LABEL: func.func @_Z12opencl_vec_tDv4_j(
56+
// CHECK: %arg0: vector<4xi32> {llvm.noundef})
57+
SYCL_EXTERNAL void opencl_vec_t(__ocl_vec_t<uint32_t, 4> var) {}
58+
59+
// CHECK-LABEL: func.func @_Z14opencl_event_t9ocl_event(
60+
// CHECK: %arg0: !llvm.ptr<struct<"opencl.event_t", opaque>, 4>)
61+
SYCL_EXTERNAL void opencl_event_t(__ocl_event_t var) {}

0 commit comments

Comments
 (0)