Skip to content

fix: PyTorch 2.x compatibility for GPU inference#44

Open
jimmyag2026-prog wants to merge 1 commit into
opendilab:mainfrom
jimmyag2026-prog:fix/pytorch-2x-compatibility
Open

fix: PyTorch 2.x compatibility for GPU inference#44
jimmyag2026-prog wants to merge 1 commit into
opendilab:mainfrom
jimmyag2026-prog:fix/pytorch-2x-compatibility

Conversation

@jimmyag2026-prog

Copy link
Copy Markdown

Summary

This PR fixes 18 device mismatch issues that prevent DI-star from running with PyTorch ≥2.0 on GPU. DI-star was originally developed for PyTorch 1.7.1, and PyTorch 2.x changed the behavior of:

  1. torch._six module removed (PyTorch 1.11+) — imports replaced with stdlib equivalents
  2. Device mismatch in GPU tensor indexingtorch.arange(bs) defaults to CPU but is used to index GPU tensors
  3. Mixed device in post-processing — model output (GPU) indexed into CPU observation tensors

Changes

distar/ctools/ — torch._six compatibility (3 files)

File Change
torch_utils/optimizer_util.py torch._six.infmath.inf
torch_utils/grad_clip.py torch._six.infmath.inf
data/collate_fn.py torch._six.string_classes(str, bytes)

distar/agent/default/model/policy.py — GPU tensor indexing

Line 35: SELECTED_UNITS_MASK[action[...]]SELECTED_UNITS_MASK.to(action[...].device)[action[...]]

distar/agent/default/model/head/action_arg_head.py — All 16 torch.arange(bs) calls

Replaced torch.arange(bs)torch.arange(bs, device=ae.device) (with flag.device in _get_key_mask scope where ae is not available).

distar/agent/default/agent.py — Post-processing device mismatch

Lines 404, 407: Added .cpu() when indexing CPU observations with GPU model output tensors.

Testing

  • Model: rl_model (Zerg) vs Elite bot (Zerg) — highest difficulty
  • Maps: KingsCove, KairosJunction, NewRepugnancy (random)
  • Result: All 3 maps — WIN (Outcome: [1])
  • Performance: GPU inference ~3.9× faster than CPU fallback
  • No algorithmic or behavioral changes — only device management fixes

Compatibility

Tested with:

  • PyTorch 2.4.1+cu121
  • CUDA 12.1 driver (backward compatible with 13.1)
  • Python 3.8
  • Ubuntu 22.04
  • NVIDIA RTX 4090

The original PyTorch 1.7.1 behavior is preserved — this PR only ensures the same code works correctly under newer PyTorch versions.

Fix 18 device mismatch issues that crash DI-star under PyTorch 2.x:

- torch._six module removed in PyTorch 1.11: replace with stdlib equivalents
  (optimizer_util.py, grad_clip.py, collate_fn.py)
- SELECTED_UNITS_MASK on CPU while action tensors on GPU (policy.py:35)
- torch.arange(bs) creates CPU tensors used to index GPU tensors
  (action_arg_head.py: all 16 occurrences)
- Selected unit indice on GPU while observation on CPU in post_process
  (agent.py:404, 407)

Tested: rl_model (Zerg) vs Elite bot (Zerg) on 3 random maps, all wins.
GPU inference ~3.9x faster than CPU mode with identical results.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant