Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ logs/

# Dependency directories
venv/
.venv/
env/
env.bak/
env.tmp/
Expand All @@ -46,4 +47,6 @@ env.production.local/
Thumbs.db

# QASM files
*.qasm
*.qasm

results/*
13 changes: 9 additions & 4 deletions env_creator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gymenv_qsimpy import QSimPyEnv
from env_wrapper import ScaleQSimPyEnv
from gymnasium.experimental.wrappers import RescaleObservationV0, DtypeObservationV0
from env_wrapper import ScaleQSimPyEnv , SerializableEnvWrapper
from gymnasium.wrappers import RescaleObservation, DtypeObservation
import numpy as np


Expand All @@ -10,20 +10,25 @@ def qsimpy_env_creator(env_config):
config = config if config is not None else {}
if dataset is None:
raise ValueError("Dataset is not specified")

env = QSimPyEnv(dataset=dataset, config=config)
env = SerializableEnvWrapper(env)

obs_filter = env_config.pop("obs_filter", None)
reward_filter = env_config.pop("reward_filter", None)

if obs_filter is not None:
if obs_filter == "rescale_-1_1":
env = RescaleObservationV0(
env = RescaleObservation(
env=env,
min_obs=np.ones((env.obs_dim,), dtype=np.float32) * -1,
max_obs=np.ones((env.obs_dim,), dtype=np.float32) * 1,
)
env = DtypeObservationV0(env, dtype=np.float32)
env = DtypeObservation(env, dtype=np.float32)

if reward_filter is not None:
if reward_filter == "scale_2x":
env = ScaleQSimPyEnv(env, scale=env_config.pop("reward_scale", 2))

# for i in range(10) : print(type(env))
return env
50 changes: 41 additions & 9 deletions env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gymnasium as gym
from gymnasium.core import Env
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward
# Make sure these are here if not already
from numpy.random import default_rng
import simpy
import numpy as np
from gymnasium.spaces import Box


class ScaleQSimPyEnv(gym.RewardWrapper):
Expand All @@ -13,12 +14,43 @@ def __init__(self, env: Env, scale: float):
def reward(self, reward):
reward *= self.scaling_factor
return reward
class SerializableEnvWrapper(gym.Wrapper):
def __getattr__(self, name):
return getattr(self.env, name)

def __getstate__(self):
# Start with wrapper __dict__
state = self.__dict__.copy()

class GymNormalizeObservation(NormalizeObservation):
def __init__(self, env: Env, *args, **kwargs):
super().__init__(env, *args, **kwargs)
self.observation_space = Box(
low=np.ones((self.env.obs_dim,)) * -np.inf,
high=np.ones((self.env.obs_dim,)) * np.inf,
)
# Replace self.env with its safe state
if hasattr(self.env, "__getstate__"):
state["env_state"] = self.env.__getstate__()
else:
state["env_state"] = self.env.__dict__.copy()

# Don't pickle the actual env object directly
if "env" in state:
del state["env"]

# Debug: check for generators in wrapper state
for k, v in list(state.items()):
if hasattr(v, "__iter__") and not isinstance(v, (list, tuple, dict, str, bytes, np.ndarray)):
print(f"[WRAPPER-PICKLE] Removing generator-like object at key '{k}' ({type(v)})")
del state[k]

# Preserve dataset path for reconstruction
state["_dataset_path"] = getattr(self.env, "dataset_path", None)
return state

def __setstate__(self, state):
from gymenv_qsimpy import QSimPyEnv
dataset_path = state.pop("_dataset_path", None)
if not dataset_path:
raise ValueError("Missing dataset path for deserialization")

new_env = QSimPyEnv(dataset=dataset_path)
if "env_state" in state:
new_env.__setstate__(state.pop("env_state"))

super().__init__(new_env)
self.__dict__.update(state)
2 changes: 1 addition & 1 deletion evaluator_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def enhanced_greedy_policy(env, skipped_list):
current_obs = env.current_obs
current_obs = env.unwrapped.current_obs
qnode_start_index = 4 # Adjust based on the actual qtask observation length
qnode_obs_length = 3 # Number of values per qnode in the observation

Expand Down
Loading