Skip to content

Commit 42526d2

Browse files
krzysz00kuharpashu123
authored
[mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (#125594)
This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Prashant Kumar <[email protected]>
1 parent 7371f69 commit 42526d2

File tree

13 files changed

+684
-71
lines changed

13 files changed

+684
-71
lines changed

mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,26 @@ namespace mlir {
1616

1717
class LLVMTypeConverter;
1818
class RewritePatternSet;
19+
class TypeConverter;
1920
class Pass;
2021

2122
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDLPASS
2223
#include "mlir/Conversion/Passes.h.inc"
2324

24-
/// Note: The ROCDL target does not support the LLVM bfloat type at this time
25-
/// and so this function will add conversions to change all `bfloat` uses
26-
/// to `i16`.
27-
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
25+
/// Note: This function will also add conversions for the AMDGPU-specific
26+
/// address spaces, but those can be added separately using
27+
/// populateAMDGPUMemorySpaceAttributeConversions().
28+
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
2829
RewritePatternSet &patterns,
2930
amdgpu::Chipset chipset);
3031

32+
/// Remap AMDGPU memory spaces to LLVM address spaces
33+
/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
34+
/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
35+
/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
36+
void populateAMDGPUMemorySpaceAttributeConversions(
37+
TypeConverter &typeConverter);
38+
3139
} // namespace mlir
3240

3341
#endif // MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#ifndef AMDGPU
1010
#define AMDGPU
1111

12+
include "mlir/Interfaces/InferTypeOpInterface.td"
1213
include "mlir/Interfaces/SideEffectInterfaces.td"
14+
include "mlir/Interfaces/ViewLikeInterface.td"
1315
include "mlir/IR/EnumAttr.td"
16+
include "mlir/IR/Properties.td"
1417
include "mlir/IR/OpBase.td"
1518

1619
def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
3235
let useDefaultAttributePrinterParser = 1;
3336
}
3437

38+
//===----------------------------------------------------------------------===//
39+
// AMDGPU general attribute definitions
40+
//===----------------------------------------------------------------------===//
41+
42+
def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
43+
"AMDGPU-specific address spaces",
44+
[
45+
I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">,
46+
I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">,
47+
I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
48+
]> {
49+
let genSpecializedAttr = 0;
50+
let cppNamespace = "::mlir::amdgpu";
51+
}
52+
53+
def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
54+
"address_space"> {
55+
let description = [{
56+
AMDGPU-specific memory spaces that may not have exact analogues on other
57+
GPU targets or backends.
58+
59+
- `fat_raw_buffer` is the memory space used when a memref is stored as
60+
as a "buffer fat pointer" - that is, a buffer resource (that is set up to
61+
use raw byte-level indexing) along with its offset. The AMDGPU backend
62+
implements `ptr addrspace(7)` to represent these fat pointers so that
63+
buffer resources (which allow advanced features like bounds checking or
64+
cache swizzling) can be used like ordinary LLVM pointers or memrefs.
65+
See also the `fat_raw_buffer_cast` operation
66+
- `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a
67+
buffer resource. It should not be used for memrefs, since it does not support
68+
indexing
69+
- `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource
70+
that carries both an index and offset field, which are used for complex
71+
structured indexing that is primarily seen in graphics applications. This
72+
is also incompatible with the simple indexing model supported by memref.
73+
}];
74+
let assemblyFormat = "`<` $value `>`";
75+
}
76+
3577
//===----------------------------------------------------------------------===//
3678
// AMDGPU Op definitions
3779
//===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
118160
let hasVerifier = 1;
119161
}
120162

163+
def AMDGPU_FatRawBufferCastOp :
164+
AMDGPU_Op<"fat_raw_buffer_cast",
165+
[Pure,
166+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
167+
ViewLikeOpInterface, AttrSizedOperandSegments]>,
168+
Arguments<(ins AnyMemRef:$source,
169+
Optional<I32>:$validBytes,
170+
Optional<I<14>>:$cacheSwizzleStride,
171+
DefaultValuedProp<BoolProp, "true">:$boundsCheck,
172+
UnitProp:$resetOffset)>,
173+
Results<(outs AnyMemRef:$result)> {
174+
let summary = "Create a raw buffer fat pointer that matches `memref`";
175+
let description = [{
176+
Wraps the memory pointed to by `source` as a raw buffer fat pointer, or,
177+
in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same
178+
sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
179+
address space.
180+
181+
This memref can be used with standard memref operations like `memref.load`,
182+
`memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
183+
buffer intrinsics. (`vector.masked_load/store` will work once there's backend
184+
support for lowering them, and then this document will be updated)
185+
186+
If `validBytes` is given, it is the number of bytes that will be valid as
187+
an offset to `out`. If it is not provided, this will be inferred from
188+
the size of the memref during lowering. This size is
189+
max_{d = 0 upto rank(source)} (sizes[d] * strides[d]) * sizeof(element type).
190+
191+
The flags of the buffer descriptor will be set up to enable raw usage -
192+
for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
193+
property determines if bounds checking is enabled or not (on architectures
194+
where this can be controlled - that is, on RDNA chips).
195+
196+
If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
197+
on architectures that support it. This swizzling, unlike the main swizzling
198+
mode (whose usage makes a buffer non-raw) does not affect index calculation,
199+
but does affect cache behavior. Mixing access between cache-swizzled raw
200+
buffers and other forms of memory access, like ordinary pointer loads or
201+
unswizzled buffer pointers can cause incorrect behavior and must be avoided.
202+
203+
This operation preserves the sizes, strides, and offset of the input
204+
memref - they'll be added in by `memref.load` later. However, if
205+
`resetOffset` is set, that offset will be added to the base pointer.
206+
If the value of the memref's offset is not uniform (independent of the lane/thread ID),
207+
this will lead to substantially decreased performance due to the need for
208+
a waterfall loop on the base address of the buffer resource.
209+
}];
210+
211+
let extraClassDeclaration = [{
212+
Value getViewSource() { return getSource(); }
213+
}];
214+
215+
let assemblyFormat = [{
216+
$source oilist (`validBytes` `(` $validBytes `)`
217+
| `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
218+
| `boundsCheck` `(` $boundsCheck `)`
219+
| `resetOffset` $resetOffset )
220+
attr-dict `:` type($source) `to` type($result)
221+
}];
222+
223+
let hasVerifier = 1;
224+
}
225+
121226
/// Raw buffer load
122227
def AMDGPU_RawBufferLoadOp :
123228
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#include "mlir/IR/BuiltinTypes.h"
1919
#include "mlir/IR/Dialect.h"
2020
#include "mlir/IR/OpDefinition.h"
21+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/SideEffectInterfaces.h"
23+
#include "mlir/Interfaces/ViewLikeInterface.h"
2224

2325
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
2426

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ class ConversionTarget;
2121
namespace amdgpu {
2222

2323
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
24+
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
2425
#define GEN_PASS_REGISTRATION
2526
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
2627

2728
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
2829
RewritePatternSet &patterns,
2930
Chipset chipset);
31+
32+
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
3033
} // namespace amdgpu
3134
} // namespace mlir
3235

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,24 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
3131
"Chipset that these operations will run on">];
3232
}
3333

34+
def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
35+
let summary = "Resolve memref.extract_strided_metadata on AMDGPU ops";
36+
let description = [{
37+
This pass rrewrites `memref.extract_strided_metadata` operations
38+
targeting the AMDGPU dialect casts.
39+
40+
The patterns in this pass should normally be run alongside those in
41+
-expand-strided-metadata, and creating a pass that combines those two
42+
sets of patterns is the recommended way to use this functionality.
43+
However, this pass (which will likely need a second -expand-strided-metadata
44+
after it) is provided so that simple usecases do not need to create custom passes.
45+
These patterns have not been added to -expnad-strided-metadata to
46+
prevent the memref dialect from depending on platform-specific code.
47+
}];
48+
let dependentDialects = [
49+
"arith::ArithDialect",
50+
"memref::MemRefDialect"
51+
];
52+
}
53+
3454
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

0 commit comments

Comments
 (0)