From e34704411fd68d5e287dce97c6f90501249636a7 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 29 Apr 2025 13:21:30 -0700 Subject: [PATCH] Added relu to nn prescriptor --- presp/prescriptor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/presp/prescriptor/nn.py b/presp/prescriptor/nn.py index 721cf0a..a811564 100644 --- a/presp/prescriptor/nn.py +++ b/presp/prescriptor/nn.py @@ -31,6 +31,8 @@ def __init__(self, model_params: list[dict], device: str = "cpu"): layers.append(torch.nn.Linear(**layer)) elif layer_type == "tanh": layers.append(torch.nn.Tanh(**layer)) + elif layer_type == "relu": + layers.append(torch.nn.ReLU(**layer)) elif layer_type == "sigmoid": layers.append(torch.nn.Sigmoid(**layer)) elif layer_type == "softmax":