Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions deepdrivewe/examples/openmm_ntl9_ddwe/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepdrivewe.recyclers import LowRecycler
from deepdrivewe.resamplers import LOFLowResampler

from proxystore.store import get_store

class InferenceConfig(BaseModel):
"""Arguments for the inference module."""
Expand Down Expand Up @@ -62,28 +63,40 @@ class InferenceConfig(BaseModel):


def run_inference(
sim_output: list[SimResult],
sim_output: list,
train_output: TrainResult,
basis_states: BasisStates,
Comment thread
NikJur marked this conversation as resolved.
target_states: list[TargetState],
config: InferenceConfig,
output_dir: Path,
) -> tuple[list[SimMetadata], list[SimMetadata], IterationMetadata]:
"""Run inference on the input data."""

# Initialise the store and resolve simulation
store = get_store('file-store')
if store is None:
raise RuntimeError("ProxyStore 'file-store' not found on worker.")

resolved_sims = [store.get(key) for key in sim_output]

# Resolve train_output if it was passed as a key
if not hasattr(train_output, 'checkpoint_path'):
train_output = store.get(train_output)

Comment thread
NikJur marked this conversation as resolved.
# Make the output directory
itetation = sim_output[0].metadata.iteration_id
itetation = resolved_sims[0].metadata.iteration_id
output_dir = output_dir / f'{itetation:06d}'
output_dir.mkdir(parents=True, exist_ok=True)

# Extract the rmsd pcoord from the last frame of each simulation
pcoords = [sim.metadata.pcoord[-1][0] for sim in sim_output]
pcoords = [sim.metadata.pcoord[-1][0] for sim in resolved_sims]

print(f'Progress coordinates: {pcoords}')
print(f'Best progress coordinate: {min(pcoords)}')
print(f'Num input simulations: {len(sim_output)}')
print(f'Num input simulations: {len(resolved_sims)}')

# Extract the simulation metadata
cur_sims = [sim.metadata for sim in sim_output]
cur_sims = [sim.metadata for sim in resolved_sims]

# Load the model and history
model, history = warmstart_model(
Expand All @@ -92,8 +105,8 @@ def run_inference(
)

# Extract the last frame contact maps and rmsd from each simulation
contact_maps = [sim.data['contact_maps'][-1] for sim in sim_output]
pcoords = [sim.data['pcoords'][-1] for sim in sim_output]
contact_maps = [sim.data['contact_maps'][-1] for sim in resolved_sims]
pcoords = [sim.data['pcoords'][-1] for sim in resolved_sims]

# Convert to int16
contact_maps = [x.astype(np.int16) for x in contact_maps]
Expand Down Expand Up @@ -156,3 +169,4 @@ def run_inference(
result = resampler.run(cur_sims, binner, recycler)

return result

32 changes: 23 additions & 9 deletions deepdrivewe/examples/openmm_ntl9_ddwe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from deepdrivewe.ai import ConvolutionalVAE
from deepdrivewe.ai import ConvolutionalVAEConfig

from proxystore.store import get_store

class TrainConfig(BaseModel):
"""Arguments for the training module."""
Expand All @@ -33,16 +34,36 @@ class TrainConfig(BaseModel):


def run_train(
sim_output: list[SimResult],
sim_output: list, # List of raw ProxyStore Keys
config: TrainConfig,
Comment thread
NikJur marked this conversation as resolved.
output_dir: Path,
) -> TrainResult:
"""Train the model on the simulation output."""
# Make the output directory
itetation = sim_output[0].metadata.iteration_id
Comment thread
NikJur marked this conversation as resolved.
output_dir = output_dir / f'{itetation:06d}'

# Manually resolve the keys using the registered 'file-store'
store = get_store('file-store')
if store is None:
raise RuntimeError("ProxyStore 'file-store' is not initialized on the worker.")

# store.get(key) retrieves the object without the destructive 'evict' behavior
resolved_sims = [store.get(key) for key in sim_output]
print(f"DEBUG: Successfully resolved {len(resolved_sims)} simulation objects", flush=True)

# Make the output directory using the first resolved object
iteration = resolved_sims[0].metadata.iteration_id
output_dir = output_dir / f'{iteration:06d}'
output_dir.mkdir(parents=True, exist_ok=True)

# Extract contact maps and pcoords from the resolved objects
contact_maps = np.concatenate(
[sim.data['contact_maps'] for sim in resolved_sims],
axis=0 # join along the frame/sample axis
)
Comment thread
NikJur marked this conversation as resolved.
pcoords = np.concatenate([sim.data['pcoords'] for sim in resolved_sims])
pcoords = pcoords.flatten()

# Load the model configuration
model_config = ConvolutionalVAEConfig.from_yaml(config.config_path)

Expand All @@ -52,13 +73,6 @@ def run_train(
checkpoint_path=config.checkpoint_path,
)

# Extract the last frame contact maps and rmsd from each simulation
contact_maps = np.concatenate(
[sim.data['contact_maps'] for sim in sim_output],
)
pcoords = np.concatenate([sim.data['pcoords'] for sim in sim_output])
pcoords = pcoords.flatten()

# Fit the model
checkpoint_path = model.fit(
x=contact_maps,
Expand Down
3 changes: 3 additions & 0 deletions deepdrivewe/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pydantic import Field
from pydantic import model_validator

from parsl.addresses import address_by_hostname

Comment thread
NikJur marked this conversation as resolved.

class BaseComputeConfig(BaseModel, ABC):
"""Compute config (HPC platform, number of GPUs, etc)."""
Expand Down Expand Up @@ -250,6 +252,7 @@ def _get_htex(self, label: str, num_nodes: int) -> HighThroughputExecutor:
return HighThroughputExecutor(
label=label,
available_accelerators=1, # 1 GH per node
address=address_by_hostname(), # dynamically set address from default 'localhost' to prevent IPv4 validation errors and ensure scaling
Comment thread
NikJur marked this conversation as resolved.
cores_per_worker=72,
cpu_affinity='alternating',
prefetch_capacity=0,
Expand Down
51 changes: 45 additions & 6 deletions deepdrivewe/workflows/ddwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def process_simulation_result(self, result: Result) -> None:
if not result.success:
self.logger.error(
f'Simulation failed after {result.retries}'
f'/{result.max_retries} attempts, quitting workflow.',
f' result={result}',
f'/{result.max_retries} attempts, quitting workflow.' # Hanging commata removed -> previously caused issues if error reported | syntax
f' result={result}'
)
Comment thread
NikJur marked this conversation as resolved.
self.done.set()
return
Expand All @@ -127,8 +127,30 @@ def process_simulation_result(self, result: Result) -> None:
# extract the proxied objects. The non-streaming case will
# need to extract and re-proxy the objects twice (once for
# the train task and once for the inference task).
output = result.value if self.streaming else extract(result.value)
self.sim_output.append(output)

# This method extracts results from the ProxyStore backend and re-registers
# them to prevent automated cache eviction, ensuring data remains available
# for both training and inference agents.

from proxystore.store import get_store
from proxystore.proxy import extract
Comment thread
NikJur marked this conversation as resolved.

# Initialise store and safely handle proxied data
store = get_store('file-store')
raw_data = result.value

if not self.streaming:
# Check if the result is a proxy before extraction to handle hybrid workflows
if hasattr(raw_data, '__proxy_wrapped__'):
raw_data = extract(raw_data)

# Re-register data as a persistent key to bypass default 'evict-on-read' behaviour
key = store.put(raw_data)

else:
key = raw_data

Comment thread
NikJur marked this conversation as resolved.
self.sim_output.append(key)

# If we have all the simulation results, submit a train task
if len(self.sim_output) == len(self.ensemble.next_sims):
Expand Down Expand Up @@ -160,8 +182,23 @@ def process_train_result(self, result: Result) -> None:
self.done.set()
return

# Store the training output
self.train_output = result.value
# This method ensures the trained model checkpoint persists by manually
Comment thread
NikJur marked this conversation as resolved.
# registering a non-evicting key in ProxyStore

from proxystore.store import get_store
from proxystore.proxy import extract

# Initialize store to handle model weight persistence
store = get_store('file-store')
raw_train_data = result.value

# SAFE EXTRACTION: Ensure we have a concrete object before re-registration.
# This prevents the inference task from encountering an evicted proxy.
if hasattr(raw_train_data, '__proxy_wrapped__'):
raw_train_data = extract(raw_train_data)

# Storing hard-copy key for inference
self.train_output = store.put(raw_train_data)

# TODO: What should we do in the streaming case?
# Does the process_train_result method even run?
Expand All @@ -182,6 +219,8 @@ def process_inference_result(self, result: Result) -> None:
self.done.set()
return

# Could add Safe extraction for best practice, but not stricly needed as final life step in proxy circle of life...
Comment thread
NikJur marked this conversation as resolved.

# Unpack the output
cur_sims, next_sims, metadata = result.value

Expand Down
Loading