@@ -15,36 +15,49 @@ use crate::tools::SchemaDict;
15
15
use crate :: SchemaValidator ;
16
16
17
17
static SCHEMA_DEFINITION_URL : GILOnceCell < SchemaValidator > = GILOnceCell :: new ( ) ;
18
+ static SCHEMA_DEFINITION_OMIT_SLASH_URL : GILOnceCell < SchemaValidator > = GILOnceCell :: new ( ) ;
18
19
19
20
#[ pyclass( name = "Url" , module = "pydantic_core._pydantic_core" , subclass) ]
20
21
#[ derive( Clone ) ]
21
22
#[ cfg_attr( debug_assertions, derive( Debug ) ) ]
22
23
pub struct PyUrl {
23
24
lib_url : Url ,
25
+ omit_trailing_slash : bool ,
24
26
}
25
27
26
28
impl PyUrl {
27
- pub fn new ( lib_url : Url ) -> Self {
28
- Self { lib_url }
29
+ pub fn new ( lib_url : Url , omit_trailing_slash : Option < bool > ) -> Self {
30
+ Self {
31
+ lib_url,
32
+ omit_trailing_slash : omit_trailing_slash. unwrap_or ( false ) ,
33
+ }
29
34
}
30
35
31
36
pub fn into_url ( self ) -> Url {
32
37
self . lib_url
33
38
}
34
39
}
35
40
36
- fn build_schema_validator ( py : Python , schema_type : & str ) -> SchemaValidator {
41
+ fn build_schema_validator ( py : Python , schema_type : & str , omit_trailing_slash : bool ) -> SchemaValidator {
37
42
let schema: & PyDict = PyDict :: new ( py) ;
38
43
schema. set_item ( "type" , schema_type) . unwrap ( ) ;
44
+ // TODO: it seems wrong, do it?
45
+ schema. set_item ( "omit_trailing_slash" , omit_trailing_slash) . unwrap ( ) ;
39
46
SchemaValidator :: py_new ( py, schema, None ) . unwrap ( )
40
47
}
41
48
42
49
#[ pymethods]
43
50
impl PyUrl {
44
51
#[ new]
45
- pub fn py_new ( py : Python , url : & PyAny ) -> PyResult < Self > {
46
- let schema_obj = SCHEMA_DEFINITION_URL
47
- . get_or_init ( py, || build_schema_validator ( py, "url" ) )
52
+ pub fn py_new ( py : Python , url : & PyAny , omit_trailing_slash : Option < bool > ) -> PyResult < Self > {
53
+ let omit = omit_trailing_slash. unwrap_or ( false ) ;
54
+ let schema = if omit {
55
+ & SCHEMA_DEFINITION_OMIT_SLASH_URL
56
+ } else {
57
+ & SCHEMA_DEFINITION_URL
58
+ } ;
59
+ let schema_obj = schema
60
+ . get_or_init ( py, || build_schema_validator ( py, "url" , omit) )
48
61
. validate_python ( py, url, None , None , None , None ) ?;
49
62
schema_obj. extract ( py)
50
63
}
@@ -89,6 +102,7 @@ impl PyUrl {
89
102
pub fn path ( & self ) -> Option < & str > {
90
103
match self . lib_url . path ( ) {
91
104
"" => None ,
105
+ path if self . omit_trailing_slash && path == "/" => None ,
92
106
path => Some ( path) ,
93
107
}
94
108
}
@@ -114,15 +128,21 @@ impl PyUrl {
114
128
115
129
// string representation of the URL, with punycode decoded when appropriate
116
130
pub fn unicode_string ( & self ) -> String {
117
- unicode_url ( & self . lib_url )
131
+ unicode_url ( & self . lib_url , self . omit_trailing_slash )
118
132
}
119
133
120
- pub fn __str__ ( & self ) -> & str {
121
- self . lib_url . as_str ( )
134
+ pub fn __str__ ( & self ) -> String {
135
+ if self . omit_trailing_slash && self . lib_url . path ( ) == "/" {
136
+ let start = before_path_length ( & self . lib_url ) ;
137
+ let mut s = self . lib_url . to_string ( ) ;
138
+ s. replace_range ( start..=start, "" ) ;
139
+ return s;
140
+ }
141
+ self . lib_url . to_string ( )
122
142
}
123
143
124
144
pub fn __repr__ ( & self ) -> String {
125
- format ! ( "Url('{}')" , self . lib_url )
145
+ format ! ( "Url('{}')" , self . __str__ ( ) )
126
146
}
127
147
128
148
fn __richcmp__ ( & self , other : & Self , op : CompareOp ) -> PyResult < bool > {
@@ -151,12 +171,12 @@ impl PyUrl {
151
171
self . clone ( ) . into_py ( py)
152
172
}
153
173
154
- fn __getnewargs__ ( & self ) -> ( & str , ) {
174
+ fn __getnewargs__ ( & self ) -> ( String , ) {
155
175
( self . __str__ ( ) , )
156
176
}
157
177
158
178
#[ classmethod]
159
- #[ pyo3( signature= ( * , scheme, host, username= None , password= None , port= None , path= None , query= None , fragment= None ) ) ]
179
+ #[ pyo3( signature = ( * , scheme, host, username = None , password = None , port = None , path = None , query = None , fragment = None ) ) ]
160
180
#[ allow( clippy:: too_many_arguments) ]
161
181
pub fn build < ' a > (
162
182
cls : & ' a PyType ,
@@ -198,13 +218,15 @@ impl PyUrl {
198
218
pub struct PyMultiHostUrl {
199
219
ref_url : PyUrl ,
200
220
extra_urls : Option < Vec < Url > > ,
221
+ omit_trailing_slash : bool ,
201
222
}
202
223
203
224
impl PyMultiHostUrl {
204
- pub fn new ( ref_url : Url , extra_urls : Option < Vec < Url > > ) -> Self {
225
+ pub fn new ( ref_url : Url , extra_urls : Option < Vec < Url > > , omit_trailing_slash : Option < bool > ) -> Self {
205
226
Self {
206
- ref_url : PyUrl :: new ( ref_url) ,
227
+ ref_url : PyUrl :: new ( ref_url, omit_trailing_slash ) ,
207
228
extra_urls,
229
+ omit_trailing_slash : omit_trailing_slash. unwrap_or ( false ) ,
208
230
}
209
231
}
210
232
@@ -214,13 +236,20 @@ impl PyMultiHostUrl {
214
236
}
215
237
216
238
static SCHEMA_DEFINITION_MULTI_HOST_URL : GILOnceCell < SchemaValidator > = GILOnceCell :: new ( ) ;
239
+ static SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL : GILOnceCell < SchemaValidator > = GILOnceCell :: new ( ) ;
217
240
218
241
#[ pymethods]
219
242
impl PyMultiHostUrl {
220
243
#[ new]
221
- pub fn py_new ( py : Python , url : & PyAny ) -> PyResult < Self > {
222
- let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL
223
- . get_or_init ( py, || build_schema_validator ( py, "multi-host-url" ) )
244
+ pub fn py_new ( py : Python , url : & PyAny , omit_trailing_slash : Option < bool > ) -> PyResult < Self > {
245
+ let omit = omit_trailing_slash. unwrap_or ( false ) ;
246
+ let schema = if omit {
247
+ & SCHEMA_DEFINITION_MULTI_HOST_OMIT_SLASH_URL
248
+ } else {
249
+ & SCHEMA_DEFINITION_MULTI_HOST_URL
250
+ } ;
251
+ let schema_obj = schema
252
+ . get_or_init ( py, || build_schema_validator ( py, "multi-host-url" , omit) )
224
253
. validate_python ( py, url, None , None , None , None ) ?;
225
254
schema_obj. extract ( py)
226
255
}
@@ -281,8 +310,9 @@ impl PyMultiHostUrl {
281
310
let hosts = extra_urls
282
311
. iter ( )
283
312
. map ( |url| {
284
- let str = unicode_url ( url) ;
285
- str[ host_offset..str. len ( ) - sub] . to_string ( )
313
+ let str = unicode_url ( url, self . omit_trailing_slash ) ;
314
+ let _sub = if self . omit_trailing_slash { 0 } else { sub } ;
315
+ str[ host_offset..str. len ( ) - _sub] . to_string ( )
286
316
} )
287
317
. collect :: < Vec < String > > ( )
288
318
. join ( "," ) ;
@@ -298,7 +328,7 @@ impl PyMultiHostUrl {
298
328
let schema = self . ref_url . lib_url . scheme ( ) ;
299
329
let host_offset = schema. len ( ) + 3 ;
300
330
301
- let mut full_url = self . ref_url . lib_url . to_string ( ) ;
331
+ let mut full_url = self . ref_url . __str__ ( ) ;
302
332
full_url. insert ( host_offset, ',' ) ;
303
333
304
334
// special urls will have had a trailing slash added, non-special urls will not
@@ -356,7 +386,7 @@ impl PyMultiHostUrl {
356
386
}
357
387
358
388
#[ classmethod]
359
- #[ pyo3( signature= ( * , scheme, hosts= None , path= None , query= None , fragment= None , host= None , username= None , password= None , port= None ) ) ]
389
+ #[ pyo3( signature = ( * , scheme, hosts = None , path = None , query = None , fragment = None , host = None , username = None , password = None , port = None ) ) ]
360
390
#[ allow( clippy:: too_many_arguments) ]
361
391
pub fn build < ' a > (
362
392
cls : & ' a PyType ,
@@ -480,19 +510,34 @@ fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> {
480
510
Ok ( dict)
481
511
}
482
512
483
- fn unicode_url ( lib_url : & Url ) -> String {
513
+ fn unicode_url ( lib_url : & Url , omit_trailing_slash : bool ) -> String {
484
514
let mut s = lib_url. to_string ( ) ;
485
515
486
516
match lib_url. host ( ) {
487
517
Some ( url:: Host :: Domain ( domain) ) if is_punnycode_domain ( lib_url, domain) => {
488
518
if let Some ( decoded) = decode_punycode ( domain) {
489
519
// replace the range containing the punycode domain with the decoded domain
490
- let start = lib_url. scheme ( ) . len ( ) + 3 ;
520
+ let before_path = before_path_length ( lib_url) ;
521
+ let start = before_path
522
+ - domain. len ( )
523
+ - match lib_url. port ( ) {
524
+ Some ( port) => 1 + port. to_string ( ) . len ( ) ,
525
+ None => 0 ,
526
+ } ;
527
+ if omit_trailing_slash && lib_url. path ( ) == "/" {
528
+ s. replace_range ( before_path..=before_path, "" ) ;
529
+ }
491
530
s. replace_range ( start..start + domain. len ( ) , & decoded) ;
492
531
}
493
532
s
494
533
}
495
- _ => s,
534
+ _ => {
535
+ if omit_trailing_slash && lib_url. path ( ) == "/" {
536
+ let before_path = before_path_length ( lib_url) ;
537
+ s. replace_range ( before_path..=before_path, "" ) ;
538
+ }
539
+ s
540
+ }
496
541
}
497
542
}
498
543
@@ -520,3 +565,21 @@ fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool {
520
565
pub fn schema_is_special ( schema : & str ) -> bool {
521
566
matches ! ( schema, "http" | "https" | "ws" | "wss" | "ftp" | "file" )
522
567
}
568
+
569
+ fn before_path_length ( url : & Url ) -> usize {
570
+ let length = url. scheme ( ) . len ( )
571
+ + 3 // :// part
572
+ + match url. username ( ) {
573
+ "" => 0 ,
574
+ // for colon (:) and at (@) signs we're adding +2
575
+ username => 2 + username. len ( ) + url. password ( ) . unwrap_or ( "" ) . len ( ) ,
576
+ }
577
+ + url. host_str ( ) . unwrap ( ) . len ( )
578
+ + match url. port ( ) {
579
+ // for colon (:) +1
580
+ Some ( port) => 1 + port. to_string ( ) . len ( ) ,
581
+ None => 0 ,
582
+ } ;
583
+
584
+ length
585
+ }
0 commit comments