Skip to content

Commit 1e27dbb

Browse files
committed
Allow for generic auth headers with pystac + validators
collection or collection_id Add a test to ensure auth headers arg works Move function around Fix mypy and suppress type checks for mocked test Address review comments Fix mypy
1 parent 5e0dddc commit 1e27dbb

File tree

6 files changed

+281
-9
lines changed

6 files changed

+281
-9
lines changed

noxfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def mypy(session: Session) -> None:
151151
"""Type-check using mypy."""
152152
args = session.posargs or ["src", "tests", "docs/conf.py"]
153153
session.install(".")
154-
session.install("mypy", "pytest", "types-requests", "types-PyYAML")
154+
session.install("mypy", "pytest", "types-requests", "types-PyYAML", "typeguard")
155155
session.run("mypy", *args)
156156
if not session.posargs:
157157
session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py")
@@ -161,7 +161,7 @@ def mypy(session: Session) -> None:
161161
def tests(session: Session) -> None:
162162
"""Run the test suite."""
163163
session.install(".")
164-
session.install("coverage[toml]", "pytest", "pygments")
164+
session.install("coverage[toml]", "pytest", "pygments", "typeguard")
165165
try:
166166
session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs)
167167
finally:

src/stac_api_validator/__main__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@
130130
"--transaction-collection",
131131
help="The name of the collection to use for Transaction Extension tests.",
132132
)
133+
@click.option(
134+
"-H",
135+
"--headers",
136+
multiple=True,
137+
help="Headers to attach to the main request and dependent pystac requests, curl syntax",
138+
)
133139
def main(
134140
log_level: str,
135141
root_url: str,
@@ -154,11 +160,21 @@ def main(
154160
query_in_field: Optional[str] = None,
155161
query_in_values: Optional[str] = None,
156162
transaction_collection: Optional[str] = None,
163+
headers: Optional[List[str]] = None,
157164
) -> int:
158165
"""STAC API Validator."""
159166
logging.basicConfig(stream=sys.stdout, level=log_level)
160167

161168
try:
169+
processed_headers = {}
170+
if headers:
171+
processed_headers.update(
172+
{
173+
key.strip(): value.strip()
174+
for key, value in (header.split(":") for header in headers)
175+
}
176+
)
177+
162178
(warnings, errors) = validate_api(
163179
root_url=root_url,
164180
ccs_to_validate=conformance_classes,
@@ -184,6 +200,7 @@ def main(
184200
query_in_values,
185201
),
186202
transaction_collection=transaction_collection,
203+
headers=processed_headers,
187204
)
188205
except Exception as e:
189206
click.secho(

src/stac_api_validator/validations.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pystac import Collection
2727
from pystac import Item
2828
from pystac import ItemCollection
29+
from pystac import StacIO
2930
from pystac import STACValidationError
3031
from pystac_client import Client
3132
from requests import Request
@@ -297,6 +298,16 @@ def is_geojson_type(maybe_type: Optional[str]) -> bool:
297298
)
298299

299300

301+
def get_catalog(data_dict: Dict[str, Any], r_session: Session) -> Catalog:
302+
stac_io = StacIO.default()
303+
if r_session.headers and r_session.headers.get("Authorization"):
304+
stac_io.headers = r_session.headers # noqa, type: ignore
305+
stac_io.headers["Accept-Encoding"] = "*"
306+
catalog = Catalog.from_dict(data_dict)
307+
catalog._stac_io = stac_io
308+
return catalog
309+
310+
300311
# def is_json_or_geojson_type(maybe_type: Optional[str]) -> bool:
301312
# return maybe_type and (is_json_type(maybe_type) or is_geojson_type(maybe_type))
302313

@@ -380,9 +391,8 @@ def retrieve(
380391
additional: Optional[str] = "",
381392
content_type: Optional[str] = None,
382393
) -> Tuple[int, Optional[Dict[str, Any]], Optional[Mapping[str, str]]]:
383-
resp = r_session.send(
384-
Request(method.value, url, headers=headers, params=params, json=body).prepare()
385-
)
394+
request = Request(method.value, url, headers=headers, params=params, json=body)
395+
resp = r_session.send(r_session.prepare_request(request))
386396

387397
# todo: handle connection exception, etc.
388398
# todo: handle timeout
@@ -533,6 +543,7 @@ def validate_api(
533543
validate_pagination: bool,
534544
query_config: QueryConfig,
535545
transaction_collection: Optional[str],
546+
headers: Optional[Dict[str, str]],
536547
) -> Tuple[Warnings, Errors]:
537548
warnings = Warnings()
538549
errors = Errors()
@@ -544,6 +555,9 @@ def validate_api(
544555
if auth_query_parameter and (xs := auth_query_parameter.split("=", 1)):
545556
r_session.params = {xs[0]: xs[1]}
546557

558+
if headers:
559+
r_session.headers.update(headers)
560+
547561
_, landing_page_body, landing_page_headers = retrieve(
548562
Method.GET, root_url, errors, Context.CORE, r_session
549563
)
@@ -700,7 +714,7 @@ def validate_api(
700714

701715
if not errors:
702716
try:
703-
catalog = Client.open(root_url)
717+
catalog = Client.open(root_url, headers=headers)
704718
catalog.validate()
705719
for child in catalog.get_children():
706720
child.validate()
@@ -807,7 +821,8 @@ def validate_core(
807821
# this validates, among other things, that the child and item link relations reference
808822
# valid STAC Catalogs, Collections, and/or Items
809823
try:
810-
list(take(1000, Catalog.from_dict(root_body).get_all_items()))
824+
catalog = get_catalog(root_body, r_session)
825+
list(take(1000, catalog.get_all_items()))
811826
except pystac.errors.STACTypeError as e:
812827
errors += (
813828
f"[{Context.CORE}] Error while traversing Catalog child/item links to find Items: {e} "
@@ -835,14 +850,15 @@ def validate_browseable(
835850
# check that at least a few of the items that can be reached from child/item link relations
836851
# can be found through search
837852
try:
838-
for item in take(10, Catalog.from_dict(root_body).get_all_items()):
853+
catalog = get_catalog(root_body, r_session)
854+
for item in take(10, catalog.get_all_items()):
839855
if link := link_by_rel(root_body.get("links"), "search"):
840856
_, body, _ = retrieve(
841857
Method.GET,
842858
link["href"],
843859
errors,
844860
Context.BROWSEABLE,
845-
params={"ids": item.id, "collections": item.collection},
861+
params={"ids": item.id, "collections": item.collection_id},
846862
r_session=r_session,
847863
)
848864
if body and len(body.get("features", [])) != 1:

tests/resources/sample-item.json

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"type": "Feature",
3+
"stac_version": "1.0.0",
4+
"id": "CS3-20160503_132131_05",
5+
"properties": {
6+
"datetime": "2016-05-03T13:22:30.040000Z",
7+
"title": "A CS3 item",
8+
"license": "PDDL-1.0",
9+
"providers": [
10+
{
11+
"name": "CoolSat",
12+
"roles": ["producer", "licensor"],
13+
"url": "https://cool-sat.com/"
14+
}
15+
]
16+
},
17+
"geometry": {
18+
"type": "Polygon",
19+
"coordinates": [
20+
[
21+
[-122.308150179, 37.488035566],
22+
[-122.597502109, 37.538869539],
23+
[-122.576687533, 37.613537207],
24+
[-122.2880486, 37.562818007],
25+
[-122.308150179, 37.488035566]
26+
]
27+
]
28+
},
29+
"links": [
30+
{
31+
"rel": "collection",
32+
"href": "https://raw.githubusercontent.com/radiantearth/stac-spec/v0.8.1/collection-spec/examples/sentinel2.json"
33+
}
34+
],
35+
"assets": {
36+
"analytic": {
37+
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/analytic.tif",
38+
"title": "4-Band Analytic",
39+
"product": "http://cool-sat.com/catalog/products/analytic.json",
40+
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
41+
"roles": ["data", "analytic"]
42+
},
43+
"thumbnail": {
44+
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/thumbnail.png",
45+
"title": "Thumbnail",
46+
"type": "image/png",
47+
"roles": ["thumbnail"]
48+
}
49+
},
50+
"bbox": [-122.59750209, 37.48803556, -122.2880486, 37.613537207],
51+
"stac_extensions": [],
52+
"collection": "CS3"
53+
}

tests/test_main.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Test cases for the __main__ module."""
2+
import unittest.mock
3+
24
import pytest
35
from click.testing import CliRunner
46

@@ -14,3 +16,38 @@ def runner() -> CliRunner:
1416
def test_main_fails(runner: CliRunner) -> None:
1517
result = runner.invoke(__main__.main)
1618
assert result.exit_code == 2
19+
20+
21+
def test_retrieve_called_with_auth_headers(
22+
request: pytest.FixtureRequest, runner: CliRunner
23+
) -> None:
24+
if request.config.getoption("typeguard_packages"):
25+
pytest.skip(
26+
"The import hook that typeguard uses seems to break the mock below."
27+
)
28+
29+
expected_headers = {
30+
"User-Agent": "python-requests/2.28.2",
31+
"Accept-Encoding": "gzip, deflate",
32+
"Accept": "*/*",
33+
"Connection": "keep-alive",
34+
"Authorization": "api-key fake-api-key-value",
35+
}
36+
37+
with unittest.mock.patch(
38+
"stac_api_validator.validations.retrieve"
39+
) as retrieve_mock:
40+
runner.invoke(
41+
__main__.main,
42+
args=[
43+
"--root-url",
44+
"https://invalid",
45+
"--conformance",
46+
"core",
47+
"-H",
48+
"Authorization: api-key fake-api-key-value",
49+
],
50+
)
51+
assert retrieve_mock.call_count == 1
52+
r_session = retrieve_mock.call_args.args[-1]
53+
assert r_session.headers == expected_headers

0 commit comments

Comments
 (0)