Skip to content

implement omit_trailing_slash feature for PyUrl [7186] #1218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,13 @@ class Url(SupportsAllComparisons):
by Mozilla.
"""

def __new__(cls, url: str) -> Self:
def __new__(cls, url: str, omit_trailing_slash: bool = False) -> Self:
"""
Create a new `Url` instance.

Args:
url: String representation of a URL.
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")

Returns:
A new `Url` instance.
Expand Down Expand Up @@ -590,12 +591,13 @@ class MultiHostUrl(SupportsAllComparisons):
by Mozilla.
"""

def __new__(cls, url: str) -> Self:
def __new__(cls, url: str, omit_trailing_slash: bool = False) -> Self:
"""
Create a new `MultiHostUrl` instance.

Args:
url: String representation of a URL.
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")

Returns:
A new `MultiHostUrl` instance.
Expand Down
8 changes: 8 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3454,6 +3454,7 @@ class UrlSchema(TypedDict, total=False):
default_host: str
default_port: int
default_path: str
omit_trailing_slash: bool # default False
strict: bool
ref: str
metadata: Any
Expand All @@ -3468,6 +3469,7 @@ def url_schema(
default_host: str | None = None,
default_port: int | None = None,
default_path: str | None = None,
omit_trailing_slash: bool | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
Expand All @@ -3492,6 +3494,7 @@ def url_schema(
default_host: The default host to use if the URL does not have a host
default_port: The default port to use if the URL does not have a port
default_path: The default path to use if the URL does not have a path
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
strict: Whether to use strict URL parsing
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3506,6 +3509,7 @@ def url_schema(
default_port=default_port,
default_path=default_path,
strict=strict,
omit_trailing_slash=omit_trailing_slash,
ref=ref,
metadata=metadata,
serialization=serialization,
Expand All @@ -3520,6 +3524,7 @@ class MultiHostUrlSchema(TypedDict, total=False):
default_host: str
default_port: int
default_path: str
omit_trailing_slash: bool
strict: bool
ref: str
metadata: Any
Expand All @@ -3534,6 +3539,7 @@ def multi_host_url_schema(
default_host: str | None = None,
default_port: int | None = None,
default_path: str | None = None,
omit_trailing_slash: bool | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
Expand All @@ -3558,6 +3564,7 @@ def multi_host_url_schema(
default_host: The default host to use if the URL does not have a host
default_port: The default port to use if the URL does not have a port
default_path: The default path to use if the URL does not have a path
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
strict: Whether to use strict URL parsing
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3571,6 +3578,7 @@ def multi_host_url_schema(
default_host=default_host,
default_port=default_port,
default_path=default_path,
omit_trailing_slash=omit_trailing_slash,
strict=strict,
ref=ref,
metadata=metadata,
Expand Down
2 changes: 1 addition & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
ObType::Url => {
let py_url: PyUrl = value.extract().map_err(py_err_se_err)?;
serializer.serialize_str(py_url.__str__())
serializer.serialize_str(py_url.__str__().as_str())
}
ObType::MultiHostUrl => {
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
Expand Down
111 changes: 87 additions & 24 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,49 @@ use crate::tools::SchemaDict;
use crate::SchemaValidator;

static SCHEMA_DEFINITION_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
static SCHEMA_DEFINITION_OMIT_SLASH_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();

#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass)]
#[derive(Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct PyUrl {
lib_url: Url,
omit_trailing_slash: bool,
}

impl PyUrl {
pub fn new(lib_url: Url) -> Self {
Self { lib_url }
pub fn new(lib_url: Url, omit_trailing_slash: Option<bool>) -> Self {
Self {
lib_url,
omit_trailing_slash: omit_trailing_slash.unwrap_or(false),
}
}

pub fn into_url(self) -> Url {
self.lib_url
}
}

fn build_schema_validator(py: Python, schema_type: &str) -> SchemaValidator {
fn build_schema_validator(py: Python, schema_type: &str, omit_trailing_slash: bool) -> SchemaValidator {
let schema: &PyDict = PyDict::new(py);
schema.set_item("type", schema_type).unwrap();
// TODO: it seems wrong, do it?
schema.set_item("omit_trailing_slash", omit_trailing_slash).unwrap();
SchemaValidator::py_new(py, schema, None).unwrap()
}

#[pymethods]
impl PyUrl {
#[new]
pub fn py_new(py: Python, url: &PyAny) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_URL
.get_or_init(py, || build_schema_validator(py, "url"))
pub fn py_new(py: Python, url: &PyAny, omit_trailing_slash: Option<bool>) -> PyResult<Self> {
let omit = omit_trailing_slash.unwrap_or(false);
let schema = if omit {
&SCHEMA_DEFINITION_OMIT_SLASH_URL
} else {
&SCHEMA_DEFINITION_URL
};
let schema_obj = schema
.get_or_init(py, || build_schema_validator(py, "url", omit))
.validate_python(py, url, None, None, None, None)?;
schema_obj.extract(py)
}
Expand Down Expand Up @@ -89,6 +102,7 @@ impl PyUrl {
pub fn path(&self) -> Option<&str> {
match self.lib_url.path() {
"" => None,
path if self.omit_trailing_slash && path == "/" => None,
path => Some(path),
}
}
Expand All @@ -114,15 +128,21 @@ impl PyUrl {

// string representation of the URL, with punycode decoded when appropriate
pub fn unicode_string(&self) -> String {
unicode_url(&self.lib_url)
unicode_url(&self.lib_url, self.omit_trailing_slash)
}

pub fn __str__(&self) -> &str {
self.lib_url.as_str()
pub fn __str__(&self) -> String {
if self.omit_trailing_slash && self.lib_url.path() == "/" {
let start = before_path_length(&self.lib_url);
let mut s = self.lib_url.to_string();
s.replace_range(start..=start, "");
return s;
}
self.lib_url.to_string()
}

pub fn __repr__(&self) -> String {
format!("Url('{}')", self.lib_url)
format!("Url('{}')", self.__str__())
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
Expand Down Expand Up @@ -151,12 +171,12 @@ impl PyUrl {
self.clone().into_py(py)
}

fn __getnewargs__(&self) -> (&str,) {
fn __getnewargs__(&self) -> (String,) {
(self.__str__(),)
}

#[classmethod]
#[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))]
#[pyo3(signature = (*, scheme, host, username = None, password = None, port = None, path = None, query = None, fragment = None))]
#[allow(clippy::too_many_arguments)]
pub fn build<'a>(
cls: &'a PyType,
Expand Down Expand Up @@ -198,13 +218,15 @@ impl PyUrl {
pub struct PyMultiHostUrl {
ref_url: PyUrl,
extra_urls: Option<Vec<Url>>,
omit_trailing_slash: bool,
}

impl PyMultiHostUrl {
pub fn new(ref_url: Url, extra_urls: Option<Vec<Url>>) -> Self {
pub fn new(ref_url: Url, extra_urls: Option<Vec<Url>>, omit_trailing_slash: Option<bool>) -> Self {
Self {
ref_url: PyUrl::new(ref_url),
ref_url: PyUrl::new(ref_url, omit_trailing_slash),
extra_urls,
omit_trailing_slash: omit_trailing_slash.unwrap_or(false),
}
}

Expand All @@ -214,13 +236,20 @@ impl PyMultiHostUrl {
}

static SCHEMA_DEFINITION_MULTI_HOST_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
static SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();

#[pymethods]
impl PyMultiHostUrl {
#[new]
pub fn py_new(py: Python, url: &PyAny) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL
.get_or_init(py, || build_schema_validator(py, "multi-host-url"))
pub fn py_new(py: Python, url: &PyAny, omit_trailing_slash: Option<bool>) -> PyResult<Self> {
let omit = omit_trailing_slash.unwrap_or(false);
let schema = if omit {
&SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL
} else {
&SCHEMA_DEFINITION_MULTI_HOST_URL
};
let schema_obj = schema
.get_or_init(py, || build_schema_validator(py, "multi-host-url", omit))
.validate_python(py, url, None, None, None, None)?;
schema_obj.extract(py)
}
Expand Down Expand Up @@ -281,8 +310,9 @@ impl PyMultiHostUrl {
let hosts = extra_urls
.iter()
.map(|url| {
let str = unicode_url(url);
str[host_offset..str.len() - sub].to_string()
let str = unicode_url(url, self.omit_trailing_slash);
let _sub = if self.omit_trailing_slash { 0 } else { sub };
str[host_offset..str.len() - _sub].to_string()
})
.collect::<Vec<String>>()
.join(",");
Expand All @@ -298,7 +328,7 @@ impl PyMultiHostUrl {
let schema = self.ref_url.lib_url.scheme();
let host_offset = schema.len() + 3;

let mut full_url = self.ref_url.lib_url.to_string();
let mut full_url = self.ref_url.__str__();
full_url.insert(host_offset, ',');

// special urls will have had a trailing slash added, non-special urls will not
Expand Down Expand Up @@ -356,7 +386,7 @@ impl PyMultiHostUrl {
}

#[classmethod]
#[pyo3(signature=(*, scheme, hosts=None, path=None, query=None, fragment=None, host=None, username=None, password=None, port=None))]
#[pyo3(signature = (*, scheme, hosts = None, path = None, query = None, fragment = None, host = None, username = None, password = None, port = None))]
#[allow(clippy::too_many_arguments)]
pub fn build<'a>(
cls: &'a PyType,
Expand Down Expand Up @@ -480,19 +510,34 @@ fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> {
Ok(dict)
}

fn unicode_url(lib_url: &Url) -> String {
fn unicode_url(lib_url: &Url, omit_trailing_slash: bool) -> String {
let mut s = lib_url.to_string();

match lib_url.host() {
Some(url::Host::Domain(domain)) if is_punnycode_domain(lib_url, domain) => {
if let Some(decoded) = decode_punycode(domain) {
// replace the range containing the punycode domain with the decoded domain
let start = lib_url.scheme().len() + 3;
let before_path = before_path_length(lib_url);
let start = before_path
- domain.len()
- match lib_url.port() {
Some(port) => 1 + port.to_string().len(),
None => 0,
};
if omit_trailing_slash && lib_url.path() == "/" {
s.replace_range(before_path..=before_path, "");
}
s.replace_range(start..start + domain.len(), &decoded);
}
s
}
_ => s,
_ => {
if omit_trailing_slash && lib_url.path() == "/" {
let before_path = before_path_length(lib_url);
s.replace_range(before_path..=before_path, "");
}
s
}
}
}

Expand Down Expand Up @@ -520,3 +565,21 @@ fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool {
pub fn schema_is_special(schema: &str) -> bool {
matches!(schema, "http" | "https" | "ws" | "wss" | "ftp" | "file")
}

fn before_path_length(url: &Url) -> usize {
let length = url.scheme().len()
+ 3 // :// part
+ match url.username() {
"" => 0,
// for colon (:) and at (@) signs we're adding +2
username => 2 + username.len() + url.password().unwrap_or("").len(),
}
+ url.host_str().unwrap().len()
+ match url.port() {
// for colon (:) +1
Some(port) => 1 + port.to_string().len(),
None => 0,
};

length
}
Loading