Skip to content

Commit 6e98388

Browse files
[SYCL-MLIR] Fix store integer to vector pointer (#8028)
This PR fixes the bug reported in #7822 (review). Signed-off-by: Tsang, Whitney <[email protected]>
1 parent fbe1c86 commit 6e98388

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

polygeist/tools/cgeist/Lib/ValueCategory.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ void ValueCategory::store(mlir::OpBuilder &builder, mlir::Value toStore) const {
129129
builder.create<LLVM::NullOp>(nt.getLoc(), pt.getElementType());
130130
}
131131
}
132+
133+
if (Index) {
134+
auto ElemTy =
135+
pt.getElementType().cast<mlir::VectorType>().getElementType();
136+
assert(ElemTy == toStore.getType() &&
137+
"Vector insertion element mismatch");
138+
ValueCategory Vec{builder.create<mlir::LLVM::LoadOp>(loc, val), false};
139+
Vec = Vec.InsertElement(builder, loc, toStore, *Index);
140+
toStore = Vec.val;
141+
}
142+
132143
if (toStore.getType() != pt.getElementType()) {
133144
if (auto mt = toStore.getType().dyn_cast<MemRefType>()) {
134145
if (auto spt =
@@ -157,8 +168,10 @@ void ValueCategory::store(mlir::OpBuilder &builder, mlir::Value toStore) const {
157168
assert(mt.getShape().size() == 1 && "must have size 1");
158169

159170
if (Index) {
160-
auto VT = mt.getElementType().cast<mlir::VectorType>().getElementType();
161-
assert(VT == toStore.getType() && "Vector insertion element mismatch");
171+
auto ElemTy =
172+
mt.getElementType().cast<mlir::VectorType>().getElementType();
173+
assert(ElemTy == toStore.getType() &&
174+
"Vector insertion element mismatch");
162175
const auto C0 = builder.createOrFold<arith::ConstantIntOp>(
163176
loc, 0, builder.getI64Type());
164177
ValueCategory Vec{builder.createOrFold<memref::LoadOp>(loc, val, C0),

polygeist/tools/cgeist/Test/Verification/sycl/structvec.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ struct structvec {
1717
// CHECK-LABEL: func.func @_Z10test_store9structvecic(%arg0: !llvm.ptr<struct<(vector<2xi8>)>> {llvm.align = 2 : i64, llvm.byval = !llvm.struct<(vector<2xi8>)>, llvm.noundef}, %arg1: i32 {llvm.noundef}, %arg2: i8 {llvm.noundef, llvm.signext}) -> !llvm.struct<(vector<2xi8>)>
1818
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
1919
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(vector<2xi8>)> : (i64) -> !llvm.ptr<struct<(vector<2xi8>)>>
20-
// CHECK-NEXT: %1 = llvm.addrspacecast %0 : !llvm.ptr<struct<(vector<2xi8>)>> to !llvm.ptr<struct<(vector<2xi8>)>, 4>
21-
// CHECK-NEXT: %2 = llvm.addrspacecast %arg0 : !llvm.ptr<struct<(vector<2xi8>)>> to !llvm.ptr<struct<(vector<2xi8>)>, 4>
22-
// CHECK-NEXT: call @_ZN9structvecC1EOS_(%1, %2) : (!llvm.ptr<struct<(vector<2xi8>)>, 4>, !llvm.ptr<struct<(vector<2xi8>)>, 4>) -> ()
23-
// CHECK-NEXT: %3 = llvm.load %0 : !llvm.ptr<struct<(vector<2xi8>)>>
24-
// CHECK-NEXT: return %3 : !llvm.struct<(vector<2xi8>)>
20+
// CHECK-NEXT: %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr<struct<(vector<2xi8>)>>) -> !llvm.ptr<vector<2xi8>>
21+
// CHECK-NEXT: %2 = llvm.load %1 : !llvm.ptr<vector<2xi8>>
22+
// CHECK-NEXT: %3 = vector.insertelement %arg2, %2[%arg1 : i32] : vector<2xi8>
23+
// CHECK-NEXT: llvm.store %3, %1 : !llvm.ptr<vector<2xi8>>
24+
// CHECK-NEXT: %4 = llvm.addrspacecast %0 : !llvm.ptr<struct<(vector<2xi8>)>> to !llvm.ptr<struct<(vector<2xi8>)>, 4>
25+
// CHECK-NEXT: %5 = llvm.addrspacecast %arg0 : !llvm.ptr<struct<(vector<2xi8>)>> to !llvm.ptr<struct<(vector<2xi8>)>, 4>
26+
// CHECK-NEXT: call @_ZN9structvecC1EOS_(%4, %5) : (!llvm.ptr<struct<(vector<2xi8>)>, 4>, !llvm.ptr<struct<(vector<2xi8>)>, 4>) -> ()
27+
// CHECK-NEXT: %6 = llvm.load %0 : !llvm.ptr<struct<(vector<2xi8>)>>
28+
// CHECK-NEXT: return %6 : !llvm.struct<(vector<2xi8>)>
2529
// CHECK-NEXT: }
2630

2731
// CHECK-LABEL: func.func @_ZN9structvecC1EOS_(%arg0: !llvm.ptr<struct<(vector<2xi8>)>, 4> {llvm.align = 2 : i64, llvm.dereferenceable_or_null = 2 : i64, llvm.noundef}, %arg1: !llvm.ptr<struct<(vector<2xi8>)>, 4> {llvm.align = 2 : i64, llvm.dereferenceable = 2 : i64, llvm.noundef})
@@ -72,7 +76,29 @@ SYCL_EXTERNAL structvec test_store(structvec sv, int idx, char el) {
7276
// CHECK-NEXT: call @_ZN9structvecC1EOS_(%17, %18) : (!llvm.ptr<struct<(vector<2xi8>)>, 4>, !llvm.ptr<struct<(vector<2xi8>)>, 4>) -> ()
7377
// CHECK-NEXT: %19 = llvm.load %0 : !llvm.ptr<struct<(vector<2xi8>)>>
7478
// CHECK-NEXT: return %19 : !llvm.struct<(vector<2xi8>)>
75-
// CHECK-NEXT: }
79+
// CHECK-NEXT: }
80+
81+
// CHECK-LABEL: func.func @_ZN9structvecC1ESt16initializer_listIcE(%arg0: !llvm.ptr<struct<(vector<2xi8>)>, 4> {llvm.align = 2 : i64, llvm.dereferenceable_or_null = 2 : i64, llvm.noundef}, %arg1: !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>> {llvm.align = 8 : i64, llvm.byval = !llvm.struct<(memref<?xi8, 4>, i64)>, llvm.noundef})
82+
// CHECK-DAG: %c2 = arith.constant 2 : index
83+
// CHECK-DAG: %c0 = arith.constant 0 : index
84+
// CHECK-DAG: %c1 = arith.constant 1 : index
85+
// CHECK-DAG: %c0_i8 = arith.constant 0 : i8
86+
// CHECK-NEXT: scf.for %arg2 = %c0 to %c2 step %c1 {
87+
// CHECK-NEXT: %0 = arith.index_cast %arg2 : index to i32
88+
// CHECK-NEXT: %1 = llvm.addrspacecast %arg1 : !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>> to !llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>, 4>
89+
// CHECK-NEXT: %2 = func.call @_ZNKSt16initializer_listIcE5beginEv(%1) : (!llvm.ptr<!llvm.struct<(memref<?xi8, 4>, i64)>, 4>) -> memref<?xi8, 4>
90+
// CHECK-NEXT: %3 = arith.index_castui %0 : i32 to index
91+
// CHECK-NEXT: %4 = memref.load %2[%3] : memref<?xi8, 4>
92+
// CHECK-NEXT: %5 = arith.cmpi ne, %4, %c0_i8 : i8
93+
// CHECK-NEXT: %6 = arith.extui %5 : i1 to i32
94+
// CHECK-NEXT: %7 = arith.trunci %6 : i32 to i8
95+
// CHECK-NEXT: %8 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr<struct<(vector<2xi8>)>, 4>) -> !llvm.ptr<vector<2xi8>, 4>
96+
// CHECK-NEXT: %9 = llvm.load %8 : !llvm.ptr<vector<2xi8>, 4>
97+
// CHECK-NEXT: %10 = vector.insertelement %7, %9[%0 : i32] : vector<2xi8>
98+
// CHECK-NEXT: llvm.store %10, %8 : !llvm.ptr<vector<2xi8>, 4>
99+
// CHECK-NEXT: }
100+
// CHECK-NEXT: return
101+
// CHECK-NEXT: }
76102

77103
SYCL_EXTERNAL structvec test_init() {
78104
structvec sv{0, 1};

0 commit comments

Comments
 (0)