14
14
from .helper import (
15
15
assert_dtype_allclose ,
16
16
get_all_dtypes ,
17
- get_complex_dtypes ,
17
+ get_float_complex_dtypes ,
18
18
has_support_aspect64 ,
19
19
is_cpu_device ,
20
20
)
@@ -678,6 +678,11 @@ def test_norm3(array, ord, axis):
678
678
679
679
680
680
class TestQr :
681
+ # Set numpy.random.seed for test methods to prevent
682
+ # random generation of the input singular matrix
683
+ def setup_method (self ):
684
+ numpy .random .seed (81 )
685
+
681
686
# TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI.
682
687
# Skip the tests on cpu until these packages are available for the external CI.
683
688
# Specifically dpcpp_linux-64>=2024.1.0
@@ -702,7 +707,9 @@ class TestQr:
702
707
ids = ["r" , "raw" , "complete" , "reduced" ],
703
708
)
704
709
def test_qr (self , dtype , shape , mode ):
705
- a = numpy .random .rand (* shape ).astype (dtype )
710
+ a = numpy .random .randn (* shape ).astype (dtype )
711
+ if numpy .issubdtype (dtype , numpy .complexfloating ):
712
+ a += 1j * numpy .random .randn (* shape )
706
713
ia = inp .array (a )
707
714
708
715
if mode == "r" :
@@ -772,7 +779,7 @@ def test_qr_empty(self, dtype, shape, mode):
772
779
ids = ["r" , "raw" , "complete" , "reduced" ],
773
780
)
774
781
def test_qr_strides (self , mode ):
775
- a = numpy .random .rand (5 , 5 )
782
+ a = numpy .random .randn (5 , 5 )
776
783
ia = inp .array (a )
777
784
778
785
# positive strides
@@ -1032,6 +1039,11 @@ def test_slogdet_errors(self):
1032
1039
1033
1040
1034
1041
class TestSvd :
1042
+ # Set numpy.random.seed for test methods to prevent
1043
+ # random generation of the input singular matrix
1044
+ def setup_method (self ):
1045
+ numpy .random .seed (81 )
1046
+
1035
1047
def get_tol (self , dtype ):
1036
1048
tol = 1e-06
1037
1049
if dtype in (inp .float32 , inp .complex64 ):
@@ -1121,18 +1133,19 @@ def test_svd(self, dtype, shape):
1121
1133
dp_a , dp_u , dp_s , dp_vt , np_u , np_s , np_vt , True
1122
1134
)
1123
1135
1124
- @pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
1136
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
1125
1137
@pytest .mark .parametrize ("compute_vt" , [True , False ], ids = ["True" , "False" ])
1126
1138
@pytest .mark .parametrize (
1127
1139
"shape" ,
1128
1140
[(2 , 2 ), (16 , 16 )],
1129
- ids = ["(2,2)" , "(16, 16)" ],
1141
+ ids = ["(2, 2)" , "(16, 16)" ],
1130
1142
)
1131
1143
def test_svd_hermitian (self , dtype , compute_vt , shape ):
1132
- a = numpy .random .randn (* shape ) + 1j * numpy .random .randn (* shape )
1133
- a = numpy .conj (a .T ) @ a
1144
+ a = numpy .random .randn (* shape ).astype (dtype )
1145
+ if numpy .issubdtype (dtype , numpy .complexfloating ):
1146
+ a += 1j * numpy .random .randn (* shape )
1147
+ a = (a + a .conj ().T ) / 2
1134
1148
1135
- a = a .astype (dtype )
1136
1149
dp_a = inp .array (a )
1137
1150
1138
1151
if compute_vt :
@@ -1167,3 +1180,155 @@ def test_svd_errors(self):
1167
1180
# a.ndim < 2
1168
1181
a_dp_ndim_1 = a_dp .flatten ()
1169
1182
assert_raises (inp .linalg .LinAlgError , inp .linalg .svd , a_dp_ndim_1 )
1183
+
1184
+
1185
+ class TestPinv :
1186
+ # Set numpy.random.seed for test methods to prevent
1187
+ # random generation of the input singular matrix
1188
+ def setup_method (self ):
1189
+ numpy .random .seed (81 )
1190
+
1191
+ def get_tol (self , dtype ):
1192
+ tol = 1e-06
1193
+ if dtype in (inp .float32 , inp .complex64 ):
1194
+ tol = 1e-04
1195
+ elif not has_support_aspect64 () and dtype in (
1196
+ inp .int32 ,
1197
+ inp .int64 ,
1198
+ None ,
1199
+ ):
1200
+ tol = 1e-04
1201
+ self ._tol = tol
1202
+
1203
+ def check_types_shapes (self , dp_B , np_B ):
1204
+ if has_support_aspect64 ():
1205
+ assert dp_B .dtype == np_B .dtype
1206
+ else :
1207
+ assert dp_B .dtype .kind == np_B .dtype .kind
1208
+
1209
+ assert dp_B .shape == np_B .shape
1210
+
1211
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
1212
+ @pytest .mark .parametrize (
1213
+ "shape" ,
1214
+ [(2 , 2 ), (3 , 4 ), (5 , 3 ), (16 , 16 ), (2 , 2 , 2 ), (2 , 4 , 2 ), (2 , 2 , 4 )],
1215
+ ids = [
1216
+ "(2, 2)" ,
1217
+ "(3, 4)" ,
1218
+ "(5, 3)" ,
1219
+ "(16, 16)" ,
1220
+ "(2, 2, 2)" ,
1221
+ "(2, 4, 2)" ,
1222
+ "(2, 2, 4)" ,
1223
+ ],
1224
+ )
1225
+ def test_pinv (self , dtype , shape ):
1226
+ a = numpy .random .randn (* shape ).astype (dtype )
1227
+ if numpy .issubdtype (dtype , numpy .complexfloating ):
1228
+ a += 1j * numpy .random .randn (* shape )
1229
+ a_dp = inp .array (a )
1230
+
1231
+ B = numpy .linalg .pinv (a )
1232
+ B_dp = inp .linalg .pinv (a_dp )
1233
+
1234
+ self .check_types_shapes (B_dp , B )
1235
+ self .get_tol (dtype )
1236
+ tol = self ._tol
1237
+ assert_allclose (B_dp , B , rtol = tol , atol = tol )
1238
+
1239
+ if a .ndim == 2 :
1240
+ reconstructed = inp .dot (a_dp , inp .dot (B_dp , a_dp ))
1241
+ else : # a.ndim > 2
1242
+ reconstructed = inp .matmul (a_dp , inp .matmul (B_dp , a_dp ))
1243
+
1244
+ assert_allclose (reconstructed , a_dp , rtol = tol , atol = tol )
1245
+
1246
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
1247
+ @pytest .mark .parametrize (
1248
+ "shape" ,
1249
+ [(2 , 2 ), (16 , 16 )],
1250
+ ids = ["(2, 2)" , "(16, 16)" ],
1251
+ )
1252
+ def test_pinv_hermitian (self , dtype , shape ):
1253
+ a = numpy .random .randn (* shape ).astype (dtype )
1254
+ if numpy .issubdtype (dtype , numpy .complexfloating ):
1255
+ a += 1j * numpy .random .randn (* shape )
1256
+ a = (a + a .conj ().T ) / 2
1257
+
1258
+ a_dp = inp .array (a )
1259
+
1260
+ B = numpy .linalg .pinv (a , hermitian = True )
1261
+ B_dp = inp .linalg .pinv (a_dp , hermitian = True )
1262
+
1263
+ self .check_types_shapes (B_dp , B )
1264
+ self .get_tol (dtype )
1265
+ tol = self ._tol
1266
+
1267
+ reconstructed = inp .dot (inp .dot (a_dp , B_dp ), a_dp )
1268
+ assert_allclose (reconstructed , a_dp , rtol = tol , atol = tol )
1269
+
1270
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
1271
+ @pytest .mark .parametrize (
1272
+ "shape" ,
1273
+ [(0 , 0 ), (0 , 2 ), (2 , 0 ), (2 , 0 , 3 ), (2 , 3 , 0 ), (0 , 2 , 3 )],
1274
+ ids = [
1275
+ "(0, 0)" ,
1276
+ "(0, 2)" ,
1277
+ "(2 ,0)" ,
1278
+ "(2, 0, 3)" ,
1279
+ "(2, 3, 0)" ,
1280
+ "(0, 2, 3)" ,
1281
+ ],
1282
+ )
1283
+ def test_pinv_empty (self , dtype , shape ):
1284
+ a = numpy .empty (shape , dtype = dtype )
1285
+ a_dp = inp .array (a )
1286
+
1287
+ B = numpy .linalg .pinv (a )
1288
+ B_dp = inp .linalg .pinv (a_dp )
1289
+
1290
+ assert_dtype_allclose (B_dp , B )
1291
+
1292
+ def test_pinv_strides (self ):
1293
+ a = numpy .random .randn (5 , 5 )
1294
+ a_dp = inp .array (a )
1295
+
1296
+ self .get_tol (a_dp .dtype )
1297
+ tol = self ._tol
1298
+
1299
+ # positive strides
1300
+ B = numpy .linalg .pinv (a [::2 , ::2 ])
1301
+ B_dp = inp .linalg .pinv (a_dp [::2 , ::2 ])
1302
+ assert_allclose (B_dp , B , rtol = tol , atol = tol )
1303
+
1304
+ # negative strides
1305
+ B = numpy .linalg .pinv (a [::- 2 , ::- 2 ])
1306
+ B_dp = inp .linalg .pinv (a_dp [::- 2 , ::- 2 ])
1307
+ assert_allclose (B_dp , B , rtol = tol , atol = tol )
1308
+
1309
+ def test_pinv_errors (self ):
1310
+ a_dp = inp .array ([[1 , 2 ], [3 , 4 ]], dtype = "float32" )
1311
+
1312
+ # unsupported type `a`
1313
+ a_np = inp .asnumpy (a_dp )
1314
+ assert_raises (TypeError , inp .linalg .pinv , a_np )
1315
+
1316
+ # unsupported type `rcond`
1317
+ rcond = numpy .array (0.5 , dtype = "float32" )
1318
+ assert_raises (TypeError , inp .linalg .pinv , a_dp , rcond )
1319
+ assert_raises (TypeError , inp .linalg .pinv , a_dp , [0.5 ])
1320
+
1321
+ # non-broadcastable `rcond`
1322
+ rcond_dp = inp .array ([0.5 ], dtype = "float32" )
1323
+ assert_raises (ValueError , inp .linalg .pinv , a_dp , rcond_dp )
1324
+
1325
+ # a.ndim < 2
1326
+ a_dp_ndim_1 = a_dp .flatten ()
1327
+ assert_raises (inp .linalg .LinAlgError , inp .linalg .pinv , a_dp_ndim_1 )
1328
+
1329
+ # diffetent queue
1330
+ a_queue = dpctl .SyclQueue ()
1331
+ rcond_queue = dpctl .SyclQueue ()
1332
+ a_dp_q = inp .array (a_dp , sycl_queue = a_queue )
1333
+ rcond_dp_q = inp .array ([0.5 ], dtype = "float32" , sycl_queue = rcond_queue )
1334
+ assert_raises (ValueError , inp .linalg .pinv , a_dp_q , rcond_dp_q )
0 commit comments