Skip to content

Commit 31b03db

Browse files
committed
Add support for safety checks for pointers from cpp1
1 parent d7bbdb1 commit 31b03db

File tree

4 files changed

+150
-6
lines changed

4 files changed

+150
-6
lines changed

include/cpp2util.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,129 @@ class out {
494494
}(PARAM1)
495495
//--------------------------------------------------------------------
496496

497+
//-----------------------------------------------------------------------
498+
//
499+
// cpp2::safety_check() ensures that cpp1 pointers are also covered by safetychecks
500+
//
501+
//-----------------------------------------------------------------------
502+
//
503+
template <typename... Ts>
504+
inline constexpr auto program_violates_lifetime_safety_guarantee = sizeof...(Ts) < 0;
505+
506+
template <typename T>
507+
requires std::is_pointer_v<T>
508+
class safetychecked_pointer {
509+
T ptr;
510+
public:
511+
512+
constexpr safetychecked_pointer(T ptr) : ptr{ptr} {}
513+
514+
constexpr operator T&() noexcept { return ptr; }
515+
516+
template <typename... Ts> void operator+ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
517+
template <typename... Ts> void operator- () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
518+
template <typename X> void operator+ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
519+
template <typename X> void operator- (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
520+
template <typename X> void operator* (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
521+
template <typename X> void operator/ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
522+
template <typename X> void operator% (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
523+
template <typename X> void operator^ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
524+
template <typename X> void operator& (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
525+
template <typename X> void operator| (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
526+
527+
template <typename... Ts> void operator++ (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
528+
template <typename... Ts> void operator-- (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
529+
template <typename... Ts> void operator[] (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
530+
template <typename X> void operator+= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
531+
template <typename X> void operator-= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
532+
template <typename X> void operator*= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
533+
template <typename X> void operator/= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
534+
535+
template <typename... Ts> void operator~ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
536+
template <typename X > void operator%= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
537+
template <typename X > void operator^= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
538+
template <typename X > void operator&= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
539+
template <typename X > void operator|= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
540+
template <typename X > void operator<<=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
541+
template <typename X > void operator>>=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
542+
template <typename X > void operator<< (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
543+
template <typename X > void operator>> (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
544+
545+
template <typename X > friend void operator+ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
546+
template <typename X > friend void operator- (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
547+
template <typename X > friend void operator* (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
548+
template <typename X > friend void operator/ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
549+
template <typename X > friend void operator% (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
550+
template <typename X > friend void operator^ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
551+
template <typename X > friend void operator& (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
552+
template <typename X > friend void operator| (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
553+
554+
555+
template <typename X>
556+
requires (std::is_same_v<T,X> || std::is_base_of_v<T, X>)
557+
constexpr safetychecked_pointer& operator=(X lhs) noexcept {
558+
ptr = lhs;
559+
return *this;
560+
}
561+
562+
template <typename X>
563+
requires std::is_same_v<std::nullptr_t,X>
564+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from null is illegal"); }
565+
566+
template <typename X>
567+
requires std::is_integral_v<X>
568+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from integer is illegal"); }
569+
570+
bool operator!() const { return !ptr; }
571+
572+
constexpr safetychecked_pointer<T*> operator&() noexcept { return &ptr; }
573+
574+
constexpr auto operator*() noexcept {
575+
if constexpr (std::is_pointer_v<CPP2_TYPEOF(*ptr)>) {
576+
return safetychecked_pointer<CPP2_TYPEOF(*ptr)>(*ptr);
577+
} else {
578+
return *ptr;
579+
}
580+
}
581+
582+
constexpr T operator->() const noexcept { return ptr; }
583+
};
584+
585+
template <typename X>
586+
requires ( !std::is_pointer_v<std::remove_cvref_t<X>>
587+
&& std::is_copy_constructible_v<X> )
588+
decltype(auto) safety_check(const X& x) {
589+
return x;
590+
}
591+
592+
template <typename X>
593+
requires std::is_rvalue_reference_v<X>
594+
decltype(auto) safety_check(X&& x) {
595+
return x;
596+
}
597+
598+
template <typename X>
599+
requires std::is_pointer_v<std::remove_cvref_t<X>>
600+
auto safety_check(X x) {
601+
return safetychecked_pointer(x);
602+
}
603+
604+
template <typename X>
605+
requires (!std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_function_v<X>)
606+
auto& safety_check(X& x) {
607+
return x;
608+
}
609+
610+
template <typename X>
611+
requires (!std::is_copy_constructible_v<X>)
612+
auto safety_check(X&& x) {
613+
return std::forward<X>(x);
614+
}
615+
616+
template <typename X, std::size_t N>
617+
decltype(auto) safety_check(X (&x)[N]) {
618+
static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");
619+
}
497620

498621
//-----------------------------------------------------------------------
499622
//

source/cppfront.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ class cppfront
621621
bool violates_bounds_safety = false;
622622
bool violates_initialization_safety = false;
623623
bool suppress_move_from_last_use = false;
624+
bool needs_safetycheck = false;
624625

625626
// For lowering
626627
//
@@ -1579,6 +1580,7 @@ class cppfront
15791580
{
15801581
assert(n.expr);
15811582
last_postfix_expr_was_pointer = false;
1583+
bool add_safetycheck = false;
15821584

15831585
// Check that this isn't pointer arithmentic
15841586
// (initial partial implementation)
@@ -1590,9 +1592,10 @@ class cppfront
15901592
{
15911593
auto& unqual = std::get<id_expression_node::unqualified>(id->id);
15921594
assert(unqual);
1593-
auto decl = sema.get_declaration_of(*unqual->identifier);
1594-
// TODO: Generalize this -- for now we detect only multi-level cases of the form "p: ***int = ...;"
1595-
// We don't recognize pointer types from Cpp1
1595+
auto decl = sema.get_declaration_of(*unqual->identifier, true);
1596+
1597+
// if initialized by something suspicious (that we have no information about) we need to add cpp1 safety checks
1598+
add_safetycheck = !decl && needs_safetycheck;
15961599
if (is_it_pointer_declaration(decl)) {
15971600
if (n.ops.empty()) {
15981601
last_postfix_expr_was_pointer = true;
@@ -1620,6 +1623,16 @@ class cppfront
16201623
}
16211624
}
16221625

1626+
std::shared_ptr<void> _on_return;
1627+
1628+
if (add_safetycheck) {
1629+
needs_safetycheck = false;
1630+
printer.print_cpp2("cpp2::safety_check(", n.position());
1631+
_on_return = [](auto l) { return std::shared_ptr<void>(nullptr, l); }([&](auto){
1632+
printer.print_cpp2(")", n.position());
1633+
});
1634+
}
1635+
16231636
// Simple case: If there are no .ops, just emit the expression
16241637
if (n.ops.empty()) {
16251638
emit(*n.expr);
@@ -2578,6 +2591,7 @@ class cppfront
25782591

25792592
push_need_expression_list_parens(false);
25802593
assert( n.initializer );
2594+
needs_safetycheck = n.initializer->suspicious_initialization;
25812595
emit( *n.initializer, false );
25822596
pop_need_expression_list_parens();
25832597

source/parse.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ struct statement_node
692692
{
693693
token const* let;
694694
std::unique_ptr<parameter_declaration_list_node> let_params;
695+
token const* suspicious_initialization = nullptr;
695696

696697
enum active { expression=0, compound, selection, declaration, return_, iteration, contract, inspect };
697698
std::variant<
@@ -2883,6 +2884,8 @@ class parser
28832884
}
28842885
}
28852886

2887+
token const* suspicious_initialization = nullptr;
2888+
28862889
if (deduced_type) {
28872890
if (peek(1)->type() == lexeme::Ampersand) {
28882891
n->address_of = &curr();
@@ -2892,13 +2895,18 @@ class parser
28922895
while(peek(n->dereference_cnt+1)->type() == lexeme::Multiply) {
28932896
n->dereference_cnt += 1;
28942897
}
2898+
} else if ((peek(1)->type() == lexeme::LeftParen && curr().type() != lexeme::Colon)
2899+
|| curr().type() == lexeme::Identifier ) {
2900+
suspicious_initialization = &curr();
28952901
}
28962902
}
28972903

28982904
if (!(n->initializer = statement(semicolon_required, n->equal_sign))) {
28992905
error("ill-formed initializer");
29002906
next();
29012907
return {};
2908+
} else {
2909+
n->initializer->suspicious_initialization = suspicious_initialization;
29022910
}
29032911
}
29042912

source/sema.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class sema
227227
}
228228

229229

230-
auto get_declaration_of(token const& t) -> declaration_sym const*
230+
auto get_declaration_of(token const& t, bool look_beyond_current_function = false) -> declaration_sym const*
231231
{
232232
// First find the position the query is coming from
233233
// and remember its depth
@@ -250,9 +250,8 @@ class sema
250250
{
251251
auto const& decl = std::get<symbol::active::declaration>(i->sym);
252252

253-
// Don't look beyond the current function
254253
assert(decl.declaration);
255-
if (decl.declaration->type.index() == declaration_node::function) {
254+
if (!look_beyond_current_function && decl.declaration->type.index() == declaration_node::function) {
256255
return nullptr;
257256
}
258257

0 commit comments

Comments
 (0)