Files
quant/data_manager.py

310 lines
13 KiB
Python

"""
Persistent local data store for price history.
Layout:
data/us.csv — S&P 500 + SPY adjusted close prices
data/us_open.csv — S&P 500 + SPY adjusted open prices
data/cn.csv — CSI 300 + 000300.SS adjusted close prices
data/cn_open.csv — CSI 300 + 000300.SS adjusted open prices
On first run: downloads full 10-year history.
On subsequent runs: reads existing file, downloads only new dates, appends.
New tickers (index rebalance) are backfilled automatically.
"""
import os
from datetime import datetime, timedelta
import pandas as pd
import yfinance as yf
DATA_DIR = "data"
def _data_path(market: str, price_type: str = "close") -> str:
suffix = "" if price_type == "close" else f"_{price_type}"
return os.path.join(DATA_DIR, f"{market}{suffix}.csv")
def load(market: str, price_type: str = "close") -> pd.DataFrame | None:
"""Load existing data file, or None if it doesn't exist."""
path = _data_path(market, price_type)
if not os.path.exists(path):
return None
df = pd.read_csv(path, index_col=0, parse_dates=True)
return df
def _download(tickers: list[str], start: str, end: str | None = None,
fields: list[str] | None = None) -> dict[str, pd.DataFrame]:
"""
Download price data from Yahoo Finance.
Parameters
----------
fields : list of str, optional
Price fields to extract, e.g. ["Close", "Open"].
Defaults to ["Close"].
Returns
-------
dict mapping field name to DataFrame (columns = tickers).
"""
if fields is None:
fields = ["Close"]
print(f"--- Downloading {len(tickers)} tickers from {start}{f' to {end}' if end else ''} ---")
kwargs = dict(auto_adjust=True)
if end:
kwargs["end"] = end
raw = yf.download(tickers, start=start, **kwargs)
if raw.empty:
return {f: pd.DataFrame() for f in fields}
result = {}
for field in fields:
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
selected = raw[field]
if isinstance(selected, pd.Series):
result[field] = selected.to_frame(name=tickers[0])
else:
result[field] = selected
else:
result[field] = pd.DataFrame()
return result
def _download_period(tickers: list[str], period: str,
fields: list[str] | None = None) -> dict[str, pd.DataFrame]:
"""Download using period-based API (for small gaps)."""
if fields is None:
fields = ["Close"]
raw = yf.download(tickers, period=period, auto_adjust=True)
if raw.empty:
return {f: pd.DataFrame() for f in fields}
result = {}
for field in fields:
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
selected = raw[field]
if isinstance(selected, pd.Series):
result[field] = selected.to_frame(name=tickers[0])
else:
result[field] = selected
else:
result[field] = pd.DataFrame()
return result
def _clean(data: pd.DataFrame) -> pd.DataFrame:
"""Drop tickers with >50% missing data, forward-fill, drop all-NaN rows."""
good = data.columns[data.notna().mean() > 0.5]
dropped = set(data.columns) - set(good)
if dropped:
print(f"--- Dropped {len(dropped)} tickers with >50% missing data ---")
data = data[good]
data = data.ffill().dropna(how="all")
return data
def _clean_market_data(data: pd.DataFrame, field: str) -> pd.DataFrame:
"""Clean market data while preserving volume gaps."""
good = data.columns[data.notna().mean() > 0.5]
dropped = set(data.columns) - set(good)
if dropped:
print(f"--- Dropped {len(dropped)} tickers with >50% missing data ---")
data = data[good]
if field == "volume":
return data
return data.ffill().dropna(how="all")
def _merge_market_panel(existing: pd.DataFrame | None, new_data: pd.DataFrame) -> pd.DataFrame:
"""Merge new data into an existing cached panel, preserving old columns and dates."""
if existing is None or existing.empty:
merged = new_data.copy()
elif new_data.empty:
merged = existing.copy()
else:
merged = existing.combine_first(new_data)
merged.loc[new_data.index, new_data.columns] = new_data
merged = merged.sort_index()
merged = merged[~merged.index.duplicated(keep="last")]
return merged
def update_market_data(market: str, tickers: list[str], fields: list[str]) -> dict[str, pd.DataFrame]:
"""Download, clean, persist, and return market data panels for requested Yahoo fields."""
field_aliases = {
"close": "Close",
"open": "Open",
"high": "High",
"low": "Low",
"volume": "Volume",
}
normalized_fields = []
yahoo_fields = []
for field in fields:
normalized = field.lower()
if normalized not in field_aliases:
raise ValueError(f"Unsupported market data field: {field}")
normalized_fields.append(normalized)
yahoo_fields.append(field_aliases[normalized])
os.makedirs(DATA_DIR, exist_ok=True)
start = (datetime.now() - timedelta(days=365 * 10)).strftime("%Y-%m-%d")
downloaded = _download(tickers, start=start, fields=yahoo_fields)
cleaned = {}
for normalized, yahoo_field in zip(normalized_fields, yahoo_fields):
data = _clean_market_data(downloaded.get(yahoo_field, pd.DataFrame()), normalized)
existing = load(market, normalized)
data = _merge_market_panel(existing, data)
path = _data_path(market, normalized)
data.to_csv(path)
print(f"--- Saved {data.shape[0]} days x {data.shape[1]} tickers to {path} ---")
cleaned[normalized] = data
return cleaned
def update(market: str, tickers: list[str],
with_open: bool = False) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""
Load existing data, download new dates and new tickers, save back.
Parameters
----------
with_open : bool
If True, also download/maintain Open prices and return (close, open) tuple.
Returns
-------
pd.DataFrame or (close_df, open_df)
Price data ready for backtesting.
"""
os.makedirs(DATA_DIR, exist_ok=True)
fields = ["Close", "Open"] if with_open else ["Close"]
existing_close = load(market, "close")
existing_open = load(market, "open") if with_open else None
# If we need open prices but don't have them cached, force a full re-download
need_full_open = with_open and existing_open is None and existing_close is not None
if existing_close is None:
# First run: download full 10-year history
start = (datetime.now() - timedelta(days=365 * 10)).strftime("%Y-%m-%d")
downloaded = _download(tickers, start=start, fields=fields)
close_data = downloaded["Close"]
open_data = downloaded.get("Open", pd.DataFrame()) if with_open else None
elif need_full_open:
# Have close data but need open data — download open prices for full range
close_data = existing_close
start = close_data.index[0].strftime("%Y-%m-%d")
known_tickers = [t for t in close_data.columns if t in tickers]
print(f"--- Downloading Open prices for {len(known_tickers)} tickers (first time) ---")
downloaded = _download(known_tickers, start=start, fields=["Open"])
open_data = downloaded.get("Open", pd.DataFrame())
last_date = close_data.index[-1]
# Still need to append new close dates and backfill new tickers
next_day = (last_date + timedelta(days=1)).strftime("%Y-%m-%d")
today = datetime.now().strftime("%Y-%m-%d")
if next_day < today:
if known_tickers:
gap_days = (datetime.now() - last_date.to_pydatetime().replace(tzinfo=None)).days
if gap_days <= 7:
print(f"--- Fetching last {gap_days + 5}d to append ---")
new = _download_period(known_tickers, f"{gap_days + 5}d", fields=fields)
new_close = new["Close"]
new_close = new_close[new_close.index > last_date]
else:
new = _download(known_tickers, start=next_day, fields=fields)
new_close = new["Close"]
if not new_close.empty:
close_data = pd.concat([close_data, new_close]).sort_index()
close_data = close_data[~close_data.index.duplicated(keep="last")]
print(f"--- Appended {len(new_close)} new days ---")
else:
print("--- No new trading days to append ---")
new_tickers = [t for t in tickers if t not in close_data.columns]
if new_tickers:
start = close_data.index[0].strftime("%Y-%m-%d")
backfill = _download(new_tickers, start=start, fields=fields)
backfill_close = backfill["Close"]
backfill_open = backfill.get("Open", pd.DataFrame())
if not backfill_close.empty:
backfill_close = backfill_close.reindex(close_data.index)
close_data = pd.concat([close_data, backfill_close], axis=1)
if not backfill_open.empty and open_data is not None:
backfill_open = backfill_open.reindex(open_data.index)
open_data = pd.concat([open_data, backfill_open], axis=1)
print(f"--- Backfilled {len(new_tickers)} new tickers ---")
else:
close_data = existing_close
open_data = existing_open if with_open else None
last_date = close_data.index[-1]
# 1) Append new dates for existing tickers
next_day = (last_date + timedelta(days=1)).strftime("%Y-%m-%d")
today = datetime.now().strftime("%Y-%m-%d")
if next_day < today:
known_tickers = [t for t in close_data.columns if t in tickers]
if known_tickers:
gap_days = (datetime.now() - last_date.to_pydatetime().replace(tzinfo=None)).days
if gap_days <= 7:
print(f"--- Fetching last {gap_days + 5}d to append ---")
new = _download_period(known_tickers, f"{gap_days + 5}d", fields=fields)
new_close = new["Close"]
new_close = new_close[new_close.index > last_date]
new_open = new.get("Open", pd.DataFrame())
if not new_open.empty:
new_open = new_open[new_open.index > last_date]
else:
new = _download(known_tickers, start=next_day, fields=fields)
new_close = new["Close"]
new_open = new.get("Open", pd.DataFrame())
if not new_close.empty:
close_data = pd.concat([close_data, new_close]).sort_index()
close_data = close_data[~close_data.index.duplicated(keep="last")]
print(f"--- Appended {len(new_close)} new days ---")
if with_open and not new_open.empty and open_data is not None:
open_data = pd.concat([open_data, new_open]).sort_index()
open_data = open_data[~open_data.index.duplicated(keep="last")]
else:
print("--- No new trading days to append ---")
# 2) Backfill any new tickers not in existing data
new_tickers = [t for t in tickers if t not in close_data.columns]
if new_tickers:
start = close_data.index[0].strftime("%Y-%m-%d")
backfill = _download(new_tickers, start=start, fields=fields)
backfill_close = backfill["Close"]
backfill_open = backfill.get("Open", pd.DataFrame())
if not backfill_close.empty:
backfill_close = backfill_close.reindex(close_data.index)
close_data = pd.concat([close_data, backfill_close], axis=1)
print(f"--- Backfilled {len(new_tickers)} new tickers ---")
if with_open and not backfill_open.empty and open_data is not None:
backfill_open = backfill_open.reindex(open_data.index)
open_data = pd.concat([open_data, backfill_open], axis=1)
# Clean and save close
close_data = _clean(close_data)
close_path = _data_path(market, "close")
close_data.to_csv(close_path)
print(f"--- Saved {close_data.shape[0]} days x {close_data.shape[1]} tickers to {close_path} ---")
if with_open and open_data is not None and not open_data.empty:
# Align open data to same tickers/dates as close
open_data = open_data.reindex(index=close_data.index, columns=close_data.columns)
open_data = open_data.ffill().dropna(how="all")
# Where open is missing, fall back to close
open_data = open_data.fillna(close_data)
open_path = _data_path(market, "open")
open_data.to_csv(open_path)
print(f"--- Saved {open_data.shape[0]} days x {open_data.shape[1]} tickers to {open_path} ---")
return close_data, open_data
return close_data