Skip to content

Commit c4d2f40

Browse files
author
Diptorup Deb
committed
Unit tests for flattened_member_count data model property.
1 parent 6e74aea commit c4d2f40

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

numba_dpex/tests/core/types/DpnpNdArray/test_models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from numba import types
66
from numba.core.datamodel import default_manager, models
7+
from numba.core.registry import cpu_target
78

89
from numba_dpex.core.datamodel.models import (
910
DpnpNdArrayModel,
1011
USMArrayModel,
1112
dpex_data_model_manager,
1213
)
14+
from numba_dpex.core.descriptor import dpex_kernel_target
1315
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
1416

1517

@@ -31,3 +33,22 @@ def test_dpnp_ndarray_Model():
3133
"""
3234

3335
assert issubclass(DpnpNdArrayModel, models.StructModel)
36+
37+
38+
def test_flattened_member_count():
39+
"""Test that the number of flattened member count matches the number of
40+
flattened args generated by the CpuTarget's ArgPacker.
41+
"""
42+
43+
cputargetctx = cpu_target.target_context
44+
kerneltargetctx = dpex_kernel_target.target_context
45+
dpex_dmm = kerneltargetctx.data_model_manager
46+
47+
for ndim in range(4):
48+
dty = DpnpNdArray(ndim)
49+
argty_tuple = tuple([dty])
50+
datamodel = dpex_dmm.lookup(dty)
51+
num_flattened_args = datamodel.flattened_field_count
52+
ap = cputargetctx.get_arg_packer(argty_tuple)
53+
54+
assert num_flattened_args == len(ap._be_args)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba.core.registry import cpu_target
6+
7+
from numba_dpex.core.descriptor import dpex_kernel_target
8+
from numba_dpex.core.types.usm_ndarray_type import USMNdArray
9+
10+
11+
def test_flattened_member_count():
12+
"""Test that the number of flattened member count matches the number of
13+
flattened args generated by the CpuTarget's ArgPacker.
14+
"""
15+
16+
cputargetctx = cpu_target.target_context
17+
kerneltargetctx = dpex_kernel_target.target_context
18+
dpex_dmm = kerneltargetctx.data_model_manager
19+
20+
for ndim in range(4):
21+
dty = USMNdArray(ndim)
22+
argty_tuple = tuple([dty])
23+
datamodel = dpex_dmm.lookup(dty)
24+
num_flattened_args = datamodel.flattened_field_count
25+
ap = cputargetctx.get_arg_packer(argty_tuple)
26+
27+
assert num_flattened_args == len(ap._be_args)

numba_dpex/tests/core/types/range_types/test_data_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44

55
import pytest
66
from numba.core.datamodel import default_manager
7+
from numba.core.registry import cpu_target
78

89
from numba_dpex.core.datamodel.models import (
910
NdRangeModel,
1011
RangeModel,
1112
dpex_data_model_manager,
1213
)
14+
from numba_dpex.core.descriptor import dpex_kernel_target
1315
from numba_dpex.core.types.range_types import NdRangeType, RangeType
1416

1517
rfields = ["ndim", "dim0", "dim1", "dim2"]
1618
ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"]
19+
range_tys = [RangeType, NdRangeType]
1720

1821

1922
def test_datamodel_registration():
@@ -58,3 +61,23 @@ def test_ndrange_model_fields(field):
5861
dm.get_field_position(field)
5962
except:
6063
pytest.fail(f"Expected field {field} not present in NdRangeModel")
64+
65+
66+
@pytest.mark.parametrize("range_type", range_tys)
67+
def test_flattened_member_count(range_type):
68+
"""Test that the number of flattened member count matches the number of
69+
flattened args generated by the CpuTarget's ArgPacker.
70+
"""
71+
72+
cputargetctx = cpu_target.target_context
73+
kerneltargetctx = dpex_kernel_target.target_context
74+
dpex_dmm = kerneltargetctx.data_model_manager
75+
76+
for ndim in range(1, 3):
77+
dty = range_type(ndim)
78+
argty_tuple = tuple([dty])
79+
datamodel = dpex_dmm.lookup(dty)
80+
num_flattened_args = datamodel.flattened_field_count
81+
ap = cputargetctx.get_arg_packer(argty_tuple)
82+
83+
assert num_flattened_args == len(ap._be_args)

0 commit comments

Comments
 (0)