Skip to content

Commit 7fb9bbe

Browse files
committed
[mlir][Memref] Add memref.memory_space_cast and its lowerings
Address space casts are present in common MLIR targets (LLVM, SPIRV). Some planned rewrites (such as one of the potential fixes to the fact that the AMDGPU backend requires alloca() to live in address space 5 / the GPU private memory space) may require such casts to be inserted into MLIR code, where those address spaces could be represented by arbitrary memory space attributes. Therefore, we define memref.memory_space_cast and its lowerings. Depends on D141293 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D141148
1 parent b7a2ff2 commit 7fb9bbe

File tree

11 files changed

+506
-24
lines changed

11 files changed

+506
-24
lines changed

mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
218218
LLVM::LLVMPointerType elemPtrType,
219219
Value alignedPtr);
220220

221+
/// Builds IR for getting the pointer to the offset's location.
222+
/// Returns a pointer to a convertType(index), which points to the beggining
223+
/// of a struct {index, index[rank], index[rank]}.
224+
static Value offsetBasePtr(OpBuilder &builder, Location loc,
225+
LLVMTypeConverter &typeConverter,
226+
Value memRefDescPtr,
227+
LLVM::LLVMPointerType elemPtrType);
221228
/// Builds IR extracting the offset from the descriptor.
222229
static Value offset(OpBuilder &builder, Location loc,
223230
LLVMTypeConverter &typeConverter, Value memRefDescPtr,

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,54 @@ def LoadOp : MemRef_Op<"load",
11771177
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
11781178
}
11791179

1180+
//===----------------------------------------------------------------------===//
1181+
// MemorySpaceCastOp
1182+
//===----------------------------------------------------------------------===//
1183+
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
1184+
DeclareOpInterfaceMethods<CastOpInterface>,
1185+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1186+
MemRefsNormalizable,
1187+
Pure,
1188+
SameOperandsAndResultElementType,
1189+
SameOperandsAndResultShape,
1190+
ViewLikeOpInterface
1191+
]> {
1192+
let summary = "memref memory space cast operation";
1193+
let description = [{
1194+
This operation casts memref values between memory spaces.
1195+
The input and result will be memrefs of the same types and shape that alias
1196+
the same underlying memory, though, for some casts on some targets,
1197+
the underlying values of the pointer stored in the memref may be affected
1198+
by the cast.
1199+
1200+
The input and result must have the same shape, element type, rank, and layout.
1201+
1202+
If the source and target address spaces are the same, this operation is a noop.
1203+
1204+
Example:
1205+
1206+
```mlir
1207+
// Cast a GPU private memory attribution into a generic pointer
1208+
%2 = memref.memory_space_cast %1 : memref<?xf32, 5> to memref<?xf32>
1209+
// Cast a generic pointer to workgroup-local memory
1210+
%4 = memref.memory_space_cast %3 : memref<5x4xi32> to memref<5x34xi32, 3>
1211+
// Cast between two non-default memory spaces
1212+
%6 = memref.memory_space_cast %5
1213+
: memref<*xmemref<?xf32>, 5> to memref<*xmemref<?xf32>, 3>
1214+
```
1215+
}];
1216+
1217+
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
1218+
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
1219+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
1220+
1221+
let extraClassDeclaration = [{
1222+
Value getViewSource() { return getSource(); }
1223+
}];
1224+
1225+
let hasFolder = 1;
1226+
}
1227+
11801228
//===----------------------------------------------------------------------===//
11811229
// PrefetchOp
11821230
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "MemRefDescriptor.h"
1111
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1212
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1314
#include "mlir/IR/Builders.h"
1415
#include "mlir/Support/MathExtras.h"
1516

@@ -457,10 +458,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
457458
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
458459
}
459460

460-
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
461-
LLVMTypeConverter &typeConverter,
462-
Value memRefDescPtr,
463-
LLVM::LLVMPointerType elemPtrType) {
461+
Value UnrankedMemRefDescriptor::offsetBasePtr(
462+
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
463+
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
464464
auto [elementPtrPtr, elemPtrPtrType] =
465465
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
466466

@@ -473,30 +473,26 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
473473
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()),
474474
offsetGep);
475475
}
476+
return offsetGep;
477+
}
476478

477-
return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(),
478-
offsetGep);
479+
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
480+
LLVMTypeConverter &typeConverter,
481+
Value memRefDescPtr,
482+
LLVM::LLVMPointerType elemPtrType) {
483+
Value offsetPtr =
484+
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
485+
return builder.create<LLVM::LoadOp>(loc, offsetPtr);
479486
}
480487

481488
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
482489
LLVMTypeConverter &typeConverter,
483490
Value memRefDescPtr,
484491
LLVM::LLVMPointerType elemPtrType,
485492
Value offset) {
486-
auto [elementPtrPtr, elemPtrPtrType] =
487-
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
488-
489-
Value offsetGep =
490-
builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
491-
elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
492-
493-
if (!elemPtrType.isOpaque()) {
494-
offsetGep = builder.create<LLVM::BitcastOp>(
495-
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()),
496-
offsetGep);
497-
}
498-
499-
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
493+
Value offsetPtr =
494+
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
495+
builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
500496
}
501497

502498
Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2021
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2122
#include "mlir/IR/AffineMap.h"
2223
#include "mlir/IR/IRMapping.h"
2324
#include "mlir/Pass/Pass.h"
25+
#include "mlir/Support/MathExtras.h"
2426
#include "llvm/ADT/SmallBitVector.h"
2527
#include <optional>
2628

@@ -1096,6 +1098,118 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
10961098
}
10971099
};
10981100

1101+
struct MemorySpaceCastOpLowering
1102+
: public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1103+
using ConvertOpToLLVMPattern<
1104+
memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1105+
1106+
LogicalResult
1107+
matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1108+
ConversionPatternRewriter &rewriter) const override {
1109+
Location loc = op.getLoc();
1110+
1111+
Type resultType = op.getDest().getType();
1112+
if (auto resultTypeR = resultType.dyn_cast<MemRefType>()) {
1113+
auto resultDescType =
1114+
typeConverter->convertType(resultTypeR).cast<LLVM::LLVMStructType>();
1115+
Type newPtrType = resultDescType.getBody()[0];
1116+
1117+
SmallVector<Value> descVals;
1118+
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1119+
descVals);
1120+
descVals[0] =
1121+
rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1122+
descVals[1] =
1123+
rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1124+
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1125+
resultTypeR, descVals);
1126+
rewriter.replaceOp(op, result);
1127+
return success();
1128+
}
1129+
if (auto resultTypeU = resultType.dyn_cast<UnrankedMemRefType>()) {
1130+
// Since the type converter won't be doing this for us, get the address
1131+
// space.
1132+
auto sourceType = op.getSource().getType().cast<UnrankedMemRefType>();
1133+
FailureOr<unsigned> maybeSourceAddrSpace =
1134+
getTypeConverter()->getMemRefAddressSpace(sourceType);
1135+
if (failed(maybeSourceAddrSpace))
1136+
return rewriter.notifyMatchFailure(loc,
1137+
"non-integer source address space");
1138+
unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1139+
FailureOr<unsigned> maybeResultAddrSpace =
1140+
getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1141+
if (failed(maybeResultAddrSpace))
1142+
return rewriter.notifyMatchFailure(loc,
1143+
"non-integer result address space");
1144+
unsigned resultAddrSpace = *maybeResultAddrSpace;
1145+
1146+
UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1147+
Value rank = sourceDesc.rank(rewriter, loc);
1148+
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1149+
1150+
// Create and allocate storage for new memref descriptor.
1151+
auto result = UnrankedMemRefDescriptor::undef(
1152+
rewriter, loc, typeConverter->convertType(resultTypeU));
1153+
result.setRank(rewriter, loc, rank);
1154+
SmallVector<Value, 1> sizes;
1155+
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1156+
result, resultAddrSpace, sizes);
1157+
Value resultUnderlyingSize = sizes.front();
1158+
Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1159+
loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
1160+
result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1161+
1162+
// Copy pointers, performing address space casts.
1163+
Type llvmElementType =
1164+
typeConverter->convertType(sourceType.getElementType());
1165+
LLVM::LLVMPointerType sourceElemPtrType =
1166+
getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace);
1167+
auto resultElemPtrType =
1168+
getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace);
1169+
1170+
Value allocatedPtr = sourceDesc.allocatedPtr(
1171+
rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1172+
Value alignedPtr =
1173+
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1174+
sourceUnderlyingDesc, sourceElemPtrType);
1175+
allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1176+
loc, resultElemPtrType, allocatedPtr);
1177+
alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1178+
loc, resultElemPtrType, alignedPtr);
1179+
1180+
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1181+
resultElemPtrType, allocatedPtr);
1182+
result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1183+
resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1184+
1185+
// Copy all the index-valued operands.
1186+
Value sourceIndexVals =
1187+
sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1188+
sourceUnderlyingDesc, sourceElemPtrType);
1189+
Value resultIndexVals =
1190+
result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1191+
resultUnderlyingDesc, resultElemPtrType);
1192+
1193+
int64_t bytesToSkip =
1194+
2 *
1195+
ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1196+
Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1197+
loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1198+
Value copySize = rewriter.create<LLVM::SubOp>(
1199+
loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1200+
Type llvmBool = typeConverter->convertType(rewriter.getI1Type());
1201+
Value nonVolatile = rewriter.create<LLVM::ConstantOp>(
1202+
loc, llvmBool, rewriter.getBoolAttr(false));
1203+
rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1204+
copySize, nonVolatile);
1205+
1206+
rewriter.replaceOp(op, ValueRange{result});
1207+
return success();
1208+
}
1209+
return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1210+
}
1211+
};
1212+
10991213
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
11001214
/// memref type. In unranked case, the fields are extracted from the underlying
11011215
/// ranked descriptor.
@@ -1785,6 +1899,7 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
17851899
LoadOpLowering,
17861900
MemRefCastOpLowering,
17871901
MemRefCopyOpLowering,
1902+
MemorySpaceCastOpLowering,
17881903
MemRefReinterpretCastOpLowering,
17891904
MemRefReshapeOpLowering,
17901905
PrefetchOpLowering,

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,17 @@ class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
223223
ConversionPatternRewriter &rewriter) const override;
224224
};
225225

226+
/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
227+
class MemorySpaceCastOpPattern final
228+
: public OpConversionPattern<memref::MemorySpaceCastOp> {
229+
public:
230+
using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
231+
232+
LogicalResult
233+
matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
234+
ConversionPatternRewriter &rewriter) const override;
235+
};
236+
226237
/// Converts memref.store to spirv.Store.
227238
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
228239
public:
@@ -552,6 +563,74 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
552563
return success();
553564
}
554565

566+
//===----------------------------------------------------------------------===//
567+
// MemorySpaceCastOp
568+
//===----------------------------------------------------------------------===//
569+
570+
LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
571+
memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
572+
ConversionPatternRewriter &rewriter) const {
573+
Location loc = addrCastOp.getLoc();
574+
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
575+
if (!typeConverter.allows(spirv::Capability::Kernel))
576+
return rewriter.notifyMatchFailure(
577+
loc, "address space casts require kernel capability");
578+
579+
auto sourceType = addrCastOp.getSource().getType().dyn_cast<MemRefType>();
580+
if (!sourceType)
581+
return rewriter.notifyMatchFailure(
582+
loc, "SPIR-V lowering requires ranked memref types");
583+
auto resultType = addrCastOp.getResult().getType().cast<MemRefType>();
584+
585+
auto sourceStorageClassAttr =
586+
sourceType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
587+
if (!sourceStorageClassAttr)
588+
return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
589+
diag << "source address space " << sourceType.getMemorySpace()
590+
<< " must be a SPIR-V storage class";
591+
});
592+
auto resultStorageClassAttr =
593+
resultType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
594+
if (!resultStorageClassAttr)
595+
return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
596+
diag << "result address space " << resultType.getMemorySpace()
597+
<< " must be a SPIR-V storage class";
598+
});
599+
600+
spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
601+
spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
602+
603+
Value result = adaptor.getSource();
604+
Type resultPtrType = typeConverter.convertType(resultType);
605+
Type genericPtrType = resultPtrType;
606+
// SPIR-V doesn't have a general address space cast operation. Instead, it has
607+
// conversions to and from generic pointers. To implement the general case,
608+
// we use specific-to-generic conversions when the source class is not
609+
// generic. Then when the result storage class is not generic, we convert the
610+
// generic pointer (either the input on ar intermediate result) to theat
611+
// class. This also means that we'll need the intermediate generic pointer
612+
// type if neither the source or destination have it.
613+
if (sourceSc != spirv::StorageClass::Generic &&
614+
resultSc != spirv::StorageClass::Generic) {
615+
Type intermediateType =
616+
MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
617+
sourceType.getLayout(),
618+
rewriter.getAttr<spirv::StorageClassAttr>(
619+
spirv::StorageClass::Generic));
620+
genericPtrType = typeConverter.convertType(intermediateType);
621+
}
622+
if (sourceSc != spirv::StorageClass::Generic) {
623+
result =
624+
rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
625+
}
626+
if (resultSc != spirv::StorageClass::Generic) {
627+
result =
628+
rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
629+
}
630+
rewriter.replaceOp(addrCastOp, result);
631+
return success();
632+
}
633+
555634
LogicalResult
556635
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
557636
ConversionPatternRewriter &rewriter) const {
@@ -577,9 +656,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
577656
namespace mlir {
578657
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
579658
RewritePatternSet &patterns) {
580-
patterns
581-
.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
582-
IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
583-
typeConverter, patterns.getContext());
659+
patterns.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern,
660+
IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
661+
MemorySpaceCastOpPattern, StoreOpPattern>(typeConverter,
662+
patterns.getContext());
584663
}
585664
} // namespace mlir

0 commit comments

Comments
 (0)