Skip to content

Commit 9cba571

Browse files
committed
Add basic iterator for matrices
This implements only enough for range based for loops Fixes #41
1 parent 3568f65 commit 9cba571

File tree

5 files changed

+72
-2
lines changed

5 files changed

+72
-2
lines changed

cpp11test/R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ gibbs_rcpp2 <- function(N, thin) {
5656
.Call("_cpp11test_gibbs_rcpp2", N, thin)
5757
}
5858

59+
row_sums <- function(x) {
60+
.Call("_cpp11test_row_sums", x)
61+
}
62+
5963
protect_one_ <- function(x, n) {
6064
invisible(.Call("_cpp11test_protect_one_", x, n))
6165
}

cpp11test/src/cpp11.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ extern "C" SEXP _cpp11test_gibbs_rcpp2(SEXP N, SEXP thin) {
103103
return cpp11::as_sexp(gibbs_rcpp2(cpp11::unmove(cpp11::as_cpp<int>(N)), cpp11::unmove(cpp11::as_cpp<int>(thin))));
104104
END_CPP11
105105
}
106+
// matrix.cpp
107+
cpp11::doubles row_sums(cpp11::doubles_matrix x);
108+
extern "C" SEXP _cpp11test_row_sums(SEXP x) {
109+
BEGIN_CPP11
110+
return cpp11::as_sexp(row_sums(cpp11::unmove(cpp11::as_cpp<cpp11::doubles_matrix>(x))));
111+
END_CPP11
112+
}
106113
// protect.cpp
107114
void protect_one_(SEXP x, int n);
108115
extern "C" SEXP _cpp11test_protect_one_(SEXP x, SEXP n) {
@@ -314,6 +321,7 @@ extern SEXP _cpp11test_rcpp_sum_dbl_accumulate_(SEXP);
314321
extern SEXP _cpp11test_rcpp_sum_dbl_for_(SEXP);
315322
extern SEXP _cpp11test_rcpp_sum_dbl_foreach_(SEXP);
316323
extern SEXP _cpp11test_remove_altrep(SEXP);
324+
extern SEXP _cpp11test_row_sums(SEXP);
317325
extern SEXP _cpp11test_sum_dbl_accumulate_(SEXP);
318326
extern SEXP _cpp11test_sum_dbl_accumulate2_(SEXP);
319327
extern SEXP _cpp11test_sum_dbl_for_(SEXP);
@@ -355,6 +363,7 @@ static const R_CallMethodDef CallEntries[] = {
355363
{"_cpp11test_rcpp_sum_dbl_for_", (DL_FUNC) &_cpp11test_rcpp_sum_dbl_for_, 1},
356364
{"_cpp11test_rcpp_sum_dbl_foreach_", (DL_FUNC) &_cpp11test_rcpp_sum_dbl_foreach_, 1},
357365
{"_cpp11test_remove_altrep", (DL_FUNC) &_cpp11test_remove_altrep, 1},
366+
{"_cpp11test_row_sums", (DL_FUNC) &_cpp11test_row_sums, 1},
358367
{"_cpp11test_sum_dbl_accumulate_", (DL_FUNC) &_cpp11test_sum_dbl_accumulate_, 1},
359368
{"_cpp11test_sum_dbl_accumulate2_", (DL_FUNC) &_cpp11test_sum_dbl_accumulate2_, 1},
360369
{"_cpp11test_sum_dbl_for_", (DL_FUNC) &_cpp11test_sum_dbl_for_, 1},

cpp11test/src/matrix.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,21 @@ using namespace Rcpp;
6666

6767
return (mat);
6868
}
69+
70+
[[cpp11::register]] cpp11::doubles row_sums(cpp11::doubles_matrix x) {
71+
cpp11::writable::doubles sums(x.nrow());
72+
73+
int i = 0;
74+
for (auto& row : x) {
75+
for (auto&& val : row) {
76+
if (cpp11::is_na(val)) {
77+
sums[i] = NA_REAL;
78+
break;
79+
}
80+
sums[i] += val;
81+
}
82+
++i;
83+
}
84+
85+
return sums;
86+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
test_that("row_sums gives same result as rowSums", {
2+
x <- cbind(x1 = 3, x2 = c(4:1, 2:5))
3+
expect_equal(row_sums(x), rowSums(x))
4+
5+
# With missing values
6+
y <- cbind(x1 = 3, x2 = c(4:1, 2:5))
7+
y[3, ] <- NA; x[4, 2] <- NA
8+
expect_equal(row_sums(x), rowSums(x))
9+
})

inst/include/cpp11/matrix.hpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class matrix {
1010
V vector_;
1111
int nrow_;
1212

13+
public:
1314
class row {
1415
private:
1516
matrix& parent_;
@@ -18,6 +19,35 @@ class matrix {
1819
public:
1920
row(matrix& parent, R_xlen_t row) : parent_(parent), row_(row) {}
2021
T operator[](const int pos) { return parent_.vector_[row_ + (pos * parent_.nrow_)]; }
22+
23+
class iterator {
24+
private:
25+
row& row_;
26+
int pos_;
27+
28+
public:
29+
iterator(row& row, R_xlen_t pos) : row_(row), pos_(pos) {}
30+
iterator begin() const { return row_.parent_.vector_iterator(&this, 0); }
31+
iterator end() const { return iterator(&this, row_.size()); }
32+
inline iterator& operator++() {
33+
++pos_;
34+
return *this;
35+
}
36+
bool operator!=(const iterator& rhs) {
37+
return !(pos_ == rhs.pos_ && row_.row_ == rhs.row_.row_);
38+
}
39+
T operator*() const { return row_[pos_]; };
40+
};
41+
42+
iterator begin() { return iterator(*this, 0); }
43+
iterator end() { return iterator(*this, size()); }
44+
R_xlen_t size() const { return parent_.vector_.size() / parent_.nrow_; }
45+
bool operator!=(const row& rhs) { return row_ != rhs.row_; }
46+
row& operator++() {
47+
++row_;
48+
return *this;
49+
}
50+
row& operator*() { return *this; }
2151
};
2252
friend row;
2353

@@ -55,7 +85,8 @@ class matrix {
5585

5686
T operator()(int row, int col) { return vector_[row + (col * nrow_)]; }
5787

58-
// operator cpp11::matrix<V, T>() { return SEXP(); }
88+
row begin() { return {*this, 0}; }
89+
row end() { return {*this, nrow_}; }
5990
};
6091

6192
using doubles_matrix = matrix<r_vector<double>, double>;
@@ -71,5 +102,4 @@ using strings_matrix = matrix<r_vector<r_string>, r_vector<r_string>::proxy>;
71102
} // namespace writable
72103

73104
// TODO: Add tests for Matrix class
74-
75105
}; // namespace cpp11

0 commit comments

Comments
 (0)