diff --git a/presp/prescriptor/nn.py b/presp/prescriptor/nn.py index 721cf0a..a811564 100644 --- a/presp/prescriptor/nn.py +++ b/presp/prescriptor/nn.py @@ -31,6 +31,8 @@ def __init__(self, model_params: list[dict], device: str = "cpu"): layers.append(torch.nn.Linear(**layer)) elif layer_type == "tanh": layers.append(torch.nn.Tanh(**layer)) + elif layer_type == "relu": + layers.append(torch.nn.ReLU(**layer)) elif layer_type == "sigmoid": layers.append(torch.nn.Sigmoid(**layer)) elif layer_type == "softmax":