Skip to content

Commit 43a9c3d

Browse files
committed
2023 day 17 refactor for clarity
1 parent 44c74f3 commit 43a9c3d

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

2023/src/day23.scala

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,21 @@ case class Maze(grid: Vector[Vector[Char]]):
4848

4949
def apply(p: Point): Char = grid(p.y)(p.x)
5050

51-
val xRange = grid.head.indices
52-
val yRange = grid.indices
51+
val xRange: Range = grid.head.indices
52+
val yRange: Range = grid.indices
5353

5454
def points: Iterator[Point] = for
5555
y <- yRange.iterator
5656
x <- xRange.iterator
5757
yield Point(x, y)
5858

59-
val walkable = points.filter(p => grid(p.y)(p.x) != '#').toSet
60-
val start = walkable.minBy(_.y)
61-
val end = walkable.maxBy(_.y)
59+
val walkable: Set[Point] = points.filter(p => grid(p.y)(p.x) != '#').toSet
60+
val start: Point = walkable.minBy(_.y)
61+
val end: Point = walkable.maxBy(_.y)
62+
63+
val junctions: Set[Point] = walkable.filter: p =>
64+
Dir.values.map(p.move).count(walkable) > 2
65+
.toSet + start + end
6266

6367
val slopes = Map.from[Point, Dir]:
6468
points.collect:
@@ -67,32 +71,30 @@ case class Maze(grid: Vector[Vector[Char]]):
6771
case p if apply(p) == '>' => p -> Dir.E
6872
case p if apply(p) == '<' => p -> Dir.W
6973

70-
val nodes: Set[Point] = walkable.filter: p =>
71-
Dir.values.map(p.move).count(walkable) > 2
72-
.toSet + start + end
73-
74+
def connectedJunctions(pos: Point)(using maze: Maze) = List.from[(Point, Int)]:
75+
def walk(pos: Point, dir: Dir): Option[Point] =
76+
val p = pos.move(dir)
77+
Option.when(maze.walkable(p) && maze.slopes.get(p).forall(_ == dir))(p)
7478

75-
def next(pos: Point, dir: Dir)(using maze: Maze): List[(Point, Dir)] =
76-
for
77-
d <- List(dir, dir.turnRight, dir.turnLeft)
78-
p = pos.move(d)
79-
if maze.slopes.get(p).forall(_ == d)
80-
if maze.walkable(p)
81-
yield p -> d
79+
def search(pos: Point, facing: Dir, dist: Int): Option[(Point, Int)] =
80+
if maze.junctions.contains(pos) then Some(pos, dist) else
81+
val adjacentSearch = for
82+
nextFacing <- LazyList(facing, facing.turnRight, facing.turnLeft)
83+
nextPos <- walk(pos, nextFacing)
84+
yield search(nextPos, nextFacing, dist + 1)
8285

83-
def nodesFrom(pos: Point)(using maze: Maze) = List.from[(Point, Int)]:
84-
def search(p: Point, d: Dir, dist: Int): Option[(Point, Int)] =
85-
next(p, d) match
86-
case (p, d) :: Nil if maze.nodes(p) => Some(p, dist + 1)
87-
case (p, d) :: Nil => search(p, d, dist + 1)
88-
case _ => None
86+
if adjacentSearch.size == 1 then adjacentSearch.head else None
8987

90-
Dir.values.flatMap(next(pos, _)).distinct.flatMap(search(_, _, 1))
88+
for
89+
d <- Dir.values
90+
p <- walk(pos, d)
91+
junction <- search(p, d, 1)
92+
yield junction
9193

9294
def longestDownhillHike(using maze: Maze): Int =
9395
def search(pos: Point, dist: Int)(using maze: Maze): Int =
9496
if pos == maze.end then dist else
95-
nodesFrom(pos).foldLeft(0):
97+
connectedJunctions(pos).foldLeft(0):
9698
case (max, (n, d)) => max.max(search(n, dist + d))
9799

98100
search(maze.start, 0)
@@ -101,21 +103,21 @@ def longestHike(using maze: Maze): Int =
101103
type Index = Int
102104

103105
val indexOf: Map[Point, Index] =
104-
maze.nodes.toList.sortBy(_.dist(maze.start)).zipWithIndex.toMap
106+
maze.junctions.toList.sortBy(_.dist(maze.start)).zipWithIndex.toMap
105107

106108
val adjacent: Map[Index, List[(Index, Int)]] =
107-
maze.nodes.toList.flatMap: p1 =>
108-
nodesFrom(p1).flatMap: (p2, d) =>
109+
maze.junctions.toList.flatMap: p1 =>
110+
connectedJunctions(p1).flatMap: (p2, d) =>
109111
val forward = indexOf(p1) -> (indexOf(p2), d)
110112
val reverse = indexOf(p2) -> (indexOf(p1), d)
111113
List(forward, reverse)
112114
.groupMap(_._1)(_._2)
113115

114-
def search(node: Index, visited: BitSet, dist: Int): Int =
115-
if node == indexOf(maze.end) then dist else
116-
adjacent(node).foldLeft(0):
117-
case (max, (n, d)) =>
118-
if visited(n) then max else
119-
max.max(search(n, visited + n, dist + d))
116+
def search(junction: Index, visited: BitSet, totalDist: Int): Int =
117+
if junction == indexOf(maze.end) then totalDist else
118+
adjacent(junction).foldLeft(0):
119+
case (longest, (nextJunct, dist)) =>
120+
if visited(nextJunct) then longest else
121+
longest.max(search(nextJunct, visited + nextJunct, totalDist + dist))
120122

121123
search(indexOf(maze.start), BitSet.empty, 0)

0 commit comments

Comments
 (0)