Skip to content

Commit 48f9e5b

Browse files
[SYCL-MLIR] Use VectorType for ext_vector_type (#7342)
Globals of `ext_vector_type` type used to represent as `memref<3xi64>`, which lowers to array type pointer in LLVM. This can cause link failures when files are compiled from different compilers. This PR changes globals of `ext_vector_type` type to be represented as `memref<vector<3xi64>>`, which lowers to vector type pointer, the same type as directly compile from `clang`. Note: After this PR, `parallel_for.cpp` can run successfully. Test case added in intel/llvm-test-suite#1374. Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 62eaebe commit 48f9e5b

File tree

5 files changed

+70
-37
lines changed

5 files changed

+70
-37
lines changed

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,16 @@ MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *expr) {
2929
expr->getEncodedElementAccess(indices);
3030
assert(indices.size() == 1 &&
3131
"The support for higher dimensions to be implemented.");
32-
auto idx = castToIndex(getMLIRLocation(expr->getAccessorLoc()),
33-
builder.create<ConstantIntOp>(loc, indices[0], 32));
3432
assert(base.isReference);
35-
base.isReference = false;
36-
auto mt = base.val.getType().cast<MemRefType>();
37-
auto shape = std::vector<int64_t>(mt.getShape());
38-
if (shape.size() == 1) {
39-
shape[0] = -1;
40-
} else {
41-
shape.erase(shape.begin());
42-
}
43-
auto mt0 =
44-
mlir::MemRefType::get(shape, mt.getElementType(),
45-
MemRefLayoutAttrInterface(), mt.getMemorySpace());
46-
base.val = builder.create<polygeist::SubIndexOp>(loc, mt0, base.val,
47-
getConstantIndex(0));
48-
return CommonArrayLookup(base, idx, base.isReference);
33+
assert(base.val.getType().isa<MemRefType>() &&
34+
"Expecting ExtVectorElementExpr to have memref type");
35+
auto MT = base.val.getType().cast<MemRefType>();
36+
assert(MT.getElementType().isa<mlir::VectorType>() &&
37+
"Expecting ExtVectorElementExpr to have memref of vector elements");
38+
auto Idx = builder.create<ConstantIntOp>(loc, indices[0], 64);
39+
mlir::Value Val = base.getValue(builder);
40+
return ValueCategory(builder.create<LLVM::ExtractElementOp>(loc, Val, Idx),
41+
/*IsReference*/ false);
4942
}
5043

5144
ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
@@ -1690,7 +1683,6 @@ ValueCategory MLIRScanner::VisitCastExpr(CastExpr *E) {
16901683
lres.dump();
16911684
}
16921685
#endif
1693-
assert(prev.isReference);
16941686
return ValueCategory(lres, /*isReference*/ false);
16951687
}
16961688
case clang::CastKind::CK_IntegralToFloating: {

polygeist/tools/cgeist/Lib/CodeGenTypes.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,8 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
14531453
bool SubRef = false;
14541454
auto ET = getMLIRType(AT->getElementType(), &SubRef, AllowMerge);
14551455
int64_t Size = AT->getNumElements();
1456+
if (isa<clang::ExtVectorType>(T))
1457+
return mlir::VectorType::get(Size, ET);
14561458
if (MemRefABI && SubRef) {
14571459
auto MT = ET.cast<MemRefType>();
14581460
auto Shape2 = std::vector<int64_t>(MT.getShape());
@@ -1530,6 +1532,10 @@ mlir::Type CodeGenTypes::getMLIRType(clang::QualType QT, bool *ImplicitRef,
15301532
}
15311533

15321534
if (isa<clang::VectorType>(PTT) || isa<clang::ComplexType>(PTT)) {
1535+
if (auto VT = SubType.dyn_cast<mlir::VectorType>())
1536+
// FIXME: We should create memref of rank 0.
1537+
// Details: https://github.com/intel/llvm/issues/7354
1538+
return mlir::MemRefType::get(Outer, SubType);
15331539
if (SubType.isa<MemRefType>()) {
15341540
assert(SubRef);
15351541
auto MT = SubType.cast<MemRefType>();

polygeist/tools/cgeist/Lib/clang-mlir.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
#include "llvm/Support/Debug.h"
4747
#include "llvm/Support/raw_ostream.h"
4848

49+
#define DEBUG_TYPE "cgeist"
50+
4951
using namespace clang;
5052
using namespace llvm;
5153
using namespace clang::driver;
@@ -2086,19 +2088,19 @@ MLIRASTConsumer::GetOrCreateGlobal(const ValueDecl *FD, std::string prefix,
20862088
mlir::Type rt = getTypes().getMLIRType(FD->getType());
20872089
auto *VD = dyn_cast<VarDecl>(FD);
20882090
LLVM_DEBUG({
2089-
if (!VD)
2090-
FD->dump();
2091+
if (!VD) {
2092+
llvm::dbgs() << "GetOrCreateGlobal ";
2093+
VD->dump(llvm::dbgs());
2094+
}
20912095
});
20922096
VD = VD->getCanonicalDecl();
20932097
unsigned memspace = VD ? CGM.getContext().getTargetAddressSpace(
20942098
CGM.GetGlobalVarAddressSpace(VD))
20952099
: CGM.getDataLayout().getDefaultGlobalsAddressSpace();
20962100
bool isArray = isa<clang::ArrayType>(FD->getType());
2097-
bool isExtVectorType =
2098-
isa<clang::ExtVectorType>(FD->getType()->getUnqualifiedDesugaredType());
20992101

21002102
mlir::MemRefType mr;
2101-
if (!isArray && !isExtVectorType) {
2103+
if (!isArray) {
21022104
mr = mlir::MemRefType::get({}, rt, {}, memspace);
21032105
} else {
21042106
auto mt = rt.cast<mlir::MemRefType>();
Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
// RUN: cgeist %s --function=* -S -emit-llvm | FileCheck %s --check-prefix=LLVM
3+
4+
#include <cstddef>
25

36
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
47

@@ -12,13 +15,30 @@ size_t evt2() {
1215
return stv.x;
1316
}
1417

15-
// CHECK: func.func @_Z3evtv() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
16-
// CHECK-NEXT: %alloca = memref.alloca() : memref<1x3xi32>
17-
// CHECK-NEXT: %0 = affine.load %alloca[0, 0] : memref<1x3xi32>
18-
// CHECK-NEXT: return %0 : i32
18+
// CHECK: func.func @_Z3evtv() -> i64 attributes {llvm.linkage = #llvm.linkage<external>} {
19+
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
20+
// CHECK-NEXT: %alloca = memref.alloca() : memref<1xvector<3xi64>>
21+
// CHECK-NEXT: %0 = affine.load %alloca[0] : memref<1xvector<3xi64>>
22+
// CHECK-NEXT: %1 = llvm.extractelement %0[%c0_i64 : i64] : vector<3xi64>
23+
// CHECK-NEXT: return %1 : i64
1924
// CHECK-NEXT: }
20-
// CHECK: func.func @_Z4evt2v() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
21-
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xi32>
22-
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xi32>
23-
// CHECK-NEXT: return %1 : i32
25+
// CHECK: func.func @_Z4evt2v() -> i64 attributes {llvm.linkage = #llvm.linkage<external>} {
26+
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
27+
// CHECK-NEXT: %0 = memref.get_global @stv : memref<vector<3xi64>>
28+
// CHECK-NEXT: %alloca = memref.alloca() : memref<1xindex>
29+
// CHECK-NEXT: %reshape = memref.reshape %0(%alloca) : (memref<vector<3xi64>>, memref<1xindex>) -> memref<1xvector<3xi64>>
30+
// CHECK-NEXT: %1 = affine.load %reshape[0] : memref<1xvector<3xi64>>
31+
// CHECK-NEXT: %2 = llvm.extractelement %1[%c0_i64 : i64] : vector<3xi64>
32+
// CHECK-NEXT: return %2 : i64
2433
// CHECK-NEXT: }
34+
35+
// LLVM: @stv = external global <3 x i64>
36+
// LLVM-LABEL: define i64 @_Z3evtv() !dbg !3 {
37+
// LLVM-NEXT: %1 = alloca <3 x i64>, align 32, !dbg !7
38+
// LLVM-NEXT: %2 = load <3 x i64>, <3 x i64>* %1, align 32
39+
// LLVM-NEXT: %3 = extractelement <3 x i64> %2, i64 0
40+
// LLVM-NEXT: ret i64 %3
41+
// LLVM-LABEL: define i64 @_Z4evt2v() !dbg !9 {
42+
// LLVM-NEXT: %1 = load <3 x i64>, <3 x i64>* @stv, align 32
43+
// LLVM-NEXT: %2 = extractelement <3 x i64> %1, i64 0
44+
// LLVM-NEXT: ret i64 %2

polygeist/tools/cgeist/Test/Verification/sycl/constructors.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,26 @@
1414
// CHECK-DAG: memref.global @__spirv_BuiltInNumSubgroups : memref<i32, 1>
1515
// CHECK-DAG: memref.global @__spirv_BuiltInSubgroupMaxSize : memref<i32, 1>
1616
// CHECK-DAG: memref.global @__spirv_BuiltInSubgroupSize : memref<i32, 1>
17-
// CHECK-DAG: memref.global @__spirv_BuiltInLocalInvocationId : memref<3xi64, 1>
18-
// CHECK-DAG: memref.global @__spirv_BuiltInWorkgroupId : memref<3xi64, 1>
19-
// CHECK-DAG: memref.global @__spirv_BuiltInWorkgroupSize : memref<3xi64, 1>
20-
// CHECK-DAG: memref.global @__spirv_BuiltInNumWorkgroups : memref<3xi64, 1>
21-
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalOffset : memref<3xi64, 1>
22-
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalSize : memref<3xi64, 1>
23-
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalInvocationId : memref<3xi64, 1>
17+
// CHECK-DAG: memref.global @__spirv_BuiltInLocalInvocationId : memref<vector<3xi64>, 1>
18+
// CHECK-DAG: memref.global @__spirv_BuiltInWorkgroupId : memref<vector<3xi64>, 1>
19+
// CHECK-DAG: memref.global @__spirv_BuiltInWorkgroupSize : memref<vector<3xi64>, 1>
20+
// CHECK-DAG: memref.global @__spirv_BuiltInNumWorkgroups : memref<vector<3xi64>, 1>
21+
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalOffset : memref<vector<3xi64>, 1>
22+
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalSize : memref<vector<3xi64>, 1>
23+
// CHECK-DAG: memref.global @__spirv_BuiltInGlobalInvocationId : memref<vector<3xi64>, 1>
24+
25+
// CHECK-LLVM-DAG: @__spirv_BuiltInSubgroupLocalInvocationId = external addrspace(1) global i32
26+
// CHECK-LLVM-DAG: @__spirv_BuiltInSubgroupId = external addrspace(1) global i32
27+
// CHECK-LLVM-DAG: @__spirv_BuiltInNumSubgroups = external addrspace(1) global i32
28+
// CHECK-LLVM-DAG: @__spirv_BuiltInSubgroupMaxSize = external addrspace(1) global i32
29+
// CHECK-LLVM-DAG: @__spirv_BuiltInSubgroupSize = external addrspace(1) global i32
30+
// CHECK-LLVM-DAG: @__spirv_BuiltInLocalInvocationId = external addrspace(1) global <3 x i64>
31+
// CHECK-LLVM-DAG: @__spirv_BuiltInWorkgroupId = external addrspace(1) global <3 x i64>
32+
// CHECK-LLVM-DAG: @__spirv_BuiltInWorkgroupSize = external addrspace(1) global <3 x i64>
33+
// CHECK-LLVM-DAG: @__spirv_BuiltInNumWorkgroups = external addrspace(1) global <3 x i64>
34+
// CHECK-LLVM-DAG: @__spirv_BuiltInGlobalOffset = external addrspace(1) global <3 x i64>
35+
// CHECK-LLVM-DAG: @__spirv_BuiltInGlobalSize = external addrspace(1) global <3 x i64>
36+
// CHECK-LLVM-DAG: @__spirv_BuiltInGlobalInvocationId = external addrspace(1) global <3 x i64>
2437

2538
// Ensure the spirv functions that reference these globals are not filtered out
2639
// CHECK-DAG: func.func @_Z28__spirv_GlobalInvocationId_xv()

0 commit comments

Comments
 (0)