Skip to content

Commit ca4b802

Browse files
committed
implement omit_trailing_slash feature for PyUrl
1 parent c6301fe commit ca4b802

File tree

6 files changed

+329
-32
lines changed

6 files changed

+329
-32
lines changed

python/pydantic_core/_pydantic_core.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,12 +454,13 @@ class Url(SupportsAllComparisons):
454454
by Mozilla.
455455
"""
456456

457-
def __new__(cls, url: str) -> Self:
457+
def __new__(cls, url: str, omit_trailing_slash: bool = False) -> Self:
458458
"""
459459
Create a new `Url` instance.
460460
461461
Args:
462462
url: String representation of a URL.
463+
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
463464
464465
Returns:
465466
A new `Url` instance.
@@ -590,12 +591,13 @@ class MultiHostUrl(SupportsAllComparisons):
590591
by Mozilla.
591592
"""
592593

593-
def __new__(cls, url: str) -> Self:
594+
def __new__(cls, url: str, omit_trailing_slash: bool = False) -> Self:
594595
"""
595596
Create a new `MultiHostUrl` instance.
596597
597598
Args:
598599
url: String representation of a URL.
600+
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
599601
600602
Returns:
601603
A new `MultiHostUrl` instance.

python/pydantic_core/core_schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,6 +3454,7 @@ class UrlSchema(TypedDict, total=False):
34543454
default_host: str
34553455
default_port: int
34563456
default_path: str
3457+
omit_trailing_slash: bool # default False
34573458
strict: bool
34583459
ref: str
34593460
metadata: Any
@@ -3468,6 +3469,7 @@ def url_schema(
34683469
default_host: str | None = None,
34693470
default_port: int | None = None,
34703471
default_path: str | None = None,
3472+
omit_trailing_slash: bool | None = None,
34713473
strict: bool | None = None,
34723474
ref: str | None = None,
34733475
metadata: Any = None,
@@ -3492,6 +3494,7 @@ def url_schema(
34923494
default_host: The default host to use if the URL does not have a host
34933495
default_port: The default port to use if the URL does not have a port
34943496
default_path: The default path to use if the URL does not have a path
3497+
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
34953498
strict: Whether to use strict URL parsing
34963499
ref: optional unique identifier of the schema, used to reference the schema in other places
34973500
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -3506,6 +3509,7 @@ def url_schema(
35063509
default_port=default_port,
35073510
default_path=default_path,
35083511
strict=strict,
3512+
omit_trailing_slash=omit_trailing_slash,
35093513
ref=ref,
35103514
metadata=metadata,
35113515
serialization=serialization,
@@ -3520,6 +3524,7 @@ class MultiHostUrlSchema(TypedDict, total=False):
35203524
default_host: str
35213525
default_port: int
35223526
default_path: str
3527+
omit_trailing_slash: bool
35233528
strict: bool
35243529
ref: str
35253530
metadata: Any
@@ -3534,6 +3539,7 @@ def multi_host_url_schema(
35343539
default_host: str | None = None,
35353540
default_port: int | None = None,
35363541
default_path: str | None = None,
3542+
omit_trailing_slash: bool | None = None,
35373543
strict: bool | None = None,
35383544
ref: str | None = None,
35393545
metadata: Any = None,
@@ -3558,6 +3564,7 @@ def multi_host_url_schema(
35583564
default_host: The default host to use if the URL does not have a host
35593565
default_port: The default port to use if the URL does not have a port
35603566
default_path: The default path to use if the URL does not have a path
3567+
omit_trailing_slash: Whether to omit trailing slash (only if path == "/")
35613568
strict: Whether to use strict URL parsing
35623569
ref: optional unique identifier of the schema, used to reference the schema in other places
35633570
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -3571,6 +3578,7 @@ def multi_host_url_schema(
35713578
default_host=default_host,
35723579
default_port=default_port,
35733580
default_path=default_path,
3581+
omit_trailing_slash=omit_trailing_slash,
35743582
strict=strict,
35753583
ref=ref,
35763584
metadata=metadata,

src/serializers/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
444444
}
445445
ObType::Url => {
446446
let py_url: PyUrl = value.extract().map_err(py_err_se_err)?;
447-
serializer.serialize_str(py_url.__str__())
447+
serializer.serialize_str(py_url.__str__().as_str())
448448
}
449449
ObType::MultiHostUrl => {
450450
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;

src/url.rs

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,49 @@ use crate::tools::SchemaDict;
1515
use crate::SchemaValidator;
1616

1717
static SCHEMA_DEFINITION_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
18+
static SCHEMA_DEFINITION_OMIT_SLASH_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
1819

1920
#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass)]
2021
#[derive(Clone)]
2122
#[cfg_attr(debug_assertions, derive(Debug))]
2223
pub struct PyUrl {
2324
lib_url: Url,
25+
omit_trailing_slash: bool,
2426
}
2527

2628
impl PyUrl {
27-
pub fn new(lib_url: Url) -> Self {
28-
Self { lib_url }
29+
pub fn new(lib_url: Url, omit_trailing_slash: Option<bool>) -> Self {
30+
Self {
31+
lib_url,
32+
omit_trailing_slash: omit_trailing_slash.unwrap_or(false),
33+
}
2934
}
3035

3136
pub fn into_url(self) -> Url {
3237
self.lib_url
3338
}
3439
}
3540

36-
fn build_schema_validator(py: Python, schema_type: &str) -> SchemaValidator {
41+
fn build_schema_validator(py: Python, schema_type: &str, omit_trailing_slash: bool) -> SchemaValidator {
3742
let schema: &PyDict = PyDict::new(py);
3843
schema.set_item("type", schema_type).unwrap();
44+
// TODO: it seems wrong, do it?
45+
schema.set_item("omit_trailing_slash", omit_trailing_slash).unwrap();
3946
SchemaValidator::py_new(py, schema, None).unwrap()
4047
}
4148

4249
#[pymethods]
4350
impl PyUrl {
4451
#[new]
45-
pub fn py_new(py: Python, url: &PyAny) -> PyResult<Self> {
46-
let schema_obj = SCHEMA_DEFINITION_URL
47-
.get_or_init(py, || build_schema_validator(py, "url"))
52+
pub fn py_new(py: Python, url: &PyAny, omit_trailing_slash: Option<bool>) -> PyResult<Self> {
53+
let omit = omit_trailing_slash.unwrap_or(false);
54+
let schema = if omit {
55+
&SCHEMA_DEFINITION_OMIT_SLASH_URL
56+
} else {
57+
&SCHEMA_DEFINITION_URL
58+
};
59+
let schema_obj = schema
60+
.get_or_init(py, || build_schema_validator(py, "url", omit))
4861
.validate_python(py, url, None, None, None, None)?;
4962
schema_obj.extract(py)
5063
}
@@ -89,6 +102,7 @@ impl PyUrl {
89102
pub fn path(&self) -> Option<&str> {
90103
match self.lib_url.path() {
91104
"" => None,
105+
path if self.omit_trailing_slash && path == "/" => None,
92106
path => Some(path),
93107
}
94108
}
@@ -114,15 +128,21 @@ impl PyUrl {
114128

115129
// string representation of the URL, with punycode decoded when appropriate
116130
pub fn unicode_string(&self) -> String {
117-
unicode_url(&self.lib_url)
131+
unicode_url(&self.lib_url, self.omit_trailing_slash)
118132
}
119133

120-
pub fn __str__(&self) -> &str {
121-
self.lib_url.as_str()
134+
pub fn __str__(&self) -> String {
135+
if self.omit_trailing_slash && self.lib_url.path() == "/" {
136+
let start = before_path_length(&self.lib_url);
137+
let mut s = self.lib_url.to_string();
138+
s.replace_range(start..=start, "");
139+
return s;
140+
}
141+
self.lib_url.to_string()
122142
}
123143

124144
pub fn __repr__(&self) -> String {
125-
format!("Url('{}')", self.lib_url)
145+
format!("Url('{}')", self.__str__())
126146
}
127147

128148
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
@@ -151,12 +171,12 @@ impl PyUrl {
151171
self.clone().into_py(py)
152172
}
153173

154-
fn __getnewargs__(&self) -> (&str,) {
174+
fn __getnewargs__(&self) -> (String,) {
155175
(self.__str__(),)
156176
}
157177

158178
#[classmethod]
159-
#[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))]
179+
#[pyo3(signature = (*, scheme, host, username = None, password = None, port = None, path = None, query = None, fragment = None))]
160180
#[allow(clippy::too_many_arguments)]
161181
pub fn build<'a>(
162182
cls: &'a PyType,
@@ -198,13 +218,15 @@ impl PyUrl {
198218
pub struct PyMultiHostUrl {
199219
ref_url: PyUrl,
200220
extra_urls: Option<Vec<Url>>,
221+
omit_trailing_slash: bool,
201222
}
202223

203224
impl PyMultiHostUrl {
204-
pub fn new(ref_url: Url, extra_urls: Option<Vec<Url>>) -> Self {
225+
pub fn new(ref_url: Url, extra_urls: Option<Vec<Url>>, omit_trailing_slash: Option<bool>) -> Self {
205226
Self {
206-
ref_url: PyUrl::new(ref_url),
227+
ref_url: PyUrl::new(ref_url, omit_trailing_slash),
207228
extra_urls,
229+
omit_trailing_slash: omit_trailing_slash.unwrap_or(false),
208230
}
209231
}
210232

@@ -214,13 +236,20 @@ impl PyMultiHostUrl {
214236
}
215237

216238
static SCHEMA_DEFINITION_MULTI_HOST_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
239+
static SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
217240

218241
#[pymethods]
219242
impl PyMultiHostUrl {
220243
#[new]
221-
pub fn py_new(py: Python, url: &PyAny) -> PyResult<Self> {
222-
let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL
223-
.get_or_init(py, || build_schema_validator(py, "multi-host-url"))
244+
pub fn py_new(py: Python, url: &PyAny, omit_trailing_slash: Option<bool>) -> PyResult<Self> {
245+
let omit = omit_trailing_slash.unwrap_or(false);
246+
let schema = if omit {
247+
&SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL
248+
} else {
249+
&SCHEMA_DEFINITION_MULTI_HOST_URL
250+
};
251+
let schema_obj = schema
252+
.get_or_init(py, || build_schema_validator(py, "multi-host-url", omit))
224253
.validate_python(py, url, None, None, None, None)?;
225254
schema_obj.extract(py)
226255
}
@@ -281,8 +310,9 @@ impl PyMultiHostUrl {
281310
let hosts = extra_urls
282311
.iter()
283312
.map(|url| {
284-
let str = unicode_url(url);
285-
str[host_offset..str.len() - sub].to_string()
313+
let str = unicode_url(url, self.omit_trailing_slash);
314+
let _sub = if self.omit_trailing_slash { 0 } else { sub };
315+
str[host_offset..str.len() - _sub].to_string()
286316
})
287317
.collect::<Vec<String>>()
288318
.join(",");
@@ -298,7 +328,7 @@ impl PyMultiHostUrl {
298328
let schema = self.ref_url.lib_url.scheme();
299329
let host_offset = schema.len() + 3;
300330

301-
let mut full_url = self.ref_url.lib_url.to_string();
331+
let mut full_url = self.ref_url.__str__();
302332
full_url.insert(host_offset, ',');
303333

304334
// special urls will have had a trailing slash added, non-special urls will not
@@ -356,7 +386,7 @@ impl PyMultiHostUrl {
356386
}
357387

358388
#[classmethod]
359-
#[pyo3(signature=(*, scheme, hosts=None, path=None, query=None, fragment=None, host=None, username=None, password=None, port=None))]
389+
#[pyo3(signature = (*, scheme, hosts = None, path = None, query = None, fragment = None, host = None, username = None, password = None, port = None))]
360390
#[allow(clippy::too_many_arguments)]
361391
pub fn build<'a>(
362392
cls: &'a PyType,
@@ -480,19 +510,34 @@ fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> {
480510
Ok(dict)
481511
}
482512

483-
fn unicode_url(lib_url: &Url) -> String {
513+
fn unicode_url(lib_url: &Url, omit_trailing_slash: bool) -> String {
484514
let mut s = lib_url.to_string();
485515

486516
match lib_url.host() {
487517
Some(url::Host::Domain(domain)) if is_punnycode_domain(lib_url, domain) => {
488518
if let Some(decoded) = decode_punycode(domain) {
489519
// replace the range containing the punycode domain with the decoded domain
490-
let start = lib_url.scheme().len() + 3;
520+
let before_path = before_path_length(lib_url);
521+
let start = before_path
522+
- domain.len()
523+
- match lib_url.port() {
524+
Some(port) => 1 + port.to_string().len(),
525+
None => 0,
526+
};
527+
if omit_trailing_slash && lib_url.path() == "/" {
528+
s.replace_range(before_path..=before_path, "");
529+
}
491530
s.replace_range(start..start + domain.len(), &decoded);
492531
}
493532
s
494533
}
495-
_ => s,
534+
_ => {
535+
if omit_trailing_slash && lib_url.path() == "/" {
536+
let before_path = before_path_length(lib_url);
537+
s.replace_range(before_path..=before_path, "");
538+
}
539+
s
540+
}
496541
}
497542
}
498543

@@ -520,3 +565,21 @@ fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool {
520565
pub fn schema_is_special(schema: &str) -> bool {
521566
matches!(schema, "http" | "https" | "ws" | "wss" | "ftp" | "file")
522567
}
568+
569+
fn before_path_length(url: &Url) -> usize {
570+
let length = url.scheme().len()
571+
+ 3 // :// part
572+
+ match url.username() {
573+
"" => 0,
574+
// for colon (:) and at (@) signs we're adding +2
575+
username => 2 + username.len() + url.password().unwrap_or("").len(),
576+
}
577+
+ url.host_str().unwrap().len()
578+
+ match url.port() {
579+
// for colon (:) +1
580+
Some(port) => 1 + port.to_string().len(),
581+
None => 0,
582+
};
583+
584+
length
585+
}

0 commit comments

Comments
 (0)