@@ -351,7 +351,9 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
351
351
)
352
352
st = list (X .strides )
353
353
perm = sorted (
354
- range (X .ndim ), key = lambda d : builtins .abs (st [d ]), reverse = True
354
+ range (X .ndim ),
355
+ key = lambda d : builtins .abs (st [d ]) if X .shape [d ] > 1 else 0 ,
356
+ reverse = True ,
355
357
)
356
358
inv_perm = sorted (range (X .ndim ), key = lambda i : perm [i ])
357
359
sh = X .shape
@@ -395,9 +397,14 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
395
397
max_ndim = max (nd1 , nd2 )
396
398
st1 += [0 ] * (max_ndim - len (st1 ))
397
399
st2 += [0 ] * (max_ndim - len (st2 ))
400
+ sh1 = list (X1 .shape ) + [0 ] * (max_ndim - nd1 )
401
+ sh2 = list (X2 .shape ) + [0 ] * (max_ndim - nd2 )
398
402
perm = sorted (
399
403
range (max_ndim ),
400
- key = lambda d : (builtins .abs (st1 [d ]), builtins .abs (st2 [d ])),
404
+ key = lambda d : (
405
+ builtins .abs (st1 [d ]) if sh1 [d ] > 1 else 0 ,
406
+ builtins .abs (st2 [d ]) if sh2 [d ] > 1 else 0 ,
407
+ ),
401
408
reverse = True ,
402
409
)
403
410
inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
@@ -417,6 +424,74 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
417
424
return dpt .permute_dims (R , inv_perm )
418
425
419
426
427
+ def _empty_like_triple_orderK (X1 , X2 , X3 , dt , res_shape , usm_type , dev ):
428
+ if not isinstance (X1 , dpt .usm_ndarray ):
429
+ raise TypeError (f"Expected usm_ndarray, got { type (X1 )} " )
430
+ if not isinstance (X2 , dpt .usm_ndarray ):
431
+ raise TypeError (f"Expected usm_ndarray, got { type (X2 )} " )
432
+ if not isinstance (X3 , dpt .usm_ndarray ):
433
+ raise TypeError (f"Expected usm_ndarray, got { type (X3 )} " )
434
+ nd1 = X1 .ndim
435
+ nd2 = X2 .ndim
436
+ nd3 = X3 .ndim
437
+ if X1 .shape == res_shape and X2 .shape == res_shape and len (res_shape ) > nd3 :
438
+ return _empty_like_pair_orderK (X1 , X2 , dt , res_shape , usm_type , dev )
439
+ elif (
440
+ X2 .shape == res_shape and X3 .shape == res_shape and len (res_shape ) > nd1
441
+ ):
442
+ return _empty_like_pair_orderK (X2 , X3 , dt , res_shape , usm_type , dev )
443
+ elif (
444
+ X1 .shape == res_shape and X3 .shape == res_shape and len (res_shape ) > nd2
445
+ ):
446
+ return _empty_like_pair_orderK (X1 , X3 , dt , res_shape , usm_type , dev )
447
+ fl1 = X1 .flags
448
+ fl2 = X2 .flags
449
+ fl3 = X3 .flags
450
+ if fl1 ["C" ] or fl2 ["C" ] or fl3 ["C" ]:
451
+ return dpt .empty (
452
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "C"
453
+ )
454
+ if fl1 ["F" ] and fl2 ["F" ] and fl3 ["F" ]:
455
+ return dpt .empty (
456
+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "F"
457
+ )
458
+ st1 = list (X1 .strides )
459
+ st2 = list (X2 .strides )
460
+ st3 = list (X3 .strides )
461
+ max_ndim = max (nd1 , nd2 , nd3 )
462
+ st1 += [0 ] * (max_ndim - len (st1 ))
463
+ st2 += [0 ] * (max_ndim - len (st2 ))
464
+ st3 += [0 ] * (max_ndim - len (st3 ))
465
+ sh1 = list (X1 .shape ) + [0 ] * (max_ndim - nd1 )
466
+ sh2 = list (X2 .shape ) + [0 ] * (max_ndim - nd2 )
467
+ sh3 = list (X3 .shape ) + [0 ] * (max_ndim - nd3 )
468
+ perm = sorted (
469
+ range (max_ndim ),
470
+ key = lambda d : (
471
+ builtins .abs (st1 [d ]) if sh1 [d ] > 1 else 0 ,
472
+ builtins .abs (st2 [d ]) if sh2 [d ] > 1 else 0 ,
473
+ builtins .abs (st3 [d ]) if sh3 [d ] > 1 else 0 ,
474
+ ),
475
+ reverse = True ,
476
+ )
477
+ inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
478
+ st1_sorted = [st1 [i ] for i in perm ]
479
+ st2_sorted = [st2 [i ] for i in perm ]
480
+ st3_sorted = [st3 [i ] for i in perm ]
481
+ sh = res_shape
482
+ sh_sorted = tuple (sh [i ] for i in perm )
483
+ R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
484
+ if max (min (st1_sorted ), min (st2_sorted ), min (st3_sorted )) < 0 :
485
+ sl = tuple (
486
+ slice (None , None , - 1 )
487
+ if (st1_sorted [i ] < 0 and st2_sorted [i ] < 0 and st3_sorted [i ] < 0 )
488
+ else slice (None , None , None )
489
+ for i in range (nd1 )
490
+ )
491
+ R = R [sl ]
492
+ return dpt .permute_dims (R , inv_perm )
493
+
494
+
420
495
def copy (usm_ary , order = "K" ):
421
496
"""copy(ary, order="K")
422
497
0 commit comments