diff --git a/benchmarks/mos2.py b/benchmarks/mos2.py index 7f9788e..f4eb1ff 100755 --- a/benchmarks/mos2.py +++ b/benchmarks/mos2.py @@ -2,227 +2,156 @@ import itertools import functools -import logging -import time import sys import json import typing as t -import polars +import numpy import pane +import pynvml -from phaser.utils.num import get_backend_module, to_numpy, Sampling -from phaser.plan import ReconsPlan, EnginePlan, EngineHook, BackendName -from phaser.state import ReconsState, IterState, PartialReconsState, Patterns -from phaser.execute import Observer, initialize_reconstruction, prepare_for_engine +from phaser.utils.num import get_backend_devices, get_backend_module, Sampling, set_default_device +from phaser.plan import ReconsPlan, EngineHook, BackendName +from phaser.state import PreparedRecons +from phaser.execute import execute_engine, initialize_reconstruction +N_WARMUP: int = 2 -class BenchmarkObserver(Observer): - def __init__(self, n_warmup: int = 2): - self.n_warmup: int = n_warmup - self.iter_times: t.List[float] = [] - super().__init__() - def update_iteration(self, state: t.Union[ReconsState, PartialReconsState], - i: int, n: int, error: t.Optional[float] = None): - finish_time = time.monotonic() +def sizeof_fmt(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Yi{suffix}" - if self.iter_start_time is not None: - delta = finish_time - self.iter_start_time - time_s = f" [{self._format_mmss(delta)}]" - if i > self.n_warmup: - self.iter_times.append(delta) - else: - time_s = "" - - w = len(str(n)) - - error_s = f" Error: {error:.3e}" if error is not None else "" - logging.info(f"Finished iter {i:{w}}/{n}{time_s}{error_s}") +def print_memory_usage(file=None): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + print(f"GPU memory usage: {sizeof_fmt(info.used)}/{sizeof_fmt(info.total)}", file=file) - state.iter = IterState(self.engine_i, i + 1, self.start_iter + i + 1) - self.iter_start_time = finish_time - -@functools.lru_cache(1) -def initialize(sim_size: int = 128) -> t.Tuple[ReconsPlan, Patterns, ReconsState]: +#@functools.lru_cache(1) +def initialize(sim_size: int = 128) -> t.Tuple[PreparedRecons, ReconsPlan]: plan = ReconsPlan.from_data({ "name": "mos2_grad", "backend": "jax", 'dtype': 'float32', 'raw_data': { 'type': 'empad', - 'path': '~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0_x64_y64_4DSTEM.raw', - 'diff_step': 1.0, - 'kv': 120.0 + 'path': '~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json', }, 'post_load': [ {'type': 'poisson', 'scale': 1.0e6}, ], - 'init_probe': {'type': 'focused', 'conv_angle': 25.0, 'defocus': 300.0}, - 'init_object': 'random', - 'init_scan': {'type': 'raster', 'shape': (64, 64), 'step_size': 0.6}, 'post_init': [], 'engines': [], }) - xp = get_backend_module(plan.backend) - (patterns, state) = initialize_reconstruction(plan, xp, Observer()) + recons = initialize_reconstruction(plan) if sim_size != 128: # pad reconstruction - new_sampling = Sampling((sim_size, sim_size), extent=tuple(state.probe.sampling.extent)) + new_sampling = Sampling((sim_size, sim_size), extent=tuple(recons.state.probe.sampling.extent)) print(f"Resampling probe and patterns to shape {new_sampling.shape}...", file=sys.stderr, flush=True) - state.probe.data = state.probe.sampling.resample(state.probe.data, new_sampling) - patterns.patterns = state.probe.sampling.resample_recip(patterns.patterns, new_sampling) - patterns.pattern_mask = state.probe.sampling.resample_recip(patterns.pattern_mask, new_sampling) - state.probe.sampling = new_sampling + recons.state.probe.data = recons.state.probe.sampling.resample(recons.state.probe.data, new_sampling) + recons.patterns.patterns = recons.state.probe.sampling.resample_recip(recons.patterns.patterns, new_sampling) + recons.patterns.pattern_mask = recons.state.probe.sampling.resample_recip(recons.patterns.pattern_mask, new_sampling) + recons.state.probe.sampling = new_sampling - return (plan, patterns, state.to_numpy()) + return (recons.to_numpy(), plan) -def benchmark_lsqml(grouping: int, sim_size: int, backend: BackendName) -> t.List[float]: - (plan, patterns, init_state) = initialize(sim_size) +def benchmark_grad( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) xp = get_backend_module(backend) + recons = recons.to_xp(xp) - engine = pane.convert({ - 'type': 'conventional', - 'probe_modes': 4, - 'niter': 12, - 'grouping': grouping, - 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, - 'solver': { - 'type': 'lsqml', - 'gamma': 1.0e-4, - }, - 'iter_constraints': [], - 'group_constraints': [ - {'type': 'clamp_object_amplitude', 'amplitude': 1.1}, - ], - 'update_probe': True, - 'update_object': True, - 'update_positions': False, - }, EngineHook) - - observer = BenchmarkObserver() - - (patterns, state) = prepare_for_engine(patterns, init_state, xp, t.cast(EnginePlan, engine.props)) - - state = engine({ - 'data': patterns, - 'state': state, - 'dtype': patterns.patterns.dtype, - 'xp': xp, - 'recons_name': plan.name, - 'seed': None, - 'engine_i': 0, - 'observer': observer - }) - - iter_times = observer.iter_times - print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) - return iter_times - - -def benchmark_grad(grouping: int, sim_size: int) -> t.List[float]: - (plan, patterns, init_state) = initialize(sim_size) - xp = get_backend_module('jax') + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(devices[0], xp) engine = pane.convert({ 'type': 'gradient', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, 'probe_modes': 4, - 'niter': 12, + 'niter': 15, 'grouping': grouping, 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, 'solvers': { 'object': { - 'type': 'sgd', - 'learning_rate': 1.0, - 'momentum': 0.99, + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, }, 'probe': { - 'type': 'sgd', + 'type': 'adam', 'learning_rate': 1.0e-3, - 'momentum': 0.99, + 'nesterov': True, }, }, 'regularizers': [ - {'type': 'obj_l1', 'cost': 15.0}, ], - 'iter_constraints': [ ], + 'iter_constraints': [ + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], 'group_constraints': [ - {'type': 'clamp_object_amplitude', 'amplitude': 1.1}, ], 'update_probe': True, 'update_object': True, 'update_positions': False, + 'save': False, 'save_images': False, }, EngineHook) - observer = BenchmarkObserver() - - (patterns, state) = prepare_for_engine(patterns, init_state, xp, t.cast(EnginePlan, engine.props)) + recons = execute_engine(recons, engine) - state = engine({ - 'data': patterns, - 'state': state, - 'dtype': patterns.patterns.dtype, - 'xp': xp, - 'recons_name': plan.name, - 'seed': None, - 'engine_i': 0, - 'observer': observer - }) + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] - iter_times = observer.iter_times + print(f"Iter times: {iter_times}", file=sys.stderr) print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) return iter_times if __name__ == '__main__': + pynvml.nvmlInit() import jax device_name = jax.devices()[0].device_kind print(f"device: {device_name}", file=sys.stderr) - for sim_size, backend, grouping in itertools.product((128, 192), ('cupy', 'jax'), (16, 32, 64, 128)): - #for sim_size, backend, grouping in itertools.product((128,), ('cupy', 'jax'), (128,)): - try: - iter_times = benchmark_lsqml(grouping, sim_size, backend) - except Exception as e: - print(f"Failed to run, error:\n{e}", file=sys.stderr) + backend = 'jax' + + for sim_size, unroll, grouping in itertools.product((128, 192), (5,), (8, 4, 16, 32, 64, 128, 256, 512, 1024)): + if backend == 'jax': + import jax.version + backend_version = jax.version.__version__ else: - json.dump({ - 'engine': 'lsqml', - 'backend': backend, - 'sim_size': sim_size, - 'n_positions': 4096, - 'n_slices': 1, - 'grouping': grouping, - 'device': device_name, - 'code': 'v3', - 'iter_times': iter_times, - }, sys.stdout) - sys.stdout.write("\n") - sys.stdout.flush() + raise NotImplementedError() - for sim_size, grouping in itertools.product((128, 192), (16, 32, 64, 128)): - #for sim_size, grouping in itertools.product((128,), (128,)): + print(f"\nRunning grad, sim_size={sim_size} backend={backend!r} grouping={grouping}...", file=sys.stderr) + print_memory_usage(file=sys.stderr) try: - iter_times = benchmark_grad(grouping, sim_size) + iter_times = benchmark_grad(grouping, sim_size, backend, unroll=unroll) except Exception as e: print(f"Failed to run, error:\n{e}", file=sys.stderr) else: - json.dump({ 'engine': 'grad', - 'backend': 'jax', + 'backend': backend, + 'backend_version': backend_version, 'sim_size': sim_size, - 'n_positions': 4096, + 'n_positions': 64*64, 'n_slices': 1, + 'n_modes': 4, 'grouping': grouping, 'device': device_name, - 'code': 'v3', + 'code': 'v5_unroll5', 'iter_times': iter_times, }, sys.stdout) sys.stdout.write("\n") diff --git a/benchmarks/mos2_lsqml.py b/benchmarks/mos2_lsqml.py new file mode 100755 index 0000000..d2c4ef6 --- /dev/null +++ b/benchmarks/mos2_lsqml.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 + +import itertools +import functools +import sys +import json +import typing as t + +import numpy +import pane +import pynvml + +from phaser.utils.num import get_backend_devices, get_backend_module, Sampling, set_default_device, to_device +from phaser.plan import ReconsPlan, EngineHook, BackendName +from phaser.state import PreparedRecons +from phaser.execute import execute_engine, initialize_reconstruction + +N_WARMUP: int = 2 + + +def sizeof_fmt(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Yi{suffix}" + + +def print_memory_usage(file=None): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + print(f"GPU memory usage: {sizeof_fmt(info.used)}/{sizeof_fmt(info.total)}", file=file) + + +#@functools.lru_cache(1) +def initialize(sim_size: int = 128) -> t.Tuple[PreparedRecons, ReconsPlan]: + plan = ReconsPlan.from_data({ + "name": "mos2_grad", + "backend": "jax", + 'dtype': 'float32', + 'raw_data': { + 'type': 'empad', + 'path': '~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json', + }, + 'post_load': [ + {'type': 'poisson', 'scale': 1.0e6}, + ], + 'post_init': [], + 'engines': [], + }) + + recons = initialize_reconstruction(plan) + + if sim_size != 128: + # pad reconstruction + new_sampling = Sampling((sim_size, sim_size), extent=tuple(recons.state.probe.sampling.extent)) + print(f"Resampling probe and patterns to shape {new_sampling.shape}...", file=sys.stderr, flush=True) + recons.state.probe.data = recons.state.probe.sampling.resample(recons.state.probe.data, new_sampling) + recons.patterns.patterns = recons.state.probe.sampling.resample_recip(recons.patterns.patterns, new_sampling) + recons.patterns.pattern_mask = recons.state.probe.sampling.resample_recip(recons.patterns.pattern_mask, new_sampling) + recons.state.probe.sampling = new_sampling + + return (recons.to_numpy(), plan) + + +def benchmark_grad( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) + xp = get_backend_module(backend) + recons = recons.to_xp(xp) + + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(to_device(devices[0], xp), xp) + + engine = pane.convert({ + 'type': 'gradient', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, + 'probe_modes': 4, + 'niter': 15, + 'grouping': grouping, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, + 'solvers': { + 'object': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + 'probe': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + }, + 'regularizers': [ + ], + 'iter_constraints': [ + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], + 'group_constraints': [ + ], + 'update_probe': True, + 'update_object': True, + 'update_positions': False, + 'save': False, 'save_images': False, + }, EngineHook) + + recons = execute_engine(recons, engine) + + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] + + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times + + +def benchmark_lsqml( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) + xp = get_backend_module(backend) + recons = recons.to_xp(xp) + + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(to_device(devices[0], xp), xp) + + engine = pane.convert({ + 'type': 'conventional', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, + 'probe_modes': 4, + 'niter': 15, + 'grouping': grouping, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-1}, + 'solver': { + 'type': 'lsqml', + 'beta_probe': 0.1, + 'beta_object': 0.1, + 'gamma': 1.0e-4, + 'illum_reg_object': 1.0e-2, + 'illum_reg_probe': 1.0e-2, + }, + 'position_solver': { + 'type': 'momentum', + 'momentum': 0.90, + 'step_size': 8.0e-2, + 'max_step_size': 0.2, + }, + 'iter_constraints': [ + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], + 'group_constraints': [ + ], + 'update_probe': True, + 'update_object': True, + 'update_positions': False, + 'save': False, 'save_images': False, + }, EngineHook) + + recons = execute_engine(recons, engine) + + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] + + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times + + +if __name__ == '__main__': + pynvml.nvmlInit() + import jax + + device_name = jax.devices()[0].device_kind + print(f"device: {device_name}", file=sys.stderr) + + for backend in ('cupy', 'jax'): + for sim_size, unroll, grouping in itertools.product((128, 192), (5,), (8, 4, 16, 32, 64, 128, 256, 512, 1024)): + if backend == 'jax': + import jax.version + backend_version = jax.version.__version__ + else: + import cupy + backend_version = cupy.__version__ + + print(f"\nRunning lsqml, sim_size={sim_size} backend={backend!r} grouping={grouping}...", file=sys.stderr) + print_memory_usage(file=sys.stderr) + try: + iter_times = benchmark_lsqml(grouping, sim_size, backend, unroll=unroll) + except Exception as e: + print(f"Failed to run, error:\n{e}", file=sys.stderr) + else: + json.dump({ + 'engine': 'lsqml', + 'backend': backend, + 'backend_version': backend_version, + 'sim_size': sim_size, + 'n_positions': 64*64, + 'n_slices': 1, + 'n_modes': 4, + 'grouping': grouping, + 'device': device_name, + 'code': 'v5_unroll5', + 'iter_times': iter_times, + }, sys.stdout) + sys.stdout.write("\n") + sys.stdout.flush() diff --git a/benchmarks/si.py b/benchmarks/si.py index 648b022..abe2116 100755 --- a/benchmarks/si.py +++ b/benchmarks/si.py @@ -2,232 +2,159 @@ import itertools import functools -import logging -import time import sys import json import typing as t -import polars +import numpy import pane +import pynvml -from phaser.utils.num import get_backend_module, to_numpy, Sampling -from phaser.plan import ReconsPlan, EnginePlan, EngineHook, BackendName -from phaser.state import ReconsState, IterState, PartialReconsState, Patterns -from phaser.execute import Observer, initialize_reconstruction, prepare_for_engine +from phaser.utils.num import get_backend_devices, get_backend_module, Sampling, set_default_device +from phaser.plan import ReconsPlan, EngineHook, BackendName +from phaser.state import PreparedRecons +from phaser.execute import execute_engine, initialize_reconstruction +N_WARMUP: int = 2 -class BenchmarkObserver(Observer): - def __init__(self, n_warmup: int = 2): - self.n_warmup: int = n_warmup - self.iter_times: t.List[float] = [] - super().__init__() - def update_iteration(self, state: t.Union[ReconsState, PartialReconsState], - i: int, n: int, error: t.Optional[float] = None): - finish_time = time.monotonic() +def sizeof_fmt(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Yi{suffix}" - if self.iter_start_time is not None: - delta = finish_time - self.iter_start_time - time_s = f" [{self._format_mmss(delta)}]" - if i > self.n_warmup: - self.iter_times.append(delta) - else: - time_s = "" - - w = len(str(n)) - - error_s = f" Error: {error:.3e}" if error is not None else "" - logging.info(f"Finished iter {i:{w}}/{n}{time_s}{error_s}") +def print_memory_usage(file=None): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + print(f"GPU memory usage: {sizeof_fmt(info.used)}/{sizeof_fmt(info.total)}", file=file) - state.iter = IterState(self.engine_i, i + 1, self.start_iter + i + 1) - self.iter_start_time = finish_time - -@functools.lru_cache(1) -def initialize(sim_size: int = 128) -> t.Tuple[ReconsPlan, Patterns, ReconsState]: +#@functools.lru_cache(1) +def initialize(sim_size: int = 128) -> t.Tuple[PreparedRecons, ReconsPlan]: plan = ReconsPlan.from_data({ "name": "si", "backend": "jax", 'dtype': 'float32', 'raw_data': { 'type': 'empad', - 'path': "~/Downloads/si-final4/Si_110_Sn_300kV_conv25_defocus20_tds/Si_110_Sn_300kV_conv25_defocus20_tds_199.70_dstep0.8_x80_y80_4DSTEM.raw", - 'diff_step': 0.8, - 'kv': 300.0 + 'path': '~/Downloads/si-final4/Si_110_Sn_300kV_conv25_defocus20_tds/Si_110_Sn_300kV_conv25_defocus20_tds_199.70_dstep0.8.json', }, 'post_load': [ - {'type': 'poisson', 'scale': 14.5e6}, + {'type': 'poisson', 'scale': 1.0e6}, ], - 'init_probe': {'type': 'focused', 'conv_angle': 25.0, 'defocus': 200.0}, - 'init_object': 'random', - 'init_scan': {'type': 'raster', 'shape': (80, 80), 'step_size': 0.3}, 'post_init': [], 'slices': {'n': 10, 'total_thickness': 200.0}, 'engines': [], }) - xp = get_backend_module(plan.backend) - (patterns, state) = initialize_reconstruction(plan, xp, Observer()) + recons = initialize_reconstruction(plan) if sim_size != 128: # pad reconstruction - new_sampling = Sampling((sim_size, sim_size), extent=tuple(state.probe.sampling.extent)) + new_sampling = Sampling((sim_size, sim_size), extent=tuple(recons.state.probe.sampling.extent)) print(f"Resampling probe and patterns to shape {new_sampling.shape}...", file=sys.stderr, flush=True) - state.probe.data = state.probe.sampling.resample(state.probe.data, new_sampling) - patterns.patterns = state.probe.sampling.resample_recip(patterns.patterns, new_sampling) - patterns.pattern_mask = state.probe.sampling.resample_recip(patterns.pattern_mask, new_sampling) - state.probe.sampling = new_sampling + recons.state.probe.data = recons.state.probe.sampling.resample(recons.state.probe.data, new_sampling) + recons.patterns.patterns = recons.state.probe.sampling.resample_recip(recons.patterns.patterns, new_sampling) + recons.patterns.pattern_mask = recons.state.probe.sampling.resample_recip(recons.patterns.pattern_mask, new_sampling) + recons.state.probe.sampling = new_sampling - return (plan, patterns, state.to_numpy()) + return (recons.to_numpy(), plan) -def benchmark_lsqml(grouping: int, sim_size: int, backend: BackendName) -> t.List[float]: - (plan, patterns, init_state) = initialize(sim_size) +def benchmark_grad( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) xp = get_backend_module(backend) + recons = recons.to_xp(xp) - engine = pane.convert({ - 'type': 'conventional', - 'probe_modes': 4, - 'niter': 12, - 'grouping': grouping, - 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, - 'solver': { - 'type': 'lsqml', - 'gamma': 1.0e-4, - }, - 'iter_constraints': [ - {'type': 'layers', 'sigma': 100.0, 'weight': 0.8}, - ], - 'group_constraints': [ - {'type': 'clamp_object_amplitude', 'amplitude': 1.1}, - ], - 'update_probe': True, - 'update_object': True, - 'update_positions': False, - }, EngineHook) - - observer = BenchmarkObserver() - - (patterns, state) = prepare_for_engine(patterns, init_state, xp, t.cast(EnginePlan, engine.props)) - - state = engine({ - 'data': patterns, - 'state': state, - 'dtype': patterns.patterns.dtype, - 'xp': xp, - 'recons_name': plan.name, - 'seed': None, - 'engine_i': 0, - 'observer': observer - }) - - return observer.iter_times - - -def benchmark_grad(grouping: int, sim_size: int) -> t.List[float]: - (plan, patterns, init_state) = initialize(sim_size) - xp = get_backend_module('jax') + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(devices[0], xp) engine = pane.convert({ 'type': 'gradient', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, 'probe_modes': 4, - 'niter': 12, + 'niter': 15, 'grouping': grouping, - 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-1}, 'solvers': { 'object': { - 'type': 'sgd', - 'learning_rate': 1.0, - 'momentum': 0.99, + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, }, 'probe': { - 'type': 'sgd', + 'type': 'adam', 'learning_rate': 1.0e-3, - 'momentum': 0.99, + 'nesterov': True, }, }, 'regularizers': [ - {'type': 'obj_l1', 'cost': 15.0}, ], 'iter_constraints': [ - {'type': 'layers', 'sigma': 100.0, 'weight': 0.8}, + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, ], 'group_constraints': [ - {'type': 'clamp_object_amplitude', 'amplitude': 1.1}, ], 'update_probe': True, 'update_object': True, 'update_positions': False, + 'save': False, 'save_images': False, }, EngineHook) - observer = BenchmarkObserver() + recons = execute_engine(recons, engine) - (patterns, state) = prepare_for_engine(patterns, init_state, xp, t.cast(EnginePlan, engine.props)) + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] - state = engine({ - 'data': patterns, - 'state': state, - 'dtype': patterns.patterns.dtype, - 'xp': xp, - 'recons_name': plan.name, - 'seed': None, - 'engine_i': 0, - 'observer': observer - }) - - return observer.iter_times + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times if __name__ == '__main__': + pynvml.nvmlInit() import jax device_name = jax.devices()[0].device_kind print(f"device: {device_name}", file=sys.stderr) - for sim_size, backend, grouping in itertools.product((128, 192), ('cupy', 'jax'), (16, 32, 64, 128)): - try: - iter_times = benchmark_lsqml(grouping, sim_size, backend) - except Exception as e: - print(f"Failed to run, error:\n{e}", file=sys.stderr) + backend = 'jax' + + for sim_size, unroll, grouping in itertools.product((128,), (5, False), (8, 4, 16, 32, 64, 128, 256, 512, 1024)): + if backend == 'jax': + import jax.version + backend_version = jax.version.__version__ else: - json.dump({ - 'engine': 'lsqml', - 'backend': backend, - 'sim_size': sim_size, - 'n_positions': 80*80, - 'n_slices': 10, - 'grouping': grouping, - 'device': device_name, - 'code': 'v1', - 'iter_times': iter_times, - }, sys.stdout) - sys.stdout.write("\n") - sys.stdout.flush() + raise NotImplementedError() - """ - for sim_size, grouping in itertools.product([128, 192], [16, 32, 64, 128]): - print(f"Running {sim_size}x{sim_size}, grouping {grouping}", file=sys.stderr) + print(f"\nRunning grad, sim_size={sim_size} backend={backend!r} grouping={grouping}...", file=sys.stderr) + print_memory_usage(file=sys.stderr) try: - iter_times = benchmark_grad(grouping, sim_size) + iter_times = benchmark_grad(grouping, sim_size, backend, unroll=unroll) except Exception as e: print(f"Failed to run, error:\n{e}", file=sys.stderr) else: json.dump({ 'engine': 'grad', - 'backend': 'jax', + 'backend': backend, + 'backend_version': backend_version, 'sim_size': sim_size, 'n_positions': 80*80, 'n_slices': 10, + 'n_modes': 4, 'grouping': grouping, 'device': device_name, - 'code': 'v2', + 'code': 'v5_unroll5' if unroll else 'v5', 'iter_times': iter_times, }, sys.stdout) sys.stdout.write("\n") sys.stdout.flush() - """ - #df = polars.DataFrame(rows, orient='row') - #df.write_ndjson(sys.stdout) # type: ignore \ No newline at end of file diff --git a/benchmarks/si_lsqml.py b/benchmarks/si_lsqml.py new file mode 100755 index 0000000..2585a45 --- /dev/null +++ b/benchmarks/si_lsqml.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +import itertools +import functools +import sys +import json +import typing as t + +import numpy +import pane +import pynvml + +from phaser.utils.num import get_backend_devices, get_backend_module, Sampling, set_default_device, to_device +from phaser.plan import ReconsPlan, EngineHook, BackendName +from phaser.state import PreparedRecons +from phaser.execute import execute_engine, initialize_reconstruction + +N_WARMUP: int = 2 + + +def sizeof_fmt(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Yi{suffix}" + + +def print_memory_usage(file=None): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + print(f"GPU memory usage: {sizeof_fmt(info.used)}/{sizeof_fmt(info.total)}", file=file) + + +#@functools.lru_cache(1) +def initialize(sim_size: int = 128) -> t.Tuple[PreparedRecons, ReconsPlan]: + plan = ReconsPlan.from_data({ + "name": "si", + "backend": "jax", + 'dtype': 'float32', + 'raw_data': { + 'type': 'empad', + 'path': '~/Downloads/si-final4/Si_110_Sn_300kV_conv25_defocus20_tds/Si_110_Sn_300kV_conv25_defocus20_tds_199.70_dstep0.8.json', + }, + 'post_load': [ + {'type': 'poisson', 'scale': 1.0e6}, + ], + 'post_init': [], + 'slices': {'n': 10, 'total_thickness': 200.0}, + 'engines': [], + }) + + recons = initialize_reconstruction(plan) + + if sim_size != 128: + # pad reconstruction + new_sampling = Sampling((sim_size, sim_size), extent=tuple(recons.state.probe.sampling.extent)) + print(f"Resampling probe and patterns to shape {new_sampling.shape}...", file=sys.stderr, flush=True) + recons.state.probe.data = recons.state.probe.sampling.resample(recons.state.probe.data, new_sampling) + recons.patterns.patterns = recons.state.probe.sampling.resample_recip(recons.patterns.patterns, new_sampling) + recons.patterns.pattern_mask = recons.state.probe.sampling.resample_recip(recons.patterns.pattern_mask, new_sampling) + recons.state.probe.sampling = new_sampling + + return (recons.to_numpy(), plan) + + +def benchmark_grad( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) + xp = get_backend_module(backend) + recons = recons.to_xp(xp) + + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(to_device(devices[0], xp), xp) + + engine = pane.convert({ + 'type': 'gradient', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, + 'probe_modes': 4, + 'niter': 15, + 'grouping': grouping, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-4}, + 'solvers': { + 'object': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + 'probe': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + }, + 'regularizers': [ + ], + 'iter_constraints': [ + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], + 'group_constraints': [ + ], + 'update_probe': True, + 'update_object': True, + 'update_positions': False, + 'save': False, 'save_images': False, + }, EngineHook) + + recons = execute_engine(recons, engine) + + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] + + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times + + +def benchmark_lsqml( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) + xp = get_backend_module(backend) + recons = recons.to_xp(xp) + + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(to_device(devices[0], xp), xp) + + engine = pane.convert({ + 'type': 'conventional', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, + 'probe_modes': 4, + 'niter': 15, + 'grouping': grouping, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-1}, + 'solver': { + 'type': 'lsqml', + 'beta_probe': 0.1, + 'beta_object': 0.1, + 'gamma': 1.0e-4, + 'illum_reg_object': 1.0e-2, + 'illum_reg_probe': 1.0e-2, + }, + 'position_solver': { + 'type': 'momentum', + 'momentum': 0.90, + 'step_size': 8.0e-2, + 'max_step_size': 0.2, + }, + 'iter_constraints': [ + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], + 'group_constraints': [ + ], + 'update_probe': True, + 'update_object': True, + 'update_positions': False, + 'save': False, 'save_images': False, + }, EngineHook) + + recons = execute_engine(recons, engine) + + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] + + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times + + +if __name__ == '__main__': + pynvml.nvmlInit() + import jax + + device_name = jax.devices()[0].device_kind + print(f"device: {device_name}", file=sys.stderr) + + for backend in ('jax', 'cupy'): + for sim_size, unroll, grouping in itertools.product((128, 192), (5,), (8, 4, 16, 32, 64, 128, 256, 512, 1024)): + if backend == 'jax': + import jax.version + backend_version = jax.version.__version__ + else: + import cupy + backend_version = cupy.__version__ + + print(f"\nRunning lsqml, sim_size={sim_size} backend={backend!r} grouping={grouping}...", file=sys.stderr) + print_memory_usage(file=sys.stderr) + try: + iter_times = benchmark_lsqml(grouping, sim_size, backend, unroll=unroll) + except Exception as e: + print(f"Failed to run, error:\n{e}", file=sys.stderr) + else: + json.dump({ + 'engine': 'lsqml', + 'backend': backend, + 'backend_version': backend_version, + 'sim_size': sim_size, + 'n_positions': 80*80, + 'n_slices': 10, + 'n_modes': 4, + 'grouping': grouping, + 'device': device_name, + 'code': 'v5_unroll5', + 'iter_times': iter_times, + }, sys.stdout) + sys.stdout.write("\n") + sys.stdout.flush() diff --git a/benchmarks/wse2.py b/benchmarks/wse2.py new file mode 100644 index 0000000..a42473e --- /dev/null +++ b/benchmarks/wse2.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +import itertools +import functools +import sys +import json +import typing as t + +import numpy +import pane +import pynvml + +from phaser.utils.num import get_backend_devices, get_backend_module, Sampling, set_default_device +from phaser.plan import ReconsPlan, EngineHook, BackendName +from phaser.state import PreparedRecons +from phaser.execute import execute_engine, initialize_reconstruction + +N_WARMUP: int = 2 + + +def sizeof_fmt(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Yi{suffix}" + + +def print_memory_usage(file=None): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + print(f"GPU memory usage: {sizeof_fmt(info.used)}/{sizeof_fmt(info.total)}", file=file) + + +#@functools.lru_cache(1) +def initialize(sim_size: int = 128) -> t.Tuple[PreparedRecons, ReconsPlan]: + plan = ReconsPlan.from_data({ + "name": "si", + "backend": "jax", + 'dtype': 'float32', + 'raw_data': { + 'type': 'empad', + 'path': '~/Downloads/ptyrad_paper/00_data/tBL_WSe2/Panel_g-h_Themis/WSe2.json', + }, + 'post_load': [], + 'post_init': [], + 'slices': {'n': 6, 'total_thickness': 60.0}, + 'engines': [], + }) + + recons = initialize_reconstruction(plan) + + if sim_size != 128: + # pad reconstruction + new_sampling = Sampling((sim_size, sim_size), extent=tuple(recons.state.probe.sampling.extent)) + print(f"Resampling probe and patterns to shape {new_sampling.shape}...", file=sys.stderr, flush=True) + recons.state.probe.data = recons.state.probe.sampling.resample(recons.state.probe.data, new_sampling) + recons.patterns.patterns = recons.state.probe.sampling.resample_recip(recons.patterns.patterns, new_sampling) + recons.patterns.pattern_mask = recons.state.probe.sampling.resample_recip(recons.patterns.pattern_mask, new_sampling) + recons.state.probe.sampling = new_sampling + + return (recons.to_numpy(), plan) + + +def benchmark_grad( + grouping: int, sim_size: int, backend: BackendName, + unroll: t.Union[int, bool] = 10, +) -> t.List[float]: + (recons, plan) = initialize(sim_size) + xp = get_backend_module(backend) + recons = recons.to_xp(xp) + + devices = get_backend_devices(xp) + print(f"Available devices: {list(devices)}", file=sys.stderr) + print(f"Using device '{devices[0]}'", file=sys.stderr) + set_default_device(devices[0], xp) + + engine = pane.convert({ + 'type': 'gradient', + 'buffer_n_groups': 16 if grouping < 256 else 2, + 'jit_unroll_slices': unroll, + 'probe_modes': 6, + 'niter': 15, + 'grouping': grouping, + 'noise_model': {'type': 'amplitude', 'eps': 1.0e-1}, + 'solvers': { + 'object': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + 'probe': { + 'type': 'adam', + 'learning_rate': 1.0e-3, + 'nesterov': True, + }, + }, + 'regularizers': [ + ], + 'iter_constraints': [ + {'type': 'layers', 'sigma': 20.0, 'weight': 0.9}, + {'type': 'clamp_object_amplitude', 'amplitude': 1.0}, + ], + 'group_constraints': [ + ], + 'update_probe': True, + 'update_object': True, + 'update_positions': False, + 'save': False, 'save_images': False, + }, EngineHook) + + """ + import h5py + from phaser.utils.optics import make_hermetian_modes + + recons.state.probe.data = make_hermetian_modes( + numpy.sum(recons.state.probe.data, axis=0), 8 + ) + recons.state.write_hdf5("init_state.h5") + + f = h5py.File("raw_data.h5", mode='w') + try: + f.create_dataset('patterns', data=recons.patterns.patterns) + f.create_dataset('pattern_mask', data=recons.patterns.pattern_mask) + finally: + f.close() + """ + + recons = execute_engine(recons, engine) + + iter_times: t.List[float] = numpy.diff(recons.state.progress['time'].values).tolist()[N_WARMUP:] + + print(f"Iter times: {iter_times}", file=sys.stderr) + print(f"Mean time: {sum(iter_times) / len(iter_times):.3f} s", file=sys.stderr) + return iter_times + + +if __name__ == '__main__': + pynvml.nvmlInit() + import jax + + device_name = jax.devices()[0].device_kind + print(f"device: {device_name}", file=sys.stderr) + + backend = 'jax' + + for sim_size, unroll, grouping in itertools.product((128,), (True, False), (4, 8, 16, 32, 64, 128, 256, 512, 1024)): + if backend == 'jax': + import jax.version + backend_version = jax.version.__version__ + else: + raise NotImplementedError() + + print(f"\nRunning grad, sim_size={sim_size} backend={backend!r} grouping={grouping}...", file=sys.stderr) + print_memory_usage(file=sys.stderr) + try: + iter_times = benchmark_grad(grouping, sim_size, backend, unroll=unroll) + except Exception as e: + print(f"Failed to run, error:\n{e}", file=sys.stderr) + else: + json.dump({ + 'engine': 'grad', + 'backend': backend, + 'backend_version': backend_version, + 'sim_size': sim_size, + 'n_positions': 128*128, + 'n_slices': 6, + 'n_modes': 6, + 'grouping': grouping, + 'device': device_name, + 'code': 'v5_unroll' if unroll else 'v5', + 'iter_times': iter_times, + }, sys.stdout) + sys.stdout.write("\n") + sys.stdout.flush() diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 416186e..d8a94f2 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -41,7 +41,7 @@ def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None return (sim, None) -@partial(jit, donate_argnames=('obj',), cupy_fuse=True) +@partial(jit, donate_argnames=('obj',), cupy_fuse=False) def clamp_amplitude( obj: NDArray[numpy.complexfloating], min: t.Union[float, numpy.floating, None], diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 2264805..7e5fc98 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -1,7 +1,7 @@ import logging from phaser.utils.misc import mask_fraction_of_groups -from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype +from phaser.utils.num import assert_dtype, cast_array_module, check_finite, to_numpy, to_complex_dtype from phaser.observer import Observer from phaser.hooks import EngineArgs from phaser.plan import ConventionalEnginePlan @@ -35,14 +35,20 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: group_constraints=group_constraints, iter_constraints=iter_constraints, xp=xp, dtype=dtype ) - patterns = args['data'].patterns - pattern_mask = xp.asarray(args['data'].pattern_mask) - - assert_dtype(patterns, dtype) - assert_dtype(pattern_mask, dtype) + assert_dtype(args['data'].patterns, dtype) + assert_dtype(args['data'].pattern_mask, dtype) assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) + pattern_mask = xp.asarray(args['data'].pattern_mask) + # load/stream patterns + if props.buffer_n_groups is None: + logging.info("Loading raw data to GPU ('buffer_n_groups' is disabled)...") + patterns = xp.asarray(args['data'].patterns) + else: + logging.info(f"Streaming raw data to GPU (buffering {props.buffer_n_groups} groups)") + patterns = args['data'].patterns + solver = props.solver(props) sim = solver.init(sim) groups = GroupManager(sim.state.scan, props.grouping, props.compact, seed=seed) @@ -100,6 +106,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: ) assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) + check_finite(sim.state.object.data, sim.state.probe.data, context=f"NaN or inf encountered, iteration {i}") sim = sim.apply_iter_constraints() diff --git a/phaser/engines/conventional/solvers.py b/phaser/engines/conventional/solvers.py index d689d95..b442b35 100644 --- a/phaser/engines/conventional/solvers.py +++ b/phaser/engines/conventional/solvers.py @@ -5,7 +5,7 @@ import numpy from numpy.typing import NDArray -from phaser.utils.num import cast_array_module, at, abs2, fft2, ifft2, jit, check_finite, to_complex_dtype, to_numpy +from phaser.utils.num import cast_array_module, at, abs2, fft2, ifft2, jit, check_finite, to_complex_dtype, to_numpy, xp_is_jax from phaser.hooks.solver import ConventionalSolver from phaser.types import process_schedule from phaser.plan import ConventionalEnginePlan, LSQMLSolverPlan, EPIESolverPlan @@ -16,9 +16,9 @@ class LSQMLSolver(ConventionalSolver): - def __init__(self, plan: ConventionalEnginePlan, props: LSQMLSolverPlan): + def __init__(self, engine_plan: ConventionalEnginePlan, props: LSQMLSolverPlan): self.plan: LSQMLSolverPlan = props - self.engine_plan: ConventionalEnginePlan = plan + self.engine_plan: ConventionalEnginePlan = engine_plan @classmethod def name(cls) -> str: @@ -31,8 +31,24 @@ def init(self, sim: SimulationState) -> SimulationState: self.obj_mag: NDArray[numpy.floating] = xp.zeros(sim.state.probe.data.shape[-2:], dtype=sim.dtype) self.probe_mag: NDArray[numpy.floating] = xp.zeros_like(sim.state.object.data, dtype=sim.dtype) + if self.engine_plan.jit_unroll_slices and xp_is_jax(xp): + self.logger.warning(f"'jit_unroll_slices' set to '{self.engine_plan.jit_unroll_slices!r}'. " + "This can result in slow compilation, use with care.") + self.jit_unroll_slices = False if self.engine_plan.jit_unroll_slices is None else self.engine_plan.jit_unroll_slices + return sim + def iter_patterns( + self, groups: t.Iterator[NDArray[numpy.int_]], patterns: NDArray[numpy.floating], xp: t.Any + ) -> t.Iterable[t.Tuple[NDArray[numpy.int_], NDArray[numpy.floating]]]: + xp = cast_array_module(xp) + + if self.engine_plan.buffer_n_groups is None: + return ((group, patterns[tuple(xp.asarray(group))]) for group in groups) + return stream_patterns( + groups, patterns, xp=xp, buf_n=self.engine_plan.buffer_n_groups + ) + def presolve( self, sim: SimulationState, @@ -44,7 +60,7 @@ def presolve( rescale_factors = [] # precompute obj_mag, probe_mag, and rescale probe intensity - for (group, group_patterns) in stream_patterns(groups, patterns, xp=sim.xp, buf_n=self.engine_plan.buffer_n_groups): + for (group, group_patterns) in self.iter_patterns(groups, patterns, sim.xp): (self.obj_mag, self.probe_mag, group_rescale_factors) = lsqml_dry_run( sim, group, group_patterns, props=propagators, pattern_mask=pattern_mask, obj_mag=self.obj_mag, probe_mag=self.probe_mag @@ -89,8 +105,7 @@ def run_iteration( pos_update = xp.zeros_like(sim.state.scan, dtype=sim.dtype) iter_errors = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups, patterns, xp=xp, - buf_n=self.engine_plan.buffer_n_groups)): + for (group_i, (group, group_patterns)) in enumerate(self.iter_patterns(groups, patterns, xp)): group_calc_error = calc_error and calc_error_mask[group_i] (sim, new_obj_mag, new_probe_mag, errors, group_pos_update) = lsqml_run( @@ -102,13 +117,13 @@ def run_iteration( update_probe=update_probe, update_position=update_positions, calc_error=group_calc_error, + jit_unroll_slices=self.jit_unroll_slices, illum_reg_object=illum_reg_object, illum_reg_probe=illum_reg_probe, gamma=gamma, ) - check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") - #assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - #assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + if self.engine_plan.check_every_group: + check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") sim = sim.apply_group_constraints(group) @@ -171,7 +186,7 @@ def run_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], st @partial( jit, donate_argnames=('sim', 'new_obj_mag', 'new_probe_mag'), - static_argnames=('update_object', 'update_probe', 'update_position', 'calc_error'), + static_argnames=('update_object', 'update_probe', 'update_position', 'calc_error', 'jit_unroll_slices'), ) def lsqml_run( sim: SimulationState, @@ -189,6 +204,7 @@ def lsqml_run( update_probe: bool = True, update_position: bool = True, calc_error: bool = True, + jit_unroll_slices: t.Union[int, bool] = False, illum_reg_object: float, illum_reg_probe: float, gamma: float, @@ -223,7 +239,7 @@ def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], st props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)] if sim.state.tilt is not None else None) - (group_probe_mag, psi) = slice_forwards(props, (group_probe_mag, psi), sim_slice) + (group_probe_mag, psi) = slice_forwards(props, (group_probe_mag, psi), sim_slice, jit_unroll_slices=jit_unroll_slices) new_obj_mag += group_obj_mag new_probe_mag += group_probe_mag @@ -270,7 +286,7 @@ def update_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], return (sim, chi) - (sim, chi) = slice_backwards(props, (sim, chi), update_slice) + (sim, chi) = slice_backwards(props, (sim, chi), update_slice, jit_unroll_slices=jit_unroll_slices) if update_position: def calc_pos_step(probes_fft: NDArray[numpy.complexfloating], kx: NDArray[numpy.floating]) -> NDArray[numpy.floating]: @@ -301,8 +317,25 @@ def name(cls) -> str: def init(self, sim: SimulationState) -> SimulationState: self.logger = logging.getLogger(__name__) + + if self.engine_plan.jit_unroll_slices and xp_is_jax(sim.xp): + self.logger.warning(f"'jit_unroll_slices' set to '{self.engine_plan.jit_unroll_slices!r}'. " + "This can result in slow compilation, use with care.") + self.jit_unroll_slices = False if self.engine_plan.jit_unroll_slices is None else self.engine_plan.jit_unroll_slices + return sim + def iter_patterns( + self, groups: t.Iterator[NDArray[numpy.int_]], patterns: NDArray[numpy.floating], xp: t.Any + ) -> t.Iterable[t.Tuple[NDArray[numpy.int_], NDArray[numpy.floating]]]: + xp = cast_array_module(xp) + + if self.engine_plan.buffer_n_groups is None: + return ((group, patterns[tuple(xp.asarray(group))]) for group in groups) + return stream_patterns( + groups, patterns, xp=xp, buf_n=self.engine_plan.buffer_n_groups + ) + def presolve( self, sim: SimulationState, @@ -312,8 +345,7 @@ def presolve( propagators: t.Optional[NDArray[numpy.complexfloating]], ) -> SimulationState: rescale_factors = [] - for (group, group_patterns) in stream_patterns(groups, patterns, xp=sim.xp, - buf_n=self.engine_plan.buffer_n_groups): + for (group, group_patterns) in self.iter_patterns(groups, patterns, sim.xp): group_rescale_factors = epie_dry_run( sim, group, group_patterns, pattern_mask=pattern_mask, props=propagators ) @@ -351,8 +383,7 @@ def run_iteration( beta_object = process_schedule(self.plan.beta_object)({'state': sim.state, 'niter': self.engine_plan.niter}) beta_probe = process_schedule(self.plan.beta_probe)({'state': sim.state, 'niter': self.engine_plan.niter}) - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups, patterns, xp=xp, - buf_n=self.engine_plan.buffer_n_groups)): + for (group_i, (group, group_patterns)) in enumerate(self.iter_patterns(groups, patterns, xp)): group_calc_error = calc_error and calc_error_mask[group_i] (sim, errors, group_pos_update) = epie_run( @@ -363,10 +394,10 @@ def run_iteration( beta_probe=beta_probe, update_object=update_object, update_probe=update_probe, + jit_unroll_slices=self.jit_unroll_slices, ) - check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") - #assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - #assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + if self.engine_plan.check_every_group: + check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") sim = sim.apply_group_constraints(group) @@ -412,7 +443,11 @@ def run_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], ps return exp_intensity / model_intensity -@partial(jit, donate_argnames=('sim',), static_argnames=('update_object', 'update_probe', 'update_position', 'calc_error')) +@partial( + jit, + donate_argnames=('sim',), + static_argnames=('update_object', 'update_probe', 'update_position', 'calc_error', 'jit_unroll_slices') +) def epie_run( sim: SimulationState, group: NDArray[numpy.integer], @@ -425,6 +460,7 @@ def epie_run( update_probe: bool = True, update_position: bool = True, calc_error: bool = True, + jit_unroll_slices: t.Union[int, bool] = False, ) -> t.Tuple[SimulationState, t.Optional[NDArray[numpy.floating]], t.Optional[NDArray[numpy.floating]]]: xp = cast_array_module(sim.xp) obj_grid = sim.state.object.sampling @@ -444,7 +480,7 @@ def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], ps props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)] if sim.state.tilt is not None else None) - psi = slice_forwards(props, psi, sim_slice) + psi = slice_forwards(props, psi, sim_slice, jit_unroll_slices=jit_unroll_slices) model_wave = fft2(psi[-1] * group_obj[:, -1, None]) # sum over incoherent modes @@ -484,7 +520,7 @@ def update_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], return (sim, chi) - (sim, chi) = slice_backwards(props, (sim, chi), update_slice) + (sim, chi) = slice_backwards(props, (sim, chi), update_slice, jit_unroll_slices=jit_unroll_slices) if update_position: def calc_pos_step(probes_fft: NDArray[numpy.complexfloating], kx: NDArray[numpy.floating]) -> NDArray[numpy.floating]: diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 489ac37..2efbd46 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -158,6 +158,8 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: observer: Observer = args.get('observer', Observer()) state = args['state'] seed = args['seed'] + # default to 10 slices + jit_unroll_slices = 10 if props.jit_unroll_slices is None else props.jit_unroll_slices noise_model = props.noise_model(None) @@ -290,7 +292,7 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple pattern_mask=pattern_mask, probe_int=probe_int, xp=xp, dtype=dtype, - jit_unroll_slices=props.jit_unroll_slices, + jit_unroll_slices=jit_unroll_slices, ) if props.check_every_group and not numpy.isfinite(float(losses_gpu['total_loss'])): raise ValueError(f"NaN or inf encountered, group {group_i}") diff --git a/phaser/plan.py b/phaser/plan.py index 0e68ad6..753d13b 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -64,13 +64,13 @@ class EnginePlan(Dataclass, kw_only=True): entire dataset to the device. """ - jit_unroll_slices: t.Union[bool, int] = 10 + jit_unroll_slices: t.Union[None, bool, int] = None """ Slices to unroll during JIT compilation (JAX backend only). Larger unrolling may be faster, at the expense of increased compilation time. `True` or `0` unrolls all slices, `False` or `1` disables unrolling. - `10` should be a good default value. + Defaults vary by engine (currently `10` for the gradient descent engine). """ update_probe: FlagLike = True diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 7cb6b6e..3ca6639 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -113,6 +113,13 @@ def get(self, name: BackendName): return None if t.TYPE_CHECKING else self.inner[name] + def try_get(self, name: BackendName): + name = self._normalize(name) + if name == 'numpy': + return numpy + + return self.inner.get(name) + def __getitem__(self, name: BackendName): if (backend := self.get(name)) is not None: return backend @@ -377,13 +384,15 @@ def is_torch(arr: t.Any) -> bool: def xp_is_cupy(xp: t.Any) -> bool: - return xp is sys.modules.get('cupy') + if (cupy := _BACKEND_LOADER.try_get('cupy')) is None: + return False + return xp is cupy def xp_is_jax(xp: t.Any) -> bool: return xp is sys.modules.get('jax.numpy') def xp_is_torch(xp: t.Any) -> bool: - if (torch := _BACKEND_LOADER.get('torch')) is None: + if (torch := _BACKEND_LOADER.try_get('torch')) is None: return False return xp is torch @@ -794,7 +803,7 @@ def ufunc_outer(ufunc: numpy.ufunc, x: ArrayLike, y: ArrayLike) -> numpy.ndarray from ._jax_kernels import outer return outer(ufunc, x, y) - if not t.TYPE_CHECKING and is_torch(x): + if not t.TYPE_CHECKING and (is_torch(x) or is_cupy(x)): return ufunc(x[(..., *((None,) * y.ndim))], y[(*((None,) * x.ndim), ...)]) return ufunc.outer(x, y) @@ -1090,4 +1099,4 @@ def at(arr: NDArray[DTypeT], idx: IndexLike) -> _AtImpl[DTypeT]: 'abs2', 'split_array', 'unstack', 'at', 'ufunc_outer', 'check_finite', 'Sampling', 'IndexLike', -] \ No newline at end of file +] diff --git a/pyproject.toml b/pyproject.toml index 3d08640..34fda05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ - "numpy>=2.0,<2.6", # tested on 2.3 + "numpy>=2.0,<2.7", # tested on 2.4 "scipy>=1.7.0,<1.19", # tested on 1.11, 1.16 "matplotlib~=3.8", "h5py~=3.8", @@ -37,7 +37,7 @@ dependencies = [ "rich>=12.0.0,<15", "tifffile>=2023.8.25", "optree>=0.13.0", - "py-pane==0.11.4", + "py-pane==0.11.5", "typing_extensions~=4.7", ] @@ -55,10 +55,14 @@ cupy12 = [ "cupy-cuda12x>=12.0.0", "pynvml>=11.0.0", ] +cupy13 = [ + "cupy-cuda13x>=12.0.0", + "pynvml>=11.0.0", +] jax = [ # 0.4.25 is last version supporting cuda 11.8, we need to support it. - # tested on 0.4.25, 0.5.x, 0.6.x, and 0.7.0 - "jax>=0.4.25,<0.8", + # tested on 0.4.25, 0.5.x, 0.6.x, 0.7.0, and 0.10.0 + "jax>=0.4.25,<0.11", "optax>=0.2.2", ] torch = [