Skip to content

Commit 73e10e0

Browse files
jessegrabowskiricardoV94
authored andcommitted
Add docstrings
1 parent 0541aaa commit 73e10e0

File tree

1 file changed

+157
-11
lines changed

1 file changed

+157
-11
lines changed

pytensor/tensor/einsum.py

Lines changed: 157 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,42 @@ def _general_dot(
149149
def contraction_list_from_path(
150150
subscripts: str, operands: Sequence[TensorLike], path: PATH
151151
):
152-
"""TODO Docstrings
152+
"""
153+
Generate a list of contraction steps based on the provided einsum path.
154+
155+
Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
156+
157+
When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
158+
some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
159+
results to the end of the stack.
153160
154-
Code adapted from einsum_opt
161+
Parameters
162+
----------
163+
subscripts: str
164+
Einsum signature string describing the computation to be performed.
165+
166+
operands: Sequence[TensorLike]
167+
Tensors described by the subscripts.
168+
169+
path: tuple[tuple[int] | tuple[int, int]]
170+
A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
171+
they should be contracted.
172+
173+
Returns
174+
-------
175+
contraction_list: list
176+
A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
177+
- contraction_inds: tuple[int]
178+
The indices of the operands to be contracted
179+
- idx_removed: str
180+
The indices of the contracted indices (those removed from the einsum string at this step)
181+
- einsum_str: str
182+
The einsum string for the contraction step
183+
- remaining: None
184+
The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
185+
- do_blas: None
186+
Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
187+
but not used.
155188
"""
156189
fake_operands = [
157190
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
@@ -200,9 +233,13 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
200233
201234
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
202235
236+
Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
237+
this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
238+
203239
Parameters
204240
----------
205241
subscripts: str
242+
Einsum signature string describing the computation to be performed.
206243
207244
operands: sequence of TensorVariable
208245
Tensors to be multiplied and summed.
@@ -211,7 +248,110 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
211248
-------
212249
TensorVariable
213250
The result of the einsum operation.
251+
252+
See Also
253+
--------
254+
pytensor.tensor.tensordot: Generalized dot product between two tensors
255+
pytensor.tensor.dot: Matrix multiplication between two tensors
256+
numpy.einsum: The numpy implementation of einsum
257+
258+
Examples
259+
--------
260+
Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
261+
tensors to be operated on. The string must follow the following rules:
262+
263+
1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
264+
2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
265+
must be a corresponding tensor in the input sequence.
266+
3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
267+
in the index string.
268+
4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
269+
the same shape in each input.
270+
5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
271+
indices in the output.
272+
6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
273+
dimensions that are not implicated in the operation.
274+
275+
Finally, two rules about these indicies govern how computation is carried out:
276+
277+
1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
278+
2. Indices that appear on the input side but not the output side are summed over.
279+
280+
The operation of these rules is best understood via examples:
281+
282+
Example 1: Matrix multiplication
283+
284+
.. code-block:: python
285+
286+
import pytensor as pt
287+
A = pt.matrix("A")
288+
B = pt.matrix("B")
289+
C = pt.einsum("ij, jk -> ik", A, B)
290+
291+
This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
292+
signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
293+
multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
294+
away.
295+
296+
Example 2: Batched matrix multiplication
297+
298+
.. code-block:: python
299+
300+
import pytensor as pt
301+
A = pt.tensor("A", shape=(None, 4, 5))
302+
B = pt.tensor("B", shape=(None, 5, 6))
303+
C = pt.einsum("bij, bjk -> bik", A, B)
304+
305+
This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
306+
the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
307+
input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
308+
should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
309+
the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
310+
311+
Example 3: Batched matrix multiplication with elipses
312+
313+
.. code-block:: python
314+
315+
import pytensor as pt
316+
A = pt.tensor("A", shape=(4, None, None, None, 5))
317+
B = pt.tensor("B", shape=(5, None, None, None, 6))
318+
C = pt.einsum("i...j, j...k -> ...ik", A, B)
319+
320+
This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
321+
of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
322+
also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
323+
between the last dimension A and the first dimension of B, which is perfectly valid.
324+
325+
Example 4: Outer product
326+
327+
.. code-block:: python
328+
329+
import pytensor as pt
330+
x = pt.tensor("x", shape=(3,))
331+
y = pt.tensor("y", shape=(4,))
332+
z = pt.einsum("i, j -> ij", x, y)
333+
334+
This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
335+
and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
336+
multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
337+
338+
Example 5: Convolution
339+
340+
.. code-block:: python
341+
342+
import pytensor as pt
343+
x = pt.tensor("x", shape=(None, None, None, None, None, None))
344+
w = pt.tensor("w", shape=(None, None, None, None))
345+
y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
346+
347+
Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
348+
and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
349+
computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
350+
dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
351+
``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
352+
together.
214353
"""
354+
215355
# TODO: Is this doing something clever about unknown shapes?
216356
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
217357
# using einsum_call=True here is an internal api for opt_einsum... sorry
@@ -224,21 +364,24 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
224364
shapes = [operand.type.shape for operand in operands]
225365

226366
if None in itertools.chain.from_iterable(shapes):
227-
# We mark optimize = False, even in cases where there is no ordering optimization to be done
228-
# because the inner graph may have to accommodate dynamic shapes.
229-
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
367+
# Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
368+
# the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
369+
# pushing intermediate results to the end of the stack.
370+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
371+
# match more often
372+
373+
# If shapes become known later we will likely want to rebuild the Op (unless we inline it)
230374
if len(operands) == 1:
231375
path = [(0,)]
232376
else:
233-
# Create default path of repeating (1,0) that executes left to right cyclically
234-
# with intermediate outputs being pushed to the end of the stack
235-
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
236377
path = [(1, 0) for i in range(len(operands) - 1)]
237378
contraction_list = contraction_list_from_path(subscripts, operands, path)
238-
optimize = (
239-
len(operands) <= 2
240-
) # If there are only 1 or 2 operands, there is no optimization to be done?
379+
380+
# If there are only 1 or 2 operands, there is no optimization to be done?
381+
optimize = len(operands) <= 2
241382
else:
383+
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
384+
# contraction order.
242385
_, contraction_list = contract_path(
243386
subscripts,
244387
*shapes,
@@ -253,6 +396,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
253396
def sum_uniques(
254397
operand: TensorVariable, names: str, uniques: list[str]
255398
) -> tuple[TensorVariable, str]:
399+
"""Reduce unique indices (those that appear only once) in a given contraction step via summing."""
256400
if uniques:
257401
axes = [names.index(name) for name in uniques]
258402
operand = operand.sum(axes)
@@ -265,6 +409,8 @@ def sum_repeats(
265409
counts: collections.Counter,
266410
keep_names: str,
267411
) -> tuple[TensorVariable, str]:
412+
"""Reduce repeated indices in a given contraction step via summation against an identity matrix."""
413+
268414
for name, count in counts.items():
269415
if count > 1:
270416
axes = [i for i, n in enumerate(names) if n == name]

0 commit comments

Comments
 (0)