fix: handle single-ticker yahoo panels
This commit is contained in:
@@ -63,10 +63,11 @@ def _download(tickers: list[str], start: str, end: str | None = None,
|
|||||||
result = {}
|
result = {}
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
||||||
if len(tickers) > 1:
|
selected = raw[field]
|
||||||
result[field] = raw[field]
|
if isinstance(selected, pd.Series):
|
||||||
|
result[field] = selected.to_frame(name=tickers[0])
|
||||||
else:
|
else:
|
||||||
result[field] = raw[field].to_frame(name=tickers[0])
|
result[field] = selected
|
||||||
else:
|
else:
|
||||||
result[field] = pd.DataFrame()
|
result[field] = pd.DataFrame()
|
||||||
return result
|
return result
|
||||||
@@ -83,10 +84,11 @@ def _download_period(tickers: list[str], period: str,
|
|||||||
result = {}
|
result = {}
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
if field in raw.columns.get_level_values(0) if isinstance(raw.columns, pd.MultiIndex) else field in raw.columns:
|
||||||
if len(tickers) > 1:
|
selected = raw[field]
|
||||||
result[field] = raw[field]
|
if isinstance(selected, pd.Series):
|
||||||
|
result[field] = selected.to_frame(name=tickers[0])
|
||||||
else:
|
else:
|
||||||
result[field] = raw[field].to_frame(name=tickers[0])
|
result[field] = selected
|
||||||
else:
|
else:
|
||||||
result[field] = pd.DataFrame()
|
result[field] = pd.DataFrame()
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -118,6 +118,27 @@ class UpdateMarketDataTests(unittest.TestCase):
|
|||||||
pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False)
|
pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False)
|
||||||
pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False)
|
pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False)
|
||||||
|
|
||||||
|
def test_update_market_data_handles_single_ticker_multiindex_download(self):
|
||||||
|
dates = pd.to_datetime(["2024-01-02", "2024-01-03"])
|
||||||
|
raw = pd.DataFrame(
|
||||||
|
{
|
||||||
|
("Close", "AAA"): [10.0, 11.0],
|
||||||
|
("Volume", "AAA"): [1000.0, 1100.0],
|
||||||
|
},
|
||||||
|
index=dates,
|
||||||
|
)
|
||||||
|
raw.columns = pd.MultiIndex.from_tuples(raw.columns)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||||
|
with mock.patch("data_manager.yf.download", return_value=raw):
|
||||||
|
panels = data_manager.update_market_data("us", ["AAA"], ["close", "volume"])
|
||||||
|
|
||||||
|
expected_close = pd.DataFrame({"AAA": [10.0, 11.0]}, index=dates)
|
||||||
|
expected_volume = pd.DataFrame({"AAA": [1000.0, 1100.0]}, index=dates)
|
||||||
|
pd.testing.assert_frame_equal(panels["close"], expected_close, check_freq=False)
|
||||||
|
pd.testing.assert_frame_equal(panels["volume"], expected_volume, check_freq=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user