From 3abc51e3e3c8bedd84c0505936d541496bf2aa73 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 23:59:06 +0800 Subject: [PATCH] feat: add OHLCV market data updater --- data_manager.py | 60 +++++++++++++++++++ tests/test_market_data.py | 123 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 tests/test_market_data.py diff --git a/data_manager.py b/data_manager.py index f1f4d4e..f843cad 100644 --- a/data_manager.py +++ b/data_manager.py @@ -103,6 +103,66 @@ def _clean(data: pd.DataFrame) -> pd.DataFrame: return data +def _clean_market_data(data: pd.DataFrame, field: str) -> pd.DataFrame: + """Clean market data while preserving volume gaps.""" + good = data.columns[data.notna().mean() > 0.5] + dropped = set(data.columns) - set(good) + if dropped: + print(f"--- Dropped {len(dropped)} tickers with >50% missing data ---") + data = data[good] + if field == "volume": + return data + return data.ffill().dropna(how="all") + + +def _merge_market_panel(existing: pd.DataFrame | None, new_data: pd.DataFrame) -> pd.DataFrame: + """Merge new data into an existing cached panel, preserving old columns and dates.""" + if existing is None or existing.empty: + merged = new_data.copy() + elif new_data.empty: + merged = existing.copy() + else: + merged = existing.combine_first(new_data) + merged.loc[new_data.index, new_data.columns] = new_data + merged = merged.sort_index() + merged = merged[~merged.index.duplicated(keep="last")] + return merged + + +def update_market_data(market: str, tickers: list[str], fields: list[str]) -> dict[str, pd.DataFrame]: + """Download, clean, persist, and return market data panels for requested Yahoo fields.""" + field_aliases = { + "close": "Close", + "open": "Open", + "high": "High", + "low": "Low", + "volume": "Volume", + } + normalized_fields = [] + yahoo_fields = [] + for field in fields: + normalized = field.lower() + if normalized not in field_aliases: + raise ValueError(f"Unsupported market data field: {field}") + normalized_fields.append(normalized) + yahoo_fields.append(field_aliases[normalized]) + + os.makedirs(DATA_DIR, exist_ok=True) + start = (datetime.now() - timedelta(days=365 * 10)).strftime("%Y-%m-%d") + downloaded = _download(tickers, start=start, fields=yahoo_fields) + + cleaned = {} + for normalized, yahoo_field in zip(normalized_fields, yahoo_fields): + data = _clean_market_data(downloaded.get(yahoo_field, pd.DataFrame()), normalized) + existing = load(market, normalized) + data = _merge_market_panel(existing, data) + path = _data_path(market, normalized) + data.to_csv(path) + print(f"--- Saved {data.shape[0]} days x {data.shape[1]} tickers to {path} ---") + cleaned[normalized] = data + return cleaned + + def update(market: str, tickers: list[str], with_open: bool = False) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: """ diff --git a/tests/test_market_data.py b/tests/test_market_data.py new file mode 100644 index 0000000..04af1f5 --- /dev/null +++ b/tests/test_market_data.py @@ -0,0 +1,123 @@ +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +import pandas as pd + +import data_manager + + +class UpdateMarketDataTests(unittest.TestCase): + def test_update_market_data_accepts_lowercase_fields_and_does_not_fill_volume(self): + dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"]) + raw = pd.DataFrame( + { + ("Close", "AAA"): [10.0, 11.0, 12.0], + ("Close", "BBB"): [20.0, float("nan"), 22.0], + ("Open", "AAA"): [9.5, 10.5, 11.5], + ("Open", "BBB"): [19.5, 20.5, 21.5], + ("High", "AAA"): [10.5, 11.5, 12.5], + ("High", "BBB"): [20.5, 21.5, 22.5], + ("Low", "AAA"): [9.0, 10.0, 11.0], + ("Low", "BBB"): [19.0, 20.0, 21.0], + ("Volume", "AAA"): [1000, 1100, 1200], + ("Volume", "BBB"): [2000, float("nan"), 2200], + }, + index=dates, + ) + raw.columns = pd.MultiIndex.from_tuples(raw.columns) + + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch.object(data_manager, "DATA_DIR", tmpdir): + with mock.patch("data_manager.yf.download", return_value=raw) as mocked_download: + panels = data_manager.update_market_data( + "us", + ["AAA", "BBB"], + ["close", "open", "high", "low", "volume"], + ) + self.assertEqual(set(panels), {"close", "open", "high", "low", "volume"}) + self.assertEqual(panels["close"].loc[dates[1], "BBB"], 20.0) + self.assertTrue(pd.isna(panels["volume"].loc[dates[1], "BBB"])) + self.assertTrue((Path(tmpdir) / "us.csv").exists()) + self.assertTrue((Path(tmpdir) / "us_open.csv").exists()) + self.assertTrue((Path(tmpdir) / "us_high.csv").exists()) + self.assertTrue((Path(tmpdir) / "us_low.csv").exists()) + self.assertTrue((Path(tmpdir) / "us_volume.csv").exists()) + + saved_high = pd.read_csv(Path(tmpdir) / "us_high.csv", index_col=0, parse_dates=True) + pd.testing.assert_frame_equal(saved_high, panels["high"], check_freq=False) + + self.assertEqual(mocked_download.call_args.args[0], ["AAA", "BBB"]) + self.assertEqual(mocked_download.call_args.kwargs["auto_adjust"], True) + self.assertIn("start", mocked_download.call_args.kwargs) + + def test_update_market_data_rejects_unsupported_fields(self): + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch.object(data_manager, "DATA_DIR", tmpdir): + with self.assertRaisesRegex(ValueError, "Unsupported market data field: adjusted_close"): + data_manager.update_market_data("us", ["AAA"], ["adjusted_close"]) + + def test_update_market_data_preserves_existing_cache_columns_and_dates(self): + existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"]) + new_dates = pd.to_datetime(["2024-01-02", "2024-01-03"]) + existing_close = pd.DataFrame( + { + "AAA": [9.0, 10.0], + "CCC": [30.0, 31.0], + }, + index=existing_dates, + ) + downloaded_close = pd.DataFrame({"Close": [10.5, 11.5]}, index=new_dates) + + with tempfile.TemporaryDirectory() as tmpdir: + existing_close.to_csv(Path(tmpdir) / "us.csv") + with mock.patch.object(data_manager, "DATA_DIR", tmpdir): + with mock.patch("data_manager.yf.download", return_value=downloaded_close): + panels = data_manager.update_market_data("us", ["AAA"], ["close"]) + + expected = pd.DataFrame( + { + "AAA": [9.0, 10.5, 11.5], + "CCC": [30.0, 31.0, float("nan")], + }, + index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]), + ) + saved_close = pd.read_csv(Path(tmpdir) / "us.csv", index_col=0, parse_dates=True) + + pd.testing.assert_frame_equal(panels["close"], expected, check_freq=False) + pd.testing.assert_frame_equal(saved_close, expected, check_freq=False) + + def test_update_market_data_volume_merge_can_clear_stale_cached_values(self): + existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"]) + new_dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"]) + existing_volume = pd.DataFrame( + { + "AAA": [1000.0, 9999.0], + "CCC": [3000.0, 3100.0], + }, + index=existing_dates, + ) + downloaded_volume = pd.DataFrame({"Volume": [float("nan"), 1200.0, 1300.0]}, index=new_dates) + + with tempfile.TemporaryDirectory() as tmpdir: + existing_volume.to_csv(Path(tmpdir) / "us_volume.csv") + with mock.patch.object(data_manager, "DATA_DIR", tmpdir): + with mock.patch("data_manager.yf.download", return_value=downloaded_volume): + panels = data_manager.update_market_data("us", ["AAA"], ["volume"]) + + expected = pd.DataFrame( + { + "AAA": [1000.0, float("nan"), 1200.0, 1300.0], + "CCC": [3000.0, 3100.0, float("nan"), float("nan")], + }, + index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"]), + ) + saved_volume = pd.read_csv(Path(tmpdir) / "us_volume.csv", index_col=0, parse_dates=True) + + pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False) + pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False) + + +if __name__ == "__main__": + unittest.main()