diff --git a/API/api.py b/API/api.py index 9f97dd4..bc3af9b 100644 --- a/API/api.py +++ b/API/api.py @@ -9,6 +9,14 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import uvicorn +from growth_color_models import GrowthPredictor, ColorPredictor +try: + import torch + TORCH_OK = True + +except ImportError: + TORCH_OK = False + print("[Torch] PyTorch not installed — growth/colour prediction unavailable.") # Plant data PLANTS_CSV_DATA = """name,pet_safe,space,water,sunlight,temperature,pollen_allergies,existing_plants @@ -324,51 +332,35 @@ def _infer(pil_img): } # Growth / colour model classes (inline — no external file needed) -try: - import torch - import torch.nn as nn - - class GrowthPredictor(nn.Module): - """Simple MLP that predicts height growth (cm) from 9 numeric features.""" - def __init__(self, input_size: int = 9): - super().__init__() - self.net = nn.Sequential( - nn.Linear(input_size, 128), nn.ReLU(), - nn.Linear(128, 64), nn.ReLU(), - nn.Linear(64, 32), nn.ReLU(), - nn.Linear(32, 1), - ) - self.register_buffer("X_mean", torch.zeros(input_size)) - self.register_buffer("X_std", torch.ones(input_size)) - - def forward(self, x: "torch.Tensor") -> "torch.Tensor": - return self.net(x).squeeze(-1) - - class ColorPredictor(nn.Module): - """Simple MLP that predicts plant colour class (5 classes) from 14 features.""" - def __init__(self, input_size: int = 14, hidden_size: int = 32): - super().__init__() - self.net = nn.Sequential( - nn.Linear(input_size, hidden_size), nn.ReLU(), - nn.Linear(hidden_size, hidden_size), nn.ReLU(), - nn.Linear(hidden_size, 5), - ) - self.register_buffer("X_mean", torch.zeros(input_size)) - self.register_buffer("X_std", torch.ones(input_size)) - - def forward(self, x: "torch.Tensor") -> "torch.Tensor": - return self.net(x) - TORCH_OK = True - print("[Torch] GrowthPredictor and ColorPredictor defined successfully.") +class DataVector(BaseModel): + days_passed: float + avg_direct_light: float + avg_indirect_light: float + avg_nighttime: float + avg_temp: float + min_temp: float + max_temp: float + times_watered: float + initial_height: float + color_before: List[int] -except ImportError: - TORCH_OK = False - print("[Torch] PyTorch not installed — growth/colour prediction unavailable.") +# Loading grwoth/color models +growth_model = GrowthPredictor() +color_model = ColorPredictor() + +color_checkpoint = torch.load("API/weights/health_model.pth", weights_only=True) +growth_checpoint = torch.load("API/weights/regression_model.pth", weights_only=True) +color_model.load_state_dict(color_checkpoint['model_state']) +growth_model.out.load_state_dict(growth_checpoint) + +X_mean = color_checkpoint['X_mean'] +X_std = color_checkpoint['X_std'] + + # Pydantic model – growth / colour prediction class DataVector(BaseModel): - user_id: Optional[int] = None days_passed: float avg_direct_light: float avg_indirect_light: float @@ -380,45 +372,6 @@ class DataVector(BaseModel): initial_height: float color_before: List[int] -# Load / initialise growth models at startup -_growth_model = None -_color_model = None - -GROWTH_MODEL_PATH = Path(__file__).resolve().parent.parent / "growth_predictor.pt" -COLOR_MODEL_PATH = Path(__file__).resolve().parent.parent / "color_predictor.pt" - -def _load_torch_models(): - global _growth_model, _color_model - if not TORCH_OK: - print("[Torch] Skipping model load — PyTorch not available.") - return - try: - import torch - gm = GrowthPredictor(9) - cm = ColorPredictor(14, 32) - - # Load saved weights if checkpoint files exist - if GROWTH_MODEL_PATH.exists(): - state = torch.load(str(GROWTH_MODEL_PATH), map_location="cpu") - gm.load_state_dict(state, strict=False) - print(f"[Torch] Loaded growth weights from {GROWTH_MODEL_PATH}") - else: - print("[Torch] No growth checkpoint found — using untrained weights (random predictions).") - - if COLOR_MODEL_PATH.exists(): - state = torch.load(str(COLOR_MODEL_PATH), map_location="cpu") - cm.load_state_dict(state, strict=False) - print(f"[Torch] Loaded colour weights from {COLOR_MODEL_PATH}") - else: - print("[Torch] No colour checkpoint found — using untrained weights (random predictions).") - - gm.eval(); cm.eval() - _growth_model = gm - _color_model = cm - print("[Torch] Growth and colour models ready.") - except Exception as e: - print(f"[Torch] Model load failed: {e}") - # FastAPI app app = FastAPI(title="Plant Care Unified API", version="2.0.0") @@ -429,10 +382,6 @@ def _load_torch_models(): allow_headers=["*"], ) -@app.on_event("startup") -def _startup(): - _load_torch_models() - @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @@ -443,8 +392,8 @@ def health(): "success": True, "status": "ok", "version": "2.0.0", "models": { "detection": "trained" if CNN_MODEL_PATH.exists() else "fallback", - "growth": "loaded" if _growth_model else "unavailable", - "colour": "loaded" if _color_model else "unavailable", + "growth": "loaded" if growth_model else "unavailable", + "colour": "loaded" if color_model else "unavailable", }, } @@ -497,47 +446,40 @@ def detect_b64(body: DetectBase64Request): return res except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/fwd") + +# --- Prediction endpoints --- @app.post("/growth") async def predict_growth(vector: DataVector): if not TORCH_OK: raise HTTPException(status_code=503, detail="PyTorch is not installed. Run: pip install torch") - if _growth_model is None or _color_model is None: + if growth_model is None or color_model is None: raise HTTPException(status_code=503, detail="Growth/colour models failed to initialise. Check server logs.") - import torch - numeric = [ - vector.days_passed, vector.avg_direct_light, vector.avg_indirect_light, - vector.avg_nighttime, vector.avg_temp, vector.min_temp, vector.max_temp, - vector.times_watered, vector.initial_height, - ] - color_oh = list(vector.color_before) + flat_vector = list(vector.model_dump().values()) + flat_vector.pop() + inp = torch.tensor(flat_vector, dtype=torch.float32).unsqueeze(0) - inp_growth = torch.tensor(numeric, dtype=torch.float32).unsqueeze(0) - inp_color = torch.tensor(numeric + color_oh, dtype=torch.float32).unsqueeze(0) - - try: - inp_growth_n = (inp_growth - _growth_model.X_mean) / (_growth_model.X_std + 1e-8) - except Exception: - inp_growth_n = inp_growth - try: - inp_color_n = (inp_color - _color_model.X_mean) / (_color_model.X_std + 1e-8) - except Exception: - inp_color_n = inp_color + inp_norm = (inp - X_mean) / (X_std) + + inp_c = torch.tensor(flat_vector).unsqueeze(0) + inp_cb = torch.tensor(vector.color_before).unsqueeze(0) + inp_c_norm = (inp_c - X_mean) / (X_std) + inp_c_final = torch.cat([inp_c_norm, inp_cb], dim=1) + inp_norm = torch.cat([inp_norm, inp_cb], dim=1) + growth_model.eval() + color_model.eval() with torch.no_grad(): - growth_pred = float(_growth_model(inp_growth_n).item()) - logits = _color_model(inp_color_n) - color_idx = int(torch.argmax(torch.softmax(logits, dim=1), dim=1).item()) + pred = growth_model(inp_norm).item() + + logits = color_model(inp_c_final) + probs = torch.softmax(logits, dim=1) + color = torch.argmax(probs, dim=1).item() + + return {"guess" : pred, "color": color} - return { - "guess": round(growth_pred, 3), - "color": color_idx, - "inputs": {"numeric": numeric, "color_before": color_oh}, - } # Entry point if __name__ == "__main__": diff --git a/API/growth_color_models.py b/API/growth_color_models.py new file mode 100644 index 0000000..5a5510e --- /dev/null +++ b/API/growth_color_models.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class GrowthPredictor(nn.Module): + def __init__(self): + super().__init__() + self.out = nn.Linear(14, 1) + + def forward(self, x): + return self.out(x) + +class ColorPredictor(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Sequential( + nn.Linear(14, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 5) + ) + + def forward(self, x): + return self.net(x) diff --git a/API/weights/health_model.pth b/API/weights/health_model.pth new file mode 100644 index 0000000..7e42abe Binary files /dev/null and b/API/weights/health_model.pth differ diff --git a/API/weights/regression_model.pth b/API/weights/regression_model.pth new file mode 100644 index 0000000..80311ca Binary files /dev/null and b/API/weights/regression_model.pth differ diff --git a/ui/pages/growth.py b/ui/pages/growth.py index e74ffc5..069bbbf 100644 --- a/ui/pages/growth.py +++ b/ui/pages/growth.py @@ -17,8 +17,8 @@ class GrowthPage(BasePage): FIELDS = [ ("days_passed", "Days passed"), - ("avg_direct_light", "Avg direct light (lux)"), - ("avg_indirect_light", "Avg indirect light (lux)"), + ("avg_direct_light", "Avg direct light (hrs)"), + ("avg_indirect_light", "Avg indirect light (hrs)"), ("avg_nighttime", "Avg nighttime (hrs)"), ("avg_temp", "Avg temperature (C)"), ("min_temp", "Min temperature (C)"), @@ -91,7 +91,7 @@ def on_color(e, v=c): api_row = tk.Frame(left_inner, bg=BG_CARD); api_row.pack(fill="x", pady=(0,10)) tk.Label(api_row, text="API URL", font=self.f_small, bg=BG_CARD, fg=TEXT_SEC).pack(anchor="w", pady=(0,4)) - self._api_url_var = tk.StringVar(value="http://localhost:8000/fwd") + self._api_url_var = tk.StringVar(value="http://localhost:5000/growth") api_frame = tk.Frame(api_row, bg=BG_GLASS); api_frame.pack(fill="x") tk.Frame(api_frame, bg=TEAL, width=3).pack(side="left", fill="y") api_entry = tk.Entry(api_frame, textvariable=self._api_url_var, @@ -145,17 +145,6 @@ def _animate(self): def _run_prediction(self): # Import the model lazily so the GUI works even before models/growth.py exists. - try: - from models.growth import predict_growth - except ModuleNotFoundError: - messagebox.showinfo( - "Model not available yet", - "The growth model (models/growth.py) hasn't been added yet.\n" - "The page will work as soon as it's in place.") - return - except Exception as ex: - messagebox.showerror("Model Error", f"Could not load the growth model:\n{ex}") - return data = {} for key, label in self.FIELDS: @@ -167,13 +156,21 @@ def _run_prediction(self): messagebox.showerror("Invalid Input", f"'{label}' must be a number."); return color = self._color_var.get() api_url = self._api_url_var.get().strip() + vector = [1 if c == color else 0 for c in self.COLORS] + payload = {**data, "color_before": vector} if not api_url: messagebox.showwarning("Missing API URL", "Please enter the growth API URL."); return self._show_loading() def _call(): try: - rep = predict_growth(data, color, api_url, user_id=1, timeout=10) - self.after(0, lambda: self._show_result(rep)) + import requests as req_lib + response = req_lib.post(api_url, json=payload, timeout=10) + if response.status_code == 200: + rep = response.json() + self.after(0, lambda: self._show_result(rep)) + else: + msg = f"Server returned status {response.status_code}.\n{response.text[:200]}" + self.after(0, lambda m=msg: self._show_error(m)) except Exception as ex: self.after(0, lambda e=ex: self._show_error(str(e))) threading.Thread(target=_call, daemon=True).start() @@ -202,7 +199,7 @@ def _show_result(self, rep): height_row.pack(fill="x", pady=(0,8)) tk.Label(height_row, text="Predicted height growth", font=self.f_small, bg=BG_GLASS, fg=TEXT_SEC).pack(anchor="w") val_row = tk.Frame(height_row, bg=BG_GLASS); val_row.pack(anchor="w") - tk.Label(val_row, text=f"{guess}", font=("Segoe UI",30,"bold"), bg=BG_GLASS, fg=TEAL).pack(side="left") + tk.Label(val_row, text=f"{guess:.3f}", font=("Segoe UI",30,"bold"), bg=BG_GLASS, fg=TEAL).pack(side="left") tk.Label(val_row, text=" cm", font=("Segoe UI",12), bg=BG_GLASS, fg=TEXT_SEC).pack(side="left", pady=(10,0)) color_row = tk.Frame(inner, bg=BG_GLASS, padx=16, pady=14)