Skip to content

Raise AttributeError on attempts to access unset oneof fields #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 22.3.0
rev: 23.1.0
hooks:
- id: black
args: ["--target-version", "py310"]

- repo: https://github.com/PyCQA/doc8
rev: 0.10.1
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.7"
black = { version = ">=19.3b0", optional = true }
black = { version = ">=23.1.0", optional = true }
grpclib = "^0.4.1"
importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
Expand Down Expand Up @@ -62,7 +62,7 @@ cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy"

[tool.poe.tasks.format]
cmd = "black . --exclude tests/output_"
cmd = "black . --exclude tests/output_ --target-version py310"
help = "Apply black formatting to source code"

[tool.poe.tasks.docs]
Expand Down
46 changes: 37 additions & 9 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,28 @@ def __repr__(self) -> str:
def __getattribute__(self, name: str) -> Any:
"""
Lazily initialize default values to avoid infinite recursion for recursive
message types
message types.
Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
"""
try:
group_current = super().__getattribute__("_group_current")
except AttributeError:
pass
else:
if name not in {"__class__", "_betterproto"}:
group = self._betterproto.oneof_group_by_field.get(name)
if group is not None and group_current[group] != name:
if sys.version_info < (3, 10):
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}"
)
else:
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}",
name=name,
obj=self,
)

value = super().__getattribute__(name)
if value is not PLACEHOLDER:
return value
Expand Down Expand Up @@ -761,7 +781,10 @@ def __bytes__(self) -> bytes:
"""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
continue

if value is None:
# Optional items should be skipped. This is used for the Google
Expand All @@ -775,9 +798,7 @@ def __bytes__(self) -> bytes:
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = (
meta.group and self._group_current[meta.group] == field_name
)
selected_in_group = bool(meta.group)

# Empty messages can still be sent on the wire if they were
# set (or received empty).
Expand Down Expand Up @@ -1016,7 +1037,12 @@ def parse(self: T, data: bytes) -> T:
parsed.wire_type, meta, field_name, parsed.value
)

current = getattr(self, field_name)
try:
current = getattr(self, field_name)
except AttributeError:
current = self._get_field_default(field_name)
setattr(self, field_name, current)

if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
Expand Down Expand Up @@ -1077,7 +1103,10 @@ def to_dict(
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
value = self._get_field_default(field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
Expand Down Expand Up @@ -1209,7 +1238,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:

if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
Expand Down Expand Up @@ -1486,7 +1515,6 @@ def _validate_field_groups(cls, values):
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore

for group, field_set in group_to_one_ofs.items():

if len(field_set) == 1:
(field,) = field_set
field_name = field.name
Expand Down
1 change: 0 additions & 1 deletion src/betterproto/grpc/grpclib_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ async def _call_rpc_handler_server_stream(
stream: grpclib.server.Stream,
request: Any,
) -> None:

response_iter = handler(request)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
Expand Down
1 change: 0 additions & 1 deletion src/betterproto/plugin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


def outputfile_compiler(output_file: OutputTemplate) -> str:

templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates")
)
Expand Down
1 change: 0 additions & 1 deletion src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def _make_one_of_field_compiler(
proto_obj: "FieldDescriptorProto",
path: List[int],
) -> FieldCompiler:

pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():

# None of these fields were explicitly set BUT they should not actually be null
# themselves
assert isinstance(message.foo, Foo)
assert isinstance(message2.foo, Foo)
assert not hasattr(message, "foo")
assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
assert not hasattr(message2, "foo")
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER

assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo)
Expand Down
13 changes: 6 additions & 7 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)

assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)

Expand All @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)

Expand All @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert not hasattr(message, "signal")
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
46 changes: 46 additions & 0 deletions tests/oneof_pattern_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass

import pytest

import betterproto


def test_oneof_pattern_matching():
@dataclass
class Sub(betterproto.Message):
val: int = betterproto.int32_field(1)

@dataclass
class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1")
sub: Sub = betterproto.message_field(3, group="group2")
abc: str = betterproto.string_field(4, group="group2")

foo = Foo(baz="test1", abc="test2")

match foo:
case Foo(bar=_):
pytest.fail("Matched 'bar' instead of 'baz'")
case Foo(baz=v):
assert v == "test1"
case _:
pytest.fail("Matched neither 'bar' nor 'baz'")

match foo:
case Foo(sub=_):
pytest.fail("Matched 'sub' instead of 'abc'")
case Foo(abc=v):
assert v == "test2"
case _:
pytest.fail("Matched neither 'sub' nor 'abc'")

foo.sub = Sub(val=1)

match foo:
case Foo(sub=Sub(val=v)):
assert v == 1
case Foo(abc=v):
pytest.fail("Matched 'abc' instead of 'sub'")
case _:
pytest.fail("Matched neither 'sub' nor 'abc'")
22 changes: 18 additions & 4 deletions tests/test_features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import sys
from copy import (
copy,
deepcopy,
Expand All @@ -18,6 +19,8 @@
Optional,
)

import pytest

import betterproto


Expand Down Expand Up @@ -151,17 +154,18 @@ class Foo(betterproto.Message):
foo.baz = "test"

# Other oneof fields should now be unset
assert foo.bar == 0
assert not hasattr(foo, "bar")
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group1")[0] == "baz"

foo.sub.val = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a behaviour change?

Copy link
Contributor Author

@a-khabarov a-khabarov Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR makes it impossible to use foo.sub.val = 1 when foo.sub is unset in the group.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit of a tradeoff. With this change the users of betterproto can no longer use the foo.sub.val = 1 syntax for fields that are unset in groups, but this also means that there is less risk of them accidentally changing which field is set in a group.

foo.sub = Sub(val=1)
assert betterproto.serialized_on_wire(foo.sub)

foo.abc = "test"

# Group 1 shouldn't be touched, group 2 should have reset
assert foo.sub.val == 0
assert betterproto.serialized_on_wire(foo.sub) is False
assert not hasattr(foo, "sub")
assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group2")[0] == "abc"

# Zero value should always serialize for one-of
Expand All @@ -176,6 +180,16 @@ class Foo(betterproto.Message):
assert betterproto.which_one_of(foo2, "group2")[0] == ""


@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="pattern matching is only supported in python3.10+",
)
def test_oneof_pattern_matching():
from .oneof_pattern_matching import test_oneof_pattern_matching

test_oneof_pattern_matching()


def test_json_casing():
@dataclass
class CasingTest(betterproto.Message):
Expand Down