@@ -1055,6 +1055,136 @@ func.func @warpgroup_mma_store(
1055
1055
return
1056
1056
}
1057
1057
1058
+ // CHECK-LABEL: @warpgroup_mma_store_multiple
1059
+ func.func @warpgroup_mma_store_multiple (
1060
+ %shmem_m64n8k : memref <64 x8 xf32 >,
1061
+ %shmem_m64n16k : memref <64 x16 xf32 >,
1062
+ %shmem_m64n24k : memref <64 x24 xf32 >,
1063
+ %shmem_m64n32k : memref <64 x32 xf32 >,
1064
+ %shmem_m64n40k : memref <64 x40 xf32 >,
1065
+ %shmem_m64n48k : memref <64 x48 xf32 >,
1066
+ %shmem_m64n56k : memref <64 x56 xf32 >,
1067
+ %shmem_m64n64k : memref <64 x64 xf32 >,
1068
+ %shmem_m64n72k : memref <64 x72 xf32 >,
1069
+ %shmem_m64n80k : memref <64 x80 xf32 >,
1070
+ %shmem_m64n88k : memref <64 x88 xf32 >,
1071
+ %shmem_m64n96k : memref <64 x96 xf32 >,
1072
+ %shmem_m64n104k : memref <64 x104 xf32 >,
1073
+ %shmem_m64n112k : memref <64 x112 xf32 >,
1074
+ %shmem_m64n120k : memref <64 x120 xf32 >,
1075
+ %shmem_m64n128k : memref <64 x128 xf32 >,
1076
+ %shmem_m64n136k : memref <64 x136 xf32 >,
1077
+ %shmem_m64n144k : memref <64 x144 xf32 >,
1078
+ %shmem_m64n152k : memref <64 x152 xf32 >,
1079
+ %shmem_m64n160k : memref <64 x160 xf32 >,
1080
+ %shmem_m64n168k : memref <64 x168 xf32 >,
1081
+ %shmem_m64n176k : memref <64 x176 xf32 >,
1082
+ %shmem_m64n184k : memref <64 x184 xf32 >,
1083
+ %shmem_m64n192k : memref <64 x192 xf32 >,
1084
+ %shmem_m64n200k : memref <64 x200 xf32 >,
1085
+ %shmem_m64n208k : memref <64 x208 xf32 >,
1086
+ %shmem_m64n216k : memref <64 x216 xf32 >,
1087
+ %shmem_m64n224k : memref <64 x224 xf32 >,
1088
+ %shmem_m64n232k : memref <64 x232 xf32 >,
1089
+ %shmem_m64n240k : memref <64 x240 xf32 >,
1090
+ %shmem_m64n248k : memref <64 x248 xf32 >,
1091
+ %shmem_m64n256k : memref <64 x256 xf32 >,
1092
+ %res_m64n16k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x16 xf32 >>,
1093
+ %res_m64n24k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x24 xf32 >>,
1094
+ %res_m64n32k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x32 xf32 >>,
1095
+ %res_m64n40k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x40 xf32 >>,
1096
+ %res_m64n48k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x48 xf32 >>,
1097
+ %res_m64n56k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x56 xf32 >>,
1098
+ %res_m64n64k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x64 xf32 >>,
1099
+ %res_m64n72k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x72 xf32 >>,
1100
+ %res_m64n80k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x80 xf32 >>,
1101
+ %res_m64n88k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x88 xf32 >>,
1102
+ %res_m64n96k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x96 xf32 >>,
1103
+ %res_m64n104k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x104 xf32 >>,
1104
+ %res_m64n112k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x112 xf32 >>,
1105
+ %res_m64n120k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x120 xf32 >>,
1106
+ %res_m64n128k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x128 xf32 >>,
1107
+ %res_m64n136k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x136 xf32 >>,
1108
+ %res_m64n144k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x144 xf32 >>,
1109
+ %res_m64n152k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x152 xf32 >>,
1110
+ %res_m64n160k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x160 xf32 >>,
1111
+ %res_m64n168k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x168 xf32 >>,
1112
+ %res_m64n176k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x176 xf32 >>,
1113
+ %res_m64n184k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x184 xf32 >>,
1114
+ %res_m64n192k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x192 xf32 >>,
1115
+ %res_m64n200k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x200 xf32 >>,
1116
+ %res_m64n208k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x208 xf32 >>,
1117
+ %res_m64n216k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x216 xf32 >>,
1118
+ %res_m64n224k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x224 xf32 >>,
1119
+ %res_m64n232k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x232 xf32 >>,
1120
+ %res_m64n240k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x240 xf32 >>,
1121
+ %res_m64n248k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x248 xf32 >>,
1122
+ %res_m64n256k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x256 xf32 >>) {
1123
+ // CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
1124
+ // CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
1125
+ // CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
1126
+ // CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
1127
+ // CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
1128
+ // CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
1129
+ // CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
1130
+ // CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
1131
+ // CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
1132
+ // CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
1133
+ // CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
1134
+ // CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
1135
+ // CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
1136
+ // CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
1137
+ // CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
1138
+ // CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
1139
+ // CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
1140
+ // CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
1141
+ // CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
1142
+ // CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
1143
+ // CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
1144
+ // CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
1145
+ // CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
1146
+ // CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
1147
+ // CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
1148
+ // CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
1149
+ // CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
1150
+ // CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
1151
+ // CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
1152
+ // CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
1153
+ // CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
1154
+ nvgpu.warpgroup.mma.store %res_m64n16k , %shmem_m64n16k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x16 xf32 >> to memref <64 x16 xf32 >
1155
+ nvgpu.warpgroup.mma.store %res_m64n24k , %shmem_m64n24k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x24 xf32 >> to memref <64 x24 xf32 >
1156
+ nvgpu.warpgroup.mma.store %res_m64n32k , %shmem_m64n32k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x32 xf32 >> to memref <64 x32 xf32 >
1157
+ nvgpu.warpgroup.mma.store %res_m64n40k , %shmem_m64n40k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x40 xf32 >> to memref <64 x40 xf32 >
1158
+ nvgpu.warpgroup.mma.store %res_m64n48k , %shmem_m64n48k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x48 xf32 >> to memref <64 x48 xf32 >
1159
+ nvgpu.warpgroup.mma.store %res_m64n56k , %shmem_m64n56k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x56 xf32 >> to memref <64 x56 xf32 >
1160
+ nvgpu.warpgroup.mma.store %res_m64n64k , %shmem_m64n64k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x64 xf32 >> to memref <64 x64 xf32 >
1161
+ nvgpu.warpgroup.mma.store %res_m64n72k , %shmem_m64n72k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x72 xf32 >> to memref <64 x72 xf32 >
1162
+ nvgpu.warpgroup.mma.store %res_m64n80k , %shmem_m64n80k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x80 xf32 >> to memref <64 x80 xf32 >
1163
+ nvgpu.warpgroup.mma.store %res_m64n88k , %shmem_m64n88k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x88 xf32 >> to memref <64 x88 xf32 >
1164
+ nvgpu.warpgroup.mma.store %res_m64n96k , %shmem_m64n96k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x96 xf32 >> to memref <64 x96 xf32 >
1165
+ nvgpu.warpgroup.mma.store %res_m64n104k , %shmem_m64n104k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x104 xf32 >> to memref <64 x104 xf32 >
1166
+ nvgpu.warpgroup.mma.store %res_m64n112k , %shmem_m64n112k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x112 xf32 >> to memref <64 x112 xf32 >
1167
+ nvgpu.warpgroup.mma.store %res_m64n120k , %shmem_m64n120k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x120 xf32 >> to memref <64 x120 xf32 >
1168
+ nvgpu.warpgroup.mma.store %res_m64n128k , %shmem_m64n128k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x128 xf32 >> to memref <64 x128 xf32 >
1169
+ nvgpu.warpgroup.mma.store %res_m64n136k , %shmem_m64n136k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x136 xf32 >> to memref <64 x136 xf32 >
1170
+ nvgpu.warpgroup.mma.store %res_m64n144k , %shmem_m64n144k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x144 xf32 >> to memref <64 x144 xf32 >
1171
+ nvgpu.warpgroup.mma.store %res_m64n152k , %shmem_m64n152k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x152 xf32 >> to memref <64 x152 xf32 >
1172
+ nvgpu.warpgroup.mma.store %res_m64n160k , %shmem_m64n160k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x160 xf32 >> to memref <64 x160 xf32 >
1173
+ nvgpu.warpgroup.mma.store %res_m64n168k , %shmem_m64n168k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x168 xf32 >> to memref <64 x168 xf32 >
1174
+ nvgpu.warpgroup.mma.store %res_m64n176k , %shmem_m64n176k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x176 xf32 >> to memref <64 x176 xf32 >
1175
+ nvgpu.warpgroup.mma.store %res_m64n184k , %shmem_m64n184k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x184 xf32 >> to memref <64 x184 xf32 >
1176
+ nvgpu.warpgroup.mma.store %res_m64n192k , %shmem_m64n192k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x192 xf32 >> to memref <64 x192 xf32 >
1177
+ nvgpu.warpgroup.mma.store %res_m64n200k , %shmem_m64n200k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x200 xf32 >> to memref <64 x200 xf32 >
1178
+ nvgpu.warpgroup.mma.store %res_m64n208k , %shmem_m64n208k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x208 xf32 >> to memref <64 x208 xf32 >
1179
+ nvgpu.warpgroup.mma.store %res_m64n216k , %shmem_m64n216k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x216 xf32 >> to memref <64 x216 xf32 >
1180
+ nvgpu.warpgroup.mma.store %res_m64n224k , %shmem_m64n224k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x224 xf32 >> to memref <64 x224 xf32 >
1181
+ nvgpu.warpgroup.mma.store %res_m64n232k , %shmem_m64n232k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x232 xf32 >> to memref <64 x232 xf32 >
1182
+ nvgpu.warpgroup.mma.store %res_m64n240k , %shmem_m64n240k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x240 xf32 >> to memref <64 x240 xf32 >
1183
+ nvgpu.warpgroup.mma.store %res_m64n248k , %shmem_m64n248k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x248 xf32 >> to memref <64 x248 xf32 >
1184
+ nvgpu.warpgroup.mma.store %res_m64n256k , %shmem_m64n256k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x256 xf32 >> to memref <64 x256 xf32 >
1185
+ return
1186
+ }
1187
+
1058
1188
func.func @warpgroup_mma_init () {
1059
1189
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
1060
1190
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
0 commit comments