From 9137f282019d245b4912d9280b427efb4cfe7d1b Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 11 Oct 2017 08:05:07 +0200 Subject: [PATCH] eeehhh --- plato/tools/common/basic.py | 4 ++-- plato/tools/common/online_predictors.py | 28 ++++++++++++++++++++++++- plato/tools/mlp/mlp.py | 4 ++-- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/plato/tools/common/basic.py b/plato/tools/common/basic.py index 2ead9e9..c9288ed 100644 --- a/plato/tools/common/basic.py +++ b/plato/tools/common/basic.py @@ -81,7 +81,7 @@ def running_mean_and_variance(data, decay = None, shape = None, elementwise=True var_new = s_new add_update(mean_last, mean_new) add_update(s_last, s_new) - return var_new + return mean_new, var_new @symbolic @@ -93,4 +93,4 @@ def running_variance(data, decay=None, shape = None, elementwise=True, initial_v :param shape: :return: """ - return running_mean_and_variance(data=data, decay=decay, shape=shape, elementwise=elementwise, initial_var = initial_value) + return running_mean_and_variance(data=data, decay=decay, shape=shape, elementwise=elementwise, initial_var = initial_value)[1] diff --git a/plato/tools/common/online_predictors.py b/plato/tools/common/online_predictors.py index 4bc6281..3302b94 100644 --- a/plato/tools/common/online_predictors.py +++ b/plato/tools/common/online_predictors.py @@ -1,4 +1,6 @@ from abc import ABCMeta, abstractmethod +from contextlib import contextmanager + from plato.interfaces.decorators import symbolic_simple, symbolic_updater from plato.interfaces.interfaces import IParameterized from plato.tools.optimization.cost import get_named_cost_function @@ -87,6 +89,25 @@ def parameters(self): return self._function.parameters + opt_params +_LOCAL_LOSSES = None + + +def declare_local_loss(loss): + if _LOCAL_LOSSES is not None: + _LOCAL_LOSSES.append(loss) + + +@contextmanager +def capture_local_losses(): + global _LOCAL_LOSSES + assert _LOCAL_LOSSES is None, "Local loss book already open" + _LOCAL_LOSSES = [] + try: + yield _LOCAL_LOSSES + finally: + _LOCAL_LOSSES = None + + class CompiledSymbolicPredictor(IPredictor, IParameterized): """ A Predictor containing the compiled methods for a SymbolicPredictor. @@ -125,7 +146,12 @@ def __call__(self, x): raise NotImplementedError() def train(self, x, y, cost_fcn, optimizer, assert_all_params_optimized=False, regularization_cost = None): - cost = cost_fcn(self.train_call(x), y) + with capture_local_losses() as local_losses: + cost = cost_fcn(self.train_call(x), y) + + if len(local_losses)>0: + cost = cost + sum(local_losses) + if regularization_cost is not None: cost = cost + regularization_cost(self.parameters) if isinstance(optimizer, dict): diff --git a/plato/tools/mlp/mlp.py b/plato/tools/mlp/mlp.py index c97a2e6..2b86baa 100644 --- a/plato/tools/mlp/mlp.py +++ b/plato/tools/mlp/mlp.py @@ -28,8 +28,8 @@ def __call__(self, x): return x @symbolic - def get_layer_activations(self, x): - activations = [] + def get_layer_activations(self, x, include_input = False): + activations = [x] if include_input else [] for lay in self.layers: x = lay(x) activations.append(x)