Skip to content

Commit 78a08e2

Browse files
committed
pass extra argument in arguments validator
1 parent c7daf16 commit 78a08e2

File tree

3 files changed

+101
-66
lines changed

3 files changed

+101
-66
lines changed

src/build_tools.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::error::Error;
22
use std::fmt;
33

4-
use pyo3::exceptions::PyException;
4+
use crate::tools::py_err;
5+
use pyo3::exceptions::{PyException, PyTypeError};
56
use pyo3::prelude::*;
67
use pyo3::types::{PyDict, PyList, PyString};
78
use pyo3::{intern, FromPyObject, PyErrArguments};
@@ -195,3 +196,28 @@ impl ExtraBehavior {
195196
Ok(res)
196197
}
197198
}
199+
200+
impl ToPyObject for ExtraBehavior {
201+
fn to_object(&self, py: Python) -> PyObject {
202+
match self {
203+
ExtraBehavior::Allow => ExtraBehavior::Allow.to_object(py),
204+
ExtraBehavior::Ignore => ExtraBehavior::Ignore.to_object(py),
205+
ExtraBehavior::Forbid => ExtraBehavior::Forbid.to_object(py),
206+
}
207+
}
208+
}
209+
210+
impl FromPyObject<'_> for ExtraBehavior {
211+
fn extract(obj: &PyAny) -> PyResult<Self> {
212+
if let Ok(string) = obj.extract::<String>() {
213+
Ok(match string.as_str() {
214+
"allow" => ExtraBehavior::Allow,
215+
"ignore" => ExtraBehavior::Ignore,
216+
"forbid" => ExtraBehavior::Forbid,
217+
_ => return py_err!(PyTypeError; "Invalid string for ExtraBehavior"),
218+
})
219+
} else {
220+
py_err!(PyTypeError; "Expected string value allow, ignore or forbid, got {}", obj.get_type())
221+
}
222+
}
223+
}

src/validators/arguments.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple};
55
use ahash::AHashSet;
66

77
use crate::build_tools::py_schema_err;
8-
use crate::build_tools::schema_or_config_same;
8+
use crate::build_tools::{schema_or_config_same, ExtraBehavior};
99
use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult};
1010
use crate::input::{GenericArguments, Input, ValidationMatch};
1111
use crate::lookup_key::LookupKey;
@@ -31,6 +31,7 @@ pub struct ArgumentsValidator {
3131
var_args_validator: Option<Box<CombinedValidator>>,
3232
var_kwargs_validator: Option<Box<CombinedValidator>>,
3333
loc_by_alias: bool,
34+
extra: ExtraBehavior,
3435
}
3536

3637
impl BuildValidator for ArgumentsValidator {
@@ -73,7 +74,7 @@ impl BuildValidator for ArgumentsValidator {
7374
}
7475
None => Some(LookupKey::from_string(py, &name)),
7576
};
76-
kwarg_key = Some(PyString::new(py, &name).into());
77+
kwarg_key = Some(PyString::intern(py, &name).into());
7778
}
7879

7980
let schema: &PyAny = arg.get_as_req(intern!(py, "schema"))?;
@@ -119,6 +120,9 @@ impl BuildValidator for ArgumentsValidator {
119120
None => None,
120121
},
121122
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
123+
extra: config
124+
.get_as(intern!(py, "extra_fields_behavior"))?
125+
.unwrap_or(ExtraBehavior::Forbid),
122126
}
123127
.into())
124128
}
@@ -166,7 +170,7 @@ impl Validator for ArgumentsValidator {
166170
py: Python<'data>,
167171
input: &'data impl Input<'data>,
168172
state: &mut ValidationState,
169-
) -> ValResult<PyObject> {
173+
) -> ValResult<'data, PyObject> {
170174
let args = input.validate_args()?;
171175

172176
let mut output_args: Vec<PyObject> = Vec::with_capacity(self.positional_params_count);
@@ -307,15 +311,16 @@ impl Validator for ArgumentsValidator {
307311
Err(err) => return Err(err),
308312
},
309313
None => {
310-
errors.push(ValLineError::new_with_loc(
311-
ErrorTypeDefaults::UnexpectedKeywordArgument,
312-
value,
313-
raw_key.as_loc_item(),
314-
));
314+
if let ExtraBehavior::Forbid = self.extra {
315+
errors.push(ValLineError::new_with_loc(
316+
ErrorTypeDefaults::UnexpectedKeywordArgument,
317+
value,
318+
raw_key.as_loc_item(),
319+
));
320+
}
315321
}
316322
}
317-
}
318-
}
323+
}}
319324
}
320325
}
321326
}};

tests/validators/test_arguments.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -775,57 +775,56 @@ def test_alias_populate_by_name(py_and_json: PyAndJson, input_value, expected):
775775
assert v.validate_test(input_value) == expected
776776

777777

778-
def validate(function):
779-
"""
780-
a demo validation decorator to test arguments
781-
"""
782-
parameters = signature(function).parameters
783-
784-
type_hints = get_type_hints(function)
785-
mode_lookup = {
786-
Parameter.POSITIONAL_ONLY: 'positional_only',
787-
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
788-
Parameter.KEYWORD_ONLY: 'keyword_only',
789-
}
790-
791-
arguments_schema = []
792-
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
793-
for i, (name, p) in enumerate(parameters.items()):
794-
if p.annotation is p.empty:
795-
annotation = Any
796-
else:
797-
annotation = type_hints[name]
798-
799-
assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
800-
if annotation in (bool, int, float, str):
801-
arg_schema = {'type': annotation.__name__}
802-
else:
803-
assert annotation is Any
804-
arg_schema = {'type': 'any'}
805-
806-
if p.kind in mode_lookup:
807-
if p.default is not p.empty:
808-
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
809-
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
810-
arguments_schema.append(s)
811-
elif p.kind == Parameter.VAR_POSITIONAL:
812-
schema['var_args_schema'] = arg_schema
813-
else:
814-
assert p.kind == Parameter.VAR_KEYWORD, p.kind
815-
schema['var_kwargs_schema'] = arg_schema
816-
817-
validator = SchemaValidator(schema)
818-
819-
@wraps(function)
820-
def wrapper(*args, **kwargs):
821-
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
822-
return function(*validated_args, **validated_kwargs)
823-
824-
return wrapper
778+
def validate(config=None):
779+
def decorator(function):
780+
parameters = signature(function).parameters
781+
type_hints = get_type_hints(function)
782+
mode_lookup = {
783+
Parameter.POSITIONAL_ONLY: 'positional_only',
784+
Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword',
785+
Parameter.KEYWORD_ONLY: 'keyword_only',
786+
}
825787

788+
arguments_schema = []
789+
schema = {'type': 'arguments', 'arguments_schema': arguments_schema}
790+
for i, (name, p) in enumerate(parameters.items()):
791+
if p.annotation is p.empty:
792+
annotation = Any
793+
else:
794+
annotation = type_hints[name]
795+
796+
assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented'
797+
if annotation in (bool, int, float, str):
798+
arg_schema = {'type': annotation.__name__}
799+
else:
800+
assert annotation is Any
801+
arg_schema = {'type': 'any'}
802+
803+
if p.kind in mode_lookup:
804+
if p.default is not p.empty:
805+
arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default}
806+
s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema}
807+
arguments_schema.append(s)
808+
elif p.kind == Parameter.VAR_POSITIONAL:
809+
schema['var_args_schema'] = arg_schema
810+
else:
811+
assert p.kind == Parameter.VAR_KEYWORD, p.kind
812+
schema['var_kwargs_schema'] = arg_schema
813+
814+
validator = SchemaValidator(schema, config=config)
815+
816+
@wraps(function)
817+
def wrapper(*args, **kwargs):
818+
# Validate arguments using the original schema
819+
validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs))
820+
return function(*validated_args, **validated_kwargs)
821+
822+
return wrapper
823+
824+
return decorator
826825

827826
def test_function_any():
828-
@validate
827+
@validate()
829828
def foobar(a, b, c):
830829
return a, b, c
831830

@@ -842,7 +841,7 @@ def foobar(a, b, c):
842841

843842

844843
def test_function_types():
845-
@validate
844+
@validate()
846845
def foobar(a: int, b: int, *, c: int):
847846
return a, b, c
848847

@@ -894,8 +893,8 @@ def test_function_positional_only(import_execute):
894893
# language=Python
895894
m = import_execute(
896895
"""
897-
def create_function(validate):
898-
@validate
896+
def create_function(validate, config = None):
897+
@validate(config = config)
899898
def foobar(a: int, b: int, /, c: int):
900899
return a, b, c
901900
return foobar
@@ -915,15 +914,20 @@ def foobar(a: int, b: int, /, c: int):
915914
},
916915
{'type': 'unexpected_keyword_argument', 'loc': ('b',), 'msg': 'Unexpected keyword argument', 'input': 2},
917916
]
918-
917+
# Allowin extras using the config
918+
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'allow'})
919+
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)
920+
# Ignore works similar than allow
921+
foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'ignore'})
922+
assert foobar('1', '2', c=3, d=4) == (1, 2, 3)
919923

920924
@pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher')
921925
def test_function_positional_only_default(import_execute):
922926
# language=Python
923927
m = import_execute(
924928
"""
925929
def create_function(validate):
926-
@validate
930+
@validate()
927931
def foobar(a: int, b: int = 42, /):
928932
return a, b
929933
return foobar
@@ -940,7 +944,7 @@ def test_function_positional_kwargs(import_execute):
940944
m = import_execute(
941945
"""
942946
def create_function(validate):
943-
@validate
947+
@validate()
944948
def foobar(a: int, b: int, /, **kwargs: bool):
945949
return a, b, kwargs
946950
return foobar
@@ -953,7 +957,7 @@ def foobar(a: int, b: int, /, **kwargs: bool):
953957

954958

955959
def test_function_args_kwargs():
956-
@validate
960+
@validate()
957961
def foobar(*args, **kwargs):
958962
return args, kwargs
959963

0 commit comments

Comments
 (0)