@@ -42,19 +42,27 @@ type blas_index = c_int; // blas index type
42
42
impl < A , S > ArrayBase < S , Ix1 >
43
43
where S : Data < Elem =A > ,
44
44
{
45
- /// Compute the dot product of one-dimensional arrays.
45
+ /// Perform dot product or matrix multiplication of arrays `self` and `rhs` .
46
46
///
47
- /// The dot product is a sum of the elementwise products (no conjugation
48
- /// of complex operands, and thus not their inner product).
47
+ /// `Rhs` may be either a one-dimensional or a two-dimensional array.
49
48
///
50
- /// **Panics** if the arrays are not of the same length.<br>
49
+ /// If `Rhs` is one-dimensional, then the operation is a vector dot
50
+ /// product, which is the sum of the elementwise products (no conjugation
51
+ /// of complex operands, and thus not their inner product). In this case,
52
+ /// `self` and `rhs` must be the same length.
53
+ ///
54
+ /// If `Rhs` is two-dimensional, then the operation is matrix
55
+ /// multiplication, where `self` is treated as a row vector. In this case,
56
+ /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
57
+ /// shape *N*.
58
+ ///
59
+ /// **Panics** if the array shapes are incompatible.<br>
51
60
/// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
52
61
/// layout allows.
53
- pub fn dot < S2 > ( & self , rhs : & ArrayBase < S2 , Ix1 > ) -> A
54
- where S2 : Data < Elem =A > ,
55
- A : LinalgScalar ,
62
+ pub fn dot < Rhs > ( & self , rhs : & Rhs ) -> <Self as Dot < Rhs > >:: Output
63
+ where Self : Dot < Rhs >
56
64
{
57
- self . dot_impl ( rhs)
65
+ Dot :: dot ( self , rhs)
58
66
}
59
67
60
68
fn dot_generic < S2 > ( & self , rhs : & ArrayBase < S2 , Ix1 > ) -> A
@@ -156,6 +164,49 @@ pub trait Dot<Rhs> {
156
164
fn dot ( & self , rhs : & Rhs ) -> Self :: Output ;
157
165
}
158
166
167
+ impl < A , S , S2 > Dot < ArrayBase < S2 , Ix1 > > for ArrayBase < S , Ix1 >
168
+ where S : Data < Elem =A > ,
169
+ S2 : Data < Elem =A > ,
170
+ A : LinalgScalar ,
171
+ {
172
+ type Output = A ;
173
+
174
+ /// Compute the dot product of one-dimensional arrays.
175
+ ///
176
+ /// The dot product is a sum of the elementwise products (no conjugation
177
+ /// of complex operands, and thus not their inner product).
178
+ ///
179
+ /// **Panics** if the arrays are not of the same length.<br>
180
+ /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
181
+ /// layout allows.
182
+ fn dot ( & self , rhs : & ArrayBase < S2 , Ix1 > ) -> A
183
+ {
184
+ self . dot_impl ( rhs)
185
+ }
186
+ }
187
+
188
+ impl < A , S , S2 > Dot < ArrayBase < S2 , Ix2 > > for ArrayBase < S , Ix1 >
189
+ where S : Data < Elem =A > ,
190
+ S2 : Data < Elem =A > ,
191
+ A : LinalgScalar ,
192
+ {
193
+ type Output = Array < A , Ix1 > ;
194
+
195
+ /// Perform the matrix multiplication of the row vector `self` and
196
+ /// rectangular matrix `rhs`.
197
+ ///
198
+ /// The array shapes must agree in the way that
199
+ /// if `self` is *M*, then `rhs` is *M* × *N*.
200
+ ///
201
+ /// Return a result array with shape *N*.
202
+ ///
203
+ /// **Panics** if shapes are incompatible.
204
+ fn dot ( & self , rhs : & ArrayBase < S2 , Ix2 > ) -> Array < A , Ix1 >
205
+ {
206
+ rhs. t ( ) . dot ( self )
207
+ }
208
+ }
209
+
159
210
impl < A , S > ArrayBase < S , Ix2 >
160
211
where S : Data < Elem =A > ,
161
212
{
0 commit comments