Skip to content

Commit fe73652

Browse files
Support subclass inits for Url and MultiHostUrl (#1508)
1 parent 8568136 commit fe73652

File tree

6 files changed

+100
-15
lines changed

6 files changed

+100
-15
lines changed

python/pydantic_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class ErrorTypeInfo(_TypedDict):
124124
"""Example of context values."""
125125

126126

127-
class MultiHostHost(_TypedDict):
127+
class MultiHostHost(_TypedDict, total=False):
128128
"""
129129
A host part of a multi-host URL.
130130
"""

python/pydantic_core/_pydantic_core.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ class Url(SupportsAllComparisons):
583583
scheme: str,
584584
username: str | None = None,
585585
password: str | None = None,
586-
host: str,
586+
host: str | None = None,
587587
port: int | None = None,
588588
path: str | None = None,
589589
query: str | None = None,
@@ -596,7 +596,7 @@ class Url(SupportsAllComparisons):
596596
scheme: The scheme part of the URL.
597597
username: The username part of the URL, or omit for no username.
598598
password: The password part of the URL, or omit for no password.
599-
host: The host part of the URL.
599+
host: The host part of the URL, or omit for no host.
600600
port: The port part of the URL, or omit for no port.
601601
path: The path part of the URL, or omit for no path.
602602
query: The query part of the URL, or omit for no query.

python/pydantic_core/core_schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3655,6 +3655,7 @@ class MyModel:
36553655

36563656
class UrlSchema(TypedDict, total=False):
36573657
type: Required[Literal['url']]
3658+
cls: Type[Any]
36583659
max_length: int
36593660
allowed_schemes: List[str]
36603661
host_required: bool # default False
@@ -3669,6 +3670,7 @@ class UrlSchema(TypedDict, total=False):
36693670

36703671
def url_schema(
36713672
*,
3673+
cls: Type[Any] | None = None,
36723674
max_length: int | None = None,
36733675
allowed_schemes: list[str] | None = None,
36743676
host_required: bool | None = None,
@@ -3693,6 +3695,7 @@ def url_schema(
36933695
```
36943696
36953697
Args:
3698+
cls: The class to use for the URL build (a subclass of `pydantic_core.Url`)
36963699
max_length: The maximum length of the URL
36973700
allowed_schemes: The allowed URL schemes
36983701
host_required: Whether the URL must have a host
@@ -3706,6 +3709,7 @@ def url_schema(
37063709
"""
37073710
return _dict_not_none(
37083711
type='url',
3712+
cls=cls,
37093713
max_length=max_length,
37103714
allowed_schemes=allowed_schemes,
37113715
host_required=host_required,
@@ -3721,6 +3725,7 @@ def url_schema(
37213725

37223726
class MultiHostUrlSchema(TypedDict, total=False):
37233727
type: Required[Literal['multi-host-url']]
3728+
cls: Type[Any]
37243729
max_length: int
37253730
allowed_schemes: List[str]
37263731
host_required: bool # default False
@@ -3735,6 +3740,7 @@ class MultiHostUrlSchema(TypedDict, total=False):
37353740

37363741
def multi_host_url_schema(
37373742
*,
3743+
cls: Type[Any] | None = None,
37383744
max_length: int | None = None,
37393745
allowed_schemes: list[str] | None = None,
37403746
host_required: bool | None = None,
@@ -3759,6 +3765,7 @@ def multi_host_url_schema(
37593765
```
37603766
37613767
Args:
3768+
cls: The class to use for the URL build (a subclass of `pydantic_core.MultiHostUrl`)
37623769
max_length: The maximum length of the URL
37633770
allowed_schemes: The allowed URL schemes
37643771
host_required: Whether the URL must have a host
@@ -3772,6 +3779,7 @@ def multi_host_url_schema(
37723779
"""
37733780
return _dict_not_none(
37743781
type='multi-host-url',
3782+
cls=cls,
37753783
max_length=max_length,
37763784
allowed_schemes=allowed_schemes,
37773785
host_required=host_required,

src/url.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,12 @@ impl PyUrl {
156156
}
157157

158158
#[classmethod]
159-
#[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))]
159+
#[pyo3(signature=(*, scheme, host=None, username=None, password=None, port=None, path=None, query=None, fragment=None))]
160160
#[allow(clippy::too_many_arguments)]
161161
pub fn build<'py>(
162162
cls: &Bound<'py, PyType>,
163163
scheme: &str,
164-
host: &str,
164+
host: Option<&str>,
165165
username: Option<&str>,
166166
password: Option<&str>,
167167
port: Option<u16>,
@@ -172,7 +172,7 @@ impl PyUrl {
172172
let url_host = UrlHostParts {
173173
username: username.map(Into::into),
174174
password: password.map(Into::into),
175-
host: Some(host.into()),
175+
host: host.map(Into::into),
176176
port,
177177
};
178178
let mut url = format!("{scheme}://{url_host}");
@@ -423,6 +423,7 @@ impl PyMultiHostUrl {
423423
}
424424
}
425425

426+
#[cfg_attr(debug_assertions, derive(Debug))]
426427
pub struct UrlHostParts {
427428
username: Option<String>,
428429
password: Option<String>,
@@ -440,11 +441,12 @@ impl FromPyObject<'_> for UrlHostParts {
440441
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
441442
let py = ob.py();
442443
let dict = ob.downcast::<PyDict>()?;
444+
443445
Ok(UrlHostParts {
444-
username: dict.get_as(intern!(py, "username"))?,
445-
password: dict.get_as(intern!(py, "password"))?,
446-
host: dict.get_as(intern!(py, "host"))?,
447-
port: dict.get_as(intern!(py, "port"))?,
446+
username: dict.get_as::<Option<_>>(intern!(py, "username"))?.flatten(),
447+
password: dict.get_as::<Option<_>>(intern!(py, "password"))?.flatten(),
448+
host: dict.get_as::<Option<_>>(intern!(py, "host"))?.flatten(),
449+
port: dict.get_as::<Option<_>>(intern!(py, "port"))?.flatten(),
448450
})
449451
}
450452
}

src/validators/url.rs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::str::Chars;
44

55
use pyo3::intern;
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyDict, PyList};
7+
use pyo3::types::{PyDict, PyList, PyType};
88

99
use ahash::AHashSet;
1010
use url::{ParseError, SyntaxViolation, Url};
@@ -26,6 +26,7 @@ type AllowedSchemas = Option<(AHashSet<String>, String)>;
2626
#[derive(Debug, Clone)]
2727
pub struct UrlValidator {
2828
strict: bool,
29+
cls: Option<Py<PyType>>,
2930
max_length: Option<usize>,
3031
allowed_schemes: AllowedSchemas,
3132
host_required: bool,
@@ -47,6 +48,7 @@ impl BuildValidator for UrlValidator {
4748

4849
Ok(Self {
4950
strict: is_strict(schema, config)?,
51+
cls: schema.get_as(intern!(schema.py(), "cls"))?,
5052
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
5153
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
5254
default_host: schema.get_as(intern!(schema.py(), "default_host"))?,
@@ -59,7 +61,7 @@ impl BuildValidator for UrlValidator {
5961
}
6062
}
6163

62-
impl_py_gc_traverse!(UrlValidator {});
64+
impl_py_gc_traverse!(UrlValidator { cls });
6365

6466
impl Validator for UrlValidator {
6567
fn validate<'py>(
@@ -93,7 +95,31 @@ impl Validator for UrlValidator {
9395
Ok(()) => {
9496
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
9597
state.floor_exactness(Exactness::Lax);
96-
Ok(either_url.into_py(py))
98+
99+
if let Some(url_subclass) = &self.cls {
100+
// TODO: we do an extra build for a subclass here, we should avoid this
101+
// in v2.11 for perf reasons, but this is a worthwhile patch for now
102+
// given that we want isinstance to work properly for subclasses of Url
103+
let py_url = match either_url {
104+
EitherUrl::Py(py_url) => py_url.get().clone(),
105+
EitherUrl::Rust(rust_url) => PyUrl::new(rust_url),
106+
};
107+
108+
let py_url = PyUrl::build(
109+
url_subclass.bind(py),
110+
py_url.scheme(),
111+
py_url.host(),
112+
py_url.username(),
113+
py_url.password(),
114+
py_url.port(),
115+
py_url.path().filter(|path| *path != "/"),
116+
py_url.query(),
117+
py_url.fragment(),
118+
)?;
119+
Ok(py_url.into_py(py))
120+
} else {
121+
Ok(either_url.into_py(py))
122+
}
97123
}
98124
Err(error_type) => Err(ValError::new(error_type, input)),
99125
}
@@ -186,6 +212,7 @@ impl CopyFromPyUrl for EitherUrl<'_> {
186212
#[derive(Debug, Clone)]
187213
pub struct MultiHostUrlValidator {
188214
strict: bool,
215+
cls: Option<Py<PyType>>,
189216
max_length: Option<usize>,
190217
allowed_schemes: AllowedSchemas,
191218
host_required: bool,
@@ -213,6 +240,7 @@ impl BuildValidator for MultiHostUrlValidator {
213240
}
214241
Ok(Self {
215242
strict: is_strict(schema, config)?,
243+
cls: schema.get_as(intern!(schema.py(), "cls"))?,
216244
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
217245
allowed_schemes,
218246
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
@@ -225,7 +253,7 @@ impl BuildValidator for MultiHostUrlValidator {
225253
}
226254
}
227255

228-
impl_py_gc_traverse!(MultiHostUrlValidator {});
256+
impl_py_gc_traverse!(MultiHostUrlValidator { cls });
229257

230258
impl Validator for MultiHostUrlValidator {
231259
fn validate<'py>(
@@ -258,7 +286,38 @@ impl Validator for MultiHostUrlValidator {
258286
Ok(()) => {
259287
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
260288
state.floor_exactness(Exactness::Lax);
261-
Ok(multi_url.into_py(py))
289+
290+
if let Some(url_subclass) = &self.cls {
291+
// TODO: we do an extra build for a subclass here, we should avoid this
292+
// in v2.11 for perf reasons, but this is a worthwhile patch for now
293+
// given that we want isinstance to work properly for subclasses of Url
294+
let py_url = match multi_url {
295+
EitherMultiHostUrl::Py(py_url) => py_url.get().clone(),
296+
EitherMultiHostUrl::Rust(rust_url) => rust_url,
297+
};
298+
299+
let hosts = py_url
300+
.hosts(py)?
301+
.into_iter()
302+
.map(|host| host.extract().expect("host should be a valid UrlHostParts"))
303+
.collect();
304+
305+
let py_url = PyMultiHostUrl::build(
306+
url_subclass.bind(py),
307+
py_url.scheme(),
308+
Some(hosts),
309+
py_url.path().filter(|path| *path != "/"),
310+
py_url.query(),
311+
py_url.fragment(),
312+
None,
313+
None,
314+
None,
315+
None,
316+
)?;
317+
Ok(py_url.into_py(py))
318+
} else {
319+
Ok(multi_url.into_py(py))
320+
}
262321
}
263322
Err(error_type) => Err(ValError::new(error_type, input)),
264323
}

tests/validators/test_url.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,3 +1305,19 @@ def test_url_build() -> None:
13051305
)
13061306
assert url == Url('postgresql://testuser:[email protected]:5432/database?sslmode=require#test')
13071307
assert str(url) == 'postgresql://testuser:[email protected]:5432/database?sslmode=require#test'
1308+
1309+
1310+
def test_url_subclass() -> None:
1311+
class UrlSubclass(Url):
1312+
pass
1313+
1314+
validator = SchemaValidator(core_schema.url_schema(cls=UrlSubclass))
1315+
assert isinstance(validator.validate_python('http://example.com'), UrlSubclass)
1316+
1317+
1318+
def test_multi_host_url_subclass() -> None:
1319+
class MultiHostUrlSubclass(MultiHostUrl):
1320+
pass
1321+
1322+
validator = SchemaValidator(core_schema.multi_host_url_schema(cls=MultiHostUrlSubclass))
1323+
assert isinstance(validator.validate_python('http://example.com'), MultiHostUrlSubclass)

0 commit comments

Comments
 (0)