Skip to content

[SandboxVec][Legality] Fix mask on diamond reuse with shuffle #126963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
#include <algorithm>

namespace llvm::sandboxir {
Expand Down Expand Up @@ -85,11 +86,13 @@ class InstrMaps {
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
for (auto [Lane, Orig] : enumerate(Origs)) {
unsigned Lane = 0;
for (Value *Orig : Origs) {
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
assert(Pair.second && "Orig already exists in the map!");
(void)Pair;
OrigToLaneMap[Orig] = Lane;
Lane += VecUtils::getNumLanes(Orig);
}
}
void clear() {
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,20 @@ CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
for (auto [Lane, V] : enumerate(Bndl)) {
uint32_t LaneAccum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this uninitialized in LaneAccum += VLanes;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is dead code that I accidentally removed in the follow-up patch instead of this one, while maintaining the patch chain. I already pushed the fix: e75e617

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you sent e75e617 to fix that

for (auto [Elm, V] : enumerate(Bndl)) {
uint32_t VLanes = VecUtils::getNumLanes(V);
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
// If there is a vector containing `V`, then get the lane it came from.
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
// This could be a vector, like <2 x float> in which case the mask needs
// to enumerate all lanes.
for (int Ln = 0; Ln != VLanes; ++Ln)
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
} else {
Vec.emplace_back(V);
}
LaneAccum += VLanes;
}
return CollectDescr(std::move(Vec));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
const ShuffleMask &Mask =
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
NewVec = createShuffle(VecOp, Mask, UserBB);
assert(NewVec->getType() == VecOp->getType() &&
"Expected same type! Bad mask ?");
break;
}
case LegalityResultID::DiamondReuseMultiInput: {
Expand Down
22 changes: 22 additions & 0 deletions llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
ret void
}

; Same but with <2 x float> elements instead of scalars.
define void @diamondWithShuffleFromVec(ptr %ptr) {
; CHECK-LABEL: define void @diamondWithShuffleFromVec(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
%ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
%ld0 = load <2 x float>, ptr %ptr0
%ld1 = load <2 x float>, ptr %ptr1
%sub0 = fsub <2 x float> %ld0, %ld1
%sub1 = fsub <2 x float> %ld1, %ld0
store <2 x float> %sub0, ptr %ptr0
store <2 x float> %sub1, ptr %ptr1
ret void
}

define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
; CHECK-LABEL: define void @diamondMultiInput(
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}

TEST_F(InstrMapsTest, VectorLanes) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
%vadd0 = add <2 x i8> %v0, %v1
%vadd1 = add <2 x i8> %v0, %v1
%vadd2 = add <4 x i8> %v2, %v3
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();

auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);

sandboxir::InstrMaps IMaps(Ctx);

// Check that the vector lanes are calculated correctly.
IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
}