Skip to content

Commit 6faac1d

Browse files
authored
Raise AttributeError on attempts to access unset oneof fields (#510)
1 parent 098989e commit 6faac1d

File tree

11 files changed

+116
-29
lines changed

11 files changed

+116
-29
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ repos:
88
- id: isort
99

1010
- repo: https://github.com/psf/black
11-
rev: 22.3.0
11+
rev: 23.1.0
1212
hooks:
1313
- id: black
14+
args: ["--target-version", "py310"]
1415

1516
- repo: https://github.com/PyCQA/doc8
1617
rev: 0.10.1

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ packages = [
1313

1414
[tool.poetry.dependencies]
1515
python = "^3.7"
16-
black = { version = ">=19.3b0", optional = true }
16+
black = { version = ">=23.1.0", optional = true }
1717
grpclib = "^0.4.1"
1818
importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
1919
jinja2 = { version = ">=3.0.3", optional = true }
@@ -62,7 +62,7 @@ cmd = "mypy src --ignore-missing-imports"
6262
help = "Check types with mypy"
6363

6464
[tool.poe.tasks.format]
65-
cmd = "black . --exclude tests/output_"
65+
cmd = "black . --exclude tests/output_ --target-version py310"
6666
help = "Apply black formatting to source code"
6767

6868
[tool.poe.tasks.docs]

src/betterproto/__init__.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,28 @@ def __repr__(self) -> str:
693693
def __getattribute__(self, name: str) -> Any:
694694
"""
695695
Lazily initialize default values to avoid infinite recursion for recursive
696-
message types
696+
message types.
697+
Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
697698
"""
699+
try:
700+
group_current = super().__getattribute__("_group_current")
701+
except AttributeError:
702+
pass
703+
else:
704+
if name not in {"__class__", "_betterproto"}:
705+
group = self._betterproto.oneof_group_by_field.get(name)
706+
if group is not None and group_current[group] != name:
707+
if sys.version_info < (3, 10):
708+
raise AttributeError(
709+
f"{group!r} is set to {group_current[group]!r}, not {name!r}"
710+
)
711+
else:
712+
raise AttributeError(
713+
f"{group!r} is set to {group_current[group]!r}, not {name!r}",
714+
name=name,
715+
obj=self,
716+
)
717+
698718
value = super().__getattribute__(name)
699719
if value is not PLACEHOLDER:
700720
return value
@@ -761,7 +781,10 @@ def __bytes__(self) -> bytes:
761781
"""
762782
output = bytearray()
763783
for field_name, meta in self._betterproto.meta_by_field_name.items():
764-
value = getattr(self, field_name)
784+
try:
785+
value = getattr(self, field_name)
786+
except AttributeError:
787+
continue
765788

766789
if value is None:
767790
# Optional items should be skipped. This is used for the Google
@@ -775,9 +798,7 @@ def __bytes__(self) -> bytes:
775798
# Note that proto3 field presence/optional fields are put in a
776799
# synthetic single-item oneof by protoc, which helps us ensure we
777800
# send the value even if the value is the default zero value.
778-
selected_in_group = (
779-
meta.group and self._group_current[meta.group] == field_name
780-
)
801+
selected_in_group = bool(meta.group)
781802

782803
# Empty messages can still be sent on the wire if they were
783804
# set (or received empty).
@@ -1016,7 +1037,12 @@ def parse(self: T, data: bytes) -> T:
10161037
parsed.wire_type, meta, field_name, parsed.value
10171038
)
10181039

1019-
current = getattr(self, field_name)
1040+
try:
1041+
current = getattr(self, field_name)
1042+
except AttributeError:
1043+
current = self._get_field_default(field_name)
1044+
setattr(self, field_name, current)
1045+
10201046
if meta.proto_type == TYPE_MAP:
10211047
# Value represents a single key/value pair entry in the map.
10221048
current[value.key] = value.value
@@ -1077,7 +1103,10 @@ def to_dict(
10771103
defaults = self._betterproto.default_gen
10781104
for field_name, meta in self._betterproto.meta_by_field_name.items():
10791105
field_is_repeated = defaults[field_name] is list
1080-
value = getattr(self, field_name)
1106+
try:
1107+
value = getattr(self, field_name)
1108+
except AttributeError:
1109+
value = self._get_field_default(field_name)
10811110
cased_name = casing(field_name).rstrip("_") # type: ignore
10821111
if meta.proto_type == TYPE_MESSAGE:
10831112
if isinstance(value, datetime):
@@ -1209,7 +1238,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:
12091238

12101239
if value[key] is not None:
12111240
if meta.proto_type == TYPE_MESSAGE:
1212-
v = getattr(self, field_name)
1241+
v = self._get_field_default(field_name)
12131242
cls = self._betterproto.cls_by_field[field_name]
12141243
if isinstance(v, list):
12151244
if cls == datetime:
@@ -1486,7 +1515,6 @@ def _validate_field_groups(cls, values):
14861515
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
14871516

14881517
for group, field_set in group_to_one_ofs.items():
1489-
14901518
if len(field_set) == 1:
14911519
(field,) = field_set
14921520
field_name = field.name

src/betterproto/grpc/grpclib_server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ async def _call_rpc_handler_server_stream(
2121
stream: grpclib.server.Stream,
2222
request: Any,
2323
) -> None:
24-
2524
response_iter = handler(request)
2625
# check if response is actually an AsyncIterator
2726
# this might be false if the method just returns without

src/betterproto/plugin/compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222

2323
def outputfile_compiler(output_file: OutputTemplate) -> str:
24-
2524
templates_folder = os.path.abspath(
2625
os.path.join(os.path.dirname(__file__), "..", "templates")
2726
)

src/betterproto/plugin/parser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def _make_one_of_field_compiler(
159159
proto_obj: "FieldDescriptorProto",
160160
path: List[int],
161161
) -> FieldCompiler:
162-
163162
pydantic = output_package.pydantic_dataclasses
164163
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
165164
return Cls(

tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():
5050

5151
# None of these fields were explicitly set BUT they should not actually be null
5252
# themselves
53-
assert isinstance(message.foo, Foo)
54-
assert isinstance(message2.foo, Foo)
53+
assert not hasattr(message, "foo")
54+
assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
55+
assert not hasattr(message2, "foo")
56+
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER
5557

5658
assert isinstance(message_reference.foo, ReferenceFoo)
5759
assert isinstance(message_reference2.foo, ReferenceFoo)

tests/inputs/oneof_enum/test_oneof_enum.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
1818
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
1919
)
2020

21-
assert message.move == Move(
22-
x=0, y=0
23-
) # Proto3 will default this as there is no null
21+
assert not hasattr(message, "move")
22+
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
2423
assert message.signal == Signal.PASS
2524
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)
2625

@@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
3332
message.from_json(
3433
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
3534
)
36-
assert message.move == Move(
37-
x=0, y=0
38-
) # Proto3 will default this as there is no null
35+
assert not hasattr(message, "move")
36+
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
3937
assert message.signal == Signal.RESIGN
4038
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
4139

@@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
4442
message = Test()
4543
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
4644
assert message.move == Move(x=2, y=3)
47-
assert message.signal == Signal.PASS
45+
assert not hasattr(message, "signal")
46+
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
4847
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

tests/oneof_pattern_matching.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from dataclasses import dataclass
2+
3+
import pytest
4+
5+
import betterproto
6+
7+
8+
def test_oneof_pattern_matching():
9+
@dataclass
10+
class Sub(betterproto.Message):
11+
val: int = betterproto.int32_field(1)
12+
13+
@dataclass
14+
class Foo(betterproto.Message):
15+
bar: int = betterproto.int32_field(1, group="group1")
16+
baz: str = betterproto.string_field(2, group="group1")
17+
sub: Sub = betterproto.message_field(3, group="group2")
18+
abc: str = betterproto.string_field(4, group="group2")
19+
20+
foo = Foo(baz="test1", abc="test2")
21+
22+
match foo:
23+
case Foo(bar=_):
24+
pytest.fail("Matched 'bar' instead of 'baz'")
25+
case Foo(baz=v):
26+
assert v == "test1"
27+
case _:
28+
pytest.fail("Matched neither 'bar' nor 'baz'")
29+
30+
match foo:
31+
case Foo(sub=_):
32+
pytest.fail("Matched 'sub' instead of 'abc'")
33+
case Foo(abc=v):
34+
assert v == "test2"
35+
case _:
36+
pytest.fail("Matched neither 'sub' nor 'abc'")
37+
38+
foo.sub = Sub(val=1)
39+
40+
match foo:
41+
case Foo(sub=Sub(val=v)):
42+
assert v == 1
43+
case Foo(abc=v):
44+
pytest.fail("Matched 'abc' instead of 'sub'")
45+
case _:
46+
pytest.fail("Matched neither 'sub' nor 'abc'")

tests/test_features.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import sys
23
from copy import (
34
copy,
45
deepcopy,
@@ -18,6 +19,8 @@
1819
Optional,
1920
)
2021

22+
import pytest
23+
2124
import betterproto
2225

2326

@@ -151,17 +154,18 @@ class Foo(betterproto.Message):
151154
foo.baz = "test"
152155

153156
# Other oneof fields should now be unset
154-
assert foo.bar == 0
157+
assert not hasattr(foo, "bar")
158+
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
155159
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
156160

157-
foo.sub.val = 1
161+
foo.sub = Sub(val=1)
158162
assert betterproto.serialized_on_wire(foo.sub)
159163

160164
foo.abc = "test"
161165

162166
# Group 1 shouldn't be touched, group 2 should have reset
163-
assert foo.sub.val == 0
164-
assert betterproto.serialized_on_wire(foo.sub) is False
167+
assert not hasattr(foo, "sub")
168+
assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER
165169
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
166170

167171
# Zero value should always serialize for one-of
@@ -176,6 +180,16 @@ class Foo(betterproto.Message):
176180
assert betterproto.which_one_of(foo2, "group2")[0] == ""
177181

178182

183+
@pytest.mark.skipif(
184+
sys.version_info < (3, 10),
185+
reason="pattern matching is only supported in python3.10+",
186+
)
187+
def test_oneof_pattern_matching():
188+
from .oneof_pattern_matching import test_oneof_pattern_matching
189+
190+
test_oneof_pattern_matching()
191+
192+
179193
def test_json_casing():
180194
@dataclass
181195
class CasingTest(betterproto.Message):

0 commit comments

Comments
 (0)