Skip to content

Commit 4947030

Browse files
committed
feat!: configurable key to ID conversion
1 parent 1708fcf commit 4947030

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

graphene_sqlalchemy/converter.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def set_non_null_many_relationships(non_null_flag):
9393
use_non_null_many_relationships = non_null_flag
9494

9595

96+
use_id_type_for_keys = True
97+
98+
99+
def set_id_for_keys(id_flag):
100+
global use_id_type_for_keys
101+
use_id_type_for_keys = id_flag
102+
103+
96104
def get_column_doc(column):
97105
return getattr(column, "doc", None)
98106

@@ -259,18 +267,34 @@ def inner(fn):
259267
convert_sqlalchemy_composite.register = _register_composite_class
260268

261269

270+
def _is_primary_or_foreign_key(column):
271+
return getattr(column, "primary_key", False) or (
272+
len(getattr(column, "foreign_keys", [])) > 0
273+
)
274+
275+
262276
def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
263277
column = column_prop.columns[0]
264-
# The converter expects a type to find the right conversion function.
265-
# If we get an instance instead, we need to convert it to a type.
266-
# The conversion function will still be able to access the instance via the column argument.
278+
# We only use the converter if no type was specified using the ORMField
267279
if "type_" not in field_kwargs:
268-
column_type = getattr(column, "type", None)
269-
if not isinstance(column_type, type):
270-
column_type = type(column_type)
280+
# If the column is a primary key, we use the ID typ
281+
if use_id_type_for_keys and _is_primary_or_foreign_key(column):
282+
field_type = graphene.ID
283+
else:
284+
# The converter expects a type to find the right conversion function.
285+
# If we get an instance instead, we need to convert it to a type.
286+
# The conversion function will still be able to access the instance via the column argument.
287+
column_type = getattr(column, "type", None)
288+
if not isinstance(column_type, type):
289+
column_type = type(column_type)
290+
291+
field_type = convert_sqlalchemy_type(
292+
column_type, column=column, registry=registry
293+
)
294+
271295
field_kwargs.setdefault(
272296
"type_",
273-
convert_sqlalchemy_type(column_type, column=column, registry=registry),
297+
field_type,
274298
)
275299
field_kwargs.setdefault("required", not is_column_nullable(column))
276300
field_kwargs.setdefault("description", get_column_doc(column))
@@ -385,10 +409,6 @@ def convert_column_to_int_or_id(
385409
registry: Registry = None,
386410
**kwargs,
387411
):
388-
# fixme drop the primary key processing from here in another pr
389-
if column is not None:
390-
if getattr(column, "primary_key", False) is True:
391-
return graphene.ID
392412
return graphene.Int
393413

394414

graphene_sqlalchemy/tests/test_converter.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import sqlalchemy
77
import sqlalchemy_utils as sqa_utils
8-
from sqlalchemy import Column, func, select, types
8+
from sqlalchemy import Column, ForeignKey, func, select, types
99
from sqlalchemy.dialects import postgresql
1010
from sqlalchemy.ext.declarative import declarative_base
1111
from sqlalchemy.ext.hybrid import hybrid_property
@@ -42,11 +42,13 @@ def mock_resolver():
4242
pass
4343

4444

45-
def get_field(sqlalchemy_type, **column_kwargs):
45+
def get_field(sqlalchemy_type, *column_args, **column_kwargs):
4646
class Model(declarative_base()):
4747
__tablename__ = "model"
4848
id_ = Column(types.Integer, primary_key=True)
49-
column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs)
49+
column = Column(
50+
sqlalchemy_type, *column_args, doc="Custom Help Text", **column_kwargs
51+
)
5052

5153
column_prop = inspect(Model).column_attrs["column"]
5254
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)
@@ -381,12 +383,28 @@ def test_should_integer_convert_int():
381383
assert get_field(types.Integer()).type == graphene.Int
382384

383385

384-
def test_should_primary_integer_convert_id():
386+
def test_should_key_integer_convert_id():
385387
assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(
386388
graphene.ID
387389
)
388390

389391

392+
def test_should_primary_string_convert_id():
393+
assert get_field(types.String(), primary_key=True).type == graphene.NonNull(
394+
graphene.ID
395+
)
396+
397+
398+
def test_should_primary_uuid_convert_id():
399+
assert get_field(sqa_utils.UUIDType, primary_key=True).type == graphene.NonNull(
400+
graphene.ID
401+
)
402+
403+
404+
def test_should_foreign_key_convert_id():
405+
assert get_field(types.Integer(), ForeignKey("model.id_")).type == graphene.ID
406+
407+
390408
def test_should_boolean_convert_boolean():
391409
assert get_field(types.Boolean()).type == graphene.Boolean
392410

0 commit comments

Comments
 (0)