Skip to content

Commit 2119b33

Browse files
committed
Address review comments
1 parent dfc4285 commit 2119b33

File tree

5 files changed

+214
-23
lines changed

5 files changed

+214
-23
lines changed

src/stac_api_validator/__main__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@
6565
help="Query parameter key and value to pass for authorization, e.g., 'key=xyz'.",
6666
)
6767
@click.option(
68-
"--auth-headers",
69-
help="Auth headers to attach to the request, e.g., {'Authorization': 'api-key <api-key>'}",
68+
"-H",
69+
"--headers",
70+
multiple=True,
71+
help="Headers to attach to the main request and dependent pystac requests, curl syntax",
7072
)
7173
def main(
7274
log_level: str,
@@ -76,20 +78,29 @@ def main(
7678
geometry: Optional[str],
7779
auth_bearer_token: Optional[str] = None,
7880
auth_query_parameter: Optional[str] = None,
79-
auth_headers: Optional[str] = None,
81+
headers: Optional[List[str]] = None,
8082
) -> int:
8183
"""STAC API Validator."""
8284
logging.basicConfig(stream=sys.stdout, level=log_level)
8385

8486
try:
87+
processed_headers = {}
88+
if headers:
89+
processed_headers.update(
90+
{
91+
key.strip(): value.strip()
92+
for key, value in (header.split(":") for header in headers)
93+
}
94+
)
95+
8596
(warnings, errors) = validate_api(
8697
root_url=root_url,
8798
ccs_to_validate=conformance_classes,
8899
collection=collection,
89100
geometry=geometry,
90101
auth_bearer_token=auth_bearer_token,
91102
auth_query_parameter=auth_query_parameter,
92-
auth_headers=auth_headers,
103+
headers=processed_headers,
93104
)
94105
except Exception as e:
95106
click.secho(

src/stac_api_validator/validations.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Validations module."""
2-
import ast
32
import itertools
43
import json
54
import logging
@@ -279,7 +278,8 @@ def is_geojson_type(maybe_type: Optional[str]) -> bool:
279278
def get_catalog(data_dict: Dict[str, Any], r_session: Session) -> Catalog:
280279
stac_io = StacIO.default()
281280
if r_session.headers and r_session.headers.get("Authorization"):
282-
stac_io.headers = {"Authorization": str(r_session.headers["Authorization"])}
281+
stac_io.headers = r_session.headers
282+
stac_io.headers["Accept-Encoding"] = "*"
283283
catalog = Catalog.from_dict(data_dict)
284284
catalog._stac_io = stac_io
285285
return catalog
@@ -513,7 +513,7 @@ def validate_api(
513513
geometry: Optional[str],
514514
auth_bearer_token: Optional[str],
515515
auth_query_parameter: Optional[str],
516-
auth_headers: Optional[str],
516+
headers: Optional[Dict[str, str]],
517517
) -> Tuple[Warnings, Errors]:
518518
warnings = Warnings()
519519
errors = Errors()
@@ -525,8 +525,8 @@ def validate_api(
525525
if auth_query_parameter and (xs := auth_query_parameter.split("=", 1)):
526526
r_session.params = {xs[0]: xs[1]}
527527

528-
if auth_headers:
529-
r_session.headers.update(ast.literal_eval(auth_headers))
528+
if headers:
529+
r_session.headers.update(headers)
530530

531531
_, landing_page_body, landing_page_headers = retrieve(
532532
Method.GET, root_url, errors, Context.CORE, r_session
@@ -596,9 +596,6 @@ def validate_api(
596596

597597
if not errors:
598598
try:
599-
headers = {}
600-
if r_session.headers and r_session.headers.get("Authorization"):
601-
headers["Authorization"] = str(r_session.headers["Authorization"])
602599
catalog = Client.open(root_url, headers=headers)
603600
catalog.validate()
604601
for child in catalog.get_children():
@@ -743,11 +740,7 @@ def validate_browseable(
743740
link["href"],
744741
errors,
745742
Context.BROWSEABLE,
746-
params={
747-
"ids": item.id,
748-
"collections": getattr(item, "collection", None)
749-
or getattr(item, "collection_id", None),
750-
},
743+
params={"ids": item.id, "collections": item.collection_id},
751744
r_session=r_session,
752745
)
753746
if body and len(body.get("features", [])) != 1:

tests/resources/sample-item.json

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"type": "Feature",
3+
"stac_version": "1.0.0",
4+
"id": "CS3-20160503_132131_05",
5+
"properties": {
6+
"datetime": "2016-05-03T13:22:30.040000Z",
7+
"title": "A CS3 item",
8+
"license": "PDDL-1.0",
9+
"providers": [
10+
{
11+
"name": "CoolSat",
12+
"roles": ["producer", "licensor"],
13+
"url": "https://cool-sat.com/"
14+
}
15+
]
16+
},
17+
"geometry": {
18+
"type": "Polygon",
19+
"coordinates": [
20+
[
21+
[-122.308150179, 37.488035566],
22+
[-122.597502109, 37.538869539],
23+
[-122.576687533, 37.613537207],
24+
[-122.2880486, 37.562818007],
25+
[-122.308150179, 37.488035566]
26+
]
27+
]
28+
},
29+
"links": [
30+
{
31+
"rel": "collection",
32+
"href": "https://raw.githubusercontent.com/radiantearth/stac-spec/v0.8.1/collection-spec/examples/sentinel2.json"
33+
}
34+
],
35+
"assets": {
36+
"analytic": {
37+
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/analytic.tif",
38+
"title": "4-Band Analytic",
39+
"product": "http://cool-sat.com/catalog/products/analytic.json",
40+
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
41+
"roles": ["data", "analytic"]
42+
},
43+
"thumbnail": {
44+
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/thumbnail.png",
45+
"title": "Thumbnail",
46+
"type": "image/png",
47+
"roles": ["thumbnail"]
48+
}
49+
},
50+
"bbox": [-122.59750209, 37.48803556, -122.2880486, 37.613537207],
51+
"stac_extensions": [],
52+
"collection": "CS3"
53+
}

tests/test_main.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@ def test_retrieve_called_with_auth_headers(
2626
"The import hook that typeguard uses seems to break the mock below."
2727
)
2828

29-
expected_auth_header = {"Authorization": "api-key fake-api-key-value"}
29+
expected_headers = {
30+
"User-Agent": "python-requests/2.28.2",
31+
"Accept-Encoding": "gzip, deflate",
32+
"Accept": "*/*",
33+
"Connection": "keep-alive",
34+
"Authorization": "api-key fake-api-key-value",
35+
}
36+
3037
with unittest.mock.patch(
3138
"stac_api_validator.validations.retrieve"
3239
) as retrieve_mock:
@@ -37,12 +44,10 @@ def test_retrieve_called_with_auth_headers(
3744
"https://invalid",
3845
"--conformance",
3946
"core",
40-
"--auth-headers",
41-
f"{expected_auth_header}",
47+
"-H",
48+
"Authorization: api-key fake-api-key-value",
4249
],
4350
)
4451
assert retrieve_mock.call_count == 1
4552
r_session = retrieve_mock.call_args.args[-1]
46-
assert (
47-
r_session.headers["Authorization"] == expected_auth_header["Authorization"]
48-
)
53+
assert r_session.headers == expected_headers

tests/test_validations.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
Test cases for the 'validations' module
3+
"""
4+
import json
5+
import os
6+
import pathlib
7+
import unittest.mock
8+
from copy import copy
9+
10+
import pystac
11+
import pytest
12+
import requests
13+
14+
from stac_api_validator import validations
15+
16+
17+
@pytest.fixture
18+
def r_session():
19+
yield requests.Session()
20+
21+
22+
@pytest.fixture
23+
def catalog_dict():
24+
current_path = pathlib.Path(os.path.dirname(os.path.abspath(__file__)))
25+
26+
with open(current_path / "resources" / "landing_page.json") as f:
27+
# Load the contents of the file into a Python dictionary
28+
data = json.load(f)
29+
30+
yield data
31+
32+
33+
@pytest.fixture
34+
def sample_item():
35+
current_path = pathlib.Path(os.path.dirname(os.path.abspath(__file__)))
36+
37+
with open(current_path / "resources" / "sample-item.json") as f:
38+
# Load the contents of the file into a Python dictionary
39+
data = json.load(f)
40+
41+
yield pystac.Item.from_dict(data)
42+
43+
44+
@pytest.fixture
45+
def expected_headers():
46+
yield {
47+
"User-Agent": "python-requests/2.28.2",
48+
"Accept-Encoding": "gzip, deflate",
49+
"Accept": "*/*",
50+
"Connection": "keep-alive",
51+
"Authorization": "api-key fake-api-key-value",
52+
}
53+
54+
55+
def test_get_catalog(r_session, catalog_dict, expected_headers):
56+
r_session.headers = copy(expected_headers)
57+
expected_headers.update({"Accept-Encoding": "*"})
58+
59+
catalog = validations.get_catalog(catalog_dict, r_session)
60+
assert catalog._stac_io.headers == expected_headers
61+
62+
63+
def test_retrieve(r_session, expected_headers):
64+
headers = {"Authorization": "api-key fake-api-key-value"}
65+
r_session.send = unittest.mock.MagicMock()
66+
r_session.send.status_code = 500
67+
68+
validations.retrieve(
69+
validations.Method.GET,
70+
"https://invalid",
71+
validations.Errors(),
72+
validations.Context.CORE,
73+
r_session=r_session,
74+
headers=headers,
75+
)
76+
assert r_session.send.call_count == 1
77+
prepared_request_headers = r_session.send.call_args_list[0].args[0].headers
78+
assert prepared_request_headers == expected_headers
79+
80+
81+
def test_validate_api(request, r_session, expected_headers):
82+
if request.config.getoption("typeguard_packages"):
83+
pytest.skip(
84+
"The import hook that typeguard uses seems to break the mock below."
85+
)
86+
headers = {"Authorization": "api-key fake-api-key-value"}
87+
88+
with unittest.mock.patch(
89+
"stac_api_validator.validations.retrieve"
90+
) as retrieve_mock:
91+
retrieve_mock.return_value = None, None, None
92+
validations.validate_api(
93+
"https://invalid",
94+
ccs_to_validate=["core"],
95+
collection=None,
96+
geometry=None,
97+
auth_bearer_token=None,
98+
auth_query_parameter=None,
99+
headers=headers,
100+
)
101+
assert retrieve_mock.call_count == 1
102+
r_session = retrieve_mock.call_args.args[-1]
103+
assert r_session.headers == expected_headers
104+
105+
106+
def test_validate_browseable(
107+
request, r_session, catalog_dict, sample_item, expected_headers
108+
):
109+
if request.config.getoption("typeguard_packages"):
110+
pytest.skip(
111+
"The import hook that typeguard uses seems to break the mock below."
112+
)
113+
114+
r_session.headers = copy(expected_headers)
115+
116+
with unittest.mock.patch(
117+
"stac_api_validator.validations.get_catalog"
118+
) as get_catalog_mock:
119+
get_catalog_mock.get_all_items.return_value = [sample_item]
120+
121+
validations.validate_browseable(
122+
catalog_dict,
123+
errors=validations.Errors(),
124+
warnings=validations.Warnings(),
125+
r_session=r_session,
126+
)
127+
assert get_catalog_mock.call_count == 1
128+
session_from_mock = get_catalog_mock.call_args.args[-1]
129+
assert session_from_mock.headers == expected_headers

0 commit comments

Comments
 (0)