@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
728
728
func.return
729
729
}
730
730
731
+ func.func @async_tma_store (%tensorMap1d: !tensorMap1d , %tensorMap2d: !tensorMap2d , %tensorMap3d: !tensorMap3d , %tensorMap4d: !tensorMap4d , %tensorMap5d: !tensorMap5d ,
732
+ %buffer1d: memref <128 xf32 ,3 >,
733
+ %buffer2d: memref <32 x32 xf32 ,3 >,
734
+ %buffer3d: memref <2 x32 x32 xf32 ,3 >,
735
+ %buffer4d: memref <2 x2 x32 x32 xf32 ,3 >,
736
+ %buffer5d: memref <2 x2 x2 x32 x32 xf32 ,3 >) {
737
+ %c0 = arith.constant 0 : index
738
+ %crd0 = arith.constant 0 : index
739
+ %crd1 = arith.constant 0 : index
740
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
741
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d [%crd0 ] : memref <128 xf32 ,3 > -> !tensorMap1d
742
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
743
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d [%crd0 , %crd1 ] : memref <32 x32 xf32 ,3 > -> !tensorMap2d
744
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
745
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d [%crd0 , %crd1 , %crd0 ] : memref <2 x32 x32 xf32 ,3 > -> !tensorMap3d
746
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
747
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d [%crd0 , %crd1 , %crd1 , %crd0 ] : memref <2 x2 x32 x32 xf32 ,3 > -> !tensorMap4d
748
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
749
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d [%crd0 , %crd1 , %crd1 , %crd0 , %crd0 ] : memref <2 x2 x2 x32 x32 xf32 ,3 > -> !tensorMap5d
750
+ func.return
751
+ }
752
+
753
+
754
+ func.func @async_tma_store_predicate (%tensorMap1d: !tensorMap1d , %tensorMap2d: !tensorMap2d , %tensorMap3d: !tensorMap3d , %tensorMap4d: !tensorMap4d , %tensorMap5d: !tensorMap5d ,
755
+ %buffer1d: memref <128 xf32 ,3 >,
756
+ %buffer2d: memref <32 x32 xf32 ,3 >,
757
+ %buffer3d: memref <2 x32 x32 xf32 ,3 >,
758
+ %buffer4d: memref <2 x2 x32 x32 xf32 ,3 >,
759
+ %buffer5d: memref <2 x2 x2 x32 x32 xf32 ,3 >,
760
+ %p: i1 ) {
761
+ %c0 = arith.constant 0 : index
762
+ %crd0 = arith.constant 0 : index
763
+ %crd1 = arith.constant 0 : index
764
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
765
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d [%crd0 ], predicate = %p : memref <128 xf32 ,3 > -> !tensorMap1d
766
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
767
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d [%crd0 , %crd1 ], predicate = %p : memref <32 x32 xf32 ,3 > -> !tensorMap2d
768
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
769
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d [%crd0 , %crd1 , %crd0 ], predicate = %p : memref <2 x32 x32 xf32 ,3 > -> !tensorMap3d
770
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
771
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d [%crd0 , %crd1 , %crd1 , %crd0 ], predicate = %p : memref <2 x2 x32 x32 xf32 ,3 > -> !tensorMap4d
772
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
773
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d [%crd0 , %crd1 , %crd1 , %crd0 , %crd0 ], predicate = %p : memref <2 x2 x2 x32 x32 xf32 ,3 > -> !tensorMap5d
774
+ func.return
775
+ }
776
+
731
777
func.func @create_tensor_map (%devicePtr2d : memref <64 x128 xf32 >, %devicePtr1d : memref <128 xf32 >) {
732
778
%crd0 = arith.constant 64 : index
733
779
%crd1 = arith.constant 128 : index
0 commit comments