diff --git a/src/strategies/factory.py b/src/strategies/factory.py index 597447f..257576f 100644 --- a/src/strategies/factory.py +++ b/src/strategies/factory.py @@ -160,68 +160,63 @@ def _build_regime_mappings(self) -> Dict[int, List[StrategyType]]: """ Build mappings from market regimes to preferred strategies. - 8-regime market classification system: - 0: Deep Bear Market - 1: Bull Trending - 2: High Volatility Trending - 3: Uncertain/Transitional - 4: Low Volatility/Sideways - 5: Moderate Bull/Sideways - 6: Recovery Phase - 7: High Volatility Uncertain - 8: Extreme Volatility + Regime integers match RegimeType in src/data/regime_labeler.py, which + produces the training labels the regime detector learns. The mapping + below must stay aligned with that enum so that the strategy recommended + at inference matches the regime the model was trained to identify. + + 8-regime market classification system (RegimeType): + 0: BULL_TRENDING - Strong upward momentum + 1: BEAR_TRENDING - Strong downward momentum + 2: HIGH_VOLATILITY - Elevated volatility environment + 3: LOW_VOLATILITY - Subdued volatility environment + 4: SIDEWAYS_RANGING - Consolidation patterns + 5: RECOVERY - Post-decline bounce patterns + 6: DISTRIBUTION - Pre-decline weakening + 7: CRISIS - Extreme stress conditions Returns: Dictionary mapping regime numbers to preferred strategy types """ return { - 0: [ # Deep Bear Market - StrategyType.LONG_PUT, - StrategyType.BEAR_PUT_SPREAD, - StrategyType.BEAR_CALL_SPREAD, - StrategyType.LONG_STRADDLE, - ], - 1: [ # Bull Trending + 0: [ # BULL_TRENDING StrategyType.LONG_CALL, StrategyType.BULL_CALL_SPREAD, StrategyType.BULL_PUT_SPREAD, ], - 2: [ # High Volatility Trending - StrategyType.LONG_STRADDLE, - StrategyType.LONG_STRANGLE, - StrategyType.BULL_CALL_SPREAD, + 1: [ # BEAR_TRENDING + StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, + StrategyType.BEAR_CALL_SPREAD, ], - 3: [ # Uncertain/Transitional - StrategyType.IRON_CONDOR, - StrategyType.SHORT_STRANGLE, - StrategyType.BUTTERFLY, + 2: [ # HIGH_VOLATILITY + StrategyType.LONG_STRADDLE, + StrategyType.LONG_STRANGLE, ], - 4: [ # Low Volatility/Sideways + 3: [ # LOW_VOLATILITY StrategyType.CALENDAR_CALL, StrategyType.CALENDAR_PUT, StrategyType.SHORT_STRADDLE, - StrategyType.IRON_CONDOR, ], - 5: [ # Moderate Bull/Sideways - StrategyType.BULL_PUT_SPREAD, - StrategyType.SHORT_STRANGLE, + 4: [ # SIDEWAYS_RANGING StrategyType.IRON_CONDOR, + StrategyType.SHORT_STRANGLE, + StrategyType.BUTTERFLY, ], - 6: [ # Recovery Phase + 5: [ # RECOVERY StrategyType.BULL_CALL_SPREAD, StrategyType.LONG_CALL, StrategyType.BULL_PUT_SPREAD, ], - 7: [ # High Volatility Uncertain - StrategyType.LONG_STRADDLE, - StrategyType.LONG_STRANGLE, + 6: [ # DISTRIBUTION + StrategyType.BEAR_CALL_SPREAD, StrategyType.IRON_CONDOR, + StrategyType.SHORT_STRANGLE, ], - 8: [ # Extreme Volatility + 7: [ # CRISIS + StrategyType.LONG_PUT, StrategyType.LONG_STRADDLE, - StrategyType.LONG_STRANGLE, - StrategyType.SHORT_STRANGLE, # For mean reversion + StrategyType.BEAR_PUT_SPREAD, ], } @@ -346,15 +341,14 @@ def _assess_expected_performance(self, strategy: BaseStrategy, regime: int) -> s Performance category string """ regime_performance_map = { - 0: "defensive", # Deep Bear - 1: "growth", # Bull Trending - 2: "volatile", # High Vol Trending - 3: "neutral", # Uncertain - 4: "income", # Low Vol Sideways - 5: "moderate", # Moderate Bull - 6: "recovery", # Recovery Phase - 7: "adaptive", # High Vol Uncertain - 8: "speculative", # Extreme Volatility + 0: "growth", # BULL_TRENDING + 1: "defensive", # BEAR_TRENDING + 2: "volatile", # HIGH_VOLATILITY + 3: "income", # LOW_VOLATILITY + 4: "neutral", # SIDEWAYS_RANGING + 5: "recovery", # RECOVERY + 6: "defensive", # DISTRIBUTION + 7: "defensive", # CRISIS } return regime_performance_map.get(regime, "unknown") @@ -372,7 +366,7 @@ def _assess_risk_for_regime(self, strategy: BaseStrategy, regime: int) -> str: """ base_risk = strategy.risk_level.value - high_risk_regimes = [0, 2, 7, 8] # Volatile or extreme conditions + high_risk_regimes = [1, 2, 6, 7] # BEAR_TRENDING, HIGH_VOLATILITY, DISTRIBUTION, CRISIS if regime in high_risk_regimes: if base_risk in ["low", "medium"]: return "appropriate_risk" diff --git a/tests/strategies/test_factory.py b/tests/strategies/test_factory.py index 0a77c9d..11ca7b3 100644 --- a/tests/strategies/test_factory.py +++ b/tests/strategies/test_factory.py @@ -25,7 +25,7 @@ def factory(self): def bull_conditions(self): """Bullish market conditions.""" return MarketConditions( - regime=1, # Bull trending + regime=0, # BULL_TRENDING (RegimeType) volatility_rank=0.4, trend_strength=0.7, time_to_expiration=35, @@ -37,7 +37,7 @@ def bull_conditions(self): def bear_conditions(self): """Bearish market conditions.""" return MarketConditions( - regime=0, # Deep bear + regime=1, # BEAR_TRENDING (RegimeType) volatility_rank=0.7, trend_strength=-0.8, time_to_expiration=30, @@ -49,7 +49,7 @@ def bear_conditions(self): def low_vol_conditions(self): """Low volatility market conditions.""" return MarketConditions( - regime=4, # Low vol sideways + regime=3, # LOW_VOLATILITY (RegimeType) volatility_rank=0.2, trend_strength=0.1, time_to_expiration=40, @@ -116,7 +116,7 @@ def test_regime_mappings_coverage(self, factory): def test_bull_regime_recommendations(self, factory, bull_conditions): """Test recommendations for bullish market regime.""" - recommendations = factory.get_recommended_strategies(1, bull_conditions) + recommendations = factory.get_recommended_strategies(0, bull_conditions) assert len(recommendations) > 0 assert len(recommendations) <= 5 # Default max @@ -140,7 +140,7 @@ def test_bull_regime_recommendations(self, factory, bull_conditions): def test_bear_regime_recommendations(self, factory, bear_conditions): """Test recommendations for bearish market regime.""" - recommendations = factory.get_recommended_strategies(0, bear_conditions) + recommendations = factory.get_recommended_strategies(1, bear_conditions) assert len(recommendations) > 0 @@ -159,7 +159,7 @@ def test_bear_regime_recommendations(self, factory, bear_conditions): def test_low_volatility_recommendations(self, factory, low_vol_conditions): """Test recommendations for low volatility regime.""" - recommendations = factory.get_recommended_strategies(4, low_vol_conditions) + recommendations = factory.get_recommended_strategies(3, low_vol_conditions) assert len(recommendations) > 0 @@ -169,7 +169,6 @@ def test_low_volatility_recommendations(self, factory, low_vol_conditions): StrategyType.CALENDAR_CALL, StrategyType.CALENDAR_PUT, StrategyType.SHORT_STRADDLE, - StrategyType.IRON_CONDOR ] # At least some low vol strategies should be recommended @@ -295,16 +294,16 @@ def test_strategy_factory_performance(self, factory): def test_regime_strategy_alignment(self, factory): """Test that recommended strategies align with regime characteristics.""" - # Test specific regime-strategy alignments + # Test specific regime-strategy alignments (RegimeType numbering) - # Regime 0 (Deep Bear) should favor protective strategies - bear_recs = factory.get_recommended_strategies(0) + # Regime 1 (BEAR_TRENDING) should favor protective strategies + bear_recs = factory.get_recommended_strategies(1) bear_types = [rec.strategy_type for rec in bear_recs] - protective_strategies = [StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.LONG_STRADDLE] + protective_strategies = [StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.BEAR_CALL_SPREAD] assert any(st in bear_types for st in protective_strategies), "Bear regime should favor protective strategies" - # Regime 4 (Low Vol) should favor time decay strategies - low_vol_recs = factory.get_recommended_strategies(4) + # Regime 3 (LOW_VOLATILITY) should favor time decay strategies + low_vol_recs = factory.get_recommended_strategies(3) low_vol_types = [rec.strategy_type for rec in low_vol_recs] time_decay_strategies = [StrategyType.CALENDAR_CALL, StrategyType.CALENDAR_PUT, StrategyType.SHORT_STRADDLE] assert any(st in low_vol_types for st in time_decay_strategies), "Low vol regime should favor time decay" @@ -345,6 +344,47 @@ def test_recommendation_consistency(self, factory): assert rec1.strategy_type == rec2.strategy_type assert abs(rec1.confidence - rec2.confidence) < 0.001 # Should be very close + def test_regime_mappings_match_labeler_numbering(self, factory): + """Guard against regime/strategy inversion (issue #15). + + The factory's regime integers must match RegimeType in + src/data/regime_labeler.py, which produces the training labels the + regime detector learns. RegimeType is not imported here directly + because src.data.regime_labeler currently fails to import (issue #1), + so the canonical numbering is asserted inline as the contract: + + 0: BULL_TRENDING 5: RECOVERY + 1: BEAR_TRENDING 6: DISTRIBUTION + 2: HIGH_VOLATILITY 7: CRISIS + 3: LOW_VOLATILITY + 4: SIDEWAYS_RANGING + + If this test fails, the factory mapping and the labeler have diverged + and the model will recommend strategies for the wrong regime. + """ + bullish = {StrategyType.LONG_CALL, StrategyType.BULL_CALL_SPREAD, StrategyType.BULL_PUT_SPREAD} + bearish = {StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.BEAR_CALL_SPREAD} + + # Regime 0 is BULL_TRENDING: must lean bullish, must not be net bearish. + bull_types = {rec.strategy_type for rec in factory.get_recommended_strategies(0)} + assert bull_types & bullish, "Regime 0 (BULL_TRENDING) must recommend bullish strategies" + assert not (bull_types & bearish), "Regime 0 (BULL_TRENDING) must not recommend bearish strategies" + + # Regime 1 is BEAR_TRENDING: must lean bearish, must not be net bullish. + bear_types = {rec.strategy_type for rec in factory.get_recommended_strategies(1)} + assert bear_types & bearish, "Regime 1 (BEAR_TRENDING) must recommend bearish strategies" + assert not (bear_types & bullish), "Regime 1 (BEAR_TRENDING) must not recommend bullish strategies" + + # Regime 7 is CRISIS: must include protective strategies. + crisis_types = {rec.strategy_type for rec in factory.get_recommended_strategies(7)} + assert crisis_types & {StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.LONG_STRADDLE}, \ + "Regime 7 (CRISIS) must recommend protective strategies" + + # Only regimes 0-7 exist; the detector outputs 8 regimes, so regime 8 + # must not be present (it would be unreachable at inference). + assert set(factory._regime_mappings.keys()) == set(range(8)), \ + "Factory must define exactly regimes 0-7 to match an 8-class detector" + def test_comprehensive_factory_validation(self, factory): """Comprehensive validation of factory functionality.""" # High-level test ensuring factory meets all requirements diff --git a/tests/strategies/test_integration.py b/tests/strategies/test_integration.py index f7bcac0..4539e9c 100644 --- a/tests/strategies/test_integration.py +++ b/tests/strategies/test_integration.py @@ -119,7 +119,7 @@ def test_regime_based_recommendations(self, factory, sample_conditions): def test_recommendations_ranking(self, factory): """Test recommendations are properly ranked by confidence.""" conditions = MarketConditions( - regime=1, # Bull trending + regime=1, # BEAR_TRENDING (RegimeType) volatility_rank=0.4, trend_strength=0.6, time_to_expiration=30, @@ -212,10 +212,10 @@ def test_market_condition_validation(self, factory): """Test market condition validation across strategies.""" # Test various market conditions test_conditions = [ - MarketConditions(0, 0.8, -0.8, 30, 10000, 0.05), # Deep bear, high vol - MarketConditions(1, 0.4, 0.6, 30, 10000, 0.05), # Bull trending - MarketConditions(4, 0.2, 0.1, 30, 10000, 0.05), # Low vol sideways - MarketConditions(7, 0.9, 0.0, 30, 10000, 0.05), # High vol uncertain + MarketConditions(0, 0.8, 0.8, 30, 10000, 0.05), # BULL_TRENDING + MarketConditions(1, 0.7, -0.6, 30, 10000, 0.05), # BEAR_TRENDING + MarketConditions(4, 0.2, 0.1, 30, 10000, 0.05), # SIDEWAYS_RANGING + MarketConditions(7, 0.9, -0.8, 30, 10000, 0.05), # CRISIS ] strategies = factory.get_all_strategies()[:5] # Test subset of strategies