Skip to content

Commit 564ebc8

Browse files
committed
Make it work
1 parent c072e66 commit 564ebc8

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
903903
}
904904
};
905905

906-
struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
906+
struct GlobalLoadLDSOpLowering
907+
: public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
907908
GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
908909
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
909910

@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
918919
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
919920
if (elemSizeInBits % 8 != 0)
920921
return op.emitOpError("element size must be a multiple of 8");
922+
923+
// TODO: instead of only transfering one element per thread, we could
924+
// augment it to transfer multiple elements per thread by issuing multiple
925+
// `global_load_lds` instructions.
921926
auto loadWidth = elemSizeInBits / 8;
922927

923928
// TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
934939
Value memrefSrc = op.getSrc();
935940
Value memrefDst = op.getDst();
936941

937-
// Collapse src memref with indices:
938-
auto flattenIndex = [&](Value memref, MemRefType memrefType,
939-
ValueRange indices) -> std::optional<Value> {
942+
// Collapse src memref with indices, returns the base pointer and linearized
943+
// index.
944+
auto flattenIndex =
945+
[&](Value memref, MemRefType memrefType,
946+
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
940947
MemRefDescriptor memRefDescriptor(memref);
941948
int64_t offset = 0;
942949
SmallVector<int64_t, 5> strides;
943950
if (failed(memrefType.getStridesAndOffset(strides, offset)))
944951
return {};
945-
return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
946-
strides);
952+
return std::make_pair(
953+
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
954+
memrefType),
955+
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
947956
};
948957

949958
// Source
950-
auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
951-
op.getSrcIndices());
952-
if (!optSrcIdx)
959+
auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
960+
op.getSrcIndices());
961+
if (!optSrcBuffer)
953962
return op.emitOpError("failed to flatten source memref indices");
954-
auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
955-
op.getDstIndices());
956-
if (!optDstIdx)
963+
auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
964+
op.getDstIndices());
965+
if (!optDstBuffer)
957966
return op.emitOpError("failed to flatten destination memref indices");
958967

959-
Type srcPtrType =
960-
LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
961-
Type dstPtrType =
962-
LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
968+
Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
969+
Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
963970
Value srcPtr = rewriter.create<LLVM::GEPOp>(
964-
loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
965-
971+
loc, srcPtrType, elemType, optSrcBuffer->first,
972+
ArrayRef<Value>({optSrcBuffer->second}));
973+
966974
Value dstPtr = rewriter.create<LLVM::GEPOp>(
967-
loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
975+
loc, dstPtrType, elemType, optDstBuffer->first,
976+
ArrayRef<Value>({optDstBuffer->second}));
968977

969978
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
970979
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

0 commit comments

Comments
 (0)