Skip to content

Simplify CORS checks and don't restrict host names. #7876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ports/espressif/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ STATIC void socket_select_task(void *arg) {
}

assert(num_triggered > 0);
assert(!FD_ISSET(socket_change_fd, &excptfds));

// Notice event trigger
if (FD_ISSET(socket_change_fd, &readfds)) {
Expand Down
80 changes: 34 additions & 46 deletions supervisor/shared/web_workflow/web_workflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
* THE SOFTWARE.
*/

// Include strchrnul()
#define _GNU_SOURCE

#include <stdarg.h>
#include <string.h>

Expand Down Expand Up @@ -85,8 +88,8 @@ typedef struct {
char destination[256];
char header_key[64];
char header_value[256];
// We store the origin so we can reply back with it.
char origin[64];
char origin[64]; // We store the origin so we can reply back with it.
char host[64]; // We store the host to check against origin.
size_t content_length;
size_t offset;
uint64_t timestamp_ms;
Expand Down Expand Up @@ -454,49 +457,33 @@ static bool _endswith(const char *str, const char *suffix) {
return strcmp(str + (strlen(str) - strlen(suffix)), suffix) == 0;
}

const char *ok_hosts[] = {
"127.0.0.1",
"localhost",
};

static bool _origin_ok(const char *origin) {
const char *http = "http://";
const char http_scheme[] = "http://";
#define PREFIX_HTTP_LEN (sizeof(http_scheme) - 1)

// note: redirected requests send an Origin of "null" and will be caught by this
if (strncmp(origin, http, strlen(http)) != 0) {
return false;
}
// These are prefix checks up to : so that any port works.
// TODO: Support DHCP hostname in addition to MDNS.
const char *end;
#if CIRCUITPY_MDNS
if (!common_hal_mdns_server_deinited(&mdns)) {
const char *local = ".local";
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
end = origin + strlen(http) + strlen(hostname) + strlen(local);
if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
return true;
}
static bool _origin_ok(_request *request) {
// Origin may be 'null'
if (request->origin[0] == '\0') {
return true;
}
#endif

_update_encoded_ip();
end = origin + strlen(http) + strlen(_our_ip_encoded);
if (strncmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
// Origin has http prefix?
if (strncmp(request->origin, http_scheme, PREFIX_HTTP_LEN) != 0) {
// Not HTTP scheme request - ok
request->origin[0] = '\0';
return true;
}

for (size_t i = 0; i < MP_ARRAY_SIZE(ok_hosts); i++) {
// Allows any port
end = origin + strlen(http) + strlen(ok_hosts[i]);
if (strncmp(origin + strlen(http), ok_hosts[i], strlen(ok_hosts[i])) == 0
&& (end[0] == '\0' || end[0] == ':')) {
// Host given?
if (request->host[0] != '\0') {
// OK if host and origin match (fqdn + port #)
if (strcmp(request->host, &request->origin[PREFIX_HTTP_LEN]) == 0) {
return true;
}
// DEBUG: OK if origin is 'localhost' (ignoring port #)
*strchrnul(&request->origin[PREFIX_HTTP_LEN], ':') = '\0';
if (strcmp(&request->origin[PREFIX_HTTP_LEN], "localhost") == 0) {
return true;
}
}
// Otherwise deny request
return false;
}

Expand All @@ -517,8 +504,8 @@ static void _cors_header(socketpool_socket_obj_t *socket, _request *request) {
_send_strs(socket,
"Access-Control-Allow-Credentials: true\r\n",
"Vary: Origin, Accept, Upgrade\r\n",
"Access-Control-Allow-Origin: *\r\n",
NULL);
"Access-Control-Allow-Origin: ",
(request->origin[0] == '\0') ? "*" : request->origin, "\r\n", NULL);
}

static void _reply_continue(socketpool_socket_obj_t *socket, _request *request) {
Expand Down Expand Up @@ -1086,11 +1073,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
#else
_reply_missing(socket, request);
#endif

// For now until CORS is sorted, allow always the origin requester.
// Note: caller knows who we are better than us. CORS is not security
// unless browser cooperates. Do not rely on mDNS or IP.
} else if (strlen(request->origin) > 0 && !_origin_ok(request->origin)) {
} else if (!_origin_ok(request)) {
_reply_forbidden(socket, request);
} else if (strncmp(request->path, "/fs/", 4) == 0) {
if (strcasecmp(request->method, "OPTIONS") == 0) {
Expand Down Expand Up @@ -1314,6 +1297,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
static void _reset_request(_request *request) {
request->state = STATE_METHOD;
request->origin[0] = '\0';
request->host[0] = '\0';
request->content_length = 0;
request->offset = 0;
request->timestamp_ms = 0;
Expand All @@ -1340,6 +1324,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
if (len == 0 || len == -MP_ENOTCONN) {
// Disconnect - clear 'in-progress'
_reset_request(request);
common_hal_socketpool_socket_close(socket);
}
break;
}
Expand Down Expand Up @@ -1421,14 +1406,17 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
request->redirect = strncmp(request->header_value, cp_local, strlen(cp_local)) == 0 &&
(strlen(request->header_value) == strlen(cp_local) ||
request->header_value[strlen(cp_local)] == ':');
strncpy(request->host, request->header_value, sizeof(request->host) - 1);
request->host[sizeof(request->host) - 1] = '\0';
} else if (strcasecmp(request->header_key, "Content-Length") == 0) {
request->content_length = strtoul(request->header_value, NULL, 10);
} else if (strcasecmp(request->header_key, "Expect") == 0) {
request->expect = strcmp(request->header_value, "100-continue") == 0;
} else if (strcasecmp(request->header_key, "Accept") == 0) {
request->json = strcasecmp(request->header_value, "application/json") == 0;
} else if (strcasecmp(request->header_key, "Origin") == 0) {
strcpy(request->origin, request->header_value);
strncpy(request->origin, request->header_value, sizeof(request->origin) - 1);
request->origin[sizeof(request->origin) - 1] = '\0';
} else if (strcasecmp(request->header_key, "X-Timestamp") == 0) {
request->timestamp_ms = strtoull(request->header_value, NULL, 10);
} else if (strcasecmp(request->header_key, "Upgrade") == 0) {
Expand Down