Skip to content

Commit af13503

Browse files
committed
Implement product of row vector and matrix
1 parent 58b57ab commit af13503

File tree

2 files changed

+113
-8
lines changed

2 files changed

+113
-8
lines changed

src/linalg/impl_linalg.rs

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,27 @@ type blas_index = c_int; // blas index type
4242
impl<A, S> ArrayBase<S, Ix1>
4343
where S: Data<Elem=A>,
4444
{
45-
/// Compute the dot product of one-dimensional arrays.
45+
/// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
4646
///
47-
/// The dot product is a sum of the elementwise products (no conjugation
48-
/// of complex operands, and thus not their inner product).
47+
/// `Rhs` may be either a one-dimensional or a two-dimensional array.
4948
///
50-
/// **Panics** if the arrays are not of the same length.<br>
49+
/// If `Rhs` is one-dimensional, then the operation is a vector dot
50+
/// product, which is the sum of the elementwise products (no conjugation
51+
/// of complex operands, and thus not their inner product). In this case,
52+
/// `self` and `rhs` must be the same length.
53+
///
54+
/// If `Rhs` is two-dimensional, then the operation is matrix
55+
/// multiplication, where `self` is treated as a row vector. In this case,
56+
/// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
57+
/// shape *N*.
58+
///
59+
/// **Panics** if the array shapes are incompatible.<br>
5160
/// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
5261
/// layout allows.
53-
pub fn dot<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
54-
where S2: Data<Elem=A>,
55-
A: LinalgScalar,
62+
pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
63+
where Self: Dot<Rhs>
5664
{
57-
self.dot_impl(rhs)
65+
Dot::dot(self, rhs)
5866
}
5967

6068
fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
@@ -156,6 +164,49 @@ pub trait Dot<Rhs> {
156164
fn dot(&self, rhs: &Rhs) -> Self::Output;
157165
}
158166

167+
impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
168+
where S: Data<Elem=A>,
169+
S2: Data<Elem=A>,
170+
A: LinalgScalar,
171+
{
172+
type Output = A;
173+
174+
/// Compute the dot product of one-dimensional arrays.
175+
///
176+
/// The dot product is a sum of the elementwise products (no conjugation
177+
/// of complex operands, and thus not their inner product).
178+
///
179+
/// **Panics** if the arrays are not of the same length.<br>
180+
/// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
181+
/// layout allows.
182+
fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A
183+
{
184+
self.dot_impl(rhs)
185+
}
186+
}
187+
188+
impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
189+
where S: Data<Elem=A>,
190+
S2: Data<Elem=A>,
191+
A: LinalgScalar,
192+
{
193+
type Output = Array<A, Ix1>;
194+
195+
/// Perform the matrix multiplication of the row vector `self` and
196+
/// rectangular matrix `rhs`.
197+
///
198+
/// The array shapes must agree in the way that
199+
/// if `self` is *M*, then `rhs` is *M* × *N*.
200+
///
201+
/// Return a result array with shape *N*.
202+
///
203+
/// **Panics** if shapes are incompatible.
204+
fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>
205+
{
206+
rhs.t().dot(self)
207+
}
208+
}
209+
159210
impl<A, S> ArrayBase<S, Ix2>
160211
where S: Data<Elem=A>,
161212
{

tests/oper.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,18 @@ fn reference_mat_vec_mul<A, S, S2>(lhs: &ArrayBase<S, Ix2>, rhs: &ArrayBase<S2,
328328
.into_shape(m).unwrap()
329329
}
330330

331+
// simple, slow, correct (hopefully) mat mul
332+
fn reference_vec_mat_mul<A, S, S2>(lhs: &ArrayBase<S, Ix1>, rhs: &ArrayBase<S2, Ix2>)
333+
-> Array1<A>
334+
where A: LinalgScalar,
335+
S: Data<Elem=A>,
336+
S2: Data<Elem=A>,
337+
{
338+
let (m, (_, n)) = (lhs.dim(), rhs.dim());
339+
reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs)
340+
.into_shape(n).unwrap()
341+
}
342+
331343
#[test]
332344
fn mat_mul() {
333345
let (m, n, k) = (8, 8, 8);
@@ -703,3 +715,45 @@ fn gen_mat_vec_mul() {
703715
}
704716
}
705717
}
718+
719+
#[test]
720+
fn vec_mat_mul() {
721+
let sizes = vec![(4, 4),
722+
(8, 8),
723+
(17, 15),
724+
(4, 17),
725+
(17, 3),
726+
(19, 18),
727+
(16, 17),
728+
(15, 16),
729+
(67, 63),
730+
];
731+
// test different strides
732+
for &s1 in &[1, 2, -1, -2] {
733+
for &s2 in &[1, 2, -1, -2] {
734+
for &(m, n) in &sizes {
735+
for &rev in &[false, true] {
736+
let mut b = range_mat64(m, n);
737+
if rev {
738+
b = b.reversed_axes();
739+
}
740+
let (m, n) = b.dim();
741+
let a = range1_mat64(m);
742+
let mut c = range1_mat64(n);
743+
let mut answer = c.clone();
744+
745+
{
746+
let b = b.slice(s![..;s1, ..;s2]);
747+
let a = a.slice(s![..;s1]);
748+
749+
let answer_part = reference_vec_mat_mul(&a, &b);
750+
answer.slice_mut(s![..;s2]).assign(&answer_part);
751+
752+
c.slice_mut(s![..;s2]).assign(&a.dot(&b));
753+
}
754+
assert_close(c.view(), answer.view());
755+
}
756+
}
757+
}
758+
}
759+
}

0 commit comments

Comments
 (0)