Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions plato/tools/common/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def running_mean_and_variance(data, decay = None, shape = None, elementwise=True
var_new = s_new
add_update(mean_last, mean_new)
add_update(s_last, s_new)
return var_new
return mean_new, var_new


@symbolic
Expand All @@ -93,4 +93,4 @@ def running_variance(data, decay=None, shape = None, elementwise=True, initial_v
:param shape:
:return:
"""
return running_mean_and_variance(data=data, decay=decay, shape=shape, elementwise=elementwise, initial_var = initial_value)
return running_mean_and_variance(data=data, decay=decay, shape=shape, elementwise=elementwise, initial_var = initial_value)[1]
28 changes: 27 additions & 1 deletion plato/tools/common/online_predictors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager

from plato.interfaces.decorators import symbolic_simple, symbolic_updater
from plato.interfaces.interfaces import IParameterized
from plato.tools.optimization.cost import get_named_cost_function
Expand Down Expand Up @@ -87,6 +89,25 @@ def parameters(self):
return self._function.parameters + opt_params


_LOCAL_LOSSES = None


def declare_local_loss(loss):
if _LOCAL_LOSSES is not None:
_LOCAL_LOSSES.append(loss)


@contextmanager
def capture_local_losses():
global _LOCAL_LOSSES
assert _LOCAL_LOSSES is None, "Local loss book already open"
_LOCAL_LOSSES = []
try:
yield _LOCAL_LOSSES
finally:
_LOCAL_LOSSES = None


class CompiledSymbolicPredictor(IPredictor, IParameterized):
"""
A Predictor containing the compiled methods for a SymbolicPredictor.
Expand Down Expand Up @@ -125,7 +146,12 @@ def __call__(self, x):
raise NotImplementedError()

def train(self, x, y, cost_fcn, optimizer, assert_all_params_optimized=False, regularization_cost = None):
cost = cost_fcn(self.train_call(x), y)
with capture_local_losses() as local_losses:
cost = cost_fcn(self.train_call(x), y)

if len(local_losses)>0:
cost = cost + sum(local_losses)

if regularization_cost is not None:
cost = cost + regularization_cost(self.parameters)
if isinstance(optimizer, dict):
Expand Down
4 changes: 2 additions & 2 deletions plato/tools/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __call__(self, x):
return x

@symbolic
def get_layer_activations(self, x):
activations = []
def get_layer_activations(self, x, include_input = False):
activations = [x] if include_input else []
for lay in self.layers:
x = lay(x)
activations.append(x)
Expand Down