Skip to content

[Tune] Ray Tune memory leak: RSS grows linearly with trial count in tune.run() #64231

Description

@zhangsikai123

What happened + What you expected to happen

Ray Tune memory leak: RSS grows linearly with trial count in tune.run()

Environment

  • Ray version: 2.55.1
  • Python: 3.10.13
  • OS: Debian GNU/Linux 12 (bookworm)

Description

When running a large number of trials via tune.run() in a single process (local mode, resources_per_trial={"cpu": 1}), the driver process RSS grows linearly with the number of completed trials, eventually leading to OOM.

Reproduction

"""
python test_ray_tune.py [--trials 4000] [--report-interval 100]
"""
import argparse, os, psutil, ray
from ray import tune
from ray.tune import Callback
from ray.tune.search.basic_variant import BasicVariantGenerator

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--trials", type=int, default=4000)
    p.add_argument("--report-interval", type=int, default=100)
    return p.parse_args()

class RssReporter(Callback):
    def __init__(self, interval, process=None):
        self.interval = interval
        self.process = process
        self.count = 0
    def on_trial_complete(self, iteration, trials, trial, **info):
        self.count += 1
        if self.count % self.interval: return
        rss = self.process.memory_info().rss if self.process else 0
        print(f"trial#{self.count}: RSS={rss / 1e9:.3f} GB")

def train_fn(config):
    x = config["x"]; y = config["y"]; z = config["z"]
    tune.report({"score": x**2 + y**2 + (ord(z) - 97)})

def main():
    args = parse_args()
    process = psutil.Process(os.getpid())

    tune.run(
        train_fn,
        config={"x": tune.uniform(-10, 10), "y": tune.uniform(-5, 5), "z": tune.choice(["a","b","c"])},
        metric="score", mode="min",
        num_samples=args.trials,
        search_alg=BasicVariantGenerator(),
        resources_per_trial={"cpu": 1}, verbose=1,
        callbacks=[RssReporter(args.report_interval, process)],
    )
    print(f"Final RSS: {process.memory_info().rss / 1e9:.3f} GB")
    ray.shutdown()

if __name__ == "__main__":
    main()

Observed behavior

$ python test_ray_tune.py --trials 4000 --report-interval 100 |grep RSS
   trial     RSS_GB
trial#100: RSS=0.672 GB
trial#200: RSS=0.680 GB
trial#300: RSS=0.686 GB
trial#400: RSS=0.692 GB
trial#500: RSS=0.702 GB
trial#600: RSS=0.704 GB
....
trial#1000: RSS=1.084 GB

RSS grows monotonically with trial count. Each batch of 100 completed trials adds a roughly constant amount of memory, suggesting per-trial state is retained and never released.

Expected behavior

For lightweight trials that each report a single scalar, the driver process RSS should remain relatively stable after initial ramp-up, not grow linearly to the point of OOM.

Impact

For workloads with thousands of trials (hyperparameter sweeps, NAS, etc.), the driver process can OOM before all trials complete. This blocks production hyperparameter search runs.

Versions / Dependencies

  • Ray version: 2.55.1
  • Python: 3.10.13
  • OS: Debian GNU/Linux 12 (bookworm)

Reproduction script

"""
python test_ray_tune.py [--trials 4000] [--report-interval 100]
"""
import argparse, os, psutil, ray
from ray import tune
from ray.tune import Callback
from ray.tune.search.basic_variant import BasicVariantGenerator

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--trials", type=int, default=4000)
    p.add_argument("--report-interval", type=int, default=100)
    return p.parse_args()

class RssReporter(Callback):
    def __init__(self, interval, process=None):
        self.interval = interval
        self.process = process
        self.count = 0
    def on_trial_complete(self, iteration, trials, trial, **info):
        self.count += 1
        if self.count % self.interval: return
        rss = self.process.memory_info().rss if self.process else 0
        print(f"trial#{self.count}: RSS={rss / 1e9:.3f} GB")

def train_fn(config):
    x = config["x"]; y = config["y"]; z = config["z"]
    tune.report({"score": x**2 + y**2 + (ord(z) - 97)})

def main():
    args = parse_args()
    process = psutil.Process(os.getpid())

    tune.run(
        train_fn,
        config={"x": tune.uniform(-10, 10), "y": tune.uniform(-5, 5), "z": tune.choice(["a","b","c"])},
        metric="score", mode="min",
        num_samples=args.trials,
        search_alg=BasicVariantGenerator(),
        resources_per_trial={"cpu": 1}, verbose=1,
        callbacks=[RssReporter(args.report_interval, process)],
    )
    print(f"Final RSS: {process.memory_info().rss / 1e9:.3f} GB")
    ray.shutdown()

if __name__ == "__main__":
    main()

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tcommunity-backlogstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)tuneTune-related issues

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status
    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions