@@ -2,6 +2,7 @@ package fpinscalalib
2
2
3
3
import org .scalatest .{FlatSpec , Matchers }
4
4
import fpinscalalib .List ._
5
+ import fpinscalalib .Tree ._
5
6
6
7
/** @param name functional_data_structures
7
8
*/
@@ -521,5 +522,113 @@ object FunctionalDataStructuresSection extends FlatSpec with Matchers with org.s
521
522
hasSubsequence(l, List (0 , 1 )) shouldBe res1
522
523
hasSubsequence(l, Nil ) shouldBe res2
523
524
}
525
+
526
+ /**
527
+ * = Trees =
528
+ *
529
+ * List is just one example of what’s called an algebraic data type (ADT). An ADT is just a data type defined by one
530
+ * or more data constructors, each of which may contain zero or more arguments. We say that the data type is the sum
531
+ * or union of its data constructors, and each data constructor is the product of its arguments, hence the name
532
+ * algebraic data type.
533
+ *
534
+ * Algebraic data types can be used to define other data structures. Let’s define a simple binary tree data structure:
535
+ *
536
+ * {{{
537
+ * sealed trait Tree[+A]
538
+ * case class Leaf[A](value: A) extends Tree[A]
539
+ * case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
540
+ * }}}
541
+ *
542
+ * Pattern matching again provides a convenient way of operating over elements of our ADT. Let’s try writing a few
543
+ * functions. For starters, let's try to implement a function `size` to count the number of nodes (leaves and branches)
544
+ * in a tree:
545
+ */
546
+
547
+ def treeSizeAssert (res0 : Int , res1 : Int ): Unit = {
548
+ def size [A ](t : Tree [A ]): Int = t match {
549
+ case Leaf (_) => res0
550
+ case Branch (l, r) => res1 + size(l) + size(r)
551
+ }
552
+
553
+ def t = Branch (Branch (Leaf (1 ), Leaf (2 )), Leaf (3 ))
554
+ size(t) shouldBe 5
555
+ }
556
+
557
+ /**
558
+ * Following a similar implementation, we can write a function `maximum` that returns the maximum element in a
559
+ * Tree[Int]:
560
+ *
561
+ * {{{
562
+ * def maximum(t: Tree[Int]): Int = t match {
563
+ * case Leaf(n) => n
564
+ * case Branch(l,r) => maximum(l) max maximum(r)
565
+ * }
566
+ * }}}
567
+ *
568
+ * In the same fashion, let's implement a function `depth` that returns the maximum path length from the root of a
569
+ * tree to any leaf.
570
+ */
571
+
572
+ def treeDepthAssert (res0 : Int , res1 : Int ): Unit = {
573
+ def depth [A ](t : Tree [A ]): Int = t match {
574
+ case Leaf (_) => res0
575
+ case Branch (l,r) => res1 + (depth(l) max depth(r))
576
+ }
577
+ def t = Branch (Branch (Leaf (1 ), Leaf (2 )), Leaf (3 ))
578
+ depth(t) shouldBe 2
579
+ }
580
+
581
+ /**
582
+ * We can also write a function `map`, analogous to the method of the same name on `List`, that modifies each element
583
+ * in a tree with a given function:
584
+ *
585
+ * {{{
586
+ * def map[A,B](t: Tree[A])(f: A => B): Tree[B] = t match {
587
+ * case Leaf(a) => Leaf(f(a))
588
+ * case Branch(l,r) => Branch(map(l)(f), map(r)(f))
589
+ * }
590
+ * }}}
591
+ *
592
+ * Let's try it out in the following exercise:
593
+ */
594
+
595
+ def treeMapAssert (res0 : Branch [Int ]): Unit = {
596
+ def t = Branch (Branch (Leaf (1 ), Leaf (2 )), Leaf (3 ))
597
+ Tree .map(t)(_ * 2 ) shouldBe res0
598
+ }
599
+
600
+ /**
601
+ * To wrap this section up, let's generalize `size`, `maximum`, `depth` and `map`, writing a new function `fold` that
602
+ * abstracts over their similarities:
603
+ *
604
+ * {{{
605
+ * def fold[A,B](t: Tree[A])(f: A => B)(g: (B,B) => B): B = t match {
606
+ * case Leaf(a) => f(a)
607
+ * case Branch(l,r) => g(fold(l)(f)(g), fold(r)(f)(g))
608
+ * }
609
+ * }}}
610
+ *
611
+ * Let's try to reimplement `size`, `maximum`, `depth`, and `map` in terms of this more general function:
612
+ */
613
+
614
+ def treeFoldAssert (res0 : Int , res1 : Int , res2 : Int , res3 : Int , res4 : Branch [Boolean ]): Unit = {
615
+ def sizeViaFold [A ](t : Tree [A ]): Int =
616
+ fold(t)(a => res0)(res1 + _ + _)
617
+
618
+ def maximumViaFold (t : Tree [Int ]): Int =
619
+ fold(t)(a => a)(_ max _)
620
+
621
+ def depthViaFold [A ](t : Tree [A ]): Int =
622
+ fold(t)(a => res2)((d1,d2) => res3 + (d1 max d2))
623
+
624
+ def mapViaFold [A ,B ](t : Tree [A ])(f : A => B ): Tree [B ] =
625
+ fold(t)(a => Leaf (f(a)): Tree [B ])(Branch (_,_))
626
+
627
+ def t = Branch (Branch (Leaf (1 ), Leaf (2 )), Leaf (3 ))
628
+ sizeViaFold(t) shouldBe 5
629
+ maximumViaFold(t) shouldBe 3
630
+ depthViaFold(t) shouldBe 2
631
+ mapViaFold(t)(_ % 2 == 0 ) shouldBe res4
632
+ }
524
633
}
525
634
0 commit comments