Skip to content

Commit 49ce940

Browse files
committed
Add support for safety checks for pointers from cpp1
1 parent 9132abc commit 49ce940

File tree

4 files changed

+150
-6
lines changed

4 files changed

+150
-6
lines changed

include/cpp2util.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,129 @@ class out {
494494
}(PARAM1)
495495
//--------------------------------------------------------------------
496496

497+
//-----------------------------------------------------------------------
498+
//
499+
// cpp2::safety_check() ensures that cpp1 pointers are also covered by safetychecks
500+
//
501+
//-----------------------------------------------------------------------
502+
//
503+
template <typename... Ts>
504+
inline constexpr auto program_violates_lifetime_safety_guarantee = sizeof...(Ts) < 0;
505+
506+
template <typename T>
507+
requires std::is_pointer_v<T>
508+
class safetychecked_pointer {
509+
T ptr;
510+
public:
511+
512+
constexpr safetychecked_pointer(T ptr) : ptr{ptr} {}
513+
514+
constexpr operator T&() noexcept { return ptr; }
515+
516+
template <typename... Ts> void operator+ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
517+
template <typename... Ts> void operator- () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
518+
template <typename X> void operator+ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
519+
template <typename X> void operator- (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
520+
template <typename X> void operator* (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
521+
template <typename X> void operator/ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
522+
template <typename X> void operator% (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
523+
template <typename X> void operator^ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
524+
template <typename X> void operator& (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
525+
template <typename X> void operator| (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
526+
527+
template <typename... Ts> void operator++ (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
528+
template <typename... Ts> void operator-- (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
529+
template <typename... Ts> void operator[] (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
530+
template <typename X> void operator+= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
531+
template <typename X> void operator-= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
532+
template <typename X> void operator*= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
533+
template <typename X> void operator/= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
534+
535+
template <typename... Ts> void operator~ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
536+
template <typename X > void operator%= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
537+
template <typename X > void operator^= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
538+
template <typename X > void operator&= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
539+
template <typename X > void operator|= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
540+
template <typename X > void operator<<=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
541+
template <typename X > void operator>>=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
542+
template <typename X > void operator<< (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
543+
template <typename X > void operator>> (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
544+
545+
template <typename X > friend void operator+ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
546+
template <typename X > friend void operator- (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
547+
template <typename X > friend void operator* (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
548+
template <typename X > friend void operator/ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
549+
template <typename X > friend void operator% (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
550+
template <typename X > friend void operator^ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
551+
template <typename X > friend void operator& (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
552+
template <typename X > friend void operator| (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
553+
554+
555+
template <typename X>
556+
requires (std::is_same_v<T,X> || std::is_base_of_v<T, X>)
557+
constexpr safetychecked_pointer& operator=(X lhs) noexcept {
558+
ptr = lhs;
559+
return *this;
560+
}
561+
562+
template <typename X>
563+
requires std::is_same_v<std::nullptr_t,X>
564+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from null is illegal"); }
565+
566+
template <typename X>
567+
requires std::is_integral_v<X>
568+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from integer is illegal"); }
569+
570+
bool operator!() const { return !ptr; }
571+
572+
constexpr safetychecked_pointer<T*> operator&() noexcept { return &ptr; }
573+
574+
constexpr auto operator*() noexcept {
575+
if constexpr (std::is_pointer_v<CPP2_TYPEOF(*ptr)>) {
576+
return safetychecked_pointer<CPP2_TYPEOF(*ptr)>(*ptr);
577+
} else {
578+
return *ptr;
579+
}
580+
}
581+
582+
constexpr T operator->() const noexcept { return ptr; }
583+
};
584+
585+
template <typename X>
586+
requires ( !std::is_pointer_v<std::remove_cvref_t<X>>
587+
&& std::is_copy_constructible_v<X> )
588+
decltype(auto) safety_check(const X& x) {
589+
return x;
590+
}
591+
592+
template <typename X>
593+
requires std::is_rvalue_reference_v<X>
594+
decltype(auto) safety_check(X&& x) {
595+
return x;
596+
}
597+
598+
template <typename X>
599+
requires std::is_pointer_v<std::remove_cvref_t<X>>
600+
auto safety_check(X x) {
601+
return safetychecked_pointer(x);
602+
}
603+
604+
template <typename X>
605+
requires (!std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_function_v<X>)
606+
auto& safety_check(X& x) {
607+
return x;
608+
}
609+
610+
template <typename X>
611+
requires (!std::is_copy_constructible_v<X>)
612+
auto safety_check(X&& x) {
613+
return std::forward<X>(x);
614+
}
615+
616+
template <typename X, std::size_t N>
617+
decltype(auto) safety_check(X (&x)[N]) {
618+
static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");
619+
}
497620

498621
//-----------------------------------------------------------------------
499622
//

source/cppfront.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ class cppfront
621621
bool violates_bounds_safety = false;
622622
bool violates_initialization_safety = false;
623623
bool suppress_move_from_last_use = false;
624+
bool needs_safetycheck = false;
624625

625626
// For lowering
626627
//
@@ -1590,6 +1591,7 @@ class cppfront
15901591
{
15911592
assert(n.expr);
15921593
last_postfix_expr_was_pointer = false;
1594+
bool add_safetycheck = false;
15931595

15941596
// Check that this isn't pointer arithmentic
15951597
// (initial partial implementation)
@@ -1601,7 +1603,7 @@ class cppfront
16011603
{
16021604
auto& unqual = std::get<id_expression_node::unqualified>(id->id);
16031605
assert(unqual);
1604-
auto decl = sema.get_declaration_of(*unqual->identifier);
1606+
auto decl = sema.get_declaration_of(*unqual->identifier, true);
16051607

16061608
bool is_pointer = false;
16071609
if (decl && decl->declaration) {
@@ -1612,8 +1614,8 @@ class cppfront
16121614
}
16131615
}
16141616

1615-
// TODO: Generalize this -- for now we detect only multi-level cases of the form "p: ***int = ...;"
1616-
// We don't recognize pointer types that are deduced or from Cpp1
1617+
// if initialized by something suspicious (that we have no information about) we need to add cpp1 safety checks
1618+
add_safetycheck = !decl && needs_safetycheck;
16171619
if (is_it_pointer_declaration(decl) || !unqual->pointer_declarators.empty() || is_pointer) {
16181620
if (n.ops.empty()) {
16191621
last_postfix_expr_was_pointer = true;
@@ -1641,6 +1643,16 @@ class cppfront
16411643
}
16421644
}
16431645

1646+
std::shared_ptr<void> _on_return;
1647+
1648+
if (add_safetycheck) {
1649+
needs_safetycheck = false;
1650+
printer.print_cpp2("cpp2::safety_check(", n.position());
1651+
_on_return = [](auto l) { return std::shared_ptr<void>(nullptr, l); }([&](auto){
1652+
printer.print_cpp2(")", n.position());
1653+
});
1654+
}
1655+
16441656
// Simple case: If there are no .ops, just emit the expression
16451657
if (n.ops.empty()) {
16461658
emit(*n.expr);
@@ -2592,7 +2604,9 @@ class cppfront
25922604

25932605
push_need_expression_list_parens(false);
25942606
assert( n.initializer );
2607+
needs_safetycheck = n.initializer->suspicious_initialization;
25952608
emit( *n.initializer, false );
2609+
needs_safetycheck = false;
25962610
pop_need_expression_list_parens();
25972611

25982612
printer.print_cpp2( " }", n.position() );

source/parse.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ struct statement_node
693693
{
694694
token const* let;
695695
std::unique_ptr<parameter_declaration_list_node> let_params;
696+
token const* suspicious_initialization = nullptr;
696697

697698
enum active { expression=0, compound, selection, declaration, return_, iteration, contract, inspect };
698699
std::variant<
@@ -2876,6 +2877,8 @@ class parser
28762877
}
28772878
}
28782879

2880+
token const* suspicious_initialization = nullptr;
2881+
28792882
if (deduced_type) {
28802883
if (peek(1)->type() == lexeme::Ampersand) {
28812884
n->address_of = &curr();
@@ -2885,13 +2888,18 @@ class parser
28852888
while(peek(n->dereference_cnt+1)->type() == lexeme::Multiply) {
28862889
n->dereference_cnt += 1;
28872890
}
2891+
} else if ((peek(1)->type() == lexeme::LeftParen && curr().type() != lexeme::Colon)
2892+
|| curr().type() == lexeme::Identifier ) {
2893+
suspicious_initialization = &curr();
28882894
}
28892895
}
28902896

28912897
if (!(n->initializer = statement(semicolon_required, n->equal_sign))) {
28922898
error("ill-formed initializer");
28932899
next();
28942900
return {};
2901+
} else {
2902+
n->initializer->suspicious_initialization = suspicious_initialization;
28952903
}
28962904
}
28972905

source/sema.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class sema
227227
}
228228

229229

230-
auto get_declaration_of(token const& t) -> declaration_sym const*
230+
auto get_declaration_of(token const& t, bool look_beyond_current_function = false) -> declaration_sym const*
231231
{
232232
// First find the position the query is coming from
233233
// and remember its depth
@@ -250,9 +250,8 @@ class sema
250250
{
251251
auto const& decl = std::get<symbol::active::declaration>(i->sym);
252252

253-
// Don't look beyond the current function
254253
assert(decl.declaration);
255-
if (decl.declaration->type.index() == declaration_node::function) {
254+
if (!look_beyond_current_function && decl.declaration->type.index() == declaration_node::function) {
256255
return nullptr;
257256
}
258257

0 commit comments

Comments
 (0)