diff --git a/.gitignore b/.gitignore index 9f098a5..5d22dd8 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ EOL # OS .DS_Store Thumbs.db + +get_tree.sh diff --git a/backend/app/config.py b/backend/app/config.py index c04fcfd..44328e2 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -9,7 +9,13 @@ class Settings(BaseSettings): # ML Model Paths MODEL_CLASSIFIER_PATH: str = "app/ml_models/classifier_model.pkl" - MODEL_REGRESSOR_PATH: str = "app/ml_models/regressor_model.pkl" + + # Regression/Ensemble Model Paths (commented out for classification-only mode) + # MODEL_REGRESSOR_PATH: str = "app/ml_models/regressor_model.pkl" # Legacy - kept for compatibility + # MODEL_XGBOOST_REGRESSOR_PATH: str = "app/ml_models/xgboost_regressor.pkl" + # MODEL_ATTENTIVEFP_PATH: str = "app/ml_models/attentivefp_regressor.pt" + # MODEL_DIMENET_PATH: str = "app/ml_models/dimenet_regressor.pt" + # MODEL_BLENDER_PATH: str = "app/ml_models/blender_model.pkl" # Feature Extraction Settings FEATURE_COUNT: int = 4200 diff --git a/backend/app/models.py b/backend/app/models.py index af9d7a4..1af20aa 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -22,7 +22,10 @@ class PredictionResult(BaseModel): smiles: str = Field(..., description="Input SMILES string") prediction: float = Field(..., description="Predicted permeability value") confidence: float = Field(..., description="Model confidence score") + uncertainty: Optional[float] = Field(None, description="Prediction uncertainty from ensemble variance") + ensemble_std: Optional[float] = Field(None, description="Standard deviation of ensemble predictions") classifier_prediction: int = Field(..., description="Binary classifier prediction (0 or 1)") + ensemble_predictions: Optional[List[float]] = Field(None, description="Individual ensemble model predictions") features: Optional[PredictionFeatures] = Field(None, description="Extracted molecular features") error: Optional[str] = Field(None, description="Error message if prediction failed") diff --git a/backend/app/utils/processing.py b/backend/app/utils/processing.py index 52df67a..a40c474 100644 --- a/backend/app/utils/processing.py +++ b/backend/app/utils/processing.py @@ -55,30 +55,46 @@ def combine_features(features: Dict[str, Any]) -> np.ndarray: return combined.reshape(1, -1) -async def process_batch(smiles_batch: List[str], session) -> List[PredictionResponse]: - """Process a batch of SMILES strings.""" - results = [] - features_list = [] - valid_smiles = [] +def validate_feature_dimensions(features: Dict[str, Any]) -> bool: + """Validate that extracted features match expected dimensions.""" + try: + # Check Morgan fingerprint dimensions + if len(features['morgan_fingerprint']) != settings.FEATURE_COUNT: + return False + + # Check descriptor keys + expected_descriptors = { + 'MolWt', 'LogP', 'TPSA', 'NumHDonors', + 'NumHAcceptors', 'NumRotatableBonds', 'NumAromaticRings' + } + if not expected_descriptors.issubset(features['descriptors'].keys()): + return False + + return True + except (KeyError, TypeError): + return False - for smiles in smiles_batch: - try: - # Get features as int8 for memory efficiency - features = smiles_to_features(smiles) - features_list.append(features) - valid_smiles.append(smiles) - except ValueError as e: - error_msg = f"Invalid SMILES structure: {str(e)}. Please check for correct syntax and valid atoms." - results.append(PredictionResponse(smiles=smiles, prediction=0.0, probability=0.0, error=error_msg)) - if features_list: - # Stack the features and convert to float32 just before prediction - features_array = np.vstack(features_list).astype(np.float32) - input_name = session.get_inputs()[0].name - predictions = session.run(None, {input_name: features_array})[0] - - for smiles, pred in zip(valid_smiles, predictions): - prob = float(pred[1]) - results.append(PredictionResponse(smiles=smiles, prediction=1 if prob >= 0.5 else 0, probability=prob)) - - return results +def normalize_features(feature_vector: np.ndarray) -> np.ndarray: + """Apply z-score normalization to continuous features, leave binary fingerprints unchanged.""" + # First 4200 features are Morgan fingerprints (binary, leave unchanged) + fingerprint_features = feature_vector[:, :settings.FEATURE_COUNT] + + # Remaining features are continuous descriptors (apply z-score normalization) + if feature_vector.shape[1] > settings.FEATURE_COUNT: + descriptor_features = feature_vector[:, settings.FEATURE_COUNT:] + + # Simple z-score normalization (mean=0, std=1) + descriptor_mean = np.mean(descriptor_features, axis=0) + descriptor_std = np.std(descriptor_features, axis=0) + descriptor_std = np.where(descriptor_std == 0, 1, descriptor_std) # Avoid division by zero + + normalized_descriptors = (descriptor_features - descriptor_mean) / descriptor_std + + # Combine normalized descriptors with unchanged fingerprints + normalized_features = np.concatenate([fingerprint_features, normalized_descriptors], axis=1) + else: + # Only fingerprint features + normalized_features = fingerprint_features + + return normalized_features diff --git a/backend/app/worker.py b/backend/app/worker.py index 00180e1..0fa3568 100644 --- a/backend/app/worker.py +++ b/backend/app/worker.py @@ -1,13 +1,10 @@ from celery import Celery -from celery.signals import worker_process_init -from typing import List, Dict, Any, Union +from typing import List, Dict, Any, Optional import logging import pickle -import joblib import numpy as np import os -from datetime import datetime -from pathlib import Path +import torch from app.config import settings from app.utils.logger import setup_logging @@ -31,120 +28,143 @@ enable_utc=True, ) -# Global model variables - initialized by worker_process_init +# Load ML models at startup classifier_model = None -regressor_models = None # Will hold ensemble of models -def _load_model(model_path: Union[str, Path]): - """Load a model from file, trying different formats.""" - model_path = Path(model_path) - try: - # Try joblib first (scikit-learn models) - return joblib.load(model_path) - except Exception: - try: - # Try pickle as fallback - with open(model_path, 'rb') as f: - return pickle.load(f) - except Exception as e: - raise RuntimeError(f"Could not load model from {model_path}: {str(e)}") - -def _load_ensemble_regressors(): - """ - Load ensemble of regressor models. - Currently loads single model but structured for future ensemble expansion. - """ - # TODO: In future, load multiple models (XGBoost, AttentiveFP, DimeNet++) - # For now, load the single regressor model - base_regressor = _load_model(settings.MODEL_REGRESSOR_PATH) - - # Structure as list for future ensemble expansion - return { - 'models': [base_regressor], - 'model_names': ['base_regressor'], - 'weights': [1.0] # Equal weighting for future ensemble - } - -@worker_process_init.connect -def init_worker(**kwargs): - """ - Initialize models when worker process starts. - This is called once per worker process and is more robust than checking globals. - """ - global classifier_model, regressor_models +def load_models(): + """Load the classification model (simplified for classification-only mode).""" + global classifier_model try: - logger.info("Initializing worker process - loading ML models...") - # Load binary classifier - classifier_model = _load_model(settings.MODEL_CLASSIFIER_PATH) - logger.info(f"Classifier model loaded: {settings.MODEL_CLASSIFIER_PATH}") - - # Load ensemble regressors - regressor_models = _load_ensemble_regressors() - logger.info(f"Regressor ensemble loaded with {len(regressor_models['models'])} model(s)") - - logger.info("Worker process initialization complete") - + classifier_path = settings.MODEL_CLASSIFIER_PATH + if os.path.exists(classifier_path): + with open(classifier_path, 'rb') as f: + classifier_model = pickle.load(f) + logger.info("Classifier model loaded successfully") + else: + logger.warning(f"Classifier model not found at {classifier_path}") + + if classifier_model is None: + logger.error("Classifier model could not be loaded - check model file path") + except Exception as e: - logger.error(f"Failed to initialize worker process: {e}") + logger.error(f"Failed to load classifier model: {e}") raise -def _predict_with_ensemble(feature_vector: np.ndarray) -> Dict[str, float]: - """ - Predict using ensemble of regressor models and calculate confidence interval. - Currently uses single model but structured for future ensemble expansion. - """ - global regressor_models - - predictions = [] - - # Get predictions from all models in ensemble - for i, model in enumerate(regressor_models['models']): - weight = regressor_models['weights'][i] - pred = model.predict(feature_vector)[0] - predictions.append(pred * weight) - - # Calculate ensemble statistics - ensemble_prediction = np.mean(predictions) - ensemble_std = np.std(predictions) if len(predictions) > 1 else 0.0 - - # Calculate confidence based on ensemble variance - # Lower variance = higher confidence (inverse relationship) - # TODO: Calibrate this confidence calculation based on validation data - confidence_from_variance = max(0.1, 1.0 - min(ensemble_std / ensemble_prediction, 0.9)) if ensemble_prediction > 0 else 0.1 - +# COMMENTED OUT - Regression/Ensemble functionality (for future use) +# ensemble_regressors = {} +# blender_model = None + +# def load_ensemble_models(): +# """Load all models for the two-step ensemble prediction pipeline.""" +# global ensemble_regressors, blender_model +# +# try: +# # Load ensemble regressors +# regressor_paths = { +# 'xgboost': os.path.join(os.path.dirname(settings.MODEL_CLASSIFIER_PATH), 'xgboost_regressor.pkl'), +# 'attentivefp': os.path.join(os.path.dirname(settings.MODEL_CLASSIFIER_PATH), 'attentivefp_regressor.pt'), +# 'dimenet': os.path.join(os.path.dirname(settings.MODEL_CLASSIFIER_PATH), 'dimenet_regressor.pt') +# } +# +# for name, path in regressor_paths.items(): +# if os.path.exists(path): +# if path.endswith('.pkl'): +# with open(path, 'rb') as f: +# ensemble_regressors[name] = pickle.load(f) +# elif path.endswith('.pt'): +# # For PyTorch models, we'll need the model architecture loaded separately +# # For now, just log that we found the file +# logger.info(f"Found {name} model at {path} (PyTorch loading not implemented yet)") +# logger.info(f"Regressor {name} loaded successfully") +# else: +# logger.warning(f"Regressor {name} not found at {path}") +# +# # Load blender model +# blender_path = os.path.join(os.path.dirname(settings.MODEL_CLASSIFIER_PATH), 'blender_model.pkl') +# if os.path.exists(blender_path): +# with open(blender_path, 'rb') as f: +# blender_model = pickle.load(f) +# logger.info("Blender model loaded successfully") +# else: +# logger.warning(f"Blender model not found at {blender_path}") +# +# except Exception as e: +# logger.error(f"Failed to load ensemble models: {e}") +# raise + + +# COMMENTED OUT - Ensemble prediction functions (for future use) +# def get_ensemble_predictions(feature_vector: np.ndarray) -> List[float]: +# """Get predictions from all available ensemble regressors.""" +# predictions = [] +# +# # XGBoost regressor +# if 'xgboost' in ensemble_regressors: +# try: +# xgb_pred = ensemble_regressors['xgboost'].predict(feature_vector)[0] +# predictions.append(float(xgb_pred)) +# except Exception as e: +# logger.warning(f"XGBoost regressor failed: {e}") +# +# # PyTorch models (AttentiveFP, DimeNet++) - placeholder for now +# # TODO: Implement when PyTorch model architectures are available +# for model_name in ['attentivefp', 'dimenet']: +# if model_name in ensemble_regressors: +# logger.warning(f"{model_name} prediction not yet implemented") +# +# return predictions + + +# def calculate_confidence_interval(predictions: List[float], classifier_confidence: float) -> Dict[str, float]: +# """Calculate calibrated confidence interval from ensemble variance.""" +# if len(predictions) == 0: +# return {'confidence': 0.0, 'uncertainty': 1.0, 'ensemble_std': 0.0} +# +# if len(predictions) == 1: +# # Single model - use classifier confidence +# return { +# 'confidence': classifier_confidence, +# 'uncertainty': 1.0 - classifier_confidence, +# 'ensemble_std': 0.0 +# } +# +# # Multiple models - calculate ensemble statistics +# ensemble_mean = np.mean(predictions) +# ensemble_std = np.std(predictions) +# +# # Combine classifier confidence with ensemble uncertainty +# # Higher std = lower confidence +# ensemble_confidence = classifier_confidence * np.exp(-ensemble_std) +# +# return { +# 'confidence': float(ensemble_confidence), +# 'uncertainty': float(ensemble_std), +# 'ensemble_std': float(ensemble_std) +# } + + +def calculate_classification_confidence(classifier_proba: np.ndarray) -> Dict[str, float]: + """Calculate confidence metrics for classification predictions.""" + max_proba = float(np.max(classifier_proba)) return { - 'prediction': float(ensemble_prediction), - 'confidence_from_ensemble': float(confidence_from_variance), - 'ensemble_std': float(ensemble_std), - 'individual_predictions': [float(p) for p in predictions] + 'confidence': max_proba, + 'uncertainty': 1.0 - max_proba, + 'class_probabilities': classifier_proba.tolist() } -def _ensure_feature_consistency(features: Dict[str, Any]) -> np.ndarray: - """ - Ensure deterministic feature vector ordering for model input. - Combines molecular descriptors and Morgan fingerprint in consistent order. - """ - # Get molecular descriptors in sorted key order for consistency - descriptor_values = [features['descriptors'][key] for key in sorted(features['descriptors'].keys())] - - # Combine with Morgan fingerprint - feature_vector = np.array(descriptor_values + features['morgan_fingerprint']) - - # Reshape for model input (models expect 2D array) - return feature_vector.reshape(1, -1) @celery_app.task(bind=True, name="predict_permeability") -def predict_permeability(self, smiles_list: List[str], created_at: str = None, job_name: str = None) -> Dict[str, Any]: +def predict_permeability(self, smiles_list: List[str]) -> Dict[str, Any]: """ - Predict permeability for a list of SMILES strings using the two-stage pipeline. + Predict permeability for a list of SMILES strings using classification model only. """ try: - # Models should be loaded by worker_process_init signal - if classifier_model is None or regressor_models is None: - raise RuntimeError("Models not initialized. Worker process initialization may have failed.") + # Ensure models are loaded + if classifier_model is None: + load_models() results = [] @@ -152,43 +172,30 @@ def predict_permeability(self, smiles_list: List[str], created_at: str = None, j try: # Extract features using processing.py logic features = smiles_to_comprehensive_features(smiles) - feature_vector = _ensure_feature_consistency(features) - - # Stage 1: Binary classification (near-zero vs non-zero accumulation) - classifier_pred = classifier_model.predict(feature_vector)[0] - classifier_prob = classifier_model.predict_proba(feature_vector)[0] + feature_vector = combine_features(features) - if classifier_pred == 0: # Near-zero accumulation - prediction = 0.0 - confidence = float(classifier_prob[0]) # Confidence in "near-zero" prediction - ensemble_info = None + # Classification prediction + if classifier_model is not None: + classifier_pred = classifier_model.predict(feature_vector)[0] + classifier_prob = classifier_model.predict_proba(feature_vector)[0] + confidence_stats = calculate_classification_confidence(classifier_prob) else: - # Stage 2: Ensemble regression for specific permeability level - ensemble_result = _predict_with_ensemble(feature_vector) - prediction = ensemble_result['prediction'] - - # Combine classifier confidence with ensemble confidence - classifier_confidence = float(classifier_prob[1]) # Confidence in "non-zero" prediction - ensemble_confidence = ensemble_result['confidence_from_ensemble'] - - # Overall confidence is combination of both stages - # TODO: Calibrate this combination based on validation data - confidence = (classifier_confidence + ensemble_confidence) / 2.0 - - ensemble_info = { - 'ensemble_std': ensemble_result['ensemble_std'], - 'individual_predictions': ensemble_result['individual_predictions'], - 'classifier_confidence': classifier_confidence, - 'ensemble_confidence': ensemble_confidence - } + # Fallback if no classifier + classifier_pred = 0 + confidence_stats = {'confidence': 0.0, 'uncertainty': 1.0, 'class_probabilities': [0.5, 0.5]} + logger.warning("No classifier model available - defaulting to non-permeant") + + # Convert classification to binary prediction + prediction = 1 if classifier_pred == 1 else 0 # 1 = permeant, 0 = non-permeant result = { 'smiles': smiles, 'prediction': prediction, - 'confidence': confidence, + 'confidence': confidence_stats['confidence'], + 'uncertainty': confidence_stats['uncertainty'], + 'class_probabilities': confidence_stats['class_probabilities'], 'classifier_prediction': int(classifier_pred), 'features': features, - 'ensemble_info': ensemble_info, # Additional debugging info 'error': None } @@ -196,28 +203,23 @@ def predict_permeability(self, smiles_list: List[str], created_at: str = None, j logger.error(f"Error processing SMILES {smiles}: {e}") result = { 'smiles': smiles, - 'prediction': 0.0, + 'prediction': 0, 'confidence': 0.0, + 'uncertainty': 1.0, + 'class_probabilities': [0.5, 0.5], 'classifier_prediction': 0, 'features': None, - 'ensemble_info': None, 'error': str(e) } results.append(result) - # Capture completion timestamp - completed_at = datetime.now().isoformat() - return { 'status': 'completed', 'results': results, 'total_processed': len(results), 'successful': len([r for r in results if r['error'] is None]), - 'failed': len([r for r in results if r['error'] is not None]), - 'created_at': created_at or datetime.now().isoformat(), - 'completed_at': completed_at, - 'job_name': job_name + 'failed': len([r for r in results if r['error'] is not None]) } except Exception as e: @@ -229,5 +231,8 @@ def predict_permeability(self, smiles_list: List[str], created_at: str = None, j 'results': [] } -# Models are now initialized by @worker_process_init signal -# This ensures proper initialization per worker process \ No newline at end of file +# Initialize models when worker starts +try: + load_models() +except Exception as e: + logger.warning(f"Models not loaded at startup: {e}") \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 69d1c21..05c13a9 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,6 +7,7 @@ python-dotenv numpy rdkit scikit-learn +xgboost joblib pydantic pydantic-settings diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index c841ea6..1011e21 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -1,54 +1,7 @@ import pytest -from fastapi.testclient import TestClient -from unittest.mock import patch, MagicMock, AsyncMock import numpy as np -import uuid -from datetime import datetime -from app.utils.processing import smiles_to_features - -# Mock model validation and Celery tasks before importing app -mock_models = { - "classifier": MagicMock(), - "regressor": MagicMock() -} -mock_models["classifier"].predict.return_value = np.array([1]) -mock_models["classifier"].predict_proba.return_value = np.array([[0.3, 0.7]]) -mock_models["regressor"].predict.return_value = np.array([0.85]) - -mock_celery_task = MagicMock() -mock_celery_task.id = str(uuid.uuid4()) -mock_celery_task.state = 'SUCCESS' -mock_celery_task.result = { - 'results': [{ - 'smiles': 'CC(=O)OC1=CC=CC=C1C(=O)O', - 'prediction': 0.85, - 'confidence': 0.7, - 'classifier_prediction': 1, - 'features': { - 'morgan_fingerprint': [1, 0, 1, 0] * 1050, # 4200 features - 'descriptors': { - 'MolWt': 180.16, - 'LogP': 1.19, - 'TPSA': 63.6, - 'NumHDonors': 1, - 'NumHAcceptors': 4, - 'NumRotatableBonds': 3, - 'NumAromaticRings': 1 - } - } - }], - 'total_processed': 1, - 'successful': 1, - 'failed': 0 -} - -with patch("app.utils.validation.validate_models", return_value=mock_models), \ - patch("app.worker.celery_app.send_task", return_value=mock_celery_task), \ - patch("app.worker.celery_app.AsyncResult", return_value=mock_celery_task): - from app.main import app - -client = TestClient(app) +from app.utils.processing import smiles_to_features, smiles_to_comprehensive_features, combine_features # Test data VALID_SMILES = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin @@ -69,211 +22,35 @@ def test_smiles_to_features_invalid(): smiles_to_features(INVALID_SMILES) -def test_submit_prediction_job_single(): - """Test GraphQL mutation to submit a single prediction job.""" - query = """ - mutation SubmitPredictionJob($input: PredictionJobInput!) { - submitPredictionJob(jobInput: $input) { - jobId - status - createdAt - progress - error - } - } - """ - variables = { - "input": { - "smilesList": [VALID_SMILES], - "jobName": "test_single_prediction" - } - } - - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 - - data = response.json() - assert "data" in data - assert "submitPredictionJob" in data["data"] - - job_data = data["data"]["submitPredictionJob"] - assert job_data["jobId"] is not None - assert job_data["status"] == "submitted" - assert job_data["error"] is None - - -def test_submit_prediction_job_batch(): - """Test GraphQL mutation to submit a batch prediction job.""" - query = """ - mutation SubmitPredictionJob($input: PredictionJobInput!) { - submitPredictionJob(jobInput: $input) { - jobId - status - createdAt - progress - error - } - } - """ - variables = { - "input": { - "smilesList": [VALID_SMILES, "CCO", "CCC"], # Multiple SMILES - "jobName": "test_batch_prediction" - } - } - - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 - - data = response.json() - job_data = data["data"]["submitPredictionJob"] - assert job_data["jobId"] is not None - assert job_data["status"] == "submitted" - assert "3 compounds" in job_data["progress"] - - -def test_submit_prediction_job_empty_list(): - """Test GraphQL mutation with empty SMILES list should return error.""" - query = """ - mutation SubmitPredictionJob($input: PredictionJobInput!) { - submitPredictionJob(jobInput: $input) { - jobId - status - error - } - } - """ - variables = { - "input": { - "smilesList": [], - "jobName": "test_empty" - } - } - - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 - - data = response.json() - job_data = data["data"]["submitPredictionJob"] - assert job_data["status"] == "error" - assert "cannot be empty" in job_data["error"] - - -def test_get_prediction_result(): - """Test GraphQL query to get prediction results.""" - query = """ - query GetPredictionResult($jobId: String!) { - getPredictionResult(jobId: $jobId) { - ... on JobResult { - status - results { - smiles - prediction - confidence - classifierPrediction - features { - morganFingerprint - descriptors { - molWt - logP - tpsa - numHDonors - numHAcceptors - numRotatableBonds - numAromaticRings - } - } - error - } - totalProcessed - successful - failed - jobId - createdAt - completedAt - } - ... on JobStatus { - jobId - status - createdAt - progress - error - } - } - } - """ - variables = {"jobId": mock_celery_task.id} +def test_smiles_to_comprehensive_features(): + """Test comprehensive feature extraction.""" + features = smiles_to_comprehensive_features(VALID_SMILES) - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 + # Check structure + assert 'morgan_fingerprint' in features + assert 'descriptors' in features - data = response.json() - assert "data" in data - result_data = data["data"]["getPredictionResult"] + # Check Morgan fingerprint + assert len(features['morgan_fingerprint']) == 4200 + assert all(isinstance(x, int) for x in features['morgan_fingerprint']) - assert result_data["status"] == "completed" - assert result_data["totalProcessed"] == 1 - assert result_data["successful"] == 1 - assert result_data["failed"] == 0 - assert len(result_data["results"]) == 1 - - prediction = result_data["results"][0] - assert prediction["smiles"] == VALID_SMILES - assert prediction["prediction"] == 0.85 - assert prediction["confidence"] == 0.7 - assert prediction["classifierPrediction"] == 1 - assert prediction["features"] is not None - - -def test_get_job_status(): - """Test GraphQL query to get job status.""" - query = """ - query GetJobStatus($jobId: String!) { - getJobStatus(jobId: $jobId) { - jobId - status - createdAt - progress - error - } + # Check descriptors + expected_descriptors = { + 'MolWt', 'LogP', 'TPSA', 'NumHDonors', + 'NumHAcceptors', 'NumRotatableBonds', 'NumAromaticRings' } - """ - variables = {"jobId": mock_celery_task.id} - - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 + assert expected_descriptors.issubset(features['descriptors'].keys()) - data = response.json() - status_data = data["data"]["getJobStatus"] - - assert status_data["jobId"] == mock_celery_task.id - assert status_data["status"] == "completed" - assert status_data["error"] is None + # Check descriptor values are reasonable for aspirin + assert features['descriptors']['MolWt'] > 100 # Should be around 180 + assert features['descriptors']['NumAromaticRings'] >= 1 # Aspirin has benzene ring -def test_get_nonexistent_job(): - """Test GraphQL query with non-existent job ID.""" - fake_job_id = str(uuid.uuid4()) - - # Create a mock for non-existent job - mock_nonexistent_task = MagicMock() - mock_nonexistent_task.state = 'PENDING' - - query = """ - query GetJobStatus($jobId: String!) { - getJobStatus(jobId: $jobId) { - jobId - status - error - } - } - """ - variables = {"jobId": fake_job_id} +def test_combine_features(): + """Test feature combination into single vector.""" + features = smiles_to_comprehensive_features(VALID_SMILES) + combined = combine_features(features) - with patch("app.worker.celery_app.AsyncResult", return_value=mock_nonexistent_task): - response = client.post("/graphql", json={"query": query, "variables": variables}) - assert response.status_code == 200 - - data = response.json() - status_data = data["data"]["getJobStatus"] - assert status_data["status"] == "pending" + # Should be 2D array with shape (1, n_features) + assert combined.shape[0] == 1 + assert combined.shape[1] == 4200 + 7 # 4200 Morgan + 7 descriptors diff --git a/frontend/package.json b/frontend/package.json index 88054ba..9f6ee13 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -9,11 +9,14 @@ "lint": "next lint" }, "dependencies": { + "@apollo/client": "^3.13.8", "@radix-ui/react-alert-dialog": "^1.1.5", + "@radix-ui/react-progress": "^1.1.7", "@radix-ui/react-slot": "^1.1.1", "@radix-ui/react-tabs": "^1.1.2", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", + "graphql": "^16.11.0", "lucide-react": "^0.474.0", "next": "15.1.5", "react": "^19.0.0", diff --git a/frontend/src/app/api/graphql/route.ts b/frontend/src/app/api/graphql/route.ts new file mode 100644 index 0000000..40b207b --- /dev/null +++ b/frontend/src/app/api/graphql/route.ts @@ -0,0 +1,46 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { mockResolvers } from '@/lib/mock-resolvers'; + +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { query, variables } = body; + + // Simple GraphQL query parsing and routing + if (query.includes('submitPredictionJob') && !query.includes('Batch')) { + const result = mockResolvers.Mutation.submitPredictionJob(null, variables); + return NextResponse.json({ + data: { submitPredictionJob: result } + }); + } + + if (query.includes('submitBatchPredictionJob')) { + const result = mockResolvers.Mutation.submitBatchPredictionJob(null, variables); + return NextResponse.json({ + data: { submitBatchPredictionJob: result } + }); + } + + if (query.includes('getPredictionResult')) { + const result = mockResolvers.Query.getPredictionResult(null, variables); + return NextResponse.json({ + data: { getPredictionResult: result } + }); + } + + return NextResponse.json({ + errors: [{ message: 'Unknown query' }] + }, { status: 400 }); + + } catch (error) { + return NextResponse.json({ + errors: [{ message: 'Invalid request' }] + }, { status: 400 }); + } +} + +export async function GET() { + return NextResponse.json({ + message: 'GraphQL endpoint - use POST requests' + }); +} \ No newline at end of file diff --git a/frontend/src/app/editor/page.tsx b/frontend/src/app/editor/page.tsx new file mode 100644 index 0000000..16c8f2d --- /dev/null +++ b/frontend/src/app/editor/page.tsx @@ -0,0 +1,74 @@ +'use client'; + +import { useState } from 'react'; +import { useRouter } from 'next/navigation'; +import ChemicalEditor from '@/components/ChemicalEditor'; +import { Button } from '@/components/ui/button'; +import { ArrowLeft, ArrowRight } from 'lucide-react'; + +export default function EditorPage() { + const router = useRouter(); + const [selectedSmiles, setSelectedSmiles] = useState(''); + + const handleSmilesGenerated = (smiles: string) => { + setSelectedSmiles(smiles); + }; + + const handleProceedToPrediction = () => { + // Navigate to main page with the SMILES string + const params = new URLSearchParams({ smiles: selectedSmiles }); + router.push(`/?${params}`); + }; + + return ( +
+
+ {/* Navigation */} +
+ + + {selectedSmiles && ( + + )} +
+ + {/* Chemical Editor */} + + + {/* Selected SMILES Display */} + {selectedSmiles && ( +
+

Selected Structure:

+
+ {selectedSmiles} +
+
+ + +
+
+ )} +
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/app/layout.tsx b/frontend/src/app/layout.tsx index f7fa87e..1d0f2e6 100644 --- a/frontend/src/app/layout.tsx +++ b/frontend/src/app/layout.tsx @@ -1,6 +1,7 @@ import type { Metadata } from "next"; import { Geist, Geist_Mono } from "next/font/google"; import "./globals.css"; +import { GraphQLProvider } from "@/lib/apollo-provider"; const geistSans = Geist({ variable: "--font-geist-sans", @@ -13,8 +14,8 @@ const geistMono = Geist_Mono({ }); export const metadata: Metadata = { - title: "Create Next App", - description: "Generated by create next app", + title: "Perm-Predict - Chemical Permeability Prediction", + description: "Advanced machine learning-based prediction of chemical accumulation in bacteria", }; export default function RootLayout({ @@ -27,7 +28,9 @@ export default function RootLayout({ - {children} + + {children} + ); diff --git a/frontend/src/app/page.tsx b/frontend/src/app/page.tsx index 85a09ae..69b755f 100644 --- a/frontend/src/app/page.tsx +++ b/frontend/src/app/page.tsx @@ -1,11 +1,45 @@ -'use client' +'use client'; -import { PredictionForm } from '@/components/PredictionForm' +import { useSearchParams } from 'next/navigation'; +import { useEffect, useState } from 'react'; +import { Button } from '@/components/ui/button'; +import { Palette } from 'lucide-react'; +import Link from 'next/link'; +import PredictionForm from '@/components/PredictionForm'; export default function Home() { + const searchParams = useSearchParams(); + const [initialSmiles, setInitialSmiles] = useState(''); + + useEffect(() => { + const smilesParam = searchParams.get('smiles'); + if (smilesParam) { + setInitialSmiles(smilesParam); + } + }, [searchParams]); + return (
- +
+ {/* Header with navigation */} +
+

+ Perm-Predict +

+

+ Advanced machine learning-based prediction of chemical accumulation in bacteria +

+ + + + +
+ + +
- ) + ); } \ No newline at end of file diff --git a/frontend/src/components/ChemicalEditor.tsx b/frontend/src/components/ChemicalEditor.tsx new file mode 100644 index 0000000..37cdcf4 --- /dev/null +++ b/frontend/src/components/ChemicalEditor.tsx @@ -0,0 +1,162 @@ +import React, { useState } from 'react'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; +import { Alert, AlertDescription } from '@/components/ui/alert'; +import { Copy, Download, Upload, Palette, Info } from 'lucide-react'; + +interface ChemicalEditorProps { + onSmilesGenerated: (smiles: string) => void; +} + +const ChemicalEditor = ({ onSmilesGenerated }: ChemicalEditorProps) => { + const [smilesInput, setSmilesInput] = useState(''); + const [showEditor, setShowEditor] = useState(false); + + // Sample common chemical structures for quick testing + const sampleStructures = [ + { name: 'Ethanol', smiles: 'CCO', description: 'Simple alcohol' }, + { name: 'Caffeine', smiles: 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', description: 'Stimulant alkaloid' }, + { name: 'Aspirin', smiles: 'CC(=O)OC1=CC=CC=C1C(=O)O', description: 'Pain reliever' }, + { name: 'Glucose', smiles: 'C([C@@H]1[C@H]([C@@H]([C@H]([C@H](O1)O)O)O)O)O', description: 'Simple sugar' }, + { name: 'Benzene', smiles: 'c1ccccc1', description: 'Aromatic hydrocarbon' }, + { name: 'Penicillin G', smiles: 'CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C', description: 'Antibiotic' }, + ]; + + const copyToClipboard = (smiles: string) => { + navigator.clipboard.writeText(smiles); + }; + + const handleUseSample = (smiles: string) => { + setSmilesInput(smiles); + onSmilesGenerated(smiles); + }; + + const handleManualInput = () => { + if (smilesInput.trim()) { + onSmilesGenerated(smilesInput.trim()); + } + }; + + return ( + + + + + Chemical Structure Editor + + + Draw or input chemical structures to generate SMILES notation + + + + + + Sample Structures + Manual SMILES + + Draw Structure + (Coming Soon) + + + + +
+ {sampleStructures.map((structure, index) => ( +
+
+
+

{structure.name}

+

{structure.description}

+
+ +
+
+ {structure.smiles} + +
+
+ ))} +
+
+ + +
+
+ setSmilesInput(e.target.value)} + className="font-mono" + onKeyPress={(e) => e.key === 'Enter' && handleManualInput()} + /> +
+
+ + +
+
+
+ + + + + +
+

Interactive Chemical Editor Coming Soon!

+

+ We're working on integrating Ketcher, a professional chemical structure editor. + For now, you can use the sample structures or input SMILES notation manually. +

+
+

• Draw chemical structures with mouse/touch

+

• Automatic SMILES generation

+

• Structure validation and optimization

+

• Import/export various chemical formats

+
+
+
+
+ + {/* Placeholder for future Ketcher integration */} +
+
+ +

Chemical Structure Editor

+

Will be integrated here

+
+
+
+
+
+
+ ); +}; + +export default ChemicalEditor; \ No newline at end of file diff --git a/frontend/src/components/PredictionForm.tsx b/frontend/src/components/PredictionForm.tsx index f2c8388..d87a174 100644 --- a/frontend/src/components/PredictionForm.tsx +++ b/frontend/src/components/PredictionForm.tsx @@ -1,128 +1,265 @@ -import React, { useState } from 'react'; +import React, { useState, useEffect } from 'react'; +import { useMutation, useLazyQuery } from '@apollo/client'; import { Alert, AlertDescription } from '@/components/ui/alert'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; +import { Progress } from '@/components/ui/progress'; +import { Loader2, CheckCircle, AlertCircle } from 'lucide-react'; import PredictionResults from './PredictionResults'; +import { SUBMIT_PREDICTION_JOB, SUBMIT_BATCH_PREDICTION_JOB, GET_PREDICTION_RESULT } from '@/lib/graphql/queries'; -import type { PredictionResult, SinglePredictionRequest } from '@/lib/types' +import type { PredictionResult, JobResponse, JobResult } from '@/lib/types' -const PredictionForm = () => { - const [smilesInput, setSmilesInput] = useState(''); - const [file, setFile] = useState(null); +interface PredictionFormProps { + initialSmiles?: string; +} + +const PredictionForm = ({ initialSmiles = '' }: PredictionFormProps) => { + const [smilesInput, setSmilesInput] = useState(initialSmiles); + const [batchInput, setBatchInput] = useState(''); const [results, setResults] = useState([]); const [error, setError] = useState(''); - const [loading, setLoading] = useState(false); + const [currentJobId, setCurrentJobId] = useState(null); + const [jobStatus, setJobStatus] = useState<'idle' | 'pending' | 'processing' | 'completed' | 'failed'>('idle'); + const [progress, setProgress] = useState(0); + + // GraphQL hooks + const [submitPredictionJob] = useMutation(SUBMIT_PREDICTION_JOB); + const [submitBatchPredictionJob] = useMutation(SUBMIT_BATCH_PREDICTION_JOB); + const [getPredictionResult, { data: jobResult, stopPolling }] = useLazyQuery(GET_PREDICTION_RESULT, { + pollInterval: 2000, + errorPolicy: 'all', + }); + + // Effect to handle job polling results + useEffect(() => { + if (jobResult?.getPredictionResult) { + const result = jobResult.getPredictionResult as JobResult; + setJobStatus(result.status); + + if (result.status === 'processing') { + setProgress(prev => Math.min(prev + 10, 90)); // Simulate progress + } else if (result.status === 'completed') { + setProgress(100); + stopPolling(); + + if (result.result) { + if (Array.isArray(result.result)) { + setResults(result.result); + } else { + setResults([result.result]); + } + } + + // Reset after a delay + setTimeout(() => { + setJobStatus('idle'); + setCurrentJobId(null); + setProgress(0); + }, 2000); + } else if (result.status === 'failed') { + setError(result.error || 'Prediction failed'); + stopPolling(); + setJobStatus('idle'); + setCurrentJobId(null); + setProgress(0); + } + } + }, [jobResult, stopPolling]); + + // Update input when initialSmiles changes + useEffect(() => { + if (initialSmiles) { + setSmilesInput(initialSmiles); + } + }, [initialSmiles]); - const handleSinglePrediction = async (e) => { + const handleSinglePrediction = async (e: React.FormEvent) => { e.preventDefault(); setError(''); - setLoading(true); - + setResults([]); + setProgress(0); + try { - const response = await fetch('/api/predict/single', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ smiles: smilesInput }) + const { data } = await submitPredictionJob({ + variables: { smiles: smilesInput } }); - - const data = await response.json() as PredictionResult; - if (data.error) { - setError(data.error); - } else { - setResults([data]); + + if (data?.submitPredictionJob) { + const jobResponse = data.submitPredictionJob as JobResponse; + setCurrentJobId(jobResponse.jobId); + setJobStatus('pending'); + setProgress(10); + + // Start polling for results + getPredictionResult({ variables: { jobId: jobResponse.jobId } }); } } catch (err) { - setError('Failed to get prediction. Please try again.'); - } finally { - setLoading(false); + setError('Failed to submit prediction. Please try again.'); + setJobStatus('idle'); } }; - const handleFileUpload = async (e) => { + const handleBatchPrediction = async (e: React.FormEvent) => { e.preventDefault(); - if (!file) { - setError('Please select a file'); + if (!batchInput.trim()) { + setError('Please enter SMILES strings'); return; } setError(''); - setLoading(true); - - const formData = new FormData(); - formData.append('file', file); + setResults([]); + setProgress(0); + + // Parse SMILES strings (one per line or comma-separated) + const smilesStrings = batchInput + .split(/[\n,]/) + .map(s => s.trim()) + .filter(s => s.length > 0); + + if (smilesStrings.length === 0) { + setError('No valid SMILES strings found'); + return; + } try { - const response = await fetch('/api/predict/batch', { - method: 'POST', - body: formData + const { data } = await submitBatchPredictionJob({ + variables: { smilesStrings } }); - - const data = await response.json() as PredictionResult[]; - if (response.ok) { - setResults(data); - } else { - setError(data.detail || 'Failed to process file'); + + if (data?.submitBatchPredictionJob) { + const jobResponse = data.submitBatchPredictionJob as JobResponse; + setCurrentJobId(jobResponse.jobId); + setJobStatus('pending'); + setProgress(10); + + // Start polling for results + getPredictionResult({ variables: { jobId: jobResponse.jobId } }); } } catch (err) { - setError('Failed to upload file. Please try again.'); - } finally { - setLoading(false); + setError('Failed to submit batch prediction. Please try again.'); + setJobStatus('idle'); + } + }; + + const getStatusIcon = () => { + switch (jobStatus) { + case 'pending': + case 'processing': + return ; + case 'completed': + return ; + case 'failed': + return ; + default: + return null; } }; + const getStatusText = () => { + switch (jobStatus) { + case 'pending': + return 'Queued for processing...'; + case 'processing': + return 'Running prediction model...'; + case 'completed': + return 'Prediction completed!'; + case 'failed': + return 'Prediction failed'; + default: + return ''; + } + }; + + const isProcessing = jobStatus === 'pending' || jobStatus === 'processing'; + return (
- Chemical Permeability Prediction + + Chemical Permeability Prediction + {getStatusIcon()} + - Enter SMILES notation or upload a CSV file to predict compound permeability + Enter SMILES notation to predict compound permeability using machine learning - + Single Prediction Batch Prediction - +
setSmilesInput(e.target.value)} - className="w-full" + className="w-full font-mono" + disabled={isProcessing} />
-
- -
+ +
- setFile(e.target.files[0])} - className="w-full" +