@@ -121,8 +121,8 @@ define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
121
121
; CHECK-NEXT: entry:
122
122
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
123
123
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
124
- ; CHECK-NEXT: [[MFADD1 :%.*]] = fadd <9 x double> [[A]], [[B]]
125
- ; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD1 ]], i32 3, i32 3)
124
+ ; CHECK-NEXT: [[MFADD :%.*]] = fadd <9 x double> [[A]], [[B]]
125
+ ; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD ]], i32 3, i32 3)
126
126
; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
127
127
; CHECK-NEXT: ret void
128
128
;
@@ -203,8 +203,8 @@ define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, doubl
203
203
; CHECK-NEXT: [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
204
204
; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
205
205
; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
206
- ; CHECK-NEXT: [[MFADD :%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207
- ; CHECK-NEXT: store <9 x double> [[MFADD ]], ptr [[E:%.*]], align 128
206
+ ; CHECK-NEXT: [[MADD :%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207
+ ; CHECK-NEXT: store <9 x double> [[MADD ]], ptr [[E:%.*]], align 128
208
208
; CHECK-NEXT: ret void
209
209
;
210
210
entry:
@@ -257,3 +257,25 @@ entry:
257
257
declare <9 x double > @llvm.matrix.multiply.v9f64.v9f64.v9f64 (<9 x double >, <9 x double >, i32 , i32 , i32 )
258
258
declare <9 x double > @llvm.matrix.transpose.v9f64.v9f64 (<9 x double >, i32 , i32 )
259
259
declare <9 x i32 > @llvm.matrix.transpose.v9i32.v9i32 (<9 x i32 >, i32 , i32 )
260
+
261
+
262
+ ; (a * b + c)^T -> (a * b)^T + b^T with integer types.
263
+ define noundef <4 x i32 > @mul_add_transpose_int (<4 x i32 > noundef %a , <4 x i32 > noundef %b , <4 x i32 > noundef %c ) {
264
+ ; CHECK-LABEL: @mul_add_transpose_int(
265
+ ; CHECK-NEXT: entry:
266
+ ; CHECK-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2)
267
+ ; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2)
268
+ ; CHECK-NEXT: [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2)
269
+ ; CHECK-NEXT: [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]]
270
+ ; CHECK-NEXT: ret <4 x i32> [[MADD]]
271
+ ;
272
+ entry:
273
+ %mul = tail call <4 x i32 > @llvm.matrix.multiply.v4i32.v4i32.v4i32 (<4 x i32 > %a , <4 x i32 > %b , i32 2 , i32 2 , i32 2 )
274
+ %add = add <4 x i32 > %mul , %c
275
+ %t = tail call <4 x i32 > @llvm.matrix.transpose.v4i32 (<4 x i32 > %add , i32 2 , i32 2 )
276
+ ret <4 x i32 > %t
277
+ }
278
+
279
+ declare <4 x i32 > @llvm.matrix.multiply.v4i32.v4i32.v4i32 (<4 x i32 >, <4 x i32 >, i32 immarg, i32 immarg, i32 immarg)
280
+
281
+ declare <4 x i32 > @llvm.matrix.transpose.v4i32 (<4 x i32 >, i32 immarg, i32 immarg)
0 commit comments