Skip to content

Hardware Acceleration via JAX & Rendering Pipeline #7

@elprofesoriqo

Description

@elprofesoriqo

Here is a proposed roadmap and feature set to scale NetForge_RL:

  1. The JAX Rewrite (Massive Parallelization)
  • Rewriting the core environment logic in JAX (utilizing pure functions and jax.random.PRNGKey) allows leveraging jax.vmap for environment vectorization.
  • The Goal: Run 4,000+ parallel instances of NetForge_RL directly on a single GPU/TPU, completely bypassing Python's GIL and multiprocessing overhead.
  1. API Standardization (JaxMARL Compatibility)
  • Aligning the API with JaxMARL.
  • Implementing a pure functional step(key, state, actions) method that returns dictionaries of observations, rewards, and dones for all agents.
  1. Rendering and Visualization Pipeline
  • Proposal: Keep the core JAX environment completely separated from rendering to maintain blazing-fast SPS (Steps Per Second) during training.
  • Implement an independent render() method that pulls the state from the GPU to the CPU (jax.device_get()) and uses NetworkX combined with Matplotlib or Pygame to draw the network topology.
  • Visual logic could involve nodes changing colors based on their current state (e.g., green = secure, red = compromised, blue = defended). Frames can be buffered and saved via moviepy to automatically generate .mp4 or .gif files during evaluation episodes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions