Skip to content

Commit f10153f

Browse files
committed
[Matrix] Handle integer types when distributing transposes across adds.
The current code did not properly account for integer matrixes. Check if the operands are floating point or integer matrixes and use FAdd/Add accordingly. This is already done for other cases, like multiplies. Fixes llvm#62281.
1 parent 737820e commit f10153f

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,13 @@ class LowerMatrixIntrinsics {
838838
auto NewInst = distributeTransposes(
839839
TAMA, {R, C}, TAMB, {R, C}, Builder,
840840
[&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
841-
auto *FAdd =
842-
cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
843-
setShapeInfo(FAdd, Shape0);
844-
return FAdd;
841+
bool IsFP = I.getType()->isFPOrFPVectorTy();
842+
auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
843+
: LocalBuilder.CreateAdd(T0, T1, "madd");
844+
845+
auto *Result = cast<Instruction>(Add);
846+
setShapeInfo(Result, Shape0);
847+
return Result;
845848
});
846849
updateShapeAndReplaceAllUsesWith(I, NewInst);
847850
eraseFromParentAndMove(&I, II, BB);

llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
121121
; CHECK-NEXT: entry:
122122
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
123123
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
124-
; CHECK-NEXT: [[MFADD1:%.*]] = fadd <9 x double> [[A]], [[B]]
125-
; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD1]], i32 3, i32 3)
124+
; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
125+
; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
126126
; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
127127
; CHECK-NEXT: ret void
128128
;
@@ -203,8 +203,8 @@ define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, doubl
203203
; CHECK-NEXT: [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
204204
; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
205205
; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
206-
; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207-
; CHECK-NEXT: store <9 x double> [[MFADD]], ptr [[E:%.*]], align 128
206+
; CHECK-NEXT: [[MADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207+
; CHECK-NEXT: store <9 x double> [[MADD]], ptr [[E:%.*]], align 128
208208
; CHECK-NEXT: ret void
209209
;
210210
entry:
@@ -257,3 +257,25 @@ entry:
257257
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
258258
declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
259259
declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
260+
261+
262+
; (a * b + c)^T -> (a * b)^T + b^T with integer types.
263+
define noundef <4 x i32> @mul_add_transpose_int(<4 x i32> noundef %a, <4 x i32> noundef %b, <4 x i32> noundef %c) {
264+
; CHECK-LABEL: @mul_add_transpose_int(
265+
; CHECK-NEXT: entry:
266+
; CHECK-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2)
267+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2)
268+
; CHECK-NEXT: [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2)
269+
; CHECK-NEXT: [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]]
270+
; CHECK-NEXT: ret <4 x i32> [[MADD]]
271+
;
272+
entry:
273+
%mul = tail call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2)
274+
%add = add <4 x i32> %mul, %c
275+
%t = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %add, i32 2, i32 2)
276+
ret <4 x i32> %t
277+
}
278+
279+
declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg)
280+
281+
declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)