Skip to content

Commit 56548e1

Browse files
committed
[Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch
1 parent 639c19d commit 56548e1

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,14 +2326,13 @@ class LowerMatrixIntrinsics {
23262326
MatrixTy A = getMatrix(OpA, Shape, Builder);
23272327
MatrixTy B = getMatrix(OpB, Shape, Builder);
23282328

2329-
Value *CondV[2];
2329+
SmallVector<Value*> CondV;
23302330
if (isa<FixedVectorType>(Cond->getType())) {
23312331
MatrixTy C = getMatrix(Cond, Shape, Builder);
2332-
CondV[0] = C.getVector(0);
2333-
CondV[1] = C.getVector(1);
2332+
llvm::copy(C.vectors(), std::back_inserter(CondV));
23342333
} else {
2335-
CondV[0] = Cond;
2336-
CondV[1] = Cond;
2334+
CondV.resize(A.getNumVectors());
2335+
std::fill(CondV.begin(), CondV.end(), Cond);
23372336
}
23382337

23392338
for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))

llvm/test/Transforms/LowerMatrixIntrinsics/select.ll

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,64 @@ define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
144144
store <4 x float> %op, ptr %out
145145
ret void
146146
}
147+
148+
define void @select_2x2_vcond_shape4(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
149+
; CHECK-LABEL: @select_2x2_vcond_shape4(
150+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x float>, ptr [[LHS:%.*]], align 16
151+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
152+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x float>, ptr [[RHS:%.*]], align 4
153+
; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[COL_LOAD1]], <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD2]]
154+
; CHECK-NEXT: store <4 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
155+
; CHECK-NEXT: ret void
156+
;
157+
%lhsv = load <4 x float>, ptr %lhs
158+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
159+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 4, i1 false, i32 4, i32 1)
160+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
161+
store <4 x float> %op, ptr %out
162+
ret void
163+
}
164+
165+
define void @select_2x2_vcond_shape5(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
166+
; CHECK-LABEL: @select_2x2_vcond_shape5(
167+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <1 x float>, ptr [[LHS:%.*]], align 16
168+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 1
169+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <1 x float>, ptr [[VEC_GEP]], align 4
170+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[LHS]], i64 2
171+
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <1 x float>, ptr [[VEC_GEP2]], align 8
172+
; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[LHS]], i64 3
173+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <1 x float>, ptr [[VEC_GEP4]], align 4
174+
; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <1 x i1>, ptr [[COND:%.*]], align 1
175+
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr i1, ptr [[COND]], i64 1
176+
; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <1 x i1>, ptr [[VEC_GEP7]], align 1
177+
; CHECK-NEXT: [[VEC_GEP9:%.*]] = getelementptr i1, ptr [[COND]], i64 2
178+
; CHECK-NEXT: [[COL_LOAD10:%.*]] = load <1 x i1>, ptr [[VEC_GEP9]], align 1
179+
; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr i1, ptr [[COND]], i64 3
180+
; CHECK-NEXT: [[COL_LOAD12:%.*]] = load <1 x i1>, ptr [[VEC_GEP11]], align 1
181+
; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <1 x float>, ptr [[RHS:%.*]], align 4
182+
; CHECK-NEXT: [[VEC_GEP14:%.*]] = getelementptr float, ptr [[RHS]], i64 1
183+
; CHECK-NEXT: [[COL_LOAD15:%.*]] = load <1 x float>, ptr [[VEC_GEP14]], align 4
184+
; CHECK-NEXT: [[VEC_GEP16:%.*]] = getelementptr float, ptr [[RHS]], i64 2
185+
; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <1 x float>, ptr [[VEC_GEP16]], align 4
186+
; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr float, ptr [[RHS]], i64 3
187+
; CHECK-NEXT: [[COL_LOAD19:%.*]] = load <1 x float>, ptr [[VEC_GEP18]], align 4
188+
; CHECK-NEXT: [[TMP1:%.*]] = select <1 x i1> [[COL_LOAD6]], <1 x float> [[COL_LOAD]], <1 x float> [[COL_LOAD13]]
189+
; CHECK-NEXT: [[TMP2:%.*]] = select <1 x i1> [[COL_LOAD8]], <1 x float> [[COL_LOAD1]], <1 x float> [[COL_LOAD15]]
190+
; CHECK-NEXT: [[TMP3:%.*]] = select <1 x i1> [[COL_LOAD10]], <1 x float> [[COL_LOAD3]], <1 x float> [[COL_LOAD17]]
191+
; CHECK-NEXT: [[TMP4:%.*]] = select <1 x i1> [[COL_LOAD12]], <1 x float> [[COL_LOAD5]], <1 x float> [[COL_LOAD19]]
192+
; CHECK-NEXT: store <1 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
193+
; CHECK-NEXT: [[VEC_GEP20:%.*]] = getelementptr float, ptr [[OUT]], i64 1
194+
; CHECK-NEXT: store <1 x float> [[TMP2]], ptr [[VEC_GEP20]], align 4
195+
; CHECK-NEXT: [[VEC_GEP21:%.*]] = getelementptr float, ptr [[OUT]], i64 2
196+
; CHECK-NEXT: store <1 x float> [[TMP3]], ptr [[VEC_GEP21]], align 8
197+
; CHECK-NEXT: [[VEC_GEP22:%.*]] = getelementptr float, ptr [[OUT]], i64 3
198+
; CHECK-NEXT: store <1 x float> [[TMP4]], ptr [[VEC_GEP22]], align 4
199+
; CHECK-NEXT: ret void
200+
;
201+
%lhsv = load <4 x float>, ptr %lhs
202+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 1, i1 false, i32 1, i32 4)
203+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 1, i1 false, i32 1, i32 4)
204+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
205+
store <4 x float> %op, ptr %out
206+
ret void
207+
}

0 commit comments

Comments
 (0)