ML4T Diagnostic
ML4T Diagnostic Documentation
Feature validation, strategy diagnostics, and Deflated Sharpe Ratio
Skip to content

API Reference

This reference is organized by import surface rather than by source tree alone.

Use case Import surface
Stable application code ml4t.diagnostic.api
Notebook and exploratory work ml4t.diagnostic
Metrics and feature statistics ml4t.diagnostic.metrics
Statistical primitives ml4t.diagnostic.evaluation.stats
Splitters and fold persistence ml4t.diagnostic.splitters
Signal analysis ml4t.diagnostic.signal
Backtest bridges ml4t.diagnostic.integration
Plotly figures and dashboards ml4t.diagnostic.visualization

Stable API (ml4t.diagnostic.api)

Use this module when you want imports that are less sensitive to future re-export cleanup at the package root.

Category Objects
Validation workflows ValidatedCrossValidation, ValidatedCrossValidationConfig, validated_cross_val_score, ValidationResult, ValidationFoldResult
Diagnostics FeatureDiagnostics, FeatureDiagnosticsResult, TradeAnalysis, PortfolioAnalysis, BarrierAnalysis
Signal analysis analyze_signal, SignalResult
Splitters CombinatorialCV, WalkForwardCV
Metrics cross_sectional_ic_series, cross_sectional_ic, pooled_ic, compute_ic_hac_stats, compute_mdi_importance, compute_permutation_importance, compute_shap_importance, compute_h_statistic, compute_shap_interactions, analyze_ml_importance, analyze_interactions

Package-Level Convenience API (ml4t.diagnostic)

The package root re-exports the most common classes and configs for interactive use:

Category Objects
Core workflows ValidatedCrossValidation, FeatureSelector, analyze_signal, BarrierAnalysis
Result types SignalResult, data-quality schemas
Configuration DiagnosticConfig, StatisticalConfig, PortfolioConfig, TradeConfig, SignalConfig, EventConfig, BarrierConfig, ReportConfig, RuntimeConfig
Optional visuals selected barrier-analysis plot functions when viz dependencies are installed

Configuration

config

ML4T Diagnostic Configuration System.

This module provides comprehensive Pydantic v2 configuration schemas for the ML4T Diagnostic framework, covering:

  • Feature Evaluation: Diagnostics, cross-feature analysis, feature-outcome relationships
  • Portfolio Evaluation: Risk/return metrics, Bayesian comparison
  • Statistical Framework: PSR, MinTRL, DSR, FDR for multiple testing correction
  • Reporting: HTML, JSON, visualization settings

Examples:

Quick start with defaults:

>>> from ml4t.diagnostic.config import DiagnosticConfig
>>> config = DiagnosticConfig()

Custom configuration:

>>> config = DiagnosticConfig(
...     stationarity=StationaritySettings(significance_level=0.01),
...     ic=ICSettings(lag_structure=[0, 1, 5, 10]),
... )

Load from YAML:

>>> config = DiagnosticConfig.from_yaml("config.yaml")

Use presets:

>>> config = DiagnosticConfig.for_quick_analysis()

DiagnosticConfig

Bases: BaseConfig

Consolidated configuration for feature analysis (single-level nesting).

Provides comprehensive feature diagnostics with direct access to all settings: - config.stationarity.enabled (not config.module_a.stationarity.enabled)

Examples

config = DiagnosticConfig( ... stationarity=StationaritySettings(significance_level=0.01), ... ic=ICSettings(lag_structure=[0, 1, 5, 10, 21]), ... ) config.to_yaml("diagnostic_config.yaml")

for_quick_analysis classmethod

for_quick_analysis()

Preset for quick exploratory analysis.

Source code in src/ml4t/diagnostic/config/feature_config.py
@classmethod
def for_quick_analysis(cls) -> DiagnosticConfig:
    """Preset for quick exploratory analysis."""
    return cls(
        stationarity=StationaritySettings(pp_enabled=False),
        volatility=VolatilitySettings(detect_clustering=False),
        distribution=DistributionSettings(detect_outliers=False),
        correlation=CorrelationSettings(lag_correlations=False),
        pca=PCASettings(enabled=False),
        clustering=ClusteringSettings(enabled=False),
        ic=ICSettings(hac_adjustment=False, compute_decay=False),
        ml_diagnostics=MLDiagnosticsSettings(shap_analysis=False, drift_detection=False),
    )

for_research classmethod

for_research()

Preset for academic research (comprehensive).

Source code in src/ml4t/diagnostic/config/feature_config.py
@classmethod
def for_research(cls) -> DiagnosticConfig:
    """Preset for academic research (comprehensive)."""
    return cls(
        stationarity=StationaritySettings(pp_enabled=True),
        volatility=VolatilitySettings(window_sizes=[10, 21, 63]),
        distribution=DistributionSettings(
            detect_outliers=True,
            normality_tests=[
                NormalityTest.JARQUE_BERA,
                NormalityTest.SHAPIRO,
                NormalityTest.ANDERSON,
            ],
        ),
        correlation=CorrelationSettings(
            methods=[
                CorrelationMethod.PEARSON,
                CorrelationMethod.SPEARMAN,
                CorrelationMethod.KENDALL,
            ],
            lag_correlations=True,
        ),
        pca=PCASettings(enabled=True),
        clustering=ClusteringSettings(enabled=True),
        ic=ICSettings(lag_structure=[0, 1, 5, 10, 21], hac_adjustment=True, compute_decay=True),
        binary_classification=BinaryClassificationSettings(enabled=True),
        threshold_analysis=ThresholdAnalysisSettings(enabled=True),
        ml_diagnostics=MLDiagnosticsSettings(shap_analysis=True, drift_detection=True),
    )

for_production classmethod

for_production()

Preset for production monitoring (fast, focused on drift).

Source code in src/ml4t/diagnostic/config/feature_config.py
@classmethod
def for_production(cls) -> DiagnosticConfig:
    """Preset for production monitoring (fast, focused on drift)."""
    return cls(
        stationarity=StationaritySettings(pp_enabled=False),
        acf=ACFSettings(enabled=False),
        volatility=VolatilitySettings(enabled=False),
        distribution=DistributionSettings(test_normality=False, compute_moments=True),
        correlation=CorrelationSettings(lag_correlations=False),
        pca=PCASettings(enabled=False),
        clustering=ClusteringSettings(enabled=False),
        ic=ICSettings(compute_decay=False),
        ml_diagnostics=MLDiagnosticsSettings(
            feature_importance=True, drift_detection=True, drift_window=21
        ),
    )

StatisticalConfig

Bases: BaseConfig

Consolidated configuration for statistical testing.

Orchestrates advanced Sharpe ratio analysis with multiple testing correction.

Examples

config = StatisticalConfig( ... psr=PSRSettings(target_sharpe=1.0), ... dsr=DSRSettings(n_trials=500), ... )

Or use presets

config = StatisticalConfig.for_research()

for_quick_check classmethod

for_quick_check()

Preset for quick overfitting check (PSR + DSR only).

Source code in src/ml4t/diagnostic/config/sharpe_config.py
@classmethod
def for_quick_check(cls) -> StatisticalConfig:
    """Preset for quick overfitting check (PSR + DSR only)."""
    return cls(
        psr=PSRSettings(compute_for_thresholds=None),
        mintrl=MinTRLSettings(enabled=False),
        dsr=DSRSettings(n_trials=100),
        fdr=FDRSettings(enabled=False),
    )

for_research classmethod

for_research()

Preset for academic research (comprehensive analysis).

Source code in src/ml4t/diagnostic/config/sharpe_config.py
@classmethod
def for_research(cls) -> StatisticalConfig:
    """Preset for academic research (comprehensive analysis)."""
    return cls(
        psr=PSRSettings(
            compute_for_thresholds=[0.0, 0.5, 1.0, 1.5, 2.0],
            confidence_level=0.99,
        ),
        mintrl=MinTRLSettings(compute_for_thresholds=[0.0, 0.5, 1.0]),
        dsr=DSRSettings(n_trials=500, prob_zero_sharpe=0.5),
        fdr=FDRSettings(
            method=FDRMethod.BENJAMINI_YEKUTIELI,
            alpha=0.05,
        ),
    )

for_publication classmethod

for_publication()

Preset for academic publication (very conservative).

Source code in src/ml4t/diagnostic/config/sharpe_config.py
@classmethod
def for_publication(cls) -> StatisticalConfig:
    """Preset for academic publication (very conservative)."""
    return cls(
        psr=PSRSettings(confidence_level=0.99, target_sharpe=0.5),
        mintrl=MinTRLSettings(confidence_level=0.99, target_sharpe=0.5),
        dsr=DSRSettings(
            n_trials=1000,
            prob_zero_sharpe=0.8,
            variance_inflation=1.5,
        ),
        fdr=FDRSettings(
            method=FDRMethod.BONFERRONI,
            alpha=0.01,
        ),
    )

PortfolioConfig

Bases: BaseConfig

Consolidated configuration for portfolio evaluation.

Orchestrates portfolio performance analysis with metrics, Bayesian comparison, time aggregation, and drawdown analysis.

Examples

config = PortfolioConfig( ... metrics=MetricsSettings(risk_free_rate=0.02), ... bayesian=BayesianSettings(enabled=True), ... ) config.to_yaml("portfolio_config.yaml")

for_quick_analysis classmethod

for_quick_analysis()

Preset for quick exploratory analysis.

Source code in src/ml4t/diagnostic/config/portfolio_config.py
@classmethod
def for_quick_analysis(cls) -> PortfolioConfig:
    """Preset for quick exploratory analysis."""
    return cls(
        metrics=MetricsSettings(metrics=[PortfolioMetric.SHARPE, PortfolioMetric.MAX_DRAWDOWN]),
        bayesian=BayesianSettings(enabled=False),
        aggregation=AggregationSettings(compute_rolling=False),
        drawdown=DrawdownSettings(compute_recovery_time=False),
    )

for_research classmethod

for_research()

Preset for academic research.

Source code in src/ml4t/diagnostic/config/portfolio_config.py
@classmethod
def for_research(cls) -> PortfolioConfig:
    """Preset for academic research."""
    return cls(
        metrics=MetricsSettings(
            metrics=[
                PortfolioMetric.SHARPE,
                PortfolioMetric.SORTINO,
                PortfolioMetric.CALMAR,
                PortfolioMetric.MAX_DRAWDOWN,
                PortfolioMetric.VAR,
                PortfolioMetric.CVAR,
                PortfolioMetric.OMEGA,
            ]
        ),
        bayesian=BayesianSettings(enabled=True, n_samples=50000),
        aggregation=AggregationSettings(
            frequencies=[TimeFrequency.DAILY, TimeFrequency.WEEKLY, TimeFrequency.MONTHLY],
            compute_rolling=True,
            rolling_windows=[21, 63, 126, 252],
        ),
        drawdown=DrawdownSettings(compute_underwater_curve=True, top_n_drawdowns=10),
    )

for_production classmethod

for_production()

Preset for production monitoring.

Source code in src/ml4t/diagnostic/config/portfolio_config.py
@classmethod
def for_production(cls) -> PortfolioConfig:
    """Preset for production monitoring."""
    return cls(
        metrics=MetricsSettings(
            metrics=[PortfolioMetric.SHARPE, PortfolioMetric.MAX_DRAWDOWN, PortfolioMetric.VAR]
        ),
        bayesian=BayesianSettings(enabled=False),
        aggregation=AggregationSettings(
            frequencies=[TimeFrequency.DAILY], compute_rolling=True, rolling_windows=[21, 63]
        ),
        drawdown=DrawdownSettings(compute_recovery_time=False),
    )

TradeConfig

Bases: BaseConfig

Consolidated configuration for trade analysis.

Combines trade extraction, filtering, SHAP alignment, error pattern clustering, and hypothesis generation into a single configuration.

Examples

config = TradeConfig( ... extraction=ExtractionSettings(n_worst=50), ... clustering=ClusteringSettings(min_cluster_size=10), ... ) config.to_yaml("trade_config.yaml")

n_worst property

n_worst

Number of worst trades (shortcut).

n_best property

n_best

Number of best trades (shortcut).

warn_low_min_trades classmethod

warn_low_min_trades(v)

Warn if min_trades is very low.

Source code in src/ml4t/diagnostic/config/trade_analysis_config.py
@field_validator("min_trades_for_clustering")
@classmethod
def warn_low_min_trades(cls, v: int) -> int:
    """Warn if min_trades is very low."""
    if v < 10:
        import warnings

        warnings.warn(
            f"min_trades_for_clustering={v} may not identify reliable patterns. Use >= 20.",
            stacklevel=2,
        )
    return v

for_quick_diagnostics classmethod

for_quick_diagnostics()

Preset for quick diagnostics (minimal clustering).

Source code in src/ml4t/diagnostic/config/trade_analysis_config.py
@classmethod
def for_quick_diagnostics(cls) -> TradeConfig:
    """Preset for quick diagnostics (minimal clustering)."""
    return cls(
        extraction=ExtractionSettings(n_worst=20, n_best=10),
        alignment=AlignmentSettings(top_n_features=10),
        clustering=ClusteringSettings(min_cluster_size=3, max_clusters=5),
        hypothesis=HypothesisSettings(template_library="minimal", max_per_cluster=3),
        min_trades_for_clustering=10,
        generate_visualizations=False,
    )

for_deep_analysis classmethod

for_deep_analysis()

Preset for comprehensive analysis.

Source code in src/ml4t/diagnostic/config/trade_analysis_config.py
@classmethod
def for_deep_analysis(cls) -> TradeConfig:
    """Preset for comprehensive analysis."""
    return cls(
        extraction=ExtractionSettings(n_worst=50, n_best=20, compute_statistics=True),
        alignment=AlignmentSettings(top_n_features=None, mode="average"),
        clustering=ClusteringSettings(
            method=ClusteringMethod.HIERARCHICAL,
            linkage=LinkageMethod.WARD,
            min_cluster_size=10,
            max_clusters=None,
            normalization="l2",
        ),
        hypothesis=HypothesisSettings(
            min_confidence=0.6,
            max_per_cluster=10,
            include_interactions=True,
            template_library="comprehensive",
        ),
        min_trades_for_clustering=30,
        generate_visualizations=True,
    )

for_production classmethod

for_production()

Preset for production monitoring (efficient, focused).

Source code in src/ml4t/diagnostic/config/trade_analysis_config.py
@classmethod
def for_production(cls) -> TradeConfig:
    """Preset for production monitoring (efficient, focused)."""
    return cls(
        extraction=ExtractionSettings(n_worst=20, n_best=5, group_by_symbol=True),
        alignment=AlignmentSettings(top_n_features=15),
        clustering=ClusteringSettings(min_cluster_size=5, max_clusters=8),
        hypothesis=HypothesisSettings(min_confidence=0.7, max_per_cluster=3),
        min_trades_for_clustering=15,
        generate_visualizations=False,
        cache_shap_vectors=True,
    )

SignalConfig

Bases: BaseConfig

Consolidated configuration for signal analysis.

Combines analysis settings, RAS adjustment, visualization, and multi-signal batch analysis into a single configuration class.

Examples

config = SignalConfig( ... analysis=AnalysisSettings(quantiles=10, periods=(1, 5)), ... visualization=VisualizationSettings(theme="dark"), ... ) config.to_yaml("signal_config.yaml")

quantiles property

quantiles

Number of quantiles (shortcut).

periods property

periods

Forward return periods (shortcut).

filter_zscore property

filter_zscore

Outlier z-score threshold (shortcut).

compute_turnover property

compute_turnover

Compute turnover metrics (shortcut).

validate_quantile_labels_count

validate_quantile_labels_count()

Ensure quantile_labels matches quantiles count if provided.

Source code in src/ml4t/diagnostic/config/signal_config.py
@model_validator(mode="after")
def validate_quantile_labels_count(self) -> SignalConfig:
    """Ensure quantile_labels matches quantiles count if provided."""
    if self.analysis.quantile_labels is not None:
        if len(self.analysis.quantile_labels) != self.analysis.quantiles:
            raise ValueError(
                f"quantile_labels length ({len(self.analysis.quantile_labels)}) "
                f"must match quantiles ({self.analysis.quantiles})"
            )
    return self

EventConfig

Bases: BaseConfig

Configuration for event study analysis.

Configures the event study methodology including window parameters, abnormal return model, and statistical test.

Attributes

window : WindowSettings Window configuration (estimation and event periods) model : str Model for computing normal/expected returns test : str Statistical test for significance confidence_level : float Confidence level for intervals min_estimation_obs : int Minimum observations in estimation window

Examples

config = EventConfig( ... window=WindowSettings(estimation_start=-252, event_end=10), ... model="market_model", ... test="boehmer", ... )

alpha property

alpha

Significance level (1 - confidence_level).

BarrierConfig

Bases: BaseConfig

Consolidated configuration for barrier analysis.

Combines analysis settings, column mappings, and visualization options into a single configuration class.

Examples

config = BarrierConfig( ... analysis=AnalysisSettings(n_quantiles=5), ... visualization=VisualizationSettings(theme="dark"), ... ) config.to_yaml("barrier_config.yaml")

n_quantiles property

n_quantiles

Number of quantiles (shortcut).

significance_level property

significance_level

Significance level (shortcut).

decile_method property

decile_method

Decile method (shortcut).

min_observations_per_quantile property

min_observations_per_quantile

Minimum observations per quantile (shortcut).

filter_zscore property

filter_zscore

Z-score filter threshold (shortcut).

drop_timeout property

drop_timeout

Drop timeout outcomes (shortcut).

bootstrap_n_resamples property

bootstrap_n_resamples

Bootstrap resamples (shortcut).

hit_rate_min_observations property

hit_rate_min_observations

Hit rate minimum observations (shortcut).

profit_factor_epsilon property

profit_factor_epsilon

Profit factor epsilon (shortcut).

signal_col property

signal_col

Signal column name (shortcut).

date_col property

date_col

Date column name (shortcut).

asset_col property

asset_col

Asset column name (shortcut).

label_col property

label_col

Label column name (shortcut).

label_return_col property

label_return_col

Label return column name (shortcut).

label_bars_col property

label_bars_col

Label bars column name (shortcut).

validate_column_uniqueness

validate_column_uniqueness()

Ensure column names don't conflict.

Source code in src/ml4t/diagnostic/config/barrier_config.py
@model_validator(mode="after")
def validate_column_uniqueness(self) -> BarrierConfig:
    """Ensure column names don't conflict."""
    cols = [self.columns.signal_col, self.columns.date_col, self.columns.asset_col]
    if len(cols) != len(set(cols)):
        raise ValueError(f"Column names must be unique: {cols}")
    return self

ReportConfig

Bases: BaseConfig

Top-level configuration for reporting (Module E).

Orchestrates report generation: - Output formats (HTML, JSON, PDF) - HTML settings (templates, themes, tables) - Visualization (plots, colors, interactivity) - JSON structure

Attributes:

Name Type Description
output_format OutputFormatConfig

Output format configuration

html HTMLConfig

HTML report configuration

visualization VisualizationConfig

Visualization configuration

json VisualizationConfig

JSON output configuration

lazy_rendering bool

Don't generate plots until accessed

cache_plots bool

Cache generated plots

parallel_plotting bool

Generate plots in parallel

n_jobs int

Parallel jobs for plotting

Examples:

>>> # Quick start with defaults
>>> config = ReportConfig()
>>> reporter = Reporter(config)
>>> reporter.generate(results, output_name="my_strategy")
>>> # Load from YAML
>>> config = ReportConfig.from_yaml("report_config.yaml")
>>> # Custom configuration
>>> config = ReportConfig(
...     output_format=OutputFormatConfig(
...         formats=[ReportFormat.HTML, ReportFormat.PDF]
...     ),
...     html=HTMLConfig(
...         template=ReportTemplate.SUMMARY,
...         theme=ReportTheme.PROFESSIONAL
...     ),
...     visualization=VisualizationConfig(
...         plot_dpi=300,
...         save_plots=True
...     )
... )

for_quick_report classmethod

for_quick_report()

Preset for quick HTML-only report (minimal plots).

Returns:

Type Description
ReportConfig

Config optimized for speed

Source code in src/ml4t/diagnostic/config/report_config.py
@classmethod
def for_quick_report(cls) -> ReportConfig:
    """Preset for quick HTML-only report (minimal plots).

    Returns:
        Config optimized for speed
    """
    return cls(
        output_format=OutputFormatConfig(formats=[ReportFormat.HTML]),
        html=HTMLConfig(
            template=ReportTemplate.SUMMARY,
            interactive_plots=False,  # Faster static plots
        ),
        visualization=VisualizationConfig(
            correlation_heatmap=True,
            time_series_plots=False,
            distribution_plots=False,
            scatter_plots=False,
        ),
        lazy_rendering=True,
    )

for_publication classmethod

for_publication()

Preset for publication-quality reports (high-res, all plots).

Returns:

Type Description
ReportConfig

Config optimized for publication

Source code in src/ml4t/diagnostic/config/report_config.py
@classmethod
def for_publication(cls) -> ReportConfig:
    """Preset for publication-quality reports (high-res, all plots).

    Returns:
        Config optimized for publication
    """
    return cls(
        output_format=OutputFormatConfig(
            formats=[ReportFormat.HTML, ReportFormat.PDF],
            compress=True,
        ),
        html=HTMLConfig(
            template=ReportTemplate.FULL,
            theme=ReportTheme.PROFESSIONAL,
            table_format=TableFormat.STYLED,
        ),
        visualization=VisualizationConfig(
            plot_dpi=300,
            plot_format="pdf",
            save_plots=True,
            correlation_heatmap=True,
            time_series_plots=True,
            distribution_plots=True,
            scatter_plots=True,
        ),
        json_config=JSONConfig(pretty_print=True, include_metadata=True),
        cache_plots=True,
        parallel_plotting=True,
    )

for_programmatic_access classmethod

for_programmatic_access()

Preset for programmatic access (JSON only, no plots).

Returns:

Type Description
ReportConfig

Config optimized for API/programmatic use

Source code in src/ml4t/diagnostic/config/report_config.py
@classmethod
def for_programmatic_access(cls) -> ReportConfig:
    """Preset for programmatic access (JSON only, no plots).

    Returns:
        Config optimized for API/programmatic use
    """
    return cls(
        output_format=OutputFormatConfig(formats=[ReportFormat.JSON]),
        visualization=VisualizationConfig(
            correlation_heatmap=False,
            time_series_plots=False,
            distribution_plots=False,
            scatter_plots=False,
        ),
        json_config=JSONConfig(
            pretty_print=False,  # Compact for parsing
            include_raw_data=True,  # Include data for downstream processing
            export_dataframes=DataFrameExportFormat.SPLIT,  # Efficient format
        ),
        lazy_rendering=True,
    )

RuntimeConfig

Bases: BaseConfig

Configuration for execution settings.

Centralizes computational resources, caching, and randomness across all evaluation functions. Pass as a separate parameter to analysis functions.

Attributes:

Name Type Description
n_jobs int

Number of parallel jobs (-1 for all cores, 1 for serial)

cache_enabled bool

Enable caching of expensive computations

cache_dir Path

Directory for cache storage

cache_ttl int | None

Cache time-to-live in seconds (None for no expiration)

verbose bool

Enable verbose output

random_state int | None

Random seed for reproducibility

Examples:

>>> from ml4t.diagnostic.config import RuntimeConfig, DiagnosticConfig
>>> runtime = RuntimeConfig(n_jobs=4, verbose=True)
>>> result = analyze_features(df, config=DiagnosticConfig(), runtime=runtime)

model_post_init

model_post_init(__context)

Create cache directory if it doesn't exist.

Source code in src/ml4t/diagnostic/config/base.py
def model_post_init(self, __context: Any) -> None:
    """Create cache directory if it doesn't exist."""
    if self.cache_enabled:
        self.cache_dir.mkdir(parents=True, exist_ok=True)

ValidatedCrossValidationConfig

Bases: BaseConfig

Configuration for ValidatedCrossValidation orchestration.

Signal Analysis

signal

Signal analysis for factor/alpha evaluation.

This module provides tools for analyzing the predictive power of signals (factors) for future returns.

Main Entry Point

analyze_signal : Compute IC, quantile returns, spread, and turnover for a factor signal. This is the recommended way to use this module.

Example

from ml4t.diagnostic.signal import analyze_signal result = analyze_signal(factor_df, prices_df) print(result.summary()) result.to_json("results.json")

Building Blocks

For custom workflows, use the component functions:

  • prepare_data : Join factor with prices and compute forward returns
  • extract_signal_ic_series : Extract per-date IC values for one horizon
  • compute_quantile_returns : Compute returns by quantile
  • compute_turnover : Compute factor turnover rate
  • filter_outliers : Remove cross-sectional outliers
  • quantize_factor : Assign quantile labels

SignalResult dataclass

SignalResult(
    ic,
    ic_std,
    ic_t_stat,
    ic_p_value,
    ic_ir=dict(),
    ic_positive_pct=dict(),
    ic_series=dict(),
    quantile_returns=dict(),
    spread=dict(),
    spread_t_stat=dict(),
    spread_p_value=dict(),
    monotonicity=dict(),
    ic_dates=dict(),
    quantile_returns_std=dict(),
    count_by_quantile=dict(),
    spread_std=dict(),
    turnover=None,
    autocorrelation=None,
    half_life=None,
    n_assets=0,
    n_dates=0,
    date_range=("", ""),
    periods=(),
    quantiles=5,
)

Immutable result from signal analysis.

All metrics are keyed by period (e.g., "1D", "5D", "21D").

Attributes

ic : dict[str, float] Mean IC by period. ic_std : dict[str, float] IC standard deviation by period. ic_t_stat : dict[str, float] T-statistic for IC != 0. ic_p_value : dict[str, float] P-value for IC significance. ic_ir : dict[str, float] Information Ratio (IC mean / IC std) by period. ic_positive_pct : dict[str, float] Percentage of periods with positive IC. ic_series : dict[str, list[float]] IC time series by period. quantile_returns : dict[str, dict[int, float]] Mean returns by period and quantile. spread : dict[str, float] Top minus bottom quantile spread. spread_t_stat : dict[str, float] T-statistic for spread. spread_p_value : dict[str, float] P-value for spread significance. monotonicity : dict[str, float] Rank correlation of quantile returns (how monotonic). turnover : dict[str, float] | None Mean turnover rate by period. autocorrelation : list[float] | None Factor autocorrelation at lags 1, 2, ... half_life : float | None Estimated signal half-life in periods. n_assets : int Number of unique assets. n_dates : int Number of unique dates. date_range : tuple[str, str] (first_date, last_date). periods : tuple[int, ...] Forward return periods analyzed. quantiles : int Number of quantiles used.

summary

summary()

Human-readable summary of results.

Source code in src/ml4t/diagnostic/signal/result.py
def summary(self) -> str:
    """Human-readable summary of results."""
    lines = [
        f"Signal Analysis: {self.n_assets} assets, {self.n_dates} dates",
        f"Date range: {self.date_range[0]} to {self.date_range[1]}",
        f"Periods: {self.periods}, Quantiles: {self.quantiles}",
        "",
        "IC Summary:",
    ]

    for period in [f"{p}D" for p in self.periods]:
        ic_val = self.ic.get(period, float("nan"))
        t = self.ic_t_stat.get(period, float("nan"))
        p = self.ic_p_value.get(period, float("nan"))
        ir = self.ic_ir.get(period, float("nan"))
        pos_pct = self.ic_positive_pct.get(period, float("nan"))
        sig = "*" if p < 0.05 else ""
        lines.append(
            f"  {period}: IC={ic_val:+.4f} (t={t:.2f}, p={p:.3f}){sig}, IR={ir:.2f}, +%={pos_pct:.0f}%"
        )

    lines.append("\nSpread (Top - Bottom):")
    for period in [f"{p}D" for p in self.periods]:
        spread = self.spread.get(period, float("nan"))
        t = self.spread_t_stat.get(period, float("nan"))
        p = self.spread_p_value.get(period, float("nan"))
        sig = "*" if p < 0.05 else ""
        lines.append(f"  {period}: {spread:+.4f} (t={t:.2f}, p={p:.3f}){sig}")

    lines.append("\nMonotonicity:")
    for period in [f"{p}D" for p in self.periods]:
        mono = self.monotonicity.get(period, float("nan"))
        lines.append(f"  {period}: {mono:+.3f}")

    if self.turnover:
        lines.append("\nTurnover:")
        for period in [f"{p}D" for p in self.periods]:
            t = self.turnover.get(period, float("nan"))
            lines.append(f"  {period}: {t:.1%}")

    if self.half_life is not None:
        lines.append(f"\nHalf-life: {self.half_life:.1f} periods")

    return "\n".join(lines)

to_ic_result

to_ic_result(period=None)

Convert to SignalICResult for visualization functions.

Parameters

period : int | str | None Specific period (e.g. 21 or "21D"). If None, includes all periods aligned to their common date intersection.

Returns

SignalICResult Pydantic model compatible with plot_ic_ts, plot_ic_histogram, etc.

Raises

ValueError If ic_dates is empty (result created without date capture).

Examples

result = analyze_signal(factor_df, prices_df) plot_ic_ts(result.to_ic_result()) plot_ic_ts(result.to_ic_result(period=21))

Source code in src/ml4t/diagnostic/signal/result.py
def to_ic_result(self, period: int | str | None = None) -> SignalICResult:
    """Convert to SignalICResult for visualization functions.

    Parameters
    ----------
    period : int | str | None
        Specific period (e.g. 21 or "21D"). If None, includes all periods
        aligned to their common date intersection.

    Returns
    -------
    SignalICResult
        Pydantic model compatible with plot_ic_ts, plot_ic_histogram, etc.

    Raises
    ------
    ValueError
        If ic_dates is empty (result created without date capture).

    Examples
    --------
    >>> result = analyze_signal(factor_df, prices_df)
    >>> plot_ic_ts(result.to_ic_result())
    >>> plot_ic_ts(result.to_ic_result(period=21))
    """
    from ml4t.diagnostic.results.signal_results.ic import SignalICResult

    if not self.ic_dates:
        raise ValueError("ic_dates not available. Re-run analyze_signal() to capture dates.")

    period_keys: list[str]
    if period is not None:
        key = f"{period}D" if isinstance(period, int) else str(period)
        if not key.endswith("D"):
            key = f"{key}D"
        if key not in self.ic:
            raise ValueError(f"Period '{key}' not found. Available: {list(self.ic.keys())}")
        period_keys = [key]
    else:
        period_keys = [f"{p}D" for p in self.periods]

    # Find common date intersection across requested periods
    date_sets = [set(self.ic_dates[k]) for k in period_keys if k in self.ic_dates]
    if not date_sets:
        common_dates: list[str] = []
    else:
        common = date_sets[0]
        for ds in date_sets[1:]:
            common = common & ds
        common_dates = sorted(common)

    # Build ic_by_date aligned to common dates
    ic_by_date: dict[str, list[float]] = {}
    for key in period_keys:
        if key not in self.ic_dates or key not in self.ic_series:
            ic_by_date[key] = []
            continue
        date_to_ic = dict(zip(self.ic_dates[key], self.ic_series[key], strict=False))
        ic_by_date[key] = [date_to_ic[d] for d in common_dates if d in date_to_ic]

    return SignalICResult(
        ic_by_date=ic_by_date,
        dates=common_dates,
        ic_mean={k: self.ic[k] for k in period_keys},
        ic_std={k: self.ic_std[k] for k in period_keys},
        ic_t_stat={k: self.ic_t_stat[k] for k in period_keys},
        ic_p_value={k: self.ic_p_value[k] for k in period_keys},
        ic_positive_pct={k: self.ic_positive_pct.get(k, 0.0) for k in period_keys},
        ic_ir={k: self.ic_ir.get(k, 0.0) for k in period_keys},
    )

to_quantile_result

to_quantile_result()

Convert to QuantileAnalysisResult for visualization functions.

Returns

QuantileAnalysisResult Pydantic model compatible with plot_quantile_returns_bar, etc.

Examples

result = analyze_signal(factor_df, prices_df) plot_quantile_returns_bar(result.to_quantile_result())

Source code in src/ml4t/diagnostic/signal/result.py
def to_quantile_result(self) -> QuantileAnalysisResult:
    """Convert to QuantileAnalysisResult for visualization functions.

    Returns
    -------
    QuantileAnalysisResult
        Pydantic model compatible with plot_quantile_returns_bar, etc.

    Examples
    --------
    >>> result = analyze_signal(factor_df, prices_df)
    >>> plot_quantile_returns_bar(result.to_quantile_result())
    """
    from ml4t.diagnostic.results.signal_results.quantile import QuantileAnalysisResult

    period_keys = [f"{p}D" for p in self.periods]
    quantile_labels = [f"Q{i}" for i in range(1, self.quantiles + 1)]

    # Convert int quantile keys to string labels: {1: val} -> {"Q1": val}
    def _relabel(d: dict[int, float]) -> dict[str, float]:
        return {f"Q{k}": v for k, v in sorted(d.items())}

    mean_returns = {pk: _relabel(self.quantile_returns[pk]) for pk in period_keys}

    # std_returns: use captured data or fill with 0.0
    std_returns: dict[str, dict[str, float]] = {}
    for pk in period_keys:
        if pk in self.quantile_returns_std:
            std_returns[pk] = _relabel(self.quantile_returns_std[pk])
        else:
            std_returns[pk] = dict.fromkeys(quantile_labels, 0.0)

    # count_by_quantile with string labels
    count_by_q: dict[str, int]
    if self.count_by_quantile:
        count_by_q = {f"Q{k}": v for k, v in sorted(self.count_by_quantile.items())}
    else:
        count_by_q = dict.fromkeys(quantile_labels, 0)

    # Spread metrics
    spread_mean = {pk: self.spread.get(pk, 0.0) for pk in period_keys}
    spread_std_d = {pk: self.spread_std.get(pk, 0.0) for pk in period_keys}
    spread_t = {pk: self.spread_t_stat.get(pk, 0.0) for pk in period_keys}
    spread_p = {pk: self.spread_p_value.get(pk, 1.0) for pk in period_keys}

    # Confidence intervals: spread ± 1.96 * spread_std
    z = 1.96
    spread_ci_lower: dict[str, float] = {}
    spread_ci_upper: dict[str, float] = {}
    for pk in period_keys:
        s = spread_mean[pk]
        se = spread_std_d[pk]
        if math.isfinite(se):
            spread_ci_lower[pk] = s - z * se
            spread_ci_upper[pk] = s + z * se
        else:
            spread_ci_lower[pk] = float("nan")
            spread_ci_upper[pk] = float("nan")

    # Monotonicity derivation
    is_monotonic: dict[str, bool] = {}
    monotonicity_direction: dict[str, str] = {}
    rank_correlation: dict[str, float] = {}
    for pk in period_keys:
        rho = self.monotonicity.get(pk, 0.0)
        rank_correlation[pk] = rho
        is_monotonic[pk] = abs(rho) > 0.8
        if rho > 0.8:
            monotonicity_direction[pk] = "increasing"
        elif rho < -0.8:
            monotonicity_direction[pk] = "decreasing"
        else:
            monotonicity_direction[pk] = "none"

    return QuantileAnalysisResult(
        n_quantiles=self.quantiles,
        quantile_labels=quantile_labels,
        periods=period_keys,
        mean_returns=mean_returns,
        std_returns=std_returns,
        count_by_quantile=count_by_q,
        spread_mean=spread_mean,
        spread_std=spread_std_d,
        spread_t_stat=spread_t,
        spread_p_value=spread_p,
        spread_ci_lower=spread_ci_lower,
        spread_ci_upper=spread_ci_upper,
        is_monotonic=is_monotonic,
        monotonicity_direction=monotonicity_direction,
        rank_correlation=rank_correlation,
    )

to_tear_sheet

to_tear_sheet(signal_name='signal')

Convert to full SignalTearSheet for dashboard display.

Bundles to_ic_result() and to_quantile_result() into a SignalTearSheet.

Parameters

signal_name : str Name for the signal (used in dashboard title).

Returns

SignalTearSheet Pydantic model with IC and quantile analysis components.

Examples

result = analyze_signal(factor_df, prices_df) tear_sheet = result.to_tear_sheet("momentum_21d") tear_sheet.show()

Source code in src/ml4t/diagnostic/signal/result.py
def to_tear_sheet(self, signal_name: str = "signal") -> SignalTearSheet:
    """Convert to full SignalTearSheet for dashboard display.

    Bundles to_ic_result() and to_quantile_result() into a SignalTearSheet.

    Parameters
    ----------
    signal_name : str
        Name for the signal (used in dashboard title).

    Returns
    -------
    SignalTearSheet
        Pydantic model with IC and quantile analysis components.

    Examples
    --------
    >>> result = analyze_signal(factor_df, prices_df)
    >>> tear_sheet = result.to_tear_sheet("momentum_21d")
    >>> tear_sheet.show()
    """
    from ml4t.diagnostic.results.signal_results.tearsheet import SignalTearSheet

    ic_result = self.to_ic_result() if self.ic_dates else None
    quantile_result = self.to_quantile_result()

    return SignalTearSheet(
        signal_name=signal_name,
        n_assets=self.n_assets,
        n_dates=self.n_dates,
        date_range=self.date_range,
        ic_analysis=ic_result,
        quantile_analysis=quantile_result,
    )

to_dict

to_dict()

Export to dictionary.

Source code in src/ml4t/diagnostic/signal/result.py
def to_dict(self) -> dict[str, Any]:
    """Export to dictionary."""
    return asdict(self)

to_json

to_json(path=None, indent=2)

Export to JSON string or file.

Parameters

path : str | None If provided, write to file. Otherwise return string. indent : int JSON indentation level.

Returns

str JSON string.

Source code in src/ml4t/diagnostic/signal/result.py
def to_json(self, path: str | None = None, indent: int = 2) -> str:
    """Export to JSON string or file.

    Parameters
    ----------
    path : str | None
        If provided, write to file. Otherwise return string.
    indent : int
        JSON indentation level.

    Returns
    -------
    str
        JSON string.
    """
    data = self.to_dict()

    def convert(obj: Any) -> Any:
        if isinstance(obj, float) and (obj != obj):  # NaN check
            return None
        if isinstance(obj, tuple):
            return list(obj)
        return obj

    def serialize(d: Any) -> Any:
        if isinstance(d, dict):
            return {str(k): serialize(v) for k, v in d.items()}
        if isinstance(d, list):
            return [serialize(v) for v in d]
        return convert(d)

    serialized = serialize(data)
    json_str = json.dumps(serialized, indent=indent)

    if path:
        with open(path, "w") as f:
            f.write(json_str)

    return json_str

from_json classmethod

from_json(path)

Load from JSON file.

Parameters

path : str Path to JSON file.

Returns

SignalResult Loaded result.

Source code in src/ml4t/diagnostic/signal/result.py
@classmethod
def from_json(cls, path: str) -> SignalResult:
    """Load from JSON file.

    Parameters
    ----------
    path : str
        Path to JSON file.

    Returns
    -------
    SignalResult
        Loaded result.
    """
    with open(path) as f:
        data = json.load(f)

    # Convert lists back to tuples for immutable fields
    if "date_range" in data:
        data["date_range"] = tuple(data["date_range"])
    if "periods" in data:
        data["periods"] = tuple(data["periods"])

    # Convert quantile keys back to int
    if "quantile_returns" in data:
        data["quantile_returns"] = {
            period: {int(k): v for k, v in qr.items()}
            for period, qr in data["quantile_returns"].items()
        }
    if "quantile_returns_std" in data:
        data["quantile_returns_std"] = {
            period: {int(k): v for k, v in qr.items()}
            for period, qr in data["quantile_returns_std"].items()
        }
    if "count_by_quantile" in data:
        data["count_by_quantile"] = {int(k): v for k, v in data["count_by_quantile"].items()}

    return cls(**data)

analyze_signal

analyze_signal(
    factor,
    prices,
    *,
    periods=(1, 5, 21),
    quantiles=5,
    filter_zscore=3.0,
    quantile_method="quantile",
    ic_method="spearman",
    compute_turnover_flag=True,
    autocorrelation_lags=10,
    min_assets=10,
    factor_col="factor",
    date_col="date",
    asset_col="asset",
    price_col="price",
)

Analyze a factor signal.

This is the main entry point for signal analysis. Computes IC, quantile returns, spread, monotonicity, and optionally turnover/autocorrelation.

Parameters

factor : DataFrame Factor data with columns: date, asset, factor. Higher factor values should predict higher returns. prices : DataFrame Price data with columns: date, asset, price. periods : tuple[int, ...] Forward return periods in trading days (default: 1, 5, 21 days). quantiles : int Number of quantiles for grouping assets (default: 5 quintiles). filter_zscore : float | None Z-score threshold for outlier filtering. None disables. quantile_method : str "quantile" (equal frequency) or "uniform" (equal width). ic_method : str "spearman" (rank correlation) or "pearson" (linear correlation). compute_turnover_flag : bool Whether to compute turnover and autocorrelation metrics. autocorrelation_lags : int Number of lags for autocorrelation analysis. min_assets : int Minimum assets per date for IC computation. factor_col, date_col, asset_col, price_col : str Column names.

Returns

SignalResult Analysis results with IC, quantile returns, spread, monotonicity, and optionally turnover metrics.

Examples

Basic usage:

result = analyze_signal(factor_df, prices_df) print(result.summary()) result.to_json("results.json")

With custom parameters:

result = analyze_signal( ... factor_df, prices_df, ... periods=(1, 5, 21, 63), ... quantiles=10, ... ic_method="pearson", ... )

Source code in src/ml4t/diagnostic/signal/core.py
def analyze_signal(
    factor: pl.DataFrame | pd.DataFrame,
    prices: pl.DataFrame | pd.DataFrame,
    *,
    periods: tuple[int, ...] = (1, 5, 21),
    quantiles: int = 5,
    filter_zscore: float | None = 3.0,
    quantile_method: str = "quantile",
    ic_method: str = "spearman",
    compute_turnover_flag: bool = True,
    autocorrelation_lags: int = 10,
    min_assets: int = 10,
    factor_col: str = "factor",
    date_col: str = "date",
    asset_col: str = "asset",
    price_col: str = "price",
) -> SignalResult:
    """Analyze a factor signal.

    This is the main entry point for signal analysis. Computes IC, quantile
    returns, spread, monotonicity, and optionally turnover/autocorrelation.

    Parameters
    ----------
    factor : DataFrame
        Factor data with columns: date, asset, factor.
        Higher factor values should predict higher returns.
    prices : DataFrame
        Price data with columns: date, asset, price.
    periods : tuple[int, ...]
        Forward return periods in trading days (default: 1, 5, 21 days).
    quantiles : int
        Number of quantiles for grouping assets (default: 5 quintiles).
    filter_zscore : float | None
        Z-score threshold for outlier filtering. None disables.
    quantile_method : str
        "quantile" (equal frequency) or "uniform" (equal width).
    ic_method : str
        "spearman" (rank correlation) or "pearson" (linear correlation).
    compute_turnover_flag : bool
        Whether to compute turnover and autocorrelation metrics.
    autocorrelation_lags : int
        Number of lags for autocorrelation analysis.
    min_assets : int
        Minimum assets per date for IC computation.
    factor_col, date_col, asset_col, price_col : str
        Column names.

    Returns
    -------
    SignalResult
        Analysis results with IC, quantile returns, spread, monotonicity,
        and optionally turnover metrics.

    Examples
    --------
    Basic usage:

    >>> result = analyze_signal(factor_df, prices_df)
    >>> print(result.summary())
    >>> result.to_json("results.json")

    With custom parameters:

    >>> result = analyze_signal(
    ...     factor_df, prices_df,
    ...     periods=(1, 5, 21, 63),
    ...     quantiles=10,
    ...     ic_method="pearson",
    ... )
    """
    # Prepare data
    data = prepare_data(
        factor,
        prices,
        periods,
        quantiles,
        filter_zscore,
        quantile_method,
        factor_col,
        date_col,
        asset_col,
        price_col,
    )

    # Extract metadata
    n_assets = data.select(asset_col).n_unique()
    n_dates = data.select(date_col).n_unique()
    all_dates = data.select(date_col).unique().sort(date_col).to_series().to_list()
    date_range = (str(all_dates[0]), str(all_dates[-1])) if all_dates else ("", "")

    # Count by quantile (period-independent, compute once)
    count_by_quantile: dict[int, int] = {}
    for q_val, cnt in data.group_by("quantile").len().sort("quantile").iter_rows():
        count_by_quantile[int(q_val)] = cnt

    # Initialize result dicts
    ic: dict[str, float] = {}
    ic_std: dict[str, float] = {}
    ic_t_stat: dict[str, float] = {}
    ic_p_value: dict[str, float] = {}
    ic_ir: dict[str, float] = {}
    ic_positive_pct: dict[str, float] = {}
    ic_series: dict[str, list[float]] = {}
    ic_dates: dict[str, list[str]] = {}
    quantile_returns: dict[str, dict[int, float]] = {}
    quantile_returns_std: dict[str, dict[int, float]] = {}
    spread: dict[str, float] = {}
    spread_t_stat: dict[str, float] = {}
    spread_p_value: dict[str, float] = {}
    spread_std: dict[str, float] = {}
    monotonicity: dict[str, float] = {}

    # Compute metrics for each period
    for period in periods:
        period_key = f"{period}D"

        # IC
        dates, ic_vals = extract_signal_ic_series(
            data, period, ic_method, factor_col, date_col, asset_col, min_assets
        )
        summary = compute_ic_summary(ic_vals)

        ic[period_key] = summary["mean"]
        ic_std[period_key] = summary["std"]
        ic_t_stat[period_key] = summary["t_stat"]
        ic_p_value[period_key] = summary["p_value"]
        ic_series[period_key] = ic_vals
        ic_dates[period_key] = [str(d) for d in dates]

        # IC Information Ratio and positive percentage
        if summary["std"] > 0:
            ic_ir[period_key] = summary["mean"] / summary["std"]
        else:
            ic_ir[period_key] = 0.0
        if ic_vals:
            ic_positive_pct[period_key] = sum(1 for x in ic_vals if x > 0) / len(ic_vals) * 100
        else:
            ic_positive_pct[period_key] = 0.0

        # Quantile returns
        q_returns = compute_quantile_returns(data, period, quantiles)
        quantile_returns[period_key] = q_returns

        # Quantile return standard deviations
        return_col = f"{period}D_fwd_return"
        q_detail = (
            data.filter(pl.col(return_col).is_not_null())
            .group_by("quantile")
            .agg(pl.col(return_col).std().alias("std_return"))
            .sort("quantile")
        )
        q_std: dict[int, float] = {}
        for row in q_detail.iter_rows(named=True):
            std_val = row["std_return"]
            q_std[int(row["quantile"])] = float(std_val) if std_val is not None else 0.0
        quantile_returns_std[period_key] = q_std

        # Spread
        spread_stats = compute_spread(data, period, quantiles)
        spread[period_key] = spread_stats["spread"]
        spread_t_stat[period_key] = spread_stats["t_stat"]
        spread_p_value[period_key] = spread_stats["p_value"]

        # Spread std (derive from identity: se = spread / t_stat)
        t = spread_stats["t_stat"]
        if abs(t) > 1e-12:
            spread_std[period_key] = abs(spread_stats["spread"] / t)
        else:
            spread_std[period_key] = float("nan")

        # Monotonicity
        monotonicity[period_key] = monotonicity_score(q_returns)

    # Turnover (optional)
    turnover_dict: dict[str, float] | None = None
    autocorr: list[float] | None = None
    half_life: float | None = None

    if compute_turnover_flag:
        turnover_val = compute_turnover(data, quantiles, date_col, asset_col)
        turnover_dict = {f"{p}D": turnover_val for p in periods}

        lags = list(range(1, autocorrelation_lags + 1))
        autocorr = compute_autocorrelation(data, lags, date_col, asset_col, factor_col)
        half_life = estimate_half_life(autocorr)

    return SignalResult(
        ic=ic,
        ic_std=ic_std,
        ic_t_stat=ic_t_stat,
        ic_p_value=ic_p_value,
        ic_ir=ic_ir,
        ic_positive_pct=ic_positive_pct,
        ic_series=ic_series,
        ic_dates=ic_dates,
        quantile_returns=quantile_returns,
        quantile_returns_std=quantile_returns_std,
        count_by_quantile=count_by_quantile,
        spread=spread,
        spread_t_stat=spread_t_stat,
        spread_p_value=spread_p_value,
        spread_std=spread_std,
        monotonicity=monotonicity,
        turnover=turnover_dict,
        autocorrelation=autocorr,
        half_life=half_life,
        n_assets=n_assets,
        n_dates=n_dates,
        date_range=date_range,
        periods=periods,
        quantiles=quantiles,
    )

prepare_data

prepare_data(
    factor,
    prices,
    periods=(1, 5, 21),
    quantiles=5,
    filter_zscore=3.0,
    quantile_method="quantile",
    factor_col="factor",
    date_col="date",
    asset_col="asset",
    price_col="price",
)

Prepare factor data for analysis.

Joins factor with prices, computes forward returns, filters outliers, and assigns quantiles.

Parameters

factor : DataFrame Factor data with columns: date, asset, factor. prices : DataFrame Price data with columns: date, asset, price. periods : tuple[int, ...] Forward return periods in trading days. quantiles : int Number of quantiles. filter_zscore : float | None Z-score threshold for outlier filtering. None disables. quantile_method : str "quantile" (equal frequency) or "uniform" (equal width). factor_col, date_col, asset_col, price_col : str Column names.

Returns

pl.DataFrame Prepared data with: date, asset, factor, quantile, {period}D_fwd_return.

Source code in src/ml4t/diagnostic/signal/core.py
def prepare_data(
    factor: pl.DataFrame | pd.DataFrame,
    prices: pl.DataFrame | pd.DataFrame,
    periods: tuple[int, ...] = (1, 5, 21),
    quantiles: int = 5,
    filter_zscore: float | None = 3.0,
    quantile_method: str = "quantile",
    factor_col: str = "factor",
    date_col: str = "date",
    asset_col: str = "asset",
    price_col: str = "price",
) -> pl.DataFrame:
    """Prepare factor data for analysis.

    Joins factor with prices, computes forward returns, filters outliers,
    and assigns quantiles.

    Parameters
    ----------
    factor : DataFrame
        Factor data with columns: date, asset, factor.
    prices : DataFrame
        Price data with columns: date, asset, price.
    periods : tuple[int, ...]
        Forward return periods in trading days.
    quantiles : int
        Number of quantiles.
    filter_zscore : float | None
        Z-score threshold for outlier filtering. None disables.
    quantile_method : str
        "quantile" (equal frequency) or "uniform" (equal width).
    factor_col, date_col, asset_col, price_col : str
        Column names.

    Returns
    -------
    pl.DataFrame
        Prepared data with: date, asset, factor, quantile, {period}D_fwd_return.
    """
    # Convert to Polars
    factor_pl = ensure_polars(factor)
    prices_pl = ensure_polars(prices)

    # Compute forward returns
    data = compute_forward_returns(factor_pl, prices_pl, periods, date_col, asset_col, price_col)

    # Filter outliers
    if filter_zscore is not None and filter_zscore > 0:
        data = filter_outliers(data, filter_zscore, factor_col, date_col)

    # Assign quantiles
    method = QuantileMethod.QUANTILE if quantile_method == "quantile" else QuantileMethod.UNIFORM
    data = quantize_factor(data, quantiles, method, factor_col, date_col)

    return data

extract_signal_ic_series

extract_signal_ic_series(
    data,
    period,
    method="spearman",
    factor_col="factor",
    date_col="date",
    asset_col="asset",
    min_obs=10,
)

Extract valid per-date IC values for a single signal horizon.

Parameters

data : pl.DataFrame Factor data with factor and forward return columns. period : int Forward return period in days. method : str, default "spearman" Correlation method ("spearman" or "pearson"). factor_col : str, default "factor" Factor column name. date_col : str, default "date" Date column name. asset_col : str, default "asset" Asset/entity column used for panel joins. min_obs : int, default 10 Minimum observations per date.

Returns

tuple[list[Any], list[float]] (dates, ic_values) for dates with valid IC.

Source code in src/ml4t/diagnostic/signal/signal_ic.py
def extract_signal_ic_series(
    data: pl.DataFrame,
    period: int,
    method: str = "spearman",
    factor_col: str = "factor",
    date_col: str = "date",
    asset_col: str = "asset",
    min_obs: int = 10,
) -> tuple[list[Any], list[float]]:
    """Extract valid per-date IC values for a single signal horizon.

    Parameters
    ----------
    data : pl.DataFrame
        Factor data with factor and forward return columns.
    period : int
        Forward return period in days.
    method : str, default "spearman"
        Correlation method ("spearman" or "pearson").
    factor_col : str, default "factor"
        Factor column name.
    date_col : str, default "date"
        Date column name.
    asset_col : str, default "asset"
        Asset/entity column used for panel joins.
    min_obs : int, default 10
        Minimum observations per date.

    Returns
    -------
    tuple[list[Any], list[float]]
        (dates, ic_values) for dates with valid IC.
    """
    return_col = f"{period}D_fwd_return"

    pred_df = data.select([date_col, asset_col, factor_col])
    ret_df = data.select([date_col, asset_col, return_col])

    ic_df = cross_sectional_ic_series(
        predictions=pred_df,
        returns=ret_df,
        pred_col=factor_col,
        ret_col=return_col,
        date_col=date_col,
        entity_col=asset_col,
        method=method,
        min_obs=min_obs,
    )

    if ic_df.height == 0:
        return [], []

    ic_clean = ic_df.filter(
        (pl.col("n_obs") >= min_obs) & pl.col("ic").cast(pl.Float64).is_finite()
    )
    dates = ic_clean[date_col].to_list()
    ic_values = ic_clean["ic"].cast(pl.Float64).to_list()
    return dates, ic_values

compute_ic_summary

compute_ic_summary(ic_series)

Compute summary statistics for an IC series.

Parameters

ic_series : list[float] IC values over time.

Returns

dict[str, float] mean, std, t_stat, p_value, pct_positive

Source code in src/ml4t/diagnostic/signal/signal_ic.py
def compute_ic_summary(
    ic_series: list[float],
) -> dict[str, float]:
    """Compute summary statistics for an IC series.

    Parameters
    ----------
    ic_series : list[float]
        IC values over time.

    Returns
    -------
    dict[str, float]
        mean, std, t_stat, p_value, pct_positive
    """
    summary = compute_ic_summary_stats(ic_series)
    return {
        "mean": float(summary["mean_ic"]),
        "std": float(summary["std_ic"]),
        "t_stat": float(summary["t_stat"]),
        "p_value": float(summary["p_value"]),
        "pct_positive": float(summary["pct_positive"]),
    }

compute_quantile_returns

compute_quantile_returns(
    data, period, n_quantiles, quantile_col="quantile"
)

Compute mean forward returns by quantile.

Parameters

data : pl.DataFrame Data with quantile and forward return columns. period : int Forward return period in days. n_quantiles : int Number of quantiles. quantile_col : str, default "quantile" Quantile column name.

Returns

dict[int, float] Mean return by quantile (1 = lowest factor).

Source code in src/ml4t/diagnostic/signal/quantile.py
def compute_quantile_returns(
    data: pl.DataFrame,
    period: int,
    n_quantiles: int,
    quantile_col: str = "quantile",
) -> dict[int, float]:
    """Compute mean forward returns by quantile.

    Parameters
    ----------
    data : pl.DataFrame
        Data with quantile and forward return columns.
    period : int
        Forward return period in days.
    n_quantiles : int
        Number of quantiles.
    quantile_col : str, default "quantile"
        Quantile column name.

    Returns
    -------
    dict[int, float]
        Mean return by quantile (1 = lowest factor).
    """
    return_col = f"{period}D_fwd_return"

    if return_col not in data.columns:
        return dict.fromkeys(range(1, n_quantiles + 1), float("nan"))

    result: dict[int, float] = {}

    quantile_means = (
        data.filter(pl.col(return_col).is_not_null())
        .group_by(quantile_col)
        .agg(pl.col(return_col).mean().alias("mean_return"))
        .sort(quantile_col)
    )

    for row in quantile_means.iter_rows(named=True):
        result[int(row[quantile_col])] = float(row["mean_return"])

    # Fill missing quantiles
    for q in range(1, n_quantiles + 1):
        if q not in result:
            result[q] = float("nan")

    return result

compute_spread

compute_spread(
    data, period, n_quantiles, quantile_col="quantile"
)

Compute long-short spread and statistics.

Parameters

data : pl.DataFrame Data with quantile and forward return columns. period : int Forward return period in days. n_quantiles : int Number of quantiles. quantile_col : str, default "quantile" Quantile column name.

Returns

dict[str, float] spread, t_stat, p_value

Source code in src/ml4t/diagnostic/signal/quantile.py
def compute_spread(
    data: pl.DataFrame,
    period: int,
    n_quantiles: int,
    quantile_col: str = "quantile",
) -> dict[str, float]:
    """Compute long-short spread and statistics.

    Parameters
    ----------
    data : pl.DataFrame
        Data with quantile and forward return columns.
    period : int
        Forward return period in days.
    n_quantiles : int
        Number of quantiles.
    quantile_col : str, default "quantile"
        Quantile column name.

    Returns
    -------
    dict[str, float]
        spread, t_stat, p_value
    """
    return_col = f"{period}D_fwd_return"

    if return_col not in data.columns:
        return {
            "spread": float("nan"),
            "t_stat": float("nan"),
            "p_value": float("nan"),
        }

    top_returns = data.filter(pl.col(quantile_col) == n_quantiles)[return_col].to_numpy()
    bottom_returns = data.filter(pl.col(quantile_col) == 1)[return_col].to_numpy()

    top_returns = top_returns[~np.isnan(top_returns)]
    bottom_returns = bottom_returns[~np.isnan(bottom_returns)]

    if len(top_returns) < 2 or len(bottom_returns) < 2:
        return {
            "spread": float("nan"),
            "t_stat": float("nan"),
            "p_value": float("nan"),
        }

    spread = float(np.mean(top_returns) - np.mean(bottom_returns))
    t_stat, p_value = ttest_ind(top_returns, bottom_returns)

    return {
        "spread": spread,
        "t_stat": float(t_stat),
        "p_value": float(p_value),
    }

compute_turnover

compute_turnover(
    data,
    n_quantiles,
    date_col="date",
    asset_col="asset",
    quantile_col="quantile",
)

Compute mean turnover rate across quantiles.

Turnover = fraction of assets that change quantile each period.

Parameters

data : pl.DataFrame Data with date, asset, and quantile columns. n_quantiles : int Number of quantiles. date_col, asset_col, quantile_col : str Column names.

Returns

float Mean turnover rate (0-1).

Source code in src/ml4t/diagnostic/signal/turnover.py
def compute_turnover(
    data: pl.DataFrame,
    n_quantiles: int,
    date_col: str = "date",
    asset_col: str = "asset",
    quantile_col: str = "quantile",
) -> float:
    """Compute mean turnover rate across quantiles.

    Turnover = fraction of assets that change quantile each period.

    Parameters
    ----------
    data : pl.DataFrame
        Data with date, asset, and quantile columns.
    n_quantiles : int
        Number of quantiles.
    date_col, asset_col, quantile_col : str
        Column names.

    Returns
    -------
    float
        Mean turnover rate (0-1).
    """
    unique_dates = data.select(date_col).unique().sort(date_col).to_series().to_list()

    if len(unique_dates) < 2:
        return float("nan")

    # Pre-compute asset sets per (date, quantile) using dict comprehension
    asset_lists = (
        data.group_by([date_col, quantile_col])
        .agg(pl.col(asset_col).alias("assets"))
        .sort([date_col, quantile_col])
    )
    # Use rows() for faster iteration (returns tuples)
    asset_sets: dict[tuple[Any, int], set[Any]] = {
        (row[0], row[1]): set(row[2]) for row in asset_lists.rows()
    }

    # Compute turnover for each quantile
    all_turnovers: list[float] = []

    for q in range(1, n_quantiles + 1):
        q_turnovers: list[float] = []

        for i in range(len(unique_dates) - 1):
            date_t = unique_dates[i]
            date_t1 = unique_dates[i + 1]

            assets_t = asset_sets.get((date_t, q), set())
            assets_t1 = asset_sets.get((date_t1, q), set())

            if assets_t and assets_t1:
                overlap = len(assets_t & assets_t1)
                turnover = 1 - overlap / max(len(assets_t), len(assets_t1))
                q_turnovers.append(turnover)

        if q_turnovers:
            all_turnovers.append(float(np.mean(q_turnovers)))

    return float(np.nanmean(all_turnovers)) if all_turnovers else float("nan")

estimate_half_life

estimate_half_life(autocorrelations)

Estimate half-life from autocorrelation decay.

Half-life is the lag where autocorrelation drops to 50% of lag-1 value.

Parameters

autocorrelations : list[float] Autocorrelation at lags 1, 2, 3, ...

Returns

float | None Half-life in periods, or None if undefined.

Source code in src/ml4t/diagnostic/signal/turnover.py
def estimate_half_life(autocorrelations: list[float]) -> float | None:
    """Estimate half-life from autocorrelation decay.

    Half-life is the lag where autocorrelation drops to 50% of lag-1 value.

    Parameters
    ----------
    autocorrelations : list[float]
        Autocorrelation at lags 1, 2, 3, ...

    Returns
    -------
    float | None
        Half-life in periods, or None if undefined.
    """
    valid_ac = [ac for ac in autocorrelations if not np.isnan(ac)]

    if len(valid_ac) < 2 or valid_ac[0] <= 0:
        return None

    threshold = 0.5 * valid_ac[0]

    for i, ac in enumerate(valid_ac):
        if ac < threshold:
            if i > 0:
                # Linear interpolation
                return i + (valid_ac[i - 1] - threshold) / (valid_ac[i - 1] - ac)
            return float(i + 1)

    return None  # Never decayed below threshold

Metrics

Use ml4t.diagnostic.metrics for reusable metric and feature-statistic helpers.

metrics

Metrics module for ML4T Diagnostic.

Provides statistical metrics and percentile computation utilities for model evaluation.

pooled_ic

pooled_ic(
    predictions,
    returns,
    method="spearman",
    confidence_intervals=False,
    alpha=0.05,
)

Calculate pooled IC across all observations.

This is a global correlation across the supplied arrays. For cross-sectional ranking strategies, prefer :func:cross_sectional_ic or :func:cross_sectional_ic_series so IC is computed per date before reduction.

Source code in src/ml4t/diagnostic/metrics/ic.py
def pooled_ic(
    predictions: Union[pl.Series, pd.Series, "NDArray[Any]"],
    returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
    method: str = "spearman",
    confidence_intervals: bool = False,
    alpha: float = 0.05,
) -> float | dict[str, float]:
    """Calculate pooled IC across all observations.

    This is a global correlation across the supplied arrays. For cross-sectional
    ranking strategies, prefer :func:`cross_sectional_ic` or
    :func:`cross_sectional_ic_series` so IC is computed per date before reduction.
    """
    return information_coefficient(
        predictions=predictions,
        returns=returns,
        method=method,
        confidence_intervals=confidence_intervals,
        alpha=alpha,
    )

cross_sectional_ic_series

cross_sectional_ic_series(
    predictions,
    returns,
    pred_col="prediction",
    ret_col="forward_return",
    date_col="date",
    entity_col=None,
    method="spearman",
    min_obs=10,
)

Compute a per-date cross-sectional IC time series.

This function computes the Information Coefficient for each time period (typically daily), enabling temporal analysis of prediction quality. Parameters


predictions : Union[pl.DataFrame, pd.DataFrame] DataFrame with predictions, indexed or with date column returns : Union[pl.DataFrame, pd.DataFrame] DataFrame with forward returns, matching predictions structure pred_col : str, default "prediction" Column name for predictions/features ret_col : str, default "forward_return" Column name for forward returns date_col : str, default "date" Column name for dates (for grouping by period) entity_col : str or list[str] or None, default None Entity column(s) for panel data (e.g., "symbol" or ["symbol"]). When provided, join includes entity columns to avoid Cartesian products. Required for cross-sectional data with multiple entities per date. method : str, default "spearman" Correlation method: "spearman" or "pearson" min_obs : int, default 10 Minimum observations per period for valid IC calculation

Returns

Union[pl.DataFrame, pd.DataFrame] Time series of IC values with columns: [date_col, 'ic', 'n_obs']

Examples

Panel data with multiple symbols per date

pred_df = pl.DataFrame({ ... "date": ["2024-01-01"] * 4 + ["2024-01-02"] * 4, ... "symbol": ["SPY", "QQQ", "IWM", "DIA"] * 2, ... "prediction": np.random.randn(8), ... }) ret_df = pl.DataFrame({ ... "date": ["2024-01-01"] * 4 + ["2024-01-02"] * 4, ... "symbol": ["SPY", "QQQ", "IWM", "DIA"] * 2, ... "forward_return": np.random.randn(8) * 0.02, ... }) ic_series = cross_sectional_ic_series(pred_df, ret_df, entity_col="symbol")

Source code in src/ml4t/diagnostic/metrics/ic.py
def cross_sectional_ic_series(
    predictions: pl.DataFrame | pd.DataFrame,
    returns: pl.DataFrame | pd.DataFrame,
    pred_col: str = "prediction",
    ret_col: str = "forward_return",
    date_col: str = "date",
    entity_col: str | list[str] | None = None,
    method: str = "spearman",
    min_obs: int = 10,
) -> pl.DataFrame | pd.DataFrame:
    """Compute a per-date cross-sectional IC time series.

    This function computes the Information Coefficient for each time period
    (typically daily), enabling temporal analysis of prediction quality.
    Parameters
    ----------
    predictions : Union[pl.DataFrame, pd.DataFrame]
        DataFrame with predictions, indexed or with date column
    returns : Union[pl.DataFrame, pd.DataFrame]
        DataFrame with forward returns, matching predictions structure
    pred_col : str, default "prediction"
        Column name for predictions/features
    ret_col : str, default "forward_return"
        Column name for forward returns
    date_col : str, default "date"
        Column name for dates (for grouping by period)
    entity_col : str or list[str] or None, default None
        Entity column(s) for panel data (e.g., "symbol" or ["symbol"]).
        When provided, join includes entity columns to avoid Cartesian
        products. Required for cross-sectional data with multiple
        entities per date.
    method : str, default "spearman"
        Correlation method: "spearman" or "pearson"
    min_obs : int, default 10
        Minimum observations per period for valid IC calculation

    Returns
    -------
    Union[pl.DataFrame, pd.DataFrame]
        Time series of IC values with columns: [date_col, 'ic', 'n_obs']

    Examples
    --------
    >>> # Panel data with multiple symbols per date
    >>> pred_df = pl.DataFrame({
    ...     "date": ["2024-01-01"] * 4 + ["2024-01-02"] * 4,
    ...     "symbol": ["SPY", "QQQ", "IWM", "DIA"] * 2,
    ...     "prediction": np.random.randn(8),
    ... })
    >>> ret_df = pl.DataFrame({
    ...     "date": ["2024-01-01"] * 4 + ["2024-01-02"] * 4,
    ...     "symbol": ["SPY", "QQQ", "IWM", "DIA"] * 2,
    ...     "forward_return": np.random.randn(8) * 0.02,
    ... })
    >>> ic_series = cross_sectional_ic_series(pred_df, ret_df, entity_col="symbol")
    """
    output_as_pandas = isinstance(predictions, pd.DataFrame)

    # Build join columns (date + entity for panel data)
    join_on: list[str] = [date_col]
    if entity_col is not None:
        if isinstance(entity_col, str):
            join_on.append(entity_col)
        else:
            join_on.extend(entity_col)

    predictions_pl = (
        predictions
        if isinstance(predictions, pl.DataFrame)
        else pl.from_pandas(cast(pd.DataFrame, predictions))
    )
    returns_pl = (
        returns
        if isinstance(returns, pl.DataFrame)
        else pl.from_pandas(cast(pd.DataFrame, returns))
    )

    # Merge predictions and returns
    df = predictions_pl.join(returns_pl, on=join_on, how="inner")

    if method not in ("spearman", "pearson"):
        raise ValueError(f"Unknown method: {method!r}. Use 'spearman' or 'pearson'.")

    # Vectorized per-date IC via polars: rank within each date group, then
    # correlate ranks per group. This avoids a Python loop over dates.
    valid_expr = pl.col(pred_col).is_finite() & pl.col(ret_col).is_finite()

    if method == "spearman":
        # Set invalid rows to null so rank() leaves them as null within
        # each date group; pl.corr ignores null pairs.
        df_valid = df.with_columns(
            [
                pl.when(valid_expr).then(pl.col(pred_col)).otherwise(None).alias("__pred_valid"),
                pl.when(valid_expr).then(pl.col(ret_col)).otherwise(None).alias("__ret_valid"),
            ]
        ).with_columns(
            [
                pl.col("__pred_valid").rank(method="average").over(date_col).alias("__pred_r"),
                pl.col("__ret_valid").rank(method="average").over(date_col).alias("__ret_r"),
            ]
        )
        p_col, r_col = "__pred_r", "__ret_r"
    else:
        df_valid = df.with_columns(
            [
                pl.when(valid_expr).then(pl.col(pred_col)).otherwise(None).alias("__pred_valid"),
                pl.when(valid_expr).then(pl.col(ret_col)).otherwise(None).alias("__ret_valid"),
            ]
        )
        p_col, r_col = "__pred_valid", "__ret_valid"

    # Preserve every date present in the join, even those where all rows
    # were invalid (matches the old per-group behavior of emitting a row
    # with n_obs=0 and ic=NaN).
    all_dates = df.select(date_col).unique().sort(date_col)
    grouped = df_valid.group_by(date_col, maintain_order=False).agg(
        [
            valid_expr.sum().alias("n_obs"),
            pl.corr(pl.col(p_col), pl.col(r_col)).alias("ic"),
        ]
    )
    ic_series_pl = (
        all_dates.join(grouped, on=date_col, how="left")
        .with_columns(
            [
                pl.col("n_obs").fill_null(0),
                pl.when(pl.col("n_obs") >= min_obs).then(pl.col("ic")).otherwise(None).alias("ic"),
            ]
        )
        .sort(date_col)
    )

    if output_as_pandas:
        return ic_series_pl.to_pandas()
    return ic_series_pl

cross_sectional_ic

cross_sectional_ic(
    predictions,
    returns,
    pred_col="prediction",
    ret_col="forward_return",
    date_col="date",
    entity_col=None,
    method="spearman",
    min_obs=10,
)

Compute per-date cross-sectional IC and return aggregate statistics.

Returns the canonical non-HAC summary for the valid per-date IC series: mean, sample standard deviation, t-statistic, p-value, percent positive, number of periods, and non-annualized IC information ratio.

Source code in src/ml4t/diagnostic/metrics/ic.py
def cross_sectional_ic(
    predictions: pl.DataFrame | pd.DataFrame,
    returns: pl.DataFrame | pd.DataFrame,
    pred_col: str = "prediction",
    ret_col: str = "forward_return",
    date_col: str = "date",
    entity_col: str | list[str] | None = None,
    method: str = "spearman",
    min_obs: int = 10,
) -> dict[str, float | int]:
    """Compute per-date cross-sectional IC and return aggregate statistics.

    Returns the canonical non-HAC summary for the valid per-date IC series:
    mean, sample standard deviation, t-statistic, p-value, percent positive,
    number of periods, and non-annualized IC information ratio.
    """
    ic_df = cross_sectional_ic_series(
        predictions=predictions,
        returns=returns,
        pred_col=pred_col,
        ret_col=ret_col,
        date_col=date_col,
        entity_col=entity_col,
        method=method,
        min_obs=min_obs,
    )

    from ml4t.diagnostic.metrics.ic_inference import compute_ic_summary_stats

    summary = compute_ic_summary_stats(ic_df, ic_col="ic")
    mean_ic = float(summary["mean_ic"])
    std_ic = float(summary["std_ic"])
    ic_ir = mean_ic / std_ic if np.isfinite(std_ic) and std_ic > 0 else np.nan
    return {
        "ic_mean": mean_ic,
        "ic_std": std_ic,
        "ic_t": float(summary["t_stat"]),
        "p_value": float(summary["p_value"]),
        "pct_positive": float(summary["pct_positive"]),
        "n_periods": int(summary["n_periods"]),
        "ic_ir": float(ic_ir),
    }

compute_ic_summary_stats

compute_ic_summary_stats(ic_series, ic_col='ic')

Compute naive summary stats for an IC time series.

This is the canonical non-HAC summary used by signal-level diagnostics and simple fallback paths where robust autocorrelation adjustment is disabled.

Source code in src/ml4t/diagnostic/metrics/ic_inference.py
def compute_ic_summary_stats(
    ic_series: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]", list[float]],
    ic_col: str = "ic",
) -> dict[str, float | int]:
    """Compute naive summary stats for an IC time series.

    This is the canonical non-HAC summary used by signal-level diagnostics and
    simple fallback paths where robust autocorrelation adjustment is disabled.
    """
    ic_values: NDArray[Any]
    if isinstance(ic_series, pl.DataFrame | pd.DataFrame):
        if isinstance(ic_series, pl.DataFrame):
            ic_values = ic_series[ic_col].to_numpy()
        else:
            ic_values = ic_series[ic_col].to_numpy()
    else:
        ic_values = np.asarray(ic_series).flatten()

    ic_clean: NDArray[Any] = ic_values[~np.isnan(ic_values)]
    n = len(ic_clean)

    if n < 2:
        return {
            "mean_ic": np.nan,
            "std_ic": np.nan,
            "t_stat": np.nan,
            "p_value": np.nan,
            "pct_positive": np.nan,
            "n_periods": n,
        }

    mean_ic = float(np.mean(ic_clean))
    std_ic = float(np.std(ic_clean, ddof=1))
    t_stat = mean_ic / (std_ic / np.sqrt(n)) if std_ic > 0 else np.nan
    p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=n - 1)) if not np.isnan(t_stat) else np.nan

    return {
        "mean_ic": mean_ic,
        "std_ic": std_ic,
        "t_stat": float(t_stat),
        "p_value": float(p_value),
        "pct_positive": float(np.mean(ic_clean > 0)),
        "n_periods": n,
    }

compute_ic_hac_stats

compute_ic_hac_stats(
    ic_series,
    ic_col="ic",
    maxlags=None,
    kernel="bartlett",
    use_correction=True,
)

Compute HAC-adjusted significance statistics for IC time series.

Uses Newey-West HAC (Heteroskedasticity and Autocorrelation Consistent) standard errors to account for autocorrelation in IC time series. This provides robust t-statistics and p-values when IC exhibits serial correlation.

The Newey-West estimator accounts for: 1. Heteroskedasticity: Non-constant variance in IC over time 2. Autocorrelation: Serial correlation in IC values 3. Lag selection: Automatic selection of optimal lag window

Parameters

ic_series : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Time series of IC values (from cross_sectional_ic_series) ic_col : str, default "ic" Column name for IC values (if DataFrame) maxlags : int | None, default None Maximum lag for HAC adjustment. If None, uses Newey-West formula: maxlags = floor(4 * (T/100)^(2/9)) where T is the sample size kernel : str, default "bartlett" Kernel function for lag weighting: - "bartlett": Triangular kernel (Newey-West default) - "uniform": Equal weights - "parzen": Parzen kernel use_correction : bool, default True Apply small-sample correction to standard errors

Returns

dict[str, float] Dictionary with HAC-adjusted statistics: - mean_ic: Mean IC across time series - hac_se: HAC-adjusted standard error - t_stat: t-statistic (mean_ic / hac_se) - p_value: Two-tailed p-value for H0: IC = 0 - n_periods: Number of observations - effective_lags: Number of lags used in HAC adjustment - naive_se: Standard OLS standard error (for comparison) - naive_t_stat: Naive t-statistic without HAC adjustment

Examples

Compute IC series first

ic_series = cross_sectional_ic_series(pred_df, ret_df)

Compute HAC-adjusted statistics

stats = compute_ic_hac_stats(ic_series) print(f"Mean IC: {stats['mean_ic']:.4f}") print(f"HAC t-stat: {stats['t_stat']:.2f}") print(f"P-value: {stats['p_value']:.4f}") print(f"Significant: {stats['p_value'] < 0.05}") Mean IC: 0.0234 HAC t-stat: 2.14 P-value: 0.0327 Significant: True

Compare with naive statistics

print(f"Naive t-stat: {stats['naive_t_stat']:.2f}") print(f"HAC adjustment factor: {stats['naive_se'] / stats['hac_se']:.2f}x") Naive t-stat: 3.45 HAC adjustment factor: 1.61x

Notes

HAC Adjustment Interpretation: - HAC SE > Naive SE: Positive autocorrelation detected - HAC SE < Naive SE: Negative autocorrelation (rare) - HAC SE ~ Naive SE: Little autocorrelation

The Newey-West automatic lag selection formula is

maxlags = floor(4 * (T/100)^(2/9))

For example: - T=100 -> maxlags=4 - T=252 -> maxlags=5 - T=500 -> maxlags=6

References

.. [1] Newey, W. K., & West, K. D. (1987). "A Simple, Positive Semi-Definite, Heteroskedasticity and Autocorrelation Consistent Covariance Matrix." Econometrica, 55(3), 703-708. .. [2] Andrews, D. W. K. (1991). "Heteroskedasticity and Autocorrelation Consistent Covariance Matrix Estimation." Econometrica, 59(3), 817-858.

Source code in src/ml4t/diagnostic/metrics/ic_inference.py
def compute_ic_hac_stats(
    ic_series: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    ic_col: str = "ic",
    maxlags: int | None = None,
    kernel: str = "bartlett",
    use_correction: bool = True,
) -> dict[str, float]:
    """Compute HAC-adjusted significance statistics for IC time series.

    Uses Newey-West HAC (Heteroskedasticity and Autocorrelation Consistent)
    standard errors to account for autocorrelation in IC time series. This
    provides robust t-statistics and p-values when IC exhibits serial correlation.

    The Newey-West estimator accounts for:
    1. Heteroskedasticity: Non-constant variance in IC over time
    2. Autocorrelation: Serial correlation in IC values
    3. Lag selection: Automatic selection of optimal lag window

    Parameters
    ----------
    ic_series : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Time series of IC values (from cross_sectional_ic_series)
    ic_col : str, default "ic"
        Column name for IC values (if DataFrame)
    maxlags : int | None, default None
        Maximum lag for HAC adjustment. If None, uses Newey-West formula:
        maxlags = floor(4 * (T/100)^(2/9))
        where T is the sample size
    kernel : str, default "bartlett"
        Kernel function for lag weighting:
        - "bartlett": Triangular kernel (Newey-West default)
        - "uniform": Equal weights
        - "parzen": Parzen kernel
    use_correction : bool, default True
        Apply small-sample correction to standard errors

    Returns
    -------
    dict[str, float]
        Dictionary with HAC-adjusted statistics:
        - mean_ic: Mean IC across time series
        - hac_se: HAC-adjusted standard error
        - t_stat: t-statistic (mean_ic / hac_se)
        - p_value: Two-tailed p-value for H0: IC = 0
        - n_periods: Number of observations
        - effective_lags: Number of lags used in HAC adjustment
        - naive_se: Standard OLS standard error (for comparison)
        - naive_t_stat: Naive t-statistic without HAC adjustment

    Examples
    --------
    >>> # Compute IC series first
    >>> ic_series = cross_sectional_ic_series(pred_df, ret_df)
    >>>
    >>> # Compute HAC-adjusted statistics
    >>> stats = compute_ic_hac_stats(ic_series)
    >>> print(f"Mean IC: {stats['mean_ic']:.4f}")
    >>> print(f"HAC t-stat: {stats['t_stat']:.2f}")
    >>> print(f"P-value: {stats['p_value']:.4f}")
    >>> print(f"Significant: {stats['p_value'] < 0.05}")
    Mean IC: 0.0234
    HAC t-stat: 2.14
    P-value: 0.0327
    Significant: True
    >>>
    >>> # Compare with naive statistics
    >>> print(f"Naive t-stat: {stats['naive_t_stat']:.2f}")
    >>> print(f"HAC adjustment factor: {stats['naive_se'] / stats['hac_se']:.2f}x")
    Naive t-stat: 3.45
    HAC adjustment factor: 1.61x

    Notes
    -----
    HAC Adjustment Interpretation:
    - HAC SE > Naive SE: Positive autocorrelation detected
    - HAC SE < Naive SE: Negative autocorrelation (rare)
    - HAC SE ~ Naive SE: Little autocorrelation

    The Newey-West automatic lag selection formula is:
        maxlags = floor(4 * (T/100)^(2/9))

    For example:
    - T=100 -> maxlags=4
    - T=252 -> maxlags=5
    - T=500 -> maxlags=6

    References
    ----------
    .. [1] Newey, W. K., & West, K. D. (1987). "A Simple, Positive Semi-Definite,
           Heteroskedasticity and Autocorrelation Consistent Covariance Matrix."
           Econometrica, 55(3), 703-708.
    .. [2] Andrews, D. W. K. (1991). "Heteroskedasticity and Autocorrelation
           Consistent Covariance Matrix Estimation." Econometrica, 59(3), 817-858.
    """
    # Extract IC values
    ic_values: NDArray[Any]
    if isinstance(ic_series, pl.DataFrame | pd.DataFrame):
        is_polars = isinstance(ic_series, pl.DataFrame)
        if is_polars:
            ic_values = cast(pl.DataFrame, ic_series)[ic_col].to_numpy()
        else:
            ic_values = cast(pd.DataFrame, ic_series)[ic_col].to_numpy()
    else:
        ic_values = np.asarray(ic_series).flatten()

    # Remove NaN values
    ic_clean: NDArray[Any] = ic_values[~np.isnan(ic_values)]

    # Validate sufficient data
    n = len(ic_clean)
    if n < 3:
        return {
            "mean_ic": np.nan,
            "hac_se": np.nan,
            "t_stat": np.nan,
            "p_value": np.nan,
            "n_periods": n,
            "effective_lags": 0,
            "naive_se": np.nan,
            "naive_t_stat": np.nan,
        }

    # Compute mean IC
    mean_ic = float(np.mean(ic_clean))

    # Compute naive (OLS) standard error
    naive_var = float(np.var(ic_clean, ddof=1))  # Sample variance
    naive_se = np.sqrt(naive_var / n)  # Standard error of mean
    naive_t_stat = mean_ic / naive_se if naive_se > 0 else np.nan

    # Determine optimal lags if not specified
    if maxlags is None:
        # Newey-West automatic lag selection formula
        # maxlags = floor(4 * (T/100)^(2/9))
        maxlags = int(np.floor(4 * (n / 100) ** (2 / 9)))
        maxlags = max(1, maxlags)  # At least 1 lag
        maxlags = min(maxlags, n // 2)  # No more than T/2

    # Fit OLS model: IC ~ constant (testing if mean IC != 0)
    # This is equivalent to a one-sample t-test
    exog = np.ones((n, 1))  # Just constant term
    y = ic_clean.reshape(-1, 1)

    # Compute HAC covariance matrix
    try:
        # Fit OLS model
        model = OLS(y, exog)
        ols_results = model.fit()

        # Get HAC-robust covariance matrix
        hac_cov = cov_hac(
            ols_results,
            nlags=maxlags,
            weights_func=_get_kernel_weights(kernel),
            use_correction=use_correction,
        )

        # Extract HAC variance (it's a 1x1 matrix for the constant)
        hac_var = hac_cov[0, 0]
        hac_se = np.sqrt(hac_var)

    except Exception as e:
        # If HAC computation fails, fall back to naive SE
        print(f"Warning: HAC computation failed ({e}), using naive SE")
        hac_se = naive_se

    # Compute HAC-adjusted t-statistic
    t_stat = mean_ic / hac_se if hac_se > 0 else np.nan

    # Compute two-tailed p-value
    # Use t-distribution with n-1 degrees of freedom
    p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=n - 1)) if not np.isnan(t_stat) else np.nan

    return {
        "mean_ic": float(mean_ic),
        "hac_se": float(hac_se),
        "t_stat": float(t_stat),
        "p_value": float(p_value),
        "n_periods": n,
        "effective_lags": maxlags,
        "naive_se": float(naive_se),
        "naive_t_stat": float(naive_t_stat),
    }

compute_ic_decay

compute_ic_decay(
    predictions,
    prices,
    horizons=None,
    pred_col="prediction",
    price_col="close",
    date_col="date",
    group_col=None,
    method="spearman",
    estimate_half_life=True,
)

Analyze how IC decays over prediction horizons.

Computes IC at multiple forward-looking horizons to understand how long predictions retain predictive power. Faster IC decay indicates shorter signal persistence.

This is critical for: 1. Determining optimal holding periods 2. Understanding alpha decay dynamics 3. Identifying when to retrain models 4. Avoiding stale predictions

Parameters

predictions : Union[pl.DataFrame, pd.DataFrame] DataFrame with predictions, must have pred_col, date_col, and optionally group_col prices : Union[pl.DataFrame, pd.DataFrame] DataFrame with prices, must have price_col, date_col, and optionally group_col horizons : list[int] | None, default None List of forward horizons in days. If None, uses 1, 2, 5, 10, and 21. pred_col : str, default "prediction" Column name for predictions price_col : str, default "close" Column name for prices date_col : str, default "date" Column name for dates group_col : str | None, default None Column name for grouping (e.g., "symbol" for multi-asset) method : str, default "spearman" Correlation method: "spearman" or "pearson" estimate_half_life : bool, default True Whether to estimate IC half-life (horizon where IC drops to 50% of initial)

Returns

dict[str, Any] Dictionary with decay analysis: - ic_by_horizon: dict mapping horizon -> IC value - horizons: list of horizons analyzed - decay_rate: exponential decay rate (if estimable) - half_life: estimated half-life in days (if estimate_half_life=True) - optimal_horizon: horizon with highest IC - n_observations: number of observations per horizon

Examples

Analyze IC decay for multi-asset predictions

decay = compute_ic_decay( ... predictions=pred_df, ... prices=price_df, ... horizons=[1, 2, 5, 10, 21], ... group_col="symbol" ... ) print(f"IC at 1-day: {decay['ic_by_horizon'].get(1):.3f}") print(f"IC at 21-day: {decay['ic_by_horizon'].get(21):.3f}") print(f"Half-life: {decay['half_life']:.1f} days") print(f"Optimal horizon: {decay['optimal_horizon']} days") IC at 1-day: 0.045 IC at 21-day: 0.012 Half-life: 8.3 days Optimal horizon: 1 days

Notes

IC Decay Patterns: - Fast decay: IC drops >50% within 5 days -> high-frequency signal - Moderate decay: IC half-life 5-20 days -> medium-term signal - Slow decay: IC half-life >20 days -> long-term signal - No decay: IC stable -> structural/fundamental signal

Half-life is estimated by fitting exponential decay

IC(h) = IC(0) * exp(-lambda * h) half_life = ln(2) / lambda

Optimal horizon is the horizon with maximum IC, useful for determining best rebalancing frequency.

References

.. [1] Kakushadze, Z. (2016). "101 Formulaic Alphas." Wilmott, 2016(84), 72-81.

Source code in src/ml4t/diagnostic/metrics/ic_inference.py
def compute_ic_decay(
    predictions: pl.DataFrame | pd.DataFrame,
    prices: pl.DataFrame | pd.DataFrame,
    horizons: list[int] | None = None,
    pred_col: str = "prediction",
    price_col: str = "close",
    date_col: str = "date",
    group_col: str | None = None,
    method: str = "spearman",
    estimate_half_life: bool = True,
) -> dict[str, Any]:
    """Analyze how IC decays over prediction horizons.

    Computes IC at multiple forward-looking horizons to understand how long
    predictions retain predictive power. Faster IC decay indicates shorter
    signal persistence.

    This is critical for:
    1. Determining optimal holding periods
    2. Understanding alpha decay dynamics
    3. Identifying when to retrain models
    4. Avoiding stale predictions

    Parameters
    ----------
    predictions : Union[pl.DataFrame, pd.DataFrame]
        DataFrame with predictions, must have pred_col, date_col, and optionally group_col
    prices : Union[pl.DataFrame, pd.DataFrame]
        DataFrame with prices, must have price_col, date_col, and optionally group_col
    horizons : list[int] | None, default None
        List of forward horizons in days. If None, uses 1, 2, 5, 10, and 21.
    pred_col : str, default "prediction"
        Column name for predictions
    price_col : str, default "close"
        Column name for prices
    date_col : str, default "date"
        Column name for dates
    group_col : str | None, default None
        Column name for grouping (e.g., "symbol" for multi-asset)
    method : str, default "spearman"
        Correlation method: "spearman" or "pearson"
    estimate_half_life : bool, default True
        Whether to estimate IC half-life (horizon where IC drops to 50% of initial)

    Returns
    -------
    dict[str, Any]
        Dictionary with decay analysis:
        - ic_by_horizon: dict mapping horizon -> IC value
        - horizons: list of horizons analyzed
        - decay_rate: exponential decay rate (if estimable)
        - half_life: estimated half-life in days (if estimate_half_life=True)
        - optimal_horizon: horizon with highest IC
        - n_observations: number of observations per horizon

    Examples
    --------
    >>> # Analyze IC decay for multi-asset predictions
    >>> decay = compute_ic_decay(
    ...     predictions=pred_df,
    ...     prices=price_df,
    ...     horizons=[1, 2, 5, 10, 21],
    ...     group_col="symbol"
    ... )
    >>> print(f"IC at 1-day: {decay['ic_by_horizon'].get(1):.3f}")
    >>> print(f"IC at 21-day: {decay['ic_by_horizon'].get(21):.3f}")
    >>> print(f"Half-life: {decay['half_life']:.1f} days")
    >>> print(f"Optimal horizon: {decay['optimal_horizon']} days")
    IC at 1-day: 0.045
    IC at 21-day: 0.012
    Half-life: 8.3 days
    Optimal horizon: 1 days

    Notes
    -----
    IC Decay Patterns:
    - Fast decay: IC drops >50% within 5 days -> high-frequency signal
    - Moderate decay: IC half-life 5-20 days -> medium-term signal
    - Slow decay: IC half-life >20 days -> long-term signal
    - No decay: IC stable -> structural/fundamental signal

    Half-life is estimated by fitting exponential decay:
        IC(h) = IC(0) * exp(-lambda * h)
        half_life = ln(2) / lambda

    Optimal horizon is the horizon with maximum IC, useful for determining
    best rebalancing frequency.

    References
    ----------
    .. [1] Kakushadze, Z. (2016). "101 Formulaic Alphas." Wilmott, 2016(84), 72-81.
    """
    # Set default horizons if not provided
    if horizons is None:
        horizons = [1, 2, 5, 10, 21]

    # Ensure horizons are sorted
    horizons = sorted(horizons)

    # Compute IC for each horizon using compute_ic_by_horizon
    ic_results = compute_ic_by_horizon(
        predictions=predictions,
        prices=prices,
        horizons=horizons,
        pred_col=pred_col,
        price_col=price_col,
        date_col=date_col,
        group_col=group_col,
        method=method,
    )

    # Extract IC values and observation counts
    ic_by_horizon: dict[int, float] = {}
    n_obs_by_horizon: dict[int, int] = {}

    for horizon, ic_value in ic_results.items():
        ic_by_horizon[horizon] = ic_value
        # Note: compute_ic_by_horizon returns just IC values, not counts
        # We'll approximate n_obs from the input data
        n_obs_by_horizon[horizon] = len(predictions)

    # Find optimal horizon (highest absolute IC)
    optimal_ic: float
    optimal_horizon: int | None
    if ic_by_horizon:
        optimal_horizon = max(ic_by_horizon.keys(), key=lambda h: abs(ic_by_horizon[h]))
        optimal_ic = ic_by_horizon[optimal_horizon]
    else:
        optimal_horizon = None
        optimal_ic = np.nan

    # Estimate decay rate and half-life
    decay_rate = np.nan
    half_life = np.nan

    if estimate_half_life and len(ic_by_horizon) >= 2:
        # Extract horizons and IC values for fitting
        h_vals = np.array(list(ic_by_horizon.keys()))
        ic_vals = np.array([ic_by_horizon[h] for h in h_vals])

        # Remove NaN values
        valid_mask = ~np.isnan(ic_vals)
        h_vals = h_vals[valid_mask]
        ic_vals = ic_vals[valid_mask]

        if len(h_vals) >= 2 and np.all(ic_vals > 0):
            # Fit exponential decay: IC(h) = IC(0) * exp(-lambda * h)
            # Take log: ln(IC(h)) = ln(IC(0)) - lambda * h
            # This is linear regression: y = a + b*x where b = -lambda

            try:
                log_ic = np.log(ic_vals)

                # Linear regression
                coeffs = np.polyfit(h_vals, log_ic, deg=1)
                decay_rate = -coeffs[0]  # -lambda from the linear fit

                # Half-life: t_{1/2} = ln(2) / lambda
                if decay_rate > 0:
                    half_life = np.log(2) / decay_rate
                elif decay_rate < 0:
                    # Negative decay rate means IC is increasing (unusual)
                    half_life = np.inf
                else:
                    half_life = np.nan

            except (ValueError, np.linalg.LinAlgError):
                # Fitting failed (e.g., all IC values identical)
                decay_rate = np.nan
                half_life = np.nan

        elif len(h_vals) >= 2:
            # Can't fit exponential if IC values are not all positive
            # Try fitting to absolute values
            try:
                abs_ic_vals = np.abs(ic_vals)
                if np.all(abs_ic_vals > 0):
                    log_abs_ic = np.log(abs_ic_vals)
                    coeffs = np.polyfit(h_vals, log_abs_ic, deg=1)
                    decay_rate = -coeffs[0]

                    half_life = np.log(2) / decay_rate if decay_rate > 0 else np.nan
            except (ValueError, np.linalg.LinAlgError):
                pass

    return {
        "ic_by_horizon": ic_by_horizon,
        "horizons": horizons,
        "decay_rate": float(decay_rate) if not np.isnan(decay_rate) else None,
        "half_life": float(half_life)
        if not np.isnan(half_life) and not np.isinf(half_life)
        else None,
        "optimal_horizon": optimal_horizon,
        "optimal_ic": optimal_ic if not np.isnan(optimal_ic) else None,
        "n_observations": n_obs_by_horizon,
    }

compute_conditional_ic

compute_conditional_ic(
    feature_a,
    feature_b,
    forward_returns,
    date_col=None,
    group_col=None,
    n_quantiles=5,
    method="spearman",
    min_periods=10,
)

Compute IC of feature_a conditional on quantiles of feature_b.

This measures how feature_a's predictive power varies across different regimes defined by feature_b. Strong variation suggests feature interaction, which is critical for understanding when features work best.

This is a key ingredient for the Feature Interaction Tear Sheet, enabling analysis like: "Does momentum (feature_a) work better in high or low volatility (feature_b) regimes?"

Parameters

feature_a : DataFrame/Series/ndarray Feature to evaluate (IC will be computed for this) If DataFrame with date_col/group_col, will compute IC per date If Series/array, must align with feature_b and forward_returns feature_b : DataFrame/Series/ndarray Conditioning feature (used to create quantile bins) Must match feature_a structure forward_returns : DataFrame/Series/ndarray Forward returns to predict Must match feature_a structure date_col : str | None, default None Column name for dates (for panel data grouping) If specified, quantiles computed cross-sectionally per date group_col : str | None, default None Column name for groups/assets (for panel data) n_quantiles : int, default 5 Number of quantile bins for feature_b method : str, default "spearman" Correlation method: "spearman" or "pearson" min_periods : int, default 10 Minimum observations per quantile for valid IC calculation

Returns

dict[str, Any] Dictionary with: - quantile_ics: IC of feature_a in each quantile of feature_b (array) - quantile_labels: Labels for each quantile (list of str) - quantile_bounds: Mean value of feature_b in each quantile (dict) - ic_variation: Std dev of ICs across quantiles (float) - ic_range: Max - min IC (float) - significance_pvalue: Statistical test p-value (float) - test_statistic: Kruskal-Wallis H statistic (float) - n_quantiles: Number of quantiles (int) - n_obs_per_quantile: Observations in each quantile (dict) - interpretation: Automated insight generation (str)

Examples

import numpy as np import pandas as pd

Does momentum work better in high or low volatility?

np.random.seed(42) n = 1000 volatility = np.random.randn(n) momentum = np.random.randn(n)

Returns depend on momentum only when volatility is high

noise = 0.1 * np.random.randn(n) returns = np.where(volatility > 0, momentum + noise, noise)

result = compute_conditional_ic(momentum, volatility, returns) print(f"IC Range: {result['ic_range']:.3f}") print(f"P-value: {result['significance_pvalue']:.3f}") print(result['interpretation']) IC Range: 0.234 P-value: 0.001 Strong interaction detected: IC ranges from 0.012 to 0.246 across feature_b quantiles (p=0.001)

Notes

Use Cases: - Regime-dependent feature effectiveness - Feature interaction discovery - Risk factor analysis (does alpha persist in different market conditions?) - Conditional portfolio construction

Panel Data Handling: When date_col is specified, quantiles are computed WITHIN each cross-section (date) to avoid lookahead bias. This ensures quantile bins are time-consistent.

Statistical Significance: Uses Kruskal-Wallis test (non-parametric one-way ANOVA) to test if IC variation across quantiles is statistically significant. This is more robust than parametric ANOVA when ICs may not be normally distributed.

Comparison to SHAP Interactions: - Conditional IC: Fast, interpretable, requires no model, pairwise only - SHAP interactions: Slow, model-specific, captures complex interactions Use conditional IC for quick screening, SHAP for deep dive on specific pairs

References

This metric combines concepts from: - Alphalens factor analysis (cross-sectional IC) - Conditional independence testing - Interaction effect analysis from experimental design

Source code in src/ml4t/diagnostic/metrics/conditional.py
def compute_conditional_ic(
    feature_a: Union[pl.DataFrame, pd.DataFrame, pl.Series, pd.Series, "NDArray[Any]"],
    feature_b: Union[pl.DataFrame, pd.DataFrame, pl.Series, pd.Series, "NDArray[Any]"],
    forward_returns: Union[pl.DataFrame, pd.DataFrame, pl.Series, pd.Series, "NDArray[Any]"],
    date_col: str | None = None,
    group_col: str | None = None,
    n_quantiles: int = 5,
    method: str = "spearman",
    min_periods: int = 10,
) -> dict[str, Any]:
    """Compute IC of feature_a conditional on quantiles of feature_b.

    This measures how feature_a's predictive power varies across different
    regimes defined by feature_b. Strong variation suggests feature interaction,
    which is critical for understanding when features work best.

    This is a key ingredient for the Feature Interaction Tear Sheet, enabling
    analysis like: "Does momentum (feature_a) work better in high or low
    volatility (feature_b) regimes?"

    Parameters
    ----------
    feature_a : DataFrame/Series/ndarray
        Feature to evaluate (IC will be computed for this)
        If DataFrame with date_col/group_col, will compute IC per date
        If Series/array, must align with feature_b and forward_returns
    feature_b : DataFrame/Series/ndarray
        Conditioning feature (used to create quantile bins)
        Must match feature_a structure
    forward_returns : DataFrame/Series/ndarray
        Forward returns to predict
        Must match feature_a structure
    date_col : str | None, default None
        Column name for dates (for panel data grouping)
        If specified, quantiles computed cross-sectionally per date
    group_col : str | None, default None
        Column name for groups/assets (for panel data)
    n_quantiles : int, default 5
        Number of quantile bins for feature_b
    method : str, default "spearman"
        Correlation method: "spearman" or "pearson"
    min_periods : int, default 10
        Minimum observations per quantile for valid IC calculation

    Returns
    -------
    dict[str, Any]
        Dictionary with:
        - quantile_ics: IC of feature_a in each quantile of feature_b (array)
        - quantile_labels: Labels for each quantile (list of str)
        - quantile_bounds: Mean value of feature_b in each quantile (dict)
        - ic_variation: Std dev of ICs across quantiles (float)
        - ic_range: Max - min IC (float)
        - significance_pvalue: Statistical test p-value (float)
        - test_statistic: Kruskal-Wallis H statistic (float)
        - n_quantiles: Number of quantiles (int)
        - n_obs_per_quantile: Observations in each quantile (dict)
        - interpretation: Automated insight generation (str)

    Examples
    --------
    >>> import numpy as np
    >>> import pandas as pd
    >>>
    >>> # Does momentum work better in high or low volatility?
    >>> np.random.seed(42)
    >>> n = 1000
    >>> volatility = np.random.randn(n)
    >>> momentum = np.random.randn(n)
    >>> # Returns depend on momentum only when volatility is high
    >>> noise = 0.1 * np.random.randn(n)
    >>> returns = np.where(volatility > 0, momentum + noise, noise)
    >>>
    >>> result = compute_conditional_ic(momentum, volatility, returns)
    >>> print(f"IC Range: {result['ic_range']:.3f}")
    >>> print(f"P-value: {result['significance_pvalue']:.3f}")
    >>> print(result['interpretation'])
    IC Range: 0.234
    P-value: 0.001
    Strong interaction detected: IC ranges from 0.012 to 0.246 across feature_b quantiles (p=0.001)

    Notes
    -----
    **Use Cases**:
    - Regime-dependent feature effectiveness
    - Feature interaction discovery
    - Risk factor analysis (does alpha persist in different market conditions?)
    - Conditional portfolio construction

    **Panel Data Handling**:
    When date_col is specified, quantiles are computed WITHIN each cross-section
    (date) to avoid lookahead bias. This ensures quantile bins are time-consistent.

    **Statistical Significance**:
    Uses Kruskal-Wallis test (non-parametric one-way ANOVA) to test if IC
    variation across quantiles is statistically significant. This is more robust
    than parametric ANOVA when ICs may not be normally distributed.

    **Comparison to SHAP Interactions**:
    - Conditional IC: Fast, interpretable, requires no model, pairwise only
    - SHAP interactions: Slow, model-specific, captures complex interactions
    Use conditional IC for quick screening, SHAP for deep dive on specific pairs

    References
    ----------
    This metric combines concepts from:
    - Alphalens factor analysis (cross-sectional IC)
    - Conditional independence testing
    - Interaction effect analysis from experimental design
    """
    adapter = DataFrameAdapter()
    quantile_labels = [f"Q{i + 1}" for i in range(n_quantiles)]

    # Handle Series/array inputs
    if isinstance(feature_a, pl.Series | pd.Series | np.ndarray):
        if date_col is not None or group_col is not None:
            raise ValueError(
                "date_col and group_col require DataFrame inputs with those columns. "
                "For Series/array inputs, use None for both."
            )
        # Convert to arrays
        feat_a_arr = adapter.to_numpy(feature_a).flatten()
        feat_b_arr = adapter.to_numpy(feature_b).flatten()
        ret_arr = adapter.to_numpy(forward_returns).flatten()

        # Validate lengths
        if not (len(feat_a_arr) == len(feat_b_arr) == len(ret_arr)):
            raise ValueError(
                f"All inputs must have same length. Got: feature_a={len(feat_a_arr)}, "
                f"feature_b={len(feat_b_arr)}, forward_returns={len(ret_arr)}"
            )

        # Remove NaN rows
        valid_mask = ~(np.isnan(feat_a_arr) | np.isnan(feat_b_arr) | np.isnan(ret_arr))
        feat_a_clean = feat_a_arr[valid_mask]
        feat_b_clean = feat_b_arr[valid_mask]
        ret_clean = ret_arr[valid_mask]

        if len(feat_a_clean) < min_periods * n_quantiles:
            return _empty_conditional_ic_result(
                n_quantiles, "Insufficient data for conditional IC analysis"
            )

        quantile_ids = _assign_quantile_labels(feat_b_clean, n_quantiles)
        if np.all(quantile_ids == -1):
            return _empty_conditional_ic_result(
                n_quantiles,
                "not enough unique values for requested quantiles",
                cannot_compute=True,
            )

        # Compute IC for each quantile
        ic_by_quantile: list[float] = []
        quantile_bounds: dict[Any, float] = {}
        n_obs_per_quantile: dict[Any, int] = {}
        ic_series_list: list[float] = []

        for i, q_label in enumerate(quantile_labels, start=1):
            mask = quantile_ids == i
            n_obs = int(np.sum(mask))
            if n_obs < min_periods:
                ic_by_quantile.append(np.nan)
                quantile_bounds[q_label] = np.nan
                n_obs_per_quantile[q_label] = n_obs
                continue

            # Compute IC for this quantile (confidence_intervals=False returns float)
            ic_result = pooled_ic(feat_a_clean[mask], ret_clean[mask], method=method)
            # When confidence_intervals=False, returns float; otherwise dict
            if isinstance(ic_result, dict):
                ic_val = float(ic_result.get("ic", np.nan))
            else:
                ic_val = float(ic_result)
            ic_by_quantile.append(ic_val)
            quantile_bounds[q_label] = float(np.mean(feat_b_clean[mask]))
            n_obs_per_quantile[q_label] = n_obs

            # Store individual IC values for statistical test
            # (approximation: use bootstrap or treat IC as single observation)
            ic_series_list.append(ic_val)

    else:
        # DataFrame input with Polars-first internal path
        if isinstance(feature_a, pl.DataFrame):
            df_a = feature_a.clone()
        elif isinstance(feature_a, pd.DataFrame):
            df_a = pl.from_pandas(feature_a)
        else:
            raise TypeError(f"feature_a must be DataFrame in this branch, got {type(feature_a)}")

        if isinstance(feature_b, pl.DataFrame):
            df_b = feature_b.clone()
        elif isinstance(feature_b, pd.DataFrame):
            df_b = pl.from_pandas(feature_b)
        else:
            raise TypeError(f"feature_b must be DataFrame in this branch, got {type(feature_b)}")

        if isinstance(forward_returns, pl.DataFrame):
            df_ret = forward_returns.clone()
        elif isinstance(forward_returns, pd.DataFrame):
            df_ret = pl.from_pandas(forward_returns)
        else:
            raise TypeError(
                f"forward_returns must be DataFrame in this branch, got {type(forward_returns)}"
            )

        # Validate structure
        if date_col is not None and date_col not in df_a.columns:
            raise ValueError(f"date_col '{date_col}' not found in feature_a DataFrame")
        if group_col is not None and group_col not in df_a.columns:
            raise ValueError(f"group_col '{group_col}' not found in feature_a DataFrame")

        # Infer feature column names (assume single value column after date/group)
        meta_cols = [c for c in [date_col, group_col] if c is not None]
        feat_a_col = [c for c in df_a.columns if c not in meta_cols][0]
        feat_b_col = [c for c in df_b.columns if c not in meta_cols][0]
        ret_col = [c for c in df_ret.columns if c not in meta_cols][0]

        # Assemble aligned arrays (same row order as current behavior)
        feat_a_arr = np.asarray(df_a[feat_a_col].to_numpy(), dtype=np.float64)
        feat_b_arr = np.asarray(df_b[feat_b_col].to_numpy(), dtype=np.float64)
        ret_arr = np.asarray(df_ret[ret_col].to_numpy(), dtype=np.float64)

        valid_mask = ~(np.isnan(feat_a_arr) | np.isnan(feat_b_arr) | np.isnan(ret_arr))
        feat_a_clean = feat_a_arr[valid_mask]
        feat_b_clean = feat_b_arr[valid_mask]
        ret_clean = ret_arr[valid_mask]

        if len(feat_a_clean) < min_periods * n_quantiles:
            return _empty_conditional_ic_result(
                n_quantiles, "Insufficient data for conditional IC analysis"
            )
        if len(feat_a_clean) == 0:
            return _empty_conditional_ic_result(n_quantiles, "No valid quantiles after filtering")

        if date_col is not None:
            date_arr = np.asarray(df_a[date_col].to_numpy())[valid_mask]
            quantile_ids = np.full(len(feat_b_clean), -1, dtype=np.int16)

            # Cross-sectional quantiles per date group.
            for date_value in np.unique(date_arr):
                group_mask = date_arr == date_value
                group_ids = _assign_quantile_labels(feat_b_clean[group_mask], n_quantiles)
                quantile_ids[group_mask] = group_ids
        else:
            quantile_ids = _assign_quantile_labels(feat_b_clean, n_quantiles)
            if np.all(quantile_ids == -1):
                return _empty_conditional_ic_result(
                    n_quantiles,
                    "not enough unique values for requested quantiles",
                    cannot_compute=True,
                )

        valid_quantile_mask = quantile_ids > 0
        if not np.any(valid_quantile_mask):
            return _empty_conditional_ic_result(n_quantiles, "No valid quantiles after filtering")

        feat_a_quant = feat_a_clean[valid_quantile_mask]
        feat_b_quant = feat_b_clean[valid_quantile_mask]
        ret_quant = ret_clean[valid_quantile_mask]
        quantile_ids = quantile_ids[valid_quantile_mask]

        ic_by_quantile = []
        quantile_bounds = {}
        n_obs_per_quantile = {}
        ic_series_list = []

        for i, q_label in enumerate(quantile_labels, start=1):
            mask = quantile_ids == i
            n_obs = int(np.sum(mask))
            if n_obs < min_periods:
                ic_by_quantile.append(np.nan)
                quantile_bounds[q_label] = np.nan
                n_obs_per_quantile[q_label] = n_obs
                continue

            ic_result = pooled_ic(feat_a_quant[mask], ret_quant[mask], method=method)
            if isinstance(ic_result, dict):
                ic_val = float(ic_result.get("ic", np.nan))
            else:
                ic_val = float(ic_result)
            ic_by_quantile.append(ic_val)
            quantile_bounds[q_label] = float(np.mean(feat_b_quant[mask]))
            n_obs_per_quantile[q_label] = n_obs
            ic_series_list.append(ic_val)

    # Convert to arrays
    ic_array = np.array(ic_by_quantile)

    # Remove NaN ICs for statistics
    valid_ics = ic_array[~np.isnan(ic_array)]

    if len(valid_ics) < 2:
        ic_variation = None
        ic_range = None
        test_statistic = None
        pvalue = None
        interpretation = "Insufficient valid quantiles for interaction analysis"
    else:
        # Compute variation metrics
        ic_variation = float(np.std(valid_ics))
        ic_range = float(np.max(valid_ics) - np.min(valid_ics))

        # Statistical significance test: Kruskal-Wallis
        # Test if ICs differ significantly across quantiles
        # Note: We're testing a single IC per quantile, which is a limitation
        # In practice, this is an approximation - ideally we'd bootstrap or
        # compute IC time series per quantile for more robust testing
        if len(valid_ics) >= 3:
            # For Kruskal-Wallis, we need at least 3 groups
            # Create dummy groups (each IC is one observation)
            # This is a conservative approximation
            try:
                # Simple approach: treat each quantile's IC as a single sample
                # This understates significance but is conservative
                # Better approach would be bootstrap IC distributions per quantile

                # Create groups for Kruskal-Wallis
                # Since we only have one IC per quantile, we'll use a simpler test
                # Check if variance is significant using randomization
                # For now, use a heuristic based on IC range and number of quantiles
                test_statistic = ic_range / (ic_variation + 1e-10)
                # Conservative: assume independence, use t-test approximation
                # This is a placeholder for proper bootstrap testing
                from scipy.stats import t

                df_test = len(valid_ics) - 1
                pvalue = 2 * (1 - t.cdf(abs(test_statistic), df_test))
            except Exception:
                test_statistic = np.nan
                pvalue = np.nan
        else:
            test_statistic = np.nan
            pvalue = np.nan

        # Generate interpretation
        if np.isnan(pvalue):
            interpretation = (
                f"IC varies across quantiles: range={ic_range:.3f}, std={ic_variation:.3f}. "
                "Statistical significance could not be determined."
            )
        elif ic_range > 0.1 and pvalue < 0.05:
            ic_min = float(np.min(valid_ics))
            ic_max = float(np.max(valid_ics))
            interpretation = (
                f"Strong interaction detected: IC ranges from {ic_min:.3f} to {ic_max:.3f} "
                f"across feature_b quantiles (p={pvalue:.3f}). "
                "Feature A's predictive power is highly regime-dependent."
            )
        elif ic_range > 0.05 and pvalue < 0.05:
            interpretation = (
                f"Moderate interaction detected: IC range={ic_range:.3f} (p={pvalue:.3f}). "
                "Feature A's effectiveness varies across feature_b regimes."
            )
        elif pvalue < 0.05:
            interpretation = (
                f"Weak but significant interaction detected (p={pvalue:.3f}). "
                "Some regime-dependence in feature A's predictive power."
            )
        else:
            interpretation = (
                f"No significant interaction detected (p={pvalue:.3f}). "
                "Feature A's predictive power is consistent across feature_b quantiles."
            )

    return {
        "quantile_ics": ic_array,
        "quantile_labels": quantile_labels,
        "quantile_bounds": quantile_bounds,
        "ic_variation": float(ic_variation)
        if ic_variation is not None and not np.isnan(ic_variation)
        else None,
        "ic_range": float(ic_range) if ic_range is not None and not np.isnan(ic_range) else None,
        "significance_pvalue": float(pvalue)
        if pvalue is not None and not np.isnan(pvalue)
        else None,
        "test_statistic": float(test_statistic)
        if test_statistic is not None and not np.isnan(test_statistic)
        else None,
        "n_quantiles": n_quantiles,
        "n_obs_per_quantile": n_obs_per_quantile,
        "interpretation": interpretation,
    }

compute_monotonicity

compute_monotonicity(
    features,
    outcomes,
    n_quantiles=5,
    feature_col=None,
    outcome_col=None,
    method="spearman",
)

Test monotonic relationship between feature values and outcomes.

Monotonicity is a key property for predictive features - we expect higher (or lower) feature values to consistently correspond to higher outcomes. Non-monotonic relationships often indicate: 1. Feature needs transformation (e.g., absolute value, log) 2. Feature has regime-dependent behavior 3. Feature is not truly predictive

This function bins features into quantiles and checks if mean outcomes increase/decrease monotonically across bins.

Parameters

features : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature values to test outcomes : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Outcome values (typically returns) n_quantiles : int, default 5 Number of quantile bins (5 = quintiles, 10 = deciles) feature_col : str | None, default None Column name for features (if DataFrame) outcome_col : str | None, default None Column name for outcomes (if DataFrame) method : str, default "spearman" Correlation method: "spearman" or "pearson"

Returns

dict[str, Any] Dictionary with monotonicity analysis: - correlation: Spearman/Pearson correlation - p_value: Statistical significance of correlation - quantile_means: Mean outcome per quantile - quantile_labels: Quantile labels (Q1, Q2, ...) - is_monotonic: Boolean, True if strictly monotonic - monotonicity_score: Fraction of quantile pairs that are monotonic (0-1) - direction: "increasing", "decreasing", or "non-monotonic" - n_observations: Total observations - n_per_quantile: Observations per quantile

Examples

Test if momentum predicts returns

features = df['momentum'] outcomes = df['forward_return'] result = compute_monotonicity(features, outcomes, n_quantiles=5)

print(f"Correlation: {result['correlation']:.3f}") print(f"P-value: {result['p_value']:.4f}") print(f"Monotonic: {result['is_monotonic']}") print(f"Direction: {result['direction']}") print(f"Quantile means: {result['quantile_means']}") Correlation: 0.156 P-value: 0.0001 Monotonic: True Direction: increasing Quantile means: [-0.002, 0.001, 0.003, 0.005, 0.008]

Notes

Monotonicity Score: - 1.0: Perfect monotonicity (all adjacent quantiles ordered correctly) - 0.8-1.0: Strong monotonicity (minor violations) - 0.6-0.8: Moderate monotonicity - <0.6: Weak or no monotonicity

Common Patterns: - Monotonic increasing: Good positive predictor - Monotonic decreasing: Good negative predictor (consider sign flip) - U-shaped: Consider absolute value or squared feature - Flat: Feature not predictive

References

.. [1] Kakushadze, Z., & Serur, J. A. (2018). "151 Trading Strategies."

Source code in src/ml4t/diagnostic/metrics/monotonicity.py
def compute_monotonicity(
    features: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    outcomes: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    n_quantiles: int = 5,
    feature_col: str | None = None,
    outcome_col: str | None = None,
    method: str = "spearman",
) -> dict[str, Any]:
    """Test monotonic relationship between feature values and outcomes.

    Monotonicity is a key property for predictive features - we expect higher
    (or lower) feature values to consistently correspond to higher outcomes.
    Non-monotonic relationships often indicate:
    1. Feature needs transformation (e.g., absolute value, log)
    2. Feature has regime-dependent behavior
    3. Feature is not truly predictive

    This function bins features into quantiles and checks if mean outcomes
    increase/decrease monotonically across bins.

    Parameters
    ----------
    features : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature values to test
    outcomes : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Outcome values (typically returns)
    n_quantiles : int, default 5
        Number of quantile bins (5 = quintiles, 10 = deciles)
    feature_col : str | None, default None
        Column name for features (if DataFrame)
    outcome_col : str | None, default None
        Column name for outcomes (if DataFrame)
    method : str, default "spearman"
        Correlation method: "spearman" or "pearson"

    Returns
    -------
    dict[str, Any]
        Dictionary with monotonicity analysis:
        - correlation: Spearman/Pearson correlation
        - p_value: Statistical significance of correlation
        - quantile_means: Mean outcome per quantile
        - quantile_labels: Quantile labels (Q1, Q2, ...)
        - is_monotonic: Boolean, True if strictly monotonic
        - monotonicity_score: Fraction of quantile pairs that are monotonic (0-1)
        - direction: "increasing", "decreasing", or "non-monotonic"
        - n_observations: Total observations
        - n_per_quantile: Observations per quantile

    Examples
    --------
    >>> # Test if momentum predicts returns
    >>> features = df['momentum']
    >>> outcomes = df['forward_return']
    >>> result = compute_monotonicity(features, outcomes, n_quantiles=5)
    >>>
    >>> print(f"Correlation: {result['correlation']:.3f}")
    >>> print(f"P-value: {result['p_value']:.4f}")
    >>> print(f"Monotonic: {result['is_monotonic']}")
    >>> print(f"Direction: {result['direction']}")
    >>> print(f"Quantile means: {result['quantile_means']}")
    Correlation: 0.156
    P-value: 0.0001
    Monotonic: True
    Direction: increasing
    Quantile means: [-0.002, 0.001, 0.003, 0.005, 0.008]

    Notes
    -----
    Monotonicity Score:
    - 1.0: Perfect monotonicity (all adjacent quantiles ordered correctly)
    - 0.8-1.0: Strong monotonicity (minor violations)
    - 0.6-0.8: Moderate monotonicity
    - <0.6: Weak or no monotonicity

    Common Patterns:
    - Monotonic increasing: Good positive predictor
    - Monotonic decreasing: Good negative predictor (consider sign flip)
    - U-shaped: Consider absolute value or squared feature
    - Flat: Feature not predictive

    References
    ----------
    .. [1] Kakushadze, Z., & Serur, J. A. (2018). "151 Trading Strategies."
    """
    # Extract feature and outcome arrays
    feature_vals: NDArray[Any]
    if isinstance(features, pl.DataFrame) or isinstance(features, pd.DataFrame):
        if feature_col is None:
            raise ValueError("feature_col must be specified for DataFrame input")
        feature_vals = features[feature_col].to_numpy()
    else:
        feature_vals = np.asarray(features).flatten()

    outcome_vals: NDArray[Any]
    if isinstance(outcomes, pl.DataFrame) or isinstance(outcomes, pd.DataFrame):
        if outcome_col is None:
            raise ValueError("outcome_col must be specified for DataFrame input")
        outcome_vals = outcomes[outcome_col].to_numpy()
    else:
        outcome_vals = np.asarray(outcomes).flatten()

    # Validate inputs
    if len(feature_vals) != len(outcome_vals):
        raise ValueError(
            f"Features ({len(feature_vals)}) and outcomes ({len(outcome_vals)}) must have same length"
        )

    # Remove NaN values
    valid_mask = ~(np.isnan(feature_vals.astype(float)) | np.isnan(outcome_vals.astype(float)))
    feature_clean = feature_vals[valid_mask]
    outcome_clean = outcome_vals[valid_mask]

    n = len(feature_clean)
    if n < n_quantiles * 2:
        # Insufficient data for quantile analysis
        return {
            "correlation": np.nan,
            "p_value": np.nan,
            "quantile_means": [],
            "quantile_labels": [],
            "is_monotonic": False,
            "monotonicity_score": 0.0,
            "direction": "insufficient_data",
            "n_observations": n,
            "n_per_quantile": [],
        }

    # Compute correlation
    if method == "spearman":
        correlation, p_value = spearmanr(feature_clean, outcome_clean)
    elif method == "pearson":
        correlation, p_value = stats.pearsonr(feature_clean, outcome_clean)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'.")

    # Create quantile bins
    quantile_edges = np.linspace(0, 100, n_quantiles + 1)
    quantile_bins = np.percentile(feature_clean, quantile_edges)

    # Assign observations to quantiles
    quantile_assignments = np.digitize(feature_clean, quantile_bins[1:-1])  # 0-indexed bins

    # Compute mean outcome per quantile
    quantile_means = []
    n_per_quantile = []

    for q in range(n_quantiles):
        mask = quantile_assignments == q
        if np.sum(mask) > 0:
            quantile_means.append(float(np.mean(outcome_clean[mask])))
            n_per_quantile.append(int(np.sum(mask)))
        else:
            quantile_means.append(np.nan)
            n_per_quantile.append(0)

    # Check monotonicity
    # Count how many adjacent pairs are ordered correctly
    monotonic_pairs = 0
    total_pairs = 0

    for i in range(len(quantile_means) - 1):
        if not (np.isnan(quantile_means[i]) or np.isnan(quantile_means[i + 1])):
            total_pairs += 1
            # Check if ordered (either increasing or decreasing)
            if correlation > 0:
                # Expect increasing
                if quantile_means[i + 1] > quantile_means[i]:
                    monotonic_pairs += 1
            # Expect decreasing
            elif quantile_means[i + 1] < quantile_means[i]:
                monotonic_pairs += 1

    monotonicity_score = monotonic_pairs / total_pairs if total_pairs > 0 else 0.0

    # Strict monotonicity check (all pairs ordered correctly)
    is_monotonic = monotonicity_score == 1.0

    # Determine direction
    if is_monotonic:
        direction = "increasing" if correlation > 0 else "decreasing"
    elif monotonicity_score >= 0.8:
        direction = "mostly_" + ("increasing" if correlation > 0 else "decreasing")
    else:
        direction = "non_monotonic"

    # Create quantile labels
    quantile_labels = [f"Q{i + 1}" for i in range(n_quantiles)]

    return {
        "correlation": float(correlation),
        "p_value": float(p_value),
        "quantile_means": quantile_means,
        "quantile_labels": quantile_labels,
        "is_monotonic": is_monotonic,
        "monotonicity_score": float(monotonicity_score),
        "direction": direction,
        "n_observations": n,
        "n_per_quantile": n_per_quantile,
    }

compute_mdi_importance

compute_mdi_importance(
    model, feature_names=None, normalize=True
)

Compute Mean Decrease in Impurity (MDI) feature importance from tree-based models.

MDI measures how much each feature contributes to decreasing the weighted impurity (Gini for classification, MSE/MAE for regression) across all trees. This is computed during model training and is available via the model's feature_importances_ attribute.

Supported Models: - LightGBM: lightgbm.LGBMClassifier, lightgbm.LGBMRegressor (recommended) - XGBoost: xgboost.XGBClassifier, xgboost.XGBRegressor (recommended) - sklearn: RandomForestClassifier, RandomForestRegressor (not recommended - slow) - sklearn: GradientBoostingClassifier, GradientBoostingRegressor (not recommended - slow)

Not supported: - sklearn's HistGradientBoosting* (doesn't expose feature_importances_)

Parameters

model : Any Fitted tree-based model with feature_importances_ attribute. Must be one of: LightGBM, XGBoost, or sklearn tree ensembles. feature_names : list[str] | None, default None Feature names for labeling. If None, uses feature names from model or generates numeric names. normalize : bool, default True If True, ensures importances sum to 1.0 (some models already normalize).

Returns

dict[str, Any] Dictionary with MDI importance results: - importances: Feature importance values (sorted descending) - feature_names: Feature labels (sorted by importance) - n_features: Number of features - normalized: Whether values sum to 1.0 - model_type: Type of model used

Raises

AttributeError If model doesn't have feature_importances_ attribute ImportError If LightGBM/XGBoost not installed and trying to use those models

Examples

import lightgbm as lgb from sklearn.datasets import make_classification

Train LightGBM model

X, y = make_classification(n_samples=1000, n_features=10, random_state=42) model = lgb.LGBMClassifier(n_estimators=100, random_state=42) model.fit(X, y)

Extract MDI importance

mdi = compute_mdi_importance( ... model=model, ... feature_names=[f'feature_{i}' for i in range(10)] ... )

Results are sorted descending by importance

print(mdi['feature_names']) # doctest: +SKIP ['feature_3', 'feature_0', ...] print(mdi['model_type']) lightgbm.LGBMClassifier

Notes

MDI vs PFI (Permutation Feature Importance):

MDI Advantages: - Very fast: Computed during training (no additional overhead) - No additional data required - Deterministic: Same result every time

MDI Disadvantages: - Biased toward high-cardinality features: Features with many unique values get inflated importance even if not truly predictive - Only for tree-based models: Not model-agnostic - Train set importance: May not reflect test set predictive power - Correlated features: Can split importance between correlated predictors

When to use MDI: - Quick exploratory analysis - When computational budget is limited - When working with tree-based models exclusively

When to use PFI instead: - Need unbiased importance estimates - Have high-cardinality categorical features - Want model-agnostic importance - Need to validate importance on test set

Comparison workflow:

Compare MDI and PFI

mdi = compute_mdi_importance(model, feature_names=features) pfi = compute_permutation_importance(model, X_test, y_test, feature_names=features)

Large discrepancies may indicate:

- High-cardinality bias in MDI

- Correlated features splitting importance

- Overfitting (high MDI, low PFI)

Performance notes: - LightGBM and XGBoost: Production-ready speed and accuracy (RECOMMENDED) - sklearn RandomForest/GradientBoosting: 10-100x slower, avoid for large datasets - sklearn HistGradientBoosting: Fast but doesn't expose feature_importances_ (use PFI instead)

References
  • Breiman, L. (2001). "Random Forests". Machine Learning.
  • Louppe, G. et al. (2013). "Understanding variable importances in forests of randomized trees". NeurIPS.
  • Strobl, C. et al. (2007). "Bias in random forest variable importance measures". BMC Bioinformatics.
Source code in src/ml4t/diagnostic/metrics/importance_classical.py
def compute_mdi_importance(
    model: Any,
    feature_names: list[str] | None = None,
    normalize: bool = True,
) -> dict[str, Any]:
    """Compute Mean Decrease in Impurity (MDI) feature importance from tree-based models.

    MDI measures how much each feature contributes to decreasing the weighted
    impurity (Gini for classification, MSE/MAE for regression) across all trees.
    This is computed during model training and is available via the model's
    `feature_importances_` attribute.

    **Supported Models**:
    - LightGBM: `lightgbm.LGBMClassifier`, `lightgbm.LGBMRegressor` (recommended)
    - XGBoost: `xgboost.XGBClassifier`, `xgboost.XGBRegressor` (recommended)
    - sklearn: `RandomForestClassifier`, `RandomForestRegressor` (not recommended - slow)
    - sklearn: `GradientBoostingClassifier`, `GradientBoostingRegressor` (not recommended - slow)

    **Not supported**:
    - sklearn's HistGradientBoosting* (doesn't expose feature_importances_)

    Parameters
    ----------
    model : Any
        Fitted tree-based model with `feature_importances_` attribute.
        Must be one of: LightGBM, XGBoost, or sklearn tree ensembles.
    feature_names : list[str] | None, default None
        Feature names for labeling. If None, uses feature names from model
        or generates numeric names.
    normalize : bool, default True
        If True, ensures importances sum to 1.0 (some models already normalize).

    Returns
    -------
    dict[str, Any]
        Dictionary with MDI importance results:
        - importances: Feature importance values (sorted descending)
        - feature_names: Feature labels (sorted by importance)
        - n_features: Number of features
        - normalized: Whether values sum to 1.0
        - model_type: Type of model used

    Raises
    ------
    AttributeError
        If model doesn't have `feature_importances_` attribute
    ImportError
        If LightGBM/XGBoost not installed and trying to use those models

    Examples
    --------
    >>> import lightgbm as lgb
    >>> from sklearn.datasets import make_classification
    >>>
    >>> # Train LightGBM model
    >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
    >>> model = lgb.LGBMClassifier(n_estimators=100, random_state=42)
    >>> model.fit(X, y)
    >>>
    >>> # Extract MDI importance
    >>> mdi = compute_mdi_importance(
    ...     model=model,
    ...     feature_names=[f'feature_{i}' for i in range(10)]
    ... )
    >>>
    >>> # Results are sorted descending by importance
    >>> print(mdi['feature_names'])  # doctest: +SKIP
    ['feature_3', 'feature_0', ...]
    >>> print(mdi['model_type'])
    lightgbm.LGBMClassifier

    Notes
    -----
    **MDI vs PFI** (Permutation Feature Importance):

    **MDI Advantages**:
    - Very fast: Computed during training (no additional overhead)
    - No additional data required
    - Deterministic: Same result every time

    **MDI Disadvantages**:
    - **Biased toward high-cardinality features**: Features with many unique values
      get inflated importance even if not truly predictive
    - **Only for tree-based models**: Not model-agnostic
    - **Train set importance**: May not reflect test set predictive power
    - **Correlated features**: Can split importance between correlated predictors

    **When to use MDI**:
    - Quick exploratory analysis
    - When computational budget is limited
    - When working with tree-based models exclusively

    **When to use PFI instead**:
    - Need unbiased importance estimates
    - Have high-cardinality categorical features
    - Want model-agnostic importance
    - Need to validate importance on test set

    **Comparison workflow**:
    >>> # Compare MDI and PFI
    >>> mdi = compute_mdi_importance(model, feature_names=features)
    >>> pfi = compute_permutation_importance(model, X_test, y_test, feature_names=features)
    >>>
    >>> # Large discrepancies may indicate:
    >>> # - High-cardinality bias in MDI
    >>> # - Correlated features splitting importance
    >>> # - Overfitting (high MDI, low PFI)

    **Performance notes**:
    - LightGBM and XGBoost: Production-ready speed and accuracy (RECOMMENDED)
    - sklearn RandomForest/GradientBoosting: 10-100x slower, avoid for large datasets
    - sklearn HistGradientBoosting: Fast but doesn't expose feature_importances_ (use PFI instead)

    References
    ----------
    - Breiman, L. (2001). "Random Forests". Machine Learning.
    - Louppe, G. et al. (2013). "Understanding variable importances in forests of
      randomized trees". NeurIPS.
    - Strobl, C. et al. (2007). "Bias in random forest variable importance measures".
      BMC Bioinformatics.
    """
    # Check if model has feature_importances_
    if not hasattr(model, "feature_importances_"):
        raise AttributeError(
            f"Model of type {type(model).__name__} does not have 'feature_importances_' attribute. "
            "MDI is only available for tree-based models (LightGBM, XGBoost, sklearn tree ensembles)."
        )

    # Extract raw importances
    importances = model.feature_importances_

    # Get feature names
    if feature_names is None:
        # Try to get from model
        if hasattr(model, "feature_name_"):
            # LightGBM
            feature_names = model.feature_name_
        elif hasattr(model, "get_booster") and hasattr(model.get_booster(), "feature_names"):
            # XGBoost
            feature_names = model.get_booster().feature_names
            if feature_names is None and hasattr(model, "feature_names_in_"):
                feature_names = list(model.feature_names_in_)
        elif hasattr(model, "feature_names_in_"):
            # sklearn
            feature_names = list(model.feature_names_in_)
        if feature_names is None:
            # Fallback to numeric names
            feature_names = [f"feature_{i}" for i in range(len(importances))]
    else:
        feature_names = list(feature_names)

    # Validate length match
    if len(feature_names) != len(importances):
        raise ValueError(
            f"Number of feature names ({len(feature_names)}) does not match number of importances ({len(importances)})"
        )

    # Normalize if requested
    if normalize:
        importance_sum = importances.sum()
        if importance_sum > 0:
            importances = importances / importance_sum
        else:
            # All zeros - already normalized
            pass

    # Sort by importance (descending)
    sorted_idx = np.argsort(importances)[::-1]

    # Determine model type
    model_type = f"{type(model).__module__}.{type(model).__name__}"

    return {
        "importances": importances[sorted_idx],
        "feature_names": [feature_names[i] for i in sorted_idx],
        "n_features": len(feature_names),
        "normalized": normalize,
        "model_type": model_type,
    }

compute_permutation_importance

compute_permutation_importance(
    model,
    X,
    y,
    feature_names=None,
    scoring=None,
    n_repeats=10,
    random_state=42,
    n_jobs=None,
)

Compute Permutation Feature Importance (PFI) for model-agnostic feature ranking.

Permutation Feature Importance measures the increase in model error when a feature's values are randomly shuffled. Features with high importance cause large performance drops when permuted, indicating they are critical for the model's predictions.

This is a model-agnostic method that works with any fitted estimator, making it superior to model-specific importance measures (e.g., tree-based feature importances) which can be biased toward high-cardinality features.

Parameters

model : Any Fitted sklearn-compatible estimator (must have predict or predict_proba) X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) y : Union[pl.Series, pd.Series, np.ndarray] Target values (n_samples,) feature_names : list[str] | None, default None Feature names for labeling. If None, uses column names from DataFrame or generates numeric names for arrays scoring : str | Callable | None, default None Scoring function to evaluate model performance. If None, uses model's default score method. Common options: - Classification: 'accuracy', 'roc_auc', 'f1' - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error' n_repeats : int, default 10 Number of times to permute each feature (more repeats = more stable estimates) random_state : int | None, default 42 Random seed for reproducibility n_jobs : int | None, default None Number of parallel jobs (-1 for all CPUs)

Returns

dict[str, Any] Dictionary with permutation importance results: - importances_mean: Mean importance per feature - importances_std: Standard deviation of importance per feature - importances_raw: All permutation results (n_features, n_repeats) - feature_names: Feature labels - baseline_score: Model score before permutation - n_repeats: Number of permutation rounds - scoring: Scoring function used

Examples

from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification

Train a simple model

X, y = make_classification(n_samples=1000, n_features=10, random_state=42) model = RandomForestClassifier(n_estimators=10, random_state=42) model.fit(X, y)

Compute permutation importance

pfi = compute_permutation_importance( ... model=model, ... X=X, ... y=y, ... n_repeats=10, ... scoring='accuracy' ... )

Results are sorted descending by importance

print(pfi['baseline_score']) 0.92 print(pfi['feature_names']) # doctest: +SKIP ['feature_0', 'feature_3', ...]

Notes

Interpretation: - Importance = 0: Feature not useful - Importance > 0: Feature contributes to predictions - Importance < 0: Feature hurts performance (may indicate overfitting) - Higher importance = More critical feature

Advantages over MDI (Mean Decrease in Impurity): - Model-agnostic: Works with any estimator - Unbiased: Not inflated by high-cardinality features - Realistic: Measures actual predictive power, not just tree splits

Computational Cost: - Time complexity: O(n_features * n_repeats * prediction_time) - Can be slow for large datasets or complex models - Use n_jobs=-1 for parallel computation

Best Practices: - Use hold-out validation set (not training data) for unbiased estimates - Increase n_repeats (20-30) for more stable results - Check for negative importances (may indicate model instability) - Compare with other importance methods (SHAP, MDI) for robustness

References

.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.

Source code in src/ml4t/diagnostic/metrics/importance_classical.py
def compute_permutation_importance(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    y: Union[pl.Series, pd.Series, "NDArray[Any]"],
    feature_names: list[str] | None = None,
    scoring: str | Callable | None = None,
    n_repeats: int = 10,
    random_state: int | None = 42,
    n_jobs: int | None = None,
) -> dict[str, Any]:
    """Compute Permutation Feature Importance (PFI) for model-agnostic feature ranking.

    Permutation Feature Importance measures the increase in model error when a
    feature's values are randomly shuffled. Features with high importance cause
    large performance drops when permuted, indicating they are critical for
    the model's predictions.

    This is a model-agnostic method that works with any fitted estimator,
    making it superior to model-specific importance measures (e.g., tree-based
    feature importances) which can be biased toward high-cardinality features.

    Parameters
    ----------
    model : Any
        Fitted sklearn-compatible estimator (must have `predict` or `predict_proba`)
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    y : Union[pl.Series, pd.Series, np.ndarray]
        Target values (n_samples,)
    feature_names : list[str] | None, default None
        Feature names for labeling. If None, uses column names from DataFrame
        or generates numeric names for arrays
    scoring : str | Callable | None, default None
        Scoring function to evaluate model performance. If None, uses model's
        default score method. Common options:
        - Classification: 'accuracy', 'roc_auc', 'f1'
        - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
    n_repeats : int, default 10
        Number of times to permute each feature (more repeats = more stable estimates)
    random_state : int | None, default 42
        Random seed for reproducibility
    n_jobs : int | None, default None
        Number of parallel jobs (-1 for all CPUs)

    Returns
    -------
    dict[str, Any]
        Dictionary with permutation importance results:
        - importances_mean: Mean importance per feature
        - importances_std: Standard deviation of importance per feature
        - importances_raw: All permutation results (n_features, n_repeats)
        - feature_names: Feature labels
        - baseline_score: Model score before permutation
        - n_repeats: Number of permutation rounds
        - scoring: Scoring function used

    Examples
    --------
    >>> from sklearn.ensemble import RandomForestClassifier
    >>> from sklearn.datasets import make_classification
    >>>
    >>> # Train a simple model
    >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
    >>> model = RandomForestClassifier(n_estimators=10, random_state=42)
    >>> model.fit(X, y)
    >>>
    >>> # Compute permutation importance
    >>> pfi = compute_permutation_importance(
    ...     model=model,
    ...     X=X,
    ...     y=y,
    ...     n_repeats=10,
    ...     scoring='accuracy'
    ... )
    >>>
    >>> # Results are sorted descending by importance
    >>> print(pfi['baseline_score'])
    0.92
    >>> print(pfi['feature_names'])  # doctest: +SKIP
    ['feature_0', 'feature_3', ...]

    Notes
    -----
    **Interpretation**:
    - Importance = 0: Feature not useful
    - Importance > 0: Feature contributes to predictions
    - Importance < 0: Feature hurts performance (may indicate overfitting)
    - Higher importance = More critical feature

    **Advantages over MDI** (Mean Decrease in Impurity):
    - Model-agnostic: Works with any estimator
    - Unbiased: Not inflated by high-cardinality features
    - Realistic: Measures actual predictive power, not just tree splits

    **Computational Cost**:
    - Time complexity: O(n_features * n_repeats * prediction_time)
    - Can be slow for large datasets or complex models
    - Use n_jobs=-1 for parallel computation

    **Best Practices**:
    - Use hold-out validation set (not training data) for unbiased estimates
    - Increase n_repeats (20-30) for more stable results
    - Check for negative importances (may indicate model instability)
    - Compare with other importance methods (SHAP, MDI) for robustness

    References
    ----------
    .. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
    """
    from sklearn.inspection import permutation_importance as sklearn_pfi

    # Convert inputs to numpy arrays
    X_array: NDArray[Any]
    if isinstance(X, pl.DataFrame):
        if feature_names is None:
            feature_names = X.columns
        X_array = X.to_numpy()
    elif isinstance(X, pd.DataFrame):
        if feature_names is None:
            feature_names = X.columns.tolist()
        X_array = X.to_numpy()
    else:
        X_array = np.asarray(X)
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]

    # Type assertion: feature_names is guaranteed to be set at this point
    assert feature_names is not None, "feature_names should be set by this point"

    y_array: NDArray[Any]
    if isinstance(y, pl.Series) or isinstance(y, pd.Series):
        y_array = y.to_numpy()
    else:
        y_array = np.asarray(y)

    # Compute baseline score
    if scoring is None:
        baseline_score = model.score(X_array, y_array)
    else:
        from sklearn.metrics import get_scorer

        scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
        baseline_score = scorer(model, X_array, y_array)

    # Compute permutation importance using sklearn
    result = sklearn_pfi(
        estimator=model,
        X=X_array,
        y=y_array,
        scoring=scoring,
        n_repeats=n_repeats,
        random_state=random_state,
        n_jobs=n_jobs,
    )

    # Extract and format results
    importances_mean = result.importances_mean
    importances_std = result.importances_std
    importances_raw = result.importances  # Shape: (n_features, n_repeats)

    # Sort by importance (descending)
    sorted_idx = np.argsort(importances_mean)[::-1]

    return {
        "importances_mean": importances_mean[sorted_idx],
        "importances_std": importances_std[sorted_idx],
        "importances_raw": importances_raw[sorted_idx],
        "feature_names": [feature_names[i] for i in sorted_idx],
        "baseline_score": float(baseline_score),
        "n_repeats": n_repeats,
        "scoring": scoring if scoring is not None else "default",
        "n_features": len(feature_names),
    }

compute_shap_importance

compute_shap_importance(
    model,
    X,
    feature_names=None,
    check_additivity=True,
    max_samples=None,
    explainer_type="auto",
    use_gpu="auto",
    background_data=None,
    explainer_kwargs=None,
    show_progress=False,
    performance_warning=True,
)

Compute SHAP (SHapley Additive exPlanations) values and aggregate to feature importance.

SHAP values provide a unified measure of feature importance based on Shapley values from cooperative game theory. Each feature's contribution to a prediction is calculated by considering all possible feature coalitions, satisfying key properties like additivity and consistency.

Key advantages over MDI and PFI:

  • Theoretically sound: Based on game theory (Shapley values)
  • Consistent: Removing a feature always decreases its importance
  • Local explanations: Provides per-prediction feature contributions
  • Interaction-aware: Accounts for feature interactions naturally
  • Unbiased: No bias toward high-cardinality features (unlike MDI)
  • Model-agnostic: Works with ANY sklearn-compatible model (v1.1+)

Multi-Explainer Support:

This function automatically selects the best SHAP explainer for your model:

  • TreeExplainer: Fast, exact computation for tree-based models
  • LinearExplainer: Fast, exact computation for linear models
  • KernelExplainer: Model-agnostic fallback (slower but universal)
  • DeepExplainer: Optimized for neural networks (TensorFlow/PyTorch)
Parameters

model : Any Fitted model compatible with SHAP explainers. X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix for SHAP computation (typically test/validation set) Shape: (n_samples, n_features) feature_names : list[str] | None, default None Feature names for labeling. If None, uses column names from DataFrame or generates numeric names for arrays check_additivity : bool, default True Verify that SHAP values sum to model predictions (sanity check). Only supported by TreeExplainer. Disable for speed if you trust the implementation. max_samples : int | None, default None Maximum number of samples to compute SHAP values for. explainer_type : str, default 'auto' SHAP explainer to use: - 'auto': Automatic selection (Tree -> Linear -> Kernel cascade) - 'tree': Force TreeExplainer - 'linear': Force LinearExplainer - 'kernel': Force KernelExplainer - 'deep': Force DeepExplainer use_gpu : Union[bool, str], default 'auto' Enable GPU acceleration for SHAP computation background_data : np.ndarray | None, default None Background dataset for KernelExplainer explainer_kwargs : dict | None, default None Additional keyword arguments passed to the explainer constructor show_progress : bool, default False Show progress bar for SHAP computation (requires tqdm) performance_warning : bool, default True Issue warning if computation will take >10 seconds

Returns

dict[str, Any] Dictionary with SHAP importance results: - shap_values: SHAP values array, shape (n_samples, n_features) - importances: Mean absolute SHAP values per feature (sorted descending) - feature_names: Feature labels (sorted by importance) - base_value: Expected model output (average prediction) - n_features: Number of features - n_samples: Number of samples used for SHAP computation - model_type: Type of model used - explainer_type: Which explainer was used - additivity_verified: Whether additivity check passed

Raises

ImportError If shap library not installed ValueError If model is not supported by specified explainer RuntimeError If SHAP computation fails

Source code in src/ml4t/diagnostic/metrics/importance_shap.py
def compute_shap_importance(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    feature_names: list[str] | None = None,
    check_additivity: bool = True,
    max_samples: int | None = None,
    explainer_type: str = "auto",
    use_gpu: bool | str = "auto",
    background_data: Union["NDArray[Any]", None] = None,
    explainer_kwargs: dict | None = None,
    show_progress: bool = False,
    performance_warning: bool = True,
) -> dict[str, Any]:
    """Compute SHAP (SHapley Additive exPlanations) values and aggregate to feature importance.

    SHAP values provide a unified measure of feature importance based on Shapley values
    from cooperative game theory. Each feature's contribution to a prediction is
    calculated by considering all possible feature coalitions, satisfying key
    properties like additivity and consistency.

    **Key advantages over MDI and PFI**:

    - **Theoretically sound**: Based on game theory (Shapley values)
    - **Consistent**: Removing a feature always decreases its importance
    - **Local explanations**: Provides per-prediction feature contributions
    - **Interaction-aware**: Accounts for feature interactions naturally
    - **Unbiased**: No bias toward high-cardinality features (unlike MDI)
    - **Model-agnostic**: Works with ANY sklearn-compatible model (v1.1+)

    **Multi-Explainer Support**:

    This function automatically selects the best SHAP explainer for your model:

    - **TreeExplainer**: Fast, exact computation for tree-based models
    - **LinearExplainer**: Fast, exact computation for linear models
    - **KernelExplainer**: Model-agnostic fallback (slower but universal)
    - **DeepExplainer**: Optimized for neural networks (TensorFlow/PyTorch)

    Parameters
    ----------
    model : Any
        Fitted model compatible with SHAP explainers.
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix for SHAP computation (typically test/validation set)
        Shape: (n_samples, n_features)
    feature_names : list[str] | None, default None
        Feature names for labeling. If None, uses column names from DataFrame
        or generates numeric names for arrays
    check_additivity : bool, default True
        Verify that SHAP values sum to model predictions (sanity check).
        Only supported by TreeExplainer. Disable for speed if you trust the
        implementation.
    max_samples : int | None, default None
        Maximum number of samples to compute SHAP values for.
    explainer_type : str, default 'auto'
        SHAP explainer to use:
        - 'auto': Automatic selection (Tree -> Linear -> Kernel cascade)
        - 'tree': Force TreeExplainer
        - 'linear': Force LinearExplainer
        - 'kernel': Force KernelExplainer
        - 'deep': Force DeepExplainer
    use_gpu : Union[bool, str], default 'auto'
        Enable GPU acceleration for SHAP computation
    background_data : np.ndarray | None, default None
        Background dataset for KernelExplainer
    explainer_kwargs : dict | None, default None
        Additional keyword arguments passed to the explainer constructor
    show_progress : bool, default False
        Show progress bar for SHAP computation (requires tqdm)
    performance_warning : bool, default True
        Issue warning if computation will take >10 seconds

    Returns
    -------
    dict[str, Any]
        Dictionary with SHAP importance results:
        - shap_values: SHAP values array, shape (n_samples, n_features)
        - importances: Mean absolute SHAP values per feature (sorted descending)
        - feature_names: Feature labels (sorted by importance)
        - base_value: Expected model output (average prediction)
        - n_features: Number of features
        - n_samples: Number of samples used for SHAP computation
        - model_type: Type of model used
        - explainer_type: Which explainer was used
        - additivity_verified: Whether additivity check passed

    Raises
    ------
    ImportError
        If shap library not installed
    ValueError
        If model is not supported by specified explainer
    RuntimeError
        If SHAP computation fails
    """
    # Check if shap is installed
    try:
        import shap  # noqa: F401 (availability check)
    except ImportError as e:
        raise ImportError(
            "SHAP library is not installed. Install with: pip install ml4t-diagnostic[ml] or: pip install shap>=0.43.0"
        ) from e

    # Convert X to appropriate format
    if isinstance(X, pl.DataFrame):
        X_array = X.to_numpy()
        if feature_names is None:
            feature_names = X.columns
    elif isinstance(X, pd.DataFrame):
        X_array = X.values
        if feature_names is None:
            feature_names = list(X.columns)
    else:
        X_array = np.asarray(X)

    # Validate shape before accessing shape[1]
    if X_array.ndim != 2:
        raise ValueError(f"X must be 2D array, got shape {X_array.shape}")

    # Set default feature names if needed (after shape validation)
    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]

    # Ensure feature_names is a list
    if feature_names is not None:
        feature_names = list(feature_names)

    n_samples_full, n_features = X_array.shape

    # Subsample if requested
    if max_samples is not None and n_samples_full > max_samples:
        # Use random sampling for representative subset
        rng = np.random.default_rng(42)
        sample_idx = rng.choice(n_samples_full, size=max_samples, replace=False)
        X_array = X_array[sample_idx]
        n_samples = max_samples
    else:
        n_samples = n_samples_full

    # Validate feature names length
    if len(feature_names) != n_features:
        raise ValueError(
            f"Number of feature names ({len(feature_names)}) does not match number of features in X ({n_features})"
        )

    # Get appropriate explainer (auto-selects or uses explicit type)
    if explainer_kwargs is None:
        explainer_kwargs = {}

    explainer, explainer_type_used, ms_per_sample = _get_explainer(
        model=model,
        X_array=X_array,
        explainer_type=explainer_type,
        use_gpu=use_gpu,
        background_data=background_data,
        **explainer_kwargs,
    )

    # Issue performance warning if needed
    _estimate_computation_time(
        explainer_type=explainer_type_used,
        n_samples=n_samples,
        ms_per_sample=ms_per_sample,
        performance_warning=performance_warning,
    )

    # Compute SHAP values with optional progress bar
    try:
        # Only TreeExplainer supports check_additivity parameter
        shap_kwargs = {}
        if explainer_type_used == "tree":
            shap_kwargs["check_additivity"] = check_additivity

        if show_progress:
            try:
                from tqdm.auto import tqdm

                # Wrap computation with progress bar for slow explainers
                if explainer_type_used == "kernel":
                    # For kernel, show progress
                    with tqdm(total=n_samples, desc="Computing SHAP values") as pbar:
                        shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
                        pbar.update(n_samples)
                else:
                    # For tree/linear/deep, just compute (fast enough)
                    shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
            except ImportError:
                # tqdm not available, compute without progress bar
                shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
        else:
            shap_values_raw = explainer.shap_values(X_array, **shap_kwargs)
    except Exception as e:
        raise RuntimeError(
            f"Failed to compute SHAP values with {explainer_type_used}Explainer. "
            f"Model type: {type(model).__name__}. Error: {e}"
        ) from e

    # Handle binary classification (returns list of arrays OR 3D array)
    if isinstance(shap_values_raw, list):
        if len(shap_values_raw) == 2:
            # Binary classification (older SHAP versions)
            shap_values = shap_values_raw[1]
        else:
            # Multiclass - use first class for importance
            shap_values = shap_values_raw[0]
    else:
        shap_values = shap_values_raw
        # Handle 3D array for binary/multiclass (newer SHAP versions)
        if shap_values.ndim == 3:
            if shap_values.shape[2] == 2:
                # Binary classification: take positive class (index 1)
                shap_values = shap_values[:, :, 1]
            else:
                # Multiclass: aggregate across classes (mean absolute)
                shap_values = np.mean(np.abs(shap_values), axis=2)

    # Validate SHAP values shape
    if shap_values.shape != (n_samples, n_features):
        raise RuntimeError(
            f"Unexpected SHAP values shape: {shap_values.shape}, expected ({n_samples}, {n_features})"
        )

    # Compute feature importance as mean absolute SHAP value
    importances = np.mean(np.abs(shap_values), axis=0)

    # Sort by importance (descending)
    sorted_idx = np.argsort(importances)[::-1]

    # Get base value (expected value)
    base_value = explainer.expected_value
    if isinstance(base_value, list | np.ndarray):
        # For binary/multiclass, take positive class or first class
        base_value = base_value[1] if len(base_value) == 2 else base_value[0]

    # Determine model type
    model_type = f"{type(model).__module__}.{type(model).__name__}"

    return {
        "shap_values": shap_values,
        "importances": importances[sorted_idx],
        "feature_names": [feature_names[i] for i in sorted_idx],
        "base_value": float(base_value),
        "n_features": n_features,
        "n_samples": n_samples,
        "model_type": model_type,
        "explainer_type": explainer_type_used,
        "additivity_verified": check_additivity,
    }

compute_mda_importance

compute_mda_importance(
    model,
    X,
    y,
    feature_names=None,
    feature_groups=None,
    removal_method="mean",
    scoring=None,
    _n_jobs=None,
)

Compute Mean Decrease in Accuracy (MDA) by feature removal.

MDA measures the drop in model performance when features are removed or neutralized. Unlike Permutation Feature Importance (PFI) which shuffles feature values, MDA replaces feature values with a constant (mean, median, or zero), simulating complete feature unavailability.

This approach naturally supports feature groups (e.g., one-hot encoded categoricals, related features like lat/lon) by removing multiple features simultaneously and measuring the joint importance.

Supported Models: - Any fitted sklearn-compatible estimator with score() or predict() method - Classification: LogisticRegression, RandomForest, XGBoost, LightGBM, etc. - Regression: LinearRegression, Ridge, GradientBoosting, etc.

Parameters

model : Any Fitted sklearn-compatible estimator (must have score() or predict() method) X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) y : Union[pl.Series, pd.Series, np.ndarray] Target values (n_samples,) feature_names : list[str] | None, default None Feature names for labeling. If None, uses column names from DataFrame or generates numeric names for arrays feature_groups : dict[str, list[str]] | None, default None Dictionary mapping group names to lists of feature names. When provided, computes importance for feature groups instead of individual features. Example: {"location": ["lat", "lon"], "time": ["hour", "day", "month"]} removal_method : str, default "mean" How to neutralize features: - "mean": Replace with feature mean (recommended for continuous features) - "median": Replace with feature median (robust to outliers) - "zero": Replace with zero (can distort if zero is out-of-distribution) scoring : str | Callable | None, default None Scoring function to evaluate model performance. If None, uses model's default score method. Common options: - Classification: 'accuracy', 'roc_auc', 'f1' - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error' n_jobs : int | None, default None Number of parallel jobs for scoring (-1 for all CPUs). Note: Parallelization is limited compared to sklearn's implementation since we need to modify data for each feature.

Returns

dict[str, Any] Dictionary with MDA importance results: - importances: Performance drop per feature/group (sorted descending) - feature_names: Feature/group labels (sorted by importance) - baseline_score: Model score before feature removal - removal_method: Method used to neutralize features - scoring: Scoring function used - n_features: Number of features/groups evaluated

Raises

ValueError If removal_method is not one of: "mean", "median", "zero" ValueError If feature_groups contains unknown feature names ValueError If X and y have different numbers of samples

Examples

from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification import numpy as np

Train a simple model

X, y = make_classification(n_samples=1000, n_features=10, n_informative=3, random_state=42) model = RandomForestClassifier(n_estimators=50, random_state=42) model.fit(X, y)

Compute MDA importance

mda = compute_mda_importance( ... model=model, ... X=X, ... y=y, ... removal_method='mean', ... scoring='accuracy' ... )

Examine results

print(f"Baseline score: {mda['baseline_score']:.3f}") print(f"Most important feature: {next(iter(mda['feature_names']))}") print(f"Importance (accuracy drop): {next(iter(mda['importances'])):.3f}") Baseline score: 0.920 Most important feature: feature_3 Importance (accuracy drop): 0.124

Feature Groups Example:

feature_groups = { ... "category_A": ["feature_0", "feature_1", "feature_2"], ... "category_B": ["feature_3", "feature_4"], ... "numeric": ["feature_5", "feature_6", "feature_7"] ... }

mda_groups = compute_mda_importance( ... model=model, ... X=X, ... y=y, ... feature_groups=feature_groups, ... removal_method='mean' ... )

See which group is most important

print(f"Most important group: {next(iter(mda_groups['feature_names']))}") print(f"Group importance: {next(iter(mda_groups['importances'])):.3f}")

Notes

MDA vs PFI (Permutation Feature Importance):

MDA Characteristics: - Removes feature completely (sets to constant) - Simulates true feature unavailability - May show larger importance drops than PFI - Naturally supports feature groups - Similar computational cost to PFI

PFI Characteristics: - Shuffles feature values (breaks feature-target relationship) - Preserves feature distribution - May show smaller importance drops - Requires additional logic for feature groups - More commonly used in literature

When to use MDA: - Want to simulate complete feature removal - Need to evaluate feature groups jointly - Want more conservative importance estimates - Comparing "with feature" vs "without feature" scenarios

When to use PFI instead: - Want to match published baselines (PFI more common) - Need to preserve feature distributions - Want less conservative importance estimates

Feature Groups: Feature groups are useful for: - One-hot encoded categoricals (remove all dummy variables together) - Related features (lat/lon, year/month/day) - Multi-dimensional embeddings - Polynomial features of same base feature

Removing feature groups jointly captures their combined importance, which can be higher than the sum of individual importances due to interactions between features in the group.

Removal Methods:

  • mean: Most common choice for continuous features. Replaces feature with its training set mean. This is a "neutral" value that doesn't distort the model's input distribution.

  • median: More robust to outliers than mean. Useful for features with skewed distributions or outliers.

  • zero: Simple but can be problematic if zero is out-of-distribution for a feature (e.g., if feature is always positive). Use with caution.

Computational Cost: - Time complexity: O(n_features * prediction_time) or O(n_groups * prediction_time) - Same order as PFI (one evaluation per feature/group) - Cannot be trivially parallelized (requires data modification) - Faster than SHAP for large datasets

Comparison with Other Methods:

Method Speed Groups Local Theory Bias
MDI Fastest No No Weak Yes
PFI Slow Hard No Strong No
MDA Slow Yes No Strong No
SHAP Medium No Yes Strongest No
  • Speed: MDI instant (from training), PFI/MDA slow (repeated scoring), SHAP medium (depends on data size)
  • Groups: MDA naturally supports, PFI requires workarounds, MDI/SHAP no
  • Local: SHAP provides per-sample importances, others are global only
  • Theory: SHAP has strongest game-theoretic foundation, PFI/MDA empirical
  • Bias: MDI biased toward high-cardinality features, others unbiased

Best Practices: - Use validation/test set (not training data) for unbiased estimates - Compare MDA with PFI and SHAP for robustness - Use feature groups for one-hot encoded categoricals - Choose removal_method based on feature distributions - Verify model still makes reasonable predictions after removal

References

.. [ALT] A. Altmann, L. Toloşi, O. Sander, T. Lengauer, "Permutation importance: a corrected feature importance measure", Bioinformatics, 26(10), 1340-1347, 2010. .. [FIS] A. Fisher, C. Rudin, F. Dominici, "All Models are Wrong, but Many are Useful: Learning a Variable's Importance by Studying an Entire Class of Prediction Models Simultaneously", JMLR, 20(177):1-81, 2019.

Source code in src/ml4t/diagnostic/metrics/importance_mda.py
def compute_mda_importance(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    y: Union[pl.Series, pd.Series, "NDArray[Any]"],
    feature_names: list[str] | None = None,
    feature_groups: dict[str, list[str]] | None = None,
    removal_method: str = "mean",
    scoring: str | Callable | None = None,
    _n_jobs: int | None = None,
) -> dict[str, Any]:
    """Compute Mean Decrease in Accuracy (MDA) by feature removal.

    MDA measures the drop in model performance when features are removed or
    neutralized. Unlike Permutation Feature Importance (PFI) which shuffles
    feature values, MDA replaces feature values with a constant (mean, median,
    or zero), simulating complete feature unavailability.

    This approach naturally supports feature groups (e.g., one-hot encoded
    categoricals, related features like lat/lon) by removing multiple features
    simultaneously and measuring the joint importance.

    **Supported Models**:
    - Any fitted sklearn-compatible estimator with `score()` or `predict()` method
    - Classification: LogisticRegression, RandomForest, XGBoost, LightGBM, etc.
    - Regression: LinearRegression, Ridge, GradientBoosting, etc.

    Parameters
    ----------
    model : Any
        Fitted sklearn-compatible estimator (must have `score()` or `predict()` method)
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    y : Union[pl.Series, pd.Series, np.ndarray]
        Target values (n_samples,)
    feature_names : list[str] | None, default None
        Feature names for labeling. If None, uses column names from DataFrame
        or generates numeric names for arrays
    feature_groups : dict[str, list[str]] | None, default None
        Dictionary mapping group names to lists of feature names.
        When provided, computes importance for feature groups instead of
        individual features. Example: {"location": ["lat", "lon"],
        "time": ["hour", "day", "month"]}
    removal_method : str, default "mean"
        How to neutralize features:
        - "mean": Replace with feature mean (recommended for continuous features)
        - "median": Replace with feature median (robust to outliers)
        - "zero": Replace with zero (can distort if zero is out-of-distribution)
    scoring : str | Callable | None, default None
        Scoring function to evaluate model performance. If None, uses model's
        default score method. Common options:
        - Classification: 'accuracy', 'roc_auc', 'f1'
        - Regression: 'r2', 'neg_mean_squared_error', 'neg_mean_absolute_error'
    n_jobs : int | None, default None
        Number of parallel jobs for scoring (-1 for all CPUs).
        Note: Parallelization is limited compared to sklearn's implementation
        since we need to modify data for each feature.

    Returns
    -------
    dict[str, Any]
        Dictionary with MDA importance results:
        - importances: Performance drop per feature/group (sorted descending)
        - feature_names: Feature/group labels (sorted by importance)
        - baseline_score: Model score before feature removal
        - removal_method: Method used to neutralize features
        - scoring: Scoring function used
        - n_features: Number of features/groups evaluated

    Raises
    ------
    ValueError
        If removal_method is not one of: "mean", "median", "zero"
    ValueError
        If feature_groups contains unknown feature names
    ValueError
        If X and y have different numbers of samples

    Examples
    --------
    >>> from sklearn.ensemble import RandomForestClassifier
    >>> from sklearn.datasets import make_classification
    >>> import numpy as np
    >>>
    >>> # Train a simple model
    >>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=3, random_state=42)
    >>> model = RandomForestClassifier(n_estimators=50, random_state=42)
    >>> model.fit(X, y)
    >>>
    >>> # Compute MDA importance
    >>> mda = compute_mda_importance(
    ...     model=model,
    ...     X=X,
    ...     y=y,
    ...     removal_method='mean',
    ...     scoring='accuracy'
    ... )
    >>>
    >>> # Examine results
    >>> print(f"Baseline score: {mda['baseline_score']:.3f}")
    >>> print(f"Most important feature: {next(iter(mda['feature_names']))}")
    >>> print(f"Importance (accuracy drop): {next(iter(mda['importances'])):.3f}")
    Baseline score: 0.920
    Most important feature: feature_3
    Importance (accuracy drop): 0.124

    **Feature Groups Example**:

    >>> # Group related features (e.g., one-hot encoded categorical)
    >>> feature_groups = {
    ...     "category_A": ["feature_0", "feature_1", "feature_2"],
    ...     "category_B": ["feature_3", "feature_4"],
    ...     "numeric": ["feature_5", "feature_6", "feature_7"]
    ... }
    >>>
    >>> mda_groups = compute_mda_importance(
    ...     model=model,
    ...     X=X,
    ...     y=y,
    ...     feature_groups=feature_groups,
    ...     removal_method='mean'
    ... )
    >>>
    >>> # See which group is most important
    >>> print(f"Most important group: {next(iter(mda_groups['feature_names']))}")
    >>> print(f"Group importance: {next(iter(mda_groups['importances'])):.3f}")

    Notes
    -----
    **MDA vs PFI** (Permutation Feature Importance):

    **MDA Characteristics**:
    - Removes feature completely (sets to constant)
    - Simulates true feature unavailability
    - May show larger importance drops than PFI
    - Naturally supports feature groups
    - Similar computational cost to PFI

    **PFI Characteristics**:
    - Shuffles feature values (breaks feature-target relationship)
    - Preserves feature distribution
    - May show smaller importance drops
    - Requires additional logic for feature groups
    - More commonly used in literature

    **When to use MDA**:
    - Want to simulate complete feature removal
    - Need to evaluate feature groups jointly
    - Want more conservative importance estimates
    - Comparing "with feature" vs "without feature" scenarios

    **When to use PFI instead**:
    - Want to match published baselines (PFI more common)
    - Need to preserve feature distributions
    - Want less conservative importance estimates

    **Feature Groups**:
    Feature groups are useful for:
    - One-hot encoded categoricals (remove all dummy variables together)
    - Related features (lat/lon, year/month/day)
    - Multi-dimensional embeddings
    - Polynomial features of same base feature

    Removing feature groups jointly captures their combined importance,
    which can be higher than the sum of individual importances due to
    interactions between features in the group.

    **Removal Methods**:

    - **mean**: Most common choice for continuous features. Replaces feature
      with its training set mean. This is a "neutral" value that doesn't
      distort the model's input distribution.

    - **median**: More robust to outliers than mean. Useful for features with
      skewed distributions or outliers.

    - **zero**: Simple but can be problematic if zero is out-of-distribution
      for a feature (e.g., if feature is always positive). Use with caution.

    **Computational Cost**:
    - Time complexity: O(n_features * prediction_time) or O(n_groups * prediction_time)
    - Same order as PFI (one evaluation per feature/group)
    - Cannot be trivially parallelized (requires data modification)
    - Faster than SHAP for large datasets

    **Comparison with Other Methods**:

    | Method | Speed    | Groups | Local | Theory      | Bias |
    |--------|----------|--------|-------|-------------|------|
    | MDI    | Fastest  | No     | No    | Weak        | Yes  |
    | PFI    | Slow     | Hard   | No    | Strong      | No   |
    | MDA    | Slow     | Yes    | No    | Strong      | No   |
    | SHAP   | Medium   | No     | Yes   | Strongest   | No   |

    - **Speed**: MDI instant (from training), PFI/MDA slow (repeated scoring),
      SHAP medium (depends on data size)
    - **Groups**: MDA naturally supports, PFI requires workarounds, MDI/SHAP no
    - **Local**: SHAP provides per-sample importances, others are global only
    - **Theory**: SHAP has strongest game-theoretic foundation, PFI/MDA empirical
    - **Bias**: MDI biased toward high-cardinality features, others unbiased

    **Best Practices**:
    - Use validation/test set (not training data) for unbiased estimates
    - Compare MDA with PFI and SHAP for robustness
    - Use feature groups for one-hot encoded categoricals
    - Choose removal_method based on feature distributions
    - Verify model still makes reasonable predictions after removal

    References
    ----------
    .. [ALT] A. Altmann, L. Toloşi, O. Sander, T. Lengauer,
       "Permutation importance: a corrected feature importance measure",
       Bioinformatics, 26(10), 1340-1347, 2010.
    .. [FIS] A. Fisher, C. Rudin, F. Dominici,
       "All Models are Wrong, but Many are Useful: Learning a Variable's
       Importance by Studying an Entire Class of Prediction Models Simultaneously",
       JMLR, 20(177):1-81, 2019.
    """
    # Validate removal method
    valid_methods = ["mean", "median", "zero"]
    if removal_method not in valid_methods:
        raise ValueError(f"removal_method must be one of {valid_methods}, got '{removal_method}'")

    # Convert inputs to numpy arrays and extract feature names
    if isinstance(X, pl.DataFrame):
        if feature_names is None:
            feature_names = list(X.columns)  # Polars columns is already a list
        X_array = X.to_numpy()
    elif isinstance(X, pd.DataFrame):
        if feature_names is None:
            feature_names = X.columns.tolist()
        X_array = X.values
    else:
        X_array = np.asarray(X)
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X_array.shape[1])]

    y_array: NDArray[Any]
    if isinstance(y, pl.Series) or isinstance(y, pd.Series):
        y_array = y.to_numpy()
    else:
        y_array = np.asarray(y)

    # Validate dimensions
    n_samples, n_features = X_array.shape
    if len(y_array) != n_samples:
        raise ValueError(
            f"X and y have inconsistent numbers of samples: {n_samples} vs {len(y_array)}"
        )

    # Set up scoring function
    if scoring is None:
        scorer = None
        baseline_score = model.score(X_array, y_array)
        scoring_name = "default"
    else:
        from sklearn.metrics import get_scorer

        scorer = get_scorer(scoring) if isinstance(scoring, str) else scoring
        baseline_score = scorer(model, X_array, y_array)
        scoring_name = scoring if isinstance(scoring, str) else "custom"

    # Compute feature replacement values based on removal method
    if removal_method == "mean":
        replacement_values = np.mean(X_array, axis=0)
    elif removal_method == "median":
        replacement_values = np.median(X_array, axis=0)
    else:  # removal_method == "zero"
        replacement_values = np.zeros(n_features)

    # Determine whether we're evaluating individual features or groups
    if feature_groups is not None:
        # Validate feature groups (feature_names is always set by this point)
        assert feature_names is not None
        all_group_features: set[str] = set()
        for group_name, features in feature_groups.items():
            for feat in features:
                if feat not in feature_names:
                    raise ValueError(
                        f"Feature '{feat}' in group '{group_name}' not found in feature_names"
                    )
                all_group_features.add(feat)

        # Map feature names to indices
        feature_name_to_idx = {name: idx for idx, name in enumerate(feature_names)}

        # Compute importance for each group
        importances_list = []
        group_names = []

        for group_name, features in feature_groups.items():
            # Get indices for all features in this group
            feature_indices = [feature_name_to_idx[feat] for feat in features]

            # Create modified data with group features removed
            X_removed = X_array.copy()
            for idx in feature_indices:
                X_removed[:, idx] = replacement_values[idx]

            # Compute score with group removed
            removed_score = (
                model.score(X_removed, y_array)
                if scorer is None
                else scorer(model, X_removed, y_array)
            )

            # Importance is the drop in performance
            importance = baseline_score - removed_score
            importances_list.append(importance)
            group_names.append(group_name)

        importances = np.array(importances_list)
        eval_feature_names = group_names
        n_eval_features = len(feature_groups)

    else:
        # Compute importance for individual features
        importances_list = []

        for feature_idx in range(n_features):
            # Create modified data with feature removed
            X_removed = X_array.copy()
            X_removed[:, feature_idx] = replacement_values[feature_idx]

            # Compute score with feature removed
            removed_score = (
                model.score(X_removed, y_array)
                if scorer is None
                else scorer(model, X_removed, y_array)
            )

            # Importance is the drop in performance
            importance = baseline_score - removed_score
            importances_list.append(importance)

        importances = np.array(importances_list)
        eval_feature_names = feature_names
        n_eval_features = n_features

    # Sort by importance (descending)
    sorted_idx = np.argsort(importances)[::-1]

    # Type assertion: eval_feature_names is guaranteed to be set
    assert eval_feature_names is not None, "eval_feature_names should be set by this point"

    return {
        "importances": importances[sorted_idx],
        "feature_names": [eval_feature_names[i] for i in sorted_idx],
        "baseline_score": float(baseline_score),
        "removal_method": removal_method,
        "scoring": scoring_name,
        "n_features": n_eval_features,
    }

analyze_ml_importance

analyze_ml_importance(
    model,
    X,
    y,
    feature_names=None,
    methods=None,
    scoring=None,
    n_repeats=10,
    random_state=42,
)

Comprehensive ML feature importance analysis comparing multiple methods.

Run multiple importance methods and generate a comparison report with consensus ranking and interpretation.

Use this when you need to answer: "Which features does my model rely on, and do different importance methods agree?"

The integrated analysis includes: - Individual method results (MDI, PFI, MDA, SHAP) - Consensus ranking (features important across methods) - Method agreement/disagreement analysis - Auto-generated insights and warnings

Why Compare Methods?

Different importance methods measure different aspects: - MDI (Mean Decrease Impurity): Fast, but biased toward high-cardinality features - PFI (Permutation): Unbiased, measures predictive importance - MDA (Mean Decrease Accuracy): Similar to PFI but removes features completely - SHAP: Theoretically sound, based on game theory

Strong consensus across methods indicates robust feature importance. Disagreement suggests model-specific artifacts or feature interactions.

Parameters

model : Any Fitted model. Requirements vary by method: - MDI: Must have feature_importances_ (tree-based models) - PFI, MDA: Must have predict() or score() - SHAP: Must be compatible with TreeExplainer X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) y : Union[pl.Series, pd.Series, np.ndarray] Target values (n_samples,) feature_names : list[str] | None, default None Feature names for labeling. If None, uses column names from DataFrame or generates numeric names methods : list[str] | None, default ["mdi", "pfi", "shap"] Which methods to run. Options: "mdi", "pfi", "mda", "shap" scoring : str | Callable | None, default None Scoring metric for PFI and MDA n_repeats : int, default 10 Number of permutations for PFI random_state : int | None, default 42 Random seed for reproducibility

Returns

dict[str, Any] Comprehensive analysis results: - method_results: Dict of individual method outputs - consensus_ranking: Features ranked by average rank across methods - method_agreement: Spearman correlations between method rankings - top_features_consensus: Features in top 10 for ALL methods - warnings: Detected issues - interpretation: Auto-generated summary - methods_run: Methods successfully executed - methods_failed: Failed methods with error messages

Raises

ValueError If no methods specified or all methods fail

Examples

from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification

Create synthetic dataset

X, y = make_classification(n_samples=1000, n_features=10, random_state=42) model = RandomForestClassifier(n_estimators=50, random_state=42) model.fit(X, y)

Comprehensive importance analysis

result = analyze_ml_importance(model, X, y, methods=["mdi", "pfi"])

Quick summary

print(result["interpretation"])

Source code in src/ml4t/diagnostic/metrics/importance_analysis.py
def analyze_ml_importance(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    y: Union[pl.Series, pd.Series, "NDArray[Any]"],
    feature_names: list[str] | None = None,
    methods: list[str] | None = None,
    scoring: str | Callable | None = None,
    n_repeats: int = 10,
    random_state: int | None = 42,
) -> dict[str, Any]:
    """Comprehensive ML feature importance analysis comparing multiple methods.

    Run multiple importance methods and generate a comparison report with
    consensus ranking and interpretation.

    Use this when you need to answer: "Which features does my model rely on,
    and do different importance methods agree?"

    The integrated analysis includes:
    - Individual method results (MDI, PFI, MDA, SHAP)
    - Consensus ranking (features important across methods)
    - Method agreement/disagreement analysis
    - Auto-generated insights and warnings

    **Why Compare Methods?**

    Different importance methods measure different aspects:
    - **MDI** (Mean Decrease Impurity): Fast, but biased toward high-cardinality features
    - **PFI** (Permutation): Unbiased, measures predictive importance
    - **MDA** (Mean Decrease Accuracy): Similar to PFI but removes features completely
    - **SHAP**: Theoretically sound, based on game theory

    Strong consensus across methods indicates robust feature importance.
    Disagreement suggests model-specific artifacts or feature interactions.

    Parameters
    ----------
    model : Any
        Fitted model. Requirements vary by method:
        - MDI: Must have `feature_importances_` (tree-based models)
        - PFI, MDA: Must have `predict()` or `score()`
        - SHAP: Must be compatible with TreeExplainer
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    y : Union[pl.Series, pd.Series, np.ndarray]
        Target values (n_samples,)
    feature_names : list[str] | None, default None
        Feature names for labeling. If None, uses column names from DataFrame
        or generates numeric names
    methods : list[str] | None, default ["mdi", "pfi", "shap"]
        Which methods to run. Options: "mdi", "pfi", "mda", "shap"
    scoring : str | Callable | None, default None
        Scoring metric for PFI and MDA
    n_repeats : int, default 10
        Number of permutations for PFI
    random_state : int | None, default 42
        Random seed for reproducibility

    Returns
    -------
    dict[str, Any]
        Comprehensive analysis results:
        - method_results: Dict of individual method outputs
        - consensus_ranking: Features ranked by average rank across methods
        - method_agreement: Spearman correlations between method rankings
        - top_features_consensus: Features in top 10 for ALL methods
        - warnings: Detected issues
        - interpretation: Auto-generated summary
        - methods_run: Methods successfully executed
        - methods_failed: Failed methods with error messages

    Raises
    ------
    ValueError
        If no methods specified or all methods fail

    Examples
    --------
    >>> from sklearn.ensemble import RandomForestClassifier
    >>> from sklearn.datasets import make_classification
    >>>
    >>> # Create synthetic dataset
    >>> X, y = make_classification(n_samples=1000, n_features=10, random_state=42)
    >>> model = RandomForestClassifier(n_estimators=50, random_state=42)
    >>> model.fit(X, y)
    >>>
    >>> # Comprehensive importance analysis
    >>> result = analyze_ml_importance(model, X, y, methods=["mdi", "pfi"])
    >>>
    >>> # Quick summary
    >>> print(result["interpretation"])
    """
    if methods is None:
        methods = ["mdi", "pfi", "shap"]

    if not methods:
        raise ValueError("At least one method must be specified")

    # Extract feature names if not provided
    if feature_names is None:
        if isinstance(X, pl.DataFrame | pd.DataFrame):
            feature_names = list(X.columns)
        else:
            # Generate numeric feature names
            n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
            feature_names = [f"f{i}" for i in range(n_features)]

    # Run each method with try/except for optional dependencies
    results = {}
    method_failures = []

    if "mdi" in methods:
        try:
            results["mdi"] = compute_mdi_importance(model, feature_names=feature_names)
        except Exception as e:
            method_failures.append(("mdi", str(e)))

    if "pfi" in methods:
        try:
            results["pfi"] = compute_permutation_importance(
                model,
                X,
                y,
                feature_names=feature_names,
                scoring=scoring,
                n_repeats=n_repeats,
                random_state=random_state,
            )
        except Exception as e:
            method_failures.append(("pfi", str(e)))

    if "mda" in methods:
        try:
            results["mda"] = compute_mda_importance(
                model, X, y, feature_names=feature_names, scoring=scoring
            )
        except Exception as e:
            method_failures.append(("mda", str(e)))

    if "shap" in methods:
        try:
            results["shap"] = compute_shap_importance(model, X, feature_names=feature_names)
        except ImportError:
            method_failures.append(
                (
                    "shap",
                    "shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
                )
            )
        except Exception as e:
            method_failures.append(("shap", str(e)))

    # Check if at least one method succeeded
    if not results:
        error_msg = "All methods failed:\n" + "\n".join(
            f"  - {method}: {error}" for method, error in method_failures
        )
        raise ValueError(error_msg)

    # 2. Compute consensus ranking
    # Convert each method's importance to rankings (1 = most important)
    rankings = {}
    for method_name, result in results.items():
        # Get feature names and importances for this method
        method_feature_names = result["feature_names"]

        if method_name == "pfi":
            importances = result["importances_mean"]
        elif method_name in ["shap", "mdi", "mda"]:
            importances = result["importances"]
        else:
            # Shouldn't happen, but handle gracefully
            continue

        # Create a mapping from feature name to importance
        feature_to_importance = dict(zip(method_feature_names, importances, strict=False))

        # Map to our canonical feature_names list (handle missing features)
        importance_values = np.array(
            [feature_to_importance.get(fname, 0.0) for fname in feature_names]
        )

        # Rank (higher importance = lower rank number, i.e., rank 0 is most important)
        ranks = np.argsort(np.argsort(importance_values)[::-1])
        rankings[method_name] = ranks

    # Average ranks across methods
    avg_ranks = np.mean(list(rankings.values()), axis=0)
    consensus_order = np.argsort(avg_ranks)

    # Get feature names in consensus order
    consensus_ranking = [feature_names[i] for i in consensus_order]

    # 3. Compute method agreement (Spearman correlation between rankings)
    method_agreement = {}
    method_names = list(rankings.keys())
    for i, m1 in enumerate(method_names):
        for m2 in method_names[i + 1 :]:
            corr, _ = spearmanr(rankings[m1], rankings[m2])
            method_agreement[f"{m1}_vs_{m2}"] = float(corr)

    # 4. Identify consensus top features (top 10 in all methods)
    top_n = 10
    top_features_by_method = {}
    for method_name, result in results.items():
        # Get top N feature names from this method
        method_top_features = result["feature_names"][:top_n]
        top_features_by_method[method_name] = set(method_top_features)

    consensus_top = (
        set.intersection(*top_features_by_method.values()) if top_features_by_method else set()
    )

    # 5. Generate warnings
    warnings = []

    # Warning: High MDI but low PFI (possible overfitting)
    if "mdi" in results and "pfi" in results:
        mdi_top = set(results["mdi"]["feature_names"][:5])
        pfi_top = set(results["pfi"]["feature_names"][:5])
        disagreement = mdi_top - pfi_top
        if disagreement:
            warnings.append(
                f"Features {disagreement} rank high in MDI but not PFI - possible overfitting to tree structure"
            )

    # Warning: Low agreement between methods
    if method_agreement:
        min_agreement = min(method_agreement.values())
        if min_agreement < 0.5:
            warnings.append(
                f"Low agreement between methods (min correlation: {min_agreement:.2f}) - results may be unreliable"
            )

    # Add method failures to warnings
    if method_failures:
        for method, error in method_failures:
            warnings.append(f"Method '{method}' failed: {error}")

    # 6. Generate interpretation
    interpretation = _generate_ml_importance_interpretation(
        consensus_ranking[:10],
        method_agreement,
        warnings,
        len(consensus_top),
    )

    return {
        "method_results": results,
        "consensus_ranking": consensus_ranking,
        "method_agreement": method_agreement,
        "top_features_consensus": list(consensus_top),
        "warnings": warnings,
        "interpretation": interpretation,
        "methods_run": list(results.keys()),
        "methods_failed": method_failures,
    }

compute_h_statistic

compute_h_statistic(
    model,
    X,
    feature_pairs=None,
    feature_names=None,
    n_samples=100,
    grid_resolution=20,
)

Compute Friedman's H-statistic for feature interaction strength.

The H-statistic (Friedman & Popescu 2008) measures how much of the variation in predictions can be attributed to interactions between feature pairs, beyond their individual main effects.

Algorithm: 1. For each feature pair (j, k): - Compute 2D partial dependence PD_{jk}(x_j, x_k) - Compute 1D partial dependences PD_j(x_j) and PD_k(x_k) - Compute H^2 = sum[PD_{jk} - PD_j - PD_k]^2 / sum[PD_{jk}^2] - H ranges from 0 (no interaction) to 1 (pure interaction)

Parameters

model : Any Trained model with .predict() method X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) feature_pairs : list[tuple[int, int]] | list[tuple[str, str]] | None, default None List of (i, j) pairs to test. If None, tests all pairs. feature_names : list[str] | None, default None Feature names. If None, uses column names or f0, f1, ... n_samples : int, default 100 Number of samples to use for PD computation (subsample if needed) grid_resolution : int, default 20 Grid size for PD evaluation

Returns

dict[str, Any] Dictionary with: - h_statistics: List of (feature_i, feature_j, H_value) sorted by H descending - feature_names: List of feature names used - n_features: Number of features - n_pairs_tested: Number of pairs tested - computation_time: Time in seconds

References
  • Friedman, J. H., & Popescu, B. E. (2008). Predictive learning via rule ensembles. The Annals of Applied Statistics, 2(3), 916-954.
Examples

import lightgbm as lgb model = lgb.LGBMRegressor() model.fit(X_train, y_train) results = compute_h_statistic(model, X_test) for feat_i, feat_j, h_val in results["h_statistics"]: ... print(f" {feat_i} x {feat_j}: H = {h_val:.4f}")

Source code in src/ml4t/diagnostic/metrics/interactions.py
def compute_h_statistic(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    feature_pairs: list[tuple[int, int]] | list[tuple[str, str]] | None = None,
    feature_names: list[str] | None = None,
    n_samples: int = 100,
    grid_resolution: int = 20,
) -> dict[str, Any]:
    """Compute Friedman's H-statistic for feature interaction strength.

    The H-statistic (Friedman & Popescu 2008) measures how much of the variation
    in predictions can be attributed to interactions between feature pairs, beyond
    their individual main effects.

    **Algorithm**:
    1. For each feature pair (j, k):
       - Compute 2D partial dependence PD_{jk}(x_j, x_k)
       - Compute 1D partial dependences PD_j(x_j) and PD_k(x_k)
       - Compute H^2 = sum[PD_{jk} - PD_j - PD_k]^2 / sum[PD_{jk}^2]
       - H ranges from 0 (no interaction) to 1 (pure interaction)

    Parameters
    ----------
    model : Any
        Trained model with .predict() method
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    feature_pairs : list[tuple[int, int]] | list[tuple[str, str]] | None, default None
        List of (i, j) pairs to test. If None, tests all pairs.
    feature_names : list[str] | None, default None
        Feature names. If None, uses column names or f0, f1, ...
    n_samples : int, default 100
        Number of samples to use for PD computation (subsample if needed)
    grid_resolution : int, default 20
        Grid size for PD evaluation

    Returns
    -------
    dict[str, Any]
        Dictionary with:
        - h_statistics: List of (feature_i, feature_j, H_value) sorted by H descending
        - feature_names: List of feature names used
        - n_features: Number of features
        - n_pairs_tested: Number of pairs tested
        - computation_time: Time in seconds

    References
    ----------
    - Friedman, J. H., & Popescu, B. E. (2008). Predictive learning via rule ensembles.
      The Annals of Applied Statistics, 2(3), 916-954.

    Examples
    --------
    >>> import lightgbm as lgb
    >>> model = lgb.LGBMRegressor()
    >>> model.fit(X_train, y_train)
    >>> results = compute_h_statistic(model, X_test)
    >>> for feat_i, feat_j, h_val in results["h_statistics"]:
    ...     print(f"  {feat_i} x {feat_j}: H = {h_val:.4f}")
    """
    start_time = time.time()

    # Convert input to numpy
    X_array, feature_names_list = _to_numpy_with_feature_names(X, feature_names)
    feature_names = feature_names_list

    n_total_samples, n_features = X_array.shape

    # Subsample if needed
    if n_total_samples > n_samples:
        rng = np.random.default_rng(42)
        indices = rng.choice(n_total_samples, size=n_samples, replace=False)
        X_sample = X_array[indices]
    else:
        X_sample = X_array
        n_samples = n_total_samples

    # Generate feature pairs if not provided - always convert to int pairs
    pairs_int: list[tuple[int, int]]
    if feature_pairs is None:
        # Test all pairs
        pairs_int = [(i, j) for i in range(n_features) for j in range(i + 1, n_features)]
    elif feature_names and len(feature_pairs) > 0 and isinstance(feature_pairs[0][0], str):
        # Convert string pairs to indices
        name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
        pairs_int = [(name_to_idx[str(i)], name_to_idx[str(j)]) for i, j in feature_pairs]
    else:
        # Already integer pairs
        pairs_int = [(int(i), int(j)) for i, j in feature_pairs]

    # Ensure feature_names is a list for indexing
    feature_names_list = list(feature_names)

    h_results: list[tuple[str, str, float]] = []

    for feat_i, feat_j in pairs_int:
        # Create grids for features i and j
        x_i_grid = np.linspace(
            float(X_sample[:, feat_i].min()), float(X_sample[:, feat_i].max()), grid_resolution
        )
        x_j_grid = np.linspace(
            float(X_sample[:, feat_j].min()), float(X_sample[:, feat_j].max()), grid_resolution
        )

        # Compute 2D partial dependence PD_{ij}
        pd_2d = np.zeros((grid_resolution, grid_resolution))
        for gi, x_i_val in enumerate(x_i_grid):
            for gj, x_j_val in enumerate(x_j_grid):
                # Replace features i and j with grid values
                X_temp = X_sample.copy()
                X_temp[:, feat_i] = x_i_val
                X_temp[:, feat_j] = x_j_val
                # Average prediction over all samples
                pd_2d[gi, gj] = model.predict(X_temp).mean()

        # Compute 1D partial dependences PD_i and PD_j
        pd_i = np.zeros(grid_resolution)
        for gi, x_i_val in enumerate(x_i_grid):
            X_temp = X_sample.copy()
            X_temp[:, feat_i] = x_i_val
            pd_i[gi] = model.predict(X_temp).mean()

        pd_j = np.zeros(grid_resolution)
        for gj, x_j_val in enumerate(x_j_grid):
            X_temp = X_sample.copy()
            X_temp[:, feat_j] = x_j_val
            pd_j[gj] = model.predict(X_temp).mean()

        # Compute H-statistic
        # H^2 = sum[PD_{ij} - PD_i - PD_j + PD_const]^2 / sum[PD_{ij}^2]

        # For numerical stability, center everything
        pd_const = pd_2d.mean()
        pd_i_centered = pd_i - pd_const
        pd_j_centered = pd_j - pd_const
        pd_2d_centered = pd_2d - pd_const

        # Interaction component: PD_{ij} - PD_i - PD_j
        # Need to broadcast pd_i and pd_j to 2D
        pd_i_broadcast = pd_i_centered[:, np.newaxis]  # Shape: (grid_resolution, 1)
        pd_j_broadcast = pd_j_centered[np.newaxis, :]  # Shape: (1, grid_resolution)

        interaction = pd_2d_centered - pd_i_broadcast - pd_j_broadcast

        # H-statistic
        numerator = np.sum(interaction**2)
        denominator = np.sum(pd_2d_centered**2)

        if denominator > 1e-10:  # Avoid division by zero
            h_squared = numerator / denominator
            h_stat = np.sqrt(max(0, h_squared))  # Ensure non-negative
        else:
            h_stat = 0.0

        h_results.append((feature_names_list[feat_i], feature_names_list[feat_j], float(h_stat)))

    # Sort by H-statistic descending
    h_results.sort(key=lambda x: x[2], reverse=True)

    computation_time = time.time() - start_time

    return {
        "h_statistics": h_results,
        "feature_names": feature_names,
        "n_features": n_features,
        "n_pairs_tested": len(h_results),
        "n_samples_used": n_samples,
        "grid_resolution": grid_resolution,
        "computation_time": computation_time,
    }

compute_shap_interactions

compute_shap_interactions(
    model,
    X,
    feature_names=None,
    _check_additivity=False,
    max_samples=None,
    top_k=None,
)

Compute SHAP interaction values for feature pairs.

SHAP interaction values decompose the SHAP value of each feature into: - Main effect (the feature's individual contribution) - Interaction effects (how the feature's impact changes with other features)

Parameters

model : Any Trained tree-based model X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) feature_names : list[str] | None, default None Feature names. If None, uses column names or f0, f1, ... _check_additivity : bool, default False Internal parameter (not used for interaction values) max_samples : int | None, default None Maximum samples to use (subsample if larger) top_k : int | None, default None Return only top K interactions by absolute magnitude

Returns

dict[str, Any] Dictionary with: - interaction_matrix: (n_features, n_features) mean absolute interactions - feature_names: List of feature names - top_interactions: List of (feature_i, feature_j, mean_interaction) sorted by magnitude - n_features: Number of features - n_samples_used: Number of samples used - computation_time: Time in seconds

Notes
  • Requires shap package (install with: pip install ml4t-diagnostic[ml])
  • Only works with tree-based models (uses TreeExplainer)
  • Interaction matrix is symmetric: interaction(i,j) = interaction(j,i)
Source code in src/ml4t/diagnostic/metrics/interactions.py
def compute_shap_interactions(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    feature_names: list[str] | None = None,
    _check_additivity: bool = False,
    max_samples: int | None = None,
    top_k: int | None = None,
) -> dict[str, Any]:
    """Compute SHAP interaction values for feature pairs.

    SHAP interaction values decompose the SHAP value of each feature into:
    - Main effect (the feature's individual contribution)
    - Interaction effects (how the feature's impact changes with other features)

    Parameters
    ----------
    model : Any
        Trained tree-based model
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    feature_names : list[str] | None, default None
        Feature names. If None, uses column names or f0, f1, ...
    _check_additivity : bool, default False
        Internal parameter (not used for interaction values)
    max_samples : int | None, default None
        Maximum samples to use (subsample if larger)
    top_k : int | None, default None
        Return only top K interactions by absolute magnitude

    Returns
    -------
    dict[str, Any]
        Dictionary with:
        - interaction_matrix: (n_features, n_features) mean absolute interactions
        - feature_names: List of feature names
        - top_interactions: List of (feature_i, feature_j, mean_interaction) sorted by magnitude
        - n_features: Number of features
        - n_samples_used: Number of samples used
        - computation_time: Time in seconds

    Notes
    -----
    - Requires shap package (install with: pip install ml4t-diagnostic[ml])
    - Only works with tree-based models (uses TreeExplainer)
    - Interaction matrix is symmetric: interaction(i,j) = interaction(j,i)
    """
    start_time = time.time()

    # Check shap availability
    try:
        import shap
    except ImportError as e:
        raise ImportError(
            "SHAP is required for interaction values. "
            "Install with: pip install ml4t-diagnostic[ml] "
            "or: pip install shap>=0.43.0"
        ) from e

    # Convert input to numpy and extract feature names
    X_array, feature_names = _to_numpy_with_feature_names(X, feature_names)

    n_total_samples, n_features = X_array.shape

    # Subsample if needed
    if max_samples is not None and n_total_samples > max_samples:
        rng = np.random.default_rng(42)
        indices = rng.choice(n_total_samples, size=max_samples, replace=False)
        X_sample = X_array[indices]
        n_samples_used = max_samples
    else:
        X_sample = X_array
        n_samples_used = n_total_samples

    # Compute SHAP interaction values using TreeExplainer
    explainer = shap.TreeExplainer(model)
    shap_interaction_values = explainer.shap_interaction_values(X_sample)

    # Handle multi-output models (classification)
    if isinstance(shap_interaction_values, list):
        # List format: use positive class for binary, average for multiclass
        if len(shap_interaction_values) == 2:
            shap_interaction_values = shap_interaction_values[1]
        else:
            shap_interaction_values = np.mean(shap_interaction_values, axis=0)

    # Check if we have a 4D array (n_samples, n_features, n_features, n_classes)
    if shap_interaction_values.ndim == 4:
        if shap_interaction_values.shape[-1] == 2:
            # Binary classification: use positive class (index 1)
            shap_interaction_values = shap_interaction_values[:, :, :, 1]
        else:
            # Multiclass: average absolute values across classes
            shap_interaction_values = np.mean(np.abs(shap_interaction_values), axis=-1)

    # Shape should now be: (n_samples, n_features, n_features)

    # Compute mean absolute interaction matrix
    interaction_matrix = np.mean(np.abs(shap_interaction_values), axis=0)

    # Ensure 2D matrix (n_features, n_features)
    if interaction_matrix.ndim != 2:
        raise ValueError(
            f"Interaction matrix should be 2D but got shape {interaction_matrix.shape}. "
            f"Raw SHAP values shape: {shap_interaction_values.shape}"
        )

    # Extract top interactions (off-diagonal, upper triangle to avoid duplicates)
    interactions_list = []
    for i in range(n_features):
        for j in range(i + 1, n_features):  # Upper triangle only
            mean_interaction = float(interaction_matrix[i, j])
            interactions_list.append((feature_names[i], feature_names[j], mean_interaction))

    # Sort by absolute interaction strength descending
    interactions_list.sort(key=lambda x: abs(x[2]), reverse=True)

    # Limit to top K if requested
    if top_k is not None:
        interactions_list = interactions_list[:top_k]

    computation_time = time.time() - start_time

    return {
        "interaction_matrix": interaction_matrix,
        "feature_names": feature_names,
        "top_interactions": interactions_list,
        "n_features": n_features,
        "n_samples_used": n_samples_used,
        "computation_time": computation_time,
    }

analyze_interactions

analyze_interactions(
    model,
    X,
    y,
    feature_pairs=None,
    methods=None,
    n_quantiles=5,
    grid_resolution=20,
    max_samples=200,
)

Comprehensive feature interaction analysis comparing multiple methods.

Run multiple interaction detection methods and generate a comparison report with consensus ranking and interpretation.

Use this when you need to answer: "Which feature pairs interact in my model, and do different interaction methods agree?"

The integrated analysis includes: - Individual method results (Conditional IC, H-statistic, SHAP interactions) - Consensus ranking (interactions important across methods) - Method agreement/disagreement analysis - Auto-generated insights and warnings

Parameters

model : Any Fitted model. Requirements vary by method: - Conditional IC: Not used (analyzes feature correlations) - H-statistic: Must have predict() method - SHAP: Must be compatible with TreeExplainer X : Union[pl.DataFrame, pd.DataFrame, np.ndarray] Feature matrix (n_samples, n_features) y : Union[pl.Series, pd.Series, np.ndarray] Target values (n_samples,) feature_pairs : list[tuple[str, str]] | None, default None Specific feature pairs to analyze. If None, tests all pairs. methods : list[str] | None, default ["conditional_ic", "h_statistic", "shap"] Which methods to run. n_quantiles : int, default 5 Number of quantile bins for Conditional IC grid_resolution : int, default 20 Grid size for partial dependence in H-statistic max_samples : int, default 200 Maximum samples for SHAP and H-statistic

Returns

dict[str, Any] Comprehensive analysis results: - method_results: Dict of individual method outputs - consensus_ranking: Feature pairs ranked by average rank across methods - method_agreement: Spearman correlations between method rankings - top_interactions_consensus: Pairs in top 10 for ALL methods - warnings: Detected issues - interpretation: Auto-generated summary - methods_run: Methods successfully executed - methods_failed: Failed methods with error messages

Raises

ValueError If all methods fail or no methods specified

Source code in src/ml4t/diagnostic/metrics/interactions.py
def analyze_interactions(
    model: Any,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    y: Union[pl.Series, pd.Series, "NDArray[Any]"],
    feature_pairs: list[tuple[str, str]] | None = None,
    methods: list[str] | None = None,
    n_quantiles: int = 5,
    grid_resolution: int = 20,
    max_samples: int = 200,
) -> dict[str, Any]:
    """Comprehensive feature interaction analysis comparing multiple methods.

    Run multiple interaction detection methods and generate a comparison report
    with consensus ranking and interpretation.

    Use this when you need to answer: "Which feature pairs interact in my model,
    and do different interaction methods agree?"

    The integrated analysis includes:
    - Individual method results (Conditional IC, H-statistic, SHAP interactions)
    - Consensus ranking (interactions important across methods)
    - Method agreement/disagreement analysis
    - Auto-generated insights and warnings

    Parameters
    ----------
    model : Any
        Fitted model. Requirements vary by method:
        - Conditional IC: Not used (analyzes feature correlations)
        - H-statistic: Must have `predict()` method
        - SHAP: Must be compatible with TreeExplainer
    X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
        Feature matrix (n_samples, n_features)
    y : Union[pl.Series, pd.Series, np.ndarray]
        Target values (n_samples,)
    feature_pairs : list[tuple[str, str]] | None, default None
        Specific feature pairs to analyze. If None, tests all pairs.
    methods : list[str] | None, default ["conditional_ic", "h_statistic", "shap"]
        Which methods to run.
    n_quantiles : int, default 5
        Number of quantile bins for Conditional IC
    grid_resolution : int, default 20
        Grid size for partial dependence in H-statistic
    max_samples : int, default 200
        Maximum samples for SHAP and H-statistic

    Returns
    -------
    dict[str, Any]
        Comprehensive analysis results:
        - method_results: Dict of individual method outputs
        - consensus_ranking: Feature pairs ranked by average rank across methods
        - method_agreement: Spearman correlations between method rankings
        - top_interactions_consensus: Pairs in top 10 for ALL methods
        - warnings: Detected issues
        - interpretation: Auto-generated summary
        - methods_run: Methods successfully executed
        - methods_failed: Failed methods with error messages

    Raises
    ------
    ValueError
        If all methods fail or no methods specified
    """
    if methods is None:
        methods = ["conditional_ic", "h_statistic", "shap"]

    if not methods:
        raise ValueError("At least one method must be specified")

    # Extract feature names if not provided
    if isinstance(X, pl.DataFrame | pd.DataFrame):
        feature_names = list(X.columns)
    else:
        # Generate numeric feature names
        n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
        feature_names = [f"f{i}" for i in range(n_features)]

    # Determine feature pairs to analyze
    if feature_pairs is None:
        # Test all pairs
        n_features = len(feature_names)
        all_pairs = []
        for i in range(n_features):
            for j in range(i + 1, n_features):
                all_pairs.append((feature_names[i], feature_names[j]))
        feature_pairs = all_pairs
    else:
        # Validate provided pairs
        feature_set = set(feature_names)
        for pair in feature_pairs:
            if len(pair) != 2:
                raise ValueError(f"Feature pair must have exactly 2 elements: {pair}")
            if pair[0] not in feature_set or pair[1] not in feature_set:
                raise ValueError(
                    f"Feature pair contains unknown features: {pair}. Available features: {feature_names}"
                )

    # Run each method with try/except for optional dependencies and errors
    results = {}
    method_failures = []

    if "conditional_ic" in methods:
        try:
            # For Conditional IC, we need to run it for each pair
            ic_results: list[tuple[str, str, float | None]] = []
            for feat_a, feat_b in feature_pairs:
                # Extract columns
                x_a: pl.Series | pd.Series | NDArray[Any]
                x_b: pl.Series | pd.Series | NDArray[Any]
                if isinstance(X, pl.DataFrame) or isinstance(X, pd.DataFrame):
                    x_a = X[feat_a]
                    x_b = X[feat_b]
                else:
                    # numpy array - need to find indices
                    idx_a = feature_names.index(feat_a)
                    idx_b = feature_names.index(feat_b)
                    X_arr = cast("NDArray[Any]", X)
                    x_a = X_arr[:, idx_a]
                    x_b = X_arr[:, idx_b]

                result = compute_conditional_ic(
                    feature_a=x_a,
                    feature_b=x_b,
                    forward_returns=y,
                    n_quantiles=n_quantiles,
                )

                # Extract interaction strength metric
                ic_range = result.get("ic_range", 0.0)
                ic_results.append((feat_a, feat_b, ic_range))

            # Sort by IC range descending
            ic_results.sort(key=lambda x: abs(x[2]) if x[2] is not None else 0.0, reverse=True)

            results["conditional_ic"] = {
                "top_interactions": ic_results,
                "n_pairs_tested": len(ic_results),
            }
        except Exception as e:
            method_failures.append(("conditional_ic", str(e)))

    if "h_statistic" in methods:
        try:
            # Convert feature pairs to indices for h_statistic
            pair_indices = []
            for feat_a, feat_b in feature_pairs:
                idx_a = feature_names.index(feat_a)
                idx_b = feature_names.index(feat_b)
                pair_indices.append((idx_a, idx_b))

            results["h_statistic"] = compute_h_statistic(
                model,
                X,
                feature_pairs=pair_indices,
                feature_names=feature_names,
                n_samples=max_samples,
                grid_resolution=grid_resolution,
            )
        except Exception as e:
            method_failures.append(("h_statistic", str(e)))

    if "shap" in methods:
        try:
            shap_result = compute_shap_interactions(
                model,
                X,
                feature_names=feature_names,
                max_samples=max_samples,
            )

            # Filter to requested pairs if feature_pairs was specified
            if feature_pairs is not None:
                pair_set = set(feature_pairs) | {(b, a) for a, b in feature_pairs}
                filtered_interactions = [
                    (a, b, score)
                    for a, b, score in shap_result["top_interactions"]
                    if (a, b) in pair_set or (b, a) in pair_set
                ]
                shap_result["top_interactions"] = filtered_interactions

            results["shap"] = shap_result
        except ImportError:
            method_failures.append(
                (
                    "shap",
                    "shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
                )
            )
        except Exception as e:
            method_failures.append(("shap", str(e)))

    # Check if at least one method succeeded
    if not results:
        error_msg = "All methods failed:\n" + "\n".join(
            f"  - {method}: {error}" for method, error in method_failures
        )
        raise ValueError(error_msg)

    # 2. Compute consensus ranking
    rankings: dict[str, NDArray[Any]] = {}
    for method_name, result in results.items():
        # Get interaction scores for this method
        method_interactions: list[tuple[str, str, float]]
        if "top_interactions" in result:
            method_interactions = cast(list[tuple[str, str, float]], result["top_interactions"])
        elif "h_statistics" in result:
            method_interactions = cast(list[tuple[str, str, float]], result["h_statistics"])
        else:
            continue

        # Create a mapping from pair to rank
        pair_to_rank: dict[tuple[str, str], int] = {}
        for rank_idx, interaction_tuple in enumerate(method_interactions):
            feat_a_int, feat_b_int = str(interaction_tuple[0]), str(interaction_tuple[1])
            pair_key = (min(feat_a_int, feat_b_int), max(feat_a_int, feat_b_int))
            pair_to_rank[pair_key] = rank_idx

        # Map all requested pairs to ranks (handle missing pairs)
        ranks_array: list[int] = []
        for feat_a, feat_b in feature_pairs:
            pair_key = (min(feat_a, feat_b), max(feat_a, feat_b))
            rank_val = pair_to_rank.get(pair_key, len(method_interactions))
            ranks_array.append(rank_val)

        rankings[method_name] = np.array(ranks_array)

    # Average ranks across methods
    avg_ranks = np.mean(list(rankings.values()), axis=0)

    # Create consensus ranking with scores from each method
    consensus_ranking: list[tuple[str, str, float, dict[str, float]]] = []
    for idx, avg_rank in enumerate(avg_ranks):
        feat_a, feat_b = feature_pairs[idx]
        pair_tuple: tuple[str, str] = (min(feat_a, feat_b), max(feat_a, feat_b))

        # Collect scores from each method
        scores_dict: dict[str, float] = {}
        for method_name, result in results.items():
            method_ints: list[tuple[str, str, float]]
            if "top_interactions" in result:
                method_ints = cast(list[tuple[str, str, float]], result["top_interactions"])
            elif "h_statistics" in result:
                method_ints = cast(list[tuple[str, str, float]], result["h_statistics"])
            else:
                continue

            for int_tuple in method_ints:
                check_pair = (
                    min(str(int_tuple[0]), str(int_tuple[1])),
                    max(str(int_tuple[0]), str(int_tuple[1])),
                )
                if check_pair == pair_tuple:
                    scores_dict[method_name] = float(int_tuple[2])
                    break

        consensus_ranking.append((feat_a, feat_b, float(avg_rank), scores_dict))

    # Sort by average rank
    consensus_ranking.sort(key=lambda x: x[2])

    # 3. Compute method agreement (Spearman correlation between rankings)
    method_agreement = {}
    method_names = list(rankings.keys())
    for i, m1 in enumerate(method_names):
        for m2 in method_names[i + 1 :]:
            corr, _ = spearmanr(rankings[m1], rankings[m2])
            method_agreement[(m1, m2)] = float(corr)

    # 4. Identify consensus top interactions (top 10 in all methods)
    top_n = 10
    top_interactions_by_method: dict[str, set[tuple[str, str]]] = {}
    for method_name, result in results.items():
        method_ints_list: list[tuple[str, str, float]]
        if "top_interactions" in result:
            method_ints_list = cast(list[tuple[str, str, float]], result["top_interactions"])
        elif "h_statistics" in result:
            method_ints_list = cast(list[tuple[str, str, float]], result["h_statistics"])
        else:
            continue

        method_top_pairs: list[tuple[str, str]] = []
        for int_entry in method_ints_list[:top_n]:
            pair_sorted: tuple[str, str] = (
                min(str(int_entry[0]), str(int_entry[1])),
                max(str(int_entry[0]), str(int_entry[1])),
            )
            method_top_pairs.append(pair_sorted)
        top_interactions_by_method[method_name] = set(method_top_pairs)

    if top_interactions_by_method:
        consensus_top_pairs = set.intersection(*top_interactions_by_method.values())
    else:
        consensus_top_pairs = set()

    consensus_top_list = list(consensus_top_pairs)

    # 5. Generate warnings
    warnings = []

    # Warning: Disagreement between specific methods
    if "conditional_ic" in results and "h_statistic" in results:
        ic_interactions: list[tuple[str, str, float]]
        if "top_interactions" in results["conditional_ic"]:
            ic_interactions = cast(
                list[tuple[str, str, float]], results["conditional_ic"]["top_interactions"]
            )
        else:
            ic_interactions = []

        h_interactions: list[tuple[str, str, float]] = cast(
            list[tuple[str, str, float]], results["h_statistic"].get("h_statistics", [])
        )

        ic_top: set[tuple[str, str]] = {
            (min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in ic_interactions[:5]
        }
        h_top: set[tuple[str, str]] = {
            (min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in h_interactions[:5]
        }

        disagreement = ic_top - h_top
        if disagreement:
            pairs_str = ", ".join([f"({a}, {b})" for a, b in disagreement])
            warnings.append(
                f"Pairs {pairs_str} rank high in Conditional IC but not H-statistic - "
                "possible regime-specific interaction (time-varying)"
            )

    # Warning: Low agreement between methods
    if method_agreement:
        min_agreement = min(method_agreement.values())
        if min_agreement < 0.5:
            warnings.append(
                f"Low agreement between methods (min correlation: {min_agreement:.2f}) - "
                "results may be unreliable or methods capture different interaction types"
            )

    # Add method failures to warnings
    if method_failures:
        for method, error in method_failures:
            warnings.append(f"Method '{method}' failed: {error}")

    # 6. Generate interpretation
    top_pairs = [(a, b) for a, b, _, _ in consensus_ranking[:10]]
    interpretation = _generate_interaction_interpretation(
        top_pairs,
        method_agreement,
        warnings,
        len(consensus_top_list),
    )

    return {
        "method_results": results,
        "consensus_ranking": consensus_ranking,
        "method_agreement": method_agreement,
        "top_interactions_consensus": consensus_top_list,
        "warnings": warnings,
        "interpretation": interpretation,
        "methods_run": list(results.keys()),
        "methods_failed": method_failures,
    }

sharpe_ratio

sharpe_ratio(
    returns,
    risk_free_rate=0.0,
    periods_per_year=252,
    confidence_intervals=False,
    alpha=0.05,
    bootstrap_samples=1000,
    random_state=None,
)

Calculate annualized Sharpe ratio.

The Sharpe Ratio measures risk-adjusted returns by dividing excess returns by return volatility. Higher values indicate better risk-adjusted performance.

Parameters

returns : Union[pl.Series, pd.Series, np.ndarray] Time series of periodic returns risk_free_rate : float, default 0.0 Annual risk-free rate periods_per_year : int, default 252 Number of return periods per year confidence_intervals : bool, default False Whether to compute bootstrap confidence intervals alpha : float, default 0.05 Significance level for confidence intervals bootstrap_samples : int, default 1000 Number of bootstrap samples for confidence intervals random_state : Optional[int], default None Random seed for reproducible bootstrap samples

Returns

Union[float, dict] If confidence_intervals=False: Sharpe ratio value If confidence_intervals=True: dict with 'sharpe', 'lower_ci', 'upper_ci'

Examples

returns = np.array([0.01, 0.02, -0.01, 0.03, 0.00]) sharpe = sharpe_ratio(returns, periods_per_year=252) print(f"Sharpe Ratio: {sharpe:.3f}")

With confidence intervals

result = sharpe_ratio(returns, confidence_intervals=True, random_state=42) print(f"Sharpe: {result['sharpe']:.3f}")

Source code in src/ml4t/diagnostic/metrics/risk_adjusted.py
def sharpe_ratio(
    returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
    risk_free_rate: float = 0.0,
    periods_per_year: int = 252,
    confidence_intervals: bool = False,
    alpha: float = 0.05,
    bootstrap_samples: int = 1000,
    random_state: int | None = None,
) -> float | dict[str, float]:
    """Calculate annualized Sharpe ratio.

    The Sharpe Ratio measures risk-adjusted returns by dividing excess returns
    by return volatility. Higher values indicate better risk-adjusted performance.

    Parameters
    ----------
    returns : Union[pl.Series, pd.Series, np.ndarray]
        Time series of periodic returns
    risk_free_rate : float, default 0.0
        Annual risk-free rate
    periods_per_year : int, default 252
        Number of return periods per year
    confidence_intervals : bool, default False
        Whether to compute bootstrap confidence intervals
    alpha : float, default 0.05
        Significance level for confidence intervals
    bootstrap_samples : int, default 1000
        Number of bootstrap samples for confidence intervals
    random_state : Optional[int], default None
        Random seed for reproducible bootstrap samples

    Returns
    -------
    Union[float, dict]
        If confidence_intervals=False: Sharpe ratio value
        If confidence_intervals=True: dict with 'sharpe', 'lower_ci', 'upper_ci'

    Examples
    --------
    >>> returns = np.array([0.01, 0.02, -0.01, 0.03, 0.00])
    >>> sharpe = sharpe_ratio(returns, periods_per_year=252)
    >>> print(f"Sharpe Ratio: {sharpe:.3f}")

    >>> # With confidence intervals
    >>> result = sharpe_ratio(returns, confidence_intervals=True, random_state=42)
    >>> print(f"Sharpe: {result['sharpe']:.3f}")
    """
    if confidence_intervals:
        return sharpe_ratio_with_ci(
            returns,
            risk_free_rate=risk_free_rate,
            periods_per_year=periods_per_year,
            alpha=alpha,
            bootstrap_samples=bootstrap_samples,
            random_state=random_state,
        )
    return periodic_sharpe_ratio(
        returns,
        periodic_risk_free_rate=_periodic_risk_free_rate(risk_free_rate, periods_per_year),
        annualization_factor=periods_per_year,
    )

sortino_ratio

sortino_ratio(
    returns,
    risk_free_rate=0.0,
    periods_per_year=252,
    target_return=0.0,
)

Calculate annualized Sortino ratio.

The Sortino Ratio is similar to Sharpe ratio but only penalizes downside volatility, making it more appropriate for asymmetric return distributions.

Parameters

returns : Union[pl.Series, pd.Series, np.ndarray] Time series of periodic returns risk_free_rate : float, default 0.0 Annual risk-free rate periods_per_year : int, default 252 Number of return periods per year target_return : float, default 0.0 Periodic target threshold after subtracting the risk-free rate

Returns

float Sortino ratio value

Examples

returns = np.array([0.01, 0.02, -0.01, 0.03, -0.02]) sortino = sortino_ratio(returns, periods_per_year=252) print(f"Sortino Ratio: {sortino:.3f}") Sortino Ratio: 0.894

Source code in src/ml4t/diagnostic/metrics/risk_adjusted.py
def sortino_ratio(
    returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
    risk_free_rate: float = 0.0,
    periods_per_year: int = 252,
    target_return: float = 0.0,
) -> float:
    """Calculate annualized Sortino ratio.

    The Sortino Ratio is similar to Sharpe ratio but only penalizes downside
    volatility, making it more appropriate for asymmetric return distributions.

    Parameters
    ----------
    returns : Union[pl.Series, pd.Series, np.ndarray]
        Time series of periodic returns
    risk_free_rate : float, default 0.0
        Annual risk-free rate
    periods_per_year : int, default 252
        Number of return periods per year
    target_return : float, default 0.0
        Periodic target threshold after subtracting the risk-free rate

    Returns
    -------
    float
        Sortino ratio value

    Examples
    --------
    >>> returns = np.array([0.01, 0.02, -0.01, 0.03, -0.02])
    >>> sortino = sortino_ratio(returns, periods_per_year=252)
    >>> print(f"Sortino Ratio: {sortino:.3f}")
    Sortino Ratio: 0.894
    """
    return periodic_sortino_ratio(
        returns,
        periodic_risk_free_rate=_periodic_risk_free_rate(risk_free_rate, periods_per_year),
        annualization_factor=periods_per_year,
        target_return=target_return,
    )

maximum_drawdown

maximum_drawdown(returns, cumulative=False)

Calculate Maximum Drawdown and related statistics.

Maximum Drawdown measures the largest peak-to-trough decline in cumulative returns. It represents the worst-case loss an investor would experience.

Parameters

returns : Union[pl.Series, pd.Series, np.ndarray] Time series of returns (or cumulative returns if cumulative=True) cumulative : bool, default False Whether input is already cumulative returns

Returns

dict Dictionary with 'max_drawdown', 'max_drawdown_duration', 'peak_date', 'trough_date'

Examples

returns = np.array([0.10, -0.05, 0.08, -0.12, 0.03]) dd = maximum_drawdown(returns) print(f"Max Drawdown: {dd['max_drawdown']:.3f}") Max Drawdown: -0.102

Source code in src/ml4t/diagnostic/metrics/risk_adjusted.py
def maximum_drawdown(
    returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
    cumulative: bool = False,
) -> dict[str, float]:
    """Calculate Maximum Drawdown and related statistics.

    Maximum Drawdown measures the largest peak-to-trough decline in cumulative
    returns. It represents the worst-case loss an investor would experience.

    Parameters
    ----------
    returns : Union[pl.Series, pd.Series, np.ndarray]
        Time series of returns (or cumulative returns if cumulative=True)
    cumulative : bool, default False
        Whether input is already cumulative returns

    Returns
    -------
    dict
        Dictionary with 'max_drawdown', 'max_drawdown_duration', 'peak_date', 'trough_date'

    Examples
    --------
    >>> returns = np.array([0.10, -0.05, 0.08, -0.12, 0.03])
    >>> dd = maximum_drawdown(returns)
    >>> print(f"Max Drawdown: {dd['max_drawdown']:.3f}")
    Max Drawdown: -0.102
    """
    # Import here to avoid circular dependency
    from ml4t.diagnostic.core.numba_utils import calculate_drawdown_numba

    # Convert to numpy array
    ret_array = DataFrameAdapter.to_numpy(returns).flatten()

    # Remove NaN values
    ret_clean = ret_array[~np.isnan(ret_array)]

    if len(ret_clean) == 0:
        return {
            "max_drawdown": np.nan,
            "max_drawdown_duration": np.nan,
            "peak_date": np.nan,
            "trough_date": np.nan,
        }

    # Calculate cumulative returns if needed
    if cumulative:
        cum_returns = ret_clean
    else:
        cum_returns = np.cumprod(1 + ret_clean) - 1  # Compound returns

    # Use Numba-optimized function
    max_drawdown_val, dd_duration, peak_idx, trough_idx = calculate_drawdown_numba(cum_returns)

    # Handle case where no drawdown was found
    if peak_idx == -1:
        return {
            "max_drawdown": 0.0,
            "max_drawdown_duration": 0,
            "peak_date": 0,
            "trough_date": 0,
        }

    return {
        "max_drawdown": float(max_drawdown_val),
        "max_drawdown_duration": int(dd_duration),
        "peak_date": int(peak_idx),
        "trough_date": int(trough_idx),
    }

hit_rate

hit_rate(predictions, returns)

Calculate hit rate (percentage of correct directional predictions).

Hit rate measures what percentage of predictions correctly identify the direction of subsequent returns (positive/negative).

Parameters

predictions : Union[pl.Series, pd.Series, np.ndarray] Model predictions or scores returns : Union[pl.Series, pd.Series, np.ndarray] Forward returns corresponding to predictions

Returns

float Hit rate as a percentage (0-100)

Examples

predictions = np.array([0.1, -0.2, 0.3, -0.1]) returns = np.array([0.02, -0.01, 0.05, 0.01]) # Note: last one wrong direction hr = hit_rate(predictions, returns) print(f"Hit Rate: {hr:.1f}%") Hit Rate: 75.0%

Source code in src/ml4t/diagnostic/metrics/basic.py
def hit_rate(
    predictions: Union[pl.Series, pd.Series, "NDArray"],
    returns: Union[pl.Series, pd.Series, "NDArray"],
) -> float:
    """Calculate hit rate (percentage of correct directional predictions).

    Hit rate measures what percentage of predictions correctly identify the
    direction of subsequent returns (positive/negative).

    Parameters
    ----------
    predictions : Union[pl.Series, pd.Series, np.ndarray]
        Model predictions or scores
    returns : Union[pl.Series, pd.Series, np.ndarray]
        Forward returns corresponding to predictions

    Returns
    -------
    float
        Hit rate as a percentage (0-100)

    Examples
    --------
    >>> predictions = np.array([0.1, -0.2, 0.3, -0.1])
    >>> returns = np.array([0.02, -0.01, 0.05, 0.01])  # Note: last one wrong direction
    >>> hr = hit_rate(predictions, returns)
    >>> print(f"Hit Rate: {hr:.1f}%")
    Hit Rate: 75.0%
    """
    # Convert inputs to numpy
    pred_array = DataFrameAdapter.to_numpy(predictions).flatten()
    ret_array = DataFrameAdapter.to_numpy(returns).flatten()

    # Validate inputs
    if len(pred_array) != len(ret_array):
        raise ValueError("Predictions and returns must have the same length")

    # Remove NaN pairs
    valid_mask = ~(np.isnan(pred_array) | np.isnan(ret_array))
    pred_clean = pred_array[valid_mask]
    ret_clean = ret_array[valid_mask]

    if len(pred_clean) == 0:
        return np.nan

    # Calculate directional accuracy
    pred_direction = np.sign(pred_clean)
    ret_direction = np.sign(ret_clean)

    # Count correct predictions (same sign)
    correct_predictions = pred_direction == ret_direction

    # Handle zero returns/predictions by considering them neutral (correct)
    zero_mask = (pred_clean == 0) | (ret_clean == 0)
    correct_predictions[zero_mask] = True  # Conservative approach

    hit_rate_value = np.mean(correct_predictions) * 100

    return float(hit_rate_value)

compute_forward_returns

compute_forward_returns(
    prices,
    periods=1,
    price_col="close",
    group_col=None,
    date_col="date",
    output_col_template="fwd_ret_{period}",
    nan_to_null=False,
)

Compute forward returns for given periods.

This is a helper function for IC analysis, computing the forward-looking returns that will be correlated with predictions/features.

Parameters

prices : Union[pl.DataFrame, pd.DataFrame] Price data with at least price_col and optionally group_col periods : Union[int, list[int]], default 1 Forward periods to compute (e.g., [1, 5, 21] for 1d, 1w, 1m) price_col : str, default "close" Column name containing prices group_col : str | None, default None Column for grouping (e.g., 'symbol' for multi-asset) date_col : str | None, default "date" Optional date/timestamp column used to sort before computing forward returns. Set to None to skip sorting. output_col_template : str, default "fwd_ret_{period}" Template for generated forward return column names. Use {period} placeholder for the horizon integer. nan_to_null : bool, default False Convert NaN forward return values to null/None in Polars output.

Returns

Union[pl.DataFrame, pd.DataFrame] DataFrame with forward return columns: fwd_ret_1, fwd_ret_5, etc.

Examples

prices = pl.DataFrame({ ... "date": ["2024-01-01", "2024-01-02", "2024-01-03"], ... "close": [100.0, 102.0, 101.0] ... }) fwd_returns = compute_forward_returns(prices, periods=[1, 2]) print(fwd_returns.columns) ['date', 'close', 'fwd_ret_1', 'fwd_ret_2']

Source code in src/ml4t/diagnostic/metrics/basic.py
def compute_forward_returns(
    prices: pl.DataFrame | pd.DataFrame,
    periods: int | list[int] = 1,
    price_col: str = "close",
    group_col: str | None = None,
    date_col: str | None = "date",
    output_col_template: str = "fwd_ret_{period}",
    nan_to_null: bool = False,
) -> pl.DataFrame | pd.DataFrame:
    """Compute forward returns for given periods.

    This is a helper function for IC analysis, computing the forward-looking
    returns that will be correlated with predictions/features.

    Parameters
    ----------
    prices : Union[pl.DataFrame, pd.DataFrame]
        Price data with at least price_col and optionally group_col
    periods : Union[int, list[int]], default 1
        Forward periods to compute (e.g., [1, 5, 21] for 1d, 1w, 1m)
    price_col : str, default "close"
        Column name containing prices
    group_col : str | None, default None
        Column for grouping (e.g., 'symbol' for multi-asset)
    date_col : str | None, default "date"
        Optional date/timestamp column used to sort before computing
        forward returns. Set to None to skip sorting.
    output_col_template : str, default "fwd_ret_{period}"
        Template for generated forward return column names.
        Use ``{period}`` placeholder for the horizon integer.
    nan_to_null : bool, default False
        Convert NaN forward return values to null/None in Polars output.

    Returns
    -------
    Union[pl.DataFrame, pd.DataFrame]
        DataFrame with forward return columns: fwd_ret_1, fwd_ret_5, etc.

    Examples
    --------
    >>> prices = pl.DataFrame({
    ...     "date": ["2024-01-01", "2024-01-02", "2024-01-03"],
    ...     "close": [100.0, 102.0, 101.0]
    ... })
    >>> fwd_returns = compute_forward_returns(prices, periods=[1, 2])
    >>> print(fwd_returns.columns)
    ['date', 'close', 'fwd_ret_1', 'fwd_ret_2']
    """
    output_as_pandas = isinstance(prices, pd.DataFrame)

    # Ensure periods is a list
    if isinstance(periods, int):
        periods = [periods]

    df = (
        prices.clone()
        if isinstance(prices, pl.DataFrame)
        else pl.from_pandas(cast(pd.DataFrame, prices))
    )

    if date_col is not None and date_col in df.columns:
        if group_col is not None and group_col in df.columns:
            df = df.sort([group_col, date_col])
        else:
            df = df.sort(date_col)

    if group_col is not None:
        # Group-wise forward returns
        return_cols = []
        for period in periods:
            col_name = output_col_template.format(period=period)
            return_cols.append(col_name)
            df = df.with_columns(
                ((pl.col(price_col).shift(-period).over(group_col) / pl.col(price_col)) - 1).alias(
                    col_name
                )
            )
    else:
        # Simple forward returns
        return_cols = []
        for period in periods:
            col_name = output_col_template.format(period=period)
            return_cols.append(col_name)
            df = df.with_columns(
                ((pl.col(price_col).shift(-period) / pl.col(price_col)) - 1).alias(col_name)
            )

    if nan_to_null:
        df = df.with_columns(
            [
                pl.when(pl.col(c).is_nan()).then(None).otherwise(pl.col(c)).alias(c)
                for c in return_cols
            ]
        )

    if output_as_pandas:
        return df.to_pandas()
    return df

Cross-Validation

splitters

Time-series cross-validation splitters for financial data.

This module provides cross-validation methods designed specifically for financial time-series data, addressing common issues like data leakage and backtest overfitting.

BaseSplitter

Bases: ABC

Abstract base class for all ml4t-diagnostic time-series splitters.

This class defines the interface that all splitters must implement to ensure compatibility with scikit-learn's model selection tools while providing additional functionality for financial time-series validation.

All splitters should support purging (removing training data that could leak information into test data) and embargo (adding gaps between train and test sets to account for serial correlation).

Session-Aware Splitting

Splitters can optionally align fold boundaries to trading session boundaries by setting align_to_sessions=True. This requires the data to have a session column (default: 'session_date') that identifies trading sessions.

Trading sessions are atomic units that should never be split across train/test folds. For intraday data (e.g., CME futures with Sunday 5pm - Friday 4pm sessions), this prevents subtle lookahead bias from mid-session splits.

Integration with qdata library:

The session column should be added using the qdata library's session assignment functionality::

from qdata import DataManager

manager = DataManager()
df = manager.load(symbol="BTC", exchange="CME", calendar="CME_Globex_Crypto")
# df now has 'session_date' column automatically assigned

Or manually using SessionAssigner::

from ml4t.data.sessions import SessionAssigner

assigner = SessionAssigner.from_exchange('CME')
df_with_sessions = assigner.assign_sessions(df)

Then use with ml4t-diagnostic splitters::

from ml4t.diagnostic.splitters import WalkForwardCV

cv = WalkForwardCV(
    n_splits=5,
    align_to_sessions=True,  # Align folds to session boundaries
    session_col='session_date'
)

for train_idx, test_idx in cv.split(df_with_sessions):
    # Fold boundaries respect session boundaries
    pass

split abstractmethod

split(X, y=None, groups=None)

Generate indices to split data into training and test sets.

Parameters

X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray Training data with shape (n_samples, n_features).

polars.Series, pandas.Series, numpy.ndarray, or None, default=None

Target variable with shape (n_samples,). Always ignored but kept for scikit-learn compatibility.

polars.Series, pandas.Series, numpy.ndarray, or None, default=None

Group labels for samples, used for multi-asset splitting. Shape (n_samples,).

Yields:

train : numpy.ndarray The training set indices for that split.

numpy.ndarray

The testing set indices for that split.

Notes:

The indices returned are integer positions, not labels or timestamps. This ensures compatibility with numpy array indexing and scikit-learn.

Source code in src/ml4t/diagnostic/splitters/base.py
@abstractmethod
def split(
    self,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
    y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
    groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
    """Generate indices to split data into training and test sets.

    Parameters
    ----------
    X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
        Training data with shape (n_samples, n_features).

    y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
        Target variable with shape (n_samples,). Always ignored but kept
        for scikit-learn compatibility.

    groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
        Group labels for samples, used for multi-asset splitting.
        Shape (n_samples,).

    Yields:
    ------
    train : numpy.ndarray
        The training set indices for that split.

    test : numpy.ndarray
        The testing set indices for that split.

    Notes:
    -----
    The indices returned are integer positions, not labels or timestamps.
    This ensures compatibility with numpy array indexing and scikit-learn.
    """

get_n_splits

get_n_splits(X=None, y=None, groups=None)

Return the number of splitting iterations in the cross-validator.

Parameters

X : polars.DataFrame, pandas.DataFrame, numpy.ndarray, or None, default=None Training data. Some splitters may use properties of X to determine the number of splits.

polars.Series, pandas.Series, numpy.ndarray, or None, default=None

Always ignored, exists for compatibility.

polars.Series, pandas.Series, numpy.ndarray, or None, default=None

Group labels. Some splitters may use this to determine splits.

Returns:

n_splits : int The number of splitting iterations.

Notes:

Most splitters can determine the number of splits from their parameters alone, but some (like GroupKFold variants) may need to inspect the data.

Source code in src/ml4t/diagnostic/splitters/base.py
def get_n_splits(
    self,
    X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"] | None = None,
    y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
    groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
) -> int:
    """Return the number of splitting iterations in the cross-validator.

    Parameters
    ----------
    X : polars.DataFrame, pandas.DataFrame, numpy.ndarray, or None, default=None
        Training data. Some splitters may use properties of X to determine
        the number of splits.

    y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
        Always ignored, exists for compatibility.

    groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
        Group labels. Some splitters may use this to determine splits.

    Returns:
    -------
    n_splits : int
        The number of splitting iterations.

    Notes:
    -----
    Most splitters can determine the number of splits from their parameters
    alone, but some (like GroupKFold variants) may need to inspect the data.
    """
    raise NotImplementedError(
        f"{self.__class__.__name__} must implement get_n_splits()",
    )

__repr__

__repr__()

Return a string representation of the splitter.

Source code in src/ml4t/diagnostic/splitters/base.py
def __repr__(self) -> str:
    """Return a string representation of the splitter."""
    return f"{self.__class__.__name__}()"

CombinatorialCV

CombinatorialCV(
    config=None,
    *,
    n_groups=8,
    n_test_groups=2,
    label_horizon=0,
    embargo_size=None,
    embargo_pct=None,
    max_combinations=None,
    random_state=None,
    align_to_sessions=False,
    session_col="session_date",
    timestamp_col=None,
    isolate_groups=True,
)

Bases: BaseSplitter

Combinatorial Cross-Validation for backtest overfitting detection.

CPCV partitions the time series into N contiguous groups and forms all combinations C(N,k) of choosing k groups for testing. This generates multiple backtest paths instead of a single chronological split, providing a robust assessment of strategy performance and enabling detection of backtest overfitting.

How It Works
  1. Partitioning: Divide time-series data into N contiguous groups of equal size
  2. Combination Generation: Generate all C(N,k) combinations of choosing k groups for testing
  3. Label Overlap Removal: For each combination, remove training samples whose labels overlap test data
  4. Embargo Buffer: Optionally add buffer periods after test groups to exclude autocorrelated samples
  5. Multi-Asset Handling: When groups are provided, handle each asset independently
Label Horizon (label_horizon)

Why needed? When labels are forward-looking (e.g., 5-day returns), training samples near the test set have labels that "see into" the test period. Without removing these, the model trains on information about test outcomes, leading to inflated performance estimates.

How it works: For each test group with range [t_start, t_end]:

1. Remove train samples where: ``t_train > t_start - label_horizon``
2. This ensures no training sample's label period overlaps with test samples

Example::

Test group: samples 100-119 (20 samples)
label_horizon: 5 samples
Removes: training samples 95-99
Reason: Sample 95's label (computed from samples 95-100) overlaps test data
Embargo Buffer (embargo_size)

Why needed? Unlike walk-forward CV where training always precedes test, CPCV can have training groups that follow test groups chronologically. Samples immediately after test data may be autocorrelated with it.

How it works: Remove training samples in a buffer zone after each test group:

- **embargo_size**: Absolute number of samples (e.g., 10 samples)
- **embargo_pct**: Percentage of total samples (e.g., 0.01 = 1%)

Example::

Test group: samples 100-119
embargo_size: 5 samples
Additional removal: training samples 120-124
When this matters: If predicting volatility and the test period has a volatility
spike, samples 120-124 likely share similar volatility due to clustering.
Multi-Asset Handling

When groups parameter is provided (e.g., asset symbols), CPCV handles each asset independently. This prevents cross-asset leakage:

Process: 1. For each asset, find its training and test samples 2. Apply label_horizon/embargo only to that asset's data 3. Combine results across all assets

Why Important? Without per-asset handling, information could leak between assets that trade at different times (e.g., European markets vs US markets).

Based on Bailey et al. (2014) "The Probability of Backtest Overfitting" and López de Prado (2018) "Advances in Financial Machine Learning".

Parameters

n_groups : int, default=8 Number of contiguous groups to partition the time series into.

int, default=2

Number of groups to use for testing in each combination.

int or pd.Timedelta, default=0

How far ahead labels look into the future. Removes training samples whose prediction targets overlap with test data.

int or pd.Timedelta, optional

Buffer zone after test groups. Excludes autocorrelated training samples that follow test data chronologically.

float, optional

Embargo size as percentage of total samples (alternative to embargo_size).

int, optional

Maximum number of combinations to generate. If None, generates all C(N,k). Use this to limit computational cost for large N.

int, optional

Random seed for combination sampling when max_combinations is set.

bool, default=False

If True, align group boundaries to trading session boundaries. Requires X to have a session column (specified by session_col parameter).

Trading sessions should be assigned using the qdata library before cross-validation: - Use DataManager with exchange/calendar parameters, or - Use SessionAssigner.from_exchange('CME') directly

str, default='session_date'

Name of the column containing session identifiers. Only used if align_to_sessions=True. This column should be added by qdata.sessions.SessionAssigner

bool, default=True

If True, prevent the same group (asset/symbol) from appearing in both train and test sets. This is enabled by default for CPCV as it's designed for multi-asset validation.

Requires passing groups parameter to split() method with asset IDs.

Note: CPCV already applies per-asset purging when groups are provided. This parameter provides additional group isolation guarantee.

Attributes:

n_groups_ : int The number of groups.

int

The number of test groups.

Examples:

import numpy as np from ml4t.diagnostic.splitters import CombinatorialCV X = np.arange(200).reshape(200, 1) cv = CombinatorialCV(n_groups=6, n_test_groups=2, label_horizon=5) combinations = list(cv.split(X)) print(f"Generated {len(combinations)} combinations") Generated 15 combinations

Each combination provides train/test indices

for i, (train, test) in enumerate(combinations[:3]): ... print(f"Combination {i+1}: Train={len(train)}, Test={len(test)}") Combination 1: Train=125, Test=50 Combination 2: Train=125, Test=50 Combination 3: Train=125, Test=50

Notes:

The total number of combinations is C(n_groups, n_test_groups). For large values, this can become computationally expensive: - C(8,2) = 28 combinations - C(10,3) = 120 combinations - C(12,4) = 495 combinations

Use max_combinations to limit computational cost for large datasets.

Initialize CombinatorialCV.

This splitter uses a config-first architecture. You can either: 1. Pass a config object: CombinatorialCV(config=my_config) 2. Pass individual parameters: CombinatorialCV(n_groups=8, n_test_groups=2)

Parameters are automatically converted to a config object internally, ensuring a single source of truth for all validation and logic.

Examples

Approach 1: Direct parameters (convenient)

cv = CombinatorialCV(n_groups=10, n_test_groups=3)

Approach 2: Config object (for serialization/reproducibility)

from ml4t.diagnostic.splitters.config import CombinatorialConfig config = CombinatorialConfig(n_groups=10, n_test_groups=3) cv = CombinatorialCV(config=config)

Config can be serialized

config.to_json("cpcv_config.json") loaded = CombinatorialConfig.from_json("cpcv_config.json") cv = CombinatorialCV(config=loaded)

Source code in src/ml4t/diagnostic/splitters/combinatorial.py
def __init__(
    self,
    config: CombinatorialConfig | None = None,
    *,
    n_groups: int = 8,
    n_test_groups: int = 2,
    label_horizon: int | pd.Timedelta = 0,
    embargo_size: int | pd.Timedelta | None = None,
    embargo_pct: float | None = None,
    max_combinations: int | None = None,
    random_state: int | None = None,
    align_to_sessions: bool = False,
    session_col: str = "session_date",
    timestamp_col: str | None = None,
    isolate_groups: bool = True,
) -> None:
    """Initialize CombinatorialCV.

    This splitter uses a config-first architecture. You can either:
    1. Pass a config object: CombinatorialCV(config=my_config)
    2. Pass individual parameters: CombinatorialCV(n_groups=8, n_test_groups=2)

    Parameters are automatically converted to a config object internally,
    ensuring a single source of truth for all validation and logic.

    Examples
    --------
    >>> # Approach 1: Direct parameters (convenient)
    >>> cv = CombinatorialCV(n_groups=10, n_test_groups=3)
    >>>
    >>> # Approach 2: Config object (for serialization/reproducibility)
    >>> from ml4t.diagnostic.splitters.config import CombinatorialConfig
    >>> config = CombinatorialConfig(n_groups=10, n_test_groups=3)
    >>> cv = CombinatorialCV(config=config)
    >>>
    >>> # Config can be serialized
    >>> config.to_json("cpcv_config.json")
    >>> loaded = CombinatorialConfig.from_json("cpcv_config.json")
    >>> cv = CombinatorialCV(config=loaded)
    """
    # Config-first: either use provided config or create from params
    if config is not None:
        # Verify no conflicting parameters when config is provided
        self._validate_no_param_conflicts(
            n_groups,
            n_test_groups,
            label_horizon,
            embargo_size,
            embargo_pct,
            max_combinations,
            random_state,
            align_to_sessions,
            session_col,
            timestamp_col,
            isolate_groups,
        )
        self.config = config
    else:
        # Create config from individual parameters
        # Note: embargo validation (mutual exclusivity) handled by config
        self.config = self._create_config_from_params(
            n_groups,
            n_test_groups,
            label_horizon,
            embargo_size,
            embargo_pct,
            max_combinations,
            random_state,
            align_to_sessions,
            session_col,
            timestamp_col,
            isolate_groups,
        )

    # Use parameter if provided, otherwise use config value
    # This allows random_state to be passed either via config or direct parameter
    self.random_state = random_state if random_state is not None else self.config.random_state

n_groups property

n_groups

Number of groups to partition timeline into.

n_test_groups property

n_test_groups

Number of groups per test set.

label_horizon property

label_horizon

Forward-looking period of labels (int samples or Timedelta).

embargo_size property

embargo_size

Embargo buffer size (int samples or Timedelta).

embargo_pct property

embargo_pct

Embargo size as percentage of total samples.

max_combinations property

max_combinations

Maximum number of folds to generate.

align_to_sessions property

align_to_sessions

Whether to align group boundaries to sessions.

session_col property

session_col

Column name containing session identifiers.

timestamp_col property

timestamp_col

Column name containing timestamps for time-based operations.

isolate_groups property

isolate_groups

Whether to prevent group overlap between train/test.

get_n_splits

get_n_splits(X=None, y=None, groups=None)

Get number of splits (combinations).

Parameters

X : array-like, optional Always ignored, exists for compatibility.

array-like, optional

Always ignored, exists for compatibility.

array-like, optional

Always ignored, exists for compatibility.

Returns:

n_splits : int Number of combinations that will be generated.

Source code in src/ml4t/diagnostic/splitters/combinatorial.py
def get_n_splits(
    self,
    X: pl.DataFrame | pd.DataFrame | NDArray[Any] | None = None,
    y: pl.Series | pd.Series | NDArray[Any] | None = None,
    groups: pl.Series | pd.Series | NDArray[Any] | None = None,
) -> int:
    """Get number of splits (combinations).

    Parameters
    ----------
    X : array-like, optional
        Always ignored, exists for compatibility.

    y : array-like, optional
        Always ignored, exists for compatibility.

    groups : array-like, optional
        Always ignored, exists for compatibility.

    Returns:
    -------
    n_splits : int
        Number of combinations that will be generated.
    """
    del X, y, groups  # Unused, for sklearn compatibility
    total_combinations = math.comb(self.n_groups, self.n_test_groups)

    if self.max_combinations is None:
        return total_combinations
    return min(self.max_combinations, total_combinations)

split

split(X, y=None, groups=None)

Generate train/test indices for combinatorial splits with purging and embargo.

This method generates all combinations C(N,k) of train/test splits, applying purging and embargo to prevent information leakage. Each yielded split represents an independent backtest path.

Parameters

X : DataFrame or ndarray of shape (n_samples, n_features) Training data. Must have a datetime index if using Timedelta-based label_horizon or embargo_size.

Series or ndarray of shape (n_samples,), optional

Target variable. Not used in splitting logic, but accepted for API compatibility with scikit-learn.

Series or ndarray of shape (n_samples,), optional

Group labels for samples (e.g., asset symbols for multi-asset strategies).

When provided: - Purging is applied independently per group (asset) - Prevents information leakage across groups - Essential for multi-asset portfolio validation

Example: groups = df["symbol"] # ["AAPL", "MSFT", "GOOGL", ...]

Yields

train : ndarray of shape (n_train_samples,) Indices of training samples for this combination. Purging and embargo have been applied to remove: - Samples overlapping with test labels (purging) - Samples in embargo buffer after test groups (embargo)

ndarray of shape (n_test_samples,)

Indices of test samples for this combination. Consists of samples from the k selected test groups.

Raises

ValueError If X has incompatible shape or missing required columns (e.g., session_col when align_to_sessions=True).

TypeError If X index is not datetime when using Timedelta parameters.

Notes

Number of Combinations: Generates C(n_groups, n_test_groups) combinations. For example: - C(8,2) = 28 combinations - C(10,3) = 120 combinations - C(12,4) = 495 combinations

Use ``max_combinations`` parameter to limit the number of splits generated.

Purging Logic: For each test group: 1. Identify test sample range [t_start, t_end] 2. Remove training samples where: t_train > t_start - label_horizon 3. This prevents training on samples whose labels overlap with test period

Embargo Logic: After purging, additionally remove training samples: - In range [t_end + 1, t_end + embargo_size] - This accounts for serial correlation in financial time series

Multi-Asset Handling: When groups is provided: 1. For each asset, find its training and test indices 2. Apply purging/embargo independently to that asset's data 3. Combine purged results across all assets 4. This prevents cross-asset information leakage

Session Alignment: When align_to_sessions=True: - Group boundaries align to trading session boundaries - Ensures each group contains complete trading days/sessions - Requires X to have column specified by session_col parameter

Examples

Basic usage with purging::

>>> import polars as pl
>>> from ml4t.diagnostic.splitters import CombinatorialCV
>>>
>>> # Create sample data
>>> n = 1000
>>> X = pl.DataFrame({"feature1": range(n), "feature2": range(n, 2*n)})
>>> y = pl.Series(range(n))
>>>
>>> # Configure CPCV
>>> cv = CombinatorialCV(
...     n_groups=8,
...     n_test_groups=2,
...     label_horizon=5,
...     embargo_size=2
... )
>>>
>>> # Generate splits
>>> for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
...     print(f"Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
Fold 0: Train=739, Test=250
Fold 1: Train=739, Test=250
...

Multi-asset usage::

>>> # Multi-asset data with symbol column
>>> symbols = pl.Series(["AAPL"] * 250 + ["MSFT"] * 250 +
...                      ["GOOGL"] * 250 + ["AMZN"] * 250)
>>>
>>> cv = CombinatorialCV(
...     n_groups=6,
...     n_test_groups=2,
...     label_horizon=5,
...     embargo_size=2,
...     isolate_groups=True
... )
>>>
>>> for train_idx, test_idx in cv.split(X, groups=symbols):
...     # Purging applied independently per asset
...     train_symbols = symbols[train_idx].unique()
...     test_symbols = symbols[test_idx].unique()

Session-aligned usage::

>>> import pandas as pd
>>>
>>> # Intraday data with session dates
>>> df = pd.DataFrame({
...     "timestamp": pd.date_range("2024-01-01", periods=1000, freq="1min"),
...     "session_date": pd.date_range("2024-01-01", periods=1000, freq="1min").date,
...     "feature1": range(1000)
... })
>>>
>>> cv = CombinatorialCV(
...     n_groups=10,
...     n_test_groups=2,
...     label_horizon=pd.Timedelta(minutes=30),
...     embargo_size=pd.Timedelta(minutes=15),
...     align_to_sessions=True,
...     session_col="session_date"
... )
>>>
>>> for train_idx, test_idx in cv.split(df):
...     # Group boundaries aligned to session boundaries
...     pass
See Also

CombinatorialConfig : Configuration object for CPCV parameters apply_purging_and_embargo : Low-level purging/embargo function BaseSplitter : Base class for all splitters

Source code in src/ml4t/diagnostic/splitters/combinatorial.py
def split(
    self,
    X: pl.DataFrame | pd.DataFrame | NDArray[Any],
    y: pl.Series | pd.Series | NDArray[Any] | None = None,
    groups: pl.Series | pd.Series | NDArray[Any] | None = None,
) -> Generator[tuple[NDArray[np.intp], NDArray[np.intp]], None, None]:
    """Generate train/test indices for combinatorial splits with purging and embargo.

    This method generates all combinations C(N,k) of train/test splits, applying
    purging and embargo to prevent information leakage. Each yielded split represents
    an independent backtest path.

    Parameters
    ----------
    X : DataFrame or ndarray of shape (n_samples, n_features)
        Training data. Must have a datetime index if using Timedelta-based
        label_horizon or embargo_size.

    y : Series or ndarray of shape (n_samples,), optional
        Target variable. Not used in splitting logic, but accepted for
        API compatibility with scikit-learn.

    groups : Series or ndarray of shape (n_samples,), optional
        Group labels for samples (e.g., asset symbols for multi-asset strategies).

        When provided:
        - Purging is applied independently per group (asset)
        - Prevents information leakage across groups
        - Essential for multi-asset portfolio validation

        Example: ``groups = df["symbol"]``  # ["AAPL", "MSFT", "GOOGL", ...]

    Yields
    ------
    train : ndarray of shape (n_train_samples,)
        Indices of training samples for this combination.
        Purging and embargo have been applied to remove:
        - Samples overlapping with test labels (purging)
        - Samples in embargo buffer after test groups (embargo)

    test : ndarray of shape (n_test_samples,)
        Indices of test samples for this combination.
        Consists of samples from the k selected test groups.

    Raises
    ------
    ValueError
        If X has incompatible shape or missing required columns
        (e.g., session_col when align_to_sessions=True).

    TypeError
        If X index is not datetime when using Timedelta parameters.

    Notes
    -----
    **Number of Combinations**:
        Generates C(n_groups, n_test_groups) combinations. For example:
        - C(8,2) = 28 combinations
        - C(10,3) = 120 combinations
        - C(12,4) = 495 combinations

        Use ``max_combinations`` parameter to limit the number of splits generated.

    **Purging Logic**:
        For each test group:
        1. Identify test sample range [t_start, t_end]
        2. Remove training samples where: t_train > t_start - label_horizon
        3. This prevents training on samples whose labels overlap with test period

    **Embargo Logic**:
        After purging, additionally remove training samples:
        - In range [t_end + 1, t_end + embargo_size]
        - This accounts for serial correlation in financial time series

    **Multi-Asset Handling**:
        When ``groups`` is provided:
        1. For each asset, find its training and test indices
        2. Apply purging/embargo independently to that asset's data
        3. Combine purged results across all assets
        4. This prevents cross-asset information leakage

    **Session Alignment**:
        When ``align_to_sessions=True``:
        - Group boundaries align to trading session boundaries
        - Ensures each group contains complete trading days/sessions
        - Requires X to have column specified by ``session_col`` parameter

    Examples
    --------
    Basic usage with purging::

        >>> import polars as pl
        >>> from ml4t.diagnostic.splitters import CombinatorialCV
        >>>
        >>> # Create sample data
        >>> n = 1000
        >>> X = pl.DataFrame({"feature1": range(n), "feature2": range(n, 2*n)})
        >>> y = pl.Series(range(n))
        >>>
        >>> # Configure CPCV
        >>> cv = CombinatorialCV(
        ...     n_groups=8,
        ...     n_test_groups=2,
        ...     label_horizon=5,
        ...     embargo_size=2
        ... )
        >>>
        >>> # Generate splits
        >>> for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
        ...     print(f"Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
        Fold 0: Train=739, Test=250
        Fold 1: Train=739, Test=250
        ...

    Multi-asset usage::

        >>> # Multi-asset data with symbol column
        >>> symbols = pl.Series(["AAPL"] * 250 + ["MSFT"] * 250 +
        ...                      ["GOOGL"] * 250 + ["AMZN"] * 250)
        >>>
        >>> cv = CombinatorialCV(
        ...     n_groups=6,
        ...     n_test_groups=2,
        ...     label_horizon=5,
        ...     embargo_size=2,
        ...     isolate_groups=True
        ... )
        >>>
        >>> for train_idx, test_idx in cv.split(X, groups=symbols):
        ...     # Purging applied independently per asset
        ...     train_symbols = symbols[train_idx].unique()
        ...     test_symbols = symbols[test_idx].unique()

    Session-aligned usage::

        >>> import pandas as pd
        >>>
        >>> # Intraday data with session dates
        >>> df = pd.DataFrame({
        ...     "timestamp": pd.date_range("2024-01-01", periods=1000, freq="1min"),
        ...     "session_date": pd.date_range("2024-01-01", periods=1000, freq="1min").date,
        ...     "feature1": range(1000)
        ... })
        >>>
        >>> cv = CombinatorialCV(
        ...     n_groups=10,
        ...     n_test_groups=2,
        ...     label_horizon=pd.Timedelta(minutes=30),
        ...     embargo_size=pd.Timedelta(minutes=15),
        ...     align_to_sessions=True,
        ...     session_col="session_date"
        ... )
        >>>
        >>> for train_idx, test_idx in cv.split(df):
        ...     # Group boundaries aligned to session boundaries
        ...     pass

    See Also
    --------
    CombinatorialConfig : Configuration object for CPCV parameters
    apply_purging_and_embargo : Low-level purging/embargo function
    BaseSplitter : Base class for all splitters
    """
    # Validate inputs (no numpy conversion - performance optimization)
    n_samples = self._validate_inputs(X, y, groups)

    # Validate session alignment if enabled
    self._validate_session_alignment(X, self.align_to_sessions, self.session_col)

    # Extract timestamps if available (supports both Polars and pandas)
    timestamps = self._extract_timestamps(X, self.timestamp_col)

    # Create group indices or boundaries
    # For session-aligned mode, we need exact indices (not boundaries) to handle
    # non-contiguous/interleaved data correctly
    if self.align_to_sessions:
        # align_to_sessions requires X to be a DataFrame (validation enforces this)
        # Use new method that returns exact indices per group
        group_indices_list = self._create_session_group_indices(
            cast(pl.DataFrame | pd.DataFrame, X)
        )
        use_exact_indices = True
        # Also create boundaries for backward compatibility with purging logic
        group_boundaries = [
            (int(indices[0]), int(indices[-1]) + 1) if len(indices) > 0 else (0, 0)
            for indices in group_indices_list
        ]
    else:
        group_boundaries = self._create_group_boundaries(n_samples)
        group_indices_list = None
        use_exact_indices = False

    # Generate combinations with memory-efficient sampling when max_combinations is set
    # Uses reservoir sampling when needed to avoid materializing all C(n,k) combinations
    combinations = iter_combinations(
        self.n_groups,
        self.n_test_groups,
        self.max_combinations,
        self.random_state,
    )

    # Generate splits for each combination
    for test_group_indices in combinations:
        # Create test set from selected groups
        if use_exact_indices and group_indices_list is not None:
            # Use exact indices (correct for non-contiguous/interleaved data)
            test_arrays = [group_indices_list[g] for g in test_group_indices]
            test_indices_array = (
                np.concatenate(test_arrays) if test_arrays else np.array([], dtype=np.intp)
            )
        else:
            # Use boundaries with range (only correct for contiguous data)
            test_indices: list[int] = []
            for group_idx in test_group_indices:
                start_idx, end_idx = group_boundaries[group_idx]
                test_indices.extend(range(start_idx, end_idx))
            test_indices_array = np.array(test_indices, dtype=np.intp)

        # Create initial training set from remaining groups
        train_group_indices_list = [
            i for i in range(self.n_groups) if i not in test_group_indices
        ]
        if use_exact_indices and group_indices_list is not None:
            # Use exact indices
            train_arrays = [group_indices_list[g] for g in train_group_indices_list]
            train_indices_array = (
                np.concatenate(train_arrays) if train_arrays else np.array([], dtype=np.intp)
            )
        else:
            # Use boundaries with range
            train_indices: list[int] = []
            for group_idx in train_group_indices_list:
                start_idx, end_idx = group_boundaries[group_idx]
                train_indices.extend(range(start_idx, end_idx))
            train_indices_array = np.array(train_indices, dtype=np.intp)

        # Apply purging and embargo between test groups and training data
        clean_train_indices = self._apply_group_purging_and_embargo(
            train_indices_array,
            test_group_indices,
            group_boundaries,
            n_samples,
            timestamps,
            groups,  # Pass groups for multi-asset awareness
            group_indices_list,  # Pass exact indices for session-aligned purging
        )

        # Apply group isolation if requested
        if self.isolate_groups and groups is not None:
            clean_train_indices = isolate_groups_from_train(
                clean_train_indices, test_indices_array, groups
            )

        # CPCV Invariant: train set must not be empty after purging
        if len(clean_train_indices) == 0:
            raise ValueError(
                f"CPCV invariant violated: train set is empty after purging/embargo. "
                f"Test groups: {test_group_indices}. "
                f"Consider reducing label_horizon ({self.label_horizon}) or "
                f"embargo_size ({self.embargo_size}) or embargo_pct ({self.embargo_pct})."
            )

        # CPCV Invariant: train and test sets must be disjoint
        overlap = np.intersect1d(clean_train_indices, test_indices_array)
        if len(overlap) > 0:
            raise ValueError(
                f"CPCV invariant violated: train and test sets have {len(overlap)} "
                f"overlapping indices. First few: {overlap[:5].tolist()}"
            )

        # Return sorted indices for deterministic behavior
        yield np.sort(clean_train_indices), np.sort(test_indices_array)

CombinatorialConfig

Bases: SplitterConfig

Configuration for Combinatorial Cross-Validation (CPCV).

Combinatorial CV is designed for multi-asset strategies and combating overfitting by creating multiple test sets from combinatorial group selections.

Reference: Bailey & Lopez de Prado (2014) "The Deflated Sharpe Ratio: Correcting for Selection Bias, Backtest Overfitting and Non-Normality"

Attributes

n_groups : int Number of groups to partition the timeline into (typically 8-12). n_test_groups : int Number of groups used for each test set (typically 2-3). Total folds = C(n_groups, n_test_groups). max_combinations : int | None Maximum number of folds to generate. If C(n_groups, n_test_groups) > max_combinations, randomly sample. contiguous_test_blocks : bool If True, only use contiguous test groups (reduces overfitting). If False, allow any combination (more folds).

validate_n_test_groups classmethod

validate_n_test_groups(v, info)

Validate that n_test_groups < n_groups (must leave groups for training).

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("n_test_groups")
@classmethod
def validate_n_test_groups(cls, v: int, info) -> int:
    """Validate that n_test_groups < n_groups (must leave groups for training)."""
    n_groups = info.data.get("n_groups")
    if n_groups is not None and v >= n_groups:
        raise ValueError(
            f"n_test_groups ({v}) cannot exceed n_groups ({n_groups}). "
            f"Must leave at least one group for training. "
            f"Typically n_test_groups is 2-3 for CPCV."
        )
    return v

validate_embargo_mutual_exclusivity

validate_embargo_mutual_exclusivity()

Validate that embargo_td and embargo_pct are mutually exclusive.

Source code in src/ml4t/diagnostic/splitters/config.py
@model_validator(mode="after")
def validate_embargo_mutual_exclusivity(self) -> CombinatorialConfig:
    """Validate that embargo_td and embargo_pct are mutually exclusive."""
    if self.embargo_td is not None and self.embargo_pct is not None:
        raise ValueError(
            "Cannot specify both 'embargo_td' and 'embargo_pct'. "
            "Choose one method for setting the embargo period."
        )
    return self

WalkForwardCV

WalkForwardCV(
    config=None,
    *,
    n_splits=5,
    test_size=None,
    train_size=None,
    gap=0,
    label_horizon=0,
    embargo_size=None,
    embargo_pct=None,
    expanding=True,
    consecutive=False,
    calendar=None,
    align_to_sessions=False,
    session_col="session_date",
    timestamp_col=None,
    isolate_groups=False,
    test_period=None,
    test_start=None,
    test_end=None,
    fold_direction="forward",
)

Bases: BaseSplitter

Walk-forward cross-validator for time-series data.

Walk-forward CV creates sequential train/test splits where training data always precedes test data. Includes optional safeguards against data leakage from overlapping labels and autocorrelation.

Parameters

n_splits : int, default=5 Number of splits to generate.

int, float, str, or None, optional

Size of each test set: - If int: number of samples (e.g., 1000) - If float: proportion of dataset (e.g., 0.1) - If str: time period using pandas offset aliases (e.g., "4W", "30D", "3M") - If None: uses 1 / (n_splits + 1) Time-based specifications require X to have a DatetimeIndex.

int, float, str, or None, optional

Size of each training set: - If int: number of samples (e.g., 10000) - If float: proportion of dataset (e.g., 0.5) - If str: time period using pandas offset aliases (e.g., "78W", "6M", "2Y") - If None: uses all available data before test set Time-based specifications require X to have a DatetimeIndex.

int, default=0

Gap between training and test set (in addition to label_horizon).

int or pd.Timedelta, default=0

How far ahead labels look into the future. Removes training samples whose prediction targets overlap with validation/test data.

Example: If predicting 5-day forward returns, a training sample at day 95 has a label computed from prices on days 95-100. If validation starts at day 98, this training sample's label "sees" validation data, creating leakage. Setting label_horizon=5 removes training samples from days 93-97.

int or pd.Timedelta, optional

Buffer zone after test periods. For standard walk-forward CV where training always precedes test, this has no effect. It is primarily used by CombinatorialCV.

float, optional

Embargo size as percentage of total samples.

bool, default=True

If True, training window expands with each split. If False, uses fixed-size rolling window.

bool, default=False

If True, uses consecutive (back-to-back) test periods with no gaps. This is appropriate for walk-forward validation where you want to simulate realistic trading with sequential validation periods. If False, spreads test periods across the dataset to sample different time periods (useful for testing robustness across market regimes).

str, CalendarConfig, or TradingCalendar, optional

Trading calendar for calendar-aware time period calculations. - If str: Name of pandas_market_calendars calendar (e.g., 'CME_Equity', 'NYSE') Creates default CalendarConfig with UTC timezone - If CalendarConfig: Full configuration with exchange, timezone, and options - If TradingCalendar: Pre-configured calendar instance - If None: Uses naive time-based calculation (backward compatible)

For intraday data with time-based test_size/train_size (e.g., '4W'), using a calendar ensures proper session-aware splitting: - Trading sessions are atomic units (won't split Sunday 5pm - Friday 4pm) - Handles varying data density in activity-based data (dollar bars, trade bars) - Proper timezone handling for tz-naive and tz-aware data - '1D' selections: Complete trading sessions - '4W' selections: Complete trading weeks (e.g., 4 weeks of 5 sessions each)

Examples:

from ml4t.diagnostic.splitters.calendar_config import CME_CONFIG cv = WalkForwardCV(test_size='4W', calendar=CME_CONFIG) # CME futures cv = WalkForwardCV(test_size='1W', calendar='NYSE') # US equities (simple)

bool, default=False

If True, align fold boundaries to trading session boundaries. Requires X to have a session column (specified by session_col parameter).

Trading sessions should be assigned using the qdata library before cross-validation: - Use DataManager with exchange/calendar parameters, or - Use SessionAssigner.from_exchange('CME') directly

When enabled, fold boundaries will never split a trading session, preventing subtle lookahead bias in intraday strategies.

str, default='session_date'

Name of the column containing session identifiers. Only used if align_to_sessions=True. This column should be added by qdata.sessions.SessionAssigner

bool, default=False

If True, prevent the same group (asset/symbol) from appearing in both train and test sets. This is critical for multi-asset validation to avoid data leakage.

Requires passing groups parameter to split() method with asset IDs.

Example:

cv = WalkForwardCV(n_splits=5, isolate_groups=True) for train, test in cv.split(df, groups=df['symbol']): ... # train and test will have completely different symbols ... pass

Attributes:

n_splits_ : int The number of splits.

Examples:

import numpy as np from ml4t.diagnostic.splitters import WalkForwardCV X = np.arange(100).reshape(100, 1) cv = WalkForwardCV(n_splits=3, label_horizon=5, embargo_size=2) for train, test in cv.split(X): ... print(f"Train: {len(train)}, Test: {len(test)}") Train: 17, Test: 25 Train: 40, Test: 25 Train: 63, Test: 25

Initialize WalkForwardCV.

This splitter uses a config-first architecture. You can either: 1. Pass a config object: WalkForwardCV(config=my_config) 2. Pass individual parameters: WalkForwardCV(n_splits=5, test_size=100)

Parameters are automatically converted to a config object internally, ensuring a single source of truth for all validation and logic.

Examples

Approach 1: Direct parameters (convenient)

cv = WalkForwardCV(n_splits=5, test_size=100)

Approach 2: Config object (for serialization/reproducibility)

from ml4t.diagnostic.splitters.config import WalkForwardConfig config = WalkForwardConfig(n_splits=5, test_size=100) cv = WalkForwardCV(config=config)

Approach 3: With held-out test period

cv = WalkForwardCV( ... n_splits=5, ... test_period="52D", # Reserve most recent 52 days for final evaluation ... test_size=20, # 20-day validation folds ... train_size=252, # 1-year training windows ... label_horizon=5, # 5 trading days gap ... calendar="NYSE", # NYSE trading calendar ... fold_direction="backward", # Folds step backward from test ... )

Validation folds (step backward from held-out test)

for train_idx, val_idx in cv.split(X): ... model.fit(X.iloc[train_idx], y.iloc[train_idx])

Final evaluation on held-out test

test_score = model.score(X.iloc[cv.test_indices_], y.iloc[cv.test_indices_])

Source code in src/ml4t/diagnostic/splitters/walk_forward.py
def __init__(
    self,
    config: WalkForwardConfig | None = None,
    *,
    n_splits: int = 5,
    test_size: float | None = None,
    train_size: float | None = None,
    gap: int = 0,
    label_horizon: int | pd.Timedelta = 0,
    embargo_size: int | pd.Timedelta | None = None,
    embargo_pct: float | None = None,
    expanding: bool = True,
    consecutive: bool = False,
    calendar: str | CalendarConfig | TradingCalendar | None = None,
    align_to_sessions: bool = False,
    session_col: str = "session_date",
    timestamp_col: str | None = None,
    isolate_groups: bool = False,
    # New parameters for held-out test
    test_period: int | str | None = None,
    test_start: date | str | None = None,
    test_end: date | str | None = None,
    fold_direction: Literal["forward", "backward"] = "forward",
) -> None:
    """Initialize WalkForwardCV.

    This splitter uses a config-first architecture. You can either:
    1. Pass a config object: WalkForwardCV(config=my_config)
    2. Pass individual parameters: WalkForwardCV(n_splits=5, test_size=100)

    Parameters are automatically converted to a config object internally,
    ensuring a single source of truth for all validation and logic.

    Examples
    --------
    >>> # Approach 1: Direct parameters (convenient)
    >>> cv = WalkForwardCV(n_splits=5, test_size=100)
    >>>
    >>> # Approach 2: Config object (for serialization/reproducibility)
    >>> from ml4t.diagnostic.splitters.config import WalkForwardConfig
    >>> config = WalkForwardConfig(n_splits=5, test_size=100)
    >>> cv = WalkForwardCV(config=config)
    >>>
    >>> # Approach 3: With held-out test period
    >>> cv = WalkForwardCV(
    ...     n_splits=5,
    ...     test_period="52D",      # Reserve most recent 52 days for final evaluation
    ...     test_size=20,           # 20-day validation folds
    ...     train_size=252,         # 1-year training windows
    ...     label_horizon=5,        # 5 trading days gap
    ...     calendar="NYSE",        # NYSE trading calendar
    ...     fold_direction="backward",  # Folds step backward from test
    ... )
    >>>
    >>> # Validation folds (step backward from held-out test)
    >>> for train_idx, val_idx in cv.split(X):
    ...     model.fit(X.iloc[train_idx], y.iloc[train_idx])
    >>> # Final evaluation on held-out test
    >>> test_score = model.score(X.iloc[cv.test_indices_], y.iloc[cv.test_indices_])
    """
    # Config-first: either use provided config or create from params
    if config is not None:
        # Explicit config provided
        # Verify no conflicting parameters were passed
        non_default_params = []
        if n_splits != 5:
            non_default_params.append("n_splits")
        if test_size is not None:
            non_default_params.append("test_size")
        if train_size is not None:
            non_default_params.append("train_size")
        if gap != 0:
            non_default_params.append("gap")
        if label_horizon != 0:
            non_default_params.append("label_horizon")
        if embargo_size is not None:
            non_default_params.append("embargo_size")
        if embargo_pct is not None:
            non_default_params.append("embargo_pct")
        if not expanding:
            non_default_params.append("expanding")
        if consecutive:
            non_default_params.append("consecutive")
        if calendar is not None:
            non_default_params.append("calendar")
        if align_to_sessions:
            non_default_params.append("align_to_sessions")
        if session_col != "session_date":
            non_default_params.append("session_col")
        if timestamp_col is not None:
            non_default_params.append("timestamp_col")
        if isolate_groups:
            non_default_params.append("isolate_groups")
        if test_period is not None:
            non_default_params.append("test_period")
        if test_start is not None:
            non_default_params.append("test_start")
        if test_end is not None:
            non_default_params.append("test_end")
        if fold_direction != "forward":
            non_default_params.append("fold_direction")

        if non_default_params:
            raise ValueError(
                f"Cannot specify both 'config' and individual parameters. "
                f"Got config plus: {', '.join(non_default_params)}"
            )

        self.config = config
    else:
        # Create config from individual parameters
        # Note: embargo_size maps to embargo_td in config
        # Determine calendar_id: explicit str overrides default,
        # CalendarConfig/TradingCalendar handled separately,
        # None means use config default ("NYSE")
        config_kwargs: dict[str, Any] = {
            "n_splits": n_splits,
            "test_size": test_size,
            "train_size": train_size,
            "label_horizon": label_horizon,
            "embargo_td": embargo_size,
            "align_to_sessions": align_to_sessions,
            "session_col": session_col,
            "timestamp_col": timestamp_col,
            "isolate_groups": isolate_groups,
            "test_period": test_period,
            "test_start": test_start,
            "test_end": test_end,
            "fold_direction": fold_direction,
        }
        if isinstance(calendar, str):
            config_kwargs["calendar_id"] = calendar
        elif calendar is not None:
            # CalendarConfig or TradingCalendar — extract exchange name
            if isinstance(calendar, CalendarConfig):
                config_kwargs["calendar_id"] = calendar.exchange
            elif isinstance(calendar, TradingCalendar):
                config_kwargs["calendar_id"] = calendar.config.exchange
        # When calendar is None, omit calendar_id to use config default ("NYSE")

        self.config = WalkForwardConfig(**config_kwargs)

    # Handle calendar initialization
    # Use calendar_id from config if no calendar parameter provided
    effective_calendar = calendar
    if effective_calendar is None and self.config.calendar_id is not None:
        effective_calendar = self.config.calendar_id

    if effective_calendar is None:
        self.calendar = None
    elif isinstance(effective_calendar, str | CalendarConfig):
        self.calendar = TradingCalendar(effective_calendar)
    elif isinstance(effective_calendar, TradingCalendar):
        self.calendar = effective_calendar
    else:
        raise TypeError(
            f"calendar must be str, CalendarConfig, TradingCalendar, or None, got {type(effective_calendar)}"
        )

    # Legacy attributes for compatibility with existing split() implementation
    # These reference the config values
    self.gap = gap
    self.embargo_pct = embargo_pct
    self.expanding = expanding
    self.consecutive = consecutive

    # Private state for held-out test (populated after split() is called)
    self._test_indices: np.ndarray | None = None
    self._test_start_idx: int | None = None
    self._test_end_idx: int | None = None

n_splits property

n_splits

Number of cross-validation folds.

test_size property

test_size

Test set size specification.

train_size property

train_size

Training set size specification.

label_horizon property

label_horizon

Forward-looking period of labels.

embargo_size property

embargo_size

Embargo buffer size.

align_to_sessions property

align_to_sessions

Whether to align fold boundaries to sessions.

session_col property

session_col

Column name containing session identifiers.

timestamp_col property

timestamp_col

Column name containing timestamps for time-based sizes.

isolate_groups property

isolate_groups

Whether to prevent group overlap between train/test.

test_period property

test_period

Held-out test period specification.

test_start_date property

test_start_date

Explicit start date for held-out test period.

test_end_date property

test_end_date

Explicit end date for held-out test period.

fold_direction property

fold_direction

Direction of validation folds.

calendar_id property

calendar_id

Trading calendar identifier.

test_indices_ property

test_indices_

Held-out test indices (populated after split() is called).

Returns

ndarray Indices reserved for the held-out test period.

Raises

ValueError If no held-out test is configured or split() hasn't been called.

Examples

cv = WalkForwardCV(n_splits=5, test_period="52D") for train_idx, val_idx in cv.split(X): ... pass # Training loop

Now test_indices_ is available

final_score = model.score(X.iloc[cv.test_indices_], y.iloc[cv.test_indices_])

fold_summary_ property

fold_summary_

Per-fold summary of train/val boundaries, sizes, and buffer gaps.

Available after split() has been fully consumed (all folds yielded).

Returns

pd.DataFrame One row per fold with columns: fold, train_start, train_end, val_start, val_end, train_rows, val_rows, train_timestamps, val_timestamps, train_span, val_span, buffer_gap_timestamps, buffer_gap_duration.

Raises

ValueError If split() has not been called yet.

Examples

cv = WalkForwardCV(n_splits=5, label_horizon=21) for train_idx, val_idx in cv.split(X): ... pass print(cv.fold_summary_) cv.fold_summary_to_csv("folds.csv")

fold_summary_to_csv

fold_summary_to_csv(path)

Write fold summary to CSV for verification.

Parameters

path : str Output CSV file path.

Source code in src/ml4t/diagnostic/splitters/walk_forward.py
def fold_summary_to_csv(self, path: str) -> None:
    """Write fold summary to CSV for verification.

    Parameters
    ----------
    path : str
        Output CSV file path.
    """
    self.fold_summary_.to_csv(path, index=False)
    logger.info("Fold summary written to %s", path)

get_n_splits

get_n_splits(X=None, y=None, groups=None)

Get number of splits.

Parameters

X : array-like, optional Always ignored, exists for compatibility.

array-like, optional

Always ignored, exists for compatibility.

array-like, optional

Always ignored, exists for compatibility.

Returns:

n_splits : int Number of splits.

Source code in src/ml4t/diagnostic/splitters/walk_forward.py
def get_n_splits(
    self,
    X: pl.DataFrame | pd.DataFrame | NDArray[Any] | None = None,
    y: pl.Series | pd.Series | NDArray[Any] | None = None,
    groups: pl.Series | pd.Series | NDArray[Any] | None = None,
) -> int:
    """Get number of splits.

    Parameters
    ----------
    X : array-like, optional
        Always ignored, exists for compatibility.

    y : array-like, optional
        Always ignored, exists for compatibility.

    groups : array-like, optional
        Always ignored, exists for compatibility.

    Returns:
    -------
    n_splits : int
        Number of splits.
    """
    del X, y, groups  # Unused, for sklearn compatibility
    return self.n_splits

split

split(X, y=None, groups=None)

Generate train/validation indices for walk-forward splits.

When a held-out test period is configured (test_period or test_start), this method yields train/validation splits for cross-validation, and the held-out test indices are accessible via test_indices_ property.

Parameters

X : array-like of shape (n_samples, n_features) Training data.

array-like of shape (n_samples,), optional

Target variable.

array-like of shape (n_samples,), optional

Group labels for samples.

Yields:

train : ndarray Training set indices for this split.

ndarray

Validation set indices for this split (or test if no held-out test).

Notes

When using held-out test mode with fold_direction="backward":

[train1][val1][train2][val2][train3][val3] | [HELD-OUT TEST]
        ←     ←     ←     ←     ←     ←     test_start

Validation folds step backward from the test boundary, ensuring that all validation is done on data chronologically before the held-out test.

Source code in src/ml4t/diagnostic/splitters/walk_forward.py
def split(
    self,
    X: pl.DataFrame | pd.DataFrame | NDArray[Any],
    y: pl.Series | pd.Series | NDArray[Any] | None = None,
    groups: pl.Series | pd.Series | NDArray[Any] | None = None,
) -> Generator[tuple[NDArray[np.intp], NDArray[np.intp]], None, None]:
    """Generate train/validation indices for walk-forward splits.

    When a held-out test period is configured (test_period or test_start),
    this method yields train/validation splits for cross-validation, and
    the held-out test indices are accessible via test_indices_ property.

    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Training data.

    y : array-like of shape (n_samples,), optional
        Target variable.

    groups : array-like of shape (n_samples,), optional
        Group labels for samples.

    Yields:
    ------
    train : ndarray
        Training set indices for this split.

    val : ndarray
        Validation set indices for this split (or test if no held-out test).

    Notes
    -----
    When using held-out test mode with fold_direction="backward":

    ```
    [train1][val1][train2][val2][train3][val3] | [HELD-OUT TEST]
            ←     ←     ←     ←     ←     ←     test_start
    ```

    Validation folds step backward from the test boundary, ensuring that
    all validation is done on data chronologically before the held-out test.
    """
    # Validate inputs and get sample count
    n_samples = self._validate_data(X, y, groups)

    # Validate session alignment if enabled
    self._validate_session_alignment(X, self.align_to_sessions, self.session_col)

    # Extract timestamps for held-out test computation
    timestamps = self._extract_timestamps(X, self.timestamp_col)

    # Initialize fold summary tracking
    self._fold_records: list[dict[str, Any]] = []
    self._split_timestamps = timestamps

    # Select the appropriate splitting generator
    if self._has_held_out_test():
        test_start_idx, test_end_idx = self._compute_test_period(n_samples, timestamps)
        self._test_start_idx = test_start_idx
        self._test_end_idx = test_end_idx
        self._test_indices = np.arange(test_start_idx, test_end_idx, dtype=np.intp)

        if self.fold_direction == "backward":
            inner = self._split_backward(X, y, groups, n_samples, test_start_idx, timestamps)
        else:
            inner = self._split_forward_with_test(
                X, y, groups, n_samples, test_start_idx, timestamps
            )
    else:
        self._test_indices = None
        self._test_start_idx = None
        self._test_end_idx = None

        has_timestamps = self._extract_timestamps(X, self.timestamp_col) is not None
        if self._should_use_calendar_splitting() and has_timestamps:
            if self.align_to_sessions:
                warnings.warn(
                    "align_to_sessions=True is deprecated when calendar is active. "
                    "Calendar-first splitting subsumes session alignment.",
                    DeprecationWarning,
                    stacklevel=2,
                )
            inner = self._split_by_calendar(X, y, groups, n_samples)
        elif self.align_to_sessions:
            inner = self._split_by_sessions(
                cast(pl.DataFrame | pd.DataFrame, X), y, groups, n_samples
            )
        else:
            inner = self._split_by_samples(X, y, groups, n_samples)

    # Yield from inner generator while recording fold metadata
    for fold_i, (train_idx, val_idx) in enumerate(inner):
        self._record_fold(fold_i, train_idx, val_idx, timestamps)
        yield train_idx, val_idx

WalkForwardConfig

Bases: SplitterConfig

Configuration for Walk-Forward Cross-Validation.

Walk-forward validation is the standard approach for time-series backtesting, where the model is trained on historical data and tested on future periods.

Attributes

test_size : int | float | str | None Size of validation folds. Alias: val_size. - int: Number of samples (or sessions if align_to_sessions=True) - float: Proportion of dataset (0.0 to 1.0) - str: Time-based ('4W', '3M') - NOT supported with align_to_sessions=True - None: Auto-calculated to maintain equal test set sizes train_size : int | float | str | None Training set size specification (same format as test_size). If None, uses expanding window (all data before test set). step_size : int | None Step size between consecutive splits: - int: Number of samples (or sessions if align_to_sessions=True) - None: Defaults to test_size (non-overlapping test sets) test_period : int | str | None Held-out test period specification (reserves most recent data for final evaluation): - int: Number of trading days (requires calendar_id) - str: Time-based ('52D', '4W') - None: No held-out test period (default, legacy behavior) test_start : date | str | None Explicit start date for held-out test period. Mutually exclusive with test_period. Accepts date object or ISO format string ('2024-01-01'). Alias: holdout_start. test_end : date | str | None Explicit end date for held-out test period. Default: end of data. Accepts date object or ISO format string ('2024-12-31'). Alias: holdout_end. fold_direction : Literal["forward", "backward"] Direction of validation folds: - "forward": Traditional walk-forward (folds step forward in time) - "backward": Folds step backward from held-out test boundary calendar_id : str | None Trading calendar for trading-day-aware gap calculations. Examples: "NYSE", "CME_Equity", "LSE" Required when label_horizon is int and you want trading-day interpretation.

val_size property

val_size

Alias for test_size.

holdout_start property

holdout_start

Alias for test_start.

holdout_end property

holdout_end

Alias for test_end.

validate_size_with_sessions classmethod

validate_size_with_sessions(v, info)

Validate that time-based sizes are not used with session alignment.

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("test_size", "train_size")
@classmethod
def validate_size_with_sessions(
    cls, v: int | float | str | None, info
) -> int | float | str | None:
    """Validate that time-based sizes are not used with session alignment."""
    if v is None:
        return v

    align_to_sessions = info.data.get("align_to_sessions", False)
    if align_to_sessions and isinstance(v, str):
        raise ValueError(
            f"align_to_sessions=True does not support time-based size specifications. "
            f"Use integer (number of sessions) or float (proportion). Got: {v!r}"
        )
    return v

validate_test_dates classmethod

validate_test_dates(v)

Convert string dates to date objects.

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("test_start", "test_end")
@classmethod
def validate_test_dates(cls, v: date | str | None) -> date | None:
    """Convert string dates to date objects."""
    if v is None:
        return v
    if isinstance(v, date):
        return v
    if isinstance(v, str):
        try:
            return date.fromisoformat(v)
        except ValueError as e:
            raise ValueError(
                f"Could not parse date string '{v}'. Use ISO format: 'YYYY-MM-DD'"
            ) from e
    raise ValueError(f"test_start/test_end must be date or ISO string, got {type(v)}")

validate_test_period classmethod

validate_test_period(v, info)

Validate test_period specification.

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("test_period")
@classmethod
def validate_test_period(cls, v: int | str | None, info) -> int | str | None:
    """Validate test_period specification."""
    if v is None:
        return v

    if isinstance(v, int):
        if v <= 0:
            raise ValueError("test_period must be a positive integer (trading days)")
        return v

    if isinstance(v, str):
        # Validate time-based format (e.g., "52D", "4W")
        import pandas as pd

        try:
            pd.Timedelta(v)
        except Exception as e:
            raise ValueError(
                f"Could not parse test_period string '{v}' as Timedelta. "
                f"Use formats like '52D', '4W', '3M'. Error: {e}"
            ) from e
        return v

    raise ValueError(f"test_period must be int or str, got {type(v)}")

validate_calendar_and_sessions

validate_calendar_and_sessions()

Warn when align_to_sessions is used alongside calendar_id.

Source code in src/ml4t/diagnostic/splitters/config.py
@model_validator(mode="after")
def validate_calendar_and_sessions(self) -> WalkForwardConfig:
    """Warn when align_to_sessions is used alongside calendar_id."""
    if self.align_to_sessions and self.calendar_id is not None:
        warnings.warn(
            "align_to_sessions=True is deprecated when calendar_id is set. "
            "Calendar-first splitting subsumes session alignment. "
            "Remove align_to_sessions=True to silence this warning.",
            DeprecationWarning,
            stacklevel=2,
        )
    return self

validate_held_out_test_config

validate_held_out_test_config()

Validate held-out test configuration consistency.

Source code in src/ml4t/diagnostic/splitters/config.py
@model_validator(mode="after")
def validate_held_out_test_config(self) -> WalkForwardConfig:
    """Validate held-out test configuration consistency."""
    # test_period and test_start are mutually exclusive
    if self.test_period is not None and self.test_start is not None:
        raise ValueError(
            "Cannot specify both 'test_period' and 'test_start'. "
            "'test_period' reserves most recent data, "
            "'test_start' specifies an explicit date range."
        )

    # test_end without test_start or test_period is invalid
    if self.test_end is not None and self.test_start is None and self.test_period is None:
        raise ValueError(
            "'test_end' requires either 'test_period' or 'test_start' to define the held-out test."
        )

    # test_period as int requires calendar_id for trading-day interpretation
    if isinstance(self.test_period, int) and self.calendar_id is None:
        warnings.warn(
            f"test_period={self.test_period} (int) without calendar_id will be interpreted "
            "as calendar days, not trading days. Set calendar_id for trading-day interpretation.",
            UserWarning,
            stacklevel=2,
        )

    # label_horizon as int with calendar_id should use trading days
    if (
        isinstance(self.label_horizon, int)
        and self.label_horizon > 0
        and self.calendar_id is not None
    ):
        # Valid configuration - label_horizon will be converted to trading days
        pass

    return self

SplitterConfig

Bases: BaseConfig

Base configuration for all cross-validation splitters.

All splitter configs inherit from this class to ensure consistent serialization, validation, and reproducibility.

Attributes

n_splits : int Number of cross-validation folds.

int or pd.Timedelta

Gap between train_end and val_start sized to the label horizon. Removes training samples whose prediction targets overlap with validation/test data ("label buffer").

Example: If predicting 5-day forward returns, a training sample at day 95 has a label computed from days 95-100. If the test set starts at day 98, this training sample's label "sees" test data, creating leakage. Setting label_horizon=5 removes training samples from days 93-97.

Aliases: label_buffer is accepted as an equivalent input name.

int or pd.Timedelta or None

Buffer zone after test periods where training samples are also excluded ("feature buffer"). Prevents autocorrelation leakage in combinatorial CV where training data can follow test data chronologically.

For standard walk-forward CV (training always before test), this has no effect.

Aliases: feature_buffer is accepted as an equivalent input name.

bool

If True, fold boundaries are aligned to trading session boundaries. Requires 'session_date' column in data (from ml4t.data.sessions.SessionAssigner).

str

Column name containing session identifiers. Default: 'session_date' (standard qdata column name).

bool

If True, ensures no overlap between train/test group identifiers. Useful for multi-asset validation to prevent data leakage.

label_buffer property

label_buffer

Alias for label_horizon (preferred name).

feature_buffer property

feature_buffer

Alias for embargo_td (preferred name).

validate_label_horizon classmethod

validate_label_horizon(v)

Validate label_horizon is either int >= 0 or a timedelta-like object.

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("label_horizon")
@classmethod
def validate_label_horizon(cls, v: Any) -> Any:
    """Validate label_horizon is either int >= 0 or a timedelta-like object."""
    if isinstance(v, int):
        if v < 0:
            raise ValueError("label_horizon must be greater than or equal to 0")
        return v
    # Allow timedelta-like objects (pd.Timedelta, datetime.timedelta)
    if hasattr(v, "total_seconds"):
        return v
    # Handle ISO 8601 duration strings from JSON serialization
    if isinstance(v, str):
        import pandas as pd

        try:
            return pd.Timedelta(v)
        except Exception as e:
            raise ValueError(  # noqa: B904
                f"Could not parse label_horizon/label_buffer string '{v}' as Timedelta. "
                f"Expected formats: '5D', '21D', '1W', '8h'. Error: {e}"
            )
    raise ValueError(f"label_horizon must be int >= 0 or timedelta-like object, got {type(v)}")

validate_embargo_td classmethod

validate_embargo_td(v)

Validate embargo_td is either None, int >= 0, or a timedelta-like object.

Source code in src/ml4t/diagnostic/splitters/config.py
@field_validator("embargo_td")
@classmethod
def validate_embargo_td(cls, v: Any) -> Any:
    """Validate embargo_td is either None, int >= 0, or a timedelta-like object."""
    if v is None:
        return v
    if isinstance(v, int):
        if v < 0:
            raise ValueError("embargo_td must be greater than or equal to 0")
        return v
    # Allow timedelta-like objects (pd.Timedelta, datetime.timedelta)
    if hasattr(v, "total_seconds"):
        return v
    # Handle ISO 8601 duration strings from JSON serialization
    if isinstance(v, str):
        import pandas as pd

        try:
            return pd.Timedelta(v)
        except Exception as e:
            raise ValueError(  # noqa: B904
                f"Could not parse embargo_td/feature_buffer string '{v}' as Timedelta. "
                f"Expected formats: '5D', '1W', '0D'. Error: {e}"
            )
    raise ValueError(
        f"embargo_td must be None, int >= 0, or timedelta-like object, got {type(v)}"
    )

save_config

save_config(config, filepath)

Save splitter configuration to disk.

This is a convenience wrapper around config.to_json() for consistency with the persistence API.

Parameters

config : SplitterConfig Configuration object to save. filepath : str or Path Path to save configuration (JSON format).

Examples

from ml4t.diagnostic.splitters.config import WalkForwardConfig config = WalkForwardConfig(n_splits=5, test_size=100) save_config(config, "cv_config.json")

Source code in src/ml4t/diagnostic/splitters/persistence.py
def save_config(
    config: Any,  # SplitterConfig or subclass
    filepath: str | Path,
) -> None:
    """Save splitter configuration to disk.

    This is a convenience wrapper around config.to_json() for consistency
    with the persistence API.

    Parameters
    ----------
    config : SplitterConfig
        Configuration object to save.
    filepath : str or Path
        Path to save configuration (JSON format).

    Examples
    --------
    >>> from ml4t.diagnostic.splitters.config import WalkForwardConfig
    >>> config = WalkForwardConfig(n_splits=5, test_size=100)
    >>> save_config(config, "cv_config.json")
    """
    filepath = Path(filepath)
    config.to_json(filepath)

load_config

load_config(filepath, config_class)

Load splitter configuration from disk.

This is a convenience wrapper around config_class.from_json() for consistency with the persistence API.

Parameters

filepath : str or Path Path to saved configuration (JSON format). config_class : type Configuration class to instantiate (e.g., WalkForwardConfig).

Returns

config : SplitterConfig Loaded configuration object.

Examples

from ml4t.diagnostic.splitters.config import WalkForwardConfig config = load_config("cv_config.json", WalkForwardConfig) print(config.n_splits)

Source code in src/ml4t/diagnostic/splitters/persistence.py
def load_config(
    filepath: str | Path,
    config_class: type[BaseConfig],
) -> BaseConfig:
    """Load splitter configuration from disk.

    This is a convenience wrapper around config_class.from_json() for consistency
    with the persistence API.

    Parameters
    ----------
    filepath : str or Path
        Path to saved configuration (JSON format).
    config_class : type
        Configuration class to instantiate (e.g., WalkForwardConfig).

    Returns
    -------
    config : SplitterConfig
        Loaded configuration object.

    Examples
    --------
    >>> from ml4t.diagnostic.splitters.config import WalkForwardConfig
    >>> config = load_config("cv_config.json", WalkForwardConfig)
    >>> print(config.n_splits)
    """
    filepath = Path(filepath)
    return config_class.from_json(filepath)

save_folds

save_folds(
    folds,
    X,
    filepath,
    *,
    metadata=None,
    include_timestamps=True,
)

Save cross-validation folds to disk.

Parameters

folds : list[tuple[NDArray, NDArray]] List of (train_indices, test_indices) tuples from CV splitter. X : array-like or DataFrame Original data used for splitting (for timestamp extraction if DataFrame). filepath : str or Path Path to save fold configuration (JSON format). metadata : dict, optional Additional metadata to store (e.g., splitter config, data info). include_timestamps : bool, default=True If True and X is a DataFrame with DatetimeIndex, save timestamps alongside indices for better human readability.

Examples

from ml4t.diagnostic.splitters import WalkForwardCV cv = WalkForwardCV(n_splits=5, test_size=100) folds = list(cv.split(X)) save_folds(folds, X, "cv_folds.json", metadata={"n_splits": 5})

Source code in src/ml4t/diagnostic/splitters/persistence.py
def save_folds(
    folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
    X: NDArray[np.floating] | pd.DataFrame | pl.DataFrame,
    filepath: str | Path,
    *,
    metadata: dict[str, Any] | None = None,
    include_timestamps: bool = True,
) -> None:
    """Save cross-validation folds to disk.

    Parameters
    ----------
    folds : list[tuple[NDArray, NDArray]]
        List of (train_indices, test_indices) tuples from CV splitter.
    X : array-like or DataFrame
        Original data used for splitting (for timestamp extraction if DataFrame).
    filepath : str or Path
        Path to save fold configuration (JSON format).
    metadata : dict, optional
        Additional metadata to store (e.g., splitter config, data info).
    include_timestamps : bool, default=True
        If True and X is a DataFrame with DatetimeIndex, save timestamps
        alongside indices for better human readability.

    Examples
    --------
    >>> from ml4t.diagnostic.splitters import WalkForwardCV
    >>> cv = WalkForwardCV(n_splits=5, test_size=100)
    >>> folds = list(cv.split(X))
    >>> save_folds(folds, X, "cv_folds.json", metadata={"n_splits": 5})
    """
    filepath = Path(filepath)

    # Extract timestamps if available
    timestamps = None
    if include_timestamps and isinstance(X, pd.DataFrame | pd.Series):
        if isinstance(X.index, pd.DatetimeIndex):
            timestamps = X.index.astype(str).tolist()
    elif include_timestamps and isinstance(X, pl.DataFrame):
        # Polars doesn't have index, check if first column is datetime
        first_col = X.columns[0]
        if X[first_col].dtype == pl.Datetime:
            timestamps = X[first_col].cast(pl.Utf8).to_list()

    # Build fold data structure
    fold_data: dict[str, Any] = {
        "version": "1.0",
        "n_folds": len(folds),
        "n_samples": len(X),
        "folds": [],
        "metadata": metadata or {},
    }

    if timestamps:
        fold_data["timestamps"] = timestamps

    for fold_idx, (train_idx, test_idx) in enumerate(folds):
        fold_info = {
            "fold_id": fold_idx,
            "train_indices": train_idx.tolist(),
            "test_indices": test_idx.tolist(),
            "train_size": len(train_idx),
            "test_size": len(test_idx),
        }

        # Add timestamp ranges if available (handle empty folds)
        if timestamps:
            if len(train_idx) > 0:
                fold_info["train_start"] = timestamps[train_idx[0]]
                fold_info["train_end"] = timestamps[train_idx[-1]]
            else:
                fold_info["train_start"] = None
                fold_info["train_end"] = None

            if len(test_idx) > 0:
                fold_info["test_start"] = timestamps[test_idx[0]]
                fold_info["test_end"] = timestamps[test_idx[-1]]
            else:
                fold_info["test_start"] = None
                fold_info["test_end"] = None

        fold_data["folds"].append(fold_info)

    # Save to JSON
    filepath.parent.mkdir(parents=True, exist_ok=True)
    with filepath.open("w") as f:
        json.dump(fold_data, f, indent=2)

load_folds

load_folds(filepath)

Load cross-validation folds from disk.

Parameters

filepath : str or Path Path to saved fold configuration (JSON format).

Returns

folds : list[tuple[NDArray, NDArray]] List of (train_indices, test_indices) tuples. metadata : dict Metadata dictionary stored with folds.

Examples

folds, metadata = load_folds("cv_folds.json") print(f"Loaded {len(folds)} folds") print(f"Metadata: {metadata}")

Source code in src/ml4t/diagnostic/splitters/persistence.py
def load_folds(
    filepath: str | Path,
) -> tuple[list[tuple[NDArray[np.int_], NDArray[np.int_]]], dict[str, Any]]:
    """Load cross-validation folds from disk.

    Parameters
    ----------
    filepath : str or Path
        Path to saved fold configuration (JSON format).

    Returns
    -------
    folds : list[tuple[NDArray, NDArray]]
        List of (train_indices, test_indices) tuples.
    metadata : dict
        Metadata dictionary stored with folds.

    Examples
    --------
    >>> folds, metadata = load_folds("cv_folds.json")
    >>> print(f"Loaded {len(folds)} folds")
    >>> print(f"Metadata: {metadata}")
    """
    filepath = Path(filepath)

    if not filepath.exists():
        raise FileNotFoundError(f"Fold file not found: {filepath}")

    with filepath.open("r") as f:
        fold_data = json.load(f)

    # Validate version
    if fold_data.get("version") != "1.0":
        raise ValueError(f"Unsupported fold file version: {fold_data.get('version')}")

    # Reconstruct folds
    folds = []
    for fold_info in fold_data["folds"]:
        train_idx = np.array(fold_info["train_indices"], dtype=np.int_)
        test_idx = np.array(fold_info["test_indices"], dtype=np.int_)
        folds.append((train_idx, test_idx))

    metadata = fold_data.get("metadata", {})

    return folds, metadata

verify_folds

verify_folds(folds, n_samples)

Verify fold integrity and compute statistics.

Parameters

folds : list[tuple[NDArray, NDArray]] List of (train_indices, test_indices) tuples. n_samples : int Total number of samples in dataset.

Returns

stats : dict Dictionary containing fold statistics and validation results.

Examples

folds, _ = load_folds("cv_folds.json") stats = verify_folds(folds, n_samples=1000) print(f"Valid: {stats['valid']}") print(f"Coverage: {stats['coverage']:.1%}")

Source code in src/ml4t/diagnostic/splitters/persistence.py
def verify_folds(
    folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
    n_samples: int,
) -> dict[str, Any]:
    """Verify fold integrity and compute statistics.

    Parameters
    ----------
    folds : list[tuple[NDArray, NDArray]]
        List of (train_indices, test_indices) tuples.
    n_samples : int
        Total number of samples in dataset.

    Returns
    -------
    stats : dict
        Dictionary containing fold statistics and validation results.

    Examples
    --------
    >>> folds, _ = load_folds("cv_folds.json")
    >>> stats = verify_folds(folds, n_samples=1000)
    >>> print(f"Valid: {stats['valid']}")
    >>> print(f"Coverage: {stats['coverage']:.1%}")
    """
    stats: dict[str, Any] = {
        "valid": True,
        "errors": [],
        "n_folds": len(folds),
        "n_samples": n_samples,
        "train_sizes": [],
        "test_sizes": [],
    }

    all_train_indices: set[int] = set()
    all_test_indices: set[int] = set()

    for fold_idx, (train_idx, test_idx) in enumerate(folds):
        stats["train_sizes"].append(len(train_idx))
        stats["test_sizes"].append(len(test_idx))

        # Check for index overlap within fold
        overlap = set(train_idx) & set(test_idx)
        if overlap:
            stats["valid"] = False
            stats["errors"].append(
                f"Fold {fold_idx}: {len(overlap)} overlapping indices between train and test"
            )

        # Check for out-of-range indices
        if np.any(train_idx < 0) or np.any(train_idx >= n_samples):
            stats["valid"] = False
            stats["errors"].append(f"Fold {fold_idx}: Train indices out of range")

        if np.any(test_idx < 0) or np.any(test_idx >= n_samples):
            stats["valid"] = False
            stats["errors"].append(f"Fold {fold_idx}: Test indices out of range")

        all_train_indices.update(train_idx)
        all_test_indices.update(test_idx)

    # Compute coverage statistics
    all_indices = all_train_indices | all_test_indices
    stats["coverage"] = len(all_indices) / n_samples
    stats["train_coverage"] = len(all_train_indices) / n_samples
    stats["test_coverage"] = len(all_test_indices) / n_samples

    # Compute size statistics
    if stats["train_sizes"]:
        train_sizes: list[int] = stats["train_sizes"]
        test_sizes: list[int] = stats["test_sizes"]
        stats["avg_train_size"] = np.mean(train_sizes)
        stats["std_train_size"] = np.std(train_sizes)
        stats["avg_test_size"] = np.mean(test_sizes)
        stats["std_test_size"] = np.std(test_sizes)

    return stats

Evaluation Workflows

These workflows live under ml4t.diagnostic.evaluation:

Area Objects
Generic orchestration Evaluator, EvaluationResult, ValidatedCrossValidation
Feature and signal diagnostics FeatureDiagnostics, MultiSignalAnalysis, analyze_ml_importance, compute_ic_hac_stats
Portfolio and backtest evaluation PortfolioAnalysis, factor attribution helpers
Trade diagnostics TradeAnalysis, TradeShapAnalyzer, TradeShapResult
Event and barrier workflows EventStudyAnalysis, BarrierAnalysis

Statistical Tests

stats

Statistical tests for financial ML evaluation.

This package implements advanced statistical tests used in ml4t-diagnostic's Three-Tier Framework:

Multiple Testing Corrections: - Deflated Sharpe Ratio (DSR) for selection bias correction - Rademacher Anti-Serum (RAS) for correlation-aware multiple testing - False Discovery Rate (FDR) and Family-Wise Error Rate (FWER) corrections

Time Series Inference: - HAC-adjusted Information Coefficient for autocorrelated data - Stationary bootstrap for temporal dependence preservation

Strategy Comparison: - White's Reality Check for multiple strategy comparison - Probability of Backtest Overfitting (PBO)

All tests are implemented with: - Mathematical correctness validated against academic references - Proper handling of autocorrelation and heteroskedasticity - Numerical stability for edge cases - Support for both single and multiple hypothesis testing

Module Decomposition (v1.4+)

The stats package is organized into focused modules:

Sharpe Ratio Analysis: - moments.py: Return statistics (Sharpe, skewness, kurtosis, autocorr) - sharpe_inference.py: Variance estimation, expected max calculation - effective_trials.py: Correlation-adjusted K_eff estimators for DSR - minimum_track_record.py: Minimum Track Record Length - backtest_overfitting.py: Probability of Backtest Overfitting - deflated_sharpe_ratio.py: DSR/PSR orchestration layer (main entry points)

Other Statistical Tests: - rademacher_adjustment.py: Rademacher complexity and RAS adjustments - bootstrap.py: Stationary bootstrap methods - hac_standard_errors.py: HAC-adjusted IC estimation - false_discovery_rate.py: FDR and FWER corrections - reality_check.py: White's Reality Check

All original imports are preserved for backward compatibility.

hac_adjusted_ic module-attribute

hac_adjusted_ic = robust_ic

deflated_sharpe_ratio_from_statistics

deflated_sharpe_ratio_from_statistics(
    observed_sharpe,
    n_samples,
    n_trials=1,
    variance_trials=0.0,
    benchmark_sharpe=0.0,
    skewness=0.0,
    excess_kurtosis=0.0,
    autocorrelation=0.0,
    confidence_level=0.95,
    frequency="daily",
    periods_per_year=None,
    *,
    effective_trials=None,
    correlation_method=None,
    min_k_eff=1.0,
)

Compute DSR/PSR from pre-computed statistics.

Use this when you have already computed the required statistics. For most users, deflated_sharpe_ratio() with raw returns is recommended.

Parameters

observed_sharpe : float Observed Sharpe ratio at native frequency. n_samples : int Number of return observations (T). n_trials : int, default 1 Number of strategies tested (K). variance_trials : float, default 0.0 Cross-sectional variance of Sharpe ratios. benchmark_sharpe : float, default 0.0 Null hypothesis threshold. skewness : float, default 0.0 Return skewness. excess_kurtosis : float, default 0.0 Return excess kurtosis (Fisher, normal=0). autocorrelation : float, default 0.0 First-order autocorrelation. confidence_level : float, default 0.95 Confidence level for testing. frequency : {"daily", "weekly", "monthly"}, default "daily" Return frequency. periods_per_year : int, optional Periods per year. effective_trials : float, optional Correlation-adjusted effective number of trials, K_eff. correlation_method : {"effective_rank", "marchenko_pastur", "clustering"}, optional Metadata describing how effective_trials was estimated. min_k_eff : float, default 1.0 Lower bound applied to effective_trials before computing the expected-maximum Sharpe term. Ignored when raw n_trials is used.

Returns

DSRResult Same as deflated_sharpe_ratio().

Source code in src/ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py
def deflated_sharpe_ratio_from_statistics(
    observed_sharpe: float,
    n_samples: int,
    n_trials: int = 1,
    variance_trials: float = 0.0,
    benchmark_sharpe: float = 0.0,
    skewness: float = 0.0,
    excess_kurtosis: float = 0.0,
    autocorrelation: float = 0.0,
    confidence_level: float = 0.95,
    frequency: Frequency = "daily",
    periods_per_year: int | None = None,
    *,
    effective_trials: float | None = None,
    correlation_method: EffectiveTrialsMethod | None = None,
    min_k_eff: float = 1.0,
) -> DSRResult:
    """Compute DSR/PSR from pre-computed statistics.

    Use this when you have already computed the required statistics.
    For most users, `deflated_sharpe_ratio()` with raw returns is recommended.

    Parameters
    ----------
    observed_sharpe : float
        Observed Sharpe ratio at native frequency.
    n_samples : int
        Number of return observations (T).
    n_trials : int, default 1
        Number of strategies tested (K).
    variance_trials : float, default 0.0
        Cross-sectional variance of Sharpe ratios.
    benchmark_sharpe : float, default 0.0
        Null hypothesis threshold.
    skewness : float, default 0.0
        Return skewness.
    excess_kurtosis : float, default 0.0
        Return excess kurtosis (Fisher, normal=0).
    autocorrelation : float, default 0.0
        First-order autocorrelation.
    confidence_level : float, default 0.95
        Confidence level for testing.
    frequency : {"daily", "weekly", "monthly"}, default "daily"
        Return frequency.
    periods_per_year : int, optional
        Periods per year.
    effective_trials : float, optional
        Correlation-adjusted effective number of trials, K_eff.
    correlation_method : {"effective_rank", "marchenko_pastur", "clustering"}, optional
        Metadata describing how ``effective_trials`` was estimated.
    min_k_eff : float, default 1.0
        Lower bound applied to ``effective_trials`` before computing the
        expected-maximum Sharpe term. Ignored when raw ``n_trials`` is used.

    Returns
    -------
    DSRResult
        Same as `deflated_sharpe_ratio()`.
    """
    # Validate inputs
    if n_samples < 1:
        raise ValueError("n_samples must be positive")
    if n_trials < 1:
        raise ValueError("n_trials must be positive")
    if n_trials > 1 and variance_trials <= 0:
        raise ValueError("variance_trials must be positive when n_trials > 1")
    if abs(autocorrelation) >= 1:
        raise ValueError("autocorrelation must be in (-1, 1)")
    if effective_trials is not None and effective_trials < 1:
        raise ValueError("effective_trials must be >= 1")
    if min_k_eff < 1:
        raise ValueError("min_k_eff must be >= 1")

    kurtosis = excess_kurtosis + 3.0

    if periods_per_year is None:
        periods_per_year = DEFAULT_PERIODS_PER_YEAR[frequency]

    annualization_factor = np.sqrt(periods_per_year)
    effective_k = (
        max(float(effective_trials), float(min_k_eff))
        if effective_trials is not None and n_trials > 1
        else effective_trials
    )
    trials_for_adjustment = effective_k if effective_k is not None else float(n_trials)

    # Expected max Sharpe
    expected_max = compute_expected_max_sharpe(trials_for_adjustment, variance_trials)
    adjusted_threshold = benchmark_sharpe + expected_max

    # Variance
    variance_sr = compute_sharpe_variance(
        sharpe=adjusted_threshold,
        n_samples=n_samples,
        skewness=skewness,
        kurtosis=kurtosis,
        autocorrelation=autocorrelation,
        n_trials=trials_for_adjustment,
    )
    std_sr = np.sqrt(variance_sr)

    # Z-score
    if std_sr > 0:
        z_score = (observed_sharpe - adjusted_threshold) / std_sr
    else:
        z_score = np.inf if observed_sharpe > adjusted_threshold else -np.inf

    probability = float(norm.cdf(z_score))
    p_value = float(1 - probability)
    is_significant = probability >= confidence_level

    sharpe_annualized = observed_sharpe * annualization_factor
    deflated = observed_sharpe - expected_max

    # MinTRL
    min_trl = _compute_min_trl_core(
        observed_sharpe=observed_sharpe,
        target_sharpe=benchmark_sharpe,
        confidence_level=confidence_level,
        skewness=skewness,
        kurtosis=kurtosis,
        autocorrelation=autocorrelation,
    )
    min_trl_years = min_trl / periods_per_year
    has_adequate = n_samples >= min_trl

    return DSRResult(
        probability=probability,
        is_significant=is_significant,
        z_score=float(z_score),
        p_value=p_value,
        sharpe_ratio=float(observed_sharpe),
        sharpe_ratio_annualized=float(sharpe_annualized),
        benchmark_sharpe=benchmark_sharpe,
        n_samples=n_samples,
        n_trials=n_trials,
        n_trials_raw=n_trials,
        n_trials_effective=float(effective_k) if effective_k is not None else None,
        correlation_method=correlation_method,
        min_k_eff=float(min_k_eff),
        frequency=frequency,
        periods_per_year=periods_per_year,
        skewness=float(skewness),
        excess_kurtosis=float(excess_kurtosis),
        autocorrelation=float(autocorrelation),
        expected_max_sharpe=float(expected_max),
        deflated_sharpe=float(deflated),
        variance_trials=float(variance_trials),
        min_trl=min_trl,
        min_trl_years=float(min_trl_years),
        has_adequate_sample=has_adequate,
        confidence_level=confidence_level,
    )

compute_min_trl

compute_min_trl(
    returns=None,
    observed_sharpe=None,
    target_sharpe=0.0,
    confidence_level=0.95,
    frequency="daily",
    periods_per_year=None,
    *,
    skewness=None,
    excess_kurtosis=None,
    autocorrelation=None,
)

Compute Minimum Track Record Length (MinTRL).

MinTRL is the minimum number of observations required to reject the null hypothesis (SR <= target) at the specified confidence level.

Parameters

returns : array-like, optional Return series. If provided, statistics are computed from it. observed_sharpe : float, optional Observed Sharpe ratio. Required if returns not provided. target_sharpe : float, default 0.0 Null hypothesis threshold (SR₀). confidence_level : float, default 0.95 Required confidence level (1 - α). frequency : {"daily", "weekly", "monthly"}, default "daily" Return frequency. periods_per_year : int, optional Periods per year (for converting to calendar time). skewness : float, optional Override computed skewness. excess_kurtosis : float, optional Override computed excess kurtosis (Fisher convention, normal=0). autocorrelation : float, optional Override computed autocorrelation.

Returns

MinTRLResult Results including min_trl, min_trl_years, and adequacy assessment. min_trl can be math.inf if observed SR <= target SR.

Examples

From returns:

result = compute_min_trl(daily_returns, frequency="daily") print(f"Need {result.min_trl_years:.1f} years of data")

From statistics:

result = compute_min_trl( ... observed_sharpe=0.5, ... target_sharpe=0.0, ... confidence_level=0.95, ... skewness=-1.0, ... excess_kurtosis=2.0, ... autocorrelation=0.1, ... )

Source code in src/ml4t/diagnostic/evaluation/stats/minimum_track_record.py
def compute_min_trl(
    returns: ArrayLike | None = None,
    observed_sharpe: float | None = None,
    target_sharpe: float = 0.0,
    confidence_level: float = 0.95,
    frequency: Frequency = "daily",
    periods_per_year: int | None = None,
    *,
    skewness: float | None = None,
    excess_kurtosis: float | None = None,
    autocorrelation: float | None = None,
) -> MinTRLResult:
    """Compute Minimum Track Record Length (MinTRL).

    MinTRL is the minimum number of observations required to reject the null
    hypothesis (SR <= target) at the specified confidence level.

    Parameters
    ----------
    returns : array-like, optional
        Return series. If provided, statistics are computed from it.
    observed_sharpe : float, optional
        Observed Sharpe ratio. Required if returns not provided.
    target_sharpe : float, default 0.0
        Null hypothesis threshold (SR₀).
    confidence_level : float, default 0.95
        Required confidence level (1 - α).
    frequency : {"daily", "weekly", "monthly"}, default "daily"
        Return frequency.
    periods_per_year : int, optional
        Periods per year (for converting to calendar time).
    skewness : float, optional
        Override computed skewness.
    excess_kurtosis : float, optional
        Override computed excess kurtosis (Fisher convention, normal=0).
    autocorrelation : float, optional
        Override computed autocorrelation.

    Returns
    -------
    MinTRLResult
        Results including min_trl, min_trl_years, and adequacy assessment.
        min_trl can be math.inf if observed SR <= target SR.

    Examples
    --------
    From returns:

    >>> result = compute_min_trl(daily_returns, frequency="daily")
    >>> print(f"Need {result.min_trl_years:.1f} years of data")

    From statistics:

    >>> result = compute_min_trl(
    ...     observed_sharpe=0.5,
    ...     target_sharpe=0.0,
    ...     confidence_level=0.95,
    ...     skewness=-1.0,
    ...     excess_kurtosis=2.0,
    ...     autocorrelation=0.1,
    ... )
    """
    # Resolve periods per year
    if periods_per_year is None:
        periods_per_year = DEFAULT_PERIODS_PER_YEAR[frequency]

    # Get statistics from returns or use provided values
    if returns is not None:
        ret_arr = np.asarray(returns).flatten()
        ret_arr = ret_arr[~np.isnan(ret_arr)]
        obs_sr, comp_skew, comp_kurt, comp_rho, n_samples = compute_return_statistics(ret_arr)

        if observed_sharpe is None:
            observed_sharpe = obs_sr
    else:
        if observed_sharpe is None:
            raise ValueError("Either returns or observed_sharpe must be provided")
        n_samples = 0  # Unknown
        comp_skew = 0.0
        comp_kurt = 3.0  # Pearson
        comp_rho = 0.0

    # Use provided or computed statistics
    skew = skewness if skewness is not None else comp_skew
    if excess_kurtosis is not None:
        kurt = excess_kurtosis + 3.0  # Fisher -> Pearson
    else:
        kurt = comp_kurt
    rho = autocorrelation if autocorrelation is not None else comp_rho

    # Compute MinTRL
    min_trl = _compute_min_trl_core(
        observed_sharpe=observed_sharpe,
        target_sharpe=target_sharpe,
        confidence_level=confidence_level,
        skewness=skew,
        kurtosis=kurt,
        autocorrelation=rho,
    )

    is_inf = math.isinf(min_trl)
    min_trl_years = float("inf") if is_inf else min_trl / periods_per_year
    has_adequate = False if is_inf or n_samples == 0 else n_samples >= min_trl
    deficit = (
        float("inf") if is_inf else max(0.0, min_trl - n_samples) if n_samples > 0 else min_trl
    )
    deficit_years = float("inf") if is_inf else deficit / periods_per_year

    return MinTRLResult(
        min_trl=min_trl,
        min_trl_years=float(min_trl_years),
        current_samples=n_samples,
        has_adequate_sample=has_adequate,
        deficit=deficit,
        deficit_years=float(deficit_years),
        observed_sharpe=float(observed_sharpe),
        target_sharpe=target_sharpe,
        confidence_level=confidence_level,
        skewness=float(skew),
        excess_kurtosis=float(kurt - 3.0),
        autocorrelation=float(rho),
        frequency=frequency,
        periods_per_year=periods_per_year,
    )

min_trl_fwer

min_trl_fwer(
    observed_sharpe,
    n_trials,
    variance_trials,
    target_sharpe=0.0,
    confidence_level=0.95,
    frequency="daily",
    periods_per_year=None,
    *,
    skewness=0.0,
    excess_kurtosis=0.0,
    autocorrelation=0.0,
)

Compute MinTRL under FWER multiple testing adjustment.

When selecting the best strategy from K trials, the MinTRL must be adjusted to account for the selection bias.

Parameters

observed_sharpe : float Observed Sharpe ratio of the best strategy. n_trials : int Number of strategies tested (K). variance_trials : float Cross-sectional variance of Sharpe ratios. target_sharpe : float, default 0.0 Original null hypothesis threshold. confidence_level : float, default 0.95 Required confidence level. frequency : {"daily", "weekly", "monthly"}, default "daily" Return frequency. periods_per_year : int, optional Periods per year. skewness : float, default 0.0 Return skewness. excess_kurtosis : float, default 0.0 Return excess kurtosis (Fisher, normal=0). autocorrelation : float, default 0.0 Return autocorrelation.

Returns

MinTRLResult Results with min_trl adjusted for multiple testing.

Source code in src/ml4t/diagnostic/evaluation/stats/minimum_track_record.py
def min_trl_fwer(
    observed_sharpe: float,
    n_trials: int,
    variance_trials: float,
    target_sharpe: float = 0.0,
    confidence_level: float = 0.95,
    frequency: Frequency = "daily",
    periods_per_year: int | None = None,
    *,
    skewness: float = 0.0,
    excess_kurtosis: float = 0.0,
    autocorrelation: float = 0.0,
) -> MinTRLResult:
    """Compute MinTRL under FWER multiple testing adjustment.

    When selecting the best strategy from K trials, the MinTRL must be adjusted
    to account for the selection bias.

    Parameters
    ----------
    observed_sharpe : float
        Observed Sharpe ratio of the best strategy.
    n_trials : int
        Number of strategies tested (K).
    variance_trials : float
        Cross-sectional variance of Sharpe ratios.
    target_sharpe : float, default 0.0
        Original null hypothesis threshold.
    confidence_level : float, default 0.95
        Required confidence level.
    frequency : {"daily", "weekly", "monthly"}, default "daily"
        Return frequency.
    periods_per_year : int, optional
        Periods per year.
    skewness : float, default 0.0
        Return skewness.
    excess_kurtosis : float, default 0.0
        Return excess kurtosis (Fisher, normal=0).
    autocorrelation : float, default 0.0
        Return autocorrelation.

    Returns
    -------
    MinTRLResult
        Results with min_trl adjusted for multiple testing.
    """
    if periods_per_year is None:
        periods_per_year = DEFAULT_PERIODS_PER_YEAR[frequency]

    kurtosis = excess_kurtosis + 3.0

    # Compute expected max Sharpe (selection bias adjustment)
    expected_max = compute_expected_max_sharpe(n_trials, variance_trials)
    adjusted_target = target_sharpe + expected_max

    # Compute MinTRL with adjusted target
    min_trl = _compute_min_trl_core(
        observed_sharpe=observed_sharpe,
        target_sharpe=adjusted_target,
        confidence_level=confidence_level,
        skewness=skewness,
        kurtosis=kurtosis,
        autocorrelation=autocorrelation,
    )

    is_inf = math.isinf(min_trl)
    min_trl_years = float("inf") if is_inf else min_trl / periods_per_year

    return MinTRLResult(
        min_trl=min_trl,
        min_trl_years=float(min_trl_years),
        current_samples=0,
        has_adequate_sample=False,
        deficit=min_trl,
        deficit_years=float(min_trl_years),
        observed_sharpe=float(observed_sharpe),
        target_sharpe=float(adjusted_target),
        confidence_level=confidence_level,
        skewness=float(skewness),
        excess_kurtosis=float(excess_kurtosis),
        autocorrelation=float(autocorrelation),
        frequency=frequency,
        periods_per_year=periods_per_year,
    )

compute_pbo

compute_pbo(is_performance, oos_performance)

Compute Probability of Backtest Overfitting (PBO).

PBO measures the probability that a strategy selected as best in-sample performs below median out-of-sample. A high PBO indicates overfitting.

Definition

From Bailey & López de Prado (2014):

.. math::

PBO = P(rank_{OOS}(\arg\max_{IS}) > N/2)

In plain English: what's the probability that the best in-sample strategy ranks in the bottom half out-of-sample?

Interpretation
  • PBO = 0%: No overfitting (best IS is also best OOS)
  • PBO = 50%: Random selection (IS performance uncorrelated with OOS)
  • PBO > 50%: Severe overfitting (IS selection is counterproductive)
Parameters

is_performance : np.ndarray, shape (n_folds, n_strategies) or (n_combinations,) In-sample performance metrics (Sharpe, IC, returns) for each strategy. oos_performance : np.ndarray, shape (n_folds, n_strategies) or (n_combinations,) Out-of-sample performance metrics (same structure as is_performance).

Returns

PBOResult Result object with PBO and diagnostic metrics. Call .interpret() for human-readable assessment.

Raises

ValueError If arrays have different shapes or fewer than 2 strategies.

Examples

import numpy as np

10 CV folds, 5 strategies

is_perf = np.random.randn(10, 5) oos_perf = np.random.randn(10, 5) result = compute_pbo(is_perf, oos_perf) print(result.interpret())

References

Bailey, D. H., & López de Prado, M. (2014). "The Probability of Backtest Overfitting." Journal of Computational Finance, 20(4), 39-69.

Source code in src/ml4t/diagnostic/evaluation/stats/backtest_overfitting.py
def compute_pbo(
    is_performance: np.ndarray[Any, np.dtype[Any]],
    oos_performance: np.ndarray[Any, np.dtype[Any]],
) -> PBOResult:
    """Compute Probability of Backtest Overfitting (PBO).

    PBO measures the probability that a strategy selected as best in-sample
    performs below median out-of-sample. A high PBO indicates overfitting.

    Definition
    ----------
    From Bailey & López de Prado (2014):

    .. math::

        PBO = P(rank_{OOS}(\\arg\\max_{IS}) > N/2)

    In plain English: what's the probability that the best in-sample strategy
    ranks in the bottom half out-of-sample?

    Interpretation
    --------------
    - PBO = 0%: No overfitting (best IS is also best OOS)
    - PBO = 50%: Random selection (IS performance uncorrelated with OOS)
    - PBO > 50%: Severe overfitting (IS selection is counterproductive)

    Parameters
    ----------
    is_performance : np.ndarray, shape (n_folds, n_strategies) or (n_combinations,)
        In-sample performance metrics (Sharpe, IC, returns) for each strategy.
    oos_performance : np.ndarray, shape (n_folds, n_strategies) or (n_combinations,)
        Out-of-sample performance metrics (same structure as is_performance).

    Returns
    -------
    PBOResult
        Result object with PBO and diagnostic metrics.
        Call .interpret() for human-readable assessment.

    Raises
    ------
    ValueError
        If arrays have different shapes or fewer than 2 strategies.

    Examples
    --------
    >>> import numpy as np
    >>> # 10 CV folds, 5 strategies
    >>> is_perf = np.random.randn(10, 5)
    >>> oos_perf = np.random.randn(10, 5)
    >>> result = compute_pbo(is_perf, oos_perf)
    >>> print(result.interpret())

    References
    ----------
    Bailey, D. H., & López de Prado, M. (2014). "The Probability of Backtest
    Overfitting." Journal of Computational Finance, 20(4), 39-69.
    """
    is_performance = np.asarray(is_performance)
    oos_performance = np.asarray(oos_performance)

    if is_performance.shape != oos_performance.shape:
        raise ValueError(
            f"is_performance and oos_performance must have same shape. "
            f"Got {is_performance.shape} vs {oos_performance.shape}"
        )

    # Handle 1D input (single combination with multiple strategies)
    if is_performance.ndim == 1:
        is_performance = is_performance.reshape(1, -1)
        oos_performance = oos_performance.reshape(1, -1)

    n_combinations, n_strategies = is_performance.shape

    if n_strategies < 2:
        raise ValueError(f"Need at least 2 strategies, got {n_strategies}")

    # For each combination, find the IS-best strategy and its OOS rank
    is_best_oos_ranks = []
    degradations = []

    for i in range(n_combinations):
        is_row = is_performance[i, :]
        oos_row = oos_performance[i, :]

        # Find strategy with best IS performance
        is_best_idx = np.argmax(is_row)
        is_best_is_perf = is_row[is_best_idx]
        is_best_oos_perf = oos_row[is_best_idx]

        # Compute OOS rank of IS-best strategy (1 = best, N = worst)
        oos_ranks = n_strategies - np.argsort(np.argsort(oos_row))
        is_best_oos_rank = oos_ranks[is_best_idx]
        is_best_oos_ranks.append(is_best_oos_rank)

        # Compute degradation (IS - OOS performance)
        degradations.append(is_best_is_perf - is_best_oos_perf)

    ranks_arr = np.array(is_best_oos_ranks)
    degrad_arr = np.array(degradations)

    # PBO = P(IS-best ranks in bottom half OOS)
    median_rank = (n_strategies + 1) / 2
    n_below_median = np.sum(ranks_arr > median_rank)
    pbo = n_below_median / n_combinations

    return PBOResult(
        pbo=float(pbo),
        pbo_pct=float(pbo * 100),
        n_combinations=int(n_combinations),
        n_strategies=int(n_strategies),
        is_best_rank_oos_median=float(np.median(ranks_arr)),
        is_best_rank_oos_mean=float(np.mean(ranks_arr)),
        degradation_mean=float(np.mean(degrad_arr)),
        degradation_std=float(np.std(degrad_arr)),
    )

ras_ic_adjustment

ras_ic_adjustment(
    observed_ic,
    complexity,
    n_samples,
    delta=0.05,
    kappa=0.02,
    return_result=False,
)

Apply RAS adjustment for Information Coefficients (bounded metrics).

Computes conservative lower bounds on true IC values accounting for data snooping and estimation error.

Formula (Hoeffding concentration for |IC| ≤ κ):

θₙ ≥ θ̂ₙ - 2R̂ - 2κ√(log(2/δ)/T)
       ───   ─────────────────
       (a)         (b)
where

(a) = data snooping penalty from testing N strategies (b) = estimation error for bounded r.v. (Hoeffding's inequality)

Parameters

observed_ic : ndarray of shape (N,) Observed Information Coefficients for N strategies. complexity : float Rademacher complexity R̂ from rademacher_complexity(). n_samples : int Number of time periods T used to compute ICs. delta : float, default=0.05 Significance level (1 - confidence). Lower = more conservative. kappa : float, default=0.02 Bound on |IC|. Critical parameter.

Practical guidance (Paleologo 2024, p.273):
- κ=0.02: Typical alpha signals
- κ=0.05: High-conviction signals
- κ=1.0: Theoretical maximum (usually too conservative)

return_result : bool, default=False If True, return RASResult dataclass with full diagnostics.

Returns

ndarray or RASResult If return_result=False: Adjusted IC lower bounds (N,). If return_result=True: RASResult with full diagnostics.

Raises

ValueError If inputs are invalid or observed ICs exceed kappa bound.

Warns

UserWarning If any |observed_ic| > κ (theoretical guarantee violated).

Notes

Derivation: 1. Data snooping: Standard Rademacher generalization bound gives 2R̂. 2. Estimation: For bounded r.v. |X| ≤ κ, Hoeffding gives P(|X̂ - X| > t) ≤ 2exp(-Tt²/2κ²). Setting RHS = δ yields t = κ√(2 log(2/δ)/T). Conservative factor 2 for two-sided.

Advantages over DSR: - Accounts for strategy correlation (R̂ ↓ as correlation ↑) - Non-asymptotic (valid for any T) - Zero false positives in Paleologo's simulations

Examples

import numpy as np X = np.random.randn(2500, 500) * 0.02 observed_ic = X.mean(axis=0) R_hat = rademacher_complexity(X) result = ras_ic_adjustment(observed_ic, R_hat, 2500, return_result=True) print(f"Significant: {result.n_significant}/{len(observed_ic)}")

References

.. [1] Paleologo (2024), Section 8.3.2, Procedure 8.1. .. [2] Hoeffding (1963), "Probability inequalities for sums of bounded random variables", JASA 58:13-30.

Source code in src/ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py
def ras_ic_adjustment(
    observed_ic: NDArray[Any],
    complexity: float,
    n_samples: int,
    delta: float = 0.05,
    kappa: float = 0.02,
    return_result: bool = False,
) -> NDArray[Any] | RASResult:
    """Apply RAS adjustment for Information Coefficients (bounded metrics).

    Computes conservative lower bounds on true IC values accounting for
    data snooping and estimation error.

    **Formula** (Hoeffding concentration for |IC| ≤ κ):

        θₙ ≥ θ̂ₙ - 2R̂ - 2κ√(log(2/δ)/T)
               ───   ─────────────────
               (a)         (b)

    where:
        (a) = data snooping penalty from testing N strategies
        (b) = estimation error for bounded r.v. (Hoeffding's inequality)

    Parameters
    ----------
    observed_ic : ndarray of shape (N,)
        Observed Information Coefficients for N strategies.
    complexity : float
        Rademacher complexity R̂ from `rademacher_complexity()`.
    n_samples : int
        Number of time periods T used to compute ICs.
    delta : float, default=0.05
        Significance level (1 - confidence). Lower = more conservative.
    kappa : float, default=0.02
        Bound on |IC|. **Critical parameter**.

        Practical guidance (Paleologo 2024, p.273):
        - κ=0.02: Typical alpha signals
        - κ=0.05: High-conviction signals
        - κ=1.0: Theoretical maximum (usually too conservative)
    return_result : bool, default=False
        If True, return RASResult dataclass with full diagnostics.

    Returns
    -------
    ndarray or RASResult
        If return_result=False: Adjusted IC lower bounds (N,).
        If return_result=True: RASResult with full diagnostics.

    Raises
    ------
    ValueError
        If inputs are invalid or observed ICs exceed kappa bound.

    Warns
    -----
    UserWarning
        If any |observed_ic| > κ (theoretical guarantee violated).

    Notes
    -----
    **Derivation**:
    1. Data snooping: Standard Rademacher generalization bound gives 2R̂.
    2. Estimation: For bounded r.v. |X| ≤ κ, Hoeffding gives
       P(|X̂ - X| > t) ≤ 2exp(-Tt²/2κ²). Setting RHS = δ yields
       t = κ√(2 log(2/δ)/T). Conservative factor 2 for two-sided.

    **Advantages over DSR**:
    - Accounts for strategy correlation (R̂ ↓ as correlation ↑)
    - Non-asymptotic (valid for any T)
    - Zero false positives in Paleologo's simulations

    Examples
    --------
    >>> import numpy as np
    >>> X = np.random.randn(2500, 500) * 0.02
    >>> observed_ic = X.mean(axis=0)
    >>> R_hat = rademacher_complexity(X)
    >>> result = ras_ic_adjustment(observed_ic, R_hat, 2500, return_result=True)
    >>> print(f"Significant: {result.n_significant}/{len(observed_ic)}")

    References
    ----------
    .. [1] Paleologo (2024), Section 8.3.2, Procedure 8.1.
    .. [2] Hoeffding (1963), "Probability inequalities for sums of bounded
           random variables", JASA 58:13-30.
    """
    observed_ic = np.asarray(observed_ic)

    if observed_ic.ndim != 1:
        raise ValueError(f"observed_ic must be 1D, got shape {observed_ic.shape}")

    if complexity < 0:
        raise ValueError(f"complexity must be non-negative, got {complexity}")

    if n_samples < 1:
        raise ValueError(f"n_samples must be positive, got {n_samples}")

    if not 0 < delta < 1:
        raise ValueError(f"delta must be in (0, 1), got {delta}")

    if kappa <= 0:
        raise ValueError(f"kappa must be positive, got {kappa}")

    # Warn if ICs exceed the bounded assumption
    max_abs_ic = np.max(np.abs(observed_ic))
    if max_abs_ic > kappa:
        warnings.warn(
            f"max(|IC|)={max_abs_ic:.4f} exceeds kappa={kappa}. "
            "Theoretical guarantees may not hold. Consider increasing kappa.",
            UserWarning,
            stacklevel=2,
        )

    N = len(observed_ic)
    T = n_samples

    # (a) Data snooping penalty: 2R̂
    data_snooping = 2 * complexity

    # (b) Estimation error: 2κ√(log(2/δ)/T) from Hoeffding
    estimation_error = 2 * kappa * np.sqrt(np.log(2 / delta) / T)

    # Conservative lower bound
    adjusted_ic = observed_ic - data_snooping - estimation_error

    if not return_result:
        return adjusted_ic

    # Compute diagnostics
    massart_bound = np.sqrt(2 * np.log(N) / T) if N > 1 else 0.0
    significant_mask = adjusted_ic > 0

    return RASResult(
        adjusted_values=adjusted_ic,
        observed_values=observed_ic,
        complexity=complexity,
        data_snooping_penalty=data_snooping,
        estimation_error=estimation_error,
        n_significant=int(np.sum(significant_mask)),
        significant_mask=significant_mask,
        massart_bound=massart_bound,
        complexity_ratio=complexity / massart_bound if massart_bound > 0 else 0.0,
    )

ras_sharpe_adjustment

ras_sharpe_adjustment(
    observed_sharpe,
    complexity,
    n_samples,
    n_strategies,
    delta=0.05,
    return_result=False,
)

Apply RAS adjustment for Sharpe ratios (sub-Gaussian metrics).

Computes conservative lower bounds on true Sharpe ratios accounting for data snooping, estimation error, and multiple testing.

Formula (sub-Gaussian concentration + union bound):

θₙ ≥ θ̂ₙ - 2R̂ - 3√(2 log(2/δ)/T) - √(2 log(2N/δ)/T)
       ───   ─────────────────────────────────────
       (a)              (b)              (c)
where

(a) = data snooping penalty (b) = sub-Gaussian estimation error (factor 3 for conservatism) © = union bound over N strategies

Parameters

observed_sharpe : ndarray of shape (N,) Observed (annualized) Sharpe ratios for N strategies. complexity : float Rademacher complexity R̂ from rademacher_complexity(). n_samples : int Number of time periods T used to compute Sharpe ratios. n_strategies : int Total number of strategies N tested. delta : float, default=0.05 Significance level (1 - confidence). Lower = more conservative. return_result : bool, default=False If True, return RASResult dataclass with full diagnostics.

Returns

ndarray or RASResult If return_result=False: Adjusted Sharpe lower bounds (N,). If return_result=True: RASResult with full diagnostics.

Notes

Derivation: 1. Data snooping: 2R̂ (standard Rademacher bound) 2. Sub-Gaussian error: For σ²-sub-Gaussian X, P(X > t) ≤ exp(-t²/2σ²). Daily returns typically have σ ≈ 1 when standardized. Factor 3 provides conservatism for heavier tails. 3. Union bound: P(∃n: |X̂ₙ - Xₙ| > t) ≤ N × single-strategy bound. Contributes √(2 log(2N/δ)/T) term.

Comparison to DSR: - DSR assumes independent strategies (overpenalizes correlated ones) - RAS captures correlation via R̂ (correlated → lower R̂ → less penalty) - RAS is non-asymptotic; DSR requires large T

Examples

import numpy as np returns = np.random.randn(252, 100) * 0.01 # 100 strategies, 1 year observed_sr = returns.mean(axis=0) / returns.std(axis=0) * np.sqrt(252) R_hat = rademacher_complexity(returns) result = ras_sharpe_adjustment( ... observed_sr, R_hat, 252, 100, return_result=True ... ) print(f"Significant: {result.n_significant}/100")

References

.. [1] Paleologo (2024), Section 8.3.2, Procedure 8.2.

Source code in src/ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py
def ras_sharpe_adjustment(
    observed_sharpe: NDArray[Any],
    complexity: float,
    n_samples: int,
    n_strategies: int,
    delta: float = 0.05,
    return_result: bool = False,
) -> NDArray[Any] | RASResult:
    """Apply RAS adjustment for Sharpe ratios (sub-Gaussian metrics).

    Computes conservative lower bounds on true Sharpe ratios accounting for
    data snooping, estimation error, and multiple testing.

    **Formula** (sub-Gaussian concentration + union bound):

        θₙ ≥ θ̂ₙ - 2R̂ - 3√(2 log(2/δ)/T) - √(2 log(2N/δ)/T)
               ───   ─────────────────────────────────────
               (a)              (b)              (c)

    where:
        (a) = data snooping penalty
        (b) = sub-Gaussian estimation error (factor 3 for conservatism)
        (c) = union bound over N strategies

    Parameters
    ----------
    observed_sharpe : ndarray of shape (N,)
        Observed (annualized) Sharpe ratios for N strategies.
    complexity : float
        Rademacher complexity R̂ from `rademacher_complexity()`.
    n_samples : int
        Number of time periods T used to compute Sharpe ratios.
    n_strategies : int
        Total number of strategies N tested.
    delta : float, default=0.05
        Significance level (1 - confidence). Lower = more conservative.
    return_result : bool, default=False
        If True, return RASResult dataclass with full diagnostics.

    Returns
    -------
    ndarray or RASResult
        If return_result=False: Adjusted Sharpe lower bounds (N,).
        If return_result=True: RASResult with full diagnostics.

    Notes
    -----
    **Derivation**:
    1. Data snooping: 2R̂ (standard Rademacher bound)
    2. Sub-Gaussian error: For σ²-sub-Gaussian X, P(X > t) ≤ exp(-t²/2σ²).
       Daily returns typically have σ ≈ 1 when standardized.
       Factor 3 provides conservatism for heavier tails.
    3. Union bound: P(∃n: |X̂ₙ - Xₙ| > t) ≤ N × single-strategy bound.
       Contributes √(2 log(2N/δ)/T) term.

    **Comparison to DSR**:
    - DSR assumes independent strategies (overpenalizes correlated ones)
    - RAS captures correlation via R̂ (correlated → lower R̂ → less penalty)
    - RAS is non-asymptotic; DSR requires large T

    Examples
    --------
    >>> import numpy as np
    >>> returns = np.random.randn(252, 100) * 0.01  # 100 strategies, 1 year
    >>> observed_sr = returns.mean(axis=0) / returns.std(axis=0) * np.sqrt(252)
    >>> R_hat = rademacher_complexity(returns)
    >>> result = ras_sharpe_adjustment(
    ...     observed_sr, R_hat, 252, 100, return_result=True
    ... )
    >>> print(f"Significant: {result.n_significant}/100")

    References
    ----------
    .. [1] Paleologo (2024), Section 8.3.2, Procedure 8.2.
    """
    observed_sharpe = np.asarray(observed_sharpe)

    if observed_sharpe.ndim != 1:
        raise ValueError(f"observed_sharpe must be 1D, got shape {observed_sharpe.shape}")

    if complexity < 0:
        raise ValueError(f"complexity must be non-negative, got {complexity}")

    if n_samples < 1:
        raise ValueError(f"n_samples must be positive, got {n_samples}")

    if n_strategies < 1:
        raise ValueError(f"n_strategies must be positive, got {n_strategies}")

    if not 0 < delta < 1:
        raise ValueError(f"delta must be in (0, 1), got {delta}")

    T = n_samples
    N = n_strategies

    # (a) Data snooping penalty: 2R̂
    data_snooping = 2 * complexity

    # (b) Sub-Gaussian estimation error (independent of N)
    # Factor 3 for conservatism with potential heavy tails
    error_term1 = 3 * np.sqrt(2 * np.log(2 / delta) / T)

    # (c) Union bound over N strategies
    error_term2 = np.sqrt(2 * np.log(2 * N / delta) / T)

    estimation_error = error_term1 + error_term2

    # Conservative lower bound
    adjusted_sharpe = observed_sharpe - data_snooping - estimation_error

    if not return_result:
        return adjusted_sharpe

    # Compute diagnostics
    massart_bound = np.sqrt(2 * np.log(N) / T) if N > 1 else 0.0
    significant_mask = adjusted_sharpe > 0

    return RASResult(
        adjusted_values=adjusted_sharpe,
        observed_values=observed_sharpe,
        complexity=complexity,
        data_snooping_penalty=data_snooping,
        estimation_error=estimation_error,
        n_significant=int(np.sum(significant_mask)),
        significant_mask=significant_mask,
        massart_bound=massart_bound,
        complexity_ratio=complexity / massart_bound if massart_bound > 0 else 0.0,
    )

benjamini_hochberg_fdr

benjamini_hochberg_fdr(
    p_values, alpha=0.05, return_details=False
)

Apply Benjamini-Hochberg False Discovery Rate correction.

Controls the False Discovery Rate (FDR) - the expected proportion of false discoveries among the rejected hypotheses. More powerful than Bonferroni correction for multiple hypothesis testing.

Based on Benjamini & Hochberg (1995): "Controlling the False Discovery Rate"

Parameters

p_values : Sequence[float] P-values from multiple hypothesis tests alpha : float, default 0.05 Target FDR level (e.g., 0.05 for 5% FDR) return_details : bool, default False Whether to return detailed information

Returns

Union[NDArray, dict] If return_details=False: Boolean array of rejected hypotheses If return_details=True: dict with 'rejected', 'adjusted_p_values', 'critical_values', 'n_rejected'

Examples

p_values = [0.001, 0.01, 0.03, 0.08, 0.12] rejected = benjamini_hochberg_fdr(p_values, alpha=0.05) print(f"Rejected: {rejected}") Rejected: [ True True True False False]

Source code in src/ml4t/diagnostic/evaluation/stats/false_discovery_rate.py
def benjamini_hochberg_fdr(
    p_values: Sequence[float],
    alpha: float = 0.05,
    return_details: bool = False,
) -> Union["NDArray[Any]", dict[str, Any]]:
    """Apply Benjamini-Hochberg False Discovery Rate correction.

    Controls the False Discovery Rate (FDR) - the expected proportion of false
    discoveries among the rejected hypotheses. More powerful than Bonferroni
    correction for multiple hypothesis testing.

    Based on Benjamini & Hochberg (1995): "Controlling the False Discovery Rate"

    Parameters
    ----------
    p_values : Sequence[float]
        P-values from multiple hypothesis tests
    alpha : float, default 0.05
        Target FDR level (e.g., 0.05 for 5% FDR)
    return_details : bool, default False
        Whether to return detailed information

    Returns
    -------
    Union[NDArray, dict]
        If return_details=False: Boolean array of rejected hypotheses
        If return_details=True: dict with 'rejected', 'adjusted_p_values',
                               'critical_values', 'n_rejected'

    Examples
    --------
    >>> p_values = [0.001, 0.01, 0.03, 0.08, 0.12]
    >>> rejected = benjamini_hochberg_fdr(p_values, alpha=0.05)
    >>> print(f"Rejected: {rejected}")
    Rejected: [ True  True  True False False]
    """
    p_array = np.array(p_values)
    n = len(p_array)

    if n == 0:
        if return_details:
            return {
                "rejected": np.array([], dtype=bool),
                "adjusted_p_values": np.array([]),
                "critical_values": np.array([]),
                "n_rejected": 0,
            }
        return np.array([], dtype=bool)

    # Sort p-values and keep track of original indices
    sorted_indices = np.argsort(p_array)
    sorted_p_values = p_array[sorted_indices]

    # Calculate critical values: (i/n) * alpha
    critical_values = np.arange(1, n + 1) / n * alpha

    # Find largest i such that P(i) <= (i/n) * alpha
    # Work backwards from largest p-value
    rejected_sorted = np.zeros(n, dtype=bool)

    for i in range(n - 1, -1, -1):
        if sorted_p_values[i] <= critical_values[i]:
            # Reject this and all smaller p-values
            rejected_sorted[: i + 1] = True
            break

    # Map back to original order
    rejected = np.zeros(n, dtype=bool)
    rejected[sorted_indices] = rejected_sorted

    if not return_details:
        return rejected

    # Calculate adjusted p-values (step-up method)
    adjusted_p_values = np.zeros(n)
    adjusted_p_values[sorted_indices] = np.minimum.accumulate(
        sorted_p_values[::-1] * n / np.arange(n, 0, -1),
    )[::-1]

    # Ensure adjusted p-values don't exceed 1
    adjusted_p_values = np.minimum(adjusted_p_values, 1.0)

    return {
        "rejected": rejected,
        "adjusted_p_values": adjusted_p_values,
        "critical_values": critical_values[sorted_indices],
        "n_rejected": int(np.sum(rejected)),
    }

holm_bonferroni

holm_bonferroni(p_values, alpha=0.05)

Holm-Bonferroni step-down procedure for FWER control.

Controls the Family-Wise Error Rate (FWER) - the probability of making at least one false discovery. More powerful than Bonferroni correction while maintaining strong FWER control.

Based on Holm (1979): "A Simple Sequentially Rejective Multiple Test Procedure"

Parameters

p_values : Sequence[float] P-values from multiple hypothesis tests alpha : float, default 0.05 Target FWER significance level

Returns

dict Dictionary with: - rejected: list[bool] - Whether each hypothesis is rejected - adjusted_p_values: list[float] - Holm-adjusted p-values - n_rejected: int - Number of rejections - critical_values: list[float] - Holm critical thresholds

Notes

The Holm procedure is a step-down method:

  1. Sort p-values ascending: p_(1) <= p_(2) <= ... <= p_(m)
  2. For p_(i), compare to alpha / (m - i + 1)
  3. Reject all hypotheses up to (and including) the last rejection
  4. Stop at first non-rejection; accept remaining hypotheses

This is uniformly more powerful than Bonferroni while controlling FWER.

Examples

p_values = [0.001, 0.01, 0.03, 0.08, 0.12] result = holm_bonferroni(p_values, alpha=0.05) print(f"Rejected: {result['rejected']}") Rejected: [True, True, False, False, False]

Source code in src/ml4t/diagnostic/evaluation/stats/false_discovery_rate.py
def holm_bonferroni(
    p_values: Sequence[float],
    alpha: float = 0.05,
) -> dict[str, Any]:
    """Holm-Bonferroni step-down procedure for FWER control.

    Controls the Family-Wise Error Rate (FWER) - the probability of making
    at least one false discovery. More powerful than Bonferroni correction
    while maintaining strong FWER control.

    Based on Holm (1979): "A Simple Sequentially Rejective Multiple Test Procedure"

    Parameters
    ----------
    p_values : Sequence[float]
        P-values from multiple hypothesis tests
    alpha : float, default 0.05
        Target FWER significance level

    Returns
    -------
    dict
        Dictionary with:
        - rejected: list[bool] - Whether each hypothesis is rejected
        - adjusted_p_values: list[float] - Holm-adjusted p-values
        - n_rejected: int - Number of rejections
        - critical_values: list[float] - Holm critical thresholds

    Notes
    -----
    The Holm procedure is a step-down method:

    1. Sort p-values ascending: p_(1) <= p_(2) <= ... <= p_(m)
    2. For p_(i), compare to alpha / (m - i + 1)
    3. Reject all hypotheses up to (and including) the last rejection
    4. Stop at first non-rejection; accept remaining hypotheses

    This is uniformly more powerful than Bonferroni while controlling FWER.

    Examples
    --------
    >>> p_values = [0.001, 0.01, 0.03, 0.08, 0.12]
    >>> result = holm_bonferroni(p_values, alpha=0.05)
    >>> print(f"Rejected: {result['rejected']}")
    Rejected: [True, True, False, False, False]
    """
    p_array = np.asarray(p_values, dtype=np.float64)
    m = len(p_array)

    if m == 0:
        return {
            "rejected": [],
            "adjusted_p_values": [],
            "n_rejected": 0,
            "critical_values": [],
        }

    # Sort p-values and track original indices
    sorted_indices = np.argsort(p_array)
    sorted_p = p_array[sorted_indices]

    # Holm critical values: alpha / (m - i + 1) for i = 0, 1, ..., m-1
    # i.e., alpha/m, alpha/(m-1), ..., alpha/1
    critical_values = alpha / (m - np.arange(m))

    # Step-down procedure: reject while p_(i) <= critical_(i)
    rejected_sorted = sorted_p <= critical_values

    # Once we fail to reject, accept all remaining
    if not rejected_sorted.all():
        first_fail = np.argmin(rejected_sorted)
        rejected_sorted[first_fail:] = False

    # Map back to original order
    rejected = np.zeros(m, dtype=bool)
    rejected[sorted_indices] = rejected_sorted

    # Compute Holm-adjusted p-values
    # adjusted_p_(i) = max_{j <= i} { (m - j + 1) * p_(j) }
    adjusted_sorted = np.maximum.accumulate(sorted_p * (m - np.arange(m)))
    adjusted_sorted = np.clip(adjusted_sorted, 0.0, 1.0)

    # Map adjusted p-values back to original order
    adjusted_p_values = np.zeros(m)
    adjusted_p_values[sorted_indices] = adjusted_sorted

    # Critical values in original order
    critical_original = np.zeros(m)
    critical_original[sorted_indices] = critical_values

    return {
        "rejected": rejected.tolist(),
        "adjusted_p_values": adjusted_p_values.tolist(),
        "n_rejected": int(rejected.sum()),
        "critical_values": critical_original.tolist(),
    }

multiple_testing_summary

multiple_testing_summary(
    test_results, method="benjamini_hochberg", alpha=0.05
)

Summarize results from multiple statistical tests with corrections.

Provides a comprehensive summary of multiple hypothesis testing results with appropriate corrections for multiple comparisons.

Parameters

test_results : Sequence[dict] List of test result dictionaries (each should have 'p_value' key) method : str, default "benjamini_hochberg" Multiple testing correction method alpha : float, default 0.05 Significance level

Returns

dict Summary with original and corrected results

Examples

results = [{'name': 'Strategy A', 'p_value': 0.01}, ... {'name': 'Strategy B', 'p_value': 0.08}] summary = multiple_testing_summary(results) print(f"Significant after correction: {summary['n_significant_corrected']}")

Source code in src/ml4t/diagnostic/evaluation/stats/false_discovery_rate.py
def multiple_testing_summary(
    test_results: Sequence[dict[str, Any]],
    method: str = "benjamini_hochberg",
    alpha: float = 0.05,
) -> dict[str, Any]:
    """Summarize results from multiple statistical tests with corrections.

    Provides a comprehensive summary of multiple hypothesis testing results
    with appropriate corrections for multiple comparisons.

    Parameters
    ----------
    test_results : Sequence[dict]
        List of test result dictionaries (each should have 'p_value' key)
    method : str, default "benjamini_hochberg"
        Multiple testing correction method
    alpha : float, default 0.05
        Significance level

    Returns
    -------
    dict
        Summary with original and corrected results

    Examples
    --------
    >>> results = [{'name': 'Strategy A', 'p_value': 0.01},
    ...           {'name': 'Strategy B', 'p_value': 0.08}]
    >>> summary = multiple_testing_summary(results)
    >>> print(f"Significant after correction: {summary['n_significant_corrected']}")
    """
    if not test_results:
        return {
            "n_tests": 0,
            "n_significant_uncorrected": 0,
            "n_significant_corrected": 0,
            "correction_method": method,
            "alpha": alpha,
        }

    # Extract p-values
    p_values = [result.get("p_value", np.nan) for result in test_results]
    valid_p_values = [p for p in p_values if not np.isnan(p)]

    if not valid_p_values:
        return {
            "n_tests": len(test_results),
            "n_significant_uncorrected": 0,
            "n_significant_corrected": 0,
            "correction_method": method,
            "alpha": alpha,
            "warning": "No valid p-values found",
        }

    # Uncorrected significance
    n_significant_uncorrected = sum(p <= alpha for p in valid_p_values)

    # Apply correction
    if method == "benjamini_hochberg":
        correction_result = benjamini_hochberg_fdr(
            valid_p_values,
            alpha=alpha,
            return_details=True,
        )
        n_significant_corrected = correction_result["n_rejected"]
        adjusted_p_values = correction_result["adjusted_p_values"]
        rejected = correction_result["rejected"]
    else:
        raise ValueError(f"Unknown correction method: {method}")

    return {
        "n_tests": len(test_results),
        "n_significant_uncorrected": n_significant_uncorrected,
        "n_significant_corrected": n_significant_corrected,
        "correction_method": method,
        "alpha": alpha,
        "adjusted_p_values": adjusted_p_values.tolist(),
        "rejected_hypotheses": rejected.tolist(),
        "uncorrected_rate": n_significant_uncorrected / len(valid_p_values),
        "corrected_rate": n_significant_corrected / len(valid_p_values),
    }

robust_ic

robust_ic(
    predictions,
    returns,
    n_samples=1000,
    return_details=False,
)

Calculate Information Coefficient with robust standard errors.

Uses stationary bootstrap [1]_ to compute standard errors that properly account for temporal dependence in time series data.

The stationary bootstrap is the correct method because: 1. Preserves temporal dependence structure 2. No asymptotic approximations required 3. Theoretically valid for rank correlation (Spearman IC)

Parameters

predictions : Union[pl.Series, pd.Series, NDArray] Model predictions or scores returns : Union[pl.Series, pd.Series, NDArray] Forward returns corresponding to predictions n_samples : int, default 1000 Number of bootstrap samples return_details : bool, default False Whether to return detailed statistics

Returns

Union[dict, float] If return_details=False: t-statistic (IC / bootstrap_std) If return_details=True: dict with 'ic', 'bootstrap_std', 't_stat', 'p_value', 'ci_lower', 'ci_upper'

Examples

predictions = np.random.randn(252) returns = 0.1 * predictions + np.random.randn(252) * 0.5 result = robust_ic(predictions, returns, return_details=True) print(f"IC: {result['ic']:.3f}, t-stat: {result['t_stat']:.3f}")

References

.. [1] Politis, D.N. & Romano, J.P. (1994). "The Stationary Bootstrap." Journal of the American Statistical Association 89:1303-1313.

Source code in src/ml4t/diagnostic/evaluation/stats/hac_standard_errors.py
def robust_ic(
    predictions: Union[pl.Series, pd.Series, "NDArray[Any]"],
    returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
    n_samples: int = 1000,
    return_details: bool = False,
) -> dict[str, float] | float:
    """Calculate Information Coefficient with robust standard errors.

    Uses stationary bootstrap [1]_ to compute standard errors that properly
    account for temporal dependence in time series data.

    The stationary bootstrap is the correct method because:
    1. Preserves temporal dependence structure
    2. No asymptotic approximations required
    3. Theoretically valid for rank correlation (Spearman IC)

    Parameters
    ----------
    predictions : Union[pl.Series, pd.Series, NDArray]
        Model predictions or scores
    returns : Union[pl.Series, pd.Series, NDArray]
        Forward returns corresponding to predictions
    n_samples : int, default 1000
        Number of bootstrap samples
    return_details : bool, default False
        Whether to return detailed statistics

    Returns
    -------
    Union[dict, float]
        If return_details=False: t-statistic (IC / bootstrap_std)
        If return_details=True: dict with 'ic', 'bootstrap_std', 't_stat',
            'p_value', 'ci_lower', 'ci_upper'

    Examples
    --------
    >>> predictions = np.random.randn(252)
    >>> returns = 0.1 * predictions + np.random.randn(252) * 0.5
    >>> result = robust_ic(predictions, returns, return_details=True)
    >>> print(f"IC: {result['ic']:.3f}, t-stat: {result['t_stat']:.3f}")

    References
    ----------
    .. [1] Politis, D.N. & Romano, J.P. (1994). "The Stationary Bootstrap."
           Journal of the American Statistical Association 89:1303-1313.
    """
    bootstrap_result = stationary_bootstrap_ic(
        predictions, returns, n_samples=n_samples, return_details=True
    )
    assert isinstance(bootstrap_result, dict)

    if not return_details:
        if bootstrap_result["bootstrap_std"] > 0:
            return bootstrap_result["ic"] / bootstrap_result["bootstrap_std"]
        return np.nan

    # Compute t-statistic
    t_stat = (
        bootstrap_result["ic"] / bootstrap_result["bootstrap_std"]
        if bootstrap_result["bootstrap_std"] > 0
        else np.nan
    )

    return {
        "ic": bootstrap_result["ic"],
        "bootstrap_std": bootstrap_result["bootstrap_std"],
        "t_stat": t_stat,
        "p_value": bootstrap_result.get("p_value", np.nan),
        "ci_lower": bootstrap_result.get("ci_lower", np.nan),
        "ci_upper": bootstrap_result.get("ci_upper", np.nan),
    }

whites_reality_check

whites_reality_check(
    returns_benchmark,
    returns_strategies,
    bootstrap_samples=1000,
    block_size=None,
    random_state=None,
)

Perform White's Reality Check for multiple strategy comparison.

Tests whether any strategy significantly outperforms a benchmark after adjusting for multiple comparisons and data mining bias. Uses stationary bootstrap to preserve temporal dependencies.

Parameters

returns_benchmark : Union[pl.Series, pd.Series, NDArray] Benchmark strategy returns returns_strategies : Union[pd.DataFrame, pl.DataFrame, NDArray] Returns for multiple strategies being tested bootstrap_samples : int, default 1000 Number of bootstrap samples for null distribution block_size : Optional[int], default None Block size for stationary bootstrap. If None, uses optimal size random_state : Optional[int], default None Random seed for reproducible results

Returns

dict Dictionary with 'test_statistic', 'p_value', 'critical_values', 'best_strategy_performance', 'null_distribution'

Notes

Test Hypothesis: - H0: No strategy beats the benchmark (max E[r_i - r_benchmark] <= 0) - H1: At least one strategy beats the benchmark

Interpretation: - p_value < 0.05: Reject H0, at least one strategy beats benchmark - p_value >= 0.05: Cannot reject H0, no evidence of outperformance

Examples

benchmark_returns = np.random.normal(0.001, 0.02, 252) strategy_returns = np.random.normal(0.002, 0.02, (252, 10)) result = whites_reality_check(benchmark_returns, strategy_returns) print(f"Reality Check p-value: {result['p_value']:.3f}")

References

White, H. (2000). "A Reality Check for Data Snooping." Econometrica, 68(5), 1097-1126.

Source code in src/ml4t/diagnostic/evaluation/stats/reality_check.py
def whites_reality_check(
    returns_benchmark: Union[pl.Series, pd.Series, "NDArray[Any]"],
    returns_strategies: Union[pd.DataFrame, pl.DataFrame, "NDArray[Any]"],
    bootstrap_samples: int = 1000,
    block_size: int | None = None,
    random_state: int | None = None,
) -> dict[str, Any]:
    """Perform White's Reality Check for multiple strategy comparison.

    Tests whether any strategy significantly outperforms a benchmark after
    adjusting for multiple comparisons and data mining bias. Uses stationary
    bootstrap to preserve temporal dependencies.

    Parameters
    ----------
    returns_benchmark : Union[pl.Series, pd.Series, NDArray]
        Benchmark strategy returns
    returns_strategies : Union[pd.DataFrame, pl.DataFrame, NDArray]
        Returns for multiple strategies being tested
    bootstrap_samples : int, default 1000
        Number of bootstrap samples for null distribution
    block_size : Optional[int], default None
        Block size for stationary bootstrap. If None, uses optimal size
    random_state : Optional[int], default None
        Random seed for reproducible results

    Returns
    -------
    dict
        Dictionary with 'test_statistic', 'p_value', 'critical_values',
        'best_strategy_performance', 'null_distribution'

    Notes
    -----
    **Test Hypothesis**:
    - H0: No strategy beats the benchmark (max E[r_i - r_benchmark] <= 0)
    - H1: At least one strategy beats the benchmark

    **Interpretation**:
    - p_value < 0.05: Reject H0, at least one strategy beats benchmark
    - p_value >= 0.05: Cannot reject H0, no evidence of outperformance

    Examples
    --------
    >>> benchmark_returns = np.random.normal(0.001, 0.02, 252)
    >>> strategy_returns = np.random.normal(0.002, 0.02, (252, 10))
    >>> result = whites_reality_check(benchmark_returns, strategy_returns)
    >>> print(f"Reality Check p-value: {result['p_value']:.3f}")

    References
    ----------
    White, H. (2000). "A Reality Check for Data Snooping."
    Econometrica, 68(5), 1097-1126.
    """
    # Convert inputs
    benchmark = DataFrameAdapter.to_numpy(returns_benchmark).flatten()

    if isinstance(returns_strategies, pd.DataFrame | pl.DataFrame):
        strategies = DataFrameAdapter.to_numpy(returns_strategies)
        if strategies.ndim == 1:
            strategies = strategies.reshape(-1, 1)
    else:
        strategies = np.array(returns_strategies)
        if strategies.ndim == 1:
            strategies = strategies.reshape(-1, 1)

    n_periods, n_strategies = strategies.shape

    if len(benchmark) != n_periods:
        raise ValueError("Benchmark and strategies must have same number of periods")

    # Calculate relative performance (strategies vs benchmark)
    relative_returns = strategies - benchmark.reshape(-1, 1)

    # Test statistic: maximum mean relative performance
    mean_relative_returns = np.mean(relative_returns, axis=0)
    test_statistic = np.max(mean_relative_returns)
    best_strategy_idx = np.argmax(mean_relative_returns)

    # Bootstrap null distribution
    if random_state is not None:
        np.random.seed(random_state)

    # Optimal block size for stationary bootstrap (rule of thumb)
    if block_size is None:
        block_size = max(1, int(n_periods ** (1 / 3)))

    null_dist_list: list[float] = []

    for _ in range(bootstrap_samples):
        # Stationary bootstrap resampling
        bootstrap_indices = _stationary_bootstrap_indices(n_periods, float(block_size))

        # Resample relative returns
        bootstrap_relative = relative_returns[bootstrap_indices]

        # Center the bootstrap sample (impose null hypothesis)
        bootstrap_relative = bootstrap_relative - np.mean(bootstrap_relative, axis=0)

        # Calculate maximum mean for this bootstrap sample
        bootstrap_max = np.max(np.mean(bootstrap_relative, axis=0))
        null_dist_list.append(float(bootstrap_max))

    null_distribution = np.array(null_dist_list)

    # Calculate p-value
    p_value = np.mean(null_distribution >= test_statistic)

    # Calculate critical values
    critical_values = {
        "90%": np.percentile(null_distribution, 90),
        "95%": np.percentile(null_distribution, 95),
        "99%": np.percentile(null_distribution, 99),
    }

    return {
        "test_statistic": float(test_statistic),
        "p_value": float(p_value),
        "critical_values": critical_values,
        "best_strategy_idx": int(best_strategy_idx),
        "best_strategy_performance": float(mean_relative_returns[best_strategy_idx]),
        "null_distribution": null_distribution,
        "n_strategies": n_strategies,
        "n_periods": n_periods,
    }

Integration

The integration surface focuses on contracts and the ml4t-backtest bridge:

Category Objects
Contracts TradeRecord, DataQualityReport, DataQualityMetrics, DataAnomaly, BacktestReportMetadata
Backtest bridge compute_metrics_from_result, analyze_backtest_result, portfolio_analysis_from_result
Tearsheet generation generate_tearsheet_from_result, profile_from_run_artifacts, generate_tearsheet_from_run_artifacts

Visualization

The visualization namespace is Plotly-first and grouped by workflow:

Area Representative functions
Cross-validation plot_cv_folds
Signal analysis plot_ic_ts, plot_quantile_returns_bar, SignalDashboard, MultiSignalDashboard
Portfolio analysis create_portfolio_dashboard, plot_portfolio_cumulative_returns, plot_monthly_returns_heatmap, plot_drawdown_underwater, plot_rolling_sharpe
Factor analysis plot_factor_betas_bar, plot_rolling_betas, plot_return_attribution_waterfall
Reporting combine_figures_to_html, generate_combined_report, export_figures_to_pdf

For a package-layout overview, see the Architecture page.