Skip to content

[mlir][nvvm] Improve cp.async.bulk.tensor.shared.cluster.global for multicast #72429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1405,13 +1405,29 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
Arguments<(ins LLVM_PointerShared:$dstMem,
LLVM_AnyPointer:$tmaDescriptor,
LLVM_PointerShared:$mbar,
Optional<I16>:$multicastMask,
Variadic<I32>:$coordinates,
PtxPredicate:$predicate)> {
let description = [{
Initiates an asynchronous copy operation on the tensor data from global
memory to shared memory.

The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Operand `multicastMask` specifies the destination CTAs in the cluster such
that each bit position in the 16-bit `multicastMask` operand corresponds to
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
}];

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to make it another op? Can't the existing op below be extended to have the multi-cast mark as an attribute and generate the right PTX in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for extending the existing Op.

I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).

However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to add this idea to the PR to gather our preference. Using the same Op is perfectly fine for this.

We've divided the original PTX instruction into two parts:
load (nvvm.cp.async.bulk.tensor.shared.cluster.global)
store (nvvm.cp.async.bulk.tensor.global.shared.cta)

I believe separating these two makes sense since they address different concerns.

I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).
However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).

I think this is the way I will go.

Copy link
Member Author

@grypp grypp Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR to use the existing Op.

Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.

For example the current op is below:

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32

with multicast_mask

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32

with multicast_mask + l2_cache_hint

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32

with multicast_mask + l2_cache_hint + im2col

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16

Same as above with predicate

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR to use the existing Op.

The updated version looks good to me.

Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.

I do not see any concerns. We can extend it the same way for cache-hint.

I believe, im2col itself will be a variadic type (since it can be of size 1,2,3). So, as long as we can have an operand that's both variadic + optional, we are good with this direction.

For example the current op is below:

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32

with multicast_mask

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32

with multicast_mask + l2_cache_hint

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32

with multicast_mask + l2_cache_hint + im2col

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16

Same as above with predicate

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter mnemonic that capture the operation, and then using attributes seems more like MLIR to me :)

Likely more friendly for the user to create as well!

let assemblyFormat = [{
$dstMem `,`
$tmaDescriptor `,`
$mbar `,`
`box` `[`$coordinates `]`
( `multicast_mask` `=` $multicastMask^ `,` )?
`box` `[`$coordinates `]`
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type(operands)
}];
Expand All @@ -1422,11 +1438,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
std::string ptx = "cp.async.bulk.tensor.";
ptx += std::to_string(dim) + "d.";
ptx += "shared::cluster.global.mbarrier::complete_tx::bytes";
if(dim == 1) ptx += " [%0], [%1, {%3} ], [%2];";
if(dim == 2) ptx += " [%0], [%1, {%3, %4} ], [%2];";
if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5} ], [%2];";
if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6} ], [%2];";
if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7} ], [%2];";
if(getMulticastMask()) {
ptx += ".multicast::cluster";
if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
} else {
if(dim == 1) ptx += " [%0], [%1, {%3} ], [%2];";
if(dim == 2) ptx += " [%0], [%1, {%3, %4} ], [%2];";
if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5} ], [%2];";
if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6} ], [%2];";
if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7} ], [%2];";
}
return ptx;
}
}];
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ struct NVGPUTmaAsyncLoadOpLowering
}

rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords,
op, dest, adaptor.getTensorMapDescriptor(), barrier, Value(), coords,
adaptor.getPredicate());
return success();
}
Expand Down
45 changes: 45 additions & 0 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
return
}

// CHECK-LABEL: @tma_load_multicast1d
func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
return
}

// CHECK-LABEL: @tma_load_multicast2d
func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
return
}

// CHECK-LABEL: @tma_load_multicast3d
func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
return
}

// CHECK-LABEL: @tma_load_multicast4d
func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
return
}

// CHECK-LABEL: @tma_load_multicast5d
func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
return
}

// CHECK-LABEL: @tma_store_1d
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
Expand Down