Skip to content

Commit 5ba3cc7

Browse files
First pass on docstrings
1 parent 48c663a commit 5ba3cc7

File tree

1 file changed

+145
-2
lines changed

1 file changed

+145
-2
lines changed

pytensor/tensor/einsum.py

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,117 @@ def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs
6161

6262

6363
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
64+
"""
65+
Create an array with values increasing along the specified axis.
66+
67+
Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
68+
increasing along the specified axis.
69+
70+
Parameters
71+
----------
72+
shape: TensorVariable
73+
The shape of the array to be created.
74+
axis: int
75+
The axis along which to fill the array with increasing values.
76+
77+
Returns
78+
-------
79+
TensorVariable
80+
An array with values increasing along the specified axis.
81+
82+
Examples
83+
--------
84+
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:
85+
86+
.. testcode::
87+
88+
import pytensor as pt
89+
shape = pt.as_tensor('shape', (5,))
90+
print(pt._iota(shape, 0).eval())
91+
92+
.. testoutput::
93+
94+
[0., 1., 2., 3., 4.]
95+
96+
In higher dimensions, it will look like many concatenated `pt.arange`:
97+
98+
.. testcode::
99+
100+
shape = pt.as_tensor('shape', (5, 5))
101+
print(pt._iota(shape, 1).eval())
102+
103+
.. testoutput::
104+
105+
[[0., 1., 2., 3., 4.],
106+
[0., 1., 2., 3., 4.],
107+
[0., 1., 2., 3., 4.],
108+
[0., 1., 2., 3., 4.],
109+
[0., 1., 2., 3., 4.]]
110+
111+
Setting ``axis=0`` above would result in the transpose of the output.
112+
"""
64113
len_shape = get_vector_length(shape)
65114
axis = normalize_axis_index(axis, len_shape)
66115
values = arange(shape[axis])
67116
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)
68117

69118

70-
def _delta(shape, axes: Sequence[int]) -> TensorVariable:
71-
"""This utility function exists for creating Kronecker delta arrays."""
119+
def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
120+
"""
121+
Create a Kroncker delta tensor.
122+
123+
The Kroncker delta function is defined:
124+
125+
.. math::
126+
127+
\\delta(i, j) = \begin{cases} 1 & \text{if} \\quad i = j \\ 0 & \text{otherwise} \\end{cases}
128+
129+
To create a Kronecker tensor, the delta function is applied elementwise to the axes specified. The result is a
130+
tensor of booleans, with ``True`` where the axis indices coincide, and ``False`` otherwise. See below for examples.
131+
132+
Parameters
133+
----------
134+
shape: TensorVariable
135+
The shape of the tensor to be created. Note that `_delta` is not defined for 1d tensors, because there is no
136+
second axis against which to compare.
137+
axes: sequence of int
138+
Axes whose indices should be compared. Note that `_delta` is not defined for a single axis, because there is no
139+
second axis against which to compare.
140+
141+
Examples
142+
--------
143+
An easy case to understand is when the shape is square and the number of axes is equal to the number of dimensions.
144+
This will result in a generalized identity tensor, with ``True`` along the main diagonal:
145+
146+
.. testcode::
147+
148+
from pytensor.tensor.einsum import _delta
149+
print(_delta((5, 5), (0, 1)).eval())
150+
151+
.. testoutput::
152+
153+
[[ True False False False False]
154+
[False True False False False]
155+
[False False True False False]
156+
[False False False True False]
157+
[False False False False True]]
158+
159+
In the case where the shape is not square, the result will be a tensor with ``True`` along the main diagonal and
160+
``False`` elsewhere:
161+
162+
.. testcode::
163+
164+
from pytensor.tensor.einsum import _delta
165+
print(_delta((3, 2), (0, 1)).eval())
166+
167+
.. testoutput::
168+
169+
[[ True False]
170+
[False True]
171+
[False False]]
172+
"""
173+
if len(axes) == 1:
174+
raise ValueError("Need at least two axes to create a delta tensor")
72175
base_shape = stack([shape[axis] for axis in axes])
73176
iotas = [_iota(base_shape, i) for i in range(len(axes))]
74177
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
@@ -81,6 +184,46 @@ def _general_dot(
81184
axes: Sequence[Sequence[int]], # Should be length 2,
82185
batch_axes: Sequence[Sequence[int]], # Should be length 2,
83186
) -> TensorVariable:
187+
"""
188+
Generalized dot product between two tensors.
189+
190+
Ultimately ``_general_dot`` is a call to `tensor_dot`, performing a multiply-and-sum ("dot") operation between two
191+
tensors, along a requested dimension. This function further generalizes this operation by allowing arbitrary
192+
batch dimensions to be specified for each tensor.
193+
194+
195+
Parameters
196+
----------
197+
vars: tuple[TensorVariable, TensorVariable]
198+
The tensors to be ``tensor_dot``ed
199+
axes: Sequence[Sequence[int]]
200+
The axes along which to perform the dot product. Should be a sequence of two sequences, one for each tensor.
201+
batch_axes: Sequence[Sequence[int]]
202+
The batch axes for each tensor. Should be a sequence of two sequences, one for each tensor.
203+
204+
Returns
205+
-------
206+
TensorVariable
207+
The result of the ``tensor_dot`` product.
208+
209+
Examples
210+
--------
211+
Perform a batched dot product between two 3d tensors:
212+
213+
.. testcode::
214+
215+
import pytensor.tensor as pt
216+
from pytensor.tensor.einsum import _general_dot
217+
A = pt.tensor(shape = (3, 4, 5))
218+
B = pt.tensor(shape = (3, 5, 2))
219+
220+
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
221+
print(result.type.shape)
222+
223+
.. testoutput::
224+
225+
(3, 4, 2)
226+
"""
84227
# Shortcut for non batched case
85228
if not batch_axes[0] and not batch_axes[1]:
86229
return tensordot(*vars, axes=axes)

0 commit comments

Comments
 (0)