Skip to content

feat: Add ML-powered error oracle using aprender #27

Description

@noahgift

Summary

Implement a ML-powered error classification oracle for decy (C-to-Rust transpiler) using the same architecture proven in depyler-oracle. This oracle will classify C/Rust transpilation errors and provide intelligent fix suggestions.


Performance Targets (from depyler-oracle)

Metric Target Depyler Achieved
Accuracy >90% 97.73%
F1 Score >0.90 0.9743
5-Fold CV >85% 90.89%
Training Time <1 sec 0.8 sec
Model Size <1 MB 503 KB
Predictions/sec >1,000 4,140
CLI Latency <500ms 366ms

Directory Structure

decy/
├── crates/
│   └── decy-oracle/
│       ├── Cargo.toml
│       ├── src/
│       │   ├── lib.rs           # Main oracle API
│       │   ├── features.rs      # Feature extraction from errors
│       │   ├── corpus.rs        # Training corpus management
│       │   ├── automl_tuning.rs # AutoML hyperparameter optimization
│       │   └── categories.rs    # Error category definitions
│       ├── tests/
│       │   ├── model_evaluation.rs  # Accuracy & CV tests
│       │   └── integration.rs       # End-to-end tests
│       └── models/
│           └── .gitkeep         # Model files (gitignored)
├── examples/
│   └── oracle_demo.rs           # Demo usage
└── scripts/
    └── train_oracle.sh          # Training script

Step 1: Create Crate Structure

# From decy root
mkdir -p crates/decy-oracle/src
mkdir -p crates/decy-oracle/tests
mkdir -p crates/decy-oracle/models
touch crates/decy-oracle/models/.gitkeep

Step 2: Cargo.toml

# crates/decy-oracle/Cargo.toml

[package]
name = "decy-oracle"
version = "0.1.0"
edition = "2021"
description = "ML-powered error classification oracle for decy C-to-Rust transpiler"
license = "MIT OR Apache-2.0"
repository = "https://github.com/paiml/decy"

[features]
default = ["gpu", "compressed-models"]
gpu = ["aprender/gpu"]                           # RTX 4090 via wgpu backend
compressed-models = ["aprender/format-compression"]  # zstd lossless compression

[dependencies]
# ML library - MUST use crates.io version, not git
aprender = { version = "0.10", default-features = true }

# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

# Error handling
thiserror = "2.0"
anyhow = "1.0"

# Logging
tracing = "0.1"

[dev-dependencies]
tempfile = "3.10"
criterion = { version = "0.5", features = ["html_reports"] }

[[bench]]
name = "oracle_benchmarks"
harness = false

Step 3: Error Categories (categories.rs)

// crates/decy-oracle/src/categories.rs

use serde::{Deserialize, Serialize};

/// C-to-Rust transpilation error categories for ML classification
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum ErrorCategory {
    // C Parsing Errors (0-9)
    CParsePreprocessor = 0,        // #include, #define issues
    CParseDeclaration = 1,         // Variable/function declaration errors
    CParseSyntax = 2,              // General syntax errors
    CParseTypeMismatch = 3,        // Type incompatibility in C

    // Ownership/Lifetime Errors (10-19)
    OwnershipBorrowConflict = 10,  // Cannot borrow as mutable
    OwnershipMovedValue = 11,      // Use of moved value
    OwnershipLifetimeMissing = 12, // Missing lifetime parameter
    OwnershipDanglingPtr = 13,     // Dangling pointer detection

    // Type System Errors (20-29)
    TypeInferenceFailed = 20,      // Cannot infer type
    TypeMismatch = 21,             // Expected X, found Y
    TypeUnsizedLocal = 22,         // Unsized type in local
    TypeTraitNotImpl = 23,         // Trait not implemented

    // Unsafe Code Errors (30-39)
    UnsafeBlockRequired = 30,      // Needs unsafe block
    UnsafeDerefRawPtr = 31,        // Raw pointer dereference
    UnsafeFfiCall = 32,            // FFI safety issue
    UnsafeStaticMut = 33,          // Mutable static access

    // Memory Layout Errors (40-49)
    LayoutStructPadding = 40,      // Struct layout mismatch
    LayoutArraySize = 41,          // Array size issues
    LayoutUnionAccess = 42,        // Union field access
    LayoutBitfield = 43,           // Bitfield transpilation

    // Standard Library Mapping (50-59)
    StdlibStringConv = 50,         // char*/String conversion
    StdlibAllocFree = 51,          // malloc/free to Box/Vec
    StdlibFileOps = 52,            // FILE* to std::fs
    StdlibPrintf = 53,             // printf to format!

    // Macro Expansion (60-69)
    MacroExpansionFailed = 60,     // Macro couldn't expand
    MacroRecursive = 61,           // Recursive macro issue
    MacroVariadic = 62,            // Variadic macro issues

    // Code Generation (70-79)
    CodegenUnsupported = 70,       // Feature not supported
    CodegenAmbiguous = 71,         // Multiple valid translations
    CodegenPerformance = 72,       // Performance concern

    // Unknown (fallback)
    Unknown = 255,
}

impl ErrorCategory {
    /// Number of categories (for ML output layer)
    pub const COUNT: usize = 32;

    /// Convert from u8 class label
    pub fn from_label(label: u8) -> Self {
        match label {
            0 => Self::CParsePreprocessor,
            1 => Self::CParseDeclaration,
            2 => Self::CParseSyntax,
            3 => Self::CParseTypeMismatch,
            10 => Self::OwnershipBorrowConflict,
            11 => Self::OwnershipMovedValue,
            12 => Self::OwnershipLifetimeMissing,
            13 => Self::OwnershipDanglingPtr,
            20 => Self::TypeInferenceFailed,
            21 => Self::TypeMismatch,
            22 => Self::TypeUnsizedLocal,
            23 => Self::TypeTraitNotImpl,
            30 => Self::UnsafeBlockRequired,
            31 => Self::UnsafeDerefRawPtr,
            32 => Self::UnsafeFfiCall,
            33 => Self::UnsafeStaticMut,
            40 => Self::LayoutStructPadding,
            41 => Self::LayoutArraySize,
            42 => Self::LayoutUnionAccess,
            43 => Self::LayoutBitfield,
            50 => Self::StdlibStringConv,
            51 => Self::StdlibAllocFree,
            52 => Self::StdlibFileOps,
            53 => Self::StdlibPrintf,
            60 => Self::MacroExpansionFailed,
            61 => Self::MacroRecursive,
            62 => Self::MacroVariadic,
            70 => Self::CodegenUnsupported,
            71 => Self::CodegenAmbiguous,
            72 => Self::CodegenPerformance,
            _ => Self::Unknown,
        }
    }

    /// Human-readable fix suggestion for C-to-Rust transpilation
    pub fn fix_suggestion(&self) -> &'static str {
        match self {
            // C Parsing
            Self::CParsePreprocessor => "Check #include paths and #define macros",
            Self::CParseDeclaration => "Verify declaration syntax and types",
            Self::CParseSyntax => "Check C syntax near the error location",
            Self::CParseTypeMismatch => "Cast or convert types explicitly in C source",

            // Ownership/Lifetime
            Self::OwnershipBorrowConflict => "Use RefCell for interior mutability or restructure borrows",
            Self::OwnershipMovedValue => "Clone the value or use references instead",
            Self::OwnershipLifetimeMissing => "Add explicit lifetime annotations to function signature",
            Self::OwnershipDanglingPtr => "Use Option<&T> or ensure pointer validity with lifetime bounds",

            // Type System
            Self::TypeInferenceFailed => "Add explicit type annotations",
            Self::TypeMismatch => "Convert types with .into() or explicit cast",
            Self::TypeUnsizedLocal => "Use Box<T> or reference for unsized types",
            Self::TypeTraitNotImpl => "Implement required trait or use trait bounds",

            // Unsafe
            Self::UnsafeBlockRequired => "Wrap operation in unsafe {} block with safety comment",
            Self::UnsafeDerefRawPtr => "Validate pointer before deref, add safety invariant",
            Self::UnsafeFfiCall => "Ensure FFI contract is maintained, document assumptions",
            Self::UnsafeStaticMut => "Consider using AtomicXxx or Mutex for thread safety",

            // Memory Layout
            Self::LayoutStructPadding => "Use #[repr(C)] to match C struct layout",
            Self::LayoutArraySize => "Use const generics or Vec for dynamic arrays",
            Self::LayoutUnionAccess => "Use enum with variants or unsafe union access",
            Self::LayoutBitfield => "Use bitflags crate or manual bit manipulation",

            // Stdlib Mapping
            Self::StdlibStringConv => "Use CStr/CString for FFI, String for owned Rust strings",
            Self::StdlibAllocFree => "Replace malloc/free with Box::new/Vec::new",
            Self::StdlibFileOps => "Use std::fs::File and std::io traits",
            Self::StdlibPrintf => "Use format! or println! macros",

            // Macros
            Self::MacroExpansionFailed => "Manually inline macro or use Rust macro_rules!",
            Self::MacroRecursive => "Refactor recursive macro to iterative form",
            Self::MacroVariadic => "Use variadic generics or tuple-based approach",

            // Codegen
            Self::CodegenUnsupported => "Feature requires manual transpilation",
            Self::CodegenAmbiguous => "Review generated options and select best fit",
            Self::CodegenPerformance => "Consider alternative Rust idiom for performance",

            Self::Unknown => "Review the error message for details",
        }
    }
}

Step 4: Feature Extraction (features.rs)

// crates/decy-oracle/src/features.rs

use crate::categories::ErrorCategory;

/// Feature vector for ML model
///
/// Features are carefully chosen to capture C-to-Rust transpilation error patterns:
/// - Error code/phase features
/// - Keyword indicators (bag of words)
/// - C-specific patterns (pointers, malloc, etc.)
/// - Rust-specific patterns (borrow, lifetime, etc.)
pub struct ErrorFeatures {
    pub features: Vec<f32>,
}

impl ErrorFeatures {
    /// Extract features from error message and context
    pub fn extract(
        phase: &str,       // "parse", "analyze", "codegen", "rustc"
        error_code: Option<&str>,  // E0382, E0502, etc.
        message: &str,
        c_source: Option<&str>,
    ) -> Self {
        let mut features = Vec::with_capacity(128);
        let msg_lower = message.to_lowercase();

        // Phase one-hot encoding (4 features)
        features.push(if phase == "parse" { 1.0 } else { 0.0 });
        features.push(if phase == "analyze" { 1.0 } else { 0.0 });
        features.push(if phase == "codegen" { 1.0 } else { 0.0 });
        features.push(if phase == "rustc" { 1.0 } else { 0.0 });

        // Rust error code features (common ownership errors)
        if let Some(code) = error_code {
            features.push(if code == "E0382" { 1.0 } else { 0.0 }); // Use of moved value
            features.push(if code == "E0502" { 1.0 } else { 0.0 }); // Borrow conflict
            features.push(if code == "E0503" { 1.0 } else { 0.0 }); // Cannot use while borrowed
            features.push(if code == "E0505" { 1.0 } else { 0.0 }); // Move out of borrowed
            features.push(if code == "E0106" { 1.0 } else { 0.0 }); // Missing lifetime
            features.push(if code == "E0277" { 1.0 } else { 0.0 }); // Trait not implemented
            features.push(if code == "E0308" { 1.0 } else { 0.0 }); // Mismatched types
            features.push(if code == "E0133" { 1.0 } else { 0.0 }); // Unsafe required
        } else {
            features.extend([0.0; 8]);
        }

        // Message length features
        features.push((message.len() as f32 / 500.0).min(1.0));
        features.push(message.lines().count() as f32 / 20.0);

        // Ownership/borrow keywords
        features.push(if msg_lower.contains("borrow") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("moved") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("lifetime") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("mutable") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("immutable") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("dropped") { 1.0 } else { 0.0 });

        // Type system keywords
        features.push(if msg_lower.contains("type") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("expected") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("found") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("trait") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("implement") { 1.0 } else { 0.0 });

        // Unsafe keywords
        features.push(if msg_lower.contains("unsafe") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("raw pointer") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("dereference") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("ffi") { 1.0 } else { 0.0 });

        // C-specific patterns in source (if available)
        if let Some(c_src) = c_source {
            let c_lower = c_src.to_lowercase();
            features.push(if c_lower.contains("malloc") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("free") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("->") { 1.0 } else { 0.0 }); // Pointer deref
            features.push(if c_src.contains('*') { 1.0 } else { 0.0 }); // Pointer decl
            features.push(if c_lower.contains("struct") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("union") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("#define") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("#include") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("printf") { 1.0 } else { 0.0 });
            features.push(if c_lower.contains("typedef") { 1.0 } else { 0.0 });
        } else {
            features.extend([0.0; 10]);
        }

        // Macro/preprocessor keywords
        features.push(if msg_lower.contains("macro") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("expand") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("preprocessor") { 1.0 } else { 0.0 });

        // Stdlib mapping keywords
        features.push(if msg_lower.contains("string") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("cstr") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("file") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("alloc") { 1.0 } else { 0.0 });

        // Layout keywords
        features.push(if msg_lower.contains("layout") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("padding") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("alignment") { 1.0 } else { 0.0 });
        features.push(if msg_lower.contains("repr") { 1.0 } else { 0.0 });

        // Pad to fixed size (128 features)
        while features.len() < 128 {
            features.push(0.0);
        }

        Self { features }
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.features
    }
}

Step 5: Training Corpus (corpus.rs)

// crates/decy-oracle/src/corpus.rs

use crate::categories::ErrorCategory;
use crate::features::ErrorFeatures;
use serde::{Deserialize, Serialize};

/// Training example with features and label
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
    pub phase: String,
    pub error_code: Option<String>,
    pub message: String,
    pub c_source: Option<String>,
    pub category: ErrorCategory,
}

/// Training corpus management
pub struct Corpus {
    examples: Vec<TrainingExample>,
}

impl Corpus {
    pub fn new() -> Self {
        Self { examples: Vec::new() }
    }

    /// Load from JSON file
    pub fn load(path: &std::path::Path) -> anyhow::Result<Self> {
        let content = std::fs::read_to_string(path)?;
        let examples: Vec<TrainingExample> = serde_json::from_str(&content)?;
        Ok(Self { examples })
    }

    /// Generate synthetic training data for C-to-Rust transpilation
    pub fn generate_synthetic(count: usize) -> Self {
        let mut examples = Vec::with_capacity(count);
        let mut rng_seed = 42u64;

        let templates = [
            // Ownership errors
            ("rustc", Some("E0382"), "error[E0382]: use of moved value: `ptr`", Some("char *ptr = malloc(100);"), ErrorCategory::OwnershipMovedValue),
            ("rustc", Some("E0502"), "error[E0502]: cannot borrow `data` as mutable because it is also borrowed as immutable", Some("int *data;"), ErrorCategory::OwnershipBorrowConflict),
            ("rustc", Some("E0106"), "error[E0106]: missing lifetime specifier", Some("struct Node { Node *next; };"), ErrorCategory::OwnershipLifetimeMissing),

            // Type errors
            ("rustc", Some("E0308"), "error[E0308]: mismatched types: expected `i32`, found `*mut i32`", Some("int *x = 42;"), ErrorCategory::TypeMismatch),
            ("rustc", Some("E0277"), "error[E0277]: the trait bound `*mut c_void: Send` is not satisfied", None, ErrorCategory::TypeTraitNotImpl),
            ("analyze", None, "cannot infer type for variable `result`", Some("auto result = func();"), ErrorCategory::TypeInferenceFailed),

            // Unsafe errors
            ("rustc", Some("E0133"), "error[E0133]: dereference of raw pointer is unsafe", Some("*ptr = 10;"), ErrorCategory::UnsafeDerefRawPtr),
            ("analyze", None, "unsafe block required for FFI call to external function", Some("extern int foo();"), ErrorCategory::UnsafeFfiCall),
            ("rustc", None, "use of mutable static is unsafe", Some("static int counter = 0;"), ErrorCategory::UnsafeStaticMut),

            // Parse errors
            ("parse", None, "failed to expand #include <missing.h>", Some("#include <missing.h>"), ErrorCategory::CParsePreprocessor),
            ("parse", None, "syntax error: unexpected token ';'", Some("int x = ;"), ErrorCategory::CParseSyntax),
            ("parse", None, "invalid declaration: conflicting types", Some("int x; char x;"), ErrorCategory::CParseDeclaration),

            // Layout errors
            ("codegen", None, "struct padding mismatch: C size 12, Rust size 16", Some("struct S { char a; int b; };"), ErrorCategory::LayoutStructPadding),
            ("codegen", None, "cannot determine array size at compile time", Some("int arr[n];"), ErrorCategory::LayoutArraySize),
            ("codegen", None, "union field access requires unsafe block", Some("union U { int i; float f; };"), ErrorCategory::LayoutUnionAccess),

            // Stdlib mapping
            ("codegen", None, "cannot convert char* to String directly", Some("char *s = \"hello\";"), ErrorCategory::StdlibStringConv),
            ("codegen", None, "malloc/free should use Box or Vec", Some("int *p = malloc(sizeof(int));"), ErrorCategory::StdlibAllocFree),
            ("codegen", None, "printf format string needs conversion", Some("printf(\"%d\", x);"), ErrorCategory::StdlibPrintf),

            // Macro errors
            ("parse", None, "macro expansion failed: recursive macro detected", Some("#define X X+1"), ErrorCategory::MacroRecursive),
            ("parse", None, "variadic macro not supported", Some("#define LOG(...) printf(__VA_ARGS__)"), ErrorCategory::MacroVariadic),
        ];

        for i in 0..count {
            rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
            let idx = (rng_seed as usize) % templates.len();
            let (phase, error_code, message, c_source, category) = templates[idx].clone();

            let varied_message = if i % 3 == 0 {
                format!("{} (variant {})", message, i)
            } else {
                message.to_string()
            };

            examples.push(TrainingExample {
                phase: phase.to_string(),
                error_code: error_code.map(String::from),
                message: varied_message,
                c_source: c_source.map(String::from),
                category,
            });
        }

        Self { examples }
    }

    /// Convert to feature matrix (X) and labels (y)
    pub fn to_training_data(&self) -> (Vec<Vec<f32>>, Vec<u8>) {
        let mut x = Vec::with_capacity(self.examples.len());
        let mut y = Vec::with_capacity(self.examples.len());

        for example in &self.examples {
            let features = ErrorFeatures::extract(
                &example.phase,
                example.error_code.as_deref(),
                &example.message,
                example.c_source.as_deref(),
            );
            x.push(features.features);
            y.push(example.category as u8);
        }

        (x, y)
    }

    pub fn len(&self) -> usize {
        self.examples.len()
    }

    pub fn is_empty(&self) -> bool {
        self.examples.is_empty()
    }
}

Step 6: Main Oracle Implementation (lib.rs)

// crates/decy-oracle/src/lib.rs

pub mod categories;
pub mod corpus;
pub mod features;
pub mod automl_tuning;

use aprender::tree::RandomForestClassifier;
use aprender::format::{self, Compression, ModelType, SaveOptions};
use aprender::model_selection::{cross_validate, KFold};
use std::path::Path;
use thiserror::Error;

pub use categories::ErrorCategory;
pub use corpus::{Corpus, TrainingExample};
pub use features::ErrorFeatures;

#[derive(Error, Debug)]
pub enum OracleError {
    #[error("Model error: {0}")]
    Model(String),
    #[error("Training error: {0}")]
    Training(String),
    #[error("IO error: {0}")]
    Io(#[from] std::io::Error),
}

pub type Result<T> = std::result::Result<T, OracleError>;

/// Oracle configuration
#[derive(Debug, Clone)]
pub struct OracleConfig {
    /// Number of trees in forest
    /// IMPORTANT: 100 is sufficient. 10,000 causes 15+ min training!
    pub n_estimators: usize,
    /// Maximum tree depth
    pub max_depth: usize,
    /// Random seed for reproducibility
    pub random_state: Option<u64>,
}

impl Default for OracleConfig {
    fn default() -> Self {
        Self {
            n_estimators: 100,  // NOT 10,000!
            max_depth: 10,
            random_state: Some(42),
        }
    }
}

/// ML-powered error classification oracle for C-to-Rust transpilation
pub struct Oracle {
    classifier: RandomForestClassifier,
    config: OracleConfig,
}

impl Oracle {
    /// Load existing model or train new one
    pub fn load_or_train(
        model_path: &Path,
        corpus: &Corpus,
        config: OracleConfig,
    ) -> Result<Self> {
        if model_path.exists() {
            tracing::info!("Loading existing model from {:?}", model_path);
            match Self::load(model_path, config.clone()) {
                Ok(oracle) => return Ok(oracle),
                Err(e) => {
                    tracing::warn!("Failed to load model: {}, retraining", e);
                }
            }
        }

        tracing::info!("Training new model with {} examples", corpus.len());
        Self::train(corpus, config)
    }

    /// Train new model from corpus
    pub fn train(corpus: &Corpus, config: OracleConfig) -> Result<Self> {
        let (x, y) = corpus.to_training_data();

        let mut classifier = RandomForestClassifier::new();
        classifier
            .set_n_estimators(config.n_estimators)
            .set_max_depth(config.max_depth);

        if let Some(seed) = config.random_state {
            classifier.set_random_state(seed);
        }

        let x_ref: Vec<&[f32]> = x.iter().map(|row| row.as_slice()).collect();

        classifier
            .fit(&x_ref, &y)
            .map_err(|e| OracleError::Training(e.to_string()))?;

        Ok(Self { classifier, config })
    }

    /// Load model from file
    pub fn load(path: &Path, config: OracleConfig) -> Result<Self> {
        let classifier: RandomForestClassifier = format::load(path)
            .map_err(|e| OracleError::Model(e.to_string()))?;

        Ok(Self { classifier, config })
    }

    /// Save model to file (with zstd compression)
    pub fn save(&self, path: &Path) -> Result<()> {
        let options = SaveOptions::default()
            .with_name("decy-oracle")
            .with_description("RandomForest error classification model for decy C-to-Rust transpiler")
            .with_compression(Compression::ZstdDefault); // 14x smaller!

        format::save(&self.classifier, ModelType::RandomForest, path, options)
            .map_err(|e| OracleError::Model(e.to_string()))?;

        Ok(())
    }

    /// Classify transpilation error and return category with confidence
    pub fn classify(
        &self,
        phase: &str,
        error_code: Option<&str>,
        message: &str,
        c_source: Option<&str>,
    ) -> (ErrorCategory, f32) {
        let features = ErrorFeatures::extract(phase, error_code, message, c_source);
        let x = [features.as_slice()];

        let predictions = self.classifier.predict(&x);
        let probabilities = self.classifier.predict_proba(&x);

        let label = predictions.first().copied().unwrap_or(255);
        let confidence = probabilities
            .first()
            .and_then(|probs| probs.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()))
            .unwrap_or(0.0);

        (ErrorCategory::from_label(label), confidence)
    }

    /// Get fix suggestion for transpilation error
    pub fn suggest_fix(
        &self,
        phase: &str,
        error_code: Option<&str>,
        message: &str,
        c_source: Option<&str>,
    ) -> String {
        let (category, confidence) = self.classify(phase, error_code, message, c_source);

        format!(
            "[{:.0}% confident] {:?}: {}",
            confidence * 100.0,
            category,
            category.fix_suggestion()
        )
    }
}

/// Cross-validation for model evaluation
pub fn evaluate_model(corpus: &Corpus, config: &OracleConfig, folds: usize) -> Result<f64> {
    let (x, y) = corpus.to_training_data();
    let x_ref: Vec<&[f32]> = x.iter().map(|row| row.as_slice()).collect();

    let mut classifier = RandomForestClassifier::new();
    classifier
        .set_n_estimators(config.n_estimators)
        .set_max_depth(config.max_depth);

    if let Some(seed) = config.random_state {
        classifier.set_random_state(seed);
    }

    let kfold = KFold::new(folds);
    let results = cross_validate(&classifier, &x_ref, &y, &kfold)
        .map_err(|e| OracleError::Training(e.to_string()))?;

    let mean_accuracy = results.iter().sum::<f64>() / results.len() as f64;
    Ok(mean_accuracy)
}

Step 7: AutoML Tuning (automl_tuning.rs)

// crates/decy-oracle/src/automl_tuning.rs

use crate::{Corpus, OracleConfig, OracleError, Result};
use aprender::automl::{RandomSearch, SearchSpace};
use aprender::model_selection::{cross_validate, StratifiedKFold};
use aprender::tree::RandomForestClassifier;

/// AutoML hyperparameter tuning results
#[derive(Debug, Clone)]
pub struct TuningResults {
    pub best_n_estimators: usize,
    pub best_max_depth: usize,
    pub best_cv_score: f64,
    pub trials_run: usize,
}

/// Tune hyperparameters using AutoML
pub fn tune_hyperparameters(
    corpus: &Corpus,
    n_trials: usize,
) -> Result<TuningResults> {
    let (x, y) = corpus.to_training_data();
    let x_ref: Vec<&[f32]> = x.iter().map(|row| row.as_slice()).collect();

    let search_space = SearchSpace::new()
        .add_int("n_estimators", 50, 200)
        .add_int("max_depth", 5, 20);

    let mut best_score = 0.0;
    let mut best_config = OracleConfig::default();

    let mut search = RandomSearch::new(search_space, 42);

    for trial in 0..n_trials {
        let params = search.suggest();
        let n_estimators = params.get_int("n_estimators") as usize;
        let max_depth = params.get_int("max_depth") as usize;

        let mut classifier = RandomForestClassifier::new();
        classifier
            .set_n_estimators(n_estimators)
            .set_max_depth(max_depth)
            .set_random_state(42);

        let kfold = StratifiedKFold::new(3);
        let results = cross_validate(&classifier, &x_ref, &y, &kfold)
            .map_err(|e| OracleError::Training(e.to_string()))?;

        let mean_score = results.iter().sum::<f64>() / results.len() as f64;

        if mean_score > best_score {
            best_score = mean_score;
            best_config = OracleConfig {
                n_estimators,
                max_depth,
                random_state: Some(42),
            };
            tracing::info!(
                "Trial {}: New best score {:.4} (n_estimators={}, max_depth={})",
                trial, best_score, n_estimators, max_depth
            );
        }

        search.report(params, mean_score);
    }

    Ok(TuningResults {
        best_n_estimators: best_config.n_estimators,
        best_max_depth: best_config.max_depth,
        best_cv_score: best_score,
        trials_run: n_trials,
    })
}

Step 8: Model Evaluation Tests

// crates/decy-oracle/tests/model_evaluation.rs

use decy_oracle::{Corpus, Oracle, OracleConfig, evaluate_model};

#[test]
fn test_accuracy_threshold() {
    let corpus = Corpus::generate_synthetic(1000);
    let config = OracleConfig::default();

    let cv_accuracy = evaluate_model(&corpus, &config, 5).unwrap();
    assert!(
        cv_accuracy > 0.85,
        "5-fold CV accuracy {:.2}% must be >85%",
        cv_accuracy * 100.0
    );
}

#[test]
fn test_training_time() {
    let corpus = Corpus::generate_synthetic(1000);
    let config = OracleConfig::default();

    let start = std::time::Instant::now();
    let _oracle = Oracle::train(&corpus, config).unwrap();
    let elapsed = start.elapsed();

    assert!(
        elapsed.as_secs_f64() < 1.0,
        "Training time {:.2}s must be <1s",
        elapsed.as_secs_f64()
    );
}

#[test]
fn test_model_size() {
    let corpus = Corpus::generate_synthetic(1000);
    let config = OracleConfig::default();
    let oracle = Oracle::train(&corpus, config).unwrap();

    let temp_dir = tempfile::tempdir().unwrap();
    let model_path = temp_dir.path().join("test_model.bin");

    oracle.save(&model_path).unwrap();

    let size = std::fs::metadata(&model_path).unwrap().len();
    assert!(
        size < 1_000_000,
        "Model size {} bytes must be <1MB",
        size
    );
}

#[test]
fn test_cv_variance() {
    use aprender::model_selection::{cross_validate, KFold};
    use aprender::tree::RandomForestClassifier;

    let corpus = Corpus::generate_synthetic(1000);
    let (x, y) = corpus.to_training_data();
    let x_ref: Vec<&[f32]> = x.iter().map(|row| row.as_slice()).collect();

    let mut classifier = RandomForestClassifier::new();
    classifier.set_n_estimators(100).set_max_depth(10);

    let kfold = KFold::new(5);
    let results = cross_validate(&classifier, &x_ref, &y, &kfold).unwrap();

    let mean = results.iter().sum::<f64>() / results.len() as f64;
    let variance = results.iter()
        .map(|x| (x - mean).powi(2))
        .sum::<f64>() / results.len() as f64;
    let std_dev = variance.sqrt();

    assert!(
        std_dev < 0.05,
        "CV std dev {:.4} must be <0.05 (indicates overfitting)",
        std_dev
    );

    println!("5-Fold CV Results:");
    for (i, score) in results.iter().enumerate() {
        println!("  Fold {}: {:.2}%", i + 1, score * 100.0);
    }
    println!("  Mean: {:.2}% +/- {:.2}%", mean * 100.0, std_dev * 100.0);
}

Step 9: Reproduction Commands

Initial Setup

git clone https://github.com/paiml/decy.git
cd decy

# Verify GPU support
nvidia-smi

# Build with GPU support
cargo build --release -p decy-oracle --features gpu

Training

cargo run --release -p decy-oracle --example train -- --synthetic --count 5000

# Expected output:
# Training with 5000 synthetic examples...
# Training complete in 0.8s
# Model saved to models/oracle.bin (503 KB)

Evaluation

cargo test -p decy-oracle --test model_evaluation -- --nocapture

# Expected output:
# test_accuracy_threshold ... ok (accuracy: 97%+)
# test_training_time ... ok (0.8s)
# test_model_size ... ok (503 KB)
# test_cv_variance ... ok (std_dev: <0.05)

Validation Checklist

Check Command Expected Result
Compiles cargo build -p decy-oracle No errors
Tests pass cargo test -p decy-oracle All green
Accuracy >90% See test output 97%+
Training <1s See test output 0.8s
Model <1MB ls -lh models/oracle.bin 503 KB
CV variance <5% See test output <5%
Clippy clean cargo clippy -p decy-oracle -- -D warnings No warnings

Integration with decy CLI

// In decy/src/main.rs or decy-core
use decy_oracle::{Oracle, OracleConfig};

fn handle_transpilation_error(error: &TranspilationError) {
    let oracle = Oracle::load_or_train(
        Path::new("models/oracle.bin"),
        &Corpus::generate_synthetic(1000),
        OracleConfig::default(),
    ).unwrap();

    let suggestion = oracle.suggest_fix(
        &error.phase,
        error.rust_error_code.as_deref(),
        &error.message,
        error.c_source.as_deref(),
    );

    eprintln!("Suggestion: {}", suggestion);
}

References


Dependencies (crates.io versions)

aprender = "0.10"   # ML library
trueno = "0.7"      # GPU backend (via aprender)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions