Skip to content

Feature/flattened member count #1166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 36 additions & 40 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,30 @@
)


def _get_flattened_member_count(ty):
"""Return the number of fields in an instance of a given StructModel."""
flattened_member_count = 0
members = ty._members
for member in members:
if isinstance(member, types.UniTuple):
flattened_member_count += member.count
elif isinstance(
member,
(
types.scalars.Integer,
types.misc.PyObject,
types.misc.RawPointer,
types.misc.CPointer,
types.misc.MemInfoPointer,
),
):
flattened_member_count += 1
else:
raise UnreachableError

return flattened_member_count


class GenericPointerModel(PrimitiveModel):
def __init__(self, dmm, fe_type):
adrsp = (
Expand Down Expand Up @@ -68,26 +92,7 @@ def __init__(self, dmm, fe_type):
@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a USMArrayModel."""
flattened_member_count = 0
members = self._members
for member in members:
if isinstance(member, types.UniTuple):
flattened_member_count += member.count
elif isinstance(
member,
(
types.scalars.Integer,
types.misc.PyObject,
types.misc.RawPointer,
types.misc.CPointer,
types.misc.MemInfoPointer,
),
):
flattened_member_count += 1
else:
raise UnreachableError

return flattened_member_count
return _get_flattened_member_count(self)


class DpnpNdArrayModel(StructModel):
Expand Down Expand Up @@ -121,26 +126,7 @@ def __init__(self, dmm, fe_type):
@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a DpnpNdArrayModel."""
flattened_member_count = 0
members = self._members
for member in members:
if isinstance(member, types.UniTuple):
flattened_member_count += member.count
elif isinstance(
member,
(
types.scalars.Integer,
types.misc.PyObject,
types.misc.RawPointer,
types.misc.CPointer,
types.misc.MemInfoPointer,
),
):
flattened_member_count += 1
else:
raise UnreachableError

return flattened_member_count
return _get_flattened_member_count(self)


class SyclQueueModel(StructModel):
Expand Down Expand Up @@ -211,6 +197,11 @@ def __init__(self, dmm, fe_type):
]
super(RangeModel, self).__init__(dmm, fe_type, members)

@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a RangeModel."""
return _get_flattened_member_count(self)


class NdRangeModel(StructModel):
"""The native data model for a
Expand All @@ -229,6 +220,11 @@ def __init__(self, dmm, fe_type):
]
super(NdRangeModel, self).__init__(dmm, fe_type, members)

@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a NdRangeModel."""
return _get_flattened_member_count(self)


def _init_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a DpexKernelTarget-specific data model manager.
Expand Down
21 changes: 21 additions & 0 deletions numba_dpex/tests/core/types/DpnpNdArray/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
DpnpNdArrayModel,
USMArrayModel,
dpex_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray


Expand All @@ -31,3 +33,22 @@ def test_dpnp_ndarray_Model():
"""

assert issubclass(DpnpNdArrayModel, models.StructModel)


def test_flattened_member_count():
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(4):
dty = DpnpNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
27 changes: 27 additions & 0 deletions numba_dpex/tests/core/types/USMNdArray/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core.registry import cpu_target

from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.usm_ndarray_type import USMNdArray


def test_flattened_member_count():
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(4):
dty = USMNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
23 changes: 23 additions & 0 deletions numba_dpex/tests/core/types/range_types/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import pytest
from numba.core.datamodel import default_manager
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
NdRangeModel,
RangeModel,
dpex_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.range_types import NdRangeType, RangeType

rfields = ["ndim", "dim0", "dim1", "dim2"]
ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"]
range_tys = [RangeType, NdRangeType]


def test_datamodel_registration():
Expand Down Expand Up @@ -58,3 +61,23 @@ def test_ndrange_model_fields(field):
dm.get_field_position(field)
except:
pytest.fail(f"Expected field {field} not present in NdRangeModel")


@pytest.mark.parametrize("range_type", range_tys)
def test_flattened_member_count(range_type):
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(1, 3):
dty = range_type(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)