Skip to content

Commit 7a7f919

Browse files
authored
[SandboxVec][Legality] Fix mask on diamond reuse with shuffle (#126963)
This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.
1 parent 9478822 commit 7a7f919

File tree

5 files changed

+63
-3
lines changed

5 files changed

+63
-3
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/SandboxIR/Value.h"
1919
#include "llvm/Support/Casting.h"
2020
#include "llvm/Support/raw_ostream.h"
21+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
2122
#include <algorithm>
2223

2324
namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ class InstrMaps {
8586
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
8687
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
8788
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
88-
for (auto [Lane, Orig] : enumerate(Origs)) {
89+
unsigned Lane = 0;
90+
for (Value *Orig : Origs) {
8991
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
9092
assert(Pair.second && "Orig already exists in the map!");
9193
(void)Pair;
9294
OrigToLaneMap[Orig] = Lane;
95+
Lane += VecUtils::getNumLanes(Orig);
9396
}
9497
}
9598
void clear() {

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,20 @@ CollectDescr
202202
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
203203
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
204204
Vec.reserve(Bndl.size());
205-
for (auto [Lane, V] : enumerate(Bndl)) {
205+
uint32_t LaneAccum;
206+
for (auto [Elm, V] : enumerate(Bndl)) {
207+
uint32_t VLanes = VecUtils::getNumLanes(V);
206208
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
207209
// If there is a vector containing `V`, then get the lane it came from.
208210
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
209-
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
211+
// This could be a vector, like <2 x float> in which case the mask needs
212+
// to enumerate all lanes.
213+
for (int Ln = 0; Ln != VLanes; ++Ln)
214+
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
210215
} else {
211216
Vec.emplace_back(V);
212217
}
218+
LaneAccum += VLanes;
213219
}
214220
return CollectDescr(std::move(Vec));
215221
}

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
328328
const ShuffleMask &Mask =
329329
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
330330
NewVec = createShuffle(VecOp, Mask, UserBB);
331+
assert(NewVec->getType() == VecOp->getType() &&
332+
"Expected same type! Bad mask ?");
331333
break;
332334
}
333335
case LegalityResultID::DiamondReuseMultiInput: {

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
243243
ret void
244244
}
245245

246+
; Same but with <2 x float> elements instead of scalars.
247+
define void @diamondWithShuffleFromVec(ptr %ptr) {
248+
; CHECK-LABEL: define void @diamondWithShuffleFromVec(
249+
; CHECK-SAME: ptr [[PTR:%.*]]) {
250+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
251+
; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
252+
; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
253+
; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
254+
; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
255+
; CHECK-NEXT: ret void
256+
;
257+
%ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
258+
%ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
259+
%ld0 = load <2 x float>, ptr %ptr0
260+
%ld1 = load <2 x float>, ptr %ptr1
261+
%sub0 = fsub <2 x float> %ld0, %ld1
262+
%sub1 = fsub <2 x float> %ld1, %ld0
263+
store <2 x float> %sub0, ptr %ptr0
264+
store <2 x float> %sub1, ptr %ptr1
265+
ret void
266+
}
267+
246268
define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
247269
; CHECK-LABEL: define void @diamondMultiInput(
248270
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
8585
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
8686
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
8787
}
88+
89+
TEST_F(InstrMapsTest, VectorLanes) {
90+
parseIR(C, R"IR(
91+
define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
92+
%vadd0 = add <2 x i8> %v0, %v1
93+
%vadd1 = add <2 x i8> %v0, %v1
94+
%vadd2 = add <4 x i8> %v2, %v3
95+
ret void
96+
}
97+
)IR");
98+
llvm::Function *LLVMF = &*M->getFunction("foo");
99+
sandboxir::Context Ctx(C);
100+
auto *F = Ctx.createFunction(LLVMF);
101+
auto *BB = &*F->begin();
102+
auto It = BB->begin();
103+
104+
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
105+
auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
106+
auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
107+
108+
sandboxir::InstrMaps IMaps(Ctx);
109+
110+
// Check that the vector lanes are calculated correctly.
111+
IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
112+
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
113+
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
114+
}

0 commit comments

Comments
 (0)