Skip to content

Add tests example of a partially unrolled loop #4736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/run-with-compiler/quote-unrolled-foreach.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
val element: scala.Int = arr.apply(i)
f.apply(element)
i = i.+(1)
}
})

((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
val element: java.lang.String = arr.apply(i)
f.apply(element)
i = i.+(1)
}
})

((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
val element: java.lang.String = arr.apply(i)
f.apply(element)
i = i.+(1)
}
})

((arr: scala.Array[scala.Int]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
val element: scala.Int = arr.apply(i)

((i: scala.Int) => java.lang.System.out.println(i)).apply(element)
i = i.+(1)
}
})

((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
if (size.%(3).!=(0)) throw new scala.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i))
f.apply(arr.apply(i.+(1)))
f.apply(arr.apply(i.+(2)))
i = i.+(3)
}
})

((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
if (size.%(4).!=(0)) throw new scala.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i.+(0)))
f.apply(arr.apply(i.+(1)))
f.apply(arr.apply(i.+(2)))
f.apply(arr.apply(i.+(3)))
i = i.+(4)
}
})

{
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
array.update(0, 1)
array.update(1, 2)
array.update(2, 3)
array.update(3, 4)
(array: scala.Array[scala.Int])
}

{
val arr1: scala.Array[scala.Int] = {
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
array.update(0, 1)
array.update(1, 3)
array.update(2, 4)
array.update(3, 5)
(array: scala.Array[scala.Int])
}
val size: scala.Int = arr1.length
var i: scala.Int = 0
while (i.<(size)) {
val element: scala.Int = arr1.apply(i)

((x: scala.Int) => scala.Predef.println(x)).apply(element)
i = i.+(1)
}
}
134 changes: 134 additions & 0 deletions tests/run-with-compiler/quote-unrolled-foreach.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import scala.annotation.tailrec
import scala.quoted._

object Test {
implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make

def main(args: Array[String]): Unit = {
val code1 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach1('(arr), '(f)) }
println(code1.show)
println()

val code1Tpe = '{ (arr: Array[String], f: String => Unit) => ~foreach1Tpe1('(arr), '(f)) }
println(code1Tpe.show)
println()

val code1Tpe2 = '{ (arr: Array[String], f: String => Unit) => ~foreach1Tpe1('(arr), '(f)) }
println(code1Tpe2.show)
println()

val code2 = '{ (arr: Array[Int]) => ~foreach1('(arr), '(i => System.out.println(i))) }
println(code2.show)
println()

val code3 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach3('(arr), '(f)) }
println(code3.show)
println()

val code4 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach4('(arr), '(f), 4) }
println(code4.show)
println()

val liftedArray = Array(1, 2, 3, 4).toExpr
println(liftedArray.show)
println()


def printAll(arr: Array[Int]) = '{
val arr1 = ~arr.toExpr
~foreach1('(arr1), '(x => println(x)))
}

println(printAll(Array(1, 3, 4, 5)).show)

}

def foreach1(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
while (i < size) {
val element: Int = (~arrRef)(i)
(~f)(element)
i += 1
}
}

def foreach1Tpe1[T](arrRef: Expr[Array[T]], f: Expr[T => Unit])(implicit t: Type[T]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
while (i < size) {
val element: ~t = (~arrRef)(i)
(~f)(element)
i += 1
}
}

def foreach1Tpe2[T: Type](arrRef: Expr[Array[T]], f: Expr[T => Unit]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
while (i < size) {
val element: T = (~arrRef)(i)
(~f)(element)
i += 1
}
}

def foreach2(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
while (i < size) {
val element = (~arrRef)(i)
~f('(element)) // Use AppliedFuntion
i += 1
}
}

def foreach3(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
if (size % 3 != 0) throw new Exception("...")// for simplicity of the implementation
while (i < size) {
(~f)((~arrRef)(i))
(~f)((~arrRef)(i + 1))
(~f)((~arrRef)(i + 2))
i += 3
}
}

def foreach3_2(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
if (size % 3 != 0) throw new Exception("...")// for simplicity of the implementation
while (i < size) {
(~f)((~arrRef)(i))
(~f)((~arrRef)(i + 1))
(~f)((~arrRef)(i + 2))
i += 3
}
}

def foreach4(arrRef: Expr[Array[Int]], f: Expr[Int => Unit], unrollSize: Int): Expr[Unit] = '{
val size = (~arrRef).length
var i = 0
if (size % ~unrollSize.toExpr != 0) throw new Exception("...") // for simplicity of the implementation
while (i < size) {
~foreachInRange(0, unrollSize)(j => '{ (~f)((~arrRef)(i + ~j.toExpr)) })
i += ~unrollSize.toExpr
}
}

implicit object ArrayIntIsLiftable extends Liftable[Array[Int]] {
override def toExpr(x: Array[Int]): Expr[Array[Int]] = '{
val array = new Array[Int](~x.length.toExpr)
~foreachInRange(0, x.length)(i => '{ array(~i.toExpr) = ~x(i).toExpr})
array
}
}

def foreachInRange(start: Int, end: Int)(f: Int => Expr[Unit]): Expr[Unit] = {
@tailrec def unroll(i: Int, acc: Expr[Unit]): Expr[Unit] =
if (i < end) unroll(i + 1, '{ ~acc; ~f(i) }) else acc
if (start < end) unroll(start + 1, f(start)) else '()
}

}
28 changes: 28 additions & 0 deletions tests/run/quote-unrolled-foreach.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<log> start loop
0
6
12
<log> start loop
18
24
30
<log> start loop
36
42
48
<log> start loop
54
60
66
<log> start loop
72
78
84
<log> start loop
90
96
102
<log> start loop
108
114
120
24 changes: 24 additions & 0 deletions tests/run/quote-unrolled-foreach/quoted_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import scala.annotation.tailrec
import scala.quoted._

object Macro {

inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int])(f: => Int => Unit): Unit = // or f: Int => Unit
~unrolledForeachImpl(unrollSize, '(seq), '(f))

private def unrolledForeachImpl(unrollSize: Int, seq: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
val size = (~seq).length
assert(size % (~unrollSize.toExpr) == 0) // for simplicity of the implementation
var i = 0
while (i < size) {
println("<log> start loop")
~{
@tailrec def loop(j: Int, acc: Expr[Unit]): Expr[Unit] =
if (j >= 0) loop(j - 1, '{ ~f('((~seq)(i + ~j.toExpr))); ~acc })
else acc
loop(unrollSize - 1, '())
}
i += ~unrollSize.toExpr
}
}
}
29 changes: 29 additions & 0 deletions tests/run/quote-unrolled-foreach/quoted_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

import scala.quoted._

object Test {
def main(args: Array[String]): Unit = {
val arr = Array.tabulate[Int](21)(x => 3 * x)
Macro.unrolledForeach(3, arr) { (x: Int) =>
System.out.println(2 * x)
}

/* unrooled code:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a test to https://github.com/lampepfl/dotty/blob/master/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala to check that the code produced by the compiler actually is what you think it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't thinks InlineBytecodeTests.scala supports separate compilation. Not sure how much work it would be to make it support it. It is probably better to just test the code directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "test the code directly"? You're doing code generation here, if you don't test that the code you generate looks like the way you think it should look like, I can guarantee you that you will have hard-to-find bugs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean test the high-level code generated instead of the final bytecode.


val size: Int = arr.length()
assert(size % 3 == 0)
var i: Int = 0
while (i < size) {
println("<log> start loop")
val x$1: Int = arr(i)
System.out.println(2 * x$1)
val x$2: Int = arr(i + 1)
System.out.println(2 * x$2)
val x$3: Int = arr(i + 2)
System.out.println(2 * x$3)
i = i + 3
}
*/
}

}