Skip to content

Commit a7163f9

Browse files
refine more code
1 parent bb40851 commit a7163f9

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

array_api_compat/paddle/_info.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def _dtypes(self, kind):
170170
int32 = paddle.int32
171171
int64 = paddle.int64
172172
uint8 = paddle.uint8
173-
# uint16, uint32, and uint64 are present in newer versions of pytorch,
174-
# but they aren't generally supported by the array API functions, so
173+
# uint16, uint32, and uint64 are not fully supported in paddle,
175174
# we omit them from this function.
176175
float32 = paddle.float32
177176
float64 = paddle.float64

array_api_compat/paddle/linalg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@
2828
from ._aliases import matmul, matrix_transpose, tensordot
2929

3030
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
31-
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
32-
31+
# first axis with size 3)
3332

3433
# paddle.cross also does not support broadcasting when it would add new
35-
# dimensions https://github.com/pytorch/pytorch/issues/39656
34+
# dimensions
3635
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
3736
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
3837
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):

tests/test_array_namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import warnings
44

5-
# import jax
5+
import jax
66
import numpy as np
77
import pytest
88
import torch

0 commit comments

Comments
 (0)