Skip to content

[mlir][nvgpu] Add nvgpu.tma.async.store #77811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]

}

def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
let summary = "TMA asynchronous store";
let description = [{
The Op store a tile memory region from global memory to shared memory by
Tensor Memory Access (TMA).

`$tensorMapDescriptor` is tensor map descriptor which has information about
tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
}];
let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
Variadic<Index>:$coordinates,
Optional<I1>:$predicate);
let assemblyFormat = [{
$src `to` $tensorMapDescriptor `[` $coordinates `]`
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type($src)
`->` type($tensorMapDescriptor)
}];
let hasVerifier = 1;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to add the commit/wait_group version for TMA stores? Otherwise this instruction is not that useful? Unless you expect people to be calling the NVVM versions for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point. I suggest using these OPs directly from the NVVM dialect.

Why? The NVGPU dialect acts as a bridge between high-level dialects (such as memref, vector, and etc.) and the NVVM dialect. The cp.async.bulk.commit.group is straightforward to use, and introducing it to the NVGPU dialect would only result in replication.

If we need higher-level abstractions (what we have for operations like mbarrier.group or nvgpu.tma.descriptor) then we can introduce specific Ops in the NVGPU dialect.

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just noticed that cp.async.bulk.wait_group.read is missing in NVVM)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.


def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let summary = "TMA create descriptor";
let description = [{
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};

struct NVGPUTmaAsyncStoreOpLowering
: public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getSrc(), {}, rewriter);
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not this for PR but we should finally add reversal to the coords to make the dimension ordering work as MLIR normally expects

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, let's do this in another PR

coords[index] = truncToI32(b, value);
}

rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
op, adaptor.getTensorMapDescriptor(), dest, coords,
adaptor.getPredicate());
return success();
}
};

struct NVGPUGenerateWarpgroupDescriptorLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// NVGPU_TmaAsyncStoreOp
//===----------------------------------------------------------------------===//

LogicalResult TmaAsyncStoreOp::verify() {
std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
*this, getTensorMapDescriptor().getType(), getSrc().getType());
if (error.has_value())
return error.value();

if (getCoordinates().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
<< " coordinates are supported.";
}
if (getCoordinates().size() !=
size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
return emitError() << "number of coordinates do not match with the rank of "
"tensor descriptor map.";
}

return success();
}

LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
func.return
}

func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
%buffer1d: memref<128xf32,3>,
%buffer2d: memref<32x32xf32,3>,
%buffer3d: memref<2x32x32xf32,3>,
%buffer4d: memref<2x2x32x32xf32,3>,
%buffer5d: memref<2x2x2x32x32xf32,3>) {
%c0 = arith.constant 0 : index
%crd0 = arith.constant 0 : index
%crd1 = arith.constant 0 : index
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0] : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
func.return
}


func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
%buffer1d: memref<128xf32,3>,
%buffer2d: memref<32x32xf32,3>,
%buffer3d: memref<2x32x32xf32,3>,
%buffer4d: memref<2x2x32x32xf32,3>,
%buffer5d: memref<2x2x2x32x32xf32,3>,
%p: i1) {
%c0 = arith.constant 0 : index
%crd0 = arith.constant 0 : index
%crd1 = arith.constant 0 : index
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
func.return
}

func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%crd0 = arith.constant 64 : index
%crd1 = arith.constant 128 : index
Expand Down