Skip to content

Commit 75bf2e6

Browse files
committed
Add support for safety checks for pointers from cpp1
1 parent 8cb1fd1 commit 75bf2e6

File tree

6 files changed

+155
-6
lines changed

6 files changed

+155
-6
lines changed

include/cpp2util.h

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
#include <cstddef>
200200
#include <utility>
201201
#include <cstdio>
202+
#include <span>
202203

203204
#if defined(CPP2_USE_SOURCE_LOCATION)
204205
#include <source_location>
@@ -494,6 +495,130 @@ class out {
494495
}(PARAM1)
495496
//--------------------------------------------------------------------
496497

498+
//-----------------------------------------------------------------------
499+
//
500+
// cpp2::safety_check() ensures that cpp1 pointers are also covered by safetychecks
501+
//
502+
//-----------------------------------------------------------------------
503+
//
504+
template <typename... Ts>
505+
inline constexpr auto program_violates_lifetime_safety_guarantee = sizeof...(Ts) < 0;
506+
507+
template <typename T>
508+
requires std::is_pointer_v<T>
509+
class safetychecked_pointer {
510+
T ptr;
511+
public:
512+
513+
constexpr safetychecked_pointer(T ptr) : ptr{ptr} {}
514+
515+
constexpr operator T&() noexcept { return ptr; }
516+
517+
template <typename... Ts> void operator+ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
518+
template <typename... Ts> void operator- () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
519+
template <typename X> void operator+ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
520+
template <typename X> void operator- (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
521+
template <typename X> void operator* (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
522+
template <typename X> void operator/ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
523+
template <typename X> void operator% (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
524+
template <typename X> void operator^ (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
525+
template <typename X> void operator& (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
526+
template <typename X> void operator| (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
527+
528+
template <typename... Ts> void operator++ (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
529+
template <typename... Ts> void operator-- (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
530+
template <typename... Ts> void operator[] (Ts...) const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
531+
template <typename X> void operator+= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
532+
template <typename X> void operator-= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
533+
template <typename X> void operator*= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
534+
template <typename X> void operator/= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
535+
536+
template <typename... Ts> void operator~ () const {static_assert(program_violates_lifetime_safety_guarantee<Ts...>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
537+
template <typename X > void operator%= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
538+
template <typename X > void operator^= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
539+
template <typename X > void operator&= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
540+
template <typename X > void operator|= (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
541+
template <typename X > void operator<<=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
542+
template <typename X > void operator>>=(X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
543+
template <typename X > void operator<< (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
544+
template <typename X > void operator>> (X) const {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
545+
546+
template <typename X > friend void operator+ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
547+
template <typename X > friend void operator- (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
548+
template <typename X > friend void operator* (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
549+
template <typename X > friend void operator/ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer arithmetic is illegal - use std::span or gsl::span instead");}
550+
template <typename X > friend void operator% (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
551+
template <typename X > friend void operator^ (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
552+
template <typename X > friend void operator& (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
553+
template <typename X > friend void operator| (X, const safetychecked_pointer&) {static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer bitwise manipulation is illegal - use std::bit_cast to convert to raw bytes first");}
554+
555+
556+
template <typename X>
557+
requires (std::is_same_v<T,X> || std::is_base_of_v<T, X>)
558+
constexpr safetychecked_pointer& operator=(X lhs) noexcept {
559+
ptr = lhs;
560+
return *this;
561+
}
562+
563+
template <typename X>
564+
requires std::is_same_v<std::nullptr_t,X>
565+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from null is illegal"); }
566+
567+
template <typename X>
568+
requires std::is_integral_v<X>
569+
constexpr void operator=(X lhs) noexcept { static_assert(program_violates_lifetime_safety_guarantee<X>, "pointer assignment from integer is illegal"); }
570+
571+
bool operator!() const { return !ptr; }
572+
573+
constexpr safetychecked_pointer<T*> operator&() noexcept { return &ptr; }
574+
575+
constexpr auto operator*() noexcept {
576+
if constexpr (std::is_pointer_v<CPP2_TYPEOF(*ptr)>) {
577+
return safetychecked_pointer<CPP2_TYPEOF(*ptr)>(*ptr);
578+
} else {
579+
return *ptr;
580+
}
581+
}
582+
583+
constexpr T operator->() const noexcept { return ptr; }
584+
};
585+
586+
template <typename X>
587+
requires ( !std::is_pointer_v<std::remove_cvref_t<X>>
588+
&& std::is_copy_constructible_v<X> )
589+
inline constexpr decltype(auto) safety_check(X const& x) {
590+
return x;
591+
}
592+
593+
template <typename X>
594+
requires std::is_rvalue_reference_v<X>
595+
inline constexpr decltype(auto) safety_check(X&& x) {
596+
return x;
597+
}
598+
599+
template <typename X>
600+
requires (std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_bounded_array_v<X>)
601+
inline constexpr auto safety_check(X const& x) {
602+
return safetychecked_pointer(x);
603+
}
604+
605+
template <typename X>
606+
requires (!std::is_pointer_v<std::remove_cvref_t<X>> && !std::is_function_v<X> && !std::is_bounded_array_v<X>)
607+
inline constexpr auto& safety_check(X& x) {
608+
return x;
609+
}
610+
611+
template <typename X>
612+
requires (!std::is_copy_constructible_v<X>)
613+
inline constexpr auto safety_check(X&& x) {
614+
return std::forward<X>(x);
615+
}
616+
617+
template <typename X>
618+
requires std::is_bounded_array_v<X>
619+
inline constexpr auto safety_check(X const& x) {
620+
return std::span(x);
621+
}
497622

498623
//-----------------------------------------------------------------------
499624
//

regression-tests/test-results/pure2-intro-example-hello-2022.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ auto println(auto const& x, auto const& len) -> void;
1919
std::span view { vec };
2020

2121
for ( auto&& cpp2_range = view; auto& str : cpp2_range ) {
22-
auto len { decorate(str) };
22+
auto len { cpp2::safety_check(decorate(str)) };
2323
println(str, len);
2424
}
2525
}

regression-tests/test-results/pure2-stdio.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//
1717
[[nodiscard]] auto main() -> int{
1818
std::string s { "Fred" };
19-
auto myfile { fopen("xyzzy", "w") };
19+
auto myfile { cpp2::safety_check(fopen("xyzzy", "w")) };
2020
CPP2_UFCS(fprintf, myfile, "Hello %s with UFCS!", CPP2_UFCS_0(c_str, s));
2121
CPP2_UFCS_0(fclose, myfile);
2222
}

source/cppfront.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ class cppfront
621621
bool violates_bounds_safety = false;
622622
bool violates_initialization_safety = false;
623623
bool suppress_move_from_last_use = false;
624+
bool needs_safetycheck = false;
624625

625626
// For lowering
626627
//
@@ -1590,6 +1591,7 @@ class cppfront
15901591
{
15911592
assert(n.expr);
15921593
last_postfix_expr_was_pointer = false;
1594+
bool add_safetycheck = false;
15931595

15941596
// Check that this isn't pointer arithmentic
15951597
// (initial partial implementation)
@@ -1601,7 +1603,7 @@ class cppfront
16011603
{
16021604
auto& unqual = std::get<id_expression_node::unqualified>(id->id);
16031605
assert(unqual);
1604-
auto decl = sema.get_declaration_of(*unqual->identifier);
1606+
auto decl = sema.get_declaration_of(*unqual->identifier, true);
16051607

16061608
bool is_pointer = false;
16071609
if (decl && decl->declaration) {
@@ -1612,8 +1614,8 @@ class cppfront
16121614
}
16131615
}
16141616

1615-
// TODO: Generalize this -- for now we detect only multi-level cases of the form "p: ***int = ...;"
1616-
// We don't recognize pointer types that are deduced or from Cpp1
1617+
// if initialized by something suspicious (that we have no information about) we need to add cpp1 safety checks
1618+
add_safetycheck = !decl && needs_safetycheck;
16171619
if (is_it_pointer_declaration(decl) || !unqual->pointer_declarators.empty() || is_pointer) {
16181620
if (n.ops.empty()) {
16191621
last_postfix_expr_was_pointer = true;
@@ -1641,6 +1643,16 @@ class cppfront
16411643
}
16421644
}
16431645

1646+
std::shared_ptr<void> _on_return;
1647+
1648+
if (add_safetycheck) {
1649+
needs_safetycheck = false;
1650+
printer.print_cpp2("cpp2::safety_check(", n.position());
1651+
_on_return = [](auto l) { return std::shared_ptr<void>(nullptr, l); }([&](auto){
1652+
printer.print_cpp2(")", n.position());
1653+
});
1654+
}
1655+
16441656
// Simple case: If there are no .ops, just emit the expression
16451657
if (n.ops.empty()) {
16461658
emit(*n.expr);
@@ -2592,7 +2604,9 @@ class cppfront
25922604

25932605
push_need_expression_list_parens(false);
25942606
assert( n.initializer );
2607+
needs_safetycheck = n.initializer->suspicious_initialization;
25952608
emit( *n.initializer, false );
2609+
needs_safetycheck = false;
25962610
pop_need_expression_list_parens();
25972611

25982612
printer.print_cpp2( " }", n.position() );

source/parse.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ struct statement_node
693693
{
694694
token const* let;
695695
std::unique_ptr<parameter_declaration_list_node> let_params;
696+
token const* suspicious_initialization = nullptr;
696697

697698
enum active { expression=0, compound, selection, declaration, return_, iteration, contract, inspect };
698699
std::variant<
@@ -2880,6 +2881,8 @@ class parser
28802881
}
28812882
}
28822883

2884+
token const* suspicious_initialization = nullptr;
2885+
28832886
if (deduced_type) {
28842887
if (peek(1)->type() == lexeme::Ampersand) {
28852888
n->address_of = &curr();
@@ -2890,13 +2893,19 @@ class parser
28902893
while(peek(n->dereference_cnt+1)->type() == lexeme::Multiply) {
28912894
n->dereference_cnt += 1;
28922895
}
2896+
}
2897+
else if ((peek(1)->type() == lexeme::LeftParen && curr().type() != lexeme::Colon)
2898+
|| curr().type() == lexeme::Identifier ) {
2899+
suspicious_initialization = &curr();
28932900
}
28942901
}
28952902

28962903
if (!(n->initializer = statement(semicolon_required, n->equal_sign))) {
28972904
error("ill-formed initializer");
28982905
next();
28992906
return {};
2907+
} else {
2908+
n->initializer->suspicious_initialization = suspicious_initialization;
29002909
}
29012910
}
29022911

source/sema.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class sema
227227
}
228228

229229

230-
auto get_declaration_of(token const& t) -> declaration_sym const*
230+
auto get_declaration_of(token const& t, bool look_beyond_current_function = false) -> declaration_sym const*
231231
{
232232
// First find the position the query is coming from
233233
// and remember its depth
@@ -253,6 +253,7 @@ class sema
253253
assert(decl.declaration);
254254
if (
255255
decl.declaration->type.index() == declaration_node::function // Don't look beyond the current function
256+
&& !look_beyond_current_function
256257
) {
257258
return nullptr;
258259
}

0 commit comments

Comments
 (0)