Skip to content

Commit 6b8ba4e

Browse files
committed
Added middleware support
1 parent ddf41d2 commit 6b8ba4e

File tree

2 files changed

+105
-9
lines changed

2 files changed

+105
-9
lines changed

httplib.h

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ inline void default_socket_options(socket_t sock) {
597597
class Server {
598598
public:
599599
using Handler = std::function<void(const Request &, Response &)>;
600+
using HandlerWithReturn = std::function<bool(const Request &, Response &)>;
600601
using HandlerWithContentReader = std::function<void(
601602
const Request &, Response &, const ContentReader &content_reader)>;
602603
using Expect100ContinueHandler =
@@ -627,7 +628,11 @@ class Server {
627628
const char *mime);
628629
void set_file_request_handler(Handler handler);
629630

631+
void set_error_handler(HandlerWithReturn handler);
630632
void set_error_handler(Handler handler);
633+
void set_pre_routing_handler(HandlerWithReturn handler);
634+
void set_post_routing_handler(Handler handler);
635+
631636
void set_expect_100_continue_handler(Expect100ContinueHandler handler);
632637
void set_logger(Logger logger);
633638

@@ -734,7 +739,9 @@ class Server {
734739
Handlers delete_handlers_;
735740
HandlersForContentReader delete_handlers_for_content_reader_;
736741
Handlers options_handlers_;
737-
Handler error_handler_;
742+
HandlerWithReturn error_handler_;
743+
HandlerWithReturn pre_routing_handler_;
744+
Handler post_routing_handler_;
738745
Logger logger_;
739746
Expect100ContinueHandler expect_100_continue_handler_;
740747

@@ -4160,14 +4167,23 @@ inline void Server::set_file_request_handler(Handler handler) {
41604167
file_request_handler_ = std::move(handler);
41614168
}
41624169

4163-
inline void Server::set_error_handler(Handler handler) {
4170+
inline void Server::set_error_handler(HandlerWithReturn handler) {
41644171
error_handler_ = std::move(handler);
41654172
}
41664173

4167-
inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
4174+
inline void Server::set_error_handler(Handler handler) {
4175+
error_handler_ = [handler](const Request &req, Response &res) {
4176+
handler(req, res);
4177+
return true;
4178+
};
4179+
}
41684180

4169-
inline void Server::set_socket_options(SocketOptions socket_options) {
4170-
socket_options_ = std::move(socket_options);
4181+
inline void Server::set_pre_routing_handler(HandlerWithReturn handler) {
4182+
pre_routing_handler_ = std::move(handler);
4183+
}
4184+
4185+
inline void Server::set_post_routing_handler(Handler handler) {
4186+
post_routing_handler_ = std::move(handler);
41714187
}
41724188

41734189
inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); }
@@ -4177,6 +4193,12 @@ Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) {
41774193
expect_100_continue_handler_ = std::move(handler);
41784194
}
41794195

4196+
inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
4197+
4198+
inline void Server::set_socket_options(SocketOptions socket_options) {
4199+
socket_options_ = std::move(socket_options);
4200+
}
4201+
41804202
inline void Server::set_keep_alive_max_count(size_t count) {
41814203
keep_alive_max_count_ = count;
41824204
}
@@ -4268,16 +4290,15 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
42684290
bool need_apply_ranges) {
42694291
assert(res.status != -1);
42704292

4271-
if (400 <= res.status && error_handler_) {
4272-
error_handler_(req, res);
4293+
if (400 <= res.status && error_handler_ && error_handler_(req, res)) {
42734294
need_apply_ranges = true;
42744295
}
42754296

42764297
std::string content_type;
42774298
std::string boundary;
42784299
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
42794300

4280-
// Preapre additional headers
4301+
// Prepare additional headers
42814302
if (close_connection || req.get_header_value("Connection") == "close") {
42824303
res.set_header("Connection", "close");
42834304
} else {
@@ -4301,6 +4322,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
43014322
res.set_header("Accept-Ranges", "bytes");
43024323
}
43034324

4325+
if (post_routing_handler_) { post_routing_handler_(req, res); }
4326+
43044327
// Response line and headers
43054328
{
43064329
detail::BufferStream bstrm;
@@ -4604,6 +4627,8 @@ inline bool Server::listen_internal() {
46044627
}
46054628

46064629
inline bool Server::routing(Request &req, Response &res, Stream &strm) {
4630+
if (pre_routing_handler_ && pre_routing_handler_(req, res)) { return true; }
4631+
46074632
// File handler
46084633
bool is_head_request = req.method == "HEAD";
46094634
if ((req.method == "GET" || is_head_request) &&
@@ -5302,7 +5327,7 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm,
53025327

53035328
inline bool ClientImpl::write_request(Stream &strm, const Request &req,
53045329
bool close_connection, Error &error) {
5305-
// Prepare additonal headers
5330+
// Prepare additional headers
53065331
Headers headers;
53075332
if (close_connection) { headers.emplace("Connection", "close"); }
53085333

test/test.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,77 @@ TEST(ErrorHandlerTest, ContentLength) {
953953
ASSERT_FALSE(svr.is_running());
954954
}
955955

956+
TEST(RoutingHandlerTest, PreRoutingHandler) {
957+
Server svr;
958+
959+
svr.set_pre_routing_handler([](const Request &req, Response &res) {
960+
if (req.path == "/routing_handler") {
961+
res.set_header("PRE_ROUTING", "on");
962+
res.set_content("Routing Handler", "text/plain");
963+
return true;
964+
}
965+
return false;
966+
});
967+
968+
svr.set_error_handler([](const Request & /*req*/, Response &res) {
969+
res.set_content("Error", "text/html");
970+
});
971+
972+
svr.set_post_routing_handler([](const Request &req, Response &res) {
973+
if (req.path == "/routing_handler") {
974+
res.set_header("POST_ROUTING", "on");
975+
}
976+
});
977+
978+
svr.Get("/hi", [](const Request & /*req*/, Response &res) {
979+
res.set_content("Hello World!\n", "text/plain");
980+
});
981+
982+
auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
983+
984+
// Give GET time to get a few messages.
985+
std::this_thread::sleep_for(std::chrono::seconds(1));
986+
987+
{
988+
Client cli(HOST, PORT);
989+
990+
auto res = cli.Get("/routing_handler");
991+
ASSERT_TRUE(res);
992+
EXPECT_EQ(200, res->status);
993+
EXPECT_EQ("Routing Handler", res->body);
994+
EXPECT_EQ(1, res->get_header_value_count("PRE_ROUTING"));
995+
EXPECT_EQ("on", res->get_header_value("PRE_ROUTING"));
996+
EXPECT_EQ(1, res->get_header_value_count("POST_ROUTING"));
997+
EXPECT_EQ("on", res->get_header_value("POST_ROUTING"));
998+
}
999+
1000+
{
1001+
Client cli(HOST, PORT);
1002+
1003+
auto res = cli.Get("/hi");
1004+
ASSERT_TRUE(res);
1005+
EXPECT_EQ(200, res->status);
1006+
EXPECT_EQ("Hello World!\n", res->body);
1007+
EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
1008+
EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
1009+
}
1010+
1011+
{
1012+
Client cli(HOST, PORT);
1013+
1014+
auto res = cli.Get("/aaa");
1015+
ASSERT_TRUE(res);
1016+
EXPECT_EQ(404, res->status);
1017+
EXPECT_EQ("Error", res->body);
1018+
EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
1019+
EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
1020+
}
1021+
1022+
svr.stop();
1023+
thread.join();
1024+
ASSERT_FALSE(svr.is_running());
1025+
}
1026+
9561027
TEST(InvalidFormatTest, StatusCode) {
9571028
Server svr;
9581029

0 commit comments

Comments
 (0)