Skip to content

[SandboxVec][Legality] Check opcodes and types #113741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H

class Utils {
#include "llvm/SandboxIR/Type.h"

namespace llvm::sandboxir {

class VecUtils {
public:
/// \Returns the number of elements in \p Ty. That is the number of lanes if a
/// fixed vector or 1 if scalar. ScalableVectors have unknown size and
Expand All @@ -25,6 +29,8 @@ class Utils {
static Type *getElementType(Type *Ty) {
return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
}
}
};

} // namespace llvm::sandboxir

#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
20 changes: 19 additions & 1 deletion llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llvm/SandboxIR/Utils.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"

namespace llvm::sandboxir {

Expand All @@ -26,7 +27,24 @@ void LegalityResult::dump() const {
std::optional<ResultReason>
LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
ArrayRef<Value *> Bndl) {
// TODO: Unimplemented.
auto *I0 = cast<Instruction>(Bndl[0]);
auto Opcode = I0->getOpcode();
// If they have different opcodes, then we cannot form a vector (for now).
if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
return cast<Instruction>(V)->getOpcode() != Opcode;
}))
return ResultReason::DiffOpcodes;

// If not the same scalar type, Pack. This will accept scalars and vectors as
// long as the element type is the same.
Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) {
return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0;
}))
return ResultReason::DiffTypes;

// TODO: Missing checks

return std::nullopt;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ add_llvm_unittest(SandboxVectorizerTests
LegalityTest.cpp
SchedulerTest.cpp
SeedCollectorTest.cpp
VecUtilsTest.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ struct LegalityTest : public testing::Test {

TEST_F(LegalityTest, Legality) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
%gep3 = getelementptr float, ptr %ptr, i32 3
%ld0 = load float, ptr %gep0
%ld1 = load float, ptr %gep0
store float %ld0, ptr %gep0
store float %ld1, ptr %gep1
store <2 x float> %vec2, ptr %gep1
store <3 x float> %vec3, ptr %gep3
store i8 %arg, ptr %gep1
ret void
}
)IR");
Expand All @@ -46,10 +50,14 @@ define void @foo(ptr %ptr) {
auto It = BB->begin();
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
[[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
auto *StI8 = cast<sandboxir::StoreInst>(&*It++);

sandboxir::LegalityAnalysis Legality;
const auto &Result = Legality.canVectorize({St0, St1});
Expand All @@ -62,6 +70,23 @@ define void @foo(ptr %ptr) {
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotInstructions);
}
{
// Check DiffOpcodes
const auto &Result = Legality.canVectorize({St0, Ld0});
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffOpcodes);
}
{
// Check DiffTypes
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));

const auto &Result = Legality.canVectorize({St0, StI8});
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffTypes);
}
}

#ifndef NDEBUG
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- VecUtilsTest.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/Type.h"
#include "gtest/gtest.h"

using namespace llvm;

struct VecUtilsTest : public testing::Test {
LLVMContext C;
};

TEST_F(VecUtilsTest, GetNumElements) {
sandboxir::Context Ctx(C);
auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
EXPECT_EQ(sandboxir::VecUtils::getNumElements(ElemTy), 1);
auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
EXPECT_EQ(sandboxir::VecUtils::getNumElements(VTy), 2);
auto *VTy1 = sandboxir::FixedVectorType::get(ElemTy, 1);
EXPECT_EQ(sandboxir::VecUtils::getNumElements(VTy1), 1);
}

TEST_F(VecUtilsTest, GetElementType) {
sandboxir::Context Ctx(C);
auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
EXPECT_EQ(sandboxir::VecUtils::getElementType(ElemTy), ElemTy);
auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy);
}
Loading