Skip to content

Commit 72a0298

Browse files
Merge master into update_elementwise_docs
2 parents 28eeccb + 1c375af commit 72a0298

File tree

2 files changed

+117
-35
lines changed

2 files changed

+117
-35
lines changed

.github/workflows/cron-run-tests.yaml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
name: Run tests suite
2+
on:
3+
# For Branch-Protection check. Only the default branch is supported. See
4+
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
5+
branch_protection_rule:
6+
# To guarantee Maintained check is occasionally updated. See
7+
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
8+
schedule:
9+
- cron: '28 2 * * *'
10+
workflow_dispatch:
11+
12+
permissions: read-all
13+
14+
env:
15+
PACKAGE_NAME: dpnp
16+
CHANNELS: '-c dppy/label/dev -c https://software.repos.intel.com/python/conda/ -c conda-forge --override-channels'
17+
TEST_ENV_NAME: test
18+
19+
jobs:
20+
test:
21+
name: Test ['${{ matrix.runner }}', python='${{ matrix.python }}']
22+
23+
runs-on: ${{ matrix.runner }}
24+
25+
defaults:
26+
run:
27+
shell: ${{ matrix.runner == 'windows-2019' && 'cmd /C CALL {0}' || 'bash -el {0}' }}
28+
29+
permissions:
30+
# Needed to cancel any previous runs that are not completed for a given workflow
31+
actions: write
32+
33+
strategy:
34+
matrix:
35+
python: ['3.9', '3.10', '3.11', '3.12']
36+
runner: [ubuntu-22.04, ubuntu-24.04, windows-2019]
37+
38+
continue-on-error: false
39+
40+
steps:
41+
- name: Cancel Previous Runs
42+
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # 0.12.1
43+
with:
44+
access_token: ${{ github.token }}
45+
46+
- name: Setup miniconda
47+
uses: conda-incubator/setup-miniconda@d2e6a045a86077fb6cad6f5adf368e9076ddaa8d # v3.1.0
48+
with:
49+
miniforge-version: latest
50+
use-mamba: 'true'
51+
channels: conda-forge
52+
conda-remove-defaults: 'true'
53+
python-version: ${{ matrix.python }}
54+
activate-environment: ${{ env.TEST_ENV_NAME }}
55+
56+
- name: Install dpnp
57+
run: |
58+
mamba install ${{ env.PACKAGE_NAME }} pytest ${{ env.CHANNELS }}
59+
env:
60+
MAMBA_NO_LOW_SPEED_LIMIT: 1
61+
62+
- name: List installed packages
63+
run: mamba list
64+
65+
- name: Activate OCL CPU RT
66+
if: ${{ matrix.runner }} == 'windows-2019'
67+
shell: pwsh
68+
run: |
69+
$script_path="$env:CONDA_PREFIX\Scripts\set-intel-ocl-icd-registry.ps1"
70+
if (Test-Path $script_path) {
71+
&$script_path
72+
} else {
73+
Write-Warning "File $script_path was NOT found!"
74+
}
75+
# Check the variable assisting OpenCL CPU driver to find TBB DLLs which are not located where it expects them by default
76+
$cl_cfg="$env:CONDA_PREFIX\Library\lib\cl.cfg"
77+
Get-Content -Tail 5 -Path $cl_cfg
78+
79+
- name: Smoke test
80+
run: |
81+
python -m dpctl -f
82+
python -c "import dpnp; print(dpnp.__version__)"
83+
84+
- name: Run tests
85+
run: |
86+
python -m pytest -ra --pyargs ${{ env.PACKAGE_NAME }}.tests
87+
env:
88+
SYCL_CACHE_PERSISTENT: 1

dpnp/dpnp_iface_statistics.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,11 @@
4747
import dpnp
4848

4949
# pylint: disable=no-name-in-module
50-
from .dpnp_algo import (
51-
dpnp_correlate,
52-
)
50+
from .dpnp_algo import dpnp_correlate
5351
from .dpnp_array import dpnp_array
54-
from .dpnp_utils import (
55-
call_origin,
56-
get_usm_allocations,
57-
)
52+
from .dpnp_utils import call_origin, get_usm_allocations
5853
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
59-
from .dpnp_utils.dpnp_utils_statistics import (
60-
dpnp_cov,
61-
)
54+
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov
6255

6356
__all__ = [
6457
"amax",
@@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
276269
"""
277270

278271
dpnp.check_supported_arrays_type(a)
272+
usm_type, exec_q = get_usm_allocations([a, weights])
273+
279274
if weights is None:
280275
avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
281276
scl = dpnp.asanyarray(
282277
avg.dtype.type(a.size / avg.size),
283-
usm_type=a.usm_type,
284-
sycl_queue=a.sycl_queue,
278+
usm_type=usm_type,
279+
sycl_queue=exec_q,
285280
)
286281
else:
287-
if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)):
288-
wgt = dpnp.asanyarray(
289-
weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue
282+
if not dpnp.is_supported_array_type(weights):
283+
weights = dpnp.asarray(
284+
weights, usm_type=usm_type, sycl_queue=exec_q
290285
)
291-
else:
292-
get_usm_allocations([a, weights])
293-
wgt = weights
294286

295-
if not dpnp.issubdtype(a.dtype, dpnp.inexact):
287+
a_dtype = a.dtype
288+
if not dpnp.issubdtype(a_dtype, dpnp.inexact):
296289
default_dtype = dpnp.default_float_type(a.device)
297-
result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype)
290+
res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype)
298291
else:
299-
result_dtype = dpnp.result_type(a.dtype, wgt.dtype)
292+
res_dtype = dpnp.result_type(a_dtype, weights.dtype)
300293

301294
# Sanity checks
302-
if a.shape != wgt.shape:
295+
wgt_shape = weights.shape
296+
a_shape = a.shape
297+
if a_shape != wgt_shape:
303298
if axis is None:
304299
raise TypeError(
305300
"Axis must be specified when shapes of input array and "
306301
"weights differ."
307302
)
308-
if wgt.ndim != 1:
303+
if weights.ndim != 1:
309304
raise TypeError(
310305
"1D weights expected when shapes of input array and "
311306
"weights differ."
312307
)
313-
if wgt.shape[0] != a.shape[axis]:
308+
if wgt_shape[0] != a_shape[axis]:
314309
raise ValueError(
315310
"Length of weights not compatible with specified axis."
316311
)
317312

318-
# setup wgt to broadcast along axis
319-
wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
320-
wgt = wgt.swapaxes(-1, axis)
313+
# setup weights to broadcast along axis
314+
weights = dpnp.broadcast_to(
315+
weights, (a.ndim - 1) * (1,) + wgt_shape
316+
)
317+
weights = weights.swapaxes(-1, axis)
321318

322-
scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
319+
scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims)
323320
if dpnp.any(scl == 0.0):
324321
raise ZeroDivisionError("Weights sum to zero, can't be normalized")
325322

326-
# result_datatype
327-
avg = (
328-
dpnp.multiply(a, wgt).sum(
329-
axis=axis, dtype=result_dtype, keepdims=keepdims
330-
)
331-
/ scl
323+
avg = dpnp.multiply(a, weights).sum(
324+
axis=axis, dtype=res_dtype, keepdims=keepdims
332325
)
326+
avg /= scl
333327

334328
if returned:
335329
if scl.shape != avg.shape:
@@ -556,7 +550,7 @@ def cov(
556550
557551
"""
558552

559-
if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
553+
if not dpnp.is_supported_array_type(m):
560554
pass
561555
elif m.ndim > 2:
562556
pass

0 commit comments

Comments
 (0)