Skip to content

[SandboxVec][Legality] Fix legality of SelectInst #125005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
return std::nullopt;
return ResultReason::DiffOpcodes;
}
case Instruction::Opcode::Select:
case Instruction::Opcode::Select: {
auto *Sel0 = cast<SelectInst>(Bndl[0]);
auto *Cond0 = Sel0->getCondition();
if (VecUtils::getNumLanes(Cond0) != VecUtils::getNumLanes(Sel0))
// TODO: For now we don't vectorize if the lanes in the condition don't
// match those of the select instruction.
return ResultReason::Unimplemented;
return std::nullopt;
}
case Instruction::Opcode::FNeg:
case Instruction::Opcode::Add:
case Instruction::Opcode::FAdd:
Expand Down
97 changes: 97 additions & 0 deletions llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s

; This file includes tests for opcodes that need special checks.

; TODO: Selects with conditions of diff number of lanes than the instruction itself need special treatment.
define void @selects_with_diff_cond_lanes(ptr %ptr, i1 %cond0, i1 %cond1, <2 x i8> %op0, <2 x i8> %op1) {
; CHECK-LABEL: define void @selects_with_diff_cond_lanes(
; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND0:%.*]], i1 [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND0]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
%ld0 = load <2 x i8>, ptr %ptr0
%ld1 = load <2 x i8>, ptr %ptr1
%sel0 = select i1 %cond0, <2 x i8> %ld0, <2 x i8> %ld0
%sel1 = select i1 %cond1, <2 x i8> %ld1, <2 x i8> %ld1
store <2 x i8> %sel0, ptr %ptr0
store <2 x i8> %sel1, ptr %ptr1
ret void
}

; TODO: Selects that share the same condition need special treatment.
define void @selects_with_common_condition_but_diff_lanes(ptr %ptr, i1 %cond, <2 x i8> %op0, <2 x i8> %op1) {
; CHECK-LABEL: define void @selects_with_common_condition_but_diff_lanes(
; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
%ld0 = load <2 x i8>, ptr %ptr0
%ld1 = load <2 x i8>, ptr %ptr1
%sel0 = select i1 %cond, <2 x i8> %ld0, <2 x i8> %ld0
%sel1 = select i1 %cond, <2 x i8> %ld1, <2 x i8> %ld1
store <2 x i8> %sel0, ptr %ptr0
store <2 x i8> %sel1, ptr %ptr1
ret void
}

; Selects with conditions of the same number of lanes as the instruction itself be vectorized as usual.
define void @selects_same_cond_lanes(ptr %ptr, <2 x i1> %cond0, <2 x i1> %cond1, <2 x i8> %op0, <2 x i8> %op1) {
; CHECK-LABEL: define void @selects_same_cond_lanes(
; CHECK-SAME: ptr [[PTR:%.*]], <2 x i1> [[COND0:%.*]], <2 x i1> [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i1> [[COND0]], i32 0
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i1> poison, i1 [[VPACK]], i32 0
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i1> [[COND0]], i32 1
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i1> [[VPACK1]], i1 [[VPACK2]], i32 1
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i1> [[COND1]], i32 0
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i1> [[VPACK3]], i1 [[VPACK4]], i32 2
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i1> [[COND1]], i32 1
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i1> [[VPACK5]], i1 [[VPACK6]], i32 3
; CHECK-NEXT: [[VECL:%.*]] = load <4 x i8>, ptr [[PTR0]], align 2
; CHECK-NEXT: [[VEC:%.*]] = select <4 x i1> [[VPACK7]], <4 x i8> [[VECL]], <4 x i8> [[VECL]]
; CHECK-NEXT: store <4 x i8> [[VEC]], ptr [[PTR0]], align 2
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
%ld0 = load <2 x i8>, ptr %ptr0
%ld1 = load <2 x i8>, ptr %ptr1
%sel0 = select <2 x i1> %cond0, <2 x i8> %ld0, <2 x i8> %ld0
%sel1 = select <2 x i1> %cond1, <2 x i8> %ld1, <2 x i8> %ld1
store <2 x i8> %sel0, ptr %ptr0
store <2 x i8> %sel1, ptr %ptr1
ret void
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,

TEST_F(LegalityTest, LegalitySkipSchedule) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2, i1 %c0, i1 %c1) {
entry:
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
Expand All @@ -93,6 +93,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
%trunc32to8 = trunc i32 %v2 to i8
%cmpSLT = icmp slt i64 %v0, %v1
%cmpSGT = icmp sgt i64 %v0, %v1
%sel0 = select i1 %c0, <2 x float> %vec2, <2 x float> %vec2
%sel1 = select i1 %c1, <2 x float> %vec2, <2 x float> %vec2
ret void
}
)IR");
Expand Down Expand Up @@ -128,6 +130,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
auto *Sel0 = cast<sandboxir::SelectInst>(&*It++);
auto *Sel1 = cast<sandboxir::SelectInst>(&*It++);

llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
Expand Down Expand Up @@ -241,6 +245,15 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::RepeatedInstrs);
}
{
// For now don't vectorize Selects when the number of elements of conditions
// doesn't match the operands.
const auto &Result =
Legality.canVectorize({Sel0, Sel1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::Unimplemented);
}
}

TEST_F(LegalityTest, LegalitySchedule) {
Expand Down