Skip to content

Commit 9b328f1

Browse files
committed
Fix flake8 errors. Minor refactoring and bugfixes
1 parent 01e9840 commit 9b328f1

20 files changed

+534
-587
lines changed

arrayfire/__init__.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@
4949
"""
5050

5151
from .algorithm import (
52-
accum, all_true, allTrueByKey, any_true, anyTrueByKey, count, countByKey, diff1, diff2, imax, imin, max, maxByKey, min, minByKey, product, productByKey, scan, scan_by_key, set_intersect,
53-
set_union, set_unique, sort, sort_by_key, sort_index, sum, sumByKey, where)
52+
accum, all_true, allTrueByKey, any_true, anyTrueByKey, count, countByKey, diff1, diff2, imax, imin, max, maxByKey,
53+
min, minByKey, product, productByKey, scan, scan_by_key, set_intersect, set_union, set_unique, sort, sort_by_key,
54+
sort_index, sum, sumByKey, where)
5455
from .arith import (
55-
abs, acos, acosh, arg, asin, asinh, atan, atan2, atanh, cast, cbrt, ceil, clamp, conjg, cos, cosh, cplx, erf, erfc, exp,
56-
expm1, factorial, floor, hypot, imag, isinf, isnan, iszero, lgamma, log, log1p, log2, log10, maxof, minof,
56+
abs, acos, acosh, arg, asin, asinh, atan, atan2, atanh, cast, cbrt, ceil, clamp, conjg, cos, cosh, cplx, erf, erfc,
57+
exp, expm1, factorial, floor, hypot, imag, isinf, isnan, iszero, lgamma, log, log1p, log2, log10, maxof, minof,
5758
mod, pow,
5859
pow2, real, rem, root, round, rsqrt, sigmoid, sign, sin, sinh, sqrt, tan, tanh, tgamma, trunc)
5960
from .array import (
@@ -64,45 +65,50 @@
6465
from .blas import dot, matmul, matmulNT, matmulTN, matmulTT, gemm
6566
from .cuda import get_native_id, get_stream, set_native_id
6667
from .data import (
67-
constant, diag, flat, flip, identity, iota, join, lookup, lower, moddims, pad, range, reorder, replace, select, shift,
68-
tile, upper)
68+
constant, diag, flat, flip, identity, iota, join, lookup, lower, moddims, pad, range, reorder, replace, select,
69+
shift, tile, upper)
6970
from .library import (
70-
BACKEND, BINARYOP, CANNY_THRESHOLD, COLORMAP, CONNECTIVITY, CONV_DOMAIN, CONV_GRADIENT, CONV_MODE, CSPACE, DIFFUSION, ERR, FLUX,
71-
HOMOGRAPHY, IMAGE_FORMAT, INTERP, ITERATIVE_DECONV, INVERSE_DECONV, MARKER, MATCH, MATPROP, MOMENT, NORM, PAD, RANDOM_ENGINE, STORAGE, TOPK, VARIANCE, YCC_STD, Dtype, Source, AF_VER_MAJOR, FORGE_VER_MAJOR)
71+
BACKEND, BINARYOP, CANNY_THRESHOLD, COLORMAP, CONNECTIVITY, CONV_DOMAIN, CONV_GRADIENT, CONV_MODE, CSPACE,
72+
DIFFUSION, ERR, FLUX, HOMOGRAPHY, IMAGE_FORMAT, INTERP, ITERATIVE_DECONV, INVERSE_DECONV, MARKER, MATCH, MATPROP,
73+
MOMENT, NORM, PAD, RANDOM_ENGINE, STORAGE, TOPK, VARIANCE, YCC_STD, Dtype, Source, AF_VER_MAJOR, FORGE_VER_MAJOR)
7274
from .device import (
7375
alloc_device, alloc_host, alloc_pinned, device_gc, device_info, device_mem_info, eval, free_device, free_host,
7476
free_pinned, get_device, get_device_count, get_device_ptr, get_manual_eval_flag, info,
75-
info_str, init, is_dbl_supported, is_half_supported, is_locked_array, lock_array, lock_device_ptr, print_mem_info, set_device,
76-
set_manual_eval_flag, sync, unlock_array, unlock_device_ptr)
77+
info_str, init, is_dbl_supported, is_half_supported, is_locked_array, lock_array, lock_device_ptr, print_mem_info,
78+
set_device, set_manual_eval_flag, sync, unlock_array, unlock_device_ptr)
7779
from .graphics import Window
7880
from .image import (
79-
anisotropic_diffusion, bilateral, canny, color_space, confidenceCC, dilate, dilate3, erode, erode3, gaussian_kernel, gradient,
80-
gray2rgb, hist_equal, histogram, hsv2rgb, is_image_io_available, iterativeDeconv, inverseDeconv, load_image, load_image_native, maxfilt,
81-
mean_shift, minfilt, moments, regions, resize, rgb2gray, rgb2hsv, rgb2ycbcr, rotate, sat, save_image,
82-
save_image_native, scale, skew, sobel_derivatives, sobel_filter, transform, translate, unwrap, wrap, ycbcr2rgb)
81+
anisotropic_diffusion, bilateral, canny, color_space, confidenceCC, dilate, dilate3, erode, erode3,
82+
gaussian_kernel, gradient, gray2rgb, hist_equal, histogram, hsv2rgb, is_image_io_available, iterativeDeconv,
83+
inverseDeconv, load_image, load_image_native, maxfilt, mean_shift, minfilt, moments, regions, resize, rgb2gray,
84+
rgb2hsv, rgb2ycbcr, rotate, sat, save_image, save_image_native, scale, skew, sobel_derivatives, sobel_filter,
85+
transform, translate, unwrap, wrap, ycbcr2rgb)
8386
from .index import Index, ParallelRange, Seq
8487
from .interop import AF_NUMBA_FOUND, AF_NUMPY_FOUND, AF_PYCUDA_FOUND, AF_PYOPENCL_FOUND, to_array
8588
from .lapack import (
86-
cholesky, cholesky_inplace, det, inverse, is_lapack_available, lu, lu_inplace, norm, pinverse, qr, qr_inplace, rank, solve,
87-
solve_lu, svd, svd_inplace)
89+
cholesky, cholesky_inplace, det, inverse, is_lapack_available, lu, lu_inplace, norm, pinverse, qr, qr_inplace,
90+
rank, solve, solve_lu, svd, svd_inplace)
8891
from .library import (
8992
get_active_backend, get_available_backends, get_backend, get_backend_count, get_backend_id, get_device_id,
90-
get_size_of, safe_call, set_backend)
93+
get_size_of, safe_call, set_backend, to_str)
9194
from .ml import convolve2GradientNN
9295
from .random import (
9396
Random_Engine, get_default_random_engine, get_seed, randn, randu, set_default_random_engine_type,
9497
set_seed)
9598
from .signal import (
96-
approx1, approx1_uniform, approx2, approx2_uniform, convolve, convolve1, convolve2, convolve2NN, convolve2_separable, convolve3, dft, fft, fft2, fft2_c2r,
97-
fft2_inplace, fft2_r2c, fft3, fft3_c2r, fft3_inplace, fft3_r2c, fft_c2r, fft_convolve, fft_convolve1,
98-
fft_convolve2, fft_convolve3, fft_inplace, fft_r2c, fir, idft, ifft, ifft2, ifft2_inplace, ifft3, ifft3_inplace,
99-
ifft_inplace, iir, medfilt, medfilt1, medfilt2, set_fft_plan_cache_size)
99+
approx1, approx1_uniform, approx2, approx2_uniform, convolve, convolve1, convolve2, convolve2NN,
100+
convolve2_separable, convolve3, dft, fft, fft2, fft2_c2r, fft2_inplace, fft2_r2c, fft3, fft3_c2r, fft3_inplace,
101+
fft3_r2c, fft_c2r, fft_convolve, fft_convolve1, fft_convolve2, fft_convolve3, fft_inplace, fft_r2c, fir, idft,
102+
ifft, ifft2, ifft2_inplace, ifft3, ifft3_inplace, ifft_inplace, iir, medfilt, medfilt1, medfilt2,
103+
set_fft_plan_cache_size)
100104
from .sparse import (
101105
convert_sparse, convert_sparse_to_dense, create_sparse, create_sparse_from_dense, create_sparse_from_host,
102106
sparse_get_col_idx, sparse_get_info, sparse_get_nnz, sparse_get_row_idx, sparse_get_storage, sparse_get_values)
103107
from .statistics import corrcoef, cov, mean, meanvar, median, stdev, topk, var
104108
from .timer import timeit
105-
from .util import dim4, dim4_to_tuple, implicit_dtype, number_dtype, to_str, get_reversion, get_version, to_dtype, to_typecode, to_c_type
109+
from .util import (
110+
dim4, dim4_to_tuple, implicit_dtype, number_dtype, get_reversion, get_version, to_dtype, to_typecode,
111+
to_c_type)
106112

107113
try:
108114
# FIXME: pycuda imported but unused
@@ -133,7 +139,7 @@
133139
"broadcast",
134140
# blas
135141
"dot", "matmul", "matmulNT", "matmulTN", "matmulTT", "gemm",
136-
#cuda
142+
# cuda
137143
"get_native_id", "get_stream", "set_native_id",
138144
# data
139145
"constant", "diag", "flat", "flip", "identity", "iota", "join", "lookup",
@@ -169,7 +175,7 @@
169175
"norm", "pinverse", "qr", "qr_inplace", "rank", "solve", "solve_lu", "svd", "svd_inplace",
170176
# library
171177
"get_active_backend", "get_available_backends", "get_backend", "get_backend_count",
172-
"get_backend_id", "get_device_id", "get_size_of", "safe_call", "set_backend",
178+
"get_backend_id", "get_device_id", "get_size_of", "safe_call", "set_backend", "to_str",
173179
# ml
174180
"convolve2GradientNN",
175181
# random
@@ -192,6 +198,6 @@
192198
# timer
193199
"timeit",
194200
# util
195-
"dim4", "dim4_to_tuple", "implicit_dtype", "number_dtype", "to_str", "get_reversion",
201+
"dim4", "dim4_to_tuple", "implicit_dtype", "number_dtype", "get_reversion",
196202
"get_version", "to_dtype", "to_typecode", "to_c_type"
197203
]

arrayfire/algorithm.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#######################################################
2-
# Copyright (c) 2019, ArrayFire
2+
# Copyright (c) 2020, ArrayFire
33
# All rights reserved.
44
#
55
# This file is distributed under 3-clause BSD license.
@@ -14,11 +14,13 @@
1414
from .array import Array
1515
from .library import backend, safe_call, BINARYOP, c_bool_t, c_double_t, c_int_t, c_pointer, c_uint_t
1616

17+
1718
def _parallel_dim(a, dim, c_func):
1819
out = Array()
1920
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
2021
return out
2122

23+
2224
def _reduce_all(a, c_func):
2325
real = c_double_t(0)
2426
imag = c_double_t(0)
@@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
2931
imag = imag.value
3032
return real if imag == 0 else real + imag * 1j
3133

34+
3235
def _nan_parallel_dim(a, dim, c_func, nan_val):
3336
out = Array()
3437
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
3538
return out
3639

40+
3741
def _nan_reduce_all(a, c_func, nan_val):
3842
real = c_double_t(0)
3943
imag = c_double_t(0)
@@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
4448
imag = imag.value
4549
return real if imag == 0 else real + imag * 1j
4650

51+
4752
def _FNSD(dim, dims):
4853
if dim >= 0:
4954
return int(dim)
@@ -55,20 +60,26 @@ def _FNSD(dim, dims):
5560
break
5661
return int(fnsd)
5762

63+
5864
def _rbk_dim(keys, vals, dim, c_func):
5965
keys_out = Array()
6066
vals_out = Array()
6167
rdim = _FNSD(dim, vals.dims())
6268
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
6369
return keys_out, vals_out
6470

71+
6572
def _nan_rbk_dim(a, dim, c_func, nan_val):
6673
keys_out = Array()
6774
vals_out = Array()
75+
# FIXME: vals is undefined
6876
rdim = _FNSD(dim, vals.dims())
69-
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
77+
# FIXME: keys is undefined
78+
safe_call(c_func(
79+
c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
7080
return keys_out, vals_out
7181

82+
7283
def sum(a, dim=None, nan_val=None):
7384
"""
7485
Calculate the sum of all the elements along a specified dimension.
@@ -88,18 +99,16 @@ def sum(a, dim=None, nan_val=None):
8899
The sum of all elements in `a` along dimension `dim`.
89100
If `dim` is `None`, sum of the entire Array is returned.
90101
"""
91-
if nan_val is not None:
92-
if dim is not None:
102+
if nan_val:
103+
if dim:
93104
return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
94105
return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
95106

96-
if dim is not None:
107+
if dim:
97108
return _parallel_dim(a, dim, backend.get().af_sum)
98109
return _reduce_all(a, backend.get().af_sum_all)
99110

100111

101-
102-
103112
def sumByKey(keys, vals, dim=-1, nan_val=None):
104113
"""
105114
Calculate the sum of elements along a specified dimension according to a key.
@@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
122131
values: af.Array or scalar number
123132
The sum of all elements in `vals` along dimension `dim` according to keys
124133
"""
125-
if (nan_val is not None):
134+
if nan_val:
126135
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
127-
else:
128-
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
136+
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
137+
129138

130139
def product(a, dim=None, nan_val=None):
131140
"""
@@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
178187
values: af.Array or scalar number
179188
The product of all elements in `vals` along dimension `dim` according to keys
180189
"""
181-
if (nan_val is not None):
190+
if nan_val is not None:
182191
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
183-
else:
184-
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
192+
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
193+
185194

186195
def min(a, dim=None):
187196
"""
@@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
227236
"""
228237
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)
229238

239+
230240
def max(a, dim=None):
231241
"""
232242
Find the maximum value of all the elements along a specified dimension.
@@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
271281
"""
272282
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)
273283

284+
274285
def all_true(a, dim=None):
275286
"""
276287
Check if all the elements along a specified dimension are true.
@@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
315326
"""
316327
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)
317328

329+
318330
def any_true(a, dim=None):
319331
"""
320332
Check if any the elements along a specified dimension are true.
@@ -334,8 +346,8 @@ def any_true(a, dim=None):
334346
"""
335347
if dim is not None:
336348
return _parallel_dim(a, dim, backend.get().af_any_true)
337-
else:
338-
return _reduce_all(a, backend.get().af_any_true_all)
349+
return _reduce_all(a, backend.get().af_any_true_all)
350+
339351

340352
def anyTrueByKey(keys, vals, dim=-1):
341353
"""
@@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
359371
"""
360372
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)
361373

374+
362375
def count(a, dim=None):
363376
"""
364377
Count the number of non zero elements in an array along a specified dimension.
@@ -378,8 +391,7 @@ def count(a, dim=None):
378391
"""
379392
if dim is not None:
380393
return _parallel_dim(a, dim, backend.get().af_count)
381-
else:
382-
return _reduce_all(a, backend.get().af_count_all)
394+
return _reduce_all(a, backend.get().af_count_all)
383395

384396

385397
def countByKey(keys, vals, dim=-1):
@@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
404416
"""
405417
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)
406418

419+
407420
def imin(a, dim=None):
408421
"""
409422
Find the value and location of the minimum value along a specified dimension

arrayfire/arith.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def cast(a, dtype):
7777
out : af.Array
7878
array containing the values from `a` after converting to `dtype`.
7979
"""
80-
out=Array()
80+
out = Array()
8181
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
8282
return out
8383

@@ -156,15 +156,8 @@ def clamp(val, low, high):
156156
vdims = dim4_to_tuple(val.dims())
157157
vty = val.type()
158158

159-
if not is_low_array:
160-
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
161-
else:
162-
low_arr = low.arr
163-
164-
if not is_high_array:
165-
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
166-
else:
167-
high_arr = high.arr
159+
low_arr = low.arr if is_low_array else constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
160+
high_arr = high.arr if is_high_array else constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
168161

169162
safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
170163

@@ -1003,6 +996,7 @@ def sqrt(a):
1003996
"""
1004997
return _arith_unary_func(a, backend.get().af_sqrt)
1005998

999+
10061000
def rsqrt(a):
10071001
"""
10081002
Reciprocal or inverse square root of each element in the array.

0 commit comments

Comments
 (0)