Skip to content

Commit 5c6fb36

Browse files
authored
[SYCL-MLIR][Opaque] Fix nd_item.get_group_linear_id lowering (#11016)
Use correct base type for `llvm.getelementptr` operation: it should be `nd_item`, not `group`. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent a9127b2 commit 5c6fb36

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

mlir-sycl/lib/Conversion/SYCLToLLVM/DPCPP.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,14 +1974,14 @@ class NDItemGetGroupLinearIDPattern
19741974
// body is dropped and we can create group type more easily.
19751975
const auto indices = GetMemberPattern<NDItemGroup>::getIndices();
19761976
assert(indices.size() == 1 && "Expecting a single index");
1977-
const auto groupTy =
1978-
cast<NdItemType>(op.getNDItem().getType().getElementType())
1979-
.getBody()[indices[0]];
1977+
auto ndItemTy = cast<NdItemType>(op.getNDItem().getType().getElementType());
1978+
const auto groupTy = ndItemTy.getBody()[indices[0]];
19801979
const auto convGroupTy = getTypeConverter()->convertType(groupTy);
1980+
Type convNDItemTy = getTypeConverter()->convertType(ndItemTy);
19811981
bool useOpaquePointers = getTypeConverter()->useOpaquePointers();
19821982
auto group = GetMemberPattern<NDItemGroup>::getRef(
1983-
rewriter, loc, convGroupTy, adaptor.getNDItem(), std::nullopt,
1984-
useOpaquePointers);
1983+
rewriter, loc, (useOpaquePointers) ? convNDItemTy : convGroupTy,
1984+
adaptor.getNDItem(), std::nullopt, useOpaquePointers);
19851985
const auto thisTy = MemRefType::get(ShapedType::kDynamic, groupTy);
19861986
// We have the already converted group, but, in order to not replicate
19871987
// `sycl.group.get_group_linear_id` conversion to LLVM, we just reuse that

mlir-sycl/test/Conversion/SYCLToLLVM/sycl-methods-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ func.func @test_3(%nd: memref<?x!sycl_nd_item_3_>) -> i64 {
10331033

10341034
// CHECK-LABEL: llvm.func @test_1(
10351035
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i64 {
1036-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.1", {{.*}}>
1036+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::nd_item.1", {{.*}}>
10371037
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr inbounds %[[VAL_1]][0, 3, 0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.1", {{.*}}>
10381038
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i64
10391039
// CHECK: llvm.return %[[VAL_4]] : i64
@@ -1045,7 +1045,7 @@ func.func @test_1(%nd: memref<?x!sycl_nd_item_1_>) -> i64 {
10451045

10461046
// CHECK-LABEL: llvm.func @test_2(
10471047
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i64 {
1048-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.2", {{.*}}>
1048+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::nd_item.2", {{.*}}>
10491049
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr inbounds %[[VAL_1]][0, 3, 0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.2", {{.*}}>
10501050
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i64
10511051
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr inbounds %[[VAL_1]][0, 2, 0, 0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.2", {{.*}}>
@@ -1063,7 +1063,7 @@ func.func @test_2(%nd: memref<?x!sycl_nd_item_2_>) -> i64 {
10631063

10641064
// CHECK-LABEL: llvm.func @test_3(
10651065
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i64 {
1066-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.3", {{.*}}>
1066+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::nd_item.3", {{.*}}>
10671067
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr inbounds %[[VAL_1]][0, 3, 0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.3", {{.*}}>
10681068
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i64
10691069
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr inbounds %[[VAL_1]][0, 2, 0, 0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::group.3", {{.*}}>

0 commit comments

Comments
 (0)