Source code for commonroad_reach.data_structure.reach.reach_set_py_graph_online

import pyximport

pyximport.install()

import os
import pickle
import logging
import warnings
from functools import lru_cache
from collections import defaultdict
from typing import Optional, List, Dict, Tuple

import numpy as np
from scipy import sparse

from commonroad.scenario.scenario import Scenario
from commonroad_reach.__version__ import __version__
from commonroad_reach.data_structure.collision_checker import CollisionChecker
from commonroad_reach.data_structure.configuration import Configuration, VehicleConfiguration, ReachableSetConfiguration
from commonroad_reach.data_structure.reach.reach_node import ReachNodeMultiGeneration, ReachNode
from commonroad_reach.data_structure.reach.reach_polygon import ReachPolygon
from commonroad_reach.data_structure.reach.reach_set import ReachableSet
from commonroad_reach.data_structure.regular_grid import RegularGrid
import commonroad_reach.utility.logger as util_logger

logger = logging.getLogger(__name__)


[docs]class PyGraphReachableSetOnline(ReachableSet): """ Online step in the graph-based reachable set computation with Python backend. """ def __init__(self, config: Configuration): super().__init__(config) self._num_time_steps_offline_computation = 0 self._collision_checker: Optional[CollisionChecker] = None self._reachability_grid: Dict[int, np.ndarray] = {} self.obstacle_grid: Optional[RegularGrid] = None self.dict_time_to_list_tuples_reach_node_attributes = {} self.dict_time_to_adjacency_matrices_parent = {} self.dict_time_to_adjacency_matrices_grandparent = {} self.reachset_bb_ll: Dict[int, np.ndarray] = dict() self.reachset_bb_ur: Dict[int, np.ndarray] = dict() # contains all pre-computed nodes self._dict_time_to_reachable_set_all: Dict[int, List[ReachNodeMultiGeneration]] = defaultdict(list) self._dict_time_to_drivable_area_all: Dict[int, List[ReachPolygon]] = defaultdict(list) self._restore_reachability_graph() self._initialize_collision_checker() self.initialize_new_scenario(self.config.scenario, self.config.planning_problem) logger.info("PyGraphReachableSetOnline initialized.") @property def max_evaluated_step(self) -> int: return max(self._reachability_grid) def _dict_step_to_drivable_area(self) -> Dict[int, List[ReachPolygon]]: dict_step_to_drivable_area = {} for t in self._list_steps_computed: dict_step_to_drivable_area[t] = self.drivable_area_at_step(t) return dict_step_to_drivable_area def _dict_step_to_reachable_set(self) -> Dict[int, List[ReachNodeMultiGeneration]]: dict_step_to_reachable_set = {} for t in self._list_steps_computed: dict_step_to_reachable_set[t] = self.reachable_set_at_step(t) return dict_step_to_reachable_set @lru_cache(128) def _occ_grid_at_step(self, step: int) -> np.ndarray: occupied_grid_obs = self.obstacle_grid.occupancy_grid_at_step(step, self.reachset_translation(step)) return occupied_grid_obs.reshape([-1, 1])
[docs] @lru_cache(128) def drivable_area_at_step(self, step: int) -> List[ReachPolygon]: if step not in self._list_steps_computed: util_logger.print_and_log_warning(logger, f"Given step {step} for drivable area retrieval is out of range.") return [] else: rectangle_list_all = self._dict_time_to_drivable_area_all[step] drivable_area = [] for index_reachset, reachable in enumerate(self._reachability_grid[step].flatten()): if reachable: try: vertices = rectangle_list_all[index_reachset].vertices vertices += self.reachset_translation(step) drivable_area.append(ReachPolygon(vertices, fix_vertices=False)) except: continue return drivable_area
[docs] @lru_cache(128) def reachable_set_at_step(self, step: int) -> List[ReachNodeMultiGeneration]: if step not in self._list_steps_computed: util_logger.print_and_log_warning(logger, "Given step for drivable area retrieval is out of range.") return [] else: reachset_list_all = self._dict_time_to_reachable_set_all[step] reachset = [] for index_reachset in np.argwhere(self._reachability_grid[step]): reachset.append(reachset_list_all[index_reachset[0]]. translate(p_lon_off=self.reachset_translation(step)[0], p_lat_off=self.reachset_translation(step)[1], v_lon_off=self.config.planning.v_lon_initial)) if step > 0: self._restore_parent_node_relationships(reachset, step) return reachset
[docs] @lru_cache(128) def time_step(self, time_index: int) -> int: """ Converts relative time index (initial time_index = 0) to time_step (initial step = step_start) """ return time_index + self.step_start
[docs] @lru_cache(128) def reachset_translation(self, step: int) -> np.ndarray: """ Translation of initial state at the given step. """ return self.config.planning.p_initial \ + self.config.planning.v_initial * self.config.scenario.dt * step
[docs] def initialize_new_scenario(self, scenario: Optional[Scenario] = None, planning_problem: [Optional] = None): """ Resets online computation for evaluation of new scenario and/or planning problem; thus, avoid time for parsing pickle file again. :param scenario: new scenario (keep old scenario if None) :param planning_problem: new planning_problem (keep old planning_problem if None) """ update_planning_problem = planning_problem is not None update_cc = update_planning_problem or (scenario is not None) if scenario is not None: self.config.scenario = scenario if planning_problem is not None: self.config.update(scenario, planning_problem) self.reachset_translation.cache_clear() if update_cc: self._initialize_collision_checker() self.obstacle_grid = RegularGrid(self.reachset_bb_ll, self.reachset_bb_ur, self._collision_checker.cpp_collision_checker, self.config.reachable_set.size_grid, self.config.reachable_set.size_grid, self.config.planning, a_lon=self.config.vehicle.ego.a_max, a_lat=self.config.vehicle.ego.a_max, t_f=self.config.scenario.dt * self._num_time_steps_offline_computation, grid_shapes=self._grid_shapes) self._reachability_grid.clear() self.reachable_set_at_step.cache_clear() self.drivable_area_at_step.cache_clear() self._occ_grid_at_step.cache_clear() self._list_steps_computed = [0] self._reachability_grid[self.step_start] = np.ones((1, 1), dtype=bool)
def _restore_parent_node_relationships(self, reachset: List[ReachNode], step: int): """ Restores parent-child relationships. """ if step == 0: return ind_2_list_index_prev = np.insert(np.cumsum(self._reachability_grid[step - 1]), 0, 0) ind_2_list_index_current = np.insert(np.cumsum(self._reachability_grid[step]), 0, 0) for index_reachset in np.flatnonzero(self._reachability_grid[step]): reachable_parents = np.asarray(np.logical_and(self._reachability_grid[step - 1].flatten(), self.dict_time_to_adjacency_matrices_parent[ step].todense()[ index_reachset, :])) node = reachset[ind_2_list_index_current[index_reachset]] for index_parent in np.flatnonzero(reachable_parents): try: parent = self.reachable_set_at_step(step - 1)[ind_2_list_index_prev[index_parent] - 1] parent.add_child_node(node) node.add_parent_node(parent) except IndexError: continue def _restore_reachability_graph(self): """ Restores reachability graph from the offline computation result. """ self.dict_time_to_list_tuples_reach_node_attributes, self.dict_time_to_adjacency_matrices_parent, \ self.dict_time_to_adjacency_matrices_grandparent, \ self.reachset_bb_ll, self.reachset_bb_ur = self._load_offline_computation_result() self._num_time_steps_offline_computation = len(self.dict_time_to_list_tuples_reach_node_attributes) self._grid_shapes = {t: (round((self.reachset_bb_ur[t][0] - ll[0]) / self.config.reachable_set.size_grid), round((self.reachset_bb_ur[t][1] - ll[1]) / self.config.reachable_set.size_grid)) for t, ll in self.reachset_bb_ll.items()} self._restore_reachable_sets(self.dict_time_to_list_tuples_reach_node_attributes) def _load_offline_computation_result(self) -> Tuple: """ Loads pickle file generated in the offline computation step. """ util_logger.print_and_log_info(logger, "* Loading offline computation result...") path_file_pickle = os.path.join(self.config.general.path_offline_data, self.config.reachable_set.name_pickle_offline) dict_data = pickle.load(open(path_file_pickle, "rb")) if dict_data["__version__"] != __version__: raise ValueError(f"Offline data was created with an older version of commonroad-reach " f"{dict_data['__version__']}. " f"Please recreate the file with the current version {__version__} !") assert dict_data["coordinate_system"] == self.config.planning.coordinate_system, \ f"pickle file was created for coordinate_system={dict_data['coordinate_system']}," \ f"not {self.config.planning.coordinate_system}!" self._validate_configurations(self.config.reachable_set, self.config.vehicle, dict_data["config.reachable_set"], dict_data["config.vehicle"]) return dict_data["node_attributes"], dict_data["adjacency_matrices_parent"], \ dict_data["adjacency_matrices_grandparent"], \ dict_data["reachset_bb_ll"], dict_data["reachset_bb_ur"] @staticmethod def _validate_configurations(reachset_config_online: ReachableSetConfiguration, vehicle_config_online: VehicleConfiguration, reachset_config_offline: ReachableSetConfiguration, vehicle_config_offline: VehicleConfiguration): """ Ensures that original configuration from the offline data is used for relevant parameters. """ def validate_and_update_config(config_online, config_offline, relevant_attributes): for attr in vars(config_online): if attr in relevant_attributes and hasattr(config_offline, attr): if getattr(config_online, attr) != getattr(config_offline, attr): online_value = getattr(config_online, attr) offline_value = getattr(config_offline, attr) warn_text = f"Parameter {config_online.__class__.__name__}.{attr}=" \ f"{online_value}!={offline_value}, which was " \ f"used to create {reachset_config_online.name_pickle_offline}. Overwriting value..." warnings.warn(warn_text) logger.warning(warn_text) setattr(config_online, attr, offline_value) relevant_attributes_reachset = \ ["size_grid"] relevant_attributes_ego_vehicle = \ ["a_lon_max", "a_lon_min", "a_lat_max", "a_lat_min", "a_max", "v_lon_max"] validate_and_update_config(reachset_config_online, reachset_config_offline, relevant_attributes_reachset) validate_and_update_config(vehicle_config_online.ego, vehicle_config_offline.ego, relevant_attributes_ego_vehicle) assert reachset_config_online.n_multi_steps <= reachset_config_offline.n_multi_steps, \ f"pre-computed only {reachset_config_offline.n_multi_steps} multi-steps " \ f"but requested {reachset_config_online.n_multi_steps}" def _restore_reachable_sets(self, dict_time_to_list_tuples_reach_node_attributes: Dict[int, Tuple]): """ Restores reachable sets from the offline computation result. """ for time_step, list_tuples_attribute in dict_time_to_list_tuples_reach_node_attributes.items(): # reconstruct nodes in the reachability graph for tuple_attribute in list_tuples_attribute: p_x_min, p_y_min, p_x_max, p_y_max, v_x_min, v_y_min, v_x_max, v_y_max = tuple_attribute polygon_x = ReachPolygon.from_rectangle_vertices(p_x_min, v_x_min, p_x_max, v_x_max) polygon_y = ReachPolygon.from_rectangle_vertices(p_y_min, v_y_min, p_y_max, v_y_max) node = ReachNodeMultiGeneration(polygon_x, polygon_y, time_step) self._dict_time_to_reachable_set_all[time_step].append(node) position_rectangle = ReachPolygon.from_rectangle_vertices(p_x_min, p_y_min, p_x_max, p_y_max) self._dict_time_to_drivable_area_all[time_step].append(position_rectangle) def _initialize_collision_checker(self): self._collision_checker = CollisionChecker(self.config)
[docs] def compute(self, step_start: int = 1, step_end: Optional[int] = None): if step_end is None: step_end = self.step_end for step in range(step_start, step_end + 1): if step > self._num_time_steps_offline_computation: util_logger.print_and_log_warning(logger, f"Time step {step} is out of range, max allowed: " f"{self._num_time_steps_offline_computation}") return self._forward_propagation(step, self.config.reachable_set.n_multi_steps) self._list_steps_computed.append(step) if self.config.reachable_set.prune_nodes_not_reaching_final_step: self._prune_nodes_not_reaching_final_step()
def _forward_propagation(self, step: int, n_multi_steps: int): """ Propagates current reachability grid and excludes forbidden states. :param step: initial step of the reachable set :param n_multi_steps: number of previous time steps considered for the propagation """ if step >= self._num_time_steps_offline_computation: raise ValueError(f'Reached max number of offline computed time steps ' f'({self._num_time_steps_offline_computation})!') reachability_grid_prop = self.dict_time_to_adjacency_matrices_parent[step].dot( self._reachability_grid[step - 1].reshape([-1, 1])) if sparse.issparse(reachability_grid_prop): reachability_grid_prop = reachability_grid_prop.toarray() # propagate grandparents: if step > 1: for delta_time_step, adj_matrix_gp in self.dict_time_to_adjacency_matrices_grandparent[step].items(): # get time index of grandparent to propagate step_gp = step - delta_time_step if delta_time_step <= n_multi_steps and step_gp in self._reachability_grid: reachability_grid_prop_grandparent = \ adj_matrix_gp.dot(self._reachability_grid[step_gp].reshape([-1, 1])) if sparse.issparse(reachability_grid_prop_grandparent): prop_grandparent_tmp = reachability_grid_prop_grandparent.toarray() else: prop_grandparent_tmp = reachability_grid_prop_grandparent # intersect with propagated cells of other time steps reachability_grid_prop = np.logical_and(reachability_grid_prop, prop_grandparent_tmp) # intersect propagated cells with occupied cells reachability_grid_prop = np.logical_and(reachability_grid_prop.reshape([-1, 1]), self._occ_grid_at_step(step)) self._reachability_grid[step] = reachability_grid_prop def _prune_nodes_not_reaching_final_step(self): """ Prunes nodes that do not reach the final time step. """ for i_t in range(self.max_evaluated_step - 1, 0, -1): self._backward_step(i_t) self._pruned = True self.reachable_set_at_step.cache_clear() self.drivable_area_at_step.cache_clear() def _backward_step(self, step: int): """ Iterates through reachability graph backward in time, discards nodes that do not have a child node. """ if step < 0: logger.warning('Reached max number of backward time steps') return reachability_grid_prop = self.dict_time_to_adjacency_matrices_parent[step + 1].transpose() * \ self._reachability_grid[step + 1].reshape([-1, 1]) if sparse.issparse(reachability_grid_prop): reachability_grid_prop = reachability_grid_prop.toarray() # intersect propagated cells with occupied cells reachability_grid_prop_pruned = np.logical_and(reachability_grid_prop.reshape([-1, 1]), self._occ_grid_at_step(step)) self._reachability_grid[step] = \ np.logical_and(self._reachability_grid[step], reachability_grid_prop_pruned)