@@ -2380,12 +2380,6 @@ class TestQr:
2380
2380
)
2381
2381
@pytest .mark .parametrize ("mode" , ["r" , "raw" , "complete" , "reduced" ])
2382
2382
def test_qr (self , dtype , shape , mode ):
2383
- if (
2384
- is_cuda_device ()
2385
- and mode in ["complete" , "reduced" ]
2386
- and shape in [(16 , 16 ), (2 , 2 , 4 )]
2387
- ):
2388
- pytest .skip ("SAT-7589" )
2389
2383
a = generate_random_numpy_array (shape , dtype , seed_value = 81 )
2390
2384
ia = dpnp .array (a )
2391
2385
@@ -2398,24 +2392,48 @@ def test_qr(self, dtype, shape, mode):
2398
2392
2399
2393
# check decomposition
2400
2394
if mode in ("complete" , "reduced" ):
2401
- if a .ndim == 2 :
2402
- assert_almost_equal (
2403
- dpnp .dot (dpnp_q , dpnp_r ),
2404
- a ,
2405
- decimal = 5 ,
2406
- )
2407
- else : # a.ndim > 2
2408
- assert_almost_equal (
2409
- dpnp .matmul (dpnp_q , dpnp_r ),
2410
- a ,
2411
- decimal = 5 ,
2412
- )
2395
+ assert_almost_equal (
2396
+ dpnp .matmul (dpnp_q , dpnp_r ),
2397
+ a ,
2398
+ decimal = 5 ,
2399
+ )
2413
2400
else : # mode=="raw"
2414
2401
assert_dtype_allclose (dpnp_q , np_q )
2415
2402
2416
2403
if mode in ("raw" , "r" ):
2417
2404
assert_dtype_allclose (dpnp_r , np_r )
2418
2405
2406
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2407
+ @pytest .mark .parametrize (
2408
+ "shape" ,
2409
+ [(32 , 32 ), (8 , 16 , 16 )],
2410
+ ids = [
2411
+ "(32, 32)" ,
2412
+ "(8, 16, 16)" ,
2413
+ ],
2414
+ )
2415
+ @pytest .mark .parametrize ("mode" , ["r" , "raw" , "complete" , "reduced" ])
2416
+ def test_qr_large (self , dtype , shape , mode ):
2417
+ a = generate_random_numpy_array (shape , dtype , seed_value = 81 )
2418
+ ia = dpnp .array (a )
2419
+ if mode == "r" :
2420
+ np_r = numpy .linalg .qr (a , mode )
2421
+ dpnp_r = dpnp .linalg .qr (ia , mode )
2422
+ else :
2423
+ np_q , np_r = numpy .linalg .qr (a , mode )
2424
+ dpnp_q , dpnp_r = dpnp .linalg .qr (ia , mode )
2425
+ # check decomposition
2426
+ if mode in ("complete" , "reduced" ):
2427
+ assert_almost_equal (
2428
+ dpnp .matmul (dpnp_q , dpnp_r ),
2429
+ a ,
2430
+ decimal = 5 ,
2431
+ )
2432
+ else : # mode=="raw"
2433
+ assert_dtype_allclose (dpnp_q , np_q , factor = 12 )
2434
+ if mode in ("raw" , "r" ):
2435
+ assert_dtype_allclose (dpnp_r , np_r , factor = 12 )
2436
+
2419
2437
@pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2420
2438
@pytest .mark .parametrize (
2421
2439
"shape" ,
0 commit comments