Skip to content

Commit 2b0d171

Browse files
committed
Implemented Eye Op in PyTorch
- Added support for diagonal offset (param `k`)
1 parent 5c8afae commit 2b0d171

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye
1010

1111

1212
@singledispatch
@@ -89,3 +89,19 @@ def arange(start, stop, step):
8989
return torch.arange(start, stop, step, dtype=dtype)
9090

9191
return arange
92+
93+
94+
@pytorch_funcify.register(Eye)
95+
def pytorch_funcify_eye(op, **kwargs):
96+
dtype = getattr(torch, op.dtype)
97+
98+
def eye(N, M, k):
99+
mjr, mnr = (M, N) if k > 0 else (N, M)
100+
k_abs = abs(k)
101+
zeros = torch.zeros(N, M, dtype=dtype)
102+
if k_abs < mjr:
103+
l_ones = min(mjr - k_abs, mnr)
104+
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k)
105+
return zeros
106+
107+
return eye

tests/link/pytorch/test_basic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import Op
1414
from pytensor.raise_op import CheckAndRaise
15-
from pytensor.tensor import alloc, arange, as_tensor, empty
15+
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1616
from pytensor.tensor.type import scalar, vector
1717

1818

@@ -235,3 +235,20 @@ def test_arange():
235235
FunctionGraph([start, stop, step], [out]),
236236
[np.array(1), np.array(10), np.array(2)],
237237
)
238+
239+
240+
def test_eye():
241+
N = scalar("N", dtype="int64")
242+
M = scalar("M", dtype="int64")
243+
k = scalar("k", dtype="int64")
244+
245+
out = eye(N, M, k, dtype="int16")
246+
247+
trange = range(1, 6)
248+
for _N in trange:
249+
for _M in trange:
250+
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
251+
compare_pytorch_and_py(
252+
FunctionGraph([N, M, k], [out]),
253+
[np.array(_N + 1), np.array(_M + 1), np.array(_k)],
254+
)

0 commit comments

Comments
 (0)