""" Stock universe providers — fetch index constituents for backtesting. Each function returns a dict with 'tickers' (list[str]) and 'benchmark' (str) compatible with Yahoo Finance ticker format. Results are cached locally so Wikipedia / web scrapes only happen once per day. """ import io import json import os import urllib.request from datetime import date import pandas as pd _HEADERS = {"User-Agent": "Mozilla/5.0 (quant-backtest)"} def _fetch_html_tables(url: str) -> list[pd.DataFrame]: """Fetch HTML tables from a URL with a proper User-Agent.""" req = urllib.request.Request(url, headers=_HEADERS) with urllib.request.urlopen(req) as resp: html = resp.read().decode("utf-8") return pd.read_html(io.StringIO(html)) CACHE_DIR = "data" def _read_cache(name: str) -> list[str] | None: path = os.path.join(CACHE_DIR, f"universe_{name}.json") if not os.path.exists(path): return None try: with open(path) as f: data = json.load(f) if data.get("date") == str(date.today()): return data["tickers"] except Exception: pass return None def _write_cache(name: str, tickers: list[str]) -> None: os.makedirs(CACHE_DIR, exist_ok=True) path = os.path.join(CACHE_DIR, f"universe_{name}.json") with open(path, "w") as f: json.dump({"date": str(date.today()), "tickers": tickers}, f) # --------------------------------------------------------------------------- # S&P 500 # --------------------------------------------------------------------------- def get_sp500() -> list[str]: """Fetch current S&P 500 constituents from Wikipedia.""" cached = _read_cache("sp500") if cached: print(f"--- Loaded S&P 500 universe from cache ({len(cached)} tickers) ---") return cached print("--- Fetching S&P 500 constituents from Wikipedia ---") url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" tables = _fetch_html_tables(url) df = tables[0] tickers = sorted(df["Symbol"].str.replace(".", "-", regex=False).tolist()) _write_cache("sp500", tickers) print(f"--- Got {len(tickers)} S&P 500 tickers ---") return tickers # --------------------------------------------------------------------------- # CSI 300 # --------------------------------------------------------------------------- # Yahoo Finance format: Shanghai → XXXXXX.SS, Shenzhen → XXXXXX.SZ # Wikipedia CSI 300 page has a table with stock codes and exchanges. def get_csi300() -> list[str]: """Fetch current CSI 300 constituents from Wikipedia.""" cached = _read_cache("csi300") if cached: print(f"--- Loaded CSI 300 universe from cache ({len(cached)} tickers) ---") return cached print("--- Fetching CSI 300 constituents from Wikipedia ---") url = "https://en.wikipedia.org/wiki/CSI_300_Index" tables = _fetch_html_tables(url) # Find the constituent table — it has ~300 rows and a "Ticker" column # with format like "SSE: 600519" or "SZSE: 300750" df = None for t in tables: if len(t) >= 200 and "Ticker" in t.columns: df = t break if df is None: raise RuntimeError("Could not find CSI 300 constituent table on Wikipedia") tickers = [] for _, row in df.iterrows(): raw = str(row["Ticker"]) # e.g. "SSE: 600519" or "SZSE: 300750" parts = raw.split(":") if len(parts) != 2: continue exchange_prefix = parts[0].strip().upper() code = parts[1].strip() if not code.isdigit(): continue suffix = ".SS" if exchange_prefix == "SSE" else ".SZ" tickers.append(code + suffix) tickers = sorted(set(tickers)) _write_cache("csi300", tickers) print(f"--- Got {len(tickers)} CSI 300 tickers ---") return tickers # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- UNIVERSES = { "us": { "fetch": get_sp500, "benchmark": "SPY", "benchmark_label": "SPY (Benchmark)", }, "cn": { "fetch": get_csi300, "benchmark": "000300.SS", "benchmark_label": "CSI 300 (Benchmark)", }, }