feat: add OHLCV market data updater

This commit is contained in:
2026-04-17 23:59:06 +08:00
parent 7239310be3
commit 3abc51e3e3
2 changed files with 183 additions and 0 deletions

View File

@@ -103,6 +103,66 @@ def _clean(data: pd.DataFrame) -> pd.DataFrame:
return data return data
def _clean_market_data(data: pd.DataFrame, field: str) -> pd.DataFrame:
"""Clean market data while preserving volume gaps."""
good = data.columns[data.notna().mean() > 0.5]
dropped = set(data.columns) - set(good)
if dropped:
print(f"--- Dropped {len(dropped)} tickers with >50% missing data ---")
data = data[good]
if field == "volume":
return data
return data.ffill().dropna(how="all")
def _merge_market_panel(existing: pd.DataFrame | None, new_data: pd.DataFrame) -> pd.DataFrame:
"""Merge new data into an existing cached panel, preserving old columns and dates."""
if existing is None or existing.empty:
merged = new_data.copy()
elif new_data.empty:
merged = existing.copy()
else:
merged = existing.combine_first(new_data)
merged.loc[new_data.index, new_data.columns] = new_data
merged = merged.sort_index()
merged = merged[~merged.index.duplicated(keep="last")]
return merged
def update_market_data(market: str, tickers: list[str], fields: list[str]) -> dict[str, pd.DataFrame]:
"""Download, clean, persist, and return market data panels for requested Yahoo fields."""
field_aliases = {
"close": "Close",
"open": "Open",
"high": "High",
"low": "Low",
"volume": "Volume",
}
normalized_fields = []
yahoo_fields = []
for field in fields:
normalized = field.lower()
if normalized not in field_aliases:
raise ValueError(f"Unsupported market data field: {field}")
normalized_fields.append(normalized)
yahoo_fields.append(field_aliases[normalized])
os.makedirs(DATA_DIR, exist_ok=True)
start = (datetime.now() - timedelta(days=365 * 10)).strftime("%Y-%m-%d")
downloaded = _download(tickers, start=start, fields=yahoo_fields)
cleaned = {}
for normalized, yahoo_field in zip(normalized_fields, yahoo_fields):
data = _clean_market_data(downloaded.get(yahoo_field, pd.DataFrame()), normalized)
existing = load(market, normalized)
data = _merge_market_panel(existing, data)
path = _data_path(market, normalized)
data.to_csv(path)
print(f"--- Saved {data.shape[0]} days x {data.shape[1]} tickers to {path} ---")
cleaned[normalized] = data
return cleaned
def update(market: str, tickers: list[str], def update(market: str, tickers: list[str],
with_open: bool = False) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: with_open: bool = False) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
""" """

123
tests/test_market_data.py Normal file
View File

@@ -0,0 +1,123 @@
import tempfile
import unittest
from pathlib import Path
from unittest import mock
import pandas as pd
import data_manager
class UpdateMarketDataTests(unittest.TestCase):
def test_update_market_data_accepts_lowercase_fields_and_does_not_fill_volume(self):
dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"])
raw = pd.DataFrame(
{
("Close", "AAA"): [10.0, 11.0, 12.0],
("Close", "BBB"): [20.0, float("nan"), 22.0],
("Open", "AAA"): [9.5, 10.5, 11.5],
("Open", "BBB"): [19.5, 20.5, 21.5],
("High", "AAA"): [10.5, 11.5, 12.5],
("High", "BBB"): [20.5, 21.5, 22.5],
("Low", "AAA"): [9.0, 10.0, 11.0],
("Low", "BBB"): [19.0, 20.0, 21.0],
("Volume", "AAA"): [1000, 1100, 1200],
("Volume", "BBB"): [2000, float("nan"), 2200],
},
index=dates,
)
raw.columns = pd.MultiIndex.from_tuples(raw.columns)
with tempfile.TemporaryDirectory() as tmpdir:
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
with mock.patch("data_manager.yf.download", return_value=raw) as mocked_download:
panels = data_manager.update_market_data(
"us",
["AAA", "BBB"],
["close", "open", "high", "low", "volume"],
)
self.assertEqual(set(panels), {"close", "open", "high", "low", "volume"})
self.assertEqual(panels["close"].loc[dates[1], "BBB"], 20.0)
self.assertTrue(pd.isna(panels["volume"].loc[dates[1], "BBB"]))
self.assertTrue((Path(tmpdir) / "us.csv").exists())
self.assertTrue((Path(tmpdir) / "us_open.csv").exists())
self.assertTrue((Path(tmpdir) / "us_high.csv").exists())
self.assertTrue((Path(tmpdir) / "us_low.csv").exists())
self.assertTrue((Path(tmpdir) / "us_volume.csv").exists())
saved_high = pd.read_csv(Path(tmpdir) / "us_high.csv", index_col=0, parse_dates=True)
pd.testing.assert_frame_equal(saved_high, panels["high"], check_freq=False)
self.assertEqual(mocked_download.call_args.args[0], ["AAA", "BBB"])
self.assertEqual(mocked_download.call_args.kwargs["auto_adjust"], True)
self.assertIn("start", mocked_download.call_args.kwargs)
def test_update_market_data_rejects_unsupported_fields(self):
with tempfile.TemporaryDirectory() as tmpdir:
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
with self.assertRaisesRegex(ValueError, "Unsupported market data field: adjusted_close"):
data_manager.update_market_data("us", ["AAA"], ["adjusted_close"])
def test_update_market_data_preserves_existing_cache_columns_and_dates(self):
existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"])
new_dates = pd.to_datetime(["2024-01-02", "2024-01-03"])
existing_close = pd.DataFrame(
{
"AAA": [9.0, 10.0],
"CCC": [30.0, 31.0],
},
index=existing_dates,
)
downloaded_close = pd.DataFrame({"Close": [10.5, 11.5]}, index=new_dates)
with tempfile.TemporaryDirectory() as tmpdir:
existing_close.to_csv(Path(tmpdir) / "us.csv")
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
with mock.patch("data_manager.yf.download", return_value=downloaded_close):
panels = data_manager.update_market_data("us", ["AAA"], ["close"])
expected = pd.DataFrame(
{
"AAA": [9.0, 10.5, 11.5],
"CCC": [30.0, 31.0, float("nan")],
},
index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
)
saved_close = pd.read_csv(Path(tmpdir) / "us.csv", index_col=0, parse_dates=True)
pd.testing.assert_frame_equal(panels["close"], expected, check_freq=False)
pd.testing.assert_frame_equal(saved_close, expected, check_freq=False)
def test_update_market_data_volume_merge_can_clear_stale_cached_values(self):
existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"])
new_dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"])
existing_volume = pd.DataFrame(
{
"AAA": [1000.0, 9999.0],
"CCC": [3000.0, 3100.0],
},
index=existing_dates,
)
downloaded_volume = pd.DataFrame({"Volume": [float("nan"), 1200.0, 1300.0]}, index=new_dates)
with tempfile.TemporaryDirectory() as tmpdir:
existing_volume.to_csv(Path(tmpdir) / "us_volume.csv")
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
with mock.patch("data_manager.yf.download", return_value=downloaded_volume):
panels = data_manager.update_market_data("us", ["AAA"], ["volume"])
expected = pd.DataFrame(
{
"AAA": [1000.0, float("nan"), 1200.0, 1300.0],
"CCC": [3000.0, 3100.0, float("nan"), float("nan")],
},
index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"]),
)
saved_volume = pd.read_csv(Path(tmpdir) / "us_volume.csv", index_col=0, parse_dates=True)
pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False)
pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False)
if __name__ == "__main__":
unittest.main()