Skip to content

[clang][bytecode] Allow adding offsets to function pointers #105641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions clang/lib/AST/ByteCode/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,34 +885,60 @@ bool Compiler<Emitter>::VisitPointerArithBinOp(const BinaryOperator *E) {
if (!LT || !RT)
return false;

// Visit the given pointer expression and optionally convert to a PT_Ptr.
auto visitAsPointer = [&](const Expr *E, PrimType T) -> bool {
if (!this->visit(E))
return false;
if (T != PT_Ptr)
return this->emitDecayPtr(T, PT_Ptr, E);
return true;
};

if (LHS->getType()->isPointerType() && RHS->getType()->isPointerType()) {
if (Op != BO_Sub)
return false;

assert(E->getType()->isIntegerType());
if (!visit(RHS) || !visit(LHS))
if (!visitAsPointer(RHS, *RT) || !visitAsPointer(LHS, *LT))
return false;

return this->emitSubPtr(classifyPrim(E->getType()), E);
}

PrimType OffsetType;
if (LHS->getType()->isIntegerType()) {
if (!visit(RHS) || !visit(LHS))
if (!visitAsPointer(RHS, *RT))
return false;
if (!this->visit(LHS))
return false;
OffsetType = *LT;
} else if (RHS->getType()->isIntegerType()) {
if (!visit(LHS) || !visit(RHS))
if (!visitAsPointer(LHS, *LT))
return false;
if (!this->visit(RHS))
return false;
OffsetType = *RT;
} else {
return false;
}

if (Op == BO_Add)
return this->emitAddOffset(OffsetType, E);
else if (Op == BO_Sub)
return this->emitSubOffset(OffsetType, E);
// Do the operation and optionally transform to
// result pointer type.
if (Op == BO_Add) {
if (!this->emitAddOffset(OffsetType, E))
return false;

if (classifyPrim(E) != PT_Ptr)
return this->emitDecayPtr(PT_Ptr, classifyPrim(E), E);
return true;
} else if (Op == BO_Sub) {
if (!this->emitSubOffset(OffsetType, E))
return false;

if (classifyPrim(E) != PT_Ptr)
return this->emitDecayPtr(PT_Ptr, classifyPrim(E), E);
return true;
}

return false;
}
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/AST/ByteCode/FunctionPointer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===----------------------- FunctionPointer.cpp ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "FunctionPointer.h"

namespace clang {
namespace interp {

APValue FunctionPointer::toAPValue(const ASTContext &) const {
if (!Func)
return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
/*OnePastTheEnd=*/false, /*IsNull=*/true);

if (!Valid)
return APValue(static_cast<Expr *>(nullptr),
CharUnits::fromQuantity(getIntegerRepresentation()), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);

if (Func->getDecl())
return APValue(Func->getDecl(), CharUnits::fromQuantity(Offset), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);
return APValue(Func->getExpr(), CharUnits::fromQuantity(Offset), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);
}

void FunctionPointer::print(llvm::raw_ostream &OS) const {
OS << "FnPtr(";
if (Func && Valid)
OS << Func->getName();
else if (Func)
OS << reinterpret_cast<uintptr_t>(Func);
else
OS << "nullptr";
OS << ") + " << Offset;
}

} // namespace interp
} // namespace clang
41 changes: 10 additions & 31 deletions clang/lib/AST/ByteCode/FunctionPointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,29 @@

#include "Function.h"
#include "Primitives.h"
#include "clang/AST/APValue.h"

namespace clang {
class ASTContext;
class APValue;
namespace interp {

class FunctionPointer final {
private:
const Function *Func;
uint64_t Offset;
bool Valid;

public:
FunctionPointer() = default;
FunctionPointer(const Function *Func) : Func(Func), Valid(true) {}
FunctionPointer(const Function *Func, uint64_t Offset = 0)
: Func(Func), Offset(Offset), Valid(true) {}

FunctionPointer(uintptr_t IntVal, const Descriptor *Desc = nullptr)
: Func(reinterpret_cast<const Function *>(IntVal)), Valid(false) {}
: Func(reinterpret_cast<const Function *>(IntVal)), Offset(0),
Valid(false) {}

const Function *getFunction() const { return Func; }
uint64_t getOffset() const { return Offset; }
bool isZero() const { return !Func; }
bool isValid() const { return Valid; }
bool isWeak() const {
Expand All @@ -39,33 +43,8 @@ class FunctionPointer final {
return Func->getDecl()->isWeak();
}

APValue toAPValue(const ASTContext &) const {
if (!Func)
return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
/*OnePastTheEnd=*/false, /*IsNull=*/true);

if (!Valid)
return APValue(static_cast<Expr *>(nullptr),
CharUnits::fromQuantity(getIntegerRepresentation()), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);

if (Func->getDecl())
return APValue(Func->getDecl(), CharUnits::Zero(), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);
return APValue(Func->getExpr(), CharUnits::Zero(), {},
/*OnePastTheEnd=*/false, /*IsNull=*/false);
}

void print(llvm::raw_ostream &OS) const {
OS << "FnPtr(";
if (Func && Valid)
OS << Func->getName();
else if (Func)
OS << reinterpret_cast<uintptr_t>(Func);
else
OS << "nullptr";
OS << ")";
}
APValue toAPValue(const ASTContext &) const;
void print(llvm::raw_ostream &OS) const;

std::string toDiagnosticString(const ASTContext &Ctx) const {
if (!Func)
Expand All @@ -79,7 +58,7 @@ class FunctionPointer final {
}

ComparisonCategoryResult compare(const FunctionPointer &RHS) const {
if (Func == RHS.Func)
if (Func == RHS.Func && Offset == RHS.Offset)
return ComparisonCategoryResult::Equal;
return ComparisonCategoryResult::Unordered;
}
Expand Down
37 changes: 32 additions & 5 deletions clang/lib/AST/ByteCode/Interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1857,8 +1857,23 @@ bool OffsetHelper(InterpState &S, CodePtr OpPC, const T &Offset,
else
S.Stk.push<Pointer>(V - O, Ptr.asIntPointer().Desc);
return true;
} else if (Ptr.isFunctionPointer()) {
uint64_t O = static_cast<uint64_t>(Offset);
uint64_t N;
if constexpr (Op == ArithOp::Add)
N = Ptr.getByteOffset() + O;
else
N = Ptr.getByteOffset() - O;

if (N > 1)
S.CCEDiag(S.Current->getSource(OpPC), diag::note_constexpr_array_index)
<< N << /*non-array*/ true << 0;
S.Stk.push<Pointer>(Ptr.asFunctionPointer().getFunction(), N);
return true;
}

assert(Ptr.isBlockPointer());

uint64_t MaxIndex = static_cast<uint64_t>(Ptr.getNumElems());
uint64_t Index;
if (Ptr.isOnePastEnd())
Expand Down Expand Up @@ -2024,10 +2039,15 @@ inline bool SubPtr(InterpState &S, CodePtr OpPC) {
return true;
}

T A = LHS.isElementPastEnd() ? T::from(LHS.getNumElems())
: T::from(LHS.getIndex());
T B = RHS.isElementPastEnd() ? T::from(RHS.getNumElems())
: T::from(RHS.getIndex());
T A = LHS.isBlockPointer()
? (LHS.isElementPastEnd() ? T::from(LHS.getNumElems())
: T::from(LHS.getIndex()))
: T::from(LHS.getIntegerRepresentation());
T B = RHS.isBlockPointer()
? (RHS.isElementPastEnd() ? T::from(RHS.getNumElems())
: T::from(RHS.getIndex()))
: T::from(RHS.getIntegerRepresentation());

return AddSubMulHelper<T, T::sub, std::minus>(S, OpPC, A.bitWidth(), A, B);
}

Expand Down Expand Up @@ -2905,8 +2925,15 @@ inline bool DecayPtr(InterpState &S, CodePtr OpPC) {

if constexpr (std::is_same_v<FromT, FunctionPointer> &&
std::is_same_v<ToT, Pointer>) {
S.Stk.push<Pointer>(OldPtr.getFunction());
S.Stk.push<Pointer>(OldPtr.getFunction(), OldPtr.getOffset());
return true;
} else if constexpr (std::is_same_v<FromT, Pointer> &&
std::is_same_v<ToT, FunctionPointer>) {
if (OldPtr.isFunctionPointer()) {
S.Stk.push<FunctionPointer>(OldPtr.asFunctionPointer().getFunction(),
OldPtr.getByteOffset());
return true;
}
}

S.Stk.push<ToT>(ToT(OldPtr.getIntegerRepresentation(), nullptr));
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/AST/ByteCode/Pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class Pointer {
if (isIntegralPointer())
return asIntPointer().Value + (Offset * elemSize());
if (isFunctionPointer())
return asFunctionPointer().getIntegerRepresentation();
return asFunctionPointer().getIntegerRepresentation() + Offset;
return reinterpret_cast<uint64_t>(asBlockPointer().Pointee) + Offset;
}

Expand Down Expand Up @@ -551,7 +551,7 @@ class Pointer {
}

/// Returns the byte offset from the start.
unsigned getByteOffset() const {
uint64_t getByteOffset() const {
if (isIntegralPointer())
return asIntPointer().Value + Offset;
if (isOnePastEnd())
Expand Down Expand Up @@ -614,6 +614,8 @@ class Pointer {

/// Checks if the pointer is pointing to a zero-size array.
bool isZeroSizeArray() const {
if (isFunctionPointer())
return false;
if (const auto *Desc = getFieldDesc())
return Desc->isZeroSizeArray();
return false;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ add_clang_library(clangAST
ByteCode/EvalEmitter.cpp
ByteCode/Frame.cpp
ByteCode/Function.cpp
ByteCode/FunctionPointer.cpp
ByteCode/InterpBuiltin.cpp
ByteCode/Floating.cpp
ByteCode/EvaluationResult.cpp
Expand Down
16 changes: 16 additions & 0 deletions clang/test/AST/ByteCode/c.c
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,19 @@ void T1(void) {

enum teste1 test1f(void), (*test1)(void) = test1f; // pedantic-warning {{ISO C forbids forward references to 'enum' types}}
enum teste1 { TEST1 };


void func(void) {
_Static_assert(func + 1 - func == 1, ""); // pedantic-warning {{arithmetic on a pointer to the function type}} \
// pedantic-warning {{arithmetic on pointers to the function type}} \
// pedantic-warning {{not an integer constant expression}}
_Static_assert(func + 0xdead000000000000UL - 0xdead000000000000UL == func, ""); // pedantic-warning 2{{arithmetic on a pointer to the function type}} \
// pedantic-warning {{not an integer constant expression}} \
// pedantic-note {{cannot refer to element 16045481047390945280 of non-array object in a constant expression}}
_Static_assert(func + 1 != func, ""); // pedantic-warning {{arithmetic on a pointer to the function type}} \
// pedantic-warning {{expression is not an integer constant expression}}
func + 0xdead000000000000UL; // all-warning {{expression result unused}} \
// pedantic-warning {{arithmetic on a pointer to the function type}}
func - 0xdead000000000000UL; // all-warning {{expression result unused}} \
// pedantic-warning {{arithmetic on a pointer to the function type}}
}
Loading