Skip to content

Commit ab87f14

Browse files
committed
Add tests example of a partally unrolled loop
1 parent 2014ffe commit ab87f14

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
<log> start loop
2+
0
3+
2
4+
4
5+
<log> start loop
6+
6
7+
8
8+
10
9+
<log> start loop
10+
12
11+
14
12+
16
13+
<log> start loop
14+
18
15+
20
16+
22
17+
<log> start loop
18+
24
19+
26
20+
28
21+
<log> start loop
22+
30
23+
32
24+
34
25+
<log> start loop
26+
36
27+
38
28+
40
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.annotation.tailrec
2+
import scala.quoted._
3+
4+
object Macro {
5+
6+
inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int])(f: => Int => Unit): Unit = // or f: Int => Unit
7+
~unrolledForeachImpl(unrollSize, '(seq), '(f))
8+
9+
private def unrolledForeachImpl(unrollSize: Int, seq: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
10+
val size = (~seq).length
11+
assert(size % (~unrollSize.toExpr) == 0) // for simplicity of the implementation
12+
var i = 0
13+
while (i < size) {
14+
println("<log> start loop")
15+
~{
16+
for (j <- new UnrolledRange(0, unrollSize)) '{
17+
val element = (~seq)(i + ~j.toExpr)
18+
~f('(element)) // or `(~f)(element)` if `f` should not be inlined
19+
}
20+
}
21+
i += ~unrollSize.toExpr
22+
}
23+
24+
}
25+
26+
private class UnrolledRange(start: Int, end: Int) {
27+
def foreach(f: Int => Expr[Unit]): Expr[Unit] = {
28+
@tailrec def loop(i: Int, acc: Expr[Unit]): Expr[Unit] =
29+
if (i >= 0) loop(i - 1, '{ ~f(i); ~acc })
30+
else acc
31+
loop(end - 1, '())
32+
}
33+
}
34+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
import scala.quoted._
3+
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
val arr = Array.tabulate[Int](21)(x => x)
7+
Macro.unrolledForeach(3, arr) { (x: Int) =>
8+
System.out.println(2 * x)
9+
}
10+
11+
/* unrooled code:
12+
13+
val size: Int = arr.length()
14+
assert(size % 3 == 0)
15+
var i: Int = 0
16+
while (i < size) {
17+
println("<log> start loop")
18+
val element$1: Int = arr(i)
19+
val x$1: Int = element$1
20+
System.out.println(2 * x$1)
21+
val element$2: Int = arr(i + 1)
22+
val x$2: Int = element$2
23+
System.out.println(2 * x$2)
24+
val element$3: Int = arr(i + 2)
25+
val x$3: Int = element$3
26+
System.out.println(2 * x$3)
27+
i = i + 3
28+
}
29+
*/
30+
}
31+
32+
}

0 commit comments

Comments
 (0)