Skip to content

Commit 12cefcc

Browse files
committed
[Matrix] Skip already fused instructions before trying to fuse multiply.
lowerDotProduct called above may already lower a matrix multiply and mark it as procssed by adding it to FusedInsts. Don't try to process it again in LowerMatrixMultiplyFused by checking if FusedInsts. Without this change, we trigger an assertion when trying to erase the same original matrix multiply twice.
1 parent f623df6 commit 12cefcc

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,8 @@ class LowerMatrixIntrinsics {
10141014

10151015
// Third, try to fuse candidates.
10161016
for (CallInst *CI : MaybeFusableInsts)
1017-
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1017+
if (!FusedInsts.contains(CI))
1018+
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
10181019

10191020
Changed = !FusedInsts.empty();
10201021

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
3+
4+
define void @test(ptr %p, <8 x i32> %x) {
5+
; CHECK-LABEL: define void @test(
6+
; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) {
7+
; CHECK-NEXT: [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4
8+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer
9+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 1>
10+
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 2>
11+
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 3>
12+
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 4>
13+
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 5>
14+
; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 6>
15+
; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 7>
16+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0
17+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0
18+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0
19+
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1
20+
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0
21+
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2
22+
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0
23+
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3
24+
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0
25+
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4
26+
; CHECK-NEXT: [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0
27+
; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5
28+
; CHECK-NEXT: [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0
29+
; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6
30+
; CHECK-NEXT: [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0
31+
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7
32+
; CHECK-NEXT: [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]]
33+
; CHECK-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]])
34+
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0
35+
; CHECK-NEXT: [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0
36+
; CHECK-NEXT: store i32 [[E]], ptr [[P]], align 4
37+
; CHECK-NEXT: ret void
38+
;
39+
%l = load <8 x i32>, ptr %p, align 4
40+
%t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8)
41+
%m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1)
42+
%e = extractelement <1 x i32> %m, i64 0
43+
store i32 %e, ptr %p, align 4
44+
ret void
45+
}
46+
47+
declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg)
48+
49+
declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)