|
| 1 | +""" |
| 2 | +Safe interval path planner |
| 3 | +
|
| 4 | +TODO: populate docstring |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import ( |
| 10 | + Grid, |
| 11 | + Interval, |
| 12 | + ObstacleArrangement, |
| 13 | + Position, |
| 14 | +) |
| 15 | +import heapq |
| 16 | +from collections.abc import Generator |
| 17 | +import random |
| 18 | +from dataclasses import dataclass |
| 19 | +from functools import total_ordering |
| 20 | + |
| 21 | + |
| 22 | +# Seed randomness for reproducibility |
| 23 | +RANDOM_SEED = 50 |
| 24 | +random.seed(RANDOM_SEED) |
| 25 | +np.random.seed(RANDOM_SEED) |
| 26 | + |
| 27 | +@dataclass() |
| 28 | +# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because |
| 29 | +# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. Parent |
| 30 | +# index is just used to track the path found by the algorithm, and has no effect on the quality |
| 31 | +# of a node. |
| 32 | +@total_ordering |
| 33 | +class Node: |
| 34 | + position: Position |
| 35 | + time: int |
| 36 | + heuristic: int |
| 37 | + parent_index: int |
| 38 | + interval: Interval |
| 39 | + |
| 40 | + """ |
| 41 | + This is what is used to drive node expansion. The node with the lowest value is expanded next. |
| 42 | + This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic) |
| 43 | + """ |
| 44 | + def __lt__(self, other: object): |
| 45 | + if not isinstance(other, Node): |
| 46 | + return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}") |
| 47 | + # TODO: assumption that these two carry all the info needed for intervals. I think that makes sense but should think about it |
| 48 | + return (self.time + self.heuristic) < (other.time + other.heuristic) |
| 49 | + |
| 50 | + """ |
| 51 | + TODO - note about interval being included here |
| 52 | + """ |
| 53 | + def __eq__(self, other: object): |
| 54 | + if not isinstance(other, Node): |
| 55 | + return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}") |
| 56 | + return self.position == other.position and self.time == other.time and self.interval == other.interval |
| 57 | + |
| 58 | + |
| 59 | +class NodePath: |
| 60 | + path: list[Node] |
| 61 | + positions_at_time: dict[int, Position] = {} |
| 62 | + |
| 63 | + def __init__(self, path: list[Node]): |
| 64 | + self.path = path |
| 65 | + for node in path: |
| 66 | + self.positions_at_time[node.time] = node.position |
| 67 | + |
| 68 | + """ |
| 69 | + Get the position of the path at a given time |
| 70 | + """ |
| 71 | + def get_position(self, time: int) -> Position | None: |
| 72 | + return self.positions_at_time.get(time) |
| 73 | + |
| 74 | + """ |
| 75 | + Time stamp of the last node in the path |
| 76 | + """ |
| 77 | + def goal_reached_time(self) -> int: |
| 78 | + return self.path[-1].time |
| 79 | + |
| 80 | + def __repr__(self): |
| 81 | + repr_string = "" |
| 82 | + for i, node in enumerate(self.path): |
| 83 | + repr_string += f"{i}: {node}\n" |
| 84 | + return repr_string |
| 85 | + |
| 86 | + |
| 87 | +class SafeIntervalPathPlanner: |
| 88 | + grid: Grid |
| 89 | + start: Position |
| 90 | + goal: Position |
| 91 | + |
| 92 | + def __init__(self, grid: Grid, start: Position, goal: Position): |
| 93 | + self.grid = grid |
| 94 | + self.start = start |
| 95 | + self.goal = goal |
| 96 | + |
| 97 | + def plan(self, verbose: bool = False) -> NodePath: |
| 98 | + |
| 99 | + safe_intervals = self.grid.get_safe_intervals() |
| 100 | + print(safe_intervals[0, 0]) |
| 101 | + |
| 102 | + open_set: list[Node] = [] |
| 103 | + first_node_interval = safe_intervals[self.start.x, self.start.y][0] |
| 104 | + heapq.heappush( |
| 105 | + open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1, first_node_interval) |
| 106 | + ) |
| 107 | + |
| 108 | + expanded_list: list[Node] = [] |
| 109 | + # TODO: copy pasta from Grid file |
| 110 | + # 2d np array of lists of (entry time, interval tuples) |
| 111 | + # TODO: use a dataclass for the tuple |
| 112 | + visited_intervals = np.empty((self.grid.grid_size[0], self.grid.grid_size[1]), dtype=object) |
| 113 | + visited_intervals[:] = [[[] for _ in range(visited_intervals.shape[1])] for _ in range(visited_intervals.shape[0])] |
| 114 | + while open_set: |
| 115 | + expanded_node: Node = heapq.heappop(open_set) |
| 116 | + if verbose: |
| 117 | + print("Expanded node:", expanded_node) |
| 118 | + |
| 119 | + if expanded_node.time + 1 >= self.grid.time_limit: |
| 120 | + if verbose: |
| 121 | + print(f"\tSkipping node that is past time limit: {expanded_node}") |
| 122 | + continue |
| 123 | + |
| 124 | + if expanded_node.position == self.goal: |
| 125 | + print(f"Found path to goal after {len(expanded_list)} expansions") |
| 126 | + path = [] |
| 127 | + path_walker: Node = expanded_node |
| 128 | + while path_walker.parent_index != -1: |
| 129 | + path.append(path_walker) |
| 130 | + path_walker = expanded_list[path_walker.parent_index] |
| 131 | + # TODO: fix hack around bad while condiiotn |
| 132 | + path.append(path_walker) |
| 133 | + |
| 134 | + # reverse path so it goes start -> goal |
| 135 | + path.reverse() |
| 136 | + return NodePath(path) |
| 137 | + |
| 138 | + expanded_idx = len(expanded_list) |
| 139 | + expanded_list.append(expanded_node) |
| 140 | + visited_intervals[expanded_node.position.x, expanded_node.position.y].append((expanded_node.time, expanded_node.interval)) |
| 141 | + |
| 142 | + # if len(expanded_set) > 100: |
| 143 | + # blarg |
| 144 | + |
| 145 | + for child in self.generate_successors(expanded_node, expanded_idx, verbose, safe_intervals, visited_intervals): |
| 146 | + heapq.heappush(open_set, child) |
| 147 | + |
| 148 | + raise Exception("No path found") |
| 149 | + |
| 150 | + """ |
| 151 | + Generate possible successors of the provided `parent_node` |
| 152 | + """ |
| 153 | + # TODO: is intervals being passed by ref? (i think so?) |
| 154 | + def generate_successors( |
| 155 | + self, parent_node: Node, parent_node_idx: int, verbose: bool, intervals: np.ndarray, visited_intervals: np.ndarray |
| 156 | + ) -> list[Node]: |
| 157 | + new_nodes = [] |
| 158 | + diffs = [ |
| 159 | + Position(0, 0), |
| 160 | + Position(1, 0), |
| 161 | + Position(-1, 0), |
| 162 | + Position(0, 1), |
| 163 | + Position(0, -1), |
| 164 | + ] |
| 165 | + for diff in diffs: |
| 166 | + new_pos = parent_node.position + diff |
| 167 | + if not self.grid.valid_position(new_pos): |
| 168 | + continue |
| 169 | + |
| 170 | + current_interval = parent_node.interval |
| 171 | + |
| 172 | + new_cell_intervals: list[Interval] = intervals[new_pos.x, new_pos.y] |
| 173 | + for interval in new_cell_intervals: |
| 174 | + # if interval ends before current starts, skip |
| 175 | + if interval.end_time < current_interval.start_time: |
| 176 | + continue |
| 177 | + |
| 178 | + # if interval starts after current ends, break |
| 179 | + # TODO: assumption here that intervals are sorted (they should be) |
| 180 | + if interval.start_time > current_interval.end_time: |
| 181 | + break |
| 182 | + |
| 183 | + # TODO: this bit feels wonky |
| 184 | + # if we have already expanded a node in this interval with a <= starting time, continue |
| 185 | + better_node_expanded = False |
| 186 | + for visited in visited_intervals[new_pos.x, new_pos.y]: |
| 187 | + if interval == visited[1] and visited[0] <= parent_node.time + 1: |
| 188 | + better_node_expanded = True |
| 189 | + break |
| 190 | + if better_node_expanded: |
| 191 | + continue |
| 192 | + |
| 193 | + # We know there is some overlap. Generate successor at the earliest possible time the |
| 194 | + # new interval can be entered |
| 195 | + # TODO: dont love the optionl usage here |
| 196 | + new_node_t = None |
| 197 | + for possible_t in range(max(parent_node.time + 1, interval.start_time), min(current_interval.end_time, interval.end_time)): |
| 198 | + if self.grid.valid_position(new_pos, possible_t): |
| 199 | + new_node_t = possible_t |
| 200 | + break |
| 201 | + |
| 202 | + if new_node_t: |
| 203 | + # TODO: should be able to break here? |
| 204 | + new_nodes.append(Node( |
| 205 | + new_pos, |
| 206 | + # entry is max of interval start and parent node start time (get there as soon as possible) |
| 207 | + max(parent_node.time + 1, interval.start_time), |
| 208 | + self.calculate_heuristic(new_pos), |
| 209 | + parent_node_idx, |
| 210 | + interval, |
| 211 | + )) |
| 212 | + |
| 213 | + return new_nodes |
| 214 | + |
| 215 | + def calculate_heuristic(self, position) -> int: |
| 216 | + diff = self.goal - position |
| 217 | + return abs(diff.x) + abs(diff.y) |
| 218 | + |
| 219 | + |
| 220 | +show_animation = True |
| 221 | +verbose = True |
| 222 | + |
| 223 | +# TODO: viz shows obstacle finish 1 cell above the goal? |
| 224 | +def main(): |
| 225 | + start = Position(1, 18) |
| 226 | + goal = Position(19, 19) |
| 227 | + grid_side_length = 21 |
| 228 | + grid = Grid( |
| 229 | + np.array([grid_side_length, grid_side_length]), |
| 230 | + # TODO: if this is set to 0, still get some obstacles with random set |
| 231 | + num_obstacles=22, |
| 232 | + obstacle_avoid_points=[start, goal], |
| 233 | + obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1, |
| 234 | + # obstacle_arrangement=ObstacleArrangement.RANDOM, |
| 235 | + ) |
| 236 | + |
| 237 | + planner = SafeIntervalPathPlanner(grid, start, goal) |
| 238 | + path = planner.plan(verbose) |
| 239 | + |
| 240 | + if verbose: |
| 241 | + print(f"Path: {path}") |
| 242 | + |
| 243 | + if not show_animation: |
| 244 | + return |
| 245 | + |
| 246 | + fig = plt.figure(figsize=(10, 7)) |
| 247 | + ax = fig.add_subplot( |
| 248 | + autoscale_on=False, |
| 249 | + xlim=(0, grid.grid_size[0] - 1), |
| 250 | + ylim=(0, grid.grid_size[1] - 1), |
| 251 | + ) |
| 252 | + ax.set_aspect("equal") |
| 253 | + ax.grid() |
| 254 | + ax.set_xticks(np.arange(0, grid_side_length, 1)) |
| 255 | + ax.set_yticks(np.arange(0, grid_side_length, 1)) |
| 256 | + |
| 257 | + (start_and_goal,) = ax.plot([], [], "mD", ms=15, label="Start and Goal") |
| 258 | + start_and_goal.set_data([start.x, goal.x], [start.y, goal.y]) |
| 259 | + (obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles") |
| 260 | + (path_points,) = ax.plot([], [], "bo", ms=10, label="Path Found") |
| 261 | + ax.legend(bbox_to_anchor=(1.05, 1)) |
| 262 | + |
| 263 | + # for stopping simulation with the esc key. |
| 264 | + plt.gcf().canvas.mpl_connect( |
| 265 | + "key_release_event", lambda event: [exit(0) if event.key == "escape" else None] |
| 266 | + ) |
| 267 | + |
| 268 | + for i in range(0, path.goal_reached_time()): |
| 269 | + obs_positions = grid.get_obstacle_positions_at_time(i) |
| 270 | + obs_points.set_data(obs_positions[0], obs_positions[1]) |
| 271 | + path_position = path.get_position(i) |
| 272 | + path_points.set_data([path_position.x], [path_position.y]) |
| 273 | + plt.pause(0.2) |
| 274 | + plt.show() |
| 275 | + |
| 276 | + |
| 277 | +if __name__ == "__main__": |
| 278 | + main() |
0 commit comments