Skip to content

Middleware support #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ inline void default_socket_options(socket_t sock) {
class Server {
public:
using Handler = std::function<void(const Request &, Response &)>;
using HandlerWithReturn = std::function<bool(const Request &, Response &)>;
using HandlerWithContentReader = std::function<void(
const Request &, Response &, const ContentReader &content_reader)>;
using Expect100ContinueHandler =
Expand Down Expand Up @@ -627,7 +628,11 @@ class Server {
const char *mime);
void set_file_request_handler(Handler handler);

void set_error_handler(HandlerWithReturn handler);
void set_error_handler(Handler handler);
void set_pre_routing_handler(HandlerWithReturn handler);
void set_post_routing_handler(Handler handler);

void set_expect_100_continue_handler(Expect100ContinueHandler handler);
void set_logger(Logger logger);

Expand Down Expand Up @@ -734,7 +739,9 @@ class Server {
Handlers delete_handlers_;
HandlersForContentReader delete_handlers_for_content_reader_;
Handlers options_handlers_;
Handler error_handler_;
HandlerWithReturn error_handler_;
HandlerWithReturn pre_routing_handler_;
Handler post_routing_handler_;
Logger logger_;
Expect100ContinueHandler expect_100_continue_handler_;

Expand Down Expand Up @@ -4160,14 +4167,23 @@ inline void Server::set_file_request_handler(Handler handler) {
file_request_handler_ = std::move(handler);
}

inline void Server::set_error_handler(Handler handler) {
inline void Server::set_error_handler(HandlerWithReturn handler) {
error_handler_ = std::move(handler);
}

inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
inline void Server::set_error_handler(Handler handler) {
error_handler_ = [handler](const Request &req, Response &res) {
handler(req, res);
return true;
};
}

inline void Server::set_socket_options(SocketOptions socket_options) {
socket_options_ = std::move(socket_options);
inline void Server::set_pre_routing_handler(HandlerWithReturn handler) {
pre_routing_handler_ = std::move(handler);
}

inline void Server::set_post_routing_handler(Handler handler) {
post_routing_handler_ = std::move(handler);
}

inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); }
Expand All @@ -4177,6 +4193,12 @@ Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) {
expect_100_continue_handler_ = std::move(handler);
}

inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }

inline void Server::set_socket_options(SocketOptions socket_options) {
socket_options_ = std::move(socket_options);
}

inline void Server::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
}
Expand Down Expand Up @@ -4268,16 +4290,15 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
bool need_apply_ranges) {
assert(res.status != -1);

if (400 <= res.status && error_handler_) {
error_handler_(req, res);
if (400 <= res.status && error_handler_ && error_handler_(req, res)) {
need_apply_ranges = true;
}

std::string content_type;
std::string boundary;
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }

// Preapre additional headers
// Prepare additional headers
if (close_connection || req.get_header_value("Connection") == "close") {
res.set_header("Connection", "close");
} else {
Expand All @@ -4301,6 +4322,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
res.set_header("Accept-Ranges", "bytes");
}

if (post_routing_handler_) { post_routing_handler_(req, res); }

// Response line and headers
{
detail::BufferStream bstrm;
Expand Down Expand Up @@ -4604,6 +4627,8 @@ inline bool Server::listen_internal() {
}

inline bool Server::routing(Request &req, Response &res, Stream &strm) {
if (pre_routing_handler_ && pre_routing_handler_(req, res)) { return true; }

// File handler
bool is_head_request = req.method == "HEAD";
if ((req.method == "GET" || is_head_request) &&
Expand Down Expand Up @@ -5302,7 +5327,7 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm,

inline bool ClientImpl::write_request(Stream &strm, const Request &req,
bool close_connection, Error &error) {
// Prepare additonal headers
// Prepare additional headers
Headers headers;
if (close_connection) { headers.emplace("Connection", "close"); }

Expand Down
71 changes: 71 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,77 @@ TEST(ErrorHandlerTest, ContentLength) {
ASSERT_FALSE(svr.is_running());
}

TEST(RoutingHandlerTest, PreRoutingHandler) {
Server svr;

svr.set_pre_routing_handler([](const Request &req, Response &res) {
if (req.path == "/routing_handler") {
res.set_header("PRE_ROUTING", "on");
res.set_content("Routing Handler", "text/plain");
return true;
}
return false;
});

svr.set_error_handler([](const Request & /*req*/, Response &res) {
res.set_content("Error", "text/html");
});

svr.set_post_routing_handler([](const Request &req, Response &res) {
if (req.path == "/routing_handler") {
res.set_header("POST_ROUTING", "on");
}
});

svr.Get("/hi", [](const Request & /*req*/, Response &res) {
res.set_content("Hello World!\n", "text/plain");
});

auto thread = std::thread([&]() { svr.listen(HOST, PORT); });

// Give GET time to get a few messages.
std::this_thread::sleep_for(std::chrono::seconds(1));

{
Client cli(HOST, PORT);

auto res = cli.Get("/routing_handler");
ASSERT_TRUE(res);
EXPECT_EQ(200, res->status);
EXPECT_EQ("Routing Handler", res->body);
EXPECT_EQ(1, res->get_header_value_count("PRE_ROUTING"));
EXPECT_EQ("on", res->get_header_value("PRE_ROUTING"));
EXPECT_EQ(1, res->get_header_value_count("POST_ROUTING"));
EXPECT_EQ("on", res->get_header_value("POST_ROUTING"));
}

{
Client cli(HOST, PORT);

auto res = cli.Get("/hi");
ASSERT_TRUE(res);
EXPECT_EQ(200, res->status);
EXPECT_EQ("Hello World!\n", res->body);
EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
}

{
Client cli(HOST, PORT);

auto res = cli.Get("/aaa");
ASSERT_TRUE(res);
EXPECT_EQ(404, res->status);
EXPECT_EQ("Error", res->body);
EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
}

svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
}

TEST(InvalidFormatTest, StatusCode) {
Server svr;

Expand Down