Skip to content

Commit d74c3cd

Browse files
author
Javier de Silóniz Sandino
committed
Exercises regarding Trees
1 parent 712895a commit d74c3cd

File tree

3 files changed

+209
-0
lines changed

3 files changed

+209
-0
lines changed

src/main/scala/fpinscalalib/FunctionalDataStructuresSection.scala

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package fpinscalalib
22

33
import org.scalatest.{FlatSpec, Matchers}
44
import fpinscalalib.List._
5+
import fpinscalalib.Tree._
56

67
/** @param name functional_data_structures
78
*/
@@ -521,5 +522,113 @@ object FunctionalDataStructuresSection extends FlatSpec with Matchers with org.s
521522
hasSubsequence(l, List(0, 1)) shouldBe res1
522523
hasSubsequence(l, Nil) shouldBe res2
523524
}
525+
526+
/**
527+
* = Trees =
528+
*
529+
* List is just one example of what’s called an algebraic data type (ADT). An ADT is just a data type defined by one
530+
* or more data constructors, each of which may contain zero or more arguments. We say that the data type is the sum
531+
* or union of its data constructors, and each data constructor is the product of its arguments, hence the name
532+
* algebraic data type.
533+
*
534+
* Algebraic data types can be used to define other data structures. Let’s define a simple binary tree data structure:
535+
*
536+
* {{{
537+
* sealed trait Tree[+A]
538+
* case class Leaf[A](value: A) extends Tree[A]
539+
* case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
540+
* }}}
541+
*
542+
* Pattern matching again provides a convenient way of operating over elements of our ADT. Let’s try writing a few
543+
* functions. For starters, let's try to implement a function `size` to count the number of nodes (leaves and branches)
544+
* in a tree:
545+
*/
546+
547+
def treeSizeAssert(res0: Int, res1: Int): Unit = {
548+
def size[A](t: Tree[A]): Int = t match {
549+
case Leaf(_) => res0
550+
case Branch(l, r) => res1 + size(l) + size(r)
551+
}
552+
553+
def t = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
554+
size(t) shouldBe 5
555+
}
556+
557+
/**
558+
* Following a similar implementation, we can write a function `maximum` that returns the maximum element in a
559+
* Tree[Int]:
560+
*
561+
* {{{
562+
* def maximum(t: Tree[Int]): Int = t match {
563+
* case Leaf(n) => n
564+
* case Branch(l,r) => maximum(l) max maximum(r)
565+
* }
566+
* }}}
567+
*
568+
* In the same fashion, let's implement a function `depth` that returns the maximum path length from the root of a
569+
* tree to any leaf.
570+
*/
571+
572+
def treeDepthAssert(res0: Int, res1: Int): Unit = {
573+
def depth[A](t: Tree[A]): Int = t match {
574+
case Leaf(_) => res0
575+
case Branch(l,r) => res1 + (depth(l) max depth(r))
576+
}
577+
def t = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
578+
depth(t) shouldBe 2
579+
}
580+
581+
/**
582+
* We can also write a function `map`, analogous to the method of the same name on `List`, that modifies each element
583+
* in a tree with a given function:
584+
*
585+
* {{{
586+
* def map[A,B](t: Tree[A])(f: A => B): Tree[B] = t match {
587+
* case Leaf(a) => Leaf(f(a))
588+
* case Branch(l,r) => Branch(map(l)(f), map(r)(f))
589+
* }
590+
* }}}
591+
*
592+
* Let's try it out in the following exercise:
593+
*/
594+
595+
def treeMapAssert(res0: Branch[Int]): Unit = {
596+
def t = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
597+
Tree.map(t)(_ * 2) shouldBe res0
598+
}
599+
600+
/**
601+
* To wrap this section up, let's generalize `size`, `maximum`, `depth` and `map`, writing a new function `fold` that
602+
* abstracts over their similarities:
603+
*
604+
* {{{
605+
* def fold[A,B](t: Tree[A])(f: A => B)(g: (B,B) => B): B = t match {
606+
* case Leaf(a) => f(a)
607+
* case Branch(l,r) => g(fold(l)(f)(g), fold(r)(f)(g))
608+
* }
609+
* }}}
610+
*
611+
* Let's try to reimplement `size`, `maximum`, `depth`, and `map` in terms of this more general function:
612+
*/
613+
614+
def treeFoldAssert(res0: Int, res1: Int, res2: Int, res3: Int, res4: Branch[Boolean]): Unit = {
615+
def sizeViaFold[A](t: Tree[A]): Int =
616+
fold(t)(a => res0)(res1 + _ + _)
617+
618+
def maximumViaFold(t: Tree[Int]): Int =
619+
fold(t)(a => a)(_ max _)
620+
621+
def depthViaFold[A](t: Tree[A]): Int =
622+
fold(t)(a => res2)((d1,d2) => res3 + (d1 max d2))
623+
624+
def mapViaFold[A,B](t: Tree[A])(f: A => B): Tree[B] =
625+
fold(t)(a => Leaf(f(a)): Tree[B])(Branch(_,_))
626+
627+
def t = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
628+
sizeViaFold(t) shouldBe 5
629+
maximumViaFold(t) shouldBe 3
630+
depthViaFold(t) shouldBe 2
631+
mapViaFold(t)(_ % 2 == 0) shouldBe res4
632+
}
524633
}
525634

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package fpinscalalib
2+
3+
// The following implementation of the binary tree is provided by Manning as a solution to the multiple implementation
4+
// exercises found in the "Functional Programming in Scala" book. The original code can be found in the following URL:
5+
//
6+
// https://github.com/fpinscala/fpinscala/commit/5bf1138f3aa71ff91babaa99613313fb9ac48b27
7+
8+
sealed trait Tree[+A]
9+
case class Leaf[A](value: A) extends Tree[A]
10+
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
11+
12+
13+
object Tree {
14+
def size[A](t: Tree[A]): Int = t match {
15+
case Leaf(_) => 1
16+
case Branch(l,r) => 1 + size(l) + size(r)
17+
}
18+
19+
/*
20+
We're using the method `max` that exists on all `Int` values rather than an explicit `if` expression.
21+
22+
Note how similar the implementation is to `size`. We'll abstract out the common pattern in a later exercise.
23+
*/
24+
def maximum(t: Tree[Int]): Int = t match {
25+
case Leaf(n) => n
26+
case Branch(l,r) => maximum(l) max maximum(r)
27+
}
28+
29+
/*
30+
Again, note how similar the implementation is to `size` and `maximum`.
31+
*/
32+
def depth[A](t: Tree[A]): Int = t match {
33+
case Leaf(_) => 0
34+
case Branch(l,r) => 1 + (depth(l) max depth(r))
35+
}
36+
37+
def map[A,B](t: Tree[A])(f: A => B): Tree[B] = t match {
38+
case Leaf(a) => Leaf(f(a))
39+
case Branch(l,r) => Branch(map(l)(f), map(r)(f))
40+
}
41+
42+
/*
43+
Like `foldRight` for lists, `fold` receives a "handler" for each of the data constructors of the type, and recursively
44+
accumulates some value using these handlers. As with `foldRight`, `fold(t)(Leaf(_))(Branch(_,_)) == t`, and we can use
45+
this function to implement just about any recursive function that would otherwise be defined by pattern matching.
46+
*/
47+
def fold[A,B](t: Tree[A])(f: A => B)(g: (B,B) => B): B = t match {
48+
case Leaf(a) => f(a)
49+
case Branch(l,r) => g(fold(l)(f)(g), fold(r)(f)(g))
50+
}
51+
52+
def sizeViaFold[A](t: Tree[A]): Int =
53+
fold(t)(a => 1)(1 + _ + _)
54+
55+
def maximumViaFold(t: Tree[Int]): Int =
56+
fold(t)(a => a)(_ max _)
57+
58+
def depthViaFold[A](t: Tree[A]): Int =
59+
fold(t)(a => 0)((d1,d2) => 1 + (d1 max d2))
60+
61+
/*
62+
Note the type annotation required on the expression `Leaf(f(a))`. Without this annotation, we get an error like this:
63+
64+
type mismatch;
65+
found : fpinscala.datastructures.Branch[B]
66+
required: fpinscala.datastructures.Leaf[B]
67+
fold(t)(a => Leaf(f(a)))(Branch(_,_))
68+
^
69+
70+
This error is an unfortunate consequence of Scala using subtyping to encode algebraic data types. Without the
71+
annotation, the result type of the fold gets inferred as `Leaf[B]` and it is then expected that the second argument
72+
to `fold` will return `Leaf[B]`, which it doesn't (it returns `Branch[B]`). Really, we'd prefer Scala to
73+
infer `Tree[B]` as the result type in both cases. When working with algebraic data types in Scala, it's somewhat
74+
common to define helper functions that simply call the corresponding data constructors but give the less specific
75+
result type:
76+
77+
def leaf[A](a: A): Tree[A] = Leaf(a)
78+
def branch[A](l: Tree[A], r: Tree[A]): Tree[A] = Branch(l, r)
79+
*/
80+
def mapViaFold[A,B](t: Tree[A])(f: A => B): Tree[B] =
81+
fold(t)(a => Leaf(f(a)): Tree[B])(Branch(_,_))
82+
}

src/test/scala/fpinscalalib/FunctionalDataStructuresSpec.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,22 @@ class FunctionalDataStructuresSpec extends Spec with Checkers {
8282
def `list hasSubsequence asserts` = {
8383
check(Test.testSuccess(FunctionalDataStructuresSection.listHasSubsequenceAssert _, true :: false :: true :: HNil))
8484
}
85+
86+
def `tree size asserts` = {
87+
check(Test.testSuccess(FunctionalDataStructuresSection.treeSizeAssert _, 1 :: 1 :: HNil))
88+
}
89+
90+
def `tree depth asserts` = {
91+
check(Test.testSuccess(FunctionalDataStructuresSection.treeDepthAssert _, 0 :: 1 :: HNil))
92+
}
93+
94+
def `tree map asserts` = {
95+
check(Test.testSuccess(FunctionalDataStructuresSection.treeMapAssert _,
96+
Branch[Int](Branch[Int](Leaf[Int](2), Leaf[Int](4)), Leaf[Int](6)) :: HNil))
97+
}
98+
99+
def `tree fold asserts` = {
100+
check(Test.testSuccess(FunctionalDataStructuresSection.treeFoldAssert _,
101+
1 :: 1 :: 0 :: 1 :: Branch(Branch(Leaf(false), Leaf(true)), Leaf(false)) :: HNil))
102+
}
85103
}

0 commit comments

Comments
 (0)