Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 41 additions & 47 deletions src/strategies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,68 +160,63 @@ def _build_regime_mappings(self) -> Dict[int, List[StrategyType]]:
"""
Build mappings from market regimes to preferred strategies.

8-regime market classification system:
0: Deep Bear Market
1: Bull Trending
2: High Volatility Trending
3: Uncertain/Transitional
4: Low Volatility/Sideways
5: Moderate Bull/Sideways
6: Recovery Phase
7: High Volatility Uncertain
8: Extreme Volatility
Regime integers match RegimeType in src/data/regime_labeler.py, which
produces the training labels the regime detector learns. The mapping
below must stay aligned with that enum so that the strategy recommended
at inference matches the regime the model was trained to identify.

8-regime market classification system (RegimeType):
0: BULL_TRENDING - Strong upward momentum
1: BEAR_TRENDING - Strong downward momentum
2: HIGH_VOLATILITY - Elevated volatility environment
3: LOW_VOLATILITY - Subdued volatility environment
4: SIDEWAYS_RANGING - Consolidation patterns
5: RECOVERY - Post-decline bounce patterns
6: DISTRIBUTION - Pre-decline weakening
7: CRISIS - Extreme stress conditions

Returns:
Dictionary mapping regime numbers to preferred strategy types
"""
return {
0: [ # Deep Bear Market
StrategyType.LONG_PUT,
StrategyType.BEAR_PUT_SPREAD,
StrategyType.BEAR_CALL_SPREAD,
StrategyType.LONG_STRADDLE,
],
1: [ # Bull Trending
0: [ # BULL_TRENDING
StrategyType.LONG_CALL,
StrategyType.BULL_CALL_SPREAD,
StrategyType.BULL_PUT_SPREAD,
],
2: [ # High Volatility Trending
StrategyType.LONG_STRADDLE,
StrategyType.LONG_STRANGLE,
StrategyType.BULL_CALL_SPREAD,
1: [ # BEAR_TRENDING
StrategyType.LONG_PUT,
StrategyType.BEAR_PUT_SPREAD,
StrategyType.BEAR_CALL_SPREAD,
],
3: [ # Uncertain/Transitional
StrategyType.IRON_CONDOR,
StrategyType.SHORT_STRANGLE,
StrategyType.BUTTERFLY,
2: [ # HIGH_VOLATILITY
StrategyType.LONG_STRADDLE,
StrategyType.LONG_STRANGLE,
],
4: [ # Low Volatility/Sideways
3: [ # LOW_VOLATILITY
StrategyType.CALENDAR_CALL,
StrategyType.CALENDAR_PUT,
StrategyType.SHORT_STRADDLE,
StrategyType.IRON_CONDOR,
],
5: [ # Moderate Bull/Sideways
StrategyType.BULL_PUT_SPREAD,
StrategyType.SHORT_STRANGLE,
4: [ # SIDEWAYS_RANGING
StrategyType.IRON_CONDOR,
StrategyType.SHORT_STRANGLE,
StrategyType.BUTTERFLY,
],
6: [ # Recovery Phase
5: [ # RECOVERY
StrategyType.BULL_CALL_SPREAD,
StrategyType.LONG_CALL,
StrategyType.BULL_PUT_SPREAD,
],
7: [ # High Volatility Uncertain
StrategyType.LONG_STRADDLE,
StrategyType.LONG_STRANGLE,
6: [ # DISTRIBUTION
StrategyType.BEAR_CALL_SPREAD,
StrategyType.IRON_CONDOR,
StrategyType.SHORT_STRANGLE,
],
8: [ # Extreme Volatility
7: [ # CRISIS
StrategyType.LONG_PUT,
StrategyType.LONG_STRADDLE,
StrategyType.LONG_STRANGLE,
StrategyType.SHORT_STRANGLE, # For mean reversion
StrategyType.BEAR_PUT_SPREAD,
],
}

Expand Down Expand Up @@ -346,15 +341,14 @@ def _assess_expected_performance(self, strategy: BaseStrategy, regime: int) -> s
Performance category string
"""
regime_performance_map = {
0: "defensive", # Deep Bear
1: "growth", # Bull Trending
2: "volatile", # High Vol Trending
3: "neutral", # Uncertain
4: "income", # Low Vol Sideways
5: "moderate", # Moderate Bull
6: "recovery", # Recovery Phase
7: "adaptive", # High Vol Uncertain
8: "speculative", # Extreme Volatility
0: "growth", # BULL_TRENDING
1: "defensive", # BEAR_TRENDING
2: "volatile", # HIGH_VOLATILITY
3: "income", # LOW_VOLATILITY
4: "neutral", # SIDEWAYS_RANGING
5: "recovery", # RECOVERY
6: "defensive", # DISTRIBUTION
7: "defensive", # CRISIS
}

return regime_performance_map.get(regime, "unknown")
Expand All @@ -372,7 +366,7 @@ def _assess_risk_for_regime(self, strategy: BaseStrategy, regime: int) -> str:
"""
base_risk = strategy.risk_level.value

high_risk_regimes = [0, 2, 7, 8] # Volatile or extreme conditions
high_risk_regimes = [1, 2, 6, 7] # BEAR_TRENDING, HIGH_VOLATILITY, DISTRIBUTION, CRISIS
if regime in high_risk_regimes:
if base_risk in ["low", "medium"]:
return "appropriate_risk"
Expand Down
66 changes: 53 additions & 13 deletions tests/strategies/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def factory(self):
def bull_conditions(self):
"""Bullish market conditions."""
return MarketConditions(
regime=1, # Bull trending
regime=0, # BULL_TRENDING (RegimeType)
volatility_rank=0.4,
trend_strength=0.7,
time_to_expiration=35,
Expand All @@ -37,7 +37,7 @@ def bull_conditions(self):
def bear_conditions(self):
"""Bearish market conditions."""
return MarketConditions(
regime=0, # Deep bear
regime=1, # BEAR_TRENDING (RegimeType)
volatility_rank=0.7,
trend_strength=-0.8,
time_to_expiration=30,
Expand All @@ -49,7 +49,7 @@ def bear_conditions(self):
def low_vol_conditions(self):
"""Low volatility market conditions."""
return MarketConditions(
regime=4, # Low vol sideways
regime=3, # LOW_VOLATILITY (RegimeType)
volatility_rank=0.2,
trend_strength=0.1,
time_to_expiration=40,
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_regime_mappings_coverage(self, factory):

def test_bull_regime_recommendations(self, factory, bull_conditions):
"""Test recommendations for bullish market regime."""
recommendations = factory.get_recommended_strategies(1, bull_conditions)
recommendations = factory.get_recommended_strategies(0, bull_conditions)

assert len(recommendations) > 0
assert len(recommendations) <= 5 # Default max
Expand All @@ -140,7 +140,7 @@ def test_bull_regime_recommendations(self, factory, bull_conditions):

def test_bear_regime_recommendations(self, factory, bear_conditions):
"""Test recommendations for bearish market regime."""
recommendations = factory.get_recommended_strategies(0, bear_conditions)
recommendations = factory.get_recommended_strategies(1, bear_conditions)

assert len(recommendations) > 0

Expand All @@ -159,7 +159,7 @@ def test_bear_regime_recommendations(self, factory, bear_conditions):

def test_low_volatility_recommendations(self, factory, low_vol_conditions):
"""Test recommendations for low volatility regime."""
recommendations = factory.get_recommended_strategies(4, low_vol_conditions)
recommendations = factory.get_recommended_strategies(3, low_vol_conditions)

assert len(recommendations) > 0

Expand All @@ -169,7 +169,6 @@ def test_low_volatility_recommendations(self, factory, low_vol_conditions):
StrategyType.CALENDAR_CALL,
StrategyType.CALENDAR_PUT,
StrategyType.SHORT_STRADDLE,
StrategyType.IRON_CONDOR
]

# At least some low vol strategies should be recommended
Expand Down Expand Up @@ -295,16 +294,16 @@ def test_strategy_factory_performance(self, factory):

def test_regime_strategy_alignment(self, factory):
"""Test that recommended strategies align with regime characteristics."""
# Test specific regime-strategy alignments
# Test specific regime-strategy alignments (RegimeType numbering)

# Regime 0 (Deep Bear) should favor protective strategies
bear_recs = factory.get_recommended_strategies(0)
# Regime 1 (BEAR_TRENDING) should favor protective strategies
bear_recs = factory.get_recommended_strategies(1)
bear_types = [rec.strategy_type for rec in bear_recs]
protective_strategies = [StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.LONG_STRADDLE]
protective_strategies = [StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.BEAR_CALL_SPREAD]
assert any(st in bear_types for st in protective_strategies), "Bear regime should favor protective strategies"

# Regime 4 (Low Vol) should favor time decay strategies
low_vol_recs = factory.get_recommended_strategies(4)
# Regime 3 (LOW_VOLATILITY) should favor time decay strategies
low_vol_recs = factory.get_recommended_strategies(3)
low_vol_types = [rec.strategy_type for rec in low_vol_recs]
time_decay_strategies = [StrategyType.CALENDAR_CALL, StrategyType.CALENDAR_PUT, StrategyType.SHORT_STRADDLE]
assert any(st in low_vol_types for st in time_decay_strategies), "Low vol regime should favor time decay"
Expand Down Expand Up @@ -345,6 +344,47 @@ def test_recommendation_consistency(self, factory):
assert rec1.strategy_type == rec2.strategy_type
assert abs(rec1.confidence - rec2.confidence) < 0.001 # Should be very close

def test_regime_mappings_match_labeler_numbering(self, factory):
"""Guard against regime/strategy inversion (issue #15).

The factory's regime integers must match RegimeType in
src/data/regime_labeler.py, which produces the training labels the
regime detector learns. RegimeType is not imported here directly
because src.data.regime_labeler currently fails to import (issue #1),
so the canonical numbering is asserted inline as the contract:

0: BULL_TRENDING 5: RECOVERY
1: BEAR_TRENDING 6: DISTRIBUTION
2: HIGH_VOLATILITY 7: CRISIS
3: LOW_VOLATILITY
4: SIDEWAYS_RANGING

If this test fails, the factory mapping and the labeler have diverged
and the model will recommend strategies for the wrong regime.
"""
bullish = {StrategyType.LONG_CALL, StrategyType.BULL_CALL_SPREAD, StrategyType.BULL_PUT_SPREAD}
bearish = {StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.BEAR_CALL_SPREAD}

# Regime 0 is BULL_TRENDING: must lean bullish, must not be net bearish.
bull_types = {rec.strategy_type for rec in factory.get_recommended_strategies(0)}
assert bull_types & bullish, "Regime 0 (BULL_TRENDING) must recommend bullish strategies"
assert not (bull_types & bearish), "Regime 0 (BULL_TRENDING) must not recommend bearish strategies"

# Regime 1 is BEAR_TRENDING: must lean bearish, must not be net bullish.
bear_types = {rec.strategy_type for rec in factory.get_recommended_strategies(1)}
assert bear_types & bearish, "Regime 1 (BEAR_TRENDING) must recommend bearish strategies"
assert not (bear_types & bullish), "Regime 1 (BEAR_TRENDING) must not recommend bullish strategies"

# Regime 7 is CRISIS: must include protective strategies.
crisis_types = {rec.strategy_type for rec in factory.get_recommended_strategies(7)}
assert crisis_types & {StrategyType.LONG_PUT, StrategyType.BEAR_PUT_SPREAD, StrategyType.LONG_STRADDLE}, \
"Regime 7 (CRISIS) must recommend protective strategies"

# Only regimes 0-7 exist; the detector outputs 8 regimes, so regime 8
# must not be present (it would be unreachable at inference).
assert set(factory._regime_mappings.keys()) == set(range(8)), \
"Factory must define exactly regimes 0-7 to match an 8-class detector"

def test_comprehensive_factory_validation(self, factory):
"""Comprehensive validation of factory functionality."""
# High-level test ensuring factory meets all requirements
Expand Down
10 changes: 5 additions & 5 deletions tests/strategies/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_regime_based_recommendations(self, factory, sample_conditions):
def test_recommendations_ranking(self, factory):
"""Test recommendations are properly ranked by confidence."""
conditions = MarketConditions(
regime=1, # Bull trending
regime=1, # BEAR_TRENDING (RegimeType)
volatility_rank=0.4,
trend_strength=0.6,
time_to_expiration=30,
Expand Down Expand Up @@ -212,10 +212,10 @@ def test_market_condition_validation(self, factory):
"""Test market condition validation across strategies."""
# Test various market conditions
test_conditions = [
MarketConditions(0, 0.8, -0.8, 30, 10000, 0.05), # Deep bear, high vol
MarketConditions(1, 0.4, 0.6, 30, 10000, 0.05), # Bull trending
MarketConditions(4, 0.2, 0.1, 30, 10000, 0.05), # Low vol sideways
MarketConditions(7, 0.9, 0.0, 30, 10000, 0.05), # High vol uncertain
MarketConditions(0, 0.8, 0.8, 30, 10000, 0.05), # BULL_TRENDING
MarketConditions(1, 0.7, -0.6, 30, 10000, 0.05), # BEAR_TRENDING
MarketConditions(4, 0.2, 0.1, 30, 10000, 0.05), # SIDEWAYS_RANGING
MarketConditions(7, 0.9, -0.8, 30, 10000, 0.05), # CRISIS
]

strategies = factory.get_all_strategies()[:5] # Test subset of strategies
Expand Down