Skip to content

Commit 68b6f39

Browse files
authored
[MLIR][AMDGPU] Fix bug in GatherToLDSOpLowering, get the correct MemRefType for destination (#142915)
This PR fixes a bug in GatherToLDSOpLowering, we were getting the MemRefType of source for the destination. Additionally, some related typos are corrected. CC: @krzysz00 @umangyadav @lialan
1 parent bd33eef commit 68b6f39

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

llvm/docs/AMDGPUUsage.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,12 +1215,12 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
12151215
denormalization mode, enabled traps, and floating point exceptions.
12161216
The format is a 64-bit concatenation of the MODE and TRAPSTS registers.
12171217

1218-
:ref:`llvm.set.fpenv<int_set_fpenv>` Sets the floating point environment to the specifies state.
1218+
:ref:`llvm.set.fpenv<int_set_fpenv>` Sets the floating point environment to the specified state.
12191219
llvm.amdgcn.load.to.lds.p<1/7> Loads values from global memory (either in the form of a global
12201220
a raw fat buffer pointer) to LDS. The size of the data copied can be 1, 2,
12211221
or 4 bytes (and gfx950 also allows 12 or 16 bytes). The LDS pointer
12221222
argument should be wavefront-uniform; the global pointer need not be.
1223-
The LDS pointer is implicitly offset by 4 * lane_id bytes for sies <= 4 bytes
1223+
The LDS pointer is implicitly offset by 4 * lane_id bytes for size <= 4 bytes
12241224
and 16 * lane_id bytes for larger sizes. This lowers to `global_load_lds`,
12251225
`buffer_load_* ... lds`, or `global_load__* ... lds` depending on address
12261226
space and architecture. `amdgcn.global.load.lds` has the same semantics as

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11011101
Location loc = op.getLoc();
11021102

11031103
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1104-
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1104+
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
11051105

11061106
// TODO: instead of only transfering one element per thread, we could
11071107
// augment it to transfer multiple elements per thread by issuing multiple

mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
3131
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
3232
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
3333

34-
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
35-
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
34+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
35+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
3636
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
3737

3838
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -65,8 +65,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
6565
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
6666
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
6767

68-
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
69-
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
68+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
69+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
7070
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
7171

7272
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -103,8 +103,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
103103
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
104104
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
105105

106-
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
107-
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
106+
// CHECK: %[[C128:.*]] = llvm.mlir.constant(128 : index) : i64
107+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C128]] : i64
108108
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
109109

110110
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -130,7 +130,9 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
130130
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
131131
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
132132
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
133-
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
133+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
134+
// CHECK: %[[DSTIDX:.*]] = llvm.mul %[[DSTIDX_CAST]], %[[C64]] : i64
135+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX]]]
134136
// CHECK: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
135137
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
136138
%c0 = arith.constant 0 : index
@@ -166,8 +168,8 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu_fat
166168
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
167169
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
168170

169-
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
170-
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
171+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
172+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
171173
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
172174

173175
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]

0 commit comments

Comments
 (0)