@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
50
50
static Value permuteVectorOffset (OpBuilder &b, Location loc,
51
51
ArrayRef<Value> indices, MemRefType memrefTy,
52
52
int64_t srcDim, int64_t tgtDim) {
53
- // / Adjust the src index to change how often the permutation changes
54
- // / if necessary.
53
+ // Adjust the src index to change how often the permutation changes
54
+ // if necessary.
55
55
Value src = indices[srcDim];
56
56
57
- // / We only want to permute every N iterations of the target dim where N is
58
- // / ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
57
+ // We only want to permute every N iterations of the target dim where N is
58
+ // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
59
59
const int64_t permuteEveryN = std::max<int64_t >(
60
60
1 , kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize (tgtDim) *
61
61
memrefTy.getElementTypeBitWidth ()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
110
110
permuteVectorOffset (builder, loc, indices, memrefTy, srcDim, tgtDim);
111
111
}
112
112
113
- // / Return all operations within `parentOp` that read from or write to
114
- // / `shmMemRef`.
113
+ // Return all operations within `parentOp` that read from or write to
114
+ // `shmMemRef`.
115
115
static LogicalResult
116
116
getShmReadAndWriteOps (Operation *parentOp, Value shmMemRef,
117
117
SmallVector<Operation *, 16 > &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
131
131
writeOps.push_back (op);
132
132
});
133
133
134
- // / Restrict to a supported set of ops. We also require at least 2D access,
135
- // / although this could be relaxed.
134
+ // Restrict to a supported set of ops. We also require at least 2D access,
135
+ // although this could be relaxed.
136
136
if (llvm::any_of (readOps, [](Operation *op) {
137
137
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
138
138
op) ||
@@ -149,23 +149,22 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
149
149
return success ();
150
150
}
151
151
152
- mlir::LogicalResult
153
- mlir::amdgpu::optimizeSharedMemoryReadsAndWrites (Operation *parentOp,
154
- Value memrefValue) {
152
+ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites (Operation *parentOp,
153
+ Value memrefValue) {
155
154
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType ());
156
155
if (!memRefType ||
157
156
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace (memRefType))
158
157
return failure ();
159
158
160
- // / Abort if the given value has any sub-views; we do not do any alias
161
- // / analysis.
159
+ // Abort if the given value has any sub-views; we do not do any alias
160
+ // analysis.
162
161
bool hasSubView = false ;
163
162
parentOp->walk ([&](memref::SubViewOp subView) { hasSubView = true ; });
164
163
if (hasSubView)
165
164
return failure ();
166
165
167
- // / Check if this is necessary given the assumption of 128b accesses:
168
- // / If dim[rank-1] is small enough to fit 8 rows in a 128B line.
166
+ // Check if this is necessary given the assumption of 128b accesses:
167
+ // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
169
168
const int64_t rowSize = memRefType.getDimSize (memRefType.getRank () - 1 );
170
169
const int64_t rowsPerLine =
171
170
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth ()) /
@@ -175,8 +174,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
175
174
if (rowsPerLine >= threadGroupSize)
176
175
return failure ();
177
176
178
- // / Get sets of operations within the function that read/write to shared
179
- // / memory.
177
+ // Get sets of operations within the function that read/write to shared
178
+ // memory.
180
179
SmallVector<Operation *, 16 > shmReadOps;
181
180
SmallVector<Operation *, 16 > shmWriteOps;
182
181
if (failed (getShmReadAndWriteOps (parentOp, memrefValue, shmReadOps,
@@ -191,7 +190,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
191
190
int64_t tgtDim = memRefType.getRank () - 1 ;
192
191
int64_t srcDim = memRefType.getRank () - 2 ;
193
192
194
- // / Transform indices for the ops writing to shared memory.
193
+ // Transform indices for the ops writing to shared memory.
195
194
while (!shmWriteOps.empty ()) {
196
195
Operation *shmWriteOp = shmWriteOps.pop_back_val ();
197
196
builder.setInsertionPoint (shmWriteOp);
@@ -203,7 +202,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
203
202
amdgpu::setIndices (shmWriteOp, transformedIndices);
204
203
}
205
204
206
- // / Transform indices for the ops reading from shared memory.
205
+ // Transform indices for the ops reading from shared memory.
207
206
while (!shmReadOps.empty ()) {
208
207
Operation *shmReadOp = shmReadOps.pop_back_val ();
209
208
builder.setInsertionPoint (shmReadOp);
@@ -218,7 +217,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
218
217
return success ();
219
218
}
220
219
221
- void amdgpu::optimizeSharedMemoryReadsAndWritesOp (func::FuncOp funcOp) {
220
+ std::optional<LogicalResult>
221
+ amdgpu::optimizeSharedMemoryReadsAndWritesOp (func::FuncOp funcOp) {
222
222
SmallVector<memref::AllocOp> shmAllocOps;
223
223
funcOp.walk ([&](memref::AllocOp allocOp) {
224
224
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace (allocOp.getType ()))
@@ -228,8 +228,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
228
228
for (auto allocOp : shmAllocOps) {
229
229
if (failed (amdgpu::optimizeSharedMemoryReadsAndWrites (funcOp,
230
230
allocOp.getMemref ())))
231
- return ;
231
+ return failure () ;
232
232
}
233
+ return success ();
233
234
}
234
235
235
236
struct OptimizeSharedMemoryPass
0 commit comments