Skip to content

Commit 0b1c365

Browse files
This commit fixes two issues with content filtering
-- `ZipContentFilter` logic for patterns that have a single allowed pattern. Earlier this would include all files. -- `generate_content_hash` also considers the patterns before generating the hash so that the hash generated matches the content of the zip.
1 parent 4f77bfc commit 0b1c365

File tree

2 files changed

+101
-10
lines changed

2 files changed

+101
-10
lines changed

package.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,20 @@ def emit_dir_content(base_dir):
235235
yield os.path.normpath(os.path.join(root, name))
236236

237237

238-
def generate_content_hash(source_paths, hash_func=hashlib.sha256, log=None):
238+
def generate_content_hash(
239+
source_paths,
240+
content_filter,
241+
hash_func=hashlib.sha256,
242+
log=None,
243+
):
239244
"""
240245
Generate a content hash of the source paths.
246+
247+
:param content_filter: Callable[[str], Iterable[str]
248+
A function that filters the content of the source paths. Given a path
249+
to a file or directory, it should return an iterable of paths to files
250+
that should be included in the hash. At present we pass in the
251+
ZipContentFilter.filter method for this purpose.
241252
"""
242253

243254
if log:
@@ -248,8 +259,7 @@ def generate_content_hash(source_paths, hash_func=hashlib.sha256, log=None):
248259
for source_path in source_paths:
249260
if os.path.isdir(source_path):
250261
source_dir = source_path
251-
_log = log if log.isEnabledFor(DEBUG3) else None
252-
for source_file in list_files(source_dir, log=_log):
262+
for source_file in content_filter(source_dir):
253263
update_hash(hash_obj, source_dir, source_file)
254264
if log:
255265
log.debug(os.path.join(source_dir, source_file))
@@ -589,10 +599,8 @@ def apply(path):
589599
op, regex = r
590600
neg = op is operator.not_
591601
m = regex.fullmatch(path)
592-
if neg and m:
593-
d = False
594-
elif m:
595-
d = True
602+
m = bool(m)
603+
d = not m if neg else m
596604
if d:
597605
return path
598606

@@ -648,6 +656,7 @@ class BuildPlanManager:
648656
def __init__(self, args, log=None):
649657
self._args = args
650658
self._source_paths = None
659+
self._patterns = []
651660
self._log = log or logging.root
652661

653662
def hash(self, extra_paths):
@@ -660,7 +669,11 @@ def hash(self, extra_paths):
660669
# runtime value, build command, and content of the build paths
661670
# because they can have an effect on the resulting archive.
662671
self._log.debug("Computing content hash on files...")
663-
content_hash = generate_content_hash(content_hash_paths, log=self._log)
672+
content_filter = ZipContentFilter(args=self._args)
673+
content_filter.compile(self._patterns)
674+
content_hash = generate_content_hash(
675+
content_hash_paths, content_filter.filter, log=self._log
676+
)
664677
return content_hash
665678

666679
def plan(self, source_path, query):
@@ -800,7 +813,9 @@ def commands_step(path, commands):
800813
patterns = claim.get("patterns")
801814
commands = claim.get("commands")
802815
if patterns:
803-
step("set:filter", patterns_list(self._args, patterns))
816+
patterns_as_list = patterns_list(self._args, patterns)
817+
self._patterns.extend(patterns_as_list)
818+
step("set:filter", patterns_as_list)
804819
if commands:
805820
commands_step(path, commands)
806821
else:

tests/test_zip_source.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
2+
from pathlib import Path
23
from unittest.mock import MagicMock, Mock
34

4-
from package import BuildPlanManager
5+
import pytest
6+
7+
from package import BuildPlanManager, ZipContentFilter, datatree
58

69

710
def test_zip_source_path_sh_work_dir():
@@ -44,3 +47,76 @@ def test_zip_source_path():
4447

4548
zip_source_path = zs.write_dirs.call_args_list[0][0][0]
4649
assert zip_source_path == f"{os.getcwd()}"
50+
51+
52+
@pytest.fixture
53+
def source_path(tmp_path: str) -> Path:
54+
"""Creates a tmp stage dir for running tests."""
55+
tmp_path = Path(tmp_path)
56+
source = tmp_path / "some_dir"
57+
source.mkdir()
58+
for files in ["file.py", "file2.py", "README.md", "requirements.txt"]:
59+
(source / files).touch()
60+
yield source
61+
62+
63+
def test_zip_content_filter(source_path: Path):
64+
"""Test the zip content filter does not take all positive."""
65+
args = Mock()
66+
query_data = {
67+
"runtime": "python",
68+
"source_path": {
69+
"path": str(source_path),
70+
"patterns": [".*.py$"],
71+
},
72+
}
73+
query = datatree("prepare_query", **query_data)
74+
75+
file_filter = ZipContentFilter(args=args)
76+
file_filter.compile(query.source_path.patterns)
77+
filtered = list(file_filter.filter(query.source_path.path))
78+
expected = [str(source_path / fname) for fname in ["file.py", "file2.py"]]
79+
assert filtered == sorted(expected)
80+
81+
# Test that filtering with empty patterns returns all files.
82+
file_filter = ZipContentFilter(args=args)
83+
file_filter.compile([])
84+
filtered = list(file_filter.filter(query.source_path.path))
85+
expected = [
86+
str(source_path / fname)
87+
for fname in
88+
["file.py", "file2.py", "README.md", "requirements.txt"]
89+
]
90+
assert filtered == sorted(expected)
91+
92+
93+
def test_generate_hash(source_path: Path):
94+
"""Tests prepare hash generation and also packaging."""
95+
args = Mock()
96+
97+
query_data = {
98+
"runtime": "python",
99+
"source_path": {
100+
"path": str(source_path),
101+
"patterns": ["!.*", ".*.py$"],
102+
},
103+
}
104+
query = datatree("prepare_query", **query_data)
105+
106+
bpm = BuildPlanManager(args)
107+
bpm.plan(query.source_path, query)
108+
hash1 = bpm.hash([]).hexdigest()
109+
110+
# Add a new file that does not match the pattern.
111+
(source_path / "file3.pyc").touch()
112+
bpm.plan(query.source_path, query)
113+
hash2 = bpm.hash([]).hexdigest()
114+
# Both hashes should still be the same.
115+
assert hash1 == hash2
116+
117+
# Add a new file that does match the pattern.
118+
(source_path / "file4.py").touch()
119+
bpm.plan(query.source_path, query)
120+
hash3 = bpm.hash([]).hexdigest()
121+
# Hash should be different.
122+
assert hash1 != hash3

0 commit comments

Comments
 (0)