From 1731ce9945277549c1164e464d19879e97a38777 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 26 Jan 2016 14:03:01 +0100 Subject: [PATCH] handle double-setting of state --- plato/core.py | 19 ++++++++++++++----- plato/interfaces/helpers.py | 1 - plato/test_core.py | 30 +++++++++++++++--------------- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/plato/core.py b/plato/core.py index 46e86a4..65ab40e 100644 --- a/plato/core.py +++ b/plato/core.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +from collections import OrderedDict, Counter from functools import partial import inspect import logging @@ -726,19 +726,28 @@ def __init__(self, swallow_updates = False): def __enter__(self): self._outer_catcher = _get_state_catcher() _set_state_catcher(self) - self._updates = [] + self._updates = {} return self def __exit__(self, *args): _set_state_catcher(self._outer_catcher) - def add_update(self, shared_var, new_val): - self._updates.append((shared_var, new_val)) + def add_update(self, shared_var, new_val, add_multiple_updates = False): + + if shared_var in self._updates: + if add_multiple_updates: + self._updates[shared_var] = self._updates[shared_var] + new_val + else: + raise Exception("You updated variable %s with value %s, but it already was being updated with value %s. " + "\nIf your intention was to add these updates, call this function with add_multiple_updates=True" + % (shared_var, new_val, self._updates[shared_var])) + else: + self._updates[shared_var] = new_val if self._outer_catcher is not None and not self.swallow_updates: # Allows for nested StateCatchers (outer ones do not have to worry about inner ones stealing their updates) self._outer_catcher.add_update(shared_var, new_val) def get_updates(self): - return self._updates + return self._updates.items() def assert_compatible_shape(actual_shape, desired_shape, name = None): diff --git a/plato/interfaces/helpers.py b/plato/interfaces/helpers.py index 401e2ff..3c38f66 100644 --- a/plato/interfaces/helpers.py +++ b/plato/interfaces/helpers.py @@ -95,7 +95,6 @@ def get_named_activation_function(activation_name): 'safenorm-relu': lambda x: normalize_safely(tt.maximum(x, 0), axis = -1), 'balanced-relu': lambda x: tt.maximum(x, 0)*(2*(tt.arange(x.shape[-1]) % 2)-1), # Glorot et al. Deep Sparse Rectifier Networks 'prenorm-relu': lambda x: tt.maximum(normalize_safely(x, axis = -1, degree = 2), 0), - 'linear': lambda x: x, 'leaky-relu-0.01': lambda x: tt.maximum(0.01*x, x), 'maxout': lambda x: tt.max(x, axis=1), # We expect (n_samples, n_maps, n_dims) data and flatten to (n_samples, n_dims) }[activation_name] diff --git a/plato/test_core.py b/plato/test_core.py index 87fb8a2..3bcaa38 100644 --- a/plato/test_core.py +++ b/plato/test_core.py @@ -474,19 +474,19 @@ def mat_mult(a, b): if __name__ == '__main__': - # test_ival_ishape() - # test_catch_sneaky_updates() - # test_catch_non_updates() + test_ival_ishape() + test_catch_sneaky_updates() + test_catch_non_updates() test_scan() - # test_strrep() - # test_omniscence() - # test_named_arguments() - # test_stateless_symbolic_function() - # test_stateful_symbolic_function() - # test_debug_trace() - # test_method_caching_bug() - # test_pure_updater() - # test_function_format_checking() - # test_callable_format_checking() - # test_inhereting_from_decorated() - # test_dual_decoration() + test_strrep() + test_omniscence() + test_named_arguments() + test_stateless_symbolic_function() + test_stateful_symbolic_function() + test_debug_trace() + test_method_caching_bug() + test_pure_updater() + test_function_format_checking() + test_callable_format_checking() + test_inhereting_from_decorated() + test_dual_decoration()