Here is a proposed roadmap and feature set to scale NetForge_RL:
- The JAX Rewrite (Massive Parallelization)
- Rewriting the core environment logic in JAX (utilizing pure functions and jax.random.PRNGKey) allows leveraging jax.vmap for environment vectorization.
- The Goal: Run 4,000+ parallel instances of NetForge_RL directly on a single GPU/TPU, completely bypassing Python's GIL and multiprocessing overhead.
- API Standardization (JaxMARL Compatibility)
- Aligning the API with JaxMARL.
- Implementing a pure functional step(key, state, actions) method that returns dictionaries of observations, rewards, and dones for all agents.
- Rendering and Visualization Pipeline
- Proposal: Keep the core JAX environment completely separated from rendering to maintain blazing-fast SPS (Steps Per Second) during training.
- Implement an independent render() method that pulls the state from the GPU to the CPU (jax.device_get()) and uses NetworkX combined with Matplotlib or Pygame to draw the network topology.
- Visual logic could involve nodes changing colors based on their current state (e.g., green = secure, red = compromised, blue = defended). Frames can be buffered and saved via moviepy to automatically generate .mp4 or .gif files during evaluation episodes.
Here is a proposed roadmap and feature set to scale NetForge_RL: