From 2c6dd621d6a0246607d22e3c873962258ac614a9 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 17 Dec 2020 16:50:49 +0100 Subject: [PATCH] Prototypes can use keras constraints. --- protoflow/applications/glvq.py | 4 ++++ protoflow/layers/prototypes.py | 10 +++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/protoflow/applications/glvq.py b/protoflow/applications/glvq.py index 50fa785..1ed961c 100644 --- a/protoflow/applications/glvq.py +++ b/protoflow/applications/glvq.py @@ -12,6 +12,7 @@ class _GLVQ(Network): layer used by GLVQ-like models including GMLVQ and LVQMLN. """ + def compile(self, loss=None, squashing="sigmoid_beta", @@ -59,6 +60,7 @@ class GLVQ(_GLVQ): prototype_initializer (str) : Method to use to set the initial prototype locations. (default: "mean") """ + def __init__(self, nclasses, input_dim, @@ -66,6 +68,7 @@ def __init__(self, prototype_initializer="zeros", trainable_prototypes=True, prototypes_dtype="float32", + prototype_constraint=None, distance_fn=squared_euclidean_distance, **kwargs): super().__init__(**kwargs) @@ -75,6 +78,7 @@ def __init__(self, prototypes_per_class=prototypes_per_class, prototype_initializer=prototype_initializer, trainable_prototypes=trainable_prototypes, + prototype_constraint=prototype_constraint, dtype=prototypes_dtype, ) self.distance_fn = distance_fn diff --git a/protoflow/layers/prototypes.py b/protoflow/layers/prototypes.py index 9ab47d9..4f7a709 100644 --- a/protoflow/layers/prototypes.py +++ b/protoflow/layers/prototypes.py @@ -6,14 +6,18 @@ from protoflow.modules import initializers +from tensorflow.python.keras import constraints + class _Prototypes(tf.keras.layers.Layer): """Base class for Prototype layers in ProtoFlow.""" + def __init__(self, nclasses=None, prototypes_per_class=1, prototype_distribution=None, - prototype_initializer='zeros', + prototype_initializer="zeros", + prototype_constraint=None, trainable_prototypes=True, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: @@ -28,6 +32,7 @@ def __init__(self, assert self.nclasses == len(prototype_distribution) self.prototype_distribution = prototype_distribution self.prototype_initializer = initializers.get(prototype_initializer) + self.prototype_constraint = constraints.get(prototype_constraint) self.trainable_prototypes = trainable_prototypes # Make a label list and flatten the list of lists using itertools @@ -55,6 +60,7 @@ def get_config(self): class Prototypes1D(_Prototypes): """Point Prototypes.""" + def build(self, input_shape): num_of_prototypes = sum(self.prototype_distribution) self.prototypes = self.add_weight( @@ -62,6 +68,7 @@ def build(self, input_shape): shape=(num_of_prototypes, input_shape[-1]), dtype=self.dtype, initializer=self.prototype_initializer, + constraint=self.prototype_constraint, trainable=self.trainable_prototypes) super().build(input_shape) @@ -77,6 +84,7 @@ class AppendPrototypes1D(Prototypes1D): `shape_transformation`: Callable that is to be applied to get a matrix. """ + def __init__(self, shape_transformation=None, **kwargs): self.shape_transformation = shape_transformation or tf.identity super().__init__(**kwargs)