@@ -149,9 +149,42 @@ def _general_dot(
149
149
def contraction_list_from_path (
150
150
subscripts : str , operands : Sequence [TensorLike ], path : PATH
151
151
):
152
- """TODO Docstrings
152
+ """
153
+ Generate a list of contraction steps based on the provided einsum path.
154
+
155
+ Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
156
+
157
+ When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
158
+ some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
159
+ results to the end of the stack.
153
160
154
- Code adapted from einsum_opt
161
+ Parameters
162
+ ----------
163
+ subscripts: str
164
+ Einsum signature string describing the computation to be performed.
165
+
166
+ operands: Sequence[TensorLike]
167
+ Tensors described by the subscripts.
168
+
169
+ path: tuple[tuple[int] | tuple[int, int]]
170
+ A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
171
+ they should be contracted.
172
+
173
+ Returns
174
+ -------
175
+ contraction_list: list
176
+ A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
177
+ - contraction_inds: tuple[int]
178
+ The indices of the operands to be contracted
179
+ - idx_removed: str
180
+ The indices of the contracted indices (those removed from the einsum string at this step)
181
+ - einsum_str: str
182
+ The einsum string for the contraction step
183
+ - remaining: None
184
+ The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
185
+ - do_blas: None
186
+ Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
187
+ but not used.
155
188
"""
156
189
fake_operands = [
157
190
np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
@@ -200,9 +233,13 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
200
233
201
234
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
202
235
236
+ Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
237
+ this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
238
+
203
239
Parameters
204
240
----------
205
241
subscripts: str
242
+ Einsum signature string describing the computation to be performed.
206
243
207
244
operands: sequence of TensorVariable
208
245
Tensors to be multiplied and summed.
@@ -211,7 +248,110 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
211
248
-------
212
249
TensorVariable
213
250
The result of the einsum operation.
251
+
252
+ See Also
253
+ --------
254
+ pytensor.tensor.tensordot: Generalized dot product between two tensors
255
+ pytensor.tensor.dot: Matrix multiplication between two tensors
256
+ numpy.einsum: The numpy implementation of einsum
257
+
258
+ Examples
259
+ --------
260
+ Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
261
+ tensors to be operated on. The string must follow the following rules:
262
+
263
+ 1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
264
+ 2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
265
+ must be a corresponding tensor in the input sequence.
266
+ 3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
267
+ in the index string.
268
+ 4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
269
+ the same shape in each input.
270
+ 5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
271
+ indices in the output.
272
+ 6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
273
+ dimensions that are not implicated in the operation.
274
+
275
+ Finally, two rules about these indicies govern how computation is carried out:
276
+
277
+ 1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
278
+ 2. Indices that appear on the input side but not the output side are summed over.
279
+
280
+ The operation of these rules is best understood via examples:
281
+
282
+ Example 1: Matrix multiplication
283
+
284
+ .. code-block:: python
285
+
286
+ import pytensor as pt
287
+ A = pt.matrix("A")
288
+ B = pt.matrix("B")
289
+ C = pt.einsum("ij, jk -> ik", A, B)
290
+
291
+ This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
292
+ signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
293
+ multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
294
+ away.
295
+
296
+ Example 2: Batched matrix multiplication
297
+
298
+ .. code-block:: python
299
+
300
+ import pytensor as pt
301
+ A = pt.tensor("A", shape=(None, 4, 5))
302
+ B = pt.tensor("B", shape=(None, 5, 6))
303
+ C = pt.einsum("bij, bjk -> bik", A, B)
304
+
305
+ This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
306
+ the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
307
+ input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
308
+ should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
309
+ the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
310
+
311
+ Example 3: Batched matrix multiplication with elipses
312
+
313
+ .. code-block:: python
314
+
315
+ import pytensor as pt
316
+ A = pt.tensor("A", shape=(4, None, None, None, 5))
317
+ B = pt.tensor("B", shape=(5, None, None, None, 6))
318
+ C = pt.einsum("i...j, j...k -> ...ik", A, B)
319
+
320
+ This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
321
+ of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
322
+ also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
323
+ between the last dimension A and the first dimension of B, which is perfectly valid.
324
+
325
+ Example 4: Outer product
326
+
327
+ .. code-block:: python
328
+
329
+ import pytensor as pt
330
+ x = pt.tensor("x", shape=(3,))
331
+ y = pt.tensor("y", shape=(4,))
332
+ z = pt.einsum("i, j -> ij", x, y)
333
+
334
+ This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
335
+ and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
336
+ multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
337
+
338
+ Example 5: Convolution
339
+
340
+ .. code-block:: python
341
+
342
+ import pytensor as pt
343
+ x = pt.tensor("x", shape=(None, None, None, None, None, None))
344
+ w = pt.tensor("w", shape=(None, None, None, None))
345
+ y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
346
+
347
+ Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
348
+ and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
349
+ computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
350
+ dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
351
+ ``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
352
+ together.
214
353
"""
354
+
215
355
# TODO: Is this doing something clever about unknown shapes?
216
356
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
217
357
# using einsum_call=True here is an internal api for opt_einsum... sorry
@@ -224,21 +364,24 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
224
364
shapes = [operand .type .shape for operand in operands ]
225
365
226
366
if None in itertools .chain .from_iterable (shapes ):
227
- # We mark optimize = False, even in cases where there is no ordering optimization to be done
228
- # because the inner graph may have to accommodate dynamic shapes.
229
- # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
367
+ # Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
368
+ # the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
369
+ # pushing intermediate results to the end of the stack.
370
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
371
+ # match more often
372
+
373
+ # If shapes become known later we will likely want to rebuild the Op (unless we inline it)
230
374
if len (operands ) == 1 :
231
375
path = [(0 ,)]
232
376
else :
233
- # Create default path of repeating (1,0) that executes left to right cyclically
234
- # with intermediate outputs being pushed to the end of the stack
235
- # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
236
377
path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
237
378
contraction_list = contraction_list_from_path (subscripts , operands , path )
238
- optimize = (
239
- len ( operands ) <= 2
240
- ) # If there are only 1 or 2 operands, there is no optimization to be done?
379
+
380
+ # If there are only 1 or 2 operands, there is no optimization to be done?
381
+ optimize = len ( operands ) <= 2
241
382
else :
383
+ # Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
384
+ # contraction order.
242
385
_ , contraction_list = contract_path (
243
386
subscripts ,
244
387
* shapes ,
@@ -253,6 +396,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
253
396
def sum_uniques (
254
397
operand : TensorVariable , names : str , uniques : list [str ]
255
398
) -> tuple [TensorVariable , str ]:
399
+ """Reduce unique indices (those that appear only once) in a given contraction step via summing."""
256
400
if uniques :
257
401
axes = [names .index (name ) for name in uniques ]
258
402
operand = operand .sum (axes )
@@ -265,6 +409,8 @@ def sum_repeats(
265
409
counts : collections .Counter ,
266
410
keep_names : str ,
267
411
) -> tuple [TensorVariable , str ]:
412
+ """Reduce repeated indices in a given contraction step via summation against an identity matrix."""
413
+
268
414
for name , count in counts .items ():
269
415
if count > 1 :
270
416
axes = [i for i , n in enumerate (names ) if n == name ]
0 commit comments