Skip to content

KrishSingaria/graph-based-neural-A-star

Repository files navigation

Generative Neural A*

This repository contains the implementation of Generative Neural-A*, a novel approach that accelerates A* search using generative AI. By combining a Graph Attention Network (GAT) encoder with a Conditional Diffusion Model generator, this system learns to predict highly optimized, maze-specific heuristic functions $h^*(n)$ that significantly reduce the number of node expansions required during pathfinding while preserving path optimality.

Overview

Read draft paper (in paper folder)

Traditional A* relies on static heuristics (like Manhattan or Euclidean distance) which do not adapt to complex obstacle topologies. Generative Neural-A* replaces these static heuristics by treating heuristic generation as a conditional denoising process. The system looks at the maze structure, start, and goal locations, and "diffuses" a smooth, topology-aware heuristic map (the "Logarithmic Trench") that guides the A* search algorithm optimally around walls and obstacles.

Key Components

Architecture

  • Graph Attention Network (GAT) Encoder: Processes the grid as a PyTorch Geometric graph, capturing spatial relationships and obstacle layouts.
  • Conditional Diffusion Model: Acts as the generator. It takes the GAT embeddings as conditioning and performs a reverse diffusion process to output the continuous heuristic heatmap.
  • Neurosymbolic Pipeline: The generated heatmap is dynamically normalized and passed as a custom heuristic function to a standard A* search algorithm.

Project Structure

generative-neural-astar/
│
├── data/                       # Generated datasets
│   ├── raw/                    # Unprocessed 64x64 grid dictionaries
│   ├── processed/              # PyTorch Geometric graph objects (.pt files)
│   └── evaluation/             # Unseen test set data
│
├── models/                     # Neural Network Architectures
│   ├── gat_encoder.py          # The Graph Attention Network Encoder
│   └── diffusion_generator.py  # The Conditional Diffusion Model Generator
│
├── utils/                      # Core Logic & Helpers
│   ├── maze_env.py             # Maze generation logic
│   ├── graph_builder.py        # Converts 2D grids to PyG Graphs
│   ├── search_algos.py         # BFS (Ground Truth) and standard A* engine
│   └── visualizer.py           # Matplotlib utilities for paths and heatmaps
│
├── scripts/                    # Executable pipeline scripts
│   ├── generate_data.py        # Generates training and evaluation mazes
│   ├── train.py                # End-to-end training of the GAT and Diffusion model
│   ├── evaluate.py             # Evaluates Neural A* vs Standard A* efficiency
│   └── paper_plots.py          # Generates figures for the academic paper
│
├── checkpoints/                # Model weights (system_best.pt)
├── results/                    # Generated output plots from evaluation
└── requirements.txt            # Python dependencies

Installation

Ensure you have Python 3.9+ and a CUDA-enabled GPU (recommended). Install the required dependencies:

pip install -r requirements.txt

Note: The requirements.txt specifies cu121 for PyTorch to ensure GPU compatibility. Adjust this if you are using a different CUDA version.

Pipeline Usage

Run the scripts from the root directory using the -m flag.

1. Data Generation

Generate the maze datasets (20,000 training mazes, 100 evaluation mazes). This process will automatically calculate the ground truth optimal heuristics (the "Logarithmic Trench") via BFS.

python -m scripts.generate_data

2. Training

Train the Generative A* System (GAT + Diffusion U-Net) end-to-end. This script saves the best weights to checkpoints/system_best.pt.

python -m scripts.train

3. Evaluation

Test the trained model on the unseen evaluation set. The script pits Standard A* (with Manhattan heuristic) against Generative Neural A*, calculating node expansions and path optimality metrics. Visualizations of the paths are saved in the results/ directory.

python -m scripts.evaluate

Evaluation

4. Paper Plots

Generate all figures and ablation studies for the paper (Pareto frontiers, heuristic evolution, and diffusion denoising timelines).

python -m scripts.paper_plots

Generated Figures

The scripts/paper_plots.py script automatically creates the following visualizations in the project root:

  • fig_1_pareto_frontier.png: Efficiency vs. Optimality tradeoff across different heuristic scales. Pareto Frontier

  • fig_2_heuristic_evolution.png: Visualizes the mathematical transition from pure BFS to the Logarithmic Trench. Heuristic Evolution

  • fig_3_denoising_timeline.png: Step-by-step reverse diffusion process generating the heuristic map. Denoising Timeline

  • fig_4_pareto_paths.png: Side-by-side path comparisons demonstrating dynamic normalization scaling. Pareto Paths

    Release contains the trained model

    release

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages