fix: handle single-ticker yahoo panels
This commit is contained in:
@@ -63,10 +63,11 @@ def _download(tickers: list[str], start: str, end: str | None = None,
|
||||
result = {}
|
||||
for field in fields:
|
||||
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
||||
if len(tickers) > 1:
|
||||
result[field] = raw[field]
|
||||
selected = raw[field]
|
||||
if isinstance(selected, pd.Series):
|
||||
result[field] = selected.to_frame(name=tickers[0])
|
||||
else:
|
||||
result[field] = raw[field].to_frame(name=tickers[0])
|
||||
result[field] = selected
|
||||
else:
|
||||
result[field] = pd.DataFrame()
|
||||
return result
|
||||
@@ -83,10 +84,11 @@ def _download_period(tickers: list[str], period: str,
|
||||
result = {}
|
||||
for field in fields:
|
||||
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
||||
if len(tickers) > 1:
|
||||
result[field] = raw[field]
|
||||
selected = raw[field]
|
||||
if isinstance(selected, pd.Series):
|
||||
result[field] = selected.to_frame(name=tickers[0])
|
||||
else:
|
||||
result[field] = raw[field].to_frame(name=tickers[0])
|
||||
result[field] = selected
|
||||
else:
|
||||
result[field] = pd.DataFrame()
|
||||
return result
|
||||
|
||||
@@ -118,6 +118,27 @@ class UpdateMarketDataTests(unittest.TestCase):
|
||||
pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False)
|
||||
pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False)
|
||||
|
||||
def test_update_market_data_handles_single_ticker_multiindex_download(self):
|
||||
dates = pd.to_datetime(["2024-01-02", "2024-01-03"])
|
||||
raw = pd.DataFrame(
|
||||
{
|
||||
("Close", "AAA"): [10.0, 11.0],
|
||||
("Volume", "AAA"): [1000.0, 1100.0],
|
||||
},
|
||||
index=dates,
|
||||
)
|
||||
raw.columns = pd.MultiIndex.from_tuples(raw.columns)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||
with mock.patch("data_manager.yf.download", return_value=raw):
|
||||
panels = data_manager.update_market_data("us", ["AAA"], ["close", "volume"])
|
||||
|
||||
expected_close = pd.DataFrame({"AAA": [10.0, 11.0]}, index=dates)
|
||||
expected_volume = pd.DataFrame({"AAA": [1000.0, 1100.0]}, index=dates)
|
||||
pd.testing.assert_frame_equal(panels["close"], expected_close, check_freq=False)
|
||||
pd.testing.assert_frame_equal(panels["volume"], expected_volume, check_freq=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user