-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nvgpu] Add nvgpu.tma.async.store
#77811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering | |
return success(); | ||
} | ||
}; | ||
|
||
struct NVGPUTmaAsyncStoreOpLowering | ||
: public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { | ||
using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; | ||
LogicalResult | ||
matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op->getLoc(), rewriter); | ||
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); | ||
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, | ||
adaptor.getSrc(), {}, rewriter); | ||
SmallVector<Value> coords = adaptor.getCoordinates(); | ||
for (auto [index, value] : llvm::enumerate(coords)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not this for PR but we should finally add reversal to the coords to make the dimension ordering work as MLIR normally expects There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, let's do this in another PR |
||
coords[index] = truncToI32(b, value); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( | ||
op, adaptor.getTensorMapDescriptor(), dest, coords, | ||
adaptor.getPredicate()); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct NVGPUGenerateWarpgroupDescriptorLowering | ||
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { | ||
using ConvertOpToLLVMPattern< | ||
|
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, | |
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity | ||
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity | ||
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load | ||
NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store | ||
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor | ||
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor | ||
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need to add the commit/wait_group version for TMA stores? Otherwise this instruction is not that useful? Unless you expect people to be calling the NVVM versions for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valid point. I suggest using these OPs directly from the
NVVM
dialect.Why? The
NVGPU
dialect acts as a bridge between high-level dialects (such asmemref
,vector
, and etc.) and theNVVM
dialect. Thecp.async.bulk.commit.group
is straightforward to use, and introducing it to the NVGPU dialect would only result in replication.If we need higher-level abstractions (what we have for operations like
mbarrier.group
ornvgpu.tma.descriptor
) then we can introduce specific Ops in the NVGPU dialect.What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(just noticed that
cp.async.bulk.wait_group.read
is missing inNVVM
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, sounds good to me. Perhaps it would be a good idea to add a pointer to the NVVM op to the description here.