Files
quant/weekly_strategy_report.py

310 lines
10 KiB
Python

#!/usr/bin/env python3
"""Weekly top-strategy report against market baselines."""
import argparse
import glob
import json
import time
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from urllib.parse import urlencode
from urllib.request import Request, urlopen
import pandas as pd
import yfinance as yf
import data_manager
INITIAL_VALUE = 10_000.0
@dataclass(frozen=True)
class Baseline:
label: str
yahoo_symbol: str | None = None
cache_symbol: str | None = None
sohu_code: str | None = None
@dataclass
class StrategyEquity:
name: str
series: pd.Series
total_return: float
BASELINES = {
"us": [
Baseline("NASDAQ Composite", yahoo_symbol="^IXIC"),
Baseline("SPY", yahoo_symbol="SPY", cache_symbol="SPY"),
],
"cn": [
Baseline("CSI 300", yahoo_symbol="000300.SS", cache_symbol="000300.SS", sohu_code="zs_000300"),
Baseline("CSI 800", yahoo_symbol="000906.SS", sohu_code="zs_000906"),
],
}
def _state_name(path: str, market: str) -> str:
base = Path(path).name
prefix = f"trader_{market}_"
return base[len(prefix):-len(".json")]
def load_strategies(market: str, include_sim: bool) -> list[StrategyEquity]:
rows: list[StrategyEquity] = []
for path in sorted(glob.glob(f"data/trader_{market}_*.json")):
name = _state_name(path, market)
if name.startswith("sim_") and not include_sim:
continue
state = json.loads(Path(path).read_text())
daily_equity = state.get("daily_equity", {}) or {}
if len(daily_equity) < 2:
continue
series = pd.Series(daily_equity, dtype=float)
series.index = pd.to_datetime(series.index)
series = series.sort_index()
initial = float(state.get("initial_capital") or series.iloc[0] or INITIAL_VALUE)
total_return = series.iloc[-1] / initial - 1.0 if initial else 0.0
rows.append(StrategyEquity(name=name, series=series, total_return=total_return))
rows.sort(key=lambda row: row.total_return, reverse=True)
return rows
def _close_from_yahoo_frame(raw: pd.DataFrame, symbol: str) -> pd.Series:
if raw.empty:
return pd.Series(dtype=float)
if isinstance(raw.columns, pd.MultiIndex):
if "Close" not in raw.columns.get_level_values(0):
return pd.Series(dtype=float)
close = raw["Close"]
if isinstance(close, pd.DataFrame):
if symbol in close.columns:
return close[symbol].dropna().astype(float)
if len(close.columns) == 1:
return close.iloc[:, 0].dropna().astype(float)
return close.dropna().astype(float)
if "Close" not in raw.columns:
return pd.Series(dtype=float)
return raw["Close"].dropna().astype(float)
def download_yahoo_close(symbol: str, start: pd.Timestamp, end: pd.Timestamp) -> pd.Series:
end_exclusive = (end + timedelta(days=1)).strftime("%Y-%m-%d")
raw = yf.download(
symbol,
start=start.strftime("%Y-%m-%d"),
end=end_exclusive,
auto_adjust=True,
progress=False,
)
series = _close_from_yahoo_frame(raw, symbol)
series.index = pd.to_datetime(series.index).tz_localize(None)
return series
def download_sohu_close(code: str, start: pd.Timestamp, end: pd.Timestamp) -> pd.Series:
params = {
"code": code,
"start": start.strftime("%Y%m%d"),
"end": end.strftime("%Y%m%d"),
"stat": "1",
"order": "D",
"period": "d",
"callback": "historySearchHandler",
"rt": "jsonp",
}
url = "https://q.stock.sohu.com/hisHq?" + urlencode(params)
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
text = ""
for attempt in range(3):
try:
with urlopen(req, timeout=20) as resp:
text = resp.read().decode("gbk").strip()
break
except Exception:
if attempt == 2:
return pd.Series(dtype=float)
time.sleep(1)
prefix = "historySearchHandler("
if not text.startswith(prefix) or not text.endswith(")"):
return pd.Series(dtype=float)
payload = json.loads(text[len(prefix):-1])
if not payload or payload[0].get("status") != 0:
return pd.Series(dtype=float)
rows = payload[0].get("hq", [])
if not rows:
return pd.Series(dtype=float)
data = {row[0]: float(row[2]) for row in rows if len(row) >= 3}
series = pd.Series(data, dtype=float)
series.index = pd.to_datetime(series.index)
return series.sort_index()
def clip_dates(series: pd.Series, start: pd.Timestamp, end: pd.Timestamp) -> pd.Series:
if series.empty:
return series
clipped = series.copy()
clipped.index = pd.to_datetime(clipped.index).tz_localize(None)
return clipped.loc[(clipped.index >= start) & (clipped.index <= end)].dropna()
def load_baseline(market: str, baseline: Baseline, start: pd.Timestamp, end: pd.Timestamp) -> pd.Series:
if baseline.sohu_code:
series = download_sohu_close(baseline.sohu_code, start, end)
else:
series = pd.Series(dtype=float)
series = clip_dates(series, start, end)
if len(series) < 2:
cached = data_manager.load(market)
if cached is not None and baseline.cache_symbol in cached.columns:
series = cached[baseline.cache_symbol].dropna().astype(float)
else:
series = pd.Series(dtype=float)
series = clip_dates(series, start, end)
if len(series) < 2 and baseline.yahoo_symbol:
series = download_yahoo_close(baseline.yahoo_symbol, start, end)
series = clip_dates(series, start, end)
if len(series) < 2 and baseline.sohu_code:
series = download_sohu_close(baseline.sohu_code, start, end)
series = clip_dates(series, start, end)
if series.empty:
return series
return series / series.iloc[0] * INITIAL_VALUE
def select_period(strategies: list[StrategyEquity], start: str | None, end: str | None) -> tuple[pd.Timestamp, pd.Timestamp]:
if start:
start_ts = pd.Timestamp(start)
else:
start_ts = max(row.series.index.min() for row in strategies)
if end:
end_ts = pd.Timestamp(end)
else:
end_ts = min(row.series.index.max() for row in strategies)
if start_ts > end_ts:
raise ValueError(f"Invalid period: start {start_ts.date()} is after end {end_ts.date()}")
return start_ts, end_ts
def weekly_last(frame: pd.DataFrame) -> pd.DataFrame:
rows = []
for _, group in frame.groupby(pd.Grouper(freq="W-FRI")):
group = group.dropna(how="all")
if group.empty:
continue
last = group.iloc[-1].copy()
last.name = group.index[-1].strftime("%Y-%m-%d")
rows.append(last)
if not rows:
return pd.DataFrame(columns=frame.columns)
result = pd.DataFrame(rows)
result.index.name = "week_date"
return result
def build_market_report(
market: str,
top: int,
include_sim: bool,
start: str | None,
end: str | None,
) -> tuple[pd.DataFrame, pd.DataFrame, list[str]]:
strategies = load_strategies(market, include_sim=include_sim)
if not strategies:
raise RuntimeError(f"No strategy equity data found for market '{market}'")
selected = strategies[:top]
start_ts, end_ts = select_period(selected, start, end)
frame = pd.DataFrame({
row.name: row.series.loc[(row.series.index >= start_ts) & (row.series.index <= end_ts)]
for row in selected
})
warnings = []
for baseline in BASELINES[market]:
series = load_baseline(market, baseline, start_ts, end_ts)
if series.empty:
warnings.append(f"{market.upper()} baseline '{baseline.label}' has no data for {start_ts.date()} to {end_ts.date()}")
continue
frame[baseline.label] = series
weekly_values = weekly_last(frame)
weekly_returns = (weekly_values / INITIAL_VALUE - 1.0) * 100.0
return weekly_values, weekly_returns, warnings
def write_outputs(market: str, values: pd.DataFrame, returns: pd.DataFrame, output_dir: Path) -> tuple[Path, Path]:
output_dir.mkdir(parents=True, exist_ok=True)
value_path = output_dir / f"weekly_{market}_top10_vs_baselines.csv"
return_path = output_dir / f"weekly_{market}_top10_vs_baselines_returns.csv"
values.round(2).to_csv(value_path)
returns.round(4).to_csv(return_path)
return value_path, return_path
def print_returns(market: str, returns: pd.DataFrame) -> None:
print(f"\n{market.upper()} weekly return %")
if returns.empty:
print(" No weekly rows.")
return
printable = returns.copy()
printable = printable.map(lambda value: "" if pd.isna(value) else f"{value:+.2f}%")
print(printable.to_string())
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Compare each market's top10 live strategies with weekly baseline data."
)
parser.add_argument("--market", choices=["all", "us", "cn"], default="all")
parser.add_argument("--top", type=int, default=10)
parser.add_argument("--start", help="Optional start date, YYYY-MM-DD")
parser.add_argument("--end", help="Optional end date, YYYY-MM-DD")
parser.add_argument("--include-sim", action="store_true")
parser.add_argument("--output-dir", default="data")
return parser.parse_args()
def main() -> None:
args = parse_args()
markets = ["us", "cn"] if args.market == "all" else [args.market]
output_dir = Path(args.output_dir)
for market in markets:
values, returns, warnings = build_market_report(
market=market,
top=args.top,
include_sim=args.include_sim,
start=args.start,
end=args.end,
)
value_path, return_path = write_outputs(market, values, returns, output_dir)
print_returns(market, returns)
print(f" values: {value_path}")
print(f" returns: {return_path}")
for warning in warnings:
print(f" warning: {warning}")
if __name__ == "__main__":
main()