Skip to content

Commit 89d5fe9

Browse files
committed
add test
1 parent 3d2078d commit 89d5fe9

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,133 @@ func.func @warpgroup_mma_128_128_64(
772772
return
773773
}
774774

775+
// CHECK-LABEL: @warpgroup_mma_store(
776+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
777+
func.func @warpgroup_mma_store(
778+
%result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
779+
%result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
780+
%matrixD: memref<128x128xf32,3>) {
781+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782+
// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783+
// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] :
784+
// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
785+
// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
786+
// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
787+
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
788+
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
789+
// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
790+
791+
// ### Store {d0, d1} of each thread ###
792+
793+
// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
794+
// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32
795+
// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32
796+
// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
797+
// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
798+
// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
799+
// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
800+
// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
801+
// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
802+
// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
803+
// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
804+
// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
805+
// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
806+
// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
807+
// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
808+
// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
809+
// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
810+
// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
811+
// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
812+
// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
813+
// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
814+
// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
815+
816+
// ### Store {d2, d3} of each thread ###
817+
818+
// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
819+
// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
820+
// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
821+
// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
822+
// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
823+
// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
824+
// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
825+
// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
826+
// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
827+
// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
828+
// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
829+
830+
// ### Store {d4, d5} of each thread ###
831+
832+
// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
833+
// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
834+
// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
835+
// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
836+
// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
837+
// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
838+
// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
839+
// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
840+
// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
841+
// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
842+
// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
843+
844+
// ### Store {d6, d7} of each thread ###
845+
846+
// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
847+
// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
848+
// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
849+
// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
850+
// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
851+
// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
852+
// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
853+
// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
854+
// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
855+
// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
856+
// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
857+
858+
// Pattern continues similarly 28x times until {... d62, d63}
859+
860+
// ### Store {d64, d65} of each thread ###
861+
862+
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
863+
// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
864+
// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
865+
// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
866+
// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
867+
// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
868+
// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
869+
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32
870+
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32
871+
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
872+
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
873+
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32
874+
// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
875+
// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
876+
// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
877+
// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
878+
// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
879+
// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
880+
// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
881+
// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
882+
// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
883+
// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
884+
// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
885+
// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
886+
// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
887+
// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
888+
// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
889+
// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
890+
// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
891+
// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
892+
893+
// Pattern continues similarly 31x times until {... d126, d127}
894+
895+
nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD :
896+
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
897+
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
898+
to memref<128x128xf32,3>
899+
return
900+
}
901+
775902
transform.sequence failures(propagate) {
776903
^bb1(%arg1: !transform.any_op):
777904
%0 = transform.structured.match ops{["func.func"]} in %arg1

0 commit comments

Comments
 (0)