Skip to content

Support LOCAL_ADDR and LOCAL_PORT header in client Request #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ struct Request {

std::string remote_addr;
int remote_port = -1;
std::string local_addr;
int local_port = -1;

// for server
std::string version;
Expand Down Expand Up @@ -514,6 +516,7 @@ class Stream {
virtual ssize_t read(char *ptr, size_t size) = 0;
virtual ssize_t write(const char *ptr, size_t size) = 0;
virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0;
virtual socket_t socket() const = 0;

template <typename... Args>
Expand Down Expand Up @@ -1778,6 +1781,7 @@ class BufferStream : public Stream {
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
void get_local_ip_and_port(std::string &ip, int &port) const override;
socket_t socket() const override;

const std::string &get_buffer() const;
Expand Down Expand Up @@ -2446,6 +2450,7 @@ class SocketStream : public Stream {
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
void get_local_ip_and_port(std::string &ip, int &port) const override;
socket_t socket() const override;

private:
Expand Down Expand Up @@ -2475,6 +2480,7 @@ class SSLSocketStream : public Stream {
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
void get_local_ip_and_port(std::string &ip, int &port) const override;
socket_t socket() const override;

private:
Expand Down Expand Up @@ -2843,9 +2849,9 @@ inline socket_t create_client_socket(
return sock;
}

inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr,
socklen_t addr_len, std::string &ip,
int &port) {
inline bool get_ip_and_port(const struct sockaddr_storage &addr,
socklen_t addr_len, std::string &ip,
int &port) {
if (addr.ss_family == AF_INET) {
port = ntohs(reinterpret_cast<const struct sockaddr_in *>(&addr)->sin_port);
} else if (addr.ss_family == AF_INET6) {
Expand All @@ -2866,6 +2872,15 @@ inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr,
return true;
}

inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
if (!getsockname(sock, reinterpret_cast<struct sockaddr *>(&addr),
&addr_len)) {
get_ip_and_port(addr, addr_len, ip, port);
}
}

inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
Expand All @@ -2890,7 +2905,7 @@ inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
return;
}
#endif
get_remote_ip_and_port(addr, addr_len, ip, port);
get_ip_and_port(addr, addr_len, ip, port);
}
}

Expand Down Expand Up @@ -4517,8 +4532,8 @@ inline void hosted_at(const std::string &hostname,
*reinterpret_cast<struct sockaddr_storage *>(rp->ai_addr);
std::string ip;
int dummy = -1;
if (detail::get_remote_ip_and_port(addr, sizeof(struct sockaddr_storage),
ip, dummy)) {
if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage),
ip, dummy)) {
addrs.push_back(ip);
}
}
Expand Down Expand Up @@ -4808,6 +4823,11 @@ inline void SocketStream::get_remote_ip_and_port(std::string &ip,
return detail::get_remote_ip_and_port(sock_, ip, port);
}

inline void SocketStream::get_local_ip_and_port(std::string &ip,
int &port) const {
return detail::get_local_ip_and_port(sock_, ip, port);
}

inline socket_t SocketStream::socket() const { return sock_; }

// Buffer stream implementation
Expand All @@ -4833,6 +4853,9 @@ inline ssize_t BufferStream::write(const char *ptr, size_t size) {
inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/,
int & /*port*/) const {}

inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/,
int & /*port*/) const {}

inline socket_t BufferStream::socket() const { return 0; }

inline const std::string &BufferStream::get_buffer() const { return buffer; }
Expand Down Expand Up @@ -5812,6 +5835,10 @@ Server::process_request(Stream &strm, bool close_connection,
req.set_header("REMOTE_ADDR", req.remote_addr);
req.set_header("REMOTE_PORT", std::to_string(req.remote_port));

strm.get_local_ip_and_port(req.local_addr, req.local_port);
req.set_header("LOCAL_ADDR", req.local_addr);
req.set_header("LOCAL_PORT", std::to_string(req.local_port));

if (req.has_header("Range")) {
const auto &range_header_value = req.get_header_value("Range");
if (!detail::parse_range_header(range_header_value, req.ranges)) {
Expand Down Expand Up @@ -7409,6 +7436,11 @@ inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip,
detail::get_remote_ip_and_port(sock_, ip, port);
}

inline void SSLSocketStream::get_local_ip_and_port(std::string &ip,
int &port) const {
detail::get_local_ip_and_port(sock_, ip, port);
}

inline socket_t SSLSocketStream::socket() const { return sock_; }

static SSLInit sslinit_;
Expand Down
7 changes: 5 additions & 2 deletions test/fuzzing/server_fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class FuzzedStream : public httplib::Stream {

ssize_t write(const std::string &s) { return write(s.data(), s.size()); }

std::string get_remote_addr() const { return ""; }

bool is_readable() const override { return true; }

bool is_writable() const override { return true; }
Expand All @@ -33,6 +31,11 @@ class FuzzedStream : public httplib::Stream {
port = 8080;
}

void get_local_ip_and_port(std::string &ip, int &port) const override {
ip = "127.0.0.1";
port = 8080;
}

socket_t socket() const override { return 0; }

private:
Expand Down
20 changes: 20 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,17 @@ class ServerTest : public ::testing::Test {
std::stoi(req.get_header_value("REMOTE_PORT")));
res.set_content(remote_addr.c_str(), "text/plain");
})
.Get("/local_addr",
[&](const Request &req, Response &res) {
EXPECT_TRUE(req.has_header("LOCAL_PORT"));
EXPECT_TRUE(req.has_header("LOCAL_ADDR"));
auto local_addr = req.get_header_value("LOCAL_ADDR");
auto local_port = req.get_header_value("LOCAL_PORT");
EXPECT_EQ(req.local_addr, local_addr);
EXPECT_EQ(req.local_port, std::stoi(local_port));
res.set_content(local_addr.append(":").append(local_port),
"text/plain");
})
.Get("/endwith%",
[&](const Request & /*req*/, Response &res) {
res.set_content("Hello World!", "text/plain");
Expand Down Expand Up @@ -2810,6 +2821,15 @@ TEST_F(ServerTest, GetMethodRemoteAddr) {
EXPECT_TRUE(res->body == "::1" || res->body == "127.0.0.1");
}

TEST_F(ServerTest, GetMethodLocalAddr) {
auto res = cli_.Get("/local_addr");
ASSERT_TRUE(res);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_TRUE(res->body == std::string("::1:").append(to_string(PORT)) ||
res->body == std::string("127.0.0.1:").append(to_string(PORT)));
}

TEST_F(ServerTest, HTTPResponseSplitting) {
auto res = cli_.Get("/http_response_splitting");
ASSERT_TRUE(res);
Expand Down