Skip to content

Commit 1ece135

Browse files
author
Diptorup Deb
authored
Merge pull request #1166 from IntelPython/feature/flattened_member_count
Feature/flattened member count
2 parents 1dda097 + c4d2f40 commit 1ece135

File tree

4 files changed

+107
-40
lines changed

4 files changed

+107
-40
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,30 @@
2020
)
2121

2222

23+
def _get_flattened_member_count(ty):
24+
"""Return the number of fields in an instance of a given StructModel."""
25+
flattened_member_count = 0
26+
members = ty._members
27+
for member in members:
28+
if isinstance(member, types.UniTuple):
29+
flattened_member_count += member.count
30+
elif isinstance(
31+
member,
32+
(
33+
types.scalars.Integer,
34+
types.misc.PyObject,
35+
types.misc.RawPointer,
36+
types.misc.CPointer,
37+
types.misc.MemInfoPointer,
38+
),
39+
):
40+
flattened_member_count += 1
41+
else:
42+
raise UnreachableError
43+
44+
return flattened_member_count
45+
46+
2347
class GenericPointerModel(PrimitiveModel):
2448
def __init__(self, dmm, fe_type):
2549
adrsp = (
@@ -68,26 +92,7 @@ def __init__(self, dmm, fe_type):
6892
@property
6993
def flattened_field_count(self):
7094
"""Return the number of fields in an instance of a USMArrayModel."""
71-
flattened_member_count = 0
72-
members = self._members
73-
for member in members:
74-
if isinstance(member, types.UniTuple):
75-
flattened_member_count += member.count
76-
elif isinstance(
77-
member,
78-
(
79-
types.scalars.Integer,
80-
types.misc.PyObject,
81-
types.misc.RawPointer,
82-
types.misc.CPointer,
83-
types.misc.MemInfoPointer,
84-
),
85-
):
86-
flattened_member_count += 1
87-
else:
88-
raise UnreachableError
89-
90-
return flattened_member_count
95+
return _get_flattened_member_count(self)
9196

9297

9398
class DpnpNdArrayModel(StructModel):
@@ -121,26 +126,7 @@ def __init__(self, dmm, fe_type):
121126
@property
122127
def flattened_field_count(self):
123128
"""Return the number of fields in an instance of a DpnpNdArrayModel."""
124-
flattened_member_count = 0
125-
members = self._members
126-
for member in members:
127-
if isinstance(member, types.UniTuple):
128-
flattened_member_count += member.count
129-
elif isinstance(
130-
member,
131-
(
132-
types.scalars.Integer,
133-
types.misc.PyObject,
134-
types.misc.RawPointer,
135-
types.misc.CPointer,
136-
types.misc.MemInfoPointer,
137-
),
138-
):
139-
flattened_member_count += 1
140-
else:
141-
raise UnreachableError
142-
143-
return flattened_member_count
129+
return _get_flattened_member_count(self)
144130

145131

146132
class SyclQueueModel(StructModel):
@@ -211,6 +197,11 @@ def __init__(self, dmm, fe_type):
211197
]
212198
super(RangeModel, self).__init__(dmm, fe_type, members)
213199

200+
@property
201+
def flattened_field_count(self):
202+
"""Return the number of fields in an instance of a RangeModel."""
203+
return _get_flattened_member_count(self)
204+
214205

215206
class NdRangeModel(StructModel):
216207
"""The native data model for a
@@ -229,6 +220,11 @@ def __init__(self, dmm, fe_type):
229220
]
230221
super(NdRangeModel, self).__init__(dmm, fe_type, members)
231222

223+
@property
224+
def flattened_field_count(self):
225+
"""Return the number of fields in an instance of a NdRangeModel."""
226+
return _get_flattened_member_count(self)
227+
232228

233229
def _init_data_model_manager() -> datamodel.DataModelManager:
234230
"""Initializes a DpexKernelTarget-specific data model manager.

numba_dpex/tests/core/types/DpnpNdArray/test_models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from numba import types
66
from numba.core.datamodel import default_manager, models
7+
from numba.core.registry import cpu_target
78

89
from numba_dpex.core.datamodel.models import (
910
DpnpNdArrayModel,
1011
USMArrayModel,
1112
dpex_data_model_manager,
1213
)
14+
from numba_dpex.core.descriptor import dpex_kernel_target
1315
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
1416

1517

@@ -31,3 +33,22 @@ def test_dpnp_ndarray_Model():
3133
"""
3234

3335
assert issubclass(DpnpNdArrayModel, models.StructModel)
36+
37+
38+
def test_flattened_member_count():
39+
"""Test that the number of flattened member count matches the number of
40+
flattened args generated by the CpuTarget's ArgPacker.
41+
"""
42+
43+
cputargetctx = cpu_target.target_context
44+
kerneltargetctx = dpex_kernel_target.target_context
45+
dpex_dmm = kerneltargetctx.data_model_manager
46+
47+
for ndim in range(4):
48+
dty = DpnpNdArray(ndim)
49+
argty_tuple = tuple([dty])
50+
datamodel = dpex_dmm.lookup(dty)
51+
num_flattened_args = datamodel.flattened_field_count
52+
ap = cputargetctx.get_arg_packer(argty_tuple)
53+
54+
assert num_flattened_args == len(ap._be_args)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba.core.registry import cpu_target
6+
7+
from numba_dpex.core.descriptor import dpex_kernel_target
8+
from numba_dpex.core.types.usm_ndarray_type import USMNdArray
9+
10+
11+
def test_flattened_member_count():
12+
"""Test that the number of flattened member count matches the number of
13+
flattened args generated by the CpuTarget's ArgPacker.
14+
"""
15+
16+
cputargetctx = cpu_target.target_context
17+
kerneltargetctx = dpex_kernel_target.target_context
18+
dpex_dmm = kerneltargetctx.data_model_manager
19+
20+
for ndim in range(4):
21+
dty = USMNdArray(ndim)
22+
argty_tuple = tuple([dty])
23+
datamodel = dpex_dmm.lookup(dty)
24+
num_flattened_args = datamodel.flattened_field_count
25+
ap = cputargetctx.get_arg_packer(argty_tuple)
26+
27+
assert num_flattened_args == len(ap._be_args)

numba_dpex/tests/core/types/range_types/test_data_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44

55
import pytest
66
from numba.core.datamodel import default_manager
7+
from numba.core.registry import cpu_target
78

89
from numba_dpex.core.datamodel.models import (
910
NdRangeModel,
1011
RangeModel,
1112
dpex_data_model_manager,
1213
)
14+
from numba_dpex.core.descriptor import dpex_kernel_target
1315
from numba_dpex.core.types.range_types import NdRangeType, RangeType
1416

1517
rfields = ["ndim", "dim0", "dim1", "dim2"]
1618
ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"]
19+
range_tys = [RangeType, NdRangeType]
1720

1821

1922
def test_datamodel_registration():
@@ -58,3 +61,23 @@ def test_ndrange_model_fields(field):
5861
dm.get_field_position(field)
5962
except:
6063
pytest.fail(f"Expected field {field} not present in NdRangeModel")
64+
65+
66+
@pytest.mark.parametrize("range_type", range_tys)
67+
def test_flattened_member_count(range_type):
68+
"""Test that the number of flattened member count matches the number of
69+
flattened args generated by the CpuTarget's ArgPacker.
70+
"""
71+
72+
cputargetctx = cpu_target.target_context
73+
kerneltargetctx = dpex_kernel_target.target_context
74+
dpex_dmm = kerneltargetctx.data_model_manager
75+
76+
for ndim in range(1, 3):
77+
dty = range_type(ndim)
78+
argty_tuple = tuple([dty])
79+
datamodel = dpex_dmm.lookup(dty)
80+
num_flattened_args = datamodel.flattened_field_count
81+
ap = cputargetctx.get_arg_packer(argty_tuple)
82+
83+
assert num_flattened_args == len(ap._be_args)

0 commit comments

Comments
 (0)