Skip to content

Commit b9f420e

Browse files
add helper methods to field extension (#69)
* add helper methods to field extension * add test
1 parent ed9dda1 commit b9f420e

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

stac_pydantic/api/extensions/fields.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Set
1+
from typing import Dict, Optional, Set
22

33
from pydantic import BaseModel
44

@@ -10,3 +10,41 @@ class FieldsExtension(BaseModel):
1010

1111
includes: Optional[Set[str]]
1212
excludes: Optional[Set[str]]
13+
14+
def _get_field_dict(self, fields: Set[str]) -> Dict:
15+
"""Internal method to create a dictionary for advanced include or exclude of pydantic fields on model export
16+
17+
Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
18+
"""
19+
field_dict = {}
20+
for field in fields:
21+
if "." in field:
22+
parent, key = field.split(".")
23+
if parent not in field_dict:
24+
field_dict[parent] = {key}
25+
else:
26+
field_dict[parent].add(key)
27+
else:
28+
field_dict[field] = ... # type:ignore
29+
return field_dict
30+
31+
@property
32+
def filter(self) -> Dict:
33+
"""
34+
Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed
35+
to the API. The output of this property may be passed to pydantic's serialization methods to include or exclude
36+
certain keys.
37+
38+
Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
39+
"""
40+
include = set()
41+
# If only include is specified, add fields to the set
42+
if self.includes and not self.excludes:
43+
include = include.union(self.includes)
44+
# If both include + exclude specified, find the difference between sets
45+
elif self.includes and self.excludes:
46+
include = include.union(self.includes) - self.excludes
47+
return {
48+
"include": self._get_field_dict(include),
49+
"exclude": self._get_field_dict(self.excludes),
50+
}

tests/test_api_extensions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from datetime import datetime
2+
3+
import pytest
4+
from shapely.geometry import Polygon
5+
6+
from stac_pydantic import Item
7+
from stac_pydantic.api.extensions.fields import FieldsExtension
8+
9+
10+
def test_fields_filter():
11+
fields = FieldsExtension(
12+
includes={"id", "geometry", "properties.foo"}, excludes={"properties.bar"}
13+
)
14+
15+
item = Item(
16+
id="test-fields-filter",
17+
geometry=Polygon.from_bounds(0, 0, 0, 0),
18+
properties={"datetime": datetime.utcnow(), "foo": "foo", "bar": "bar"},
19+
assets={},
20+
links=[],
21+
bbox=[0, 0, 0, 0],
22+
)
23+
24+
d = item.to_dict(**fields.filter)
25+
assert d.pop("id") == item.id
26+
assert d.pop("geometry") == item.geometry
27+
props = d.pop("properties")
28+
assert props["foo"] == "foo"
29+
30+
assert not props.get("bar")
31+
assert not d

0 commit comments

Comments
 (0)