@@ -61,14 +61,117 @@ def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs
61
61
62
62
63
63
def _iota (shape : TensorVariable , axis : int ) -> TensorVariable :
64
+ """
65
+ Create an array with values increasing along the specified axis.
66
+
67
+ Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
68
+ increasing along the specified axis.
69
+
70
+ Parameters
71
+ ----------
72
+ shape: TensorVariable
73
+ The shape of the array to be created.
74
+ axis: int
75
+ The axis along which to fill the array with increasing values.
76
+
77
+ Returns
78
+ -------
79
+ TensorVariable
80
+ An array with values increasing along the specified axis.
81
+
82
+ Examples
83
+ --------
84
+ In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:
85
+
86
+ .. testcode::
87
+
88
+ import pytensor as pt
89
+ shape = pt.as_tensor('shape', (5,))
90
+ print(pt._iota(shape, 0).eval())
91
+
92
+ .. testoutput::
93
+
94
+ [0., 1., 2., 3., 4.]
95
+
96
+ In higher dimensions, it will look like many concatenated `pt.arange`:
97
+
98
+ .. testcode::
99
+
100
+ shape = pt.as_tensor('shape', (5, 5))
101
+ print(pt._iota(shape, 1).eval())
102
+
103
+ .. testoutput::
104
+
105
+ [[0., 1., 2., 3., 4.],
106
+ [0., 1., 2., 3., 4.],
107
+ [0., 1., 2., 3., 4.],
108
+ [0., 1., 2., 3., 4.],
109
+ [0., 1., 2., 3., 4.]]
110
+
111
+ Setting ``axis=0`` above would result in the transpose of the output.
112
+ """
64
113
len_shape = get_vector_length (shape )
65
114
axis = normalize_axis_index (axis , len_shape )
66
115
values = arange (shape [axis ])
67
116
return broadcast_to (shape_padright (values , len_shape - axis - 1 ), shape )
68
117
69
118
70
- def _delta (shape , axes : Sequence [int ]) -> TensorVariable :
71
- """This utility function exists for creating Kronecker delta arrays."""
119
+ def _delta (shape : TensorVariable , axes : Sequence [int ]) -> TensorVariable :
120
+ """
121
+ Create a Kroncker delta tensor.
122
+
123
+ The Kroncker delta function is defined:
124
+
125
+ .. math::
126
+
127
+ \\ delta(i, j) = \b egin{cases} 1 & \t ext{if} \\ quad i = j \\ 0 & \t ext{otherwise} \\ end{cases}
128
+
129
+ To create a Kronecker tensor, the delta function is applied elementwise to the axes specified. The result is a
130
+ tensor of booleans, with ``True`` where the axis indices coincide, and ``False`` otherwise. See below for examples.
131
+
132
+ Parameters
133
+ ----------
134
+ shape: TensorVariable
135
+ The shape of the tensor to be created. Note that `_delta` is not defined for 1d tensors, because there is no
136
+ second axis against which to compare.
137
+ axes: sequence of int
138
+ Axes whose indices should be compared. Note that `_delta` is not defined for a single axis, because there is no
139
+ second axis against which to compare.
140
+
141
+ Examples
142
+ --------
143
+ An easy case to understand is when the shape is square and the number of axes is equal to the number of dimensions.
144
+ This will result in a generalized identity tensor, with ``True`` along the main diagonal:
145
+
146
+ .. testcode::
147
+
148
+ from pytensor.tensor.einsum import _delta
149
+ print(_delta((5, 5), (0, 1)).eval())
150
+
151
+ .. testoutput::
152
+
153
+ [[ True False False False False]
154
+ [False True False False False]
155
+ [False False True False False]
156
+ [False False False True False]
157
+ [False False False False True]]
158
+
159
+ In the case where the shape is not square, the result will be a tensor with ``True`` along the main diagonal and
160
+ ``False`` elsewhere:
161
+
162
+ .. testcode::
163
+
164
+ from pytensor.tensor.einsum import _delta
165
+ print(_delta((3, 2), (0, 1)).eval())
166
+
167
+ .. testoutput::
168
+
169
+ [[ True False]
170
+ [False True]
171
+ [False False]]
172
+ """
173
+ if len (axes ) == 1 :
174
+ raise ValueError ("Need at least two axes to create a delta tensor" )
72
175
base_shape = stack ([shape [axis ] for axis in axes ])
73
176
iotas = [_iota (base_shape , i ) for i in range (len (axes ))]
74
177
eyes = [eq (i1 , i2 ) for i1 , i2 in pairwise (iotas )]
@@ -81,6 +184,46 @@ def _general_dot(
81
184
axes : Sequence [Sequence [int ]], # Should be length 2,
82
185
batch_axes : Sequence [Sequence [int ]], # Should be length 2,
83
186
) -> TensorVariable :
187
+ """
188
+ Generalized dot product between two tensors.
189
+
190
+ Ultimately ``_general_dot`` is a call to `tensor_dot`, performing a multiply-and-sum ("dot") operation between two
191
+ tensors, along a requested dimension. This function further generalizes this operation by allowing arbitrary
192
+ batch dimensions to be specified for each tensor.
193
+
194
+
195
+ Parameters
196
+ ----------
197
+ vars: tuple[TensorVariable, TensorVariable]
198
+ The tensors to be ``tensor_dot``ed
199
+ axes: Sequence[Sequence[int]]
200
+ The axes along which to perform the dot product. Should be a sequence of two sequences, one for each tensor.
201
+ batch_axes: Sequence[Sequence[int]]
202
+ The batch axes for each tensor. Should be a sequence of two sequences, one for each tensor.
203
+
204
+ Returns
205
+ -------
206
+ TensorVariable
207
+ The result of the ``tensor_dot`` product.
208
+
209
+ Examples
210
+ --------
211
+ Perform a batched dot product between two 3d tensors:
212
+
213
+ .. testcode::
214
+
215
+ import pytensor.tensor as pt
216
+ from pytensor.tensor.einsum import _general_dot
217
+ A = pt.tensor(shape = (3, 4, 5))
218
+ B = pt.tensor(shape = (3, 5, 2))
219
+
220
+ result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
221
+ print(result.type.shape)
222
+
223
+ .. testoutput::
224
+
225
+ (3, 4, 2)
226
+ """
84
227
# Shortcut for non batched case
85
228
if not batch_axes [0 ] and not batch_axes [1 ]:
86
229
return tensordot (* vars , axes = axes )
0 commit comments