Skip to content

Commit 54a33ee

Browse files
committed
it works and is WAY faster than a*
1 parent 30a61ad commit 54a33ee

File tree

2 files changed

+332
-1
lines changed

2 files changed

+332
-1
lines changed

PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __sub__(self, other):
3030
f"Subtraction not supported for Position and {type(other)}"
3131
)
3232

33+
@dataclass
34+
class Interval:
35+
start_time: int
36+
end_time: int
3337

3438
class ObstacleArrangement(Enum):
3539
# Random obstacle positions and movements
@@ -180,10 +184,13 @@ def obstacle_arrangement_1(self, obs_count: int) -> list[list[Position]]:
180184
output:
181185
bool: True if position/time combination is valid, False otherwise
182186
"""
183-
def valid_position(self, position: Position, t: int) -> bool:
187+
def valid_position(self, position: Position, t: int = None) -> bool:
184188
# Check if new position is in grid
185189
if not self.inside_grid_bounds(position):
186190
return False
191+
192+
if not t:
193+
return True
187194

188195
# Check if new position is not occupied at time t
189196
return self.reservation_matrix[position.x, position.y, t] == 0
@@ -231,6 +238,52 @@ def get_obstacle_positions_at_time(self, t: int) -> tuple[list[int], list[int]]:
231238
y_positions.append(obs_path[t].y)
232239
return (x_positions, y_positions)
233240

241+
"""
242+
Returns safe intervals for each cell
243+
"""
244+
def get_safe_intervals(self) -> np.ndarray:
245+
intervals = np.empty((self.grid_size[0], self.grid_size[1]), dtype=object)
246+
intervals[:] = [[[] for _ in range(intervals.shape[1])] for _ in range(intervals.shape[0])]
247+
for x in range(intervals.shape[0]):
248+
for y in range(intervals.shape[1]):
249+
intervals[x, y] = self.get_safe_intervals_at_cell(Position(x, y))
250+
251+
return intervals
252+
253+
def get_safe_intervals_at_cell(self, cell: Position) -> list[Interval]:
254+
vals = self.reservation_matrix[cell.x, cell.y, :]
255+
# Find where the array is zero
256+
zero_mask = (vals == 0)
257+
258+
# Identify transitions between zero and nonzero elements
259+
diff = np.diff(zero_mask.astype(int))
260+
261+
# Start indices: where zeros begin (1 after a nonzero)
262+
start_indices = np.where(diff == 1)[0] + 1
263+
264+
# End indices: where zeros stop (just before a nonzero)
265+
end_indices = np.where(diff == -1)[0]
266+
267+
# Handle edge cases if the array starts or ends with zeros
268+
if zero_mask[0]: # If the first element is zero, add index 0 to start_indices
269+
start_indices = np.insert(start_indices, 0, 0)
270+
if zero_mask[-1]: # If the last element is zero, add the last index to end_indices
271+
end_indices = np.append(end_indices, len(vals) - 1)
272+
273+
# Create pairs of (first zero, last zero)
274+
# TODO - this is generating np.int instead of normal int, is that alright?
275+
intervals = [Interval(start, end) for start, end in zip(start_indices, end_indices)]
276+
277+
print(f"intervals at position {cell} : {intervals}")
278+
279+
# for i in range(len(intervals)):
280+
for interval in intervals:
281+
if interval.start_time == interval.end_time:
282+
print("AAAAAAAAAA matching! ", interval.start_time)
283+
# TODO: hate this modification in the loop
284+
intervals.remove(interval)
285+
286+
return intervals
234287

235288
show_animation = True
236289

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""
2+
Safe interval path planner
3+
4+
TODO: populate docstring
5+
"""
6+
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import (
10+
Grid,
11+
Interval,
12+
ObstacleArrangement,
13+
Position,
14+
)
15+
import heapq
16+
from collections.abc import Generator
17+
import random
18+
from dataclasses import dataclass
19+
from functools import total_ordering
20+
21+
22+
# Seed randomness for reproducibility
23+
RANDOM_SEED = 50
24+
random.seed(RANDOM_SEED)
25+
np.random.seed(RANDOM_SEED)
26+
27+
@dataclass()
28+
# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because
29+
# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. Parent
30+
# index is just used to track the path found by the algorithm, and has no effect on the quality
31+
# of a node.
32+
@total_ordering
33+
class Node:
34+
position: Position
35+
time: int
36+
heuristic: int
37+
parent_index: int
38+
interval: Interval
39+
40+
"""
41+
This is what is used to drive node expansion. The node with the lowest value is expanded next.
42+
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
43+
"""
44+
def __lt__(self, other: object):
45+
if not isinstance(other, Node):
46+
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
47+
# TODO: assumption that these two carry all the info needed for intervals. I think that makes sense but should think about it
48+
return (self.time + self.heuristic) < (other.time + other.heuristic)
49+
50+
"""
51+
TODO - note about interval being included here
52+
"""
53+
def __eq__(self, other: object):
54+
if not isinstance(other, Node):
55+
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
56+
return self.position == other.position and self.time == other.time and self.interval == other.interval
57+
58+
59+
class NodePath:
60+
path: list[Node]
61+
positions_at_time: dict[int, Position] = {}
62+
63+
def __init__(self, path: list[Node]):
64+
self.path = path
65+
for node in path:
66+
self.positions_at_time[node.time] = node.position
67+
68+
"""
69+
Get the position of the path at a given time
70+
"""
71+
def get_position(self, time: int) -> Position | None:
72+
return self.positions_at_time.get(time)
73+
74+
"""
75+
Time stamp of the last node in the path
76+
"""
77+
def goal_reached_time(self) -> int:
78+
return self.path[-1].time
79+
80+
def __repr__(self):
81+
repr_string = ""
82+
for i, node in enumerate(self.path):
83+
repr_string += f"{i}: {node}\n"
84+
return repr_string
85+
86+
87+
class SafeIntervalPathPlanner:
88+
grid: Grid
89+
start: Position
90+
goal: Position
91+
92+
def __init__(self, grid: Grid, start: Position, goal: Position):
93+
self.grid = grid
94+
self.start = start
95+
self.goal = goal
96+
97+
def plan(self, verbose: bool = False) -> NodePath:
98+
99+
safe_intervals = self.grid.get_safe_intervals()
100+
print(safe_intervals[0, 0])
101+
102+
open_set: list[Node] = []
103+
first_node_interval = safe_intervals[self.start.x, self.start.y][0]
104+
heapq.heappush(
105+
open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1, first_node_interval)
106+
)
107+
108+
expanded_list: list[Node] = []
109+
# TODO: copy pasta from Grid file
110+
# 2d np array of lists of (entry time, interval tuples)
111+
# TODO: use a dataclass for the tuple
112+
visited_intervals = np.empty((self.grid.grid_size[0], self.grid.grid_size[1]), dtype=object)
113+
visited_intervals[:] = [[[] for _ in range(visited_intervals.shape[1])] for _ in range(visited_intervals.shape[0])]
114+
while open_set:
115+
expanded_node: Node = heapq.heappop(open_set)
116+
if verbose:
117+
print("Expanded node:", expanded_node)
118+
119+
if expanded_node.time + 1 >= self.grid.time_limit:
120+
if verbose:
121+
print(f"\tSkipping node that is past time limit: {expanded_node}")
122+
continue
123+
124+
if expanded_node.position == self.goal:
125+
print(f"Found path to goal after {len(expanded_list)} expansions")
126+
path = []
127+
path_walker: Node = expanded_node
128+
while path_walker.parent_index != -1:
129+
path.append(path_walker)
130+
path_walker = expanded_list[path_walker.parent_index]
131+
# TODO: fix hack around bad while condiiotn
132+
path.append(path_walker)
133+
134+
# reverse path so it goes start -> goal
135+
path.reverse()
136+
return NodePath(path)
137+
138+
expanded_idx = len(expanded_list)
139+
expanded_list.append(expanded_node)
140+
visited_intervals[expanded_node.position.x, expanded_node.position.y].append((expanded_node.time, expanded_node.interval))
141+
142+
# if len(expanded_set) > 100:
143+
# blarg
144+
145+
for child in self.generate_successors(expanded_node, expanded_idx, verbose, safe_intervals, visited_intervals):
146+
heapq.heappush(open_set, child)
147+
148+
raise Exception("No path found")
149+
150+
"""
151+
Generate possible successors of the provided `parent_node`
152+
"""
153+
# TODO: is intervals being passed by ref? (i think so?)
154+
def generate_successors(
155+
self, parent_node: Node, parent_node_idx: int, verbose: bool, intervals: np.ndarray, visited_intervals: np.ndarray
156+
) -> list[Node]:
157+
new_nodes = []
158+
diffs = [
159+
Position(0, 0),
160+
Position(1, 0),
161+
Position(-1, 0),
162+
Position(0, 1),
163+
Position(0, -1),
164+
]
165+
for diff in diffs:
166+
new_pos = parent_node.position + diff
167+
if not self.grid.valid_position(new_pos):
168+
continue
169+
170+
current_interval = parent_node.interval
171+
172+
new_cell_intervals: list[Interval] = intervals[new_pos.x, new_pos.y]
173+
for interval in new_cell_intervals:
174+
# if interval ends before current starts, skip
175+
if interval.end_time < current_interval.start_time:
176+
continue
177+
178+
# if interval starts after current ends, break
179+
# TODO: assumption here that intervals are sorted (they should be)
180+
if interval.start_time > current_interval.end_time:
181+
break
182+
183+
# TODO: this bit feels wonky
184+
# if we have already expanded a node in this interval with a <= starting time, continue
185+
better_node_expanded = False
186+
for visited in visited_intervals[new_pos.x, new_pos.y]:
187+
if interval == visited[1] and visited[0] <= parent_node.time + 1:
188+
better_node_expanded = True
189+
break
190+
if better_node_expanded:
191+
continue
192+
193+
# We know there is some overlap. Generate successor at the earliest possible time the
194+
# new interval can be entered
195+
# TODO: dont love the optionl usage here
196+
new_node_t = None
197+
for possible_t in range(max(parent_node.time + 1, interval.start_time), min(current_interval.end_time, interval.end_time)):
198+
if self.grid.valid_position(new_pos, possible_t):
199+
new_node_t = possible_t
200+
break
201+
202+
if new_node_t:
203+
# TODO: should be able to break here?
204+
new_nodes.append(Node(
205+
new_pos,
206+
# entry is max of interval start and parent node start time (get there as soon as possible)
207+
max(parent_node.time + 1, interval.start_time),
208+
self.calculate_heuristic(new_pos),
209+
parent_node_idx,
210+
interval,
211+
))
212+
213+
return new_nodes
214+
215+
def calculate_heuristic(self, position) -> int:
216+
diff = self.goal - position
217+
return abs(diff.x) + abs(diff.y)
218+
219+
220+
show_animation = True
221+
verbose = True
222+
223+
# TODO: viz shows obstacle finish 1 cell above the goal?
224+
def main():
225+
start = Position(1, 18)
226+
goal = Position(19, 19)
227+
grid_side_length = 21
228+
grid = Grid(
229+
np.array([grid_side_length, grid_side_length]),
230+
# TODO: if this is set to 0, still get some obstacles with random set
231+
num_obstacles=22,
232+
obstacle_avoid_points=[start, goal],
233+
obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1,
234+
# obstacle_arrangement=ObstacleArrangement.RANDOM,
235+
)
236+
237+
planner = SafeIntervalPathPlanner(grid, start, goal)
238+
path = planner.plan(verbose)
239+
240+
if verbose:
241+
print(f"Path: {path}")
242+
243+
if not show_animation:
244+
return
245+
246+
fig = plt.figure(figsize=(10, 7))
247+
ax = fig.add_subplot(
248+
autoscale_on=False,
249+
xlim=(0, grid.grid_size[0] - 1),
250+
ylim=(0, grid.grid_size[1] - 1),
251+
)
252+
ax.set_aspect("equal")
253+
ax.grid()
254+
ax.set_xticks(np.arange(0, grid_side_length, 1))
255+
ax.set_yticks(np.arange(0, grid_side_length, 1))
256+
257+
(start_and_goal,) = ax.plot([], [], "mD", ms=15, label="Start and Goal")
258+
start_and_goal.set_data([start.x, goal.x], [start.y, goal.y])
259+
(obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles")
260+
(path_points,) = ax.plot([], [], "bo", ms=10, label="Path Found")
261+
ax.legend(bbox_to_anchor=(1.05, 1))
262+
263+
# for stopping simulation with the esc key.
264+
plt.gcf().canvas.mpl_connect(
265+
"key_release_event", lambda event: [exit(0) if event.key == "escape" else None]
266+
)
267+
268+
for i in range(0, path.goal_reached_time()):
269+
obs_positions = grid.get_obstacle_positions_at_time(i)
270+
obs_points.set_data(obs_positions[0], obs_positions[1])
271+
path_position = path.get_position(i)
272+
path_points.set_data([path_position.x], [path_position.y])
273+
plt.pause(0.2)
274+
plt.show()
275+
276+
277+
if __name__ == "__main__":
278+
main()

0 commit comments

Comments
 (0)