Skip to content

Commit 499abb2

Browse files
committed
Add generic type attribute mapping infrastructure, use it in GpuToX
Remapping memory spaces is a function often needed in type conversions, most often when going to LLVM or to/from SPIR-V (a future commit), and it is possible that such remappings may become more common in the future as dialects take advantage of the more generic memory space infrastructure. Currently, memory space remappings are handled by running a special-purpose conversion pass before the main conversion that changes the address space attributes. In this commit, this approach is replaced by adding a notion of type attribute conversions TypeConverter, which is then used to convert memory space attributes. Then, we use this infrastructure throughout the *ToLLVM conversions. This has the advantage of loosing the requirements on the inputs to those passes from "all address spaces must be integers" to "all memory spaces must be convertible to integer spaces", a looser requirement that reduces the coupling between portions of MLIR. ON top of that, this change leads to the removal of most of the calls to getMemorySpaceAsInt(), bringing us closer to removing it. (A rework of the SPIR-V conversions to use this new system will be in a folowup commit.) As a note, one long-term motivation for this change is that I would eventually like to add an allocaMemorySpace key to MLIR data layouts and then call getMemRefAddressSpace(allocaMemorySpace) in the relevant *ToLLVM in order to ensure all alloca()s, whether incoming or produces during the LLVM lowering, have the correct address space for a given target. I expect that the type attribute conversion system may be useful in other contexts. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D142159
1 parent 3c565c2 commit 499abb2

File tree

24 files changed

+411
-384
lines changed

24 files changed

+411
-384
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ class LLVMTypeConverter : public TypeConverter {
147147
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
148148
const DataLayout &layout);
149149

150+
/// Return the LLVM address space corresponding to the memory space of the
151+
/// memref type `type` or failure if the memory space cannot be converted to
152+
/// an integer.
153+
FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type);
154+
150155
/// Check if a memref type can be converted to a bare pointer.
151156
static bool canConvertToBarePtr(BaseMemRefType type);
152157

mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def AMDGPU_Dialect : Dialect {
2525

2626

2727
let dependentDialects = [
28-
"arith::ArithDialect"
28+
"arith::ArithDialect",
29+
"gpu::GPUDialect"
2930
];
3031
let useDefaultAttributePrinterParser = 1;
3132
}

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,6 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
6161
}
6262

6363
namespace gpu {
64-
/// A function that maps a MemorySpace enum to a target-specific integer value.
65-
using MemorySpaceMapping =
66-
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
67-
68-
/// Populates type conversion rules for lowering memory space attributes to
69-
/// numeric values.
70-
void populateMemorySpaceAttributeTypeConversions(
71-
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
72-
73-
/// Populates patterns to lower memory space attributes to numeric values.
74-
void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter,
75-
RewritePatternSet &patterns);
76-
77-
/// Populates legality rules for lowering memory space attriutes to numeric
78-
/// values.
79-
void populateLowerMemorySpaceOpLegality(ConversionTarget &target);
80-
8164
/// Returns the default annotation name for GPU binary blobs.
8265
std::string getDefaultGpuBinaryAnnotation();
8366

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,4 @@ def GpuMapParallelLoopsPass
3737
let dependentDialects = ["mlir::gpu::GPUDialect"];
3838
}
3939

40-
def GPULowerMemorySpaceAttributesPass
41-
: Pass<"gpu-lower-memory-space-attributes"> {
42-
let summary = "Assign numeric values to memref memory space symbolic placeholders";
43-
let description = [{
44-
Updates all memref types that have a memory space attribute
45-
that is a `gpu::AddressSpaceAttr`. These attributes are
46-
changed to `IntegerAttr`'s using a mapping that is given in the
47-
options.
48-
}];
49-
let options = [
50-
Option<"privateAddrSpace", "private", "unsigned", "5",
51-
"private address space numeric value">,
52-
Option<"workgroupAddrSpace", "workgroup", "unsigned", "3",
53-
"workgroup address space numeric value">,
54-
Option<"globalAddrSpace", "global", "unsigned", "1",
55-
"global address space numeric value">
56-
];
57-
}
58-
5940
#endif // MLIR_DIALECT_GPU_PASSES

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
namespace mlir {
2222

2323
// Forward declarations.
24+
class Attribute;
2425
class Block;
2526
class ConversionPatternRewriter;
2627
class MLIRContext;
@@ -87,6 +88,34 @@ class TypeConverter {
8788
SmallVector<Type, 4> argTypes;
8889
};
8990

91+
/// The general result of a type attribute conversion callback, allowing
92+
/// for early termination. The default constructor creates the na case.
93+
class AttributeConversionResult {
94+
public:
95+
constexpr AttributeConversionResult() : impl() {}
96+
AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
97+
98+
static AttributeConversionResult result(Attribute attr);
99+
static AttributeConversionResult na();
100+
static AttributeConversionResult abort();
101+
102+
bool hasResult() const;
103+
bool isNa() const;
104+
bool isAbort() const;
105+
106+
Attribute getResult() const;
107+
108+
private:
109+
AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
110+
111+
llvm::PointerIntPair<Attribute, 2> impl;
112+
// Note that na is 0 so that we can use PointerIntPair's default
113+
// constructor.
114+
static constexpr unsigned naTag = 0;
115+
static constexpr unsigned resultTag = 1;
116+
static constexpr unsigned abortTag = 2;
117+
};
118+
90119
/// Register a conversion function. A conversion function must be convertible
91120
/// to any of the following forms(where `T` is a class derived from `Type`:
92121
/// * std::optional<Type>(T)
@@ -156,6 +185,34 @@ class TypeConverter {
156185
wrapMaterialization<T>(std::forward<FnT>(callback)));
157186
}
158187

188+
/// Register a conversion function for attributes within types. Type
189+
/// converters may call this function in order to allow hoking into the
190+
/// translation of attributes that exist within types. For example, a type
191+
/// converter for the `memref` type could use these conversions to convert
192+
/// memory spaces or layouts in an extensible way.
193+
///
194+
/// The conversion functions take a non-null Type or subclass of Type and a
195+
/// non-null Attribute (or subclass of Attribute), and returns a
196+
/// `AttributeConversionResult`. This result can either contan an `Attribute`,
197+
/// which may be `nullptr`, representing the conversion's success,
198+
/// `AttributeConversionResult::na()` (the default empty value), indicating
199+
/// that the conversion function did not apply and that further conversion
200+
/// functions should be checked, or `AttributeConversionResult::abort()`
201+
/// indicating that the conversion process should be aborted.
202+
///
203+
/// Registered conversion functions are callled in the reverse of the order in
204+
/// which they were registered.
205+
template <
206+
typename FnT,
207+
typename T =
208+
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
209+
typename A =
210+
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
211+
void addTypeAttributeConversion(FnT &&callback) {
212+
registerTypeAttributeConversion(
213+
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
214+
}
215+
159216
/// Convert the given type. This function should return failure if no valid
160217
/// conversion exists, success otherwise. If the new set of types is empty,
161218
/// the type is removed and any usages of the existing value are expected to
@@ -226,6 +283,12 @@ class TypeConverter {
226283
resultType, inputs);
227284
}
228285

286+
/// Convert an attribute present `attr` from within the type `type` using
287+
/// the registered conversion functions. If no applicable conversion has been
288+
/// registered, return std::nullopt. Note that the empty attribute/`nullptr`
289+
/// is a valid return value for this function.
290+
std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
291+
229292
private:
230293
/// The signature of the callback used to convert a type. If the new set of
231294
/// types is empty, the type is removed and any usages of the existing value
@@ -237,6 +300,10 @@ class TypeConverter {
237300
using MaterializationCallbackFn = std::function<std::optional<Value>(
238301
OpBuilder &, Type, ValueRange, Location)>;
239302

303+
/// The signature of the callback used to convert a type attribute.
304+
using TypeAttributeConversionCallbackFn =
305+
std::function<AttributeConversionResult(Type, Attribute)>;
306+
240307
/// Attempt to materialize a conversion using one of the provided
241308
/// materialization functions.
242309
Value materializeConversion(
@@ -311,6 +378,32 @@ class TypeConverter {
311378
};
312379
}
313380

381+
/// Generate a wrapper for the given memory space conversion callback. The
382+
/// callback may take any subclass of `Attribute` and the wrapper will check
383+
/// for the target attribute to be of the expected class before calling the
384+
/// callback.
385+
template <typename T, typename A, typename FnT>
386+
TypeAttributeConversionCallbackFn
387+
wrapTypeAttributeConversion(FnT &&callback) {
388+
return [callback = std::forward<FnT>(callback)](
389+
Type type, Attribute attr) -> AttributeConversionResult {
390+
if (T derivedType = type.dyn_cast<T>()) {
391+
if (A derivedAttr = attr.dyn_cast_or_null<A>())
392+
return callback(derivedType, derivedAttr);
393+
}
394+
return AttributeConversionResult::na();
395+
};
396+
}
397+
398+
/// Register a memory space conversion, clearing caches.
399+
void
400+
registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
401+
typeAttributeConversions.emplace_back(std::move(callback));
402+
// Clear type conversions in case a memory space is lingering inside.
403+
cachedDirectConversions.clear();
404+
cachedMultiConversions.clear();
405+
}
406+
314407
/// The set of registered conversion functions.
315408
SmallVector<ConversionCallbackFn, 4> conversions;
316409

@@ -319,6 +412,9 @@ class TypeConverter {
319412
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
320413
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
321414

415+
/// The list of registered type attribute conversion functions.
416+
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
417+
322418
/// A set of cached conversions to avoid recomputing in the common case.
323419
/// Direct 1-1 conversions are the most common, so this cache stores the
324420
/// successful 1-1 conversions as well as all failed conversions.

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "GPUOpsLowering.h"
1010
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11+
#include "mlir/IR/Attributes.h"
1112
#include "mlir/IR/Builders.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "llvm/ADT/STLExtras.h"
@@ -474,3 +475,18 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
474475
rewriter.replaceOp(op, result);
475476
return success();
476477
}
478+
479+
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
480+
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
481+
}
482+
483+
void mlir::populateGpuMemorySpaceAttributeConversions(
484+
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
485+
typeConverter.addTypeAttributeConversion(
486+
[mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
487+
gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
488+
unsigned addressSpace = mapping(memorySpace);
489+
return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
490+
addressSpace);
491+
});
492+
}

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
112112
}
113113
};
114114

115+
/// A function that maps a MemorySpace enum to a target-specific integer value.
116+
using MemorySpaceMapping =
117+
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
118+
119+
/// Populates memory space attribute conversion rules for lowering
120+
/// gpu.address_space to integer values.
121+
void populateGpuMemorySpaceAttributeConversions(
122+
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
115123
} // namespace mlir
116124

117125
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -241,38 +241,26 @@ struct LowerGpuOpsToNVVMOpsPass
241241
return signalPassFailure();
242242
}
243243

244-
// MemRef conversion for GPU to NVVM lowering.
245-
{
246-
RewritePatternSet patterns(m.getContext());
247-
TypeConverter typeConverter;
248-
typeConverter.addConversion([](Type t) { return t; });
249-
// NVVM uses alloca in the default address space to represent private
250-
// memory allocations, so drop private annotations. NVVM uses address
251-
// space 3 for shared memory. NVVM uses the default address space to
252-
// represent global memory.
253-
gpu::populateMemorySpaceAttributeTypeConversions(
254-
typeConverter, [](gpu::AddressSpace space) -> unsigned {
255-
switch (space) {
256-
case gpu::AddressSpace::Global:
257-
return static_cast<unsigned>(
258-
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
259-
case gpu::AddressSpace::Workgroup:
260-
return static_cast<unsigned>(
261-
NVVM::NVVMMemorySpace::kSharedMemorySpace);
262-
case gpu::AddressSpace::Private:
263-
return 0;
264-
}
265-
llvm_unreachable("unknown address space enum value");
266-
return 0;
267-
});
268-
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
269-
ConversionTarget target(getContext());
270-
gpu::populateLowerMemorySpaceOpLegality(target);
271-
if (failed(applyFullConversion(m, target, std::move(patterns))))
272-
return signalPassFailure();
273-
}
274-
275244
LLVMTypeConverter converter(m.getContext(), options);
245+
// NVVM uses alloca in the default address space to represent private
246+
// memory allocations, so drop private annotations. NVVM uses address
247+
// space 3 for shared memory. NVVM uses the default address space to
248+
// represent global memory.
249+
populateGpuMemorySpaceAttributeConversions(
250+
converter, [](gpu::AddressSpace space) -> unsigned {
251+
switch (space) {
252+
case gpu::AddressSpace::Global:
253+
return static_cast<unsigned>(
254+
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
255+
case gpu::AddressSpace::Workgroup:
256+
return static_cast<unsigned>(
257+
NVVM::NVVMMemorySpace::kSharedMemorySpace);
258+
case gpu::AddressSpace::Private:
259+
return 0;
260+
}
261+
llvm_unreachable("unknown address space enum value");
262+
return 0;
263+
});
276264
// Lowering for MMAMatrixType.
277265
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
278266
return convertMMAToLLVMType(type);

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -132,33 +132,21 @@ struct LowerGpuOpsToROCDLOpsPass
132132
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
133133
}
134134

135-
// Apply memory space lowering. The target uses 3 for workgroup memory and 5
136-
// for private memory.
137-
{
138-
RewritePatternSet patterns(ctx);
139-
TypeConverter typeConverter;
140-
typeConverter.addConversion([](Type t) { return t; });
141-
gpu::populateMemorySpaceAttributeTypeConversions(
142-
typeConverter, [](gpu::AddressSpace space) {
143-
switch (space) {
144-
case gpu::AddressSpace::Global:
145-
return 1;
146-
case gpu::AddressSpace::Workgroup:
147-
return 3;
148-
case gpu::AddressSpace::Private:
149-
return 5;
150-
}
151-
llvm_unreachable("unknown address space enum value");
152-
return 0;
153-
});
154-
ConversionTarget target(getContext());
155-
gpu::populateLowerMemorySpaceOpLegality(target);
156-
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
157-
if (failed(applyFullConversion(m, target, std::move(patterns))))
158-
return signalPassFailure();
159-
}
160-
161135
LLVMTypeConverter converter(ctx, options);
136+
populateGpuMemorySpaceAttributeConversions(
137+
converter, [](gpu::AddressSpace space) {
138+
switch (space) {
139+
case gpu::AddressSpace::Global:
140+
return 1;
141+
case gpu::AddressSpace::Workgroup:
142+
return 3;
143+
case gpu::AddressSpace::Private:
144+
return 5;
145+
}
146+
llvm_unreachable("unknown address space enum value");
147+
return 0;
148+
});
149+
162150
RewritePatternSet llvmPatterns(ctx);
163151

164152
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
112112
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
113113
auto elementType = type.getElementType();
114114
auto structElementType = typeConverter->convertType(elementType);
115-
return getTypeConverter()->getPointerType(structElementType,
116-
type.getMemorySpaceAsInt());
115+
auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
116+
if (failed(addressSpace))
117+
return {};
118+
return getTypeConverter()->getPointerType(structElementType, *addressSpace);
117119
}
118120

119121
void ConvertToLLVMPattern::getMemRefDescriptorSizes(

0 commit comments

Comments
 (0)