Skip to content

Construct train step from an objective function and optimizer #595

Description

@seanmor5

Right now the only way to construct a train step is using a loss function and an optimizer:

def train_step(model, loss, optimizer, opts \\ []) do

This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:

  objective_fn = fn trainable_parameters, model_state, loss_scale_state, inp, tar ->
    # hack to use trainable parameters as grad
    model_state =
      update_in(model_state, [Access.key!(:data)], fn data ->
        tree_merge(data, trainable_parameters, fn _, _, v -> v end)
      end)

    model_out = forward_model_fn.(model_state, inp)
    unscaled_loss = loss_fn.(tar, model_out.prediction)
    scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)

    {model_out, scaled_loss, unscaled_loss}
  end

If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions