Skip to content

Commit 020e781

Browse files
committed
Add support for safety checks for pointers from cpp1
1 parent e1b6443 commit 020e781

File tree

6 files changed

+169
-15
lines changed

6 files changed

+169
-15
lines changed

include/cpp2util.h

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
#include <cstddef>
204204
#include <utility>
205205
#include <cstdio>
206+
#include <span>
206207

207208
#if defined(CPP2_USE_SOURCE_LOCATION)
208209
#include <source_location>
@@ -498,6 +499,135 @@ class out {
498499
}(PARAM1)
499500
//--------------------------------------------------------------------
500501

502+
//-----------------------------------------------------------------------
503+
//
504+
// cpp2::safety_check() ensures that cpp1 pointers are also covered by safetychecks
505+
//
506+
//-----------------------------------------------------------------------
507+
//
508+
template <typename... Ts>
509+
inline constexpr auto program_violates_lifetime_safety_guarantee = sizeof...(Ts) < 0;
510+
511+
template <typename T>
512+
requires std::is_pointer_v<T>
513+
class safetychecked_pointer {
514+
T ptr;
515+
public:
516+
517+
constexpr safetychecked_pointer(T ptr) : ptr{ptr} {}
518+
519+
constexpr operator T&() noexcept { return ptr; }
520+
521+
template <typename... Ts> void operator+ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
522+
template <typename... Ts> void operator- () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
523+
template <typename X> void operator+ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
524+
template <typename X> void operator- (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
525+
template <typename X> void operator* (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
526+
template <typename X> void operator/ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
527+
template <typename X> void operator% (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
528+
template <typename X> void operator^ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
529+
template <typename X> void operator& (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
530+
template <typename X> void operator| (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
531+
532+
template <typename... Ts> void operator++ (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
533+
template <typename... Ts> void operator-- (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
534+
template <typename... Ts> void operator[] (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
535+
template <typename X> void operator+= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
536+
template <typename X> void operator-= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
537+
template <typename X> void operator*= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
538+
template <typename X> void operator/= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
539+
540+
template <typename... Ts> void operator~ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
541+
template <typename X > void operator%= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
542+
template <typename X > void operator^= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
543+
template <typename X > void operator&= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
544+
template <typename X > void operator|= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
545+
template <typename X > void operator<<=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
546+
template <typename X > void operator>>=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
547+
template <typename X > void operator<< (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
548+
template <typename X > void operator>> (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
549+
550+
template <typename X > friend void operator+ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
551+
template <typename X > friend void operator- (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
552+
template <typename X > friend void operator* (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
553+
template <typename X > friend void operator/ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
554+
template <typename X > friend void operator% (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
555+
template <typename X > friend void operator^ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
556+
template <typename X > friend void operator& (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
557+
template <typename X > friend void operator| (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
558+
559+
560+
template <typename X>
561+
requires (std::is_same_v<T,X> || std::is_base_of_v<T, X>)
562+
constexpr safetychecked_pointer& operator=(X lhs) noexcept {
563+
ptr = lhs;
564+
return *this;
565+
}
566+
567+
template <typename X>
568+
requires std::is_same_v<std::nullptr_t,X>
569+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from null is illegal"); }
570+
571+
template <typename X>
572+
requires std::is_integral_v<X>
573+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from integer is illegal"); }
574+
575+
bool operator!() const { return !ptr; }
576+
577+
constexpr safetychecked_pointer<T*> operator&() noexcept { return &ptr; }
578+
579+
constexpr auto operator*() noexcept {
580+
if constexpr (std::is_pointer_v<CPP2_TYPEOF(*ptr)>) {
581+
return safetychecked_pointer<CPP2_TYPEOF(*ptr)>(*ptr);
582+
} else {
583+
return *ptr;
584+
}
585+
}
586+
587+
constexpr T operator->() const noexcept { return ptr; }
588+
};
589+
590+
template <typename X>
591+
requires ( !std::is_pointer_v<std::remove_cvref_t<X>>
592+
&& std::is_copy_constructible_v<X> )
593+
inline constexpr auto safety_check(X const& x) {
594+
return x;
595+
}
596+
597+
template <typename R, typename... Args>
598+
inline constexpr auto safety_check(R (&x)(Args...)) {
599+
return x;
600+
}
601+
602+
template <typename X>
603+
requires std::is_rvalue_reference_v<X>
604+
inline constexpr decltype(auto) safety_check(X&& x) {
605+
return std::forward<X>(x);
606+
}
607+
608+
template <typename X>
609+
requires (std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_bounded_array_v<X>)
610+
inline constexpr auto safety_check(X const& x) {
611+
return safetychecked_pointer(x);
612+
}
613+
614+
template <typename X>
615+
requires (!std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_function_v<X> && !std::is_bounded_array_v<X>)
616+
inline constexpr auto& safety_check(X& x) {
617+
return x;
618+
}
619+
620+
template <typename X>
621+
requires (!std::is_copy_constructible_v<X>)
622+
inline constexpr auto safety_check(X&& x) {
623+
return std::forward<X>(x);
624+
}
625+
626+
template <typename X>
627+
requires std::is_bounded_array_v<X>
628+
inline constexpr auto safety_check(X const& x) {
629+
return std::span(x);
630+
}
501631

502632
//-----------------------------------------------------------------------
503633
//

regression-tests/test-results/pure2-intro-example-hello-2022.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ auto println(auto const& x, auto const& len) -> void;
1919
std::span view { vec };
2020

2121
for ( auto&& cpp2_range = view; auto& str : cpp2_range ) {
22-
auto len { decorate(str) };
22+
auto len { cpp2::safety_check(decorate(str)) };
2323
println(str, len);
2424
}
2525
}

regression-tests/test-results/pure2-stdio.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//
1717
[[nodiscard]] auto main() -> int{
1818
std::string s { "Fred" };
19-
auto myfile { fopen("xyzzy", "w") };
19+
auto myfile { cpp2::safety_check(fopen("xyzzy", "w")) };
2020
CPP2_UFCS(fprintf, myfile, "Hello %s with UFCS!", CPP2_UFCS_0(c_str, s));
2121
CPP2_UFCS_0(fclose, myfile);
2222
}

source/cppfront.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ class cppfront
621621
bool violates_bounds_safety = false;
622622
bool violates_initialization_safety = false;
623623
bool suppress_move_from_last_use = false;
624+
bool needs_safetycheck = false;
624625

625626
// For lowering
626627
//
@@ -1091,7 +1092,7 @@ class cppfront
10911092

10921093
in_definite_init = is_definite_initialization(n.identifier);
10931094
if (!in_definite_init && !in_parameter_list) {
1094-
if (auto decl = sema.get_local_declaration_of(*n.identifier);
1095+
if (auto decl = sema.get_declaration_of(*n.identifier);
10951096
is_local_name &&
10961097
decl &&
10971098
// note pointer equality: if we're not in the actual declaration of n.identifier
@@ -1564,10 +1565,10 @@ class cppfront
15641565
return true;
15651566

15661567
if (decl->declaration->dereference) {
1567-
auto deref = sema.get_local_declaration_of(*decl->declaration->dereference);
1568+
auto deref = sema.get_declaration_of(*decl->declaration->dereference);
15681569
return is_it_pointer_declaration(deref, deref_cnt+decl->declaration->dereference_cnt, addr_cnt);
15691570
} else if (decl->declaration->address_of) {
1570-
auto addr = sema.get_local_declaration_of(*decl->declaration->address_of);
1571+
auto addr = sema.get_declaration_of(*decl->declaration->address_of);
15711572
return is_it_pointer_declaration(addr, deref_cnt, addr_cnt+1);
15721573
}
15731574

@@ -1590,6 +1591,7 @@ class cppfront
15901591
{
15911592
assert(n.expr);
15921593
last_postfix_expr_was_pointer = false;
1594+
bool add_safetycheck = false;
15931595

15941596
// Check that this isn't pointer arithmentic
15951597
// (initial partial implementation)
@@ -1601,7 +1603,7 @@ class cppfront
16011603
{
16021604
auto& unqual = std::get<id_expression_node::unqualified>(id->id);
16031605
assert(unqual);
1604-
auto decl = sema.get_local_declaration_of(*unqual->identifier);
1606+
auto decl = sema.get_declaration_of(*unqual->identifier, true);
16051607

16061608
bool is_pointer = false;
16071609
if (decl && decl->declaration) {
@@ -1612,8 +1614,8 @@ class cppfront
16121614
}
16131615
}
16141616

1615-
// TODO: Generalize this -- for now we detect only multi-level cases of the form "p: ***int = ...;"
1616-
// We don't recognize pointer types that are deduced or from Cpp1
1617+
// if initialized by something suspicious (that we have no information about) we need to add cpp1 safety checks
1618+
add_safetycheck = !decl && needs_safetycheck;
16171619
if (is_it_pointer_declaration(decl) || !unqual->pointer_declarators.empty() || is_pointer) {
16181620
if (n.ops.empty()) {
16191621
last_postfix_expr_was_pointer = true;
@@ -1641,6 +1643,16 @@ class cppfront
16411643
}
16421644
}
16431645

1646+
std::shared_ptr<void> _on_return;
1647+
1648+
if (add_safetycheck) {
1649+
needs_safetycheck = false;
1650+
printer.print_cpp2("cpp2::safety_check(", n.position());
1651+
_on_return = [](auto l) { return std::shared_ptr<void>(nullptr, l); }([&](auto){
1652+
printer.print_cpp2(")", n.position());
1653+
});
1654+
}
1655+
16441656
// Simple case: If there are no .ops, just emit the expression
16451657
if (n.ops.empty()) {
16461658
emit(*n.expr);
@@ -2592,7 +2604,9 @@ class cppfront
25922604

25932605
push_need_expression_list_parens(false);
25942606
assert( n.initializer );
2607+
needs_safetycheck = n.initializer->suspicious_initialization;
25952608
emit( *n.initializer, false );
2609+
needs_safetycheck = false;
25962610
pop_need_expression_list_parens();
25972611

25982612
printer.print_cpp2( " }", n.position() );

source/parse.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ struct statement_node
693693
{
694694
token const* let;
695695
std::unique_ptr<parameter_declaration_list_node> let_params;
696+
token const* suspicious_initialization = nullptr;
696697

697698
enum active { expression=0, compound, selection, declaration, return_, iteration, contract, inspect };
698699
std::variant<
@@ -2880,6 +2881,8 @@ class parser
28802881
}
28812882
}
28822883

2884+
token const* suspicious_initialization = nullptr;
2885+
28832886
if (deduced_type) {
28842887
if (peek(1)->type() == lexeme::Ampersand) {
28852888
n->address_of = &curr();
@@ -2890,13 +2893,19 @@ class parser
28902893
while(peek(n->dereference_cnt+1)->type() == lexeme::Multiply) {
28912894
n->dereference_cnt += 1;
28922895
}
2896+
}
2897+
else if ((peek(1)->type() == lexeme::LeftParen && curr().type() != lexeme::Colon)
2898+
|| curr().type() == lexeme::Identifier ) {
2899+
suspicious_initialization = &curr();
28932900
}
28942901
}
28952902

28962903
if (!(n->initializer = statement(semicolon_required, n->equal_sign))) {
28972904
error("ill-formed initializer");
28982905
next();
28992906
return {};
2907+
} else {
2908+
n->initializer->suspicious_initialization = suspicious_initialization;
29002909
}
29012910
}
29022911

source/sema.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,8 @@ class sema
226226
{
227227
}
228228

229-
// Get the declaration of t within the same named function
230-
//
231-
auto get_local_declaration_of(token const& t) -> declaration_sym const*
229+
230+
auto get_declaration_of(token const& t, bool look_beyond_current_function = false) -> declaration_sym const*
232231
{
233232
// First find the position the query is coming from
234233
// and remember its depth
@@ -255,9 +254,11 @@ class sema
255254
// Don't look beyond the start of the current named (has identifier) function
256255
// (an unnamed function is ok to look beyond)
257256
assert(decl.declaration);
258-
if (decl.declaration->type.index() == declaration_node::function &&
259-
decl.declaration->identifier)
260-
{
257+
if (
258+
decl.declaration->type.index() == declaration_node::function
259+
&& decl.declaration->identifier
260+
&& !look_beyond_current_function
261+
) {
261262
return nullptr;
262263
}
263264

@@ -895,7 +896,7 @@ class sema
895896
{
896897
// Put this into the table if it's a use of an object in scope
897898
// or it's a 'copy' parameter
898-
if (auto decl = get_local_declaration_of(t);
899+
if (auto decl = get_declaration_of(t);
899900
decl
900901
)
901902
{

0 commit comments

Comments
 (0)