ML4T Backtest
ML4T Backtest Documentation
Event-driven backtesting with realistic execution
Skip to content

API Reference

Auto-generated from source docstrings.

Core

Engine

Engine(
    feed,
    strategy,
    config=None,
    *,
    contract_specs=None,
    market_impact_model=None,
    execution_limits=None,
)

Event-driven backtesting engine.

The Engine orchestrates the backtest by iterating through market data, managing the broker, and calling the strategy on each bar.

Execution Flow
  1. Initialize strategy (on_start)
  2. For each bar: a. Update broker with current prices b. Process pending exits (NEXT_BAR_OPEN mode) c. Evaluate position rules (stops, trails) d. Process pending orders e. Call strategy.on_data() f. Process new orders (SAME_BAR mode) g. Update water marks h. Record equity
  3. Close open positions
  4. Finalize strategy (on_end)

Attributes:

Name Type Description
feed

DataFeed providing price and signal data

strategy

Strategy implementing trading logic

broker

Broker handling order execution and positions

config

BacktestConfig with all behavioral settings

equity_curve list[tuple[datetime, float]]

List of (timestamp, equity) tuples

Example

from ml4t.backtest import Engine, DataFeed, Strategy, BacktestConfig

class MyStrategy(Strategy): ... def on_data(self, timestamp, data, context, broker): ... for asset, bar in data.items(): ... if bar.get('signal', 0) > 0.5: ... broker.submit_order(asset, 100)

feed = DataFeed(prices_df=df) engine = Engine(feed=feed, strategy=MyStrategy()) result = engine.run() print(result['total_return'])

Source code in src/ml4t/backtest/engine.py
def __init__(
    self,
    feed: DataFeed,
    strategy: Strategy,
    config: BacktestConfig | None = None,
    *,
    contract_specs: dict[str, Any] | None = None,
    market_impact_model: Any | None = None,
    execution_limits: Any | None = None,
):
    from .config import BacktestConfig as ConfigCls

    if config is None:
        config = ConfigCls()

    self.feed = feed
    self.strategy = strategy
    self.config = config.merge_feed_spec(getattr(feed, "feed_spec", None))
    self.execution_mode = self.config.execution_mode
    self.broker = Broker.from_config(
        self.config,
        contract_specs=contract_specs,
        market_impact_model=market_impact_model,
        execution_limits=execution_limits,
    )
    self.equity_curve: list[tuple[datetime, float]] = []
    self.portfolio_state: list[tuple[datetime, float, float, float, float, int]] = []

    # Calendar session enforcement (lazy initialized in run())
    self._calendar = None
    self._skipped_bars = 0

run

run()

Run backtest and return structured results.

Returns:

Type Description
BacktestResult

BacktestResult with trades, equity curve, metrics, and export methods.

BacktestResult

Call .to_dict() for backward-compatible dictionary output.

Source code in src/ml4t/backtest/engine.py
def run(self) -> BacktestResult:
    """Run backtest and return structured results.

    Returns:
        BacktestResult with trades, equity curve, metrics, and export methods.
        Call .to_dict() for backward-compatible dictionary output.
    """
    # Lazy calendar initialization (zero cost if unused)
    is_trading_day_fn = None
    if self.config and self.config.resolved_calendar:
        from .calendar import get_calendar, is_trading_day

        self._calendar = get_calendar(self.config.resolved_calendar)
        is_trading_day_fn = is_trading_day

    self.strategy.on_prepare(self.broker, self.feed.timestamps, self.config)
    self.strategy.on_start(self.broker)

    # Date-level cache for trading day checks (significant speedup for intraday data)
    trading_day_cache: dict[date, bool] = {}

    for timestamp, assets_data, context in self.feed:
        # Calendar session enforcement
        calendar_id = self.config.resolved_calendar if self.config else None
        if (
            self._calendar
            and calendar_id
            and self.config
            and self.config.enforce_sessions
            and is_trading_day_fn
        ):
            # For daily data, check trading day; for intraday, check market hours
            if self.config.resolved_data_frequency == DataFrequency.DAILY:
                if not is_trading_day_fn(calendar_id, timestamp.date()):
                    self._skipped_bars += 1
                    continue
            else:
                # Intraday: use cached trading day check (avoid expensive calendar.valid_days per bar)
                bar_date = timestamp.date()
                if bar_date not in trading_day_cache:
                    trading_day_cache[bar_date] = is_trading_day_fn(calendar_id, bar_date)
                if not trading_day_cache[bar_date]:
                    self._skipped_bars += 1
                    continue

        prices = getattr(assets_data, "_prices", None)
        opens = getattr(assets_data, "_opens", None)
        highs = getattr(assets_data, "_highs", None)
        lows = getattr(assets_data, "_lows", None)
        closes = getattr(assets_data, "_closes", None)
        volumes = getattr(assets_data, "_volumes", None)
        bids = getattr(assets_data, "_bids", None)
        asks = getattr(assets_data, "_asks", None)
        mids = getattr(assets_data, "_mids", None)
        bid_sizes = getattr(assets_data, "_bid_sizes", None)
        ask_sizes = getattr(assets_data, "_ask_sizes", None)
        signals = getattr(assets_data, "_signals", None)

        if (
            prices is None
            or opens is None
            or highs is None
            or lows is None
            or closes is None
            or volumes is None
            or bids is None
            or asks is None
            or mids is None
            or bid_sizes is None
            or ask_sizes is None
            or signals is None
        ):
            prices = {
                a: price
                for a, d in assets_data.items()
                if (price := d.get("price", d.get("close"))) is not None
            }
            opens = {a: d.get("open", d.get("close")) for a, d in assets_data.items()}
            highs = {a: d.get("high", d.get("close")) for a, d in assets_data.items()}
            lows = {a: d.get("low", d.get("close")) for a, d in assets_data.items()}
            closes = {
                a: close
                for a, d in assets_data.items()
                if (close := d.get("close", d.get("price"))) is not None
            }
            volumes = {a: d.get("volume", 0) for a, d in assets_data.items()}
            bids = {a: d["bid"] for a, d in assets_data.items() if d.get("bid") is not None}
            asks = {a: d["ask"] for a, d in assets_data.items() if d.get("ask") is not None}
            mids = {a: d["mid"] for a, d in assets_data.items() if d.get("mid") is not None}
            bid_sizes = {
                a: d["bid_size"]
                for a, d in assets_data.items()
                if d.get("bid_size") is not None
            }
            ask_sizes = {
                a: d["ask_size"]
                for a, d in assets_data.items()
                if d.get("ask_size") is not None
            }
            signals = {a: d.get("signals", {}) for a, d in assets_data.items()}

        self.broker._update_time(
            timestamp,
            prices,
            opens,
            highs,
            lows,
            closes,
            volumes,
            bids,
            asks,
            mids,
            bid_sizes,
            ask_sizes,
            signals,
        )

        # Process pending exits from NEXT_BAR_OPEN mode (fills at open)
        # This must happen BEFORE evaluate_position_rules() to clear deferred exits
        self.broker._process_pending_exits()

        # Evaluate position rules (stops, trails, etc.) - generates exit orders
        self.broker.evaluate_position_rules()

        if self.execution_mode == ExecutionMode.NEXT_BAR:
            # Next-bar mode: process pending orders at open price
            self.broker._process_orders(use_open=True)
            # Strategy generates new orders
            self.strategy.on_data(timestamp, assets_data, context, self.broker)
            # New orders will be processed next bar
        else:
            # Same-bar mode: process before and after strategy
            self.broker._process_orders()
            self.strategy.on_data(timestamp, assets_data, context, self.broker)
            self.broker._process_orders()

        # Update water marks at END of bar, AFTER all orders processed
        # This ensures new positions get their HWM updated from entry bar's high
        # VBT Pro behavior: HWM updated at bar end, used in NEXT bar's trail evaluation
        self.broker._update_water_marks()

        self._record_portfolio_state(timestamp)

    self.strategy.on_end(self.broker)
    return self._generate_results()

run_dict

run_dict()

Run backtest and return dictionary (backward compatible).

This is equivalent to run().to_dict() but more explicit for code that requires dictionary output.

Returns:

Type Description
dict[str, Any]

Dictionary with metrics, trades, and equity curve.

Source code in src/ml4t/backtest/engine.py
def run_dict(self) -> dict[str, Any]:
    """Run backtest and return dictionary (backward compatible).

    This is equivalent to run().to_dict() but more explicit for code
    that requires dictionary output.

    Returns:
        Dictionary with metrics, trades, and equity curve.
    """
    return self.run().to_dict()

from_config classmethod

from_config(
    feed,
    strategy,
    config,
    *,
    contract_specs=None,
    market_impact_model=None,
    execution_limits=None,
)

Create an Engine instance from a BacktestConfig.

Equivalent to Engine(feed, strategy, config). Kept as a convenience for code that reads more clearly with a named constructor.

Parameters:

Name Type Description Default
feed DataFeed

DataFeed with price data

required
strategy Strategy

Strategy to execute

required
config BacktestConfig

BacktestConfig with all behavioral settings

required
contract_specs dict[str, Any] | None

Per-asset contract specifications (futures multipliers, etc.)

None
market_impact_model Any | None

Market impact model for fill simulation

None
execution_limits Any | None

Execution limits (max order size, etc.)

None

Returns:

Type Description
Engine

Configured Engine instance

Source code in src/ml4t/backtest/engine.py
@classmethod
def from_config(
    cls,
    feed: DataFeed,
    strategy: Strategy,
    config: BacktestConfig,
    *,
    contract_specs: dict[str, Any] | None = None,
    market_impact_model: Any | None = None,
    execution_limits: Any | None = None,
) -> Engine:
    """Create an Engine instance from a BacktestConfig.

    Equivalent to ``Engine(feed, strategy, config)``. Kept as a convenience
    for code that reads more clearly with a named constructor.

    Args:
        feed: DataFeed with price data
        strategy: Strategy to execute
        config: BacktestConfig with all behavioral settings
        contract_specs: Per-asset contract specifications (futures multipliers, etc.)
        market_impact_model: Market impact model for fill simulation
        execution_limits: Execution limits (max order size, etc.)

    Returns:
        Configured Engine instance
    """
    return cls(
        feed,
        strategy,
        config,
        contract_specs=contract_specs,
        market_impact_model=market_impact_model,
        execution_limits=execution_limits,
    )

run_backtest

run_backtest(
    prices,
    strategy,
    signals=None,
    context=None,
    config=None,
    *,
    feed_spec=None,
    contract=None,
    contract_specs=None,
    market_impact_model=None,
    execution_limits=None,
)

Run a backtest with minimal setup.

Parameters:

Name Type Description Default
prices DataFrame | str

Price DataFrame or path to parquet file

required
strategy Strategy

Strategy instance to execute

required
signals DataFrame | str | None

Optional signals DataFrame or path

None
context DataFrame | str | None

Optional context DataFrame or path

None
config BacktestConfig | str | None

BacktestConfig instance, preset name (str), or None for defaults

None
feed_spec Any | None

Optional shared dataset contract for schema and temporal metadata

None
contract Any | None

Alias for feed_spec

None
contract_specs dict[str, Any] | None

Per-asset contract specifications (futures multipliers, etc.)

None
market_impact_model Any | None

Market impact model for fill simulation

None
execution_limits Any | None

Execution limits (max order size, etc.)

None

Returns:

Type Description
BacktestResult

BacktestResult with metrics, trades, equity curve, and export methods.

Example

Using config preset

result = run_backtest(prices_df, strategy, config="backtrader") print(result.metrics["sharpe"])

Using custom config

config = BacktestConfig.from_preset("backtrader") config.commission_rate = 0.002 result = run_backtest(prices_df, strategy, config=config)

Futures with contract specs

from ml4t.backtest import ContractSpec, AssetClass specs = {"ES": ContractSpec(symbol="ES", asset_class=AssetClass.FUTURE, multiplier=50.0)} result = run_backtest(prices_df, strategy, config=config, contract_specs=specs)

Source code in src/ml4t/backtest/engine.py
def run_backtest(
    prices: pl.DataFrame | str,
    strategy: Strategy,
    signals: pl.DataFrame | str | None = None,
    context: pl.DataFrame | str | None = None,
    config: BacktestConfig | str | None = None,
    *,
    feed_spec: Any | None = None,
    contract: Any | None = None,
    contract_specs: dict[str, Any] | None = None,
    market_impact_model: Any | None = None,
    execution_limits: Any | None = None,
) -> BacktestResult:
    """Run a backtest with minimal setup.

    Args:
        prices: Price DataFrame or path to parquet file
        strategy: Strategy instance to execute
        signals: Optional signals DataFrame or path
        context: Optional context DataFrame or path
        config: BacktestConfig instance, preset name (str), or None for defaults
        feed_spec: Optional shared dataset contract for schema and temporal metadata
        contract: Alias for feed_spec
        contract_specs: Per-asset contract specifications (futures multipliers, etc.)
        market_impact_model: Market impact model for fill simulation
        execution_limits: Execution limits (max order size, etc.)

    Returns:
        BacktestResult with metrics, trades, equity curve, and export methods.

    Example:
        # Using config preset
        result = run_backtest(prices_df, strategy, config="backtrader")
        print(result.metrics["sharpe"])

        # Using custom config
        config = BacktestConfig.from_preset("backtrader")
        config.commission_rate = 0.002
        result = run_backtest(prices_df, strategy, config=config)

        # Futures with contract specs
        from ml4t.backtest import ContractSpec, AssetClass
        specs = {"ES": ContractSpec(symbol="ES", asset_class=AssetClass.FUTURE, multiplier=50.0)}
        result = run_backtest(prices_df, strategy, config=config, contract_specs=specs)
    """
    feed = DataFeed(
        prices_path=prices if isinstance(prices, str) else None,
        signals_path=signals if isinstance(signals, str) else None,
        context_path=context if isinstance(context, str) else None,
        prices_df=prices if isinstance(prices, pl.DataFrame) else None,
        signals_df=signals if isinstance(signals, pl.DataFrame) else None,
        context_df=context if isinstance(context, pl.DataFrame) else None,
        feed_spec=feed_spec,
        contract=contract,
    )

    if isinstance(config, str):
        from .config import BacktestConfig as ConfigCls

        config = ConfigCls.from_preset(config)

    return Engine(
        feed,
        strategy,
        config,
        contract_specs=contract_specs,
        market_impact_model=market_impact_model,
        execution_limits=execution_limits,
    ).run()

Strategy

Bases: ABC

Base strategy class.

on_data abstractmethod

on_data(timestamp, data, context, broker)

Called for each timestamp with all available data.

Source code in src/ml4t/backtest/strategy.py
@abstractmethod
def on_data(
    self,
    timestamp: datetime,
    data: dict[str, dict],
    context: dict[str, Any],
    broker: Any,  # Avoid circular import, use Any for broker type
) -> None:
    """Called for each timestamp with all available data."""
    pass

on_start

on_start(broker)

Called before backtest starts.

Source code in src/ml4t/backtest/strategy.py
def on_start(self, broker: Any) -> None:  # noqa: B027
    """Called before backtest starts."""
    pass

on_prepare

on_prepare(broker, timestamps, config=None)

Called before on_start with access to the full feed timestamp universe.

Source code in src/ml4t/backtest/strategy.py
def on_prepare(
    self,
    broker: Any,
    timestamps: Sequence[datetime],
    config: Any | None = None,
) -> None:
    """Called before on_start with access to the full feed timestamp universe."""
    return None

on_end

on_end(broker)

Called after backtest ends.

Source code in src/ml4t/backtest/strategy.py
def on_end(self, broker: Any) -> None:  # noqa: B027
    """Called after backtest ends."""
    pass

DataFeed

DataFeed(
    prices_path=None,
    signals_path=None,
    context_path=None,
    prices_df=None,
    signals_df=None,
    context_df=None,
    *,
    feed_spec=None,
    contract=None,
    entity_col=None,
    timestamp_col=None,
    price_col=None,
    open_col=None,
    high_col=None,
    low_col=None,
    close_col=None,
    volume_col=None,
    bid_col=None,
    ask_col=None,
    mid_col=None,
    bid_size_col=None,
    ask_size_col=None,
)

Polars-based multi-asset data feed with signals and context.

Pre-partitions data by timestamp at initialization for O(1) lookups during iteration. DataFrames are stored in their native format and converted to dicts only at iteration time, reducing memory usage ~10x for large datasets.

Memory Efficiency
  • 1M bars: ~100 MB (was ~1 GB with pre-converted dicts)
  • 10M bars: ~1 GB (vs ~10+ GB with dicts)
Usage

feed = DataFeed(prices_df=prices, signals_df=signals) for timestamp, assets_data, context in feed: # assets_data: {"AAPL": {"close": 150.0, "signals": {...}}, ...} process(timestamp, assets_data)

Source code in src/ml4t/backtest/datafeed.py
def __init__(
    self,
    prices_path: str | None = None,
    signals_path: str | None = None,
    context_path: str | None = None,
    prices_df: pl.DataFrame | None = None,
    signals_df: pl.DataFrame | None = None,
    context_df: pl.DataFrame | None = None,
    *,
    feed_spec: FeedSpec | Any | None = None,
    contract: FeedSpec | Any | None = None,
    entity_col: str | None = None,
    timestamp_col: str | None = None,
    price_col: str | None = None,
    open_col: str | None = None,
    high_col: str | None = None,
    low_col: str | None = None,
    close_col: str | None = None,
    volume_col: str | None = None,
    bid_col: str | None = None,
    ask_col: str | None = None,
    mid_col: str | None = None,
    bid_size_col: str | None = None,
    ask_size_col: str | None = None,
):
    if feed_spec is not None and contract is not None:
        raise ValueError("Pass either feed_spec or contract, not both")

    self.prices = (
        prices_df
        if prices_df is not None
        else (pl.scan_parquet(prices_path).collect() if prices_path else None)
    )
    self.signals = (
        signals_df
        if signals_df is not None
        else (pl.scan_parquet(signals_path).collect() if signals_path else None)
    )
    self.context = (
        context_df
        if context_df is not None
        else (pl.scan_parquet(context_path).collect() if context_path else None)
    )

    if self.prices is None:
        raise ValueError("prices_path or prices_df required")

    raw_spec = FeedSpec.from_any(feed_spec if feed_spec is not None else contract)
    self.feed_spec = raw_spec.with_overrides(
        entity_col=entity_col,
        timestamp_col=timestamp_col,
        price_col=price_col,
        open_col=open_col,
        high_col=high_col,
        low_col=low_col,
        close_col=close_col,
        volume_col=volume_col,
        bid_col=bid_col,
        ask_col=ask_col,
        mid_col=mid_col,
        bid_size_col=bid_size_col,
        ask_size_col=ask_size_col,
    ).resolve(self.prices.columns, self.ENTITY_COL_CANDIDATES)
    self.contract = self.feed_spec
    self._timestamp_col = self.feed_spec.timestamp_col
    self._entity_col = self.feed_spec.entity_col
    self._price_col = self.feed_spec.price_col
    self._open_col = self.feed_spec.open_col
    self._high_col = self.feed_spec.high_col
    self._low_col = self.feed_spec.low_col
    self._close_col = self.feed_spec.close_col
    self._volume_col = self.feed_spec.volume_col
    self._bid_col = self.feed_spec.bid_col
    self._ask_col = self.feed_spec.ask_col
    self._mid_col = self.feed_spec.mid_col
    self._bid_size_col = self.feed_spec.bid_size_col
    self._ask_size_col = self.feed_spec.ask_size_col

    # Pre-partition data by timestamp for O(1) lookups
    # Store DataFrames (memory efficient) instead of dicts (memory explosion)
    self._prices_by_ts = self._partition_by_timestamp(self.prices)
    self._signals_by_ts = (
        self._partition_by_timestamp(self.signals) if self.signals is not None else {}
    )
    self._context_by_ts = (
        self._partition_by_timestamp(self.context) if self.context is not None else {}
    )

    self._timestamps = self._get_timestamps()
    self._idx = 0
    self._signal_columns = (
        [c for c in self.signals.columns if c not in (self._timestamp_col, self._entity_col)]
        if self.signals is not None
        else []
    )
    self._context_columns = (
        [c for c in self.context.columns if c != self._timestamp_col]
        if self.context is not None
        else []
    )

    price_cols = self.prices.columns
    self._price_asset_idx = price_cols.index(self._entity_col)
    self._price_open_idx = (
        price_cols.index(self._open_col) if self._open_col in price_cols else -1
    )
    self._price_high_idx = (
        price_cols.index(self._high_col) if self._high_col in price_cols else -1
    )
    self._price_low_idx = price_cols.index(self._low_col) if self._low_col in price_cols else -1
    self._price_close_idx = (
        price_cols.index(self._close_col) if self._close_col in price_cols else -1
    )
    self._price_price_idx = (
        price_cols.index(self._price_col)
        if self._price_col in price_cols
        else self._price_close_idx
    )
    self._price_volume_idx = (
        price_cols.index(self._volume_col) if self._volume_col in price_cols else -1
    )
    self._price_bid_idx = price_cols.index(self._bid_col) if self._bid_col in price_cols else -1
    self._price_ask_idx = price_cols.index(self._ask_col) if self._ask_col in price_cols else -1
    self._price_mid_idx = price_cols.index(self._mid_col) if self._mid_col in price_cols else -1
    self._price_bid_size_idx = (
        price_cols.index(self._bid_size_col) if self._bid_size_col in price_cols else -1
    )
    self._price_ask_size_idx = (
        price_cols.index(self._ask_size_col) if self._ask_size_col in price_cols else -1
    )

    if self.signals is not None:
        signal_cols = self.signals.columns
        if self._timestamp_col not in signal_cols:
            raise ValueError(
                f"timestamp_col={self._timestamp_col!r} not found in signal columns {signal_cols}"
            )
        self._signal_asset_idx = signal_cols.index(self._entity_col)
        self._signal_col_indices = [signal_cols.index(c) for c in self._signal_columns]
    else:
        self._signal_asset_idx = -1
        self._signal_col_indices = []

    if self.context is not None:
        context_cols = self.context.columns
        if self._timestamp_col not in context_cols:
            raise ValueError(
                f"timestamp_col={self._timestamp_col!r} not found in context columns {context_cols}"
            )
        self._context_col_indices = [context_cols.index(c) for c in self._context_columns]
    else:
        self._context_col_indices = []

n_bars property

n_bars

Number of unique timestamps/bars.

timestamps property

timestamps

Unique feed timestamps in iteration order.

Configuration

BacktestConfig dataclass

BacktestConfig(
    allow_short_selling=False,
    allow_leverage=False,
    initial_margin=0.5,
    long_maintenance_margin=0.25,
    short_maintenance_margin=0.3,
    fixed_margin_schedule=None,
    short_cash_policy=CREDIT,
    execution_price=OPEN,
    mark_price=PRICE,
    execution_mode=NEXT_BAR,
    stop_fill_mode=STOP_PRICE,
    stop_level_basis=FILL_PRICE,
    trail_hwm_source=CLOSE,
    initial_hwm_source=FILL_PRICE,
    trail_stop_timing=LAGGED,
    share_type=FRACTIONAL,
    commission_type=PERCENTAGE,
    commission_rate=0.001,
    commission_per_share=0.0,
    commission_per_trade=0.0,
    commission_minimum=0.0,
    slippage_type=PERCENTAGE,
    slippage_rate=0.001,
    slippage_fixed=0.0,
    stop_slippage_rate=0.0,
    initial_cash=100000.0,
    cash_buffer_pct=0.0,
    settlement_delay=0,
    settlement_reduces_buying_power=True,
    reject_on_insufficient_cash=True,
    skip_cash_validation=False,
    partial_fills_allowed=False,
    fill_ordering=EXIT_FIRST,
    entry_order_priority=SUBMISSION,
    next_bar_submission_precheck=False,
    next_bar_simple_cash_check=False,
    buying_power_reservation=False,
    next_bar_queue_shadow_validation=False,
    immediate_fill=False,
    rebalance_mode=INCREMENTAL,
    rebalance_headroom_pct=1.0,
    missing_price_policy=SKIP,
    late_asset_policy=ALLOW,
    late_asset_min_bars=1,
    calendar=None,
    timezone="UTC",
    data_frequency=DAILY,
    enforce_sessions=False,
    preset_name=None,
    feed_spec=None,
    metadata=dict(),
)

Complete configuration for backtesting behavior.

All behavioral differences between frameworks are captured here. Load presets to match specific frameworks exactly.

This is the single source of truth for all backtest settings. Broker and Engine are configured entirely from this dataclass.

from_preset classmethod

from_preset(preset)

Load a predefined configuration preset.

Available presets: - "default": Sensible defaults for general use - "backtrader": Match Backtrader's default behavior - "vectorbt": Match VectorBT's default behavior - "zipline": Match Zipline's default behavior - "lean": Match QuantConnect LEAN's default behavior - "realistic": Conservative settings for realistic simulation

Source code in src/ml4t/backtest/config.py
@classmethod
def from_preset(cls, preset: str) -> BacktestConfig:
    """
    Load a predefined configuration preset.

    Available presets:
    - "default": Sensible defaults for general use
    - "backtrader": Match Backtrader's default behavior
    - "vectorbt": Match VectorBT's default behavior
    - "zipline": Match Zipline's default behavior
    - "lean": Match QuantConnect LEAN's default behavior
    - "realistic": Conservative settings for realistic simulation
    """
    from .profiles import get_profile_config

    profile_data = get_profile_config(preset)
    return cls.from_dict(profile_data, preset_name=preset, strict=True)

from_yaml classmethod

from_yaml(path)

Load config from YAML file.

Source code in src/ml4t/backtest/config.py
@classmethod
def from_yaml(cls, path: str | Path) -> BacktestConfig:
    """Load config from YAML file."""
    path = Path(path)
    with open(path) as f:
        data = yaml.safe_load(f)
    return cls.from_dict(data, preset_name=path.stem, strict=True)

from_dict classmethod

from_dict(data, preset_name=None, strict=True)

Create config from dictionary.

Parameters:

Name Type Description Default
data dict

Nested config dictionary

required
preset_name str | None

Optional metadata label

None
strict bool

If True, reject unknown sections/keys

True
Source code in src/ml4t/backtest/config.py
@classmethod
def from_dict(
    cls, data: dict, preset_name: str | None = None, strict: bool = True
) -> BacktestConfig:
    """Create config from dictionary.

    Args:
        data: Nested config dictionary
        preset_name: Optional metadata label
        strict: If True, reject unknown sections/keys
    """
    if not isinstance(data, dict):
        raise TypeError(f"Config data must be a dict, got {type(data).__name__}")

    if strict:
        allowed_sections = {
            "account",
            "execution",
            "stops",
            "position_sizing",
            "commission",
            "slippage",
            "cash",
            "settlement",
            "orders",
            "calendar",
            "feed",
            "metadata",
        }
        unknown_sections = set(data) - allowed_sections
        if unknown_sections:
            raise ValueError(f"Unknown config section(s): {sorted(unknown_sections)}")

        allowed_keys_by_section = {
            "account": {
                "allow_short_selling",
                "allow_leverage",
                "initial_margin",
                "long_maintenance_margin",
                "short_maintenance_margin",
                "fixed_margin_schedule",
                "short_cash_policy",
            },
            "execution": {"execution_price", "mark_price", "execution_mode"},
            "stops": {
                "stop_fill_mode",
                "stop_level_basis",
                "trail_hwm_source",
                "initial_hwm_source",
                "trail_stop_timing",
            },
            "position_sizing": {"share_type"},
            "commission": {"model", "rate", "per_share", "per_trade", "minimum"},
            "slippage": {"model", "rate", "fixed", "stop_rate"},
            "cash": {"initial", "buffer_pct"},
            "settlement": {"delay", "reduces_buying_power"},
            "orders": {
                "reject_on_insufficient_cash",
                "skip_cash_validation",
                "partial_fills_allowed",
                "fill_ordering",
                "entry_order_priority",
                "next_bar_submission_precheck",
                "next_bar_simple_cash_check",
                "buying_power_reservation",
                "next_bar_queue_shadow_validation",
                "immediate_fill",
                "rebalance_mode",
                "rebalance_headroom_pct",
                "missing_price_policy",
                "late_asset_policy",
                "late_asset_min_bars",
            },
            "calendar": {
                "calendar",
                "timezone",
                "data_frequency",
                "enforce_sessions",
            },
            "feed": {
                "timestamp_col",
                "entity_col",
                "price_col",
                "open_col",
                "high_col",
                "low_col",
                "close_col",
                "volume_col",
                "bid_col",
                "ask_col",
                "mid_col",
                "bid_size_col",
                "ask_size_col",
                "calendar",
                "timezone",
                "data_frequency",
                "bar_type",
                "timestamp_semantics",
                "session_start_time",
            },
        }
        for section, cfg in data.items():
            if section == "metadata":
                if not isinstance(cfg, dict):
                    raise TypeError(
                        f"Section 'metadata' must be a dict, got {type(cfg).__name__}"
                    )
                continue
            if not isinstance(cfg, dict):
                raise TypeError(f"Section '{section}' must be a dict, got {type(cfg).__name__}")
            unknown_keys = set(cfg) - allowed_keys_by_section[section]
            if unknown_keys:
                raise ValueError(
                    f"Unknown key(s) in section '{section}': {sorted(unknown_keys)}"
                )

    acct_cfg = data.get("account", {})
    exec_cfg = data.get("execution", {})
    stops_cfg = data.get("stops", {})
    sizing_cfg = data.get("position_sizing", {})
    comm_cfg = data.get("commission", {})
    slip_cfg = data.get("slippage", {})
    cash_cfg = data.get("cash", {})
    settle_cfg = data.get("settlement", {})
    order_cfg = data.get("orders", {})
    cal_cfg = data.get("calendar", {})
    feed_cfg = data.get("feed", {})
    metadata = data.get("metadata", {})

    if metadata is None:
        metadata = {}
    if not isinstance(metadata, dict):
        raise TypeError(f"Section 'metadata' must be a dict, got {type(metadata).__name__}")

    allow_short_selling = acct_cfg.get("allow_short_selling", False)
    allow_leverage = acct_cfg.get("allow_leverage", False)

    return cls(
        # Account
        allow_short_selling=allow_short_selling,
        allow_leverage=allow_leverage,
        initial_margin=acct_cfg.get("initial_margin", 0.5),
        long_maintenance_margin=acct_cfg.get("long_maintenance_margin", 0.25),
        short_maintenance_margin=acct_cfg.get("short_maintenance_margin", 0.30),
        fixed_margin_schedule=acct_cfg.get("fixed_margin_schedule"),
        short_cash_policy=ShortCashPolicy(acct_cfg.get("short_cash_policy", "credit")),
        # Execution
        execution_price=ExecutionPrice(exec_cfg.get("execution_price", "open")),
        mark_price=ExecutionPrice(exec_cfg.get("mark_price", "price")),
        execution_mode=ExecutionMode(exec_cfg.get("execution_mode", "next_bar")),
        # Stops
        stop_fill_mode=StopFillMode(stops_cfg.get("stop_fill_mode", "stop_price")),
        stop_level_basis=StopLevelBasis(stops_cfg.get("stop_level_basis", "fill_price")),
        trail_hwm_source=WaterMarkSource(stops_cfg.get("trail_hwm_source", "close")),
        initial_hwm_source=InitialHwmSource(stops_cfg.get("initial_hwm_source", "fill_price")),
        trail_stop_timing=TrailStopTiming(stops_cfg.get("trail_stop_timing", "lagged")),
        # Sizing
        share_type=ShareType(sizing_cfg.get("share_type", "fractional")),
        # Commission
        commission_type=CommissionType(comm_cfg.get("model", "percentage")),
        commission_rate=comm_cfg.get("rate", 0.001),
        commission_per_share=comm_cfg.get("per_share", 0.0),
        commission_per_trade=comm_cfg.get("per_trade", 0.0),
        commission_minimum=comm_cfg.get("minimum", 0.0),
        # Slippage
        slippage_type=SlippageType(slip_cfg.get("model", "percentage")),
        slippage_rate=slip_cfg.get("rate", 0.001),
        slippage_fixed=slip_cfg.get("fixed", 0.0),
        stop_slippage_rate=slip_cfg.get("stop_rate", 0.0),
        # Cash
        initial_cash=cash_cfg.get("initial", 100000.0),
        cash_buffer_pct=cash_cfg.get("buffer_pct", 0.0),
        # Settlement
        settlement_delay=settle_cfg.get("delay", 0),
        settlement_reduces_buying_power=settle_cfg.get("reduces_buying_power", True),
        # Orders
        reject_on_insufficient_cash=order_cfg.get("reject_on_insufficient_cash", True),
        skip_cash_validation=order_cfg.get("skip_cash_validation", False),
        partial_fills_allowed=order_cfg.get("partial_fills_allowed", False),
        fill_ordering=FillOrdering(order_cfg.get("fill_ordering", "exit_first")),
        entry_order_priority=EntryOrderPriority(
            order_cfg.get("entry_order_priority", "submission")
        ),
        next_bar_submission_precheck=order_cfg.get("next_bar_submission_precheck", False),
        next_bar_simple_cash_check=order_cfg.get("next_bar_simple_cash_check", False),
        buying_power_reservation=order_cfg.get("buying_power_reservation", False),
        next_bar_queue_shadow_validation=order_cfg.get(
            "next_bar_queue_shadow_validation", False
        ),
        immediate_fill=order_cfg.get("immediate_fill", False),
        rebalance_mode=RebalanceMode(order_cfg.get("rebalance_mode", "incremental")),
        rebalance_headroom_pct=order_cfg.get("rebalance_headroom_pct", 1.0),
        missing_price_policy=MissingPricePolicy(order_cfg.get("missing_price_policy", "skip")),
        late_asset_policy=LateAssetPolicy(order_cfg.get("late_asset_policy", "allow")),
        late_asset_min_bars=order_cfg.get("late_asset_min_bars", 1),
        # Calendar
        calendar=cal_cfg.get("calendar"),
        timezone=cal_cfg.get("timezone", "UTC"),
        data_frequency=DataFrequency(cal_cfg.get("data_frequency", "daily")),
        enforce_sessions=cal_cfg.get("enforce_sessions", False),
        # Metadata
        preset_name=preset_name,
        feed_spec=FeedSpec.from_any(feed_cfg) if feed_cfg else None,
        metadata=dict(metadata),
    )

to_yaml

to_yaml(path)

Save config to YAML file.

Source code in src/ml4t/backtest/config.py
def to_yaml(self, path: str | Path) -> None:
    """Save config to YAML file."""
    path = Path(path)
    with open(path, "w") as f:
        yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)

to_dict

to_dict()

Convert config to dictionary for serialization.

Source code in src/ml4t/backtest/config.py
def to_dict(self) -> dict:
    """Convert config to dictionary for serialization."""
    return {
        "account": {
            "allow_short_selling": self.allow_short_selling,
            "allow_leverage": self.allow_leverage,
            "initial_margin": self.initial_margin,
            "long_maintenance_margin": self.long_maintenance_margin,
            "short_maintenance_margin": self.short_maintenance_margin,
            "fixed_margin_schedule": self.fixed_margin_schedule,
            "short_cash_policy": self.short_cash_policy.value,
        },
        "execution": {
            "execution_price": self.execution_price.value,
            "mark_price": self.mark_price.value,
            "execution_mode": self.execution_mode.value,
        },
        "stops": {
            "stop_fill_mode": self.stop_fill_mode.value,
            "stop_level_basis": self.stop_level_basis.value,
            "trail_hwm_source": self.trail_hwm_source.value,
            "initial_hwm_source": self.initial_hwm_source.value,
            "trail_stop_timing": self.trail_stop_timing.value,
        },
        "position_sizing": {
            "share_type": self.share_type.value,
        },
        "commission": {
            "model": self.commission_type.value,
            "rate": self.commission_rate,
            "per_share": self.commission_per_share,
            "per_trade": self.commission_per_trade,
            "minimum": self.commission_minimum,
        },
        "slippage": {
            "model": self.slippage_type.value,
            "rate": self.slippage_rate,
            "fixed": self.slippage_fixed,
            "stop_rate": self.stop_slippage_rate,
        },
        "cash": {
            "initial": self.initial_cash,
            "buffer_pct": self.cash_buffer_pct,
        },
        "settlement": {
            "delay": self.settlement_delay,
            "reduces_buying_power": self.settlement_reduces_buying_power,
        },
        "orders": {
            "reject_on_insufficient_cash": self.reject_on_insufficient_cash,
            "skip_cash_validation": self.skip_cash_validation,
            "partial_fills_allowed": self.partial_fills_allowed,
            "fill_ordering": self.fill_ordering.value,
            "entry_order_priority": self.entry_order_priority.value,
            "next_bar_submission_precheck": self.next_bar_submission_precheck,
            "next_bar_simple_cash_check": self.next_bar_simple_cash_check,
            "buying_power_reservation": self.buying_power_reservation,
            "next_bar_queue_shadow_validation": self.next_bar_queue_shadow_validation,
            "immediate_fill": self.immediate_fill,
            "rebalance_mode": self.rebalance_mode.value,
            "rebalance_headroom_pct": self.rebalance_headroom_pct,
            "missing_price_policy": self.missing_price_policy.value,
            "late_asset_policy": self.late_asset_policy.value,
            "late_asset_min_bars": self.late_asset_min_bars,
        },
        "calendar": {
            "calendar": self.calendar,
            "timezone": self.timezone,
            "data_frequency": self.data_frequency.value,
            "enforce_sessions": self.enforce_sessions,
        },
        "feed": _feed_spec_to_dict(self.resolved_feed_spec),
        "metadata": serialize_artifact_value(self.metadata),
    }

validate

validate(warn=True)

Validate configuration and return warnings for edge cases.

Checks for configurations that may produce unexpected results or indicate potential issues. Returns a list of warning messages.

Parameters:

Name Type Description Default
warn bool

If True, emit warnings via warnings.warn(). Default True.

True

Returns:

Type Description
list[str]

List of warning message strings (empty if no issues found).

Example

config = BacktestConfig(execution_mode=ExecutionMode.SAME_BAR) warnings = config.validate()

["SAME_BAR execution has look-ahead bias risk..."]

Source code in src/ml4t/backtest/config.py
def validate(self, warn: bool = True) -> list[str]:
    """Validate configuration and return warnings for edge cases.

    Checks for configurations that may produce unexpected results or
    indicate potential issues. Returns a list of warning messages.

    Args:
        warn: If True, emit warnings via warnings.warn(). Default True.

    Returns:
        List of warning message strings (empty if no issues found).

    Example:
        config = BacktestConfig(execution_mode=ExecutionMode.SAME_BAR)
        warnings = config.validate()
        # ["SAME_BAR execution has look-ahead bias risk..."]
    """
    import warnings as _warnings

    issues: list[str] = []

    # Look-ahead bias warning
    if self.execution_mode == ExecutionMode.SAME_BAR:
        issues.append(
            "SAME_BAR execution has look-ahead bias risk. "
            "Use NEXT_BAR execution mode for realistic backtesting."
        )

    # Zero cost warning
    if self.commission_type == CommissionType.NONE and self.slippage_type == SlippageType.NONE:
        issues.append(
            "Both commission and slippage are disabled. Results may be overly optimistic."
        )

    # Volume-based slippage without partial fills
    if self.slippage_type == SlippageType.VOLUME_BASED and not self.partial_fills_allowed:
        issues.append(
            "Volume-based slippage without partial_fills_allowed may cause "
            "orders to be rejected in low-volume conditions."
        )

    # High slippage + high commission
    total_cost = self.slippage_rate + self.commission_rate
    if total_cost > 0.01:  # > 1% round-trip
        issues.append(
            f"Total transaction cost ({total_cost:.2%}) is high. "
            "Verify this matches your broker's actual costs."
        )

    # Fractional shares warning for production
    if self.share_type == ShareType.FRACTIONAL and self.preset_name == "realistic":
        issues.append(
            "REALISTIC preset with fractional shares may not match all brokers. "
            "Set share_type=INTEGER for most accurate simulation."
        )

    # Margin parameter validation
    if self.allow_leverage:
        if not 0.0 < self.initial_margin <= 1.0:
            issues.append(f"initial_margin ({self.initial_margin}) must be in (0.0, 1.0]")
        if not 0.0 < self.long_maintenance_margin <= 1.0:
            issues.append(
                f"long_maintenance_margin ({self.long_maintenance_margin}) must be in (0.0, 1.0]"
            )
        if not 0.0 < self.short_maintenance_margin <= 1.0:
            issues.append(
                f"short_maintenance_margin ({self.short_maintenance_margin}) must be in (0.0, 1.0]"
            )
        if self.long_maintenance_margin >= self.initial_margin:
            issues.append(
                f"long_maintenance_margin ({self.long_maintenance_margin}) must be < "
                f"initial_margin ({self.initial_margin})"
            )
        if self.short_maintenance_margin >= self.initial_margin:
            issues.append(
                f"short_maintenance_margin ({self.short_maintenance_margin}) must be < "
                f"initial_margin ({self.initial_margin})"
            )

    if self.settlement_delay < 0 or self.settlement_delay > 5:
        issues.append(
            f"settlement_delay ({self.settlement_delay}) should be 0-5. "
            "Common values: 0 (instant), 1 (T+1), 2 (T+2 US equities)."
        )

    if not 0.0 < self.rebalance_headroom_pct <= 1.0:
        issues.append(
            f"rebalance_headroom_pct ({self.rebalance_headroom_pct}) must be in (0.0, 1.0]"
        )
    if self.late_asset_min_bars < 1:
        issues.append(f"late_asset_min_bars ({self.late_asset_min_bars}) must be >= 1")

    # Emit warnings if requested
    if warn and issues:
        for msg in issues:
            _warnings.warn(msg, UserWarning, stacklevel=2)

    return issues

describe

describe()

Return human-readable description of configuration.

Source code in src/ml4t/backtest/config.py
def describe(self) -> str:
    """Return human-readable description of configuration."""
    allow_shorts, allow_leverage = self.get_effective_account_settings()
    account_str = self.get_effective_account_type()

    lines = [
        f"BacktestConfig (preset: {self.preset_name or 'custom'})",
        "=" * 50,
        "",
        "Account:",
        f"  Type: {account_str}",
        f"  Short selling: {'allowed' if allow_shorts else 'disabled'}",
        f"  Leverage: {'enabled' if allow_leverage else 'disabled'}",
    ]

    if allow_leverage:
        lines.extend(
            [
                f"  Initial margin: {self.initial_margin:.0%}",
                f"  Long maintenance: {self.long_maintenance_margin:.0%}",
                f"  Short maintenance: {self.short_maintenance_margin:.0%}",
            ]
        )

    lines.extend(
        [
            "",
            "Execution:",
            f"  Execution mode: {self.execution_mode.value}",
            f"  Execution price: {self.execution_price.value}",
            f"  Mark price: {self.mark_price.value}",
            "",
            "Stops:",
            f"  Fill mode: {self.stop_fill_mode.value}",
            f"  Level basis: {self.stop_level_basis.value}",
            f"  Trail HWM source: {self.trail_hwm_source.value}",
            f"  Trail timing: {self.trail_stop_timing.value}",
            "",
            "Position Sizing:",
            f"  Share type: {self.share_type.value}",
            "",
            "Costs:",
            f"  Commission: {self.commission_type.value} @ {self.commission_rate:.2%}",
            f"  Slippage: {self.slippage_type.value} @ {self.slippage_rate:.2%}",
        ]
    )

    if self.stop_slippage_rate > 0:
        lines.append(f"  Stop slippage: +{self.stop_slippage_rate:.2%}")

    lines.extend(
        [
            "",
            "Orders:",
            f"  Fill ordering: {self.fill_ordering.value}",
            f"  Entry priority: {self.entry_order_priority.value}",
            f"  Next-bar precheck: {self.next_bar_submission_precheck}",
            f"  Next-bar cash check: {self.next_bar_simple_cash_check}",
            f"  Next-bar queue shadow validation: {self.next_bar_queue_shadow_validation}",
            f"  Rebalance mode: {self.rebalance_mode.value}",
            f"  Rebalance headroom: {self.rebalance_headroom_pct:.3f}",
            f"  Missing price policy: {self.missing_price_policy.value}",
            f"  Late asset policy: {self.late_asset_policy.value}",
            f"  Late asset min bars: {self.late_asset_min_bars}",
            f"  Reject insufficient: {self.reject_on_insufficient_cash}",
            f"  Skip cash validation: {self.skip_cash_validation}",
            f"  Partial fills: {self.partial_fills_allowed}",
            "",
            "Cash:",
            f"  Initial: ${self.initial_cash:,.0f}",
            f"  Buffer: {self.cash_buffer_pct:.1%}",
        ]
    )

    if self.settlement_delay > 0:
        lines.extend(
            [
                "",
                "Settlement:",
                f"  Delay: T+{self.settlement_delay}",
            ]
        )

    return "\n".join(line for line in lines if line is not None)

profiles

Centralized profile definitions for framework-aligned behavior.

get_profile_config

get_profile_config(name)

Return a deep copy of nested config data for the named profile.

Source code in src/ml4t/backtest/profiles.py
def get_profile_config(name: str) -> dict:
    """Return a deep copy of nested config data for the named profile."""
    key = _ALIASES.get(name, name)
    if key not in _PROFILES:
        available = ", ".join(sorted(_PROFILES.keys()))
        raise ValueError(f"Unknown preset '{name}'. Available: {available}")
    return deepcopy(_PROFILES[key])

list_profiles

list_profiles()

List canonical preset names.

Source code in src/ml4t/backtest/profiles.py
def list_profiles() -> list[str]:
    """List canonical preset names."""
    return _CORE_PROFILE_NAMES.copy()

Broker

Broker

Broker(
    initial_cash=100000.0,
    commission_model=None,
    slippage_model=None,
    stop_slippage_rate=0.0,
    execution_mode=SAME_BAR,
    execution_price=CLOSE,
    mark_price=PRICE,
    stop_fill_mode=STOP_PRICE,
    stop_level_basis=FILL_PRICE,
    trail_hwm_source=CLOSE,
    initial_hwm_source=FILL_PRICE,
    trail_stop_timing=LAGGED,
    allow_short_selling=False,
    allow_leverage=False,
    initial_margin=0.5,
    long_maintenance_margin=0.25,
    short_maintenance_margin=0.3,
    fixed_margin_schedule=None,
    short_cash_policy=CREDIT,
    execution_limits=None,
    market_impact_model=None,
    contract_specs=None,
    share_type=FRACTIONAL,
    fill_ordering=EXIT_FIRST,
    entry_order_priority=SUBMISSION,
    next_bar_submission_precheck=False,
    next_bar_simple_cash_check=False,
    buying_power_reservation=False,
    next_bar_queue_shadow_validation=False,
    immediate_fill=False,
    reject_on_insufficient_cash=True,
    skip_cash_validation=False,
    cash_buffer_pct=0.0,
    partial_fills_allowed=False,
    rebalance_headroom_pct=1.0,
    missing_price_policy=SKIP,
    late_asset_policy=ALLOW,
    late_asset_min_bars=1,
    settlement_delay=0,
    settlement_reduces_buying_power=True,
)

Broker interface - same for backtest and live trading.

Source code in src/ml4t/backtest/broker.py
def __init__(
    self,
    initial_cash: float = 100000.0,
    commission_model: CommissionModel | None = None,
    slippage_model: SlippageModel | None = None,
    stop_slippage_rate: float = 0.0,
    execution_mode: ExecutionMode = ExecutionMode.SAME_BAR,
    execution_price: ExecutionPrice = ExecutionPrice.CLOSE,
    mark_price: ExecutionPrice = ExecutionPrice.PRICE,
    stop_fill_mode: StopFillMode = StopFillMode.STOP_PRICE,
    stop_level_basis: StopLevelBasis = StopLevelBasis.FILL_PRICE,
    trail_hwm_source: WaterMarkSource = WaterMarkSource.CLOSE,
    initial_hwm_source: InitialHwmSource = InitialHwmSource.FILL_PRICE,
    trail_stop_timing: TrailStopTiming = TrailStopTiming.LAGGED,
    allow_short_selling: bool = False,
    allow_leverage: bool = False,
    initial_margin: float = 0.5,
    long_maintenance_margin: float = 0.25,
    short_maintenance_margin: float = 0.30,
    fixed_margin_schedule: dict[str, tuple[float, float]] | None = None,
    short_cash_policy: ShortCashPolicy = ShortCashPolicy.CREDIT,
    execution_limits: ExecutionLimits | None = None,
    market_impact_model: MarketImpactModel | None = None,
    contract_specs: dict[str, ContractSpec] | None = None,
    share_type: ShareType = ShareType.FRACTIONAL,
    fill_ordering: FillOrdering = FillOrdering.EXIT_FIRST,
    entry_order_priority: EntryOrderPriority = EntryOrderPriority.SUBMISSION,
    next_bar_submission_precheck: bool = False,
    next_bar_simple_cash_check: bool = False,
    buying_power_reservation: bool = False,
    next_bar_queue_shadow_validation: bool = False,
    immediate_fill: bool = False,
    reject_on_insufficient_cash: bool = True,
    skip_cash_validation: bool = False,
    cash_buffer_pct: float = 0.0,
    partial_fills_allowed: bool = False,
    rebalance_headroom_pct: float = 1.0,
    missing_price_policy: MissingPricePolicy = MissingPricePolicy.SKIP,
    late_asset_policy: LateAssetPolicy = LateAssetPolicy.ALLOW,
    late_asset_min_bars: int = 1,
    settlement_delay: int = 0,
    settlement_reduces_buying_power: bool = True,
):
    # Runtime imports for accounting classes.
    # These are imported here rather than at module level because:
    # 1. The package __init__.py imports Broker, creating a potential import order issue
    # 2. TYPE_CHECKING block above provides type hints for static analysis
    # 3. This pattern allows mypy/pyright to validate types without runtime circular import
    from .accounting import (
        AccountState,
        Gatekeeper,
        UnifiedAccountPolicy,
    )

    self.initial_cash = initial_cash
    # Note: self.cash is now a property delegating to self.account.cash (Bug #5 fix)
    self.commission_model = commission_model or NoCommission()
    self.slippage_model = slippage_model or NoSlippage()
    self.stop_slippage_rate = stop_slippage_rate
    self.execution_mode = execution_mode
    self.execution_price = execution_price
    self.mark_price = mark_price
    self.stop_fill_mode = stop_fill_mode
    self.stop_level_basis = stop_level_basis
    self.trail_hwm_source = trail_hwm_source
    self.initial_hwm_source = initial_hwm_source
    self.trail_stop_timing = trail_stop_timing
    self.share_type = share_type
    self.fill_ordering = fill_ordering
    self.entry_order_priority = entry_order_priority
    self.next_bar_submission_precheck = next_bar_submission_precheck
    self.next_bar_simple_cash_check = next_bar_simple_cash_check
    self.buying_power_reservation = buying_power_reservation
    self.next_bar_queue_shadow_validation = next_bar_queue_shadow_validation
    self.immediate_fill = immediate_fill
    self.reject_on_insufficient_cash = reject_on_insufficient_cash
    self.skip_cash_validation = skip_cash_validation
    self.cash_buffer_pct = cash_buffer_pct
    self.partial_fills_allowed = partial_fills_allowed
    self.rebalance_headroom_pct = rebalance_headroom_pct
    self.missing_price_policy = missing_price_policy
    self.late_asset_policy = late_asset_policy
    self.late_asset_min_bars = late_asset_min_bars
    self.settlement_delay = settlement_delay
    self.settlement_reduces_buying_power = settlement_reduces_buying_power
    self._bar_index: int = 0

    # Auto-populate fixed_margin_schedule from ContractSpec.margin
    # This lets users specify margin once on ContractSpec rather than duplicating
    # it in both ContractSpec and BacktestConfig.fixed_margin_schedule.
    effective_margin_schedule = dict(fixed_margin_schedule or {})
    if contract_specs:
        for symbol, spec in contract_specs.items():
            if spec.margin is not None and symbol not in effective_margin_schedule:
                # Use spec.margin as initial margin, 50% as maintenance (industry standard)
                effective_margin_schedule[symbol] = (spec.margin, spec.margin * 0.5)

    # Create AccountState with UnifiedAccountPolicy
    policy: AccountPolicy = UnifiedAccountPolicy(
        allow_short_selling=allow_short_selling,
        allow_leverage=allow_leverage,
        initial_margin=initial_margin,
        long_maintenance_margin=long_maintenance_margin,
        short_maintenance_margin=short_maintenance_margin,
        fixed_margin_schedule=effective_margin_schedule or None,
        short_cash_policy=short_cash_policy.value,
    )

    self.account = AccountState(initial_cash=initial_cash, policy=policy)
    # Derive account_type string from flags for backward compat
    if allow_leverage:
        self.account_type = "margin"
    elif allow_short_selling:
        self.account_type = "crypto"
    else:
        self.account_type = "cash"
    self.allow_short_selling = allow_short_selling
    self.allow_leverage = allow_leverage
    self.initial_margin = initial_margin
    self.long_maintenance_margin = long_maintenance_margin
    self.short_maintenance_margin = short_maintenance_margin
    self.fixed_margin_schedule = fixed_margin_schedule or {}
    self.short_cash_policy = short_cash_policy

    # Create Gatekeeper for order validation
    self.gatekeeper = Gatekeeper(
        self.account,
        self.commission_model,
        cash_buffer_pct=self.cash_buffer_pct,
        settlement_reduces_buying_power=self.settlement_reduces_buying_power,
    )

    self.positions: dict[str, Position] = {}
    self.orders: list[Order] = []
    self.pending_orders: list[Order] = []
    self.fills: list[Fill] = []
    self.trades: list[Trade] = []
    self._order_counter = 0
    self._current_time: datetime | None = None
    self._current_prices: dict[str, float] = {}  # FeedSpec.price_col values
    self._current_opens: dict[str, float] = {}  # open prices for next-bar execution
    self._current_highs: dict[str, float] = {}  # high prices for limit/stop checks
    self._current_lows: dict[str, float] = {}  # low prices for limit/stop checks
    self._current_closes: dict[str, float] = {}
    self._current_volumes: dict[str, float] = {}
    self._current_bids: dict[str, float] = {}
    self._current_asks: dict[str, float] = {}
    self._current_mids: dict[str, float] = {}
    self._current_bid_sizes: dict[str, float] = {}
    self._current_ask_sizes: dict[str, float] = {}
    self._current_signals: dict[str, dict[str, float]] = {}
    self._last_prices: dict[str, float] = {}
    self._asset_bars_seen: dict[str, int] = {}
    self._rebalance_counter = 0
    self._orders_this_bar: list[Order] = []  # Orders placed this bar (for next-bar mode)
    self._orders_this_bar_ids: set[str] = set()

    # Risk management
    self._position_rules: Any = None  # Global position rules
    self._position_rules_by_asset: dict[str, Any] = {}  # Per-asset rules
    self._pending_exits: dict[str, dict] = {}  # asset -> {reason, pct} for NEXT_BAR_OPEN mode

    # Execution model (volume limits and market impact)
    self.execution_limits = execution_limits  # ExecutionLimits instance
    self.market_impact_model = market_impact_model  # MarketImpactModel instance
    self._partial_orders: dict[str, float] = {}  # order_id -> remaining quantity
    self._filled_this_bar: set[str] = set()  # order_ids that had fills this bar

    # VBT Pro compatibility: prevent same-bar re-entry after stop exit
    self._stop_exits_this_bar: set[str] = set()  # assets that had stop exits this bar

    # VBT Pro compatibility: track positions created this bar
    # New positions should NOT have HWM updated from entry bar's high
    # VBT Pro uses CLOSE for initial HWM on entry bar, then updates from HIGH next bar
    self._positions_created_this_bar: set[str] = set()

    # Contract specifications (for futures and other derivatives)
    self._contract_specs: dict[str, ContractSpec] = contract_specs or {}

    # Fill execution (extracted from _execute_fill)
    self._fill_executor = FillExecutor(self)

    # Per-asset trading statistics for stateful decision-making
    self._asset_stats: dict[str, AssetTradingStats] = {}
    self._stats_config = StatsConfig()
    self._session_config = None  # Optional SessionConfig for session boundary detection
    self._last_session_id: int | None = None  # Track current session for boundary detection

    # Extracted orchestration components (Phase B1 alpha-reset)
    self._order_book = OrderBook(self)
    self._risk_engine = RiskEngine(self)
    self._fill_engine = FillEngine(self)
    self._execution_engine = ExecutionEngine(self)
    self._portfolio_ledger = PortfolioLedger(self)

submit_order

submit_order(
    asset,
    quantity,
    side=None,
    order_type=MARKET,
    limit_price=None,
    stop_price=None,
    trail_amount=None,
    _options=None,
)

Submit a new order to the broker.

Creates and queues an order for execution. Orders are validated by the Gatekeeper before fills to ensure account constraints are met.

Parameters:

Name Type Description Default
asset str

Asset symbol (e.g., "AAPL", "BTC-USD")

required
quantity float

Number of shares/units. Positive = buy, negative = sell (if side is not specified)

required
side OrderSide | None

OrderSide.BUY or OrderSide.SELL. If None, inferred from quantity sign

None
order_type OrderType

Order type (MARKET, LIMIT, STOP, TRAILING_STOP)

MARKET
limit_price float | None

Limit price for LIMIT orders

None
stop_price float | None

Stop/trigger price for STOP orders

None
trail_amount float | None

Trail distance for TRAILING_STOP orders

None

Returns:

Type Description
Order | None

Order object if submitted successfully, None if rejected

Order | None

(e.g., same-bar re-entry after stop exit in VBT Pro mode)

Examples:

Market buy

order = broker.submit_order("AAPL", 100)

Market sell (using negative quantity)

order = broker.submit_order("AAPL", -100)

Limit buy

order = broker.submit_order("AAPL", 100, order_type=OrderType.LIMIT, limit_price=150.0)

Stop sell (stop-loss)

order = broker.submit_order("AAPL", -100, order_type=OrderType.STOP, stop_price=145.0)

Source code in src/ml4t/backtest/broker.py
def submit_order(
    self,
    asset: str,
    quantity: float,
    side: OrderSide | None = None,
    order_type: OrderType = OrderType.MARKET,
    limit_price: float | None = None,
    stop_price: float | None = None,
    trail_amount: float | None = None,
    _options: SubmitOrderOptions | None = None,
) -> Order | None:
    """Submit a new order to the broker.

    Creates and queues an order for execution. Orders are validated by the
    Gatekeeper before fills to ensure account constraints are met.

    Args:
        asset: Asset symbol (e.g., "AAPL", "BTC-USD")
        quantity: Number of shares/units. Positive = buy, negative = sell
                 (if side is not specified)
        side: OrderSide.BUY or OrderSide.SELL. If None, inferred from quantity sign
        order_type: Order type (MARKET, LIMIT, STOP, TRAILING_STOP)
        limit_price: Limit price for LIMIT orders
        stop_price: Stop/trigger price for STOP orders
        trail_amount: Trail distance for TRAILING_STOP orders

    Returns:
        Order object if submitted successfully, None if rejected
        (e.g., same-bar re-entry after stop exit in VBT Pro mode)

    Examples:
        # Market buy
        order = broker.submit_order("AAPL", 100)

        # Market sell (using negative quantity)
        order = broker.submit_order("AAPL", -100)

        # Limit buy
        order = broker.submit_order("AAPL", 100, order_type=OrderType.LIMIT,
                                    limit_price=150.0)

        # Stop sell (stop-loss)
        order = broker.submit_order("AAPL", -100, order_type=OrderType.STOP,
                                    stop_price=145.0)
    """
    return self._order_book.submit_order(
        asset=asset,
        quantity=quantity,
        side=side,
        order_type=order_type,
        limit_price=limit_price,
        stop_price=stop_price,
        trail_amount=trail_amount,
        options=_options,
    )

submit_bracket

submit_bracket(
    asset,
    quantity,
    take_profit,
    stop_loss,
    entry_type=MARKET,
    entry_limit=None,
    validate_prices=True,
)

Submit entry with take-profit and stop-loss.

Creates a bracket order with entry, take-profit limit, and stop-loss orders. The exit side is automatically determined from the entry direction.

Parameters:

Name Type Description Default
asset str

Asset symbol to trade

required
quantity float

Position size (positive for long, negative for short)

required
take_profit float

Take-profit price level (LIMIT order)

required
stop_loss float

Stop-loss price level (STOP order)

required
entry_type OrderType

Entry order type (default MARKET)

MARKET
entry_limit float | None

Entry limit price (if entry_type is LIMIT)

None
validate_prices bool

If True, validate that TP/SL prices are sensible for the position direction (default True)

True

Returns:

Type Description
tuple[Order, Order, Order] | None

Tuple of (entry_order, take_profit_order, stop_loss_order) or None if any fails.

Raises:

Type Description
ValueError

If validate_prices=True and prices are inverted for direction.

Notes

For LONG entries (quantity > 0): - take_profit should be > reference_price (profit on up move) - stop_loss should be < reference_price (exit on down move)

For SHORT entries (quantity < 0): - take_profit should be < reference_price (profit on down move) - stop_loss should be > reference_price (exit on up move)

Reference price is entry_limit (if LIMIT order) or current market price.

Source code in src/ml4t/backtest/broker.py
def submit_bracket(
    self,
    asset: str,
    quantity: float,
    take_profit: float,
    stop_loss: float,
    entry_type: OrderType = OrderType.MARKET,
    entry_limit: float | None = None,
    validate_prices: bool = True,
) -> tuple[Order, Order, Order] | None:
    """Submit entry with take-profit and stop-loss.

    Creates a bracket order with entry, take-profit limit, and stop-loss orders.
    The exit side is automatically determined from the entry direction.

    Args:
        asset: Asset symbol to trade
        quantity: Position size (positive for long, negative for short)
        take_profit: Take-profit price level (LIMIT order)
        stop_loss: Stop-loss price level (STOP order)
        entry_type: Entry order type (default MARKET)
        entry_limit: Entry limit price (if entry_type is LIMIT)
        validate_prices: If True, validate that TP/SL prices are sensible
                        for the position direction (default True)

    Returns:
        Tuple of (entry_order, take_profit_order, stop_loss_order) or None if any fails.

    Raises:
        ValueError: If validate_prices=True and prices are inverted for direction.

    Notes:
        For LONG entries (quantity > 0):
            - take_profit should be > reference_price (profit on up move)
            - stop_loss should be < reference_price (exit on down move)

        For SHORT entries (quantity < 0):
            - take_profit should be < reference_price (profit on down move)
            - stop_loss should be > reference_price (exit on up move)

        Reference price is entry_limit (if LIMIT order) or current market price.
    """
    import warnings

    entry = self.submit_order(asset, quantity, order_type=entry_type, limit_price=entry_limit)
    if entry is None:
        return None

    # Derive exit side from entry direction (Bug #4 fix)
    # Long entry (BUY) -> SELL to exit; Short entry (SELL) -> BUY to cover
    exit_side = OrderSide.SELL if entry.side == OrderSide.BUY else OrderSide.BUY
    exit_qty = abs(quantity)

    # Validate bracket prices if requested
    if validate_prices:
        ref_price = entry_limit if entry_limit is not None else self._current_prices.get(asset)
        if ref_price is not None:
            is_long = entry.side == OrderSide.BUY

            if is_long:
                # Long: TP should be above entry, SL should be below
                if take_profit <= ref_price:
                    warnings.warn(
                        f"Bracket order for LONG {asset}: take_profit ({take_profit}) <= "
                        f"entry ({ref_price}). TP should be above entry for longs.",
                        UserWarning,
                        stacklevel=2,
                    )
                if stop_loss >= ref_price:
                    warnings.warn(
                        f"Bracket order for LONG {asset}: stop_loss ({stop_loss}) >= "
                        f"entry ({ref_price}). SL should be below entry for longs.",
                        UserWarning,
                        stacklevel=2,
                    )
            else:
                # Short: TP should be below entry, SL should be above
                if take_profit >= ref_price:
                    warnings.warn(
                        f"Bracket order for SHORT {asset}: take_profit ({take_profit}) >= "
                        f"entry ({ref_price}). TP should be below entry for shorts.",
                        UserWarning,
                        stacklevel=2,
                    )
                if stop_loss <= ref_price:
                    warnings.warn(
                        f"Bracket order for SHORT {asset}: stop_loss ({stop_loss}) <= "
                        f"entry ({ref_price}). SL should be above entry for shorts.",
                        UserWarning,
                        stacklevel=2,
                    )

    tp = self.submit_order(asset, exit_qty, exit_side, OrderType.LIMIT, limit_price=take_profit)
    if tp is None:
        return None
    tp.parent_id = entry.order_id

    sl = self.submit_order(asset, exit_qty, exit_side, OrderType.STOP, stop_price=stop_loss)
    if sl is None:
        return None
    sl.parent_id = entry.order_id

    return entry, tp, sl

close_position

close_position(asset, _options=None)

Close an open position for the given asset.

Submits a market order to fully close the position.

Parameters:

Name Type Description Default
asset str

Asset symbol to close

required

Returns:

Type Description
Order | None

Order object if position exists and order submitted, None otherwise

Example

Close AAPL position

order = broker.close_position("AAPL")

Source code in src/ml4t/backtest/broker.py
def close_position(
    self, asset: str, _options: SubmitOrderOptions | None = None
) -> Order | None:
    """Close an open position for the given asset.

    Submits a market order to fully close the position.

    Args:
        asset: Asset symbol to close

    Returns:
        Order object if position exists and order submitted, None otherwise

    Example:
        # Close AAPL position
        order = broker.close_position("AAPL")
    """
    pos = self.positions.get(asset)
    if pos and pos.quantity != 0:
        side = OrderSide.SELL if pos.quantity > 0 else OrderSide.BUY
        return self.submit_order(asset, abs(pos.quantity), side, _options=_options)
    return None

cancel_order

cancel_order(order_id)
Source code in src/ml4t/backtest/broker.py
def cancel_order(self, order_id: str) -> bool:
    return self._order_book.cancel_order(order_id)

get_position

get_position(asset)

Get the current position for an asset.

Parameters:

Name Type Description Default
asset str

Asset symbol

required

Returns:

Type Description
Position | None

Position object if position exists, None otherwise

Source code in src/ml4t/backtest/broker.py
def get_position(self, asset: str) -> Position | None:
    """Get the current position for an asset.

    Args:
        asset: Asset symbol

    Returns:
        Position object if position exists, None otherwise
    """
    return self.positions.get(asset)

get_positions

get_positions()

Get all current positions.

Returns:

Type Description
dict[str, Position]

Dictionary mapping asset symbols to Position objects

Source code in src/ml4t/backtest/broker.py
def get_positions(self) -> dict[str, Position]:
    """Get all current positions.

    Returns:
        Dictionary mapping asset symbols to Position objects
    """
    return self.positions

get_cash

get_cash()

Get current cash balance.

Returns:

Type Description
float

Current cash balance (can be negative for margin accounts)

Source code in src/ml4t/backtest/broker.py
def get_cash(self) -> float:
    """Get current cash balance.

    Returns:
        Current cash balance (can be negative for margin accounts)
    """
    return self.cash

get_account_value

get_account_value()

Calculate total account value (cash + position values).

Source code in src/ml4t/backtest/broker.py
def get_account_value(self) -> float:
    """Calculate total account value (cash + position values)."""
    return self._portfolio_ledger.get_account_value()

get_rejected_orders

get_rejected_orders(asset=None)

Get all rejected orders, optionally filtered by asset.

Parameters:

Name Type Description Default
asset str | None

If provided, filter to only this asset's rejected orders

None

Returns:

Type Description
list[Order]

List of rejected Order objects with rejection_reason populated

Source code in src/ml4t/backtest/broker.py
def get_rejected_orders(self, asset: str | None = None) -> list[Order]:
    """Get all rejected orders, optionally filtered by asset.

    Args:
        asset: If provided, filter to only this asset's rejected orders

    Returns:
        List of rejected Order objects with rejection_reason populated
    """
    return self._portfolio_ledger.get_rejected_orders(asset=asset)

set_position_rules

set_position_rules(rules, asset=None)

Set position rules globally or per-asset.

Parameters:

Name Type Description Default
rules PositionRule

PositionRule or RuleChain to apply

required
asset str | None

If provided, apply only to this asset; otherwise global

None
Source code in src/ml4t/backtest/broker.py
def set_position_rules(self, rules: PositionRule, asset: str | None = None) -> None:
    """Set position rules globally or per-asset.

    Args:
        rules: PositionRule or RuleChain to apply
        asset: If provided, apply only to this asset; otherwise global
    """
    if asset:
        self._position_rules_by_asset[asset] = rules
    else:
        self._position_rules = rules

rebalance_to_weights

rebalance_to_weights(target_weights, order_type=MARKET)

Rebalance portfolio to target weights.

Calculates orders needed to achieve target portfolio allocation. Processes sells before buys to free up capital.

Parameters:

Name Type Description Default
target_weights dict[str, float]

Dict of {asset: weight} where weights are decimals (0.10 = 10%). Weights should sum to <= 1.0.

required
order_type OrderType

Order type for all orders (default MARKET)

MARKET

Returns:

Type Description
list[Order]

List of submitted orders (may include None for rejected orders)

Example

Equal weight three stocks

broker.rebalance_to_weights({ "AAPL": 0.33, "GOOGL": 0.33, "MSFT": 0.34, })

Source code in src/ml4t/backtest/broker.py
def rebalance_to_weights(
    self,
    target_weights: dict[str, float],
    order_type: OrderType = OrderType.MARKET,
) -> list[Order]:
    """Rebalance portfolio to target weights.

    Calculates orders needed to achieve target portfolio allocation.
    Processes sells before buys to free up capital.

    Args:
        target_weights: Dict of {asset: weight} where weights are decimals
                       (0.10 = 10%). Weights should sum to <= 1.0.
        order_type: Order type for all orders (default MARKET)

    Returns:
        List of submitted orders (may include None for rejected orders)

    Example:
        # Equal weight three stocks
        broker.rebalance_to_weights({
            "AAPL": 0.33,
            "GOOGL": 0.33,
            "MSFT": 0.34,
        })
    """
    portfolio_value = self.get_account_value()
    if portfolio_value <= 0:
        return []

    orders: list[Order] = []
    sells: list[tuple[str, float]] = []  # (asset, target_value)
    buys: list[tuple[str, float]] = []  # (asset, target_value)
    rebalance_id: str | None = None

    scaled_weights = {
        asset: weight * self.rebalance_headroom_pct for asset, weight in target_weights.items()
    }

    def resolve_price(asset: str) -> float | None:
        price = self._current_prices.get(asset)
        if price is not None and price > 0:
            return price
        if self.missing_price_policy == MissingPricePolicy.USE_LAST:
            last = self._last_prices.get(asset)
            if last is not None and last > 0:
                return last
        return None

    def allows_trading(asset: str) -> bool:
        if self.late_asset_policy != LateAssetPolicy.REQUIRE_HISTORY:
            return True
        return self._asset_bars_seen.get(asset, 0) >= self.late_asset_min_bars

    def rebalance_options() -> SubmitOrderOptions:
        nonlocal rebalance_id
        if rebalance_id is None:
            rebalance_id = self._next_rebalance_id()
        return SubmitOrderOptions(rebalance_id=rebalance_id)

    # Calculate target values and categorize as buys or sells
    for asset, weight in scaled_weights.items():
        if not allows_trading(asset):
            continue
        price = resolve_price(asset)
        if price is None:
            continue

        target_value = portfolio_value * weight

        pos = self.positions.get(asset)
        # Bug #2 fix: Include contract multiplier in value calculations
        multiplier = self.get_multiplier(asset)
        current_value = pos.quantity * price * multiplier if pos and pos.quantity != 0 else 0.0

        delta = target_value - current_value
        if abs(delta) < 0.01:  # Less than 1 cent
            continue

        if delta < 0:
            sells.append((asset, target_value))
        else:
            buys.append((asset, target_value))

    # Also close positions not in target weights
    for asset, pos in self.positions.items():
        if pos.quantity != 0 and asset not in scaled_weights:
            sells.append((asset, 0.0))

    # Process sells first (frees capital for buys)
    for asset, target_value in sells:
        price = resolve_price(asset)
        if price is not None:
            order = self._order_to_target_value(
                asset,
                target_value,
                price,
                order_type,
                None,
                rebalance_options(),
            )
            if order:
                orders.append(order)

    # Then process buys
    for asset, target_value in buys:
        price = resolve_price(asset)
        if price is not None:
            order = self._order_to_target_value(
                asset,
                target_value,
                price,
                order_type,
                None,
                rebalance_options(),
            )
            if order:
                orders.append(order)

    return orders

Domain Types

Order dataclass

Order(
    asset,
    side,
    quantity,
    order_type=MARKET,
    limit_price=None,
    stop_price=None,
    trail_amount=None,
    parent_id=None,
    rebalance_id=None,
    order_id="",
    status=PENDING,
    created_at=None,
    filled_at=None,
    filled_price=None,
    filled_quantity=0.0,
    rejection_reason=None,
    _created_bar_index=0,
    _signal_price=None,
    _risk_exit_reason=None,
    _exit_reason=None,
    _risk_fill_price=None,
)

Fill dataclass

Fill(
    order_id,
    asset,
    side,
    quantity,
    price,
    timestamp,
    rebalance_id=None,
    commission=0.0,
    slippage=0.0,
    order_type="",
    limit_price=None,
    stop_price=None,
    price_source="",
    reference_price=None,
    quote_mid_price=None,
    bid_price=None,
    ask_price=None,
    spread=None,
    bid_size=None,
    ask_size=None,
    available_size=None,
)

Trade dataclass

Trade(
    symbol,
    entry_time,
    exit_time,
    entry_price,
    exit_price,
    quantity,
    pnl,
    pnl_percent,
    bars_held,
    fees=0.0,
    exit_slippage=0.0,
    exit_reason="signal",
    status="closed",
    mfe=0.0,
    mae=0.0,
    entry_slippage=0.0,
    multiplier=1.0,
    entry_quote_mid_price=None,
    entry_bid_price=None,
    entry_ask_price=None,
    entry_spread=None,
    entry_available_size=None,
    exit_quote_mid_price=None,
    exit_bid_price=None,
    exit_ask_price=None,
    exit_spread=None,
    exit_available_size=None,
    metadata=None,
)

Round-trip trade (closed or open).

This dataclass is part of the cross-library API specification, designed to produce identical Parquet output across Python, Numba, and Rust implementations.

For open trades (status="open"), exit_time and exit_price represent mark-to-market values at the end of the backtest period.

Schema Alignment (v0.1.0a6): - symbol: Asset identifier (was 'asset' in earlier versions) - fees: Total transaction fees (was 'commission') - mfe/mae: Max favorable/adverse excursion (was 'max_favorable_excursion'/'max_adverse_excursion') - direction: Derived property from quantity sign

direction property

direction

Return 'long' or 'short' based on quantity sign.

is_open property

is_open

Return True if this is an open (mark-to-market) trade.

commission property

commission

Backward-compat alias for validation scripts expecting commission.

gross_pnl property

gross_pnl

Price-move P&L before fees: (exit - entry) * quantity * multiplier.

net_pnl property

net_pnl

P&L after all costs. Alias for self.pnl.

gross_return property

gross_return

Direction-aware gross return. Same as pnl_percent.

net_return property

net_return

Direction-aware net return including fees.

total_slippage_cost property

total_slippage_cost

Total slippage cost in dollars (entry + exit).

cost_drag property

cost_drag

Total cost as fraction of notional: (fees + slippage) / notional.

Position dataclass

Position(
    asset,
    quantity,
    entry_price,
    entry_time,
    current_price=None,
    bars_held=0,
    high_water_mark=None,
    low_water_mark=None,
    max_favorable_excursion=0.0,
    max_adverse_excursion=0.0,
    initial_quantity=None,
    context=dict(),
    multiplier=1.0,
    entry_commission=0.0,
    entry_slippage=0.0,
)

Unified position tracking for strategy and accounting.

Supports both long and short positions with: - Weighted average cost basis tracking - Mark-to-market price tracking - Risk metrics (MFE/MAE, water marks) - Contract multipliers for futures

Attributes:

Name Type Description
asset str

Asset identifier (e.g., "AAPL", "ES")

quantity float

Position size (positive=long, negative=short)

entry_price float

Weighted average entry price (cost basis)

entry_time datetime

Timestamp when position was first opened

current_price float | None

Latest mark-to-market price (updated each bar)

bars_held int

Number of bars this position has been held

Examples:

Long position: Position("AAPL", 100, 150.0, datetime.now()) -> quantity=100, unrealized_pnl depends on current_price

Short position: Position("AAPL", -100, 150.0, datetime.now()) -> quantity=-100, profit if price drops

market_value property

market_value

Current market value of the position.

For long positions: positive value (asset on balance sheet) For short positions: negative value (liability on balance sheet)

Returns:

Type Description
float

Market value = quantity × current_price

side property

side

Return 'long' or 'short' based on quantity sign.

unrealized_pnl

unrealized_pnl(current_price=None)

Calculate unrealized P&L including contract multiplier.

Parameters:

Name Type Description Default
current_price float | None

Price to calculate P&L at. If None, uses self.current_price.

None

Returns:

Type Description
float

Unrealized P&L = (current_price - entry_price) × quantity × multiplier

Source code in src/ml4t/backtest/types.py
def unrealized_pnl(self, current_price: float | None = None) -> float:
    """Calculate unrealized P&L including contract multiplier.

    Args:
        current_price: Price to calculate P&L at. If None, uses self.current_price.

    Returns:
        Unrealized P&L = (current_price - entry_price) × quantity × multiplier
    """
    price = current_price if current_price is not None else self.current_price
    if price is None:
        price = self.entry_price
    return (price - self.entry_price) * self.quantity * self.multiplier

pnl_percent

pnl_percent(current_price=None)

Calculate direction-aware percentage return on position.

For long positions: (price - entry) / entry For short positions: (entry - price) / entry

Parameters:

Name Type Description Default
current_price float | None

Price to calculate return at. If None, uses self.current_price.

None
Source code in src/ml4t/backtest/types.py
def pnl_percent(self, current_price: float | None = None) -> float:
    """Calculate direction-aware percentage return on position.

    For long positions: (price - entry) / entry
    For short positions: (entry - price) / entry

    Args:
        current_price: Price to calculate return at. If None, uses self.current_price.
    """
    price = current_price if current_price is not None else self.current_price
    if price is None:
        price = self.entry_price
    if self.entry_price == 0:
        return 0.0
    raw = (price - self.entry_price) / self.entry_price
    return raw if self.quantity >= 0 else -raw

notional_value

notional_value(current_price=None)

Calculate notional value of position.

Parameters:

Name Type Description Default
current_price float | None

Price to calculate value at. If None, uses self.current_price.

None
Source code in src/ml4t/backtest/types.py
def notional_value(self, current_price: float | None = None) -> float:
    """Calculate notional value of position.

    Args:
        current_price: Price to calculate value at. If None, uses self.current_price.
    """
    price = current_price if current_price is not None else self.current_price
    if price is None:
        price = self.entry_price
    return abs(self.quantity) * price * self.multiplier

update_water_marks

update_water_marks(
    current_price,
    bar_high=None,
    bar_low=None,
    use_high_for_hwm=False,
    use_low_for_lwm=False,
)

Update high/low water marks and excursion tracking.

Parameters:

Name Type Description Default
current_price float

Current bar's close price

required
bar_high float | None

Bar's high price (used for HWM if use_high_for_hwm=True)

None
bar_low float | None

Bar's low price (used for LWM if use_low_for_lwm=True)

None
use_high_for_hwm bool

If True, use bar_high for HWM (VBT Pro OHLC mode). If False, use current_price (close) for HWM (default).

False
use_low_for_lwm bool

If True, use bar_low for LWM (VBT Pro OHLC mode). If False, use current_price (close) for LWM (default).

False
Source code in src/ml4t/backtest/types.py
def update_water_marks(
    self,
    current_price: float,
    bar_high: float | None = None,
    bar_low: float | None = None,
    use_high_for_hwm: bool = False,
    use_low_for_lwm: bool = False,
) -> None:
    """Update high/low water marks and excursion tracking.

    Args:
        current_price: Current bar's close price
        bar_high: Bar's high price (used for HWM if use_high_for_hwm=True)
        bar_low: Bar's low price (used for LWM if use_low_for_lwm=True)
        use_high_for_hwm: If True, use bar_high for HWM (VBT Pro OHLC mode).
                          If False, use current_price (close) for HWM (default).
        use_low_for_lwm: If True, use bar_low for LWM (VBT Pro OHLC mode).
                         If False, use current_price (close) for LWM (default).
    """
    # Update current price
    self.current_price = current_price

    # Select HWM source based on configuration
    high_for_hwm = bar_high if use_high_for_hwm and bar_high is not None else current_price
    low_for_lwm = bar_low if use_low_for_lwm and bar_low is not None else current_price

    # Update water marks (guaranteed non-None after __post_init__)
    if self.high_water_mark is None or high_for_hwm > self.high_water_mark:
        self.high_water_mark = high_for_hwm
    if self.low_water_mark is None or low_for_lwm < self.low_water_mark:
        self.low_water_mark = low_for_lwm

    # Update MFE/MAE using bar extremes (more accurate than close only)
    # For longs: MFE from high, MAE from low
    # For shorts: MFE from low, MAE from high
    if self.quantity > 0:  # Long position
        mfe_return = self.pnl_percent(high_for_hwm)
        mae_return = self.pnl_percent(low_for_lwm)
    else:  # Short position
        mfe_return = self.pnl_percent(low_for_lwm)
        mae_return = self.pnl_percent(high_for_hwm)

    if mfe_return > self.max_favorable_excursion:
        self.max_favorable_excursion = mfe_return
    if mae_return < self.max_adverse_excursion:
        self.max_adverse_excursion = mae_return

__repr__

__repr__()

String representation for debugging.

Source code in src/ml4t/backtest/types.py
def __repr__(self) -> str:
    """String representation for debugging."""
    direction = "LONG" if self.quantity > 0 else "SHORT"
    price = self.current_price if self.current_price is not None else self.entry_price
    pnl = self.unrealized_pnl()
    return (
        f"Position({direction} {abs(self.quantity):.2f} {self.asset} "
        f"@ ${self.entry_price:.2f}, "
        f"current ${price:.2f}, "
        f"PnL ${pnl:+.2f})"
    )

Enums

OrderType

Bases: Enum

OrderSide

Bases: Enum

ExecutionMode

Bases: str, Enum

Order execution timing mode.

StopFillMode

Bases: str, Enum

Stop/take-profit fill price mode.

Different frameworks handle stop order fills differently: - STOP_PRICE: Fill at exact stop/target price (standard model, default) Matches VectorBT Pro with OHLC and Backtrader behavior - CLOSE_PRICE: Fill at bar's close price when stop triggers Matches VectorBT Pro with close-only data - BAR_EXTREME: Fill at bar's low (stop-loss) or high (take-profit) Worst/best case model (conservative/optimistic) - NEXT_BAR_OPEN: Fill at next bar's open price when stop triggers Matches Zipline behavior (strategy-level stops)

CommissionType

Bases: str, Enum

Commission calculation method.

SlippageType

Bases: str, Enum

Slippage calculation method.

FillOrdering

Bases: str, Enum

Order processing sequence within a single bar.

Controls how pending orders are sequenced during fill processing:

EXIT_FIRST (default): All exits → mark-to-market → all entries (with gatekeeper validation). Capital-efficient: exits free cash before entries need it. Matches VectorBT call_seq='auto' behavior.

FIFO

Orders process in submission order with sequential cash updates. Each order's gatekeeper check sees cash from all prior fills. Matches Backtrader's submission-order processing.

SEQUENTIAL

Orders process in submission order (typically alphabetical by asset) without exit/entry separation. Cash updates after each individual fill. Unlike EXIT_FIRST, exits do not pre-free cash for later entries. Matches LEAN's per-order sequential buying-power model.

Results

BacktestResult dataclass

BacktestResult(
    trades,
    equity_curve,
    fills,
    metrics,
    predictions=None,
    config=None,
    equity=None,
    trade_analyzer=None,
    portfolio_state=list(),
    _trades_df=None,
    _equity_df=None,
    _fills_df=None,
    _portfolio_state_df=None,
)

Structured backtest result with export capabilities.

This class wraps the raw output from Engine.run() and provides: - DataFrame conversion methods (trades, equity, daily P&L) - Parquet export/import for persistence - Integration with ml4t.diagnostic library - Backward-compatible dict export

Attributes:

Name Type Description
trades list[Trade]

List of completed Trade objects

equity_curve list[tuple[datetime, float]]

List of (timestamp, portfolio_value) tuples

fills list[Fill]

List of Fill objects (all order fills)

predictions DataFrame | None

Raw prediction DataFrame passed into the backtest (optional)

metrics dict[str, Any]

Dictionary of computed performance metrics

config BacktestConfig | None

BacktestConfig used for the backtest (optional)

equity EquityCurve | None

EquityCurve analytics object

trade_analyzer TradeAnalyzer | None

TradeAnalyzer analytics object

to_fills_dataframe

to_fills_dataframe()

Convert fills to Polars DataFrame.

Source code in src/ml4t/backtest/result.py
def to_fills_dataframe(self) -> pl.DataFrame:
    """Convert fills to Polars DataFrame."""
    if self._fills_df is not None:
        return self._fills_df

    if not self.fills:
        return pl.DataFrame(schema=self._fills_schema())

    records = []
    for fill in self.fills:
        records.append(
            {
                "order_id": fill.order_id,
                "rebalance_id": fill.rebalance_id,
                "asset": fill.asset,
                "side": fill.side.value,
                "quantity": fill.quantity,
                "price": fill.price,
                "timestamp": fill.timestamp,
                "commission": fill.commission,
                "slippage": fill.slippage,
                "order_type": fill.order_type,
                "limit_price": fill.limit_price,
                "stop_price": fill.stop_price,
                "price_source": fill.price_source,
                "reference_price": fill.reference_price,
                "quote_mid_price": fill.quote_mid_price,
                "bid_price": fill.bid_price,
                "ask_price": fill.ask_price,
                "spread": fill.spread,
                "bid_size": fill.bid_size,
                "ask_size": fill.ask_size,
                "available_size": fill.available_size,
            }
        )

    self._fills_df = pl.DataFrame(records, schema=self._fills_schema())
    return self._fills_df

to_portfolio_state_dataframe

to_portfolio_state_dataframe()

Convert portfolio state snapshots to Polars DataFrame.

Returns DataFrame with columns

timestamp, equity, cash, gross_exposure, net_exposure, open_positions

Returns:

Type Description
DataFrame

Polars DataFrame with one row per bar, sorted by timestamp

Source code in src/ml4t/backtest/result.py
def to_portfolio_state_dataframe(self) -> pl.DataFrame:
    """Convert portfolio state snapshots to Polars DataFrame.

    Returns DataFrame with columns:
        timestamp, equity, cash, gross_exposure, net_exposure, open_positions

    Returns:
        Polars DataFrame with one row per bar, sorted by timestamp
    """
    if self._portfolio_state_df is not None:
        return self._portfolio_state_df

    if not self.portfolio_state:
        return pl.DataFrame(schema=self._portfolio_state_schema())

    self._portfolio_state_df = (
        pl.DataFrame(
            self.portfolio_state,
            schema=[
                "timestamp",
                "equity",
                "cash",
                "gross_exposure",
                "net_exposure",
                "open_positions",
            ],
            orient="row",
        )
        .sort("timestamp")
        .cast(self._portfolio_state_schema())
    )
    return self._portfolio_state_df

to_predictions_dataframe

to_predictions_dataframe()

Return the raw prediction DataFrame used as backtest input.

Source code in src/ml4t/backtest/result.py
def to_predictions_dataframe(self) -> pl.DataFrame:
    """Return the raw prediction DataFrame used as backtest input."""
    if self.predictions is None:
        return pl.DataFrame()
    return self.predictions

to_trades_dataframe

to_trades_dataframe()

Convert trades to Polars DataFrame.

Returns DataFrame with columns

symbol, entry_time, exit_time, entry_price, exit_price, quantity, direction, pnl, pnl_percent, bars_held, fees, exit_slippage, mfe, mae, entry_slippage, multiplier, gross_pnl, net_return, total_slippage_cost, cost_drag, exit_reason, status

Cost decomposition columns

gross_pnl: Price-move P&L before fees net_return: Direction-aware net return including fees total_slippage_cost: Entry + exit slippage in dollars cost_drag: Total cost as fraction of notional

The status column indicates "closed" (actually exited) or "open" (mark-to-market at end of backtest).

Returns:

Type Description
DataFrame

Polars DataFrame with one row per trade

Source code in src/ml4t/backtest/result.py
def to_trades_dataframe(self) -> pl.DataFrame:
    """Convert trades to Polars DataFrame.

    Returns DataFrame with columns:
        symbol, entry_time, exit_time, entry_price, exit_price,
        quantity, direction, pnl, pnl_percent, bars_held,
        fees, exit_slippage, mfe, mae, entry_slippage, multiplier,
        gross_pnl, net_return, total_slippage_cost, cost_drag,
        exit_reason, status

    Cost decomposition columns:
        gross_pnl: Price-move P&L before fees
        net_return: Direction-aware net return including fees
        total_slippage_cost: Entry + exit slippage in dollars
        cost_drag: Total cost as fraction of notional

    The status column indicates "closed" (actually exited) or "open"
    (mark-to-market at end of backtest).

    Returns:
        Polars DataFrame with one row per trade
    """
    if self._trades_df is not None:
        return self._trades_df

    if not self.trades:
        return pl.DataFrame(schema=self._trades_schema())

    records = []
    for t in self.trades:
        records.append(
            {
                "symbol": t.symbol,
                "entry_time": t.entry_time,
                "exit_time": t.exit_time,
                "entry_price": t.entry_price,
                "exit_price": t.exit_price,
                "quantity": t.quantity,
                "direction": t.direction,
                "pnl": t.pnl,
                "pnl_percent": t.pnl_percent,
                "bars_held": t.bars_held,
                "fees": t.fees,
                "exit_slippage": t.exit_slippage,
                "mfe": t.mfe,
                "mae": t.mae,
                "entry_slippage": t.entry_slippage,
                "multiplier": t.multiplier,
                "entry_quote_mid_price": t.entry_quote_mid_price,
                "entry_bid_price": t.entry_bid_price,
                "entry_ask_price": t.entry_ask_price,
                "entry_spread": t.entry_spread,
                "entry_available_size": t.entry_available_size,
                "exit_quote_mid_price": t.exit_quote_mid_price,
                "exit_bid_price": t.exit_bid_price,
                "exit_ask_price": t.exit_ask_price,
                "exit_spread": t.exit_spread,
                "exit_available_size": t.exit_available_size,
                "gross_pnl": t.gross_pnl,
                "net_return": t.net_return,
                "total_slippage_cost": t.total_slippage_cost,
                "cost_drag": t.cost_drag,
                "exit_reason": t.exit_reason,
                "status": t.status,
            }
        )

    self._trades_df = pl.DataFrame(records, schema=self._trades_schema())
    return self._trades_df

to_equity_dataframe

to_equity_dataframe()

Convert equity curve to Polars DataFrame.

Returns DataFrame with columns

timestamp, equity, return, cumulative_return, drawdown, high_water_mark

Returns:

Type Description
DataFrame

Polars DataFrame with one row per bar, sorted by timestamp

Source code in src/ml4t/backtest/result.py
def to_equity_dataframe(self) -> pl.DataFrame:
    """Convert equity curve to Polars DataFrame.

    Returns DataFrame with columns:
        timestamp, equity, return, cumulative_return,
        drawdown, high_water_mark

    Returns:
        Polars DataFrame with one row per bar, sorted by timestamp
    """
    if self._equity_df is not None:
        return self._equity_df

    if not self.equity_curve:
        return pl.DataFrame(schema=self._equity_schema())

    timestamps = [ts for ts, _ in self.equity_curve]
    values = [float(v) for _, v in self.equity_curve]

    # Build base DataFrame and sort by timestamp
    df = pl.DataFrame({"timestamp": timestamps, "equity": values}).sort("timestamp")

    # Vectorized computation using Polars
    df = df.with_columns(
        [
            # Returns: percent change, first bar has no return
            pl.col("equity").pct_change().fill_null(0.0).alias("return"),
            # Cumulative return from initial equity
            (pl.col("equity") / pl.first("equity") - 1.0).alias("cumulative_return"),
            # High water mark (running maximum)
            pl.col("equity").cum_max().alias("high_water_mark"),
        ]
    ).with_columns(
        # Drawdown: (equity / hwm) - 1, handle division by zero
        pl.when(pl.col("high_water_mark") > 0)
        .then(pl.col("equity") / pl.col("high_water_mark") - 1.0)
        .otherwise(0.0)
        .alias("drawdown")
    )

    # Reorder columns to match expected schema
    self._equity_df = df.select(
        ["timestamp", "equity", "return", "cumulative_return", "drawdown", "high_water_mark"]
    )

    return self._equity_df

to_dict

to_dict()

Export as dictionary (backward compatible with Engine.run()).

Returns:

Type Description
dict[str, Any]

Dictionary with all metrics and raw data

Source code in src/ml4t/backtest/result.py
def to_dict(self) -> dict[str, Any]:
    """Export as dictionary (backward compatible with Engine.run()).

    Returns:
        Dictionary with all metrics and raw data
    """
    result = dict(self.metrics)
    result.update(
        {
            "trades": self.trades,
            "equity_curve": self.equity_curve,
            "fills": self.fills,
            "portfolio_state": self.portfolio_state,
        }
    )
    if self.predictions is not None:
        result["predictions"] = self.predictions
    if self.equity is not None:
        result["equity"] = self.equity
    if self.trade_analyzer is not None:
        result["trade_analyzer"] = self.trade_analyzer
    return result

to_parquet

to_parquet(path, include=None, compression='zstd')

Export backtest result to Parquet files.

Creates directory structure

{path}/ trades.parquet fills.parquet predictions.parquet equity.parquet portfolio_state.parquet daily_pnl.parquet metrics.json config.yaml (if config available) spec.yaml (if config available)

Parameters:

Name Type Description Default
path str | Path

Directory path to write files

required
include list[str] | None

Components to include. Default: all. Options: ["trades", "fills", "predictions", "equity", "portfolio_state", "daily_pnl", "metrics", "config"]

None
compression Literal['lz4', 'uncompressed', 'snappy', 'gzip', 'brotli', 'zstd']

Parquet compression codec (default: "zstd")

'zstd'

Returns:

Type Description
dict[str, Path]

Dict mapping component names to file paths

Source code in src/ml4t/backtest/result.py
def to_parquet(
    self,
    path: str | Path,
    include: list[str] | None = None,
    compression: Literal["lz4", "uncompressed", "snappy", "gzip", "brotli", "zstd"] = "zstd",
) -> dict[str, Path]:
    """Export backtest result to Parquet files.

    Creates directory structure:
        {path}/
            trades.parquet
            fills.parquet
            predictions.parquet
            equity.parquet
            portfolio_state.parquet
            daily_pnl.parquet
            metrics.json
            config.yaml (if config available)
            spec.yaml (if config available)

    Args:
        path: Directory path to write files
        include: Components to include. Default: all.
            Options: ["trades", "fills", "predictions", "equity", "portfolio_state", "daily_pnl",
                "metrics", "config"]
        compression: Parquet compression codec (default: "zstd")

    Returns:
        Dict mapping component names to file paths
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    if include is None:
        include = [
            "trades",
            "fills",
            "predictions",
            "equity",
            "portfolio_state",
            "daily_pnl",
            "metrics",
            "config",
            "spec",
        ]

    written: dict[str, Path] = {}

    if "trades" in include:
        trades_path = path / "trades.parquet"
        self.to_trades_dataframe().write_parquet(trades_path, compression=compression)
        written["trades"] = trades_path

    if "fills" in include:
        fills_path = path / "fills.parquet"
        self.to_fills_dataframe().write_parquet(fills_path, compression=compression)
        written["fills"] = fills_path

    if "predictions" in include and self.predictions is not None:
        predictions_path = path / "predictions.parquet"
        self.to_predictions_dataframe().write_parquet(predictions_path, compression=compression)
        written["predictions"] = predictions_path

    if "equity" in include:
        equity_path = path / "equity.parquet"
        self.to_equity_dataframe().write_parquet(equity_path, compression=compression)
        written["equity"] = equity_path

    if "portfolio_state" in include:
        portfolio_state_path = path / "portfolio_state.parquet"
        self.to_portfolio_state_dataframe().write_parquet(
            portfolio_state_path, compression=compression
        )
        written["portfolio_state"] = portfolio_state_path

    if "daily_pnl" in include:
        daily_path = path / "daily_pnl.parquet"
        self.to_daily_pnl().write_parquet(daily_path, compression=compression)
        written["daily_pnl"] = daily_path

    if "metrics" in include:
        metrics_path = path / "metrics.json"
        # Filter to JSON-serializable metrics
        serializable = {}
        for k, v in self.metrics.items():
            if isinstance(v, int | float | str | bool | type(None)):
                serializable[k] = v
            elif isinstance(v, datetime):
                serializable[k] = v.isoformat()
            else:
                # Handle numpy scalars (np.float64, np.int64, etc.)
                try:
                    import numpy as np

                    if isinstance(v, np.generic):
                        serializable[k] = v.item()
                except (ImportError, AttributeError):
                    pass  # Skip if numpy not available or not a numpy type
        with open(metrics_path, "w") as f:
            json.dump(serializable, f, indent=2)
        written["metrics"] = metrics_path

    if "config" in include and self.config is not None:
        config_path = path / "config.yaml"
        try:
            import yaml

            with open(config_path, "w") as f:
                yaml.dump(self.config.to_dict(), f, default_flow_style=False)
            written["config"] = config_path
        except (ImportError, AttributeError):
            pass  # Skip if yaml not available or config has no to_dict

    if "spec" in include and self.config is not None:
        spec_path = path / "spec.yaml"
        try:
            import yaml

            with open(spec_path, "w") as f:
                yaml.dump(self.to_spec_dict(), f, default_flow_style=False, sort_keys=False)
            written["spec"] = spec_path
        except (ImportError, AttributeError):
            pass

    return written

Execution: Market Impact

LinearImpact dataclass

LinearImpact(coefficient=0.1, permanent_fraction=0.5)

Bases: MarketImpactModel

Linear market impact model.

Impact = coefficient * (quantity / volume) * price

Simple model where impact scales linearly with participation rate. Appropriate for liquid markets with moderate order sizes.

Parameters:

Name Type Description Default
coefficient float

Impact scaling factor (default 0.1) Higher values = more impact per unit participation

0.1
permanent_fraction float

Fraction of impact that is permanent (0-1) Remainder is temporary and reverts

0.5
Example

model = LinearImpact(coefficient=0.1)

10% participation at $100 price = $1.00 impact

calculate

calculate(quantity, price, volume, is_buy)

Calculate linear impact.

Source code in src/ml4t/backtest/execution/impact.py
def calculate(
    self,
    quantity: float,
    price: float,
    volume: float | None,
    is_buy: bool,
) -> float:
    """Calculate linear impact."""
    if volume is None or volume == 0:
        return 0.0

    participation = quantity / volume
    impact = self.coefficient * participation * price

    # Apply direction (buys push price up, sells push price down)
    return impact if is_buy else -impact

SquareRootImpact dataclass

SquareRootImpact(
    coefficient=0.5, volatility=0.02, adv_factor=1.0
)

Bases: MarketImpactModel

Square root market impact model (Almgren-Chriss style).

Impact = coefficient * sigma * sqrt(quantity / ADV) * price

Based on academic market microstructure research. Impact scales with the square root of order size, which matches empirical observations.

Parameters:

Name Type Description Default
coefficient float

Scaling factor (default 0.5, typical range 0.1-1.0)

0.5
volatility float

Daily volatility (sigma, default 0.02 = 2%)

0.02
adv_factor float

Average daily volume as multiple of bar volume (default 1.0 for daily bars, 390 for minute bars)

1.0
Example

model = SquareRootImpact(coefficient=0.5, volatility=0.02)

For order = 1% of ADV at 2% vol, $100 price:

Impact = 0.5 * 0.02 * sqrt(0.01) * 100 = $0.10

calculate

calculate(quantity, price, volume, is_buy)

Calculate square root impact.

Source code in src/ml4t/backtest/execution/impact.py
def calculate(
    self,
    quantity: float,
    price: float,
    volume: float | None,
    is_buy: bool,
) -> float:
    """Calculate square root impact."""
    if volume is None or volume == 0:
        return 0.0

    adv = volume * self.adv_factor
    participation = quantity / adv

    # Square root impact
    impact = self.coefficient * self.volatility * math.sqrt(participation) * price

    return impact if is_buy else -impact

PowerLawImpact dataclass

PowerLawImpact(
    coefficient=0.1, exponent=0.5, min_impact=0.0
)

Bases: MarketImpactModel

Generalized power law impact model.

Impact = coefficient * (quantity / volume)^exponent * price

Flexible model that can represent various impact regimes. - exponent = 1.0: Linear (like LinearImpact) - exponent = 0.5: Square root (like SquareRootImpact) - exponent < 0.5: Concave (impact flattens for large orders) - exponent > 1.0: Convex (impact accelerates for large orders)

Parameters:

Name Type Description Default
coefficient float

Scaling factor (default 0.1)

0.1
exponent float

Power law exponent (default 0.5)

0.5
min_impact float

Minimum impact per trade (fixed cost, default 0)

0.0
Example

model = PowerLawImpact(coefficient=0.1, exponent=0.6)

calculate

calculate(quantity, price, volume, is_buy)

Calculate power law impact.

Source code in src/ml4t/backtest/execution/impact.py
def calculate(
    self,
    quantity: float,
    price: float,
    volume: float | None,
    is_buy: bool,
) -> float:
    """Calculate power law impact."""
    if volume is None or volume == 0:
        return self.min_impact if is_buy else -self.min_impact

    participation = quantity / volume

    # Power law impact
    impact = self.coefficient * (participation**self.exponent) * price
    impact = max(impact, self.min_impact)

    return impact if is_buy else -impact

Risk: Position Rules

StopLoss dataclass

StopLoss(pct)

Exit when stop price is breached during the bar.

Stop orders trigger when the bar's price range touches the stop level. Fill price depends on StopFillMode configuration: - STOP_PRICE: Fill at exact stop price (standard model, default) - BAR_EXTREME: Fill at bar's low (matches VectorBT Pro behavior)

For long positions: stop triggers if bar_low <= stop_price For short positions: stop triggers if bar_high >= stop_price

Parameters:

Name Type Description Default
pct float

Maximum loss as decimal (0.05 = 5% loss triggers exit)

required
Example

rule = StopLoss(pct=0.05) # Exit at -5%

evaluate

evaluate(state)

Exit if stop price was breached during the bar.

Source code in src/ml4t/backtest/risk/position/static.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Exit if stop price was breached during the bar."""
    # Get base price for stop level calculation (entry_price or signal_price)
    base_price = _get_stop_base_price(state, state.context)

    # Calculate stop price from base
    if state.is_long:
        stop_price = base_price * (1 - self.pct)
        # Check if stop was triggered during bar (low touched stop level)
        triggered = (
            state.bar_low is not None and state.bar_low <= stop_price
        ) or state.current_price <= stop_price
    else:  # short
        stop_price = base_price * (1 + self.pct)
        # Check if stop was triggered during bar (high touched stop level)
        triggered = (
            state.bar_high is not None and state.bar_high >= stop_price
        ) or state.current_price >= stop_price

    if triggered:
        # Determine fill price based on mode
        fill_mode = _get_stop_fill_mode(state.context)
        if fill_mode == StopFillMode.NEXT_BAR_OPEN:
            # Zipline model: defer exit to next bar, fill at open
            return PositionAction.exit_full(
                reason=f"stop_loss_{self.pct:.1%}",
                defer_fill=True,  # Broker will fill at next bar's open
            )
        elif fill_mode == StopFillMode.CLOSE_PRICE:
            # VectorBT Pro close-only model: always fill at close price
            fill_price = state.current_price
        elif fill_mode == StopFillMode.BAR_EXTREME:
            # Conservative model: fill at bar's extreme (worst case)
            if state.is_long:
                fill_price = state.bar_low if state.bar_low is not None else stop_price
            else:
                fill_price = state.bar_high if state.bar_high is not None else stop_price
        else:
            # Standard model (STOP_PRICE): fill at exact stop price if within bar range
            # If bar gaps through stop, fill at open (gap behavior)
            if state.is_long:
                # For long stops: check if bar opened below stop (gap down)
                # or if stop is within bar range
                if state.bar_open is not None and state.bar_open <= stop_price:
                    # Bar opened below stop - fill at open (Backtrader gap behavior)
                    fill_price = state.bar_open
                elif (
                    state.bar_low is not None
                    and state.bar_high is not None
                    and state.bar_low <= stop_price <= state.bar_high
                ):
                    # Stop within bar range - fill at exact stop
                    fill_price = stop_price
                else:
                    # Gap through (VBT behavior) - fill at close
                    fill_price = state.current_price
            else:
                # For short stops: check if bar opened above stop (gap up)
                if state.bar_open is not None and state.bar_open >= stop_price:
                    # Bar opened above stop - fill at open (gap behavior)
                    fill_price = state.bar_open
                elif (
                    state.bar_low is not None
                    and state.bar_high is not None
                    and state.bar_low <= stop_price <= state.bar_high
                ):
                    # Stop within bar range - fill at exact stop
                    fill_price = stop_price
                else:
                    # Gap through (VBT behavior) - fill at close
                    fill_price = state.current_price

        return PositionAction.exit_full(
            reason=f"stop_loss_{self.pct:.1%}",
            fill_price=fill_price,
        )
    return PositionAction.hold()

TakeProfit dataclass

TakeProfit(pct)

Exit when target price is reached during the bar.

Take-profit orders trigger when the bar's price range touches the target. Fill price depends on StopFillMode configuration: - STOP_PRICE: Fill at exact target price (standard model, default) - BAR_EXTREME: Fill at bar's high (matches VectorBT Pro behavior)

For long positions: triggers if bar_high >= target_price For short positions: triggers if bar_low <= target_price

Parameters:

Name Type Description Default
pct float

Target profit as decimal (0.10 = 10% profit triggers exit)

required
Example

rule = TakeProfit(pct=0.10) # Exit at +10%

evaluate

evaluate(state)

Exit if target price was reached during the bar.

Source code in src/ml4t/backtest/risk/position/static.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Exit if target price was reached during the bar."""
    # Get base price for target level calculation (entry_price or signal_price)
    base_price = _get_stop_base_price(state, state.context)

    # Calculate target price from base
    if state.is_long:
        target_price = base_price * (1 + self.pct)
        # Check if target was reached during bar (high touched target)
        triggered = (
            state.bar_high is not None and state.bar_high >= target_price
        ) or state.current_price >= target_price
    else:  # short
        target_price = base_price * (1 - self.pct)
        # Check if target was reached during bar (low touched target)
        triggered = (
            state.bar_low is not None and state.bar_low <= target_price
        ) or state.current_price <= target_price

    if triggered:
        # Determine fill price based on mode
        fill_mode = _get_stop_fill_mode(state.context)
        if fill_mode == StopFillMode.NEXT_BAR_OPEN:
            # Zipline model: defer exit to next bar, fill at open
            return PositionAction.exit_full(
                reason=f"take_profit_{self.pct:.1%}",
                defer_fill=True,  # Broker will fill at next bar's open
            )
        elif fill_mode == StopFillMode.CLOSE_PRICE:
            # VectorBT Pro close-only model: always fill at close price
            fill_price = state.current_price
        elif fill_mode == StopFillMode.BAR_EXTREME:
            # Optimistic model: fill at bar's extreme (best case for profits)
            if state.is_long:
                fill_price = state.bar_high if state.bar_high is not None else target_price
            else:
                fill_price = state.bar_low if state.bar_low is not None else target_price
        else:
            # Standard model (STOP_PRICE): fill at exact target price if within bar range
            # If bar gaps through target, fill at open/close (gap behavior)
            if state.is_long:
                # For long targets: check if bar opened above target (price improvement)
                # or if target is within bar range
                if state.bar_open is not None and state.bar_open >= target_price:
                    # Bar opened above target - fill at open (Backtrader behavior)
                    fill_price = state.bar_open
                elif (
                    state.bar_low is not None
                    and state.bar_high is not None
                    and state.bar_low <= target_price <= state.bar_high
                ):
                    # Target within bar range - fill at exact target
                    fill_price = target_price
                else:
                    # Gap through - fill at close
                    fill_price = state.current_price
            else:
                # For short targets: check if bar opened below target (price improvement)
                if state.bar_open is not None and state.bar_open <= target_price:
                    # Bar opened below target - fill at open (price improvement)
                    fill_price = state.bar_open
                elif (
                    state.bar_low is not None
                    and state.bar_high is not None
                    and state.bar_low <= target_price <= state.bar_high
                ):
                    # Target within bar range - fill at exact target
                    fill_price = target_price
                else:
                    # Gap through - fill at close
                    fill_price = state.current_price

        return PositionAction.exit_full(
            reason=f"take_profit_{self.pct:.1%}",
            fill_price=fill_price,
        )
    return PositionAction.hold()

TimeExit dataclass

TimeExit(max_bars)

Exit after holding for a specified number of bars.

Parameters:

Name Type Description Default
max_bars int

Maximum bars to hold position

required
Example

rule = TimeExit(max_bars=20) # Exit after 20 bars

evaluate

evaluate(state)

Exit if held too long.

Source code in src/ml4t/backtest/risk/position/static.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Exit if held too long."""
    if state.bars_held >= self.max_bars:
        # Time exits fill at current close price
        return PositionAction.exit_full(f"time_exit_{self.max_bars}bars")
    return PositionAction.hold()

TrailingStop dataclass

TrailingStop(pct)

Exit when price retraces from high water mark.

For longs: Exit if price drops X% from highest price since entry For shorts: Exit if price rises X% from lowest price since entry

Fill price depends on StopFillMode configuration: - STOP_PRICE: Fill at exact trail level (default) - CLOSE_PRICE: Fill at bar's close price (VBT Pro behavior)

Parameters:

Name Type Description Default
pct float

Trail percentage as decimal (0.05 = 5% trail)

required
Example

rule = TrailingStop(pct=0.05) # 5% trailing stop

evaluate

evaluate(state)

Exit if price retraces beyond trail.

Uses bar_low/bar_high for intrabar trigger detection. Handles gap-through: if bar opens beyond stop level, fill at open.

Fill price depends on StopFillMode configuration: - STOP_PRICE: Fill at exact trail level (default) - CLOSE_PRICE: Fill at bar's close price - BAR_EXTREME: Fill at bar's low (long) or high (short)

Water mark timing depends on TrailStopTiming configuration: - LAGGED: Use water mark from PREVIOUS bar (default, 1-bar lag) - INTRABAR: Compute "live" water mark using current bar's extreme, then check. VBT Pro compatible: respects StopFillMode for fill price.

Gap-through handling: When bar opens beyond the stop level (gap down for longs, gap up for shorts), the fill is at the open price regardless of StopFillMode. This matches VBT Pro behavior.

Source code in src/ml4t/backtest/risk/position/dynamic.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Exit if price retraces beyond trail.

    Uses bar_low/bar_high for intrabar trigger detection.
    Handles gap-through: if bar opens beyond stop level, fill at open.

    Fill price depends on StopFillMode configuration:
    - STOP_PRICE: Fill at exact trail level (default)
    - CLOSE_PRICE: Fill at bar's close price
    - BAR_EXTREME: Fill at bar's low (long) or high (short)

    Water mark timing depends on TrailStopTiming configuration:
    - LAGGED: Use water mark from PREVIOUS bar (default, 1-bar lag)
    - INTRABAR: Compute "live" water mark using current bar's extreme, then check.
                VBT Pro compatible: respects StopFillMode for fill price.

    Gap-through handling: When bar opens beyond the stop level (gap down for
    longs, gap up for shorts), the fill is at the open price regardless of
    StopFillMode. This matches VBT Pro behavior.
    """

    fill_mode = _get_stop_fill_mode_for_trail(state.context)
    trail_timing = _get_trail_stop_timing(state.context)

    if state.is_long:
        return self._evaluate_long(state, fill_mode, trail_timing)
    else:
        return self._evaluate_short(state, fill_mode, trail_timing)

RuleChain dataclass

RuleChain(rules)

Evaluate rules in order, first non-HOLD action wins.

This is the most common composition pattern - rules are checked in priority order and the first rule to trigger takes effect.

Parameters:

Name Type Description Default
rules list[PositionRule]

List of rules to evaluate in order

required
Example

chain = RuleChain([ StopLoss(pct=0.05), # Highest priority ScaledExit([(0.10, 0.5)]), # Second priority TighteningTrailingStop([...]), # Third priority TimeExit(bars=20), # Lowest priority ])

evaluate

evaluate(state)

Evaluate rules in order, return first non-HOLD action.

Source code in src/ml4t/backtest/risk/position/composite.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Evaluate rules in order, return first non-HOLD action."""
    for rule in self.rules:
        action = rule.evaluate(state)
        if action.action != ActionType.HOLD:
            return action
    return PositionAction.hold()

AllOf dataclass

AllOf(rules)

All rules must return non-HOLD for the action to trigger.

Useful for requiring multiple conditions to be true before exiting. Returns the first rule's action details (pct, stop_price, etc.).

Parameters:

Name Type Description Default
rules list[PositionRule]

List of rules that must all agree

required
Example

Exit only if both profitable AND held long enough

rule = AllOf([ TakeProfit(pct=0.0), # Must be profitable TimeExit(bars=5), # Must have held 5+ bars ])

evaluate

evaluate(state)

Return action only if ALL rules return non-HOLD.

Source code in src/ml4t/backtest/risk/position/composite.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Return action only if ALL rules return non-HOLD."""
    actions = [rule.evaluate(state) for rule in self.rules]

    if all(a.action != ActionType.HOLD for a in actions):
        # All triggered - return first rule's action with combined reason
        reasons = [a.reason for a in actions if a.reason]
        first = actions[0]
        return PositionAction(
            action=first.action,
            pct=first.pct,
            stop_price=first.stop_price,
            reason=" AND ".join(reasons) if reasons else "",
        )

    return PositionAction.hold()

AnyOf dataclass

AnyOf(rules)

First rule to return non-HOLD wins (alias for RuleChain).

Semantically equivalent to RuleChain but named for clarity when composing complex rule logic.

Parameters:

Name Type Description Default
rules list[PositionRule]

List of rules where any can trigger

required
Example

Exit on stop-loss OR signal

rule = AnyOf([ StopLoss(pct=0.05), SignalExit(threshold=0.5), ])

evaluate

evaluate(state)

Return first non-HOLD action (same as RuleChain).

Source code in src/ml4t/backtest/risk/position/composite.py
def evaluate(self, state: PositionState) -> PositionAction:
    """Return first non-HOLD action (same as RuleChain)."""
    for rule in self.rules:
        action = rule.evaluate(state)
        if action.action != ActionType.HOLD:
            return action
    return PositionAction.hold()

Risk: Portfolio Limits

MaxDrawdownLimit dataclass

MaxDrawdownLimit(
    max_drawdown=0.2, action="halt", warn_threshold=None
)

Bases: PortfolioLimit

Halt trading when drawdown exceeds threshold.

Parameters:

Name Type Description Default
max_drawdown float

Maximum allowed drawdown (0.0-1.0) Default 0.20 = 20% max drawdown

0.2
action str

Action when breached ("warn", "reduce", "halt") Default "halt" - stops all new trades

'halt'
warn_threshold float | None

Optional earlier threshold for warnings

None
Example

limit = MaxDrawdownLimit(max_drawdown=0.20, warn_threshold=0.15)

Warns at 15% drawdown, halts at 20%

MaxPositionsLimit dataclass

MaxPositionsLimit(max_positions=10, action='halt')

Bases: PortfolioLimit

Limit maximum number of open positions.

Parameters:

Name Type Description Default
max_positions int

Maximum number of simultaneous positions

10
action str

Action when breached ("warn", "halt")

'halt'
Example

limit = MaxPositionsLimit(max_positions=10)

Prevents opening more than 10 positions

MaxExposureLimit dataclass

MaxExposureLimit(max_exposure_pct=0.1, action='warn')

Bases: PortfolioLimit

Limit maximum exposure to a single asset.

Parameters:

Name Type Description Default
max_exposure_pct float

Maximum position size as % of equity (0.0-1.0) Default 0.10 = 10% max per asset

0.1
action str

Action when breached

'warn'
Example

limit = MaxExposureLimit(max_exposure_pct=0.10)

No single position can be > 10% of portfolio

DailyLossLimit dataclass

DailyLossLimit(max_daily_loss_pct=0.02, action='halt')

Bases: PortfolioLimit

Halt trading when daily loss exceeds threshold.

Parameters:

Name Type Description Default
max_daily_loss_pct float

Maximum daily loss as % of equity (0.0-1.0) Default 0.02 = 2% max daily loss

0.02
action str

Action when breached

'halt'
Example

limit = DailyLossLimit(max_daily_loss_pct=0.02)

Halt if down more than 2% today

Strategy Templates

SignalFollowingStrategy

Bases: Strategy

Template for strategies that follow pre-computed signals.

Use this when you have ML predictions, technical indicators, or any pre-computed signal column in your DataFrame.

Class Attributes

signal_column: Name of the signal column in data (default: "signal") position_size: Fraction of equity per position (default: 0.10) allow_shorts: Whether to allow short positions (default: False)

Example

class MyMLStrategy(SignalFollowingStrategy): ... signal_column = "rf_prediction" ... position_size = 0.05 ... ... def should_enter_long(self, signal): ... return signal > 0.7 ... ... def should_exit(self, signal): ... return signal < 0.3

should_enter_long abstractmethod

should_enter_long(signal)

Return True to open a long position.

Parameters:

Name Type Description Default
signal float

Current signal value for the asset

required

Returns:

Type Description
bool

True if should enter long position

Source code in src/ml4t/backtest/strategies/templates.py
@abstractmethod
def should_enter_long(self, signal: float) -> bool:
    """Return True to open a long position.

    Args:
        signal: Current signal value for the asset

    Returns:
        True if should enter long position
    """

should_exit abstractmethod

should_exit(signal)

Return True to close current position.

Parameters:

Name Type Description Default
signal float

Current signal value for the asset

required

Returns:

Type Description
bool

True if should exit position

Source code in src/ml4t/backtest/strategies/templates.py
@abstractmethod
def should_exit(self, signal: float) -> bool:
    """Return True to close current position.

    Args:
        signal: Current signal value for the asset

    Returns:
        True if should exit position
    """

should_enter_short

should_enter_short(signal)

Return True to open a short position.

Override this method for short strategies. Default returns False.

Parameters:

Name Type Description Default
signal float

Current signal value for the asset

required

Returns:

Type Description
bool

True if should enter short position

Source code in src/ml4t/backtest/strategies/templates.py
def should_enter_short(self, signal: float) -> bool:
    """Return True to open a short position.

    Override this method for short strategies. Default returns False.

    Args:
        signal: Current signal value for the asset

    Returns:
        True if should enter short position
    """
    return False

on_data

on_data(timestamp, data, context, broker)

Process each bar and generate orders based on signals.

Source code in src/ml4t/backtest/strategies/templates.py
def on_data(
    self,
    timestamp: datetime,
    data: dict[str, dict],
    context: dict[str, Any],
    broker: Broker,
) -> None:
    """Process each bar and generate orders based on signals."""
    for asset, bar in data.items():
        # Signals are nested under 'signals' dict in DataFeed output
        signals = bar.get("signals", {})
        signal = signals.get(self.signal_column, 0) if signals else 0
        if signal is None:
            signal = 0

        position = broker.get_position(asset)
        price = bar.get("close", 0)

        if position is None:
            # No position - check for entry
            fractional = _use_fractional(self.allow_fractional, broker)
            if self.should_enter_long(signal):
                equity = broker.get_account_value()
                raw_shares = (equity * self.position_size) / price if price > 0 else 0
                shares = raw_shares if fractional else int(raw_shares)
                if shares > 0:
                    broker.submit_order(asset, shares)
            elif self.allow_shorts and self.should_enter_short(signal):
                equity = broker.get_account_value()
                raw_shares = (equity * self.position_size) / price if price > 0 else 0
                shares = raw_shares if fractional else int(raw_shares)
                if shares > 0:
                    broker.submit_order(asset, -shares)
        else:
            # Have position - check for exit
            if self.should_exit(signal):
                broker.close_position(asset)

MomentumStrategy

MomentumStrategy()

Bases: Strategy

Template for momentum/trend-following strategies.

Enters long when asset has positive momentum over lookback period, exits when momentum turns negative.

Class Attributes

lookback: Number of bars for momentum calculation (default: 20) entry_threshold: Minimum return to enter (default: 0.05 = 5%) exit_threshold: Return level to exit (default: -0.02 = -2%) position_size: Fraction of equity per position (default: 0.10)

Example

class MyMomentum(MomentumStrategy): ... lookback = 60 # 60-day momentum ... entry_threshold = 0.10 # Enter on 10% gain ... exit_threshold = 0.0 # Exit when momentum turns negative

Source code in src/ml4t/backtest/strategies/templates.py
def __init__(self) -> None:
    self.price_history: dict[str, list[float]] = defaultdict(list)

calculate_momentum

calculate_momentum(prices)

Calculate momentum as return over lookback period.

Parameters:

Name Type Description Default
prices list[float]

List of prices (most recent last)

required

Returns:

Type Description
float

Return from first to last price

Source code in src/ml4t/backtest/strategies/templates.py
def calculate_momentum(self, prices: list[float]) -> float:
    """Calculate momentum as return over lookback period.

    Args:
        prices: List of prices (most recent last)

    Returns:
        Return from first to last price
    """
    if len(prices) < 2 or prices[0] == 0:
        return 0.0
    return (prices[-1] / prices[0]) - 1

on_data

on_data(timestamp, data, context, broker)

Process each bar and trade based on momentum.

Source code in src/ml4t/backtest/strategies/templates.py
def on_data(
    self,
    timestamp: datetime,
    data: dict[str, dict],
    context: dict[str, Any],
    broker: Broker,
) -> None:
    """Process each bar and trade based on momentum."""
    for asset, bar in data.items():
        close = bar.get("close")
        if close is None or close <= 0:
            continue

        # Track price history
        self.price_history[asset].append(close)

        # Wait for enough history
        if len(self.price_history[asset]) < self.lookback:
            continue

        # Keep only lookback period
        self.price_history[asset] = self.price_history[asset][-self.lookback :]

        # Calculate momentum
        momentum = self.calculate_momentum(self.price_history[asset])
        position = broker.get_position(asset)

        if position is None and momentum > self.entry_threshold:
            # Enter long on strong momentum
            equity = broker.get_account_value()
            raw_shares = (equity * self.position_size) / close
            fractional = _use_fractional(self.allow_fractional, broker)
            shares = raw_shares if fractional else int(raw_shares)
            if shares > 0:
                broker.submit_order(asset, shares)
        elif position is not None and momentum < self.exit_threshold:
            # Exit on weak momentum
            broker.close_position(asset)

MeanReversionStrategy

MeanReversionStrategy()

Bases: Strategy

Template for mean-reversion strategies.

Buys when price is below moving average by a threshold, sells when price reverts to the mean.

Class Attributes

lookback: Number of bars for mean calculation (default: 20) entry_zscore: Z-score threshold to enter (default: -2.0) exit_zscore: Z-score threshold to exit (default: 0.0) position_size: Fraction of equity per position (default: 0.10)

Example

class MyMeanReversion(MeanReversionStrategy): ... lookback = 30 ... entry_zscore = -2.5 # More extreme entry ... exit_zscore = 0.5 # Take profit above mean

Source code in src/ml4t/backtest/strategies/templates.py
def __init__(self) -> None:
    self.price_history: dict[str, list[float]] = defaultdict(list)

calculate_zscore

calculate_zscore(prices, current)

Calculate z-score of current price vs historical distribution.

Parameters:

Name Type Description Default
prices list[float]

Historical prices

required
current float

Current price

required

Returns:

Type Description
float | None

Z-score or None if insufficient data

Source code in src/ml4t/backtest/strategies/templates.py
def calculate_zscore(self, prices: list[float], current: float) -> float | None:
    """Calculate z-score of current price vs historical distribution.

    Args:
        prices: Historical prices
        current: Current price

    Returns:
        Z-score or None if insufficient data
    """
    if len(prices) < 2:
        return None

    try:
        avg = mean(prices)
        std = stdev(prices)
        if std == 0:
            return None
        return (current - avg) / std
    except Exception:
        return None

on_data

on_data(timestamp, data, context, broker)

Process each bar and trade based on mean reversion.

Source code in src/ml4t/backtest/strategies/templates.py
def on_data(
    self,
    timestamp: datetime,
    data: dict[str, dict],
    context: dict[str, Any],
    broker: Broker,
) -> None:
    """Process each bar and trade based on mean reversion."""
    for asset, bar in data.items():
        close = bar.get("close")
        if close is None or close <= 0:
            continue

        # Track price history
        self.price_history[asset].append(close)

        # Wait for enough history
        if len(self.price_history[asset]) < self.lookback:
            continue

        # Keep only lookback period
        prices = self.price_history[asset][-self.lookback :]
        self.price_history[asset] = prices

        # Calculate z-score
        zscore = self.calculate_zscore(prices[:-1], close)
        if zscore is None:
            continue

        position = broker.get_position(asset)

        if position is None and zscore < self.entry_zscore:
            # Enter long on oversold condition
            equity = broker.get_account_value()
            raw_shares = (equity * self.position_size) / close
            fractional = _use_fractional(self.allow_fractional, broker)
            shares = raw_shares if fractional else int(raw_shares)
            if shares > 0:
                broker.submit_order(asset, shares)
        elif position is not None and zscore > self.exit_zscore:
            # Exit on mean reversion
            broker.close_position(asset)

LongShortStrategy

LongShortStrategy()

Bases: Strategy

Template for long/short equity strategies.

Ranks assets by a signal and goes long top N, short bottom N.

Class Attributes

signal_column: Column to rank assets by (default: "signal") long_count: Number of assets to go long (default: 5) short_count: Number of assets to go short (default: 5) position_size: Fraction of equity per position (default: 0.05) rebalance_frequency: Bars between rebalancing (default: 20)

Example

class MyLongShort(LongShortStrategy): ... signal_column = "momentum_score" ... long_count = 10 ... short_count = 10 ... rebalance_frequency = 21 # Monthly

Source code in src/ml4t/backtest/strategies/templates.py
def __init__(self) -> None:
    self.bar_count = 0
    self._resolved_schedule: frozenset[datetime] | None = None

on_prepare

on_prepare(broker, timestamps, config=None)

Resolve optional schedule-based rebalance gating before the run starts.

Source code in src/ml4t/backtest/strategies/templates.py
def on_prepare(
    self,
    broker: Any,
    timestamps: Sequence[datetime],
    config: BacktestConfig | None = None,
) -> None:
    """Resolve optional schedule-based rebalance gating before the run starts."""
    if self.rebalance_schedule is None:
        self._resolved_schedule = None
        return
    calendar = config.resolved_calendar if config is not None else None
    timezone = config.resolved_timezone if config is not None else "UTC"
    feed_spec = config.resolved_feed_spec if config is not None else None
    resolved = resolve_rebalance_timestamps(
        timestamps,
        self.rebalance_schedule,
        feed_spec=feed_spec,
        calendar=calendar,
        timezone=timezone,
    )
    self._resolved_schedule = frozenset(resolved.to_list())

rank_assets

rank_assets(data)

Rank assets by signal and return long/short lists.

Parameters:

Name Type Description Default
data dict[str, dict]

Current bar data for all assets

required

Returns:

Type Description
tuple[list[str], list[str]]

Tuple of (long_assets, short_assets)

Source code in src/ml4t/backtest/strategies/templates.py
def rank_assets(self, data: dict[str, dict]) -> tuple[list[str], list[str]]:
    """Rank assets by signal and return long/short lists.

    Args:
        data: Current bar data for all assets

    Returns:
        Tuple of (long_assets, short_assets)
    """
    # Collect signals (signals are nested under 'signals' dict)
    signals: list[tuple[str, float]] = []
    for asset, bar in data.items():
        bar_signals = bar.get("signals", {})
        signal = bar_signals.get(self.signal_column) if bar_signals else None
        if signal is not None:
            signals.append((asset, signal))

    if not signals:
        return [], []

    # Sort by signal (high to low)
    signals.sort(key=lambda x: x[1], reverse=True)

    # Top N for long, bottom N for short
    long_assets = [s[0] for s in signals[: self.long_count]]
    short_assets = [s[0] for s in signals[-self.short_count :]]

    # Don't short the same assets we're going long
    short_assets = [a for a in short_assets if a not in long_assets]

    return long_assets, short_assets

on_data

on_data(timestamp, data, context, broker)

Rebalance portfolio periodically based on rankings.

Source code in src/ml4t/backtest/strategies/templates.py
def on_data(
    self,
    timestamp: datetime,
    data: dict[str, dict],
    context: dict[str, Any],
    broker: Broker,
) -> None:
    """Rebalance portfolio periodically based on rankings."""
    self.bar_count += 1

    if self.rebalance_schedule is not None:
        if self._resolved_schedule is None:
            raise ValueError("rebalance_schedule is set but was not prepared before execution")
        if timestamp not in self._resolved_schedule:
            return
    elif self.bar_count % self.rebalance_frequency != 1:
        return

    # Get current rankings
    long_assets, short_assets = self.rank_assets(data)
    target_assets = set(long_assets + short_assets)

    # Close positions not in target
    for asset in list(broker.get_positions().keys()):
        if asset not in target_assets:
            broker.close_position(asset)

    # Open/adjust positions
    equity = broker.get_account_value()
    fractional = _use_fractional(self.allow_fractional, broker)

    for asset in long_assets:
        price = data.get(asset, {}).get("close", 0)
        if price <= 0:
            continue

        position = broker.get_position(asset)
        raw_shares = (equity * self.position_size) / price
        target_shares = raw_shares if fractional else int(raw_shares)

        if position is None and target_shares > 0:
            broker.submit_order(asset, target_shares)

    for asset in short_assets:
        price = data.get(asset, {}).get("close", 0)
        if price <= 0:
            continue

        position = broker.get_position(asset)
        raw_shares = (equity * self.position_size) / price
        target_shares = raw_shares if fractional else int(raw_shares)

        if position is None and target_shares > 0:
            broker.submit_order(asset, -target_shares)