@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
903
903
}
904
904
};
905
905
906
- struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern <GlobalLoadLDSOp> {
906
+ struct GlobalLoadLDSOpLowering
907
+ : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
907
908
GlobalLoadLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
908
909
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
909
910
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
918
919
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth ();
919
920
if (elemSizeInBits % 8 != 0 )
920
921
return op.emitOpError (" element size must be a multiple of 8" );
922
+
923
+ // TODO: instead of only transfering one element per thread, we could
924
+ // augment it to transfer multiple elements per thread by issuing multiple
925
+ // `global_load_lds` instructions.
921
926
auto loadWidth = elemSizeInBits / 8 ;
922
927
923
928
// TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
934
939
Value memrefSrc = op.getSrc ();
935
940
Value memrefDst = op.getDst ();
936
941
937
- // Collapse src memref with indices:
938
- auto flattenIndex = [&](Value memref, MemRefType memrefType,
939
- ValueRange indices) -> std::optional<Value> {
942
+ // Collapse src memref with indices, returns the base pointer and linearized
943
+ // index.
944
+ auto flattenIndex =
945
+ [&](Value memref, MemRefType memrefType,
946
+ ValueRange indices) -> std::optional<std::pair<Value, Value>> {
940
947
MemRefDescriptor memRefDescriptor (memref);
941
948
int64_t offset = 0 ;
942
949
SmallVector<int64_t , 5 > strides;
943
950
if (failed (memrefType.getStridesAndOffset (strides, offset)))
944
951
return {};
945
- return getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices,
946
- strides);
952
+ return std::make_pair (
953
+ memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (),
954
+ memrefType),
955
+ getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices, strides));
947
956
};
948
957
949
958
// Source
950
- auto optSrcIdx = flattenIndex (src, cast<MemRefType>(memrefSrc.getType ()),
951
- op.getSrcIndices ());
952
- if (!optSrcIdx )
959
+ auto optSrcBuffer = flattenIndex (src, cast<MemRefType>(memrefSrc.getType ()),
960
+ op.getSrcIndices ());
961
+ if (!optSrcBuffer )
953
962
return op.emitOpError (" failed to flatten source memref indices" );
954
- auto optDstIdx = flattenIndex (dst, cast<MemRefType>(memrefDst.getType ()),
955
- op.getDstIndices ());
956
- if (!optDstIdx )
963
+ auto optDstBuffer = flattenIndex (dst, cast<MemRefType>(memrefDst.getType ()),
964
+ op.getDstIndices ());
965
+ if (!optDstBuffer )
957
966
return op.emitOpError (" failed to flatten destination memref indices" );
958
967
959
- Type srcPtrType =
960
- LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
961
- Type dstPtrType =
962
- LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
968
+ Type srcPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
969
+ Type dstPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
963
970
Value srcPtr = rewriter.create <LLVM::GEPOp>(
964
- loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
965
-
971
+ loc, srcPtrType, elemType, optSrcBuffer->first ,
972
+ ArrayRef<Value>({optSrcBuffer->second }));
973
+
966
974
Value dstPtr = rewriter.create <LLVM::GEPOp>(
967
- loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
975
+ loc, dstPtrType, elemType, optDstBuffer->first ,
976
+ ArrayRef<Value>({optDstBuffer->second }));
968
977
969
978
rewriter.replaceOpWithNewOp <ROCDL::GlobalLoadLDSOp>(
970
979
op, srcPtr, dstPtr, createI32Constant (rewriter, loc, loadWidth),
0 commit comments