@@ -264,57 +264,131 @@ function vml_prefix(t::DataType)
264
264
error (" unknown type $t " )
265
265
end
266
266
267
+ if isdefined (Base, :_checkcontiguous )
268
+ alldense (@nospecialize (x)) = Base. _checkcontiguous (Bool, x)
269
+ else
270
+ alldense (x) = x isa DenseArray
271
+ alldense (x:: Base.ReshapedArray ) = alldense (parent (x))
272
+ alldense (x:: Base.FastContiguousSubArray ) = alldense (parent (x))
273
+ alldense (x:: Base.ReinterpretArray ) = alldense (parent (x))
274
+ end
275
+ alldense (x, y, z... ) = alldense (x) && alldense (y, z... )
276
+
277
+ if isdefined (Base, :merge_adjacent_dim )
278
+ const merge_adjacent_dim = Base. merge_adjacent_dim
279
+ else
280
+ merge_adjacent_dim (:: Dims{0} , :: Dims{0} ) = 1 , 1 , 0
281
+ merge_adjacent_dim (apsz:: Dims{1} , apst:: Dims{1} ) = apsz[1 ], apst[1 ], 1
282
+ function merge_adjacent_dim (apsz:: Dims{N} , apst:: Dims{N} , n:: Int = 1 ) where {N}
283
+ sz, st = apsz[n], apst[n]
284
+ while n < N
285
+ szₙ, stₙ = apsz[n+ 1 ], apst[n+ 1 ]
286
+ if sz == 1
287
+ sz, st = szₙ, stₙ
288
+ elseif stₙ == st * sz || szₙ == 1
289
+ sz *= szₙ
290
+ else
291
+ break
292
+ end
293
+ n += 1
294
+ end
295
+ return sz, st, n
296
+ end
297
+ end
298
+
299
+ getstrides (x... ) = map (stride1, x)
300
+ function stride1 (x:: AbstractArray )
301
+ alldense (x) && return 1
302
+ ndims (x) == 1 && return stride (x, 1 )
303
+ szs:: Dims = size (x)
304
+ sts:: Dims = strides (x)
305
+ _, st, n = merge_adjacent_dim (szs, sts)
306
+ n === ndims (x) && return st
307
+ throw (ArgumentError (" only support vector like inputs" ))
308
+ end
309
+
267
310
function def_unary_op (tin, tout, jlname, jlname!, mklname;
268
311
vmltype = tin)
269
- mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$mklname " ))
312
+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$(mklname) I" ))
313
+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (vmltype))$mklname " ))
270
314
exports = Symbol[]
271
315
(@isdefined jlname) || push! (exports, jlname)
272
316
(@isdefined jlname!) || push! (exports, jlname!)
273
317
@eval begin
274
- function ($ jlname!)(out:: Array {$tout} , A:: Array {$tin} )
318
+ function ($ jlname!)(out:: AbstractArray {$tout} , A:: AbstractArray {$tin} )
275
319
size (out) == size (A) || throw (DimensionMismatch ())
276
- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
320
+ if alldense (out, A) || ((sts = getstrides (out, A)) == (1 , 1 ))
321
+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
322
+ else
323
+ stᵒ, stᴬ = sts
324
+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, out, stᵒ)
325
+ end
277
326
vml_check_error ()
278
327
return out
279
328
end
280
329
$ (if tin == tout
281
330
quote
282
- function $ (jlname!)(A:: Array{$tin} )
283
- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, A)
331
+ function $ (jlname!)(A:: AbstractArray{$tin} )
332
+ if alldense (A) || ((sts = getstrides (A)) == (1 ,))
333
+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, A)
334
+ else
335
+ (stᴬ,) = sts
336
+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, A, stᴬ)
337
+ end
284
338
vml_check_error ()
285
339
return A
286
340
end
287
341
end
288
342
end )
289
- function ($ jlname)(A:: Array{$tin} )
290
- out = similar (A, $ tout)
291
- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out)
292
- vml_check_error ()
293
- return out
294
- end
343
+ ($ jlname)(A:: AbstractArray{$tin} ) = $ (jlname!)(similar (A, $ tout), A)
295
344
$ (isempty (exports) ? nothing : Expr (:export , exports... ))
296
345
end
297
346
end
298
347
299
348
function def_binary_op (tin, tout, jlname, jlname!, mklname, broadcast)
300
- mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
349
+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
350
+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$(mklname) I" ))
301
351
exports = Symbol[]
302
352
(@isdefined jlname) || push! (exports, jlname)
303
353
(@isdefined jlname!) || push! (exports, jlname!)
304
354
@eval begin
305
355
$ (isempty (exports) ? nothing : Expr (:export , exports... ))
306
- function ($ jlname!)(out:: Array{$tout} , A:: Array{$tin} , B:: Array{$tin} )
307
- size (out) == size (A) == size (B) || throw (DimensionMismatch (" Input arrays and output array need to have the same size" ))
308
- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
356
+ function ($ jlname!)(out:: AbstractArray{$tout} , A:: AbstractArray{$tin} , B:: AbstractArray{$tin} )
357
+ size (A) == size (B) || throw (DimensionMismatch (" Input arrays need to have the same size" ))
358
+ size (out) == size (A) || throw (DimensionMismatch (" Output array need to have the same size with input" ))
359
+ if alldense (out, A, B) || ((sts = getstrides (out, A, B)) == (1 , 1 , 1 ))
360
+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
361
+ else
362
+ stᵒ, stᴬ, stᴮ = sts
363
+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, B, stᴮ, out, stᵒ)
364
+ end
309
365
vml_check_error ()
310
366
return out
311
367
end
312
- function ($ jlname)(A:: Array{$tout} , B:: Array{$tin} )
313
- size (A) == size (B) || throw (DimensionMismatch (" Input arrays need to have the same size" ))
314
- out = similar (A)
315
- ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, B, out)
368
+ ($ jlname)(A:: AbstractArray{$tin} , B:: AbstractArray{$tin} ) = ($ jlname!)(similar (A, $ tout), A, B)
369
+ end
370
+ end
371
+
372
+ function def_one2two_op (tin, tout, jlname, jlname!, mklname)
373
+ mklfndense = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$mklname " ))
374
+ mklfn = Base. Meta. quot (Symbol (" $(vml_prefix (tin))$(mklname) I" ))
375
+ exports = Symbol[]
376
+ (@isdefined jlname) || push! (exports, jlname)
377
+ (@isdefined jlname!) || push! (exports, jlname!)
378
+ @eval begin
379
+ $ (isempty (exports) ? nothing : Expr (:export , exports... ))
380
+ function ($ jlname!)(out1:: AbstractArray{$tout} , out2:: AbstractArray{$tout} , A:: AbstractArray{$tin} )
381
+ size (out1) == size (out2) || throw (DimensionMismatch (" Output arrays need to have the same size" ))
382
+ size (A) == size (out2) || throw (DimensionMismatch (" Output array need to have the same size with input" ))
383
+ if alldense (out1, out2, A) || ((sts = getstrides (out1, out2, A)) == (1 , 1 , 1 ))
384
+ ccall (($ mklfndense, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Ptr{$ tin}, Ptr{$ tout}), length (A), A, out1, out2)
385
+ else
386
+ st¹, st², stᴬ = sts
387
+ ccall (($ mklfn, MKL_jll. libmkl_rt), Nothing, (Int, Ptr{$ tin}, Int, Ptr{$ tin}, Int, Ptr{$ tout}, Int), length (A), A, stᴬ, out1, st¹, out2, st²)
388
+ end
316
389
vml_check_error ()
317
- return out
390
+ return out1, out2
318
391
end
392
+ ($ jlname)(A:: AbstractArray{$tin} ) = ($ jlname!)(similar (A, $ tout), similar (A, $ tout), A)
319
393
end
320
394
end
0 commit comments