From 1edce834305d5803a38e8d30b42caaaa6ab1badb Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sat, 18 Apr 2026 00:03:07 +0800 Subject: [PATCH] fix: handle single-ticker yahoo panels --- data_manager.py | 14 ++++++++------ tests/test_market_data.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/data_manager.py b/data_manager.py index f843cad..c156b8b 100644 --- a/data_manager.py +++ b/data_manager.py @@ -63,10 +63,11 @@ def _download(tickers: list[str], start: str, end: str | None = None, result = {} for field in fields: if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns: - if len(tickers) > 1: - result[field] = raw[field] + selected = raw[field] + if isinstance(selected, pd.Series): + result[field] = selected.to_frame(name=tickers[0]) else: - result[field] = raw[field].to_frame(name=tickers[0]) + result[field] = selected else: result[field] = pd.DataFrame() return result @@ -83,10 +84,11 @@ def _download_period(tickers: list[str], period: str, result = {} for field in fields: if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns: - if len(tickers) > 1: - result[field] = raw[field] + selected = raw[field] + if isinstance(selected, pd.Series): + result[field] = selected.to_frame(name=tickers[0]) else: - result[field] = raw[field].to_frame(name=tickers[0]) + result[field] = selected else: result[field] = pd.DataFrame() return result diff --git a/tests/test_market_data.py b/tests/test_market_data.py index 04af1f5..9a0241f 100644 --- a/tests/test_market_data.py +++ b/tests/test_market_data.py @@ -118,6 +118,27 @@ class UpdateMarketDataTests(unittest.TestCase): pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False) pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False) + def test_update_market_data_handles_single_ticker_multiindex_download(self): + dates = pd.to_datetime(["2024-01-02", "2024-01-03"]) + raw = pd.DataFrame( + { + ("Close", "AAA"): [10.0, 11.0], + ("Volume", "AAA"): [1000.0, 1100.0], + }, + index=dates, + ) + raw.columns = pd.MultiIndex.from_tuples(raw.columns) + + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch.object(data_manager, "DATA_DIR", tmpdir): + with mock.patch("data_manager.yf.download", return_value=raw): + panels = data_manager.update_market_data("us", ["AAA"], ["close", "volume"]) + + expected_close = pd.DataFrame({"AAA": [10.0, 11.0]}, index=dates) + expected_volume = pd.DataFrame({"AAA": [1000.0, 1100.0]}, index=dates) + pd.testing.assert_frame_equal(panels["close"], expected_close, check_freq=False) + pd.testing.assert_frame_equal(panels["volume"], expected_volume, check_freq=False) + if __name__ == "__main__": unittest.main()