Source code for commonroad_reach.data_structure.reach.reach_node

import copy
from collections import defaultdict
from typing import Optional, Dict, Set, List

from shapely import affinity

from commonroad_reach.data_structure.reach.reach_polygon import ReachPolygon


[docs]class ReachNode: """ Node within a reachability graph, also used in reachable set computations. .. note:: - Each node is a Cartesian product of longitudinal and lateral polygons. - Curvilinear coordinate system: polygon_lon is a polygon in the longitudinal p-v domain, and polygon_lat is a polygon in the lateral p-v domain. - Cartesian coordinate system: polygons are in the x-v and y-v domains, respectively. """ cnt_id = 0 def __init__(self, polygon_lon: ReachPolygon, polygon_lat: ReachPolygon, step: int = -1): self._polygon_lon: ReachPolygon = polygon_lon self._polygon_lat: ReachPolygon = polygon_lat self._bounds_lon = polygon_lon.bounds if polygon_lon else None self._bounds_lat = polygon_lat.bounds if polygon_lat else None self.position_rectangle: Optional[ReachPolygon] = None if self._bounds_lon and self._bounds_lat: self.update_position_rectangle() self.id = ReachNode.cnt_id ReachNode.cnt_id += 1 self.step = step self.list_nodes_parent: List[ReachNode] = list() self.list_nodes_child: List[ReachNode] = list() # the node from which the current node is propagated self.source_propagation = None def __repr__(self): return f"ReachNode(step={self.step}, id={self.id})" def __eq__(self, other: object) -> bool: if isinstance(other, ReachNode): return self.id == other.id and self.step == other.step else: return False def __key(self): return self.id, self.step def __hash__(self): return hash(self.__key()) @property def polygon_lon(self) -> ReachPolygon: """ Polygon in the longitudinal direction. See note of :class:`ReachNode`. """ return self._polygon_lon @polygon_lon.setter def polygon_lon(self, polygon: ReachPolygon): self._polygon_lon = polygon self._bounds_lon = polygon.bounds @property def polygon_lat(self) -> ReachPolygon: """ Polygon in the lateral direction. See note of :class:`ReachNode`. """ return self._polygon_lat @polygon_lat.setter def polygon_lat(self, polygon: ReachPolygon): self._polygon_lat = polygon self._bounds_lat = polygon.bounds @property def p_lon_min(self): """ Minimum position in the longitudinal direction. """ return self._bounds_lon[0] @property def p_lon_max(self): """ Maximum position in the longitudinal direction. """ return self._bounds_lon[2] @property def v_lon_min(self): """ Minimum velocity in the longitudinal direction. """ return self._bounds_lon[1] @property def v_lon_max(self): """ Maximum velocity in the longitudinal direction. """ return self._bounds_lon[3] @property def p_lat_min(self): """ Minimum position in the lateral direction. """ return self._bounds_lat[0] @property def p_lat_max(self): """ Maximum position in the lateral direction. """ return self._bounds_lat[2] @property def v_lat_min(self): """ Minimum velocity in the lateral direction. """ return self._bounds_lat[1] @property def v_lat_max(self): """ Maximum velocity in the lateral direction. """ return self._bounds_lat[3] @property def p_x_min(self): """ Minimum x-position in the Cartesian coordinate system. """ return self.p_lon_min @property def p_x_max(self): """ Maximum x-position in the Cartesian coordinate system. """ return self.p_lon_max @property def v_x_min(self): """ Minimum x-velocity in the Cartesian coordinate system. """ return self.v_lon_min @property def v_x_max(self): """ Maximum x-velocity in the Cartesian coordinate system. """ return self.v_lon_max @property def p_y_min(self): """ Minimum y-position in the Cartesian coordinate system. """ return self.p_lat_min @property def p_y_max(self): """ Maximum y-position in the Cartesian coordinate system. """ return self.p_lat_max @property def v_y_min(self): """ Minimum y-velocity in the Cartesian coordinate system. """ return self.v_lat_min @property def v_y_max(self): """ Maximum y-velocity in the Cartesian coordinate system. """ return self.v_lat_max
[docs] def clone(self) -> "ReachNode": """ Returns a clone of the reach node. """ node_clone = ReachNode(self.polygon_lon.clone(convexify=False), self.polygon_lat.clone(convexify=False), self.step) node_clone.list_nodes_parent = copy.deepcopy(self.list_nodes_parent) node_clone.list_nodes_child = copy.deepcopy(self.list_nodes_child) node_clone.source_propagation = self.source_propagation return node_clone
[docs] def update_position_rectangle(self): """ Updates the position rectangle based on the latest position attributes. """ tuple_vertices_rectangle = (self.p_lon_min, self.p_lat_min, self.p_lon_max, self.p_lat_max) self.position_rectangle = ReachPolygon.from_rectangle_vertices(*tuple_vertices_rectangle)
[docs] def translate(self, p_lon_off: float = 0.0, v_lon_off: float = 0.0, p_lat_off: float = 0.0, v_lat_off: float = 0.0): """ Returns a copy translated by input offsets. """ return ReachNode( ReachPolygon.from_polygon(affinity.translate(self.polygon_lon, xoff=p_lon_off, yoff=v_lon_off)), ReachPolygon.from_polygon(affinity.translate(self.polygon_lat, xoff=p_lat_off, yoff=v_lat_off)), step=self.step)
[docs] def add_parent_node(self, node_parent: "ReachNode"): if node_parent not in self.list_nodes_parent: self.list_nodes_parent.append(node_parent)
[docs] def remove_parent_node(self, node_parent: "ReachNode") -> bool: if node_parent in self.list_nodes_parent: self.list_nodes_parent.remove(node_parent) return True return False
[docs] def add_child_node(self, node_child: "ReachNode"): if node_child not in self.list_nodes_child: self.list_nodes_child.append(node_child)
[docs] def remove_child_node(self, node_child: "ReachNode") -> bool: if node_child in self.list_nodes_child: self.list_nodes_child.remove(node_child) return True return False
[docs] def intersect_in_position_domain(self, p_lon_min: float, p_lat_min: float, p_lon_max: float, p_lat_max: float): """ Perform intersection in the position domain. """ self._polygon_lon = self.polygon_lon.intersect_halfspace(1, 0, p_lon_max) self._polygon_lon = self.polygon_lon.intersect_halfspace(-1, 0, -p_lon_min) self._polygon_lat = self.polygon_lat.intersect_halfspace(1, 0, p_lat_max) self._polygon_lat = self.polygon_lat.intersect_halfspace(-1, 0, -p_lat_min) self._bounds_lon = self._polygon_lon.bounds self._bounds_lat = self._polygon_lat.bounds self.update_position_rectangle()
[docs] def intersect_in_velocity_domain(self, v_lon_min: float, v_lat_min: float, v_lon_max: float, v_lat_max: float): """ Perform intersection in the velocity domain. """ self._polygon_lon = self.polygon_lon.intersect_halfspace(0, 1, v_lon_max) self._polygon_lon = self.polygon_lon.intersect_halfspace(0, -1, -v_lon_min) self._polygon_lat = self.polygon_lat.intersect_halfspace(0, 1, v_lat_max) self._polygon_lat = self.polygon_lat.intersect_halfspace(0, -1, -v_lat_min)
[docs] @classmethod def reset_class_id_counter(cls): cls.cnt_id = 0
[docs]class ReachNodeMultiGeneration(ReachNode): """ Node within a reachability graph, also used in reachable set computations. In addition to :class:`ReachNode`, this class holds lists reach nodes across generations. """ def __init__(self, polygon_lon, polygon_lat, step: int = -1): super(ReachNodeMultiGeneration, self).__init__(polygon_lon, polygon_lat, step) self.dict_time_to_set_nodes_grandparent: Dict[int, Set[ReachNodeMultiGeneration]] = defaultdict(set) self.dict_time_to_set_nodes_grandchild: Dict[int, Set[ReachNodeMultiGeneration]] = defaultdict(set)
[docs] def add_grandparent_node(self, node_grandparent: "ReachNodeMultiGeneration") -> bool: delta_steps = self.step - node_grandparent.step assert delta_steps > 1, f"not a grand_parent: node_grandparent.step={node_grandparent.step}, " \ f"self.step={self.step}" if node_grandparent not in self.dict_time_to_set_nodes_grandparent[delta_steps]: self.dict_time_to_set_nodes_grandparent[delta_steps].add(node_grandparent) return True return False
[docs] def remove_grandparent_node(self, node_grandparent: "ReachNodeMultiGeneration") -> bool: delta_steps = self.step - node_grandparent.step if node_grandparent in self.dict_time_to_set_nodes_grandparent[delta_steps]: self.dict_time_to_set_nodes_grandparent[delta_steps].remove(node_grandparent) return True return False
[docs] def add_grandchild_node(self, node_grandchild: "ReachNodeMultiGeneration") -> bool: delta_steps = node_grandchild.step - self.step assert delta_steps > 1, f"not a grandchild: node_grandchild.step={node_grandchild.step}, " \ f"self.step={self.step}" if node_grandchild not in self.dict_time_to_set_nodes_grandchild[delta_steps]: self.dict_time_to_set_nodes_grandchild[delta_steps].add(node_grandchild) return True return False
[docs] def remove_grandchild_node(self, node_grandchild: "ReachNodeMultiGeneration") -> bool: delta_steps = node_grandchild.step - self.step if node_grandchild in self.dict_time_to_set_nodes_grandchild[delta_steps]: self.dict_time_to_set_nodes_grandchild[delta_steps].remove(node_grandchild) return True return False