@@ -772,6 +772,133 @@ func.func @warpgroup_mma_128_128_64(
772
772
return
773
773
}
774
774
775
+ // CHECK-LABEL: @warpgroup_mma_store(
776
+ // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
777
+ func.func @warpgroup_mma_store (
778
+ %result1 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
779
+ %result2 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
780
+ %matrixD: memref <128 x128 xf32 ,3 >) {
781
+ // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782
+ // CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783
+ // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] :
784
+ // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
785
+ // CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
786
+ // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
787
+ // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
788
+ // CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
789
+ // CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
790
+
791
+ // ### Store {d0, d1} of each thread ###
792
+
793
+ // CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
794
+ // CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32
795
+ // CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32
796
+ // CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
797
+ // CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
798
+ // CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
799
+ // CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
800
+ // CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
801
+ // CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
802
+ // CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
803
+ // CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
804
+ // CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
805
+ // CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
806
+ // CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
807
+ // CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
808
+ // CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
809
+ // CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
810
+ // CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
811
+ // CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
812
+ // CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
813
+ // CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
814
+ // CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
815
+
816
+ // ### Store {d2, d3} of each thread ###
817
+
818
+ // CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
819
+ // CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
820
+ // CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
821
+ // CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
822
+ // CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
823
+ // CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
824
+ // CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
825
+ // CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
826
+ // CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
827
+ // CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
828
+ // CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
829
+
830
+ // ### Store {d4, d5} of each thread ###
831
+
832
+ // CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
833
+ // CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
834
+ // CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
835
+ // CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
836
+ // CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
837
+ // CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
838
+ // CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
839
+ // CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
840
+ // CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
841
+ // CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
842
+ // CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
843
+
844
+ // ### Store {d6, d7} of each thread ###
845
+
846
+ // CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
847
+ // CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
848
+ // CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
849
+ // CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
850
+ // CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
851
+ // CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
852
+ // CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
853
+ // CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
854
+ // CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
855
+ // CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
856
+ // CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
857
+
858
+ // Pattern continues similarly 28x times until {... d62, d63}
859
+
860
+ // ### Store {d64, d65} of each thread ###
861
+
862
+ // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
863
+ // CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
864
+ // CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
865
+ // CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
866
+ // CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
867
+ // CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
868
+ // CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
869
+ // CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32
870
+ // CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32
871
+ // CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
872
+ // CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
873
+ // CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32
874
+ // CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
875
+ // CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
876
+ // CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
877
+ // CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
878
+ // CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
879
+ // CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
880
+ // CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
881
+ // CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
882
+ // CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
883
+ // CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
884
+ // CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
885
+ // CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
886
+ // CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
887
+ // CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
888
+ // CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
889
+ // CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
890
+ // CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
891
+ // CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
892
+
893
+ // Pattern continues similarly 31x times until {... d126, d127}
894
+
895
+ nvgpu.warpgroup.mma.store [%result1 , %result2 ], %matrixD :
896
+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
897
+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>
898
+ to memref <128 x128 xf32 ,3 >
899
+ return
900
+ }
901
+
775
902
transform.sequence failures (propagate ) {
776
903
^bb1 (%arg1: !transform.any_op ):
777
904
%0 = transform.structured.match ops {[" func.func" ]} in %arg1
0 commit comments