Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 56 additions & 114 deletions API/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from growth_color_models import GrowthPredictor, ColorPredictor
try:
import torch
TORCH_OK = True

except ImportError:
TORCH_OK = False
print("[Torch] PyTorch not installed — growth/colour prediction unavailable.")

# Plant data
PLANTS_CSV_DATA = """name,pet_safe,space,water,sunlight,temperature,pollen_allergies,existing_plants
Expand Down Expand Up @@ -324,51 +332,35 @@ def _infer(pil_img):
}

# Growth / colour model classes (inline — no external file needed)
try:
import torch
import torch.nn as nn

class GrowthPredictor(nn.Module):
"""Simple MLP that predicts height growth (cm) from 9 numeric features."""
def __init__(self, input_size: int = 9):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1),
)
self.register_buffer("X_mean", torch.zeros(input_size))
self.register_buffer("X_std", torch.ones(input_size))

def forward(self, x: "torch.Tensor") -> "torch.Tensor":
return self.net(x).squeeze(-1)

class ColorPredictor(nn.Module):
"""Simple MLP that predicts plant colour class (5 classes) from 14 features."""
def __init__(self, input_size: int = 14, hidden_size: int = 32):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size), nn.ReLU(),
nn.Linear(hidden_size, hidden_size), nn.ReLU(),
nn.Linear(hidden_size, 5),
)
self.register_buffer("X_mean", torch.zeros(input_size))
self.register_buffer("X_std", torch.ones(input_size))

def forward(self, x: "torch.Tensor") -> "torch.Tensor":
return self.net(x)

TORCH_OK = True
print("[Torch] GrowthPredictor and ColorPredictor defined successfully.")
class DataVector(BaseModel):
days_passed: float
avg_direct_light: float
avg_indirect_light: float
avg_nighttime: float
avg_temp: float
min_temp: float
max_temp: float
times_watered: float
initial_height: float
color_before: List[int]

except ImportError:
TORCH_OK = False
print("[Torch] PyTorch not installed — growth/colour prediction unavailable.")
# Loading grwoth/color models

growth_model = GrowthPredictor()
color_model = ColorPredictor()

color_checkpoint = torch.load("API/weights/health_model.pth", weights_only=True)
growth_checpoint = torch.load("API/weights/regression_model.pth", weights_only=True)
color_model.load_state_dict(color_checkpoint['model_state'])
growth_model.out.load_state_dict(growth_checpoint)

X_mean = color_checkpoint['X_mean']
X_std = color_checkpoint['X_std']


# Pydantic model – growth / colour prediction
class DataVector(BaseModel):
user_id: Optional[int] = None
days_passed: float
avg_direct_light: float
avg_indirect_light: float
Expand All @@ -380,45 +372,6 @@ class DataVector(BaseModel):
initial_height: float
color_before: List[int]

# Load / initialise growth models at startup
_growth_model = None
_color_model = None

GROWTH_MODEL_PATH = Path(__file__).resolve().parent.parent / "growth_predictor.pt"
COLOR_MODEL_PATH = Path(__file__).resolve().parent.parent / "color_predictor.pt"

def _load_torch_models():
global _growth_model, _color_model
if not TORCH_OK:
print("[Torch] Skipping model load — PyTorch not available.")
return
try:
import torch
gm = GrowthPredictor(9)
cm = ColorPredictor(14, 32)

# Load saved weights if checkpoint files exist
if GROWTH_MODEL_PATH.exists():
state = torch.load(str(GROWTH_MODEL_PATH), map_location="cpu")
gm.load_state_dict(state, strict=False)
print(f"[Torch] Loaded growth weights from {GROWTH_MODEL_PATH}")
else:
print("[Torch] No growth checkpoint found — using untrained weights (random predictions).")

if COLOR_MODEL_PATH.exists():
state = torch.load(str(COLOR_MODEL_PATH), map_location="cpu")
cm.load_state_dict(state, strict=False)
print(f"[Torch] Loaded colour weights from {COLOR_MODEL_PATH}")
else:
print("[Torch] No colour checkpoint found — using untrained weights (random predictions).")

gm.eval(); cm.eval()
_growth_model = gm
_color_model = cm
print("[Torch] Growth and colour models ready.")
except Exception as e:
print(f"[Torch] Model load failed: {e}")

# FastAPI app
app = FastAPI(title="Plant Care Unified API", version="2.0.0")

Expand All @@ -429,10 +382,6 @@ def _load_torch_models():
allow_headers=["*"],
)

@app.on_event("startup")
def _startup():
_load_torch_models()

@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
Expand All @@ -443,8 +392,8 @@ def health():
"success": True, "status": "ok", "version": "2.0.0",
"models": {
"detection": "trained" if CNN_MODEL_PATH.exists() else "fallback",
"growth": "loaded" if _growth_model else "unavailable",
"colour": "loaded" if _color_model else "unavailable",
"growth": "loaded" if growth_model else "unavailable",
"colour": "loaded" if color_model else "unavailable",
},
}

Expand Down Expand Up @@ -497,47 +446,40 @@ def detect_b64(body: DetectBase64Request):
return res
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/fwd")
# --- Prediction endpoints ---
@app.post("/growth")
async def predict_growth(vector: DataVector):
if not TORCH_OK:
raise HTTPException(status_code=503,
detail="PyTorch is not installed. Run: pip install torch")
if _growth_model is None or _color_model is None:
if growth_model is None or color_model is None:
raise HTTPException(status_code=503,
detail="Growth/colour models failed to initialise. Check server logs.")
import torch

numeric = [
vector.days_passed, vector.avg_direct_light, vector.avg_indirect_light,
vector.avg_nighttime, vector.avg_temp, vector.min_temp, vector.max_temp,
vector.times_watered, vector.initial_height,
]
color_oh = list(vector.color_before)
flat_vector = list(vector.model_dump().values())
flat_vector.pop()
inp = torch.tensor(flat_vector, dtype=torch.float32).unsqueeze(0)

inp_growth = torch.tensor(numeric, dtype=torch.float32).unsqueeze(0)
inp_color = torch.tensor(numeric + color_oh, dtype=torch.float32).unsqueeze(0)

try:
inp_growth_n = (inp_growth - _growth_model.X_mean) / (_growth_model.X_std + 1e-8)
except Exception:
inp_growth_n = inp_growth
try:
inp_color_n = (inp_color - _color_model.X_mean) / (_color_model.X_std + 1e-8)
except Exception:
inp_color_n = inp_color
inp_norm = (inp - X_mean) / (X_std)

inp_c = torch.tensor(flat_vector).unsqueeze(0)
inp_cb = torch.tensor(vector.color_before).unsqueeze(0)
inp_c_norm = (inp_c - X_mean) / (X_std)
inp_c_final = torch.cat([inp_c_norm, inp_cb], dim=1)
inp_norm = torch.cat([inp_norm, inp_cb], dim=1)

growth_model.eval()
color_model.eval()
with torch.no_grad():
growth_pred = float(_growth_model(inp_growth_n).item())
logits = _color_model(inp_color_n)
color_idx = int(torch.argmax(torch.softmax(logits, dim=1), dim=1).item())
pred = growth_model(inp_norm).item()

logits = color_model(inp_c_final)
probs = torch.softmax(logits, dim=1)
color = torch.argmax(probs, dim=1).item()

return {"guess" : pred, "color": color}

return {
"guess": round(growth_pred, 3),
"color": color_idx,
"inputs": {"numeric": numeric, "color_before": color_oh},
}

# Entry point
if __name__ == "__main__":
Expand Down
25 changes: 25 additions & 0 deletions API/growth_color_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class GrowthPredictor(nn.Module):
def __init__(self):
super().__init__()
self.out = nn.Linear(14, 1)

def forward(self, x):
return self.out(x)

class ColorPredictor(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(14, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 5)
)

def forward(self, x):
return self.net(x)
Binary file added API/weights/health_model.pth
Binary file not shown.
Binary file added API/weights/regression_model.pth
Binary file not shown.
31 changes: 14 additions & 17 deletions ui/pages/growth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class GrowthPage(BasePage):
FIELDS = [
("days_passed", "Days passed"),
("avg_direct_light", "Avg direct light (lux)"),
("avg_indirect_light", "Avg indirect light (lux)"),
("avg_direct_light", "Avg direct light (hrs)"),
("avg_indirect_light", "Avg indirect light (hrs)"),
("avg_nighttime", "Avg nighttime (hrs)"),
("avg_temp", "Avg temperature (C)"),
("min_temp", "Min temperature (C)"),
Expand Down Expand Up @@ -91,7 +91,7 @@ def on_color(e, v=c):

api_row = tk.Frame(left_inner, bg=BG_CARD); api_row.pack(fill="x", pady=(0,10))
tk.Label(api_row, text="API URL", font=self.f_small, bg=BG_CARD, fg=TEXT_SEC).pack(anchor="w", pady=(0,4))
self._api_url_var = tk.StringVar(value="http://localhost:8000/fwd")
self._api_url_var = tk.StringVar(value="http://localhost:5000/growth")
api_frame = tk.Frame(api_row, bg=BG_GLASS); api_frame.pack(fill="x")
tk.Frame(api_frame, bg=TEAL, width=3).pack(side="left", fill="y")
api_entry = tk.Entry(api_frame, textvariable=self._api_url_var,
Expand Down Expand Up @@ -145,17 +145,6 @@ def _animate(self):

def _run_prediction(self):
# Import the model lazily so the GUI works even before models/growth.py exists.
try:
from models.growth import predict_growth
except ModuleNotFoundError:
messagebox.showinfo(
"Model not available yet",
"The growth model (models/growth.py) hasn't been added yet.\n"
"The page will work as soon as it's in place.")
return
except Exception as ex:
messagebox.showerror("Model Error", f"Could not load the growth model:\n{ex}")
return

data = {}
for key, label in self.FIELDS:
Expand All @@ -167,13 +156,21 @@ def _run_prediction(self):
messagebox.showerror("Invalid Input", f"'{label}' must be a number."); return
color = self._color_var.get()
api_url = self._api_url_var.get().strip()
vector = [1 if c == color else 0 for c in self.COLORS]
payload = {**data, "color_before": vector}
if not api_url:
messagebox.showwarning("Missing API URL", "Please enter the growth API URL."); return
self._show_loading()
def _call():
try:
rep = predict_growth(data, color, api_url, user_id=1, timeout=10)
self.after(0, lambda: self._show_result(rep))
import requests as req_lib
response = req_lib.post(api_url, json=payload, timeout=10)
if response.status_code == 200:
rep = response.json()
self.after(0, lambda: self._show_result(rep))
else:
msg = f"Server returned status {response.status_code}.\n{response.text[:200]}"
self.after(0, lambda m=msg: self._show_error(m))
except Exception as ex:
self.after(0, lambda e=ex: self._show_error(str(e)))
threading.Thread(target=_call, daemon=True).start()
Expand Down Expand Up @@ -202,7 +199,7 @@ def _show_result(self, rep):
height_row.pack(fill="x", pady=(0,8))
tk.Label(height_row, text="Predicted height growth", font=self.f_small, bg=BG_GLASS, fg=TEXT_SEC).pack(anchor="w")
val_row = tk.Frame(height_row, bg=BG_GLASS); val_row.pack(anchor="w")
tk.Label(val_row, text=f"{guess}", font=("Segoe UI",30,"bold"), bg=BG_GLASS, fg=TEAL).pack(side="left")
tk.Label(val_row, text=f"{guess:.3f}", font=("Segoe UI",30,"bold"), bg=BG_GLASS, fg=TEAL).pack(side="left")
tk.Label(val_row, text=" cm", font=("Segoe UI",12), bg=BG_GLASS, fg=TEXT_SEC).pack(side="left", pady=(10,0))

color_row = tk.Frame(inner, bg=BG_GLASS, padx=16, pady=14)
Expand Down
Loading