Skip to content

Commit d5b8fcf

Browse files
committed
Add variables as options using ODS
1 parent 1025f2b commit d5b8fcf

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,17 @@ def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
3737
attempts to optimize reads/writes from a memref representing GPU shared
3838
memory in order to avoid bank conflicts.
3939
}];
40-
4140
let dependentDialects = [
4241
"memref::MemRefDialect", "vector::VectorDialect"
4342
];
43+
let options = [
44+
Option<"sharedMemoryLineSizeBytes", "shared-memory-line-size-bytes", "uint64_t",
45+
/*default=*/"128",
46+
"Shared memory line size in bytes">,
47+
Option<"defaultVectorSizeBits", "default-vector-size-bits", "uint64_t",
48+
/*default=*/"128",
49+
"Default vector size in bits">,
50+
];
4451
}
4552

4653
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
149149
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
150150
Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
151151
int64_t defaultVectorSizeBits) {
152-
153152
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
154153
if (!memRefType ||
155154
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -239,20 +238,12 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
239238

240239
struct OptimizeSharedMemoryPass
241240
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
242-
243241
public:
244-
OptimizeSharedMemoryPass()
245-
: OptimizeSharedMemoryBase(),
246-
sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes = 128),
247-
defaultVectorSizeBits(defaultVectorSizeBits = 128){};
248-
249-
OptimizeSharedMemoryPass(int64_t sharedMemoryLineSizeBytes,
250-
int64_t defaultVectorSizeBits)
251-
: OptimizeSharedMemoryBase(),
252-
sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes),
253-
defaultVectorSizeBits(defaultVectorSizeBits){};
254-
242+
OptimizeSharedMemoryPass() = default;
243+
OptimizeSharedMemoryPass(const OptimizeSharedMemoryOptions &options)
244+
: OptimizeSharedMemoryBase(options) {}
255245
void runOnOperation() override {
246+
256247
Operation *op = getOperation();
257248
SmallVector<memref::AllocOp> shmAllocOps;
258249
op->walk([&](memref::AllocOp allocOp) {
@@ -268,8 +259,4 @@ struct OptimizeSharedMemoryPass
268259
return;
269260
}
270261
}
271-
272-
private:
273-
int64_t sharedMemoryLineSizeBytes;
274-
int64_t defaultVectorSizeBits;
275262
};

0 commit comments

Comments
 (0)