feat: add OHLCV market data updater
This commit is contained in:
123
tests/test_market_data.py
Normal file
123
tests/test_market_data.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import data_manager
|
||||
|
||||
|
||||
class UpdateMarketDataTests(unittest.TestCase):
|
||||
def test_update_market_data_accepts_lowercase_fields_and_does_not_fill_volume(self):
|
||||
dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"])
|
||||
raw = pd.DataFrame(
|
||||
{
|
||||
("Close", "AAA"): [10.0, 11.0, 12.0],
|
||||
("Close", "BBB"): [20.0, float("nan"), 22.0],
|
||||
("Open", "AAA"): [9.5, 10.5, 11.5],
|
||||
("Open", "BBB"): [19.5, 20.5, 21.5],
|
||||
("High", "AAA"): [10.5, 11.5, 12.5],
|
||||
("High", "BBB"): [20.5, 21.5, 22.5],
|
||||
("Low", "AAA"): [9.0, 10.0, 11.0],
|
||||
("Low", "BBB"): [19.0, 20.0, 21.0],
|
||||
("Volume", "AAA"): [1000, 1100, 1200],
|
||||
("Volume", "BBB"): [2000, float("nan"), 2200],
|
||||
},
|
||||
index=dates,
|
||||
)
|
||||
raw.columns = pd.MultiIndex.from_tuples(raw.columns)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||
with mock.patch("data_manager.yf.download", return_value=raw) as mocked_download:
|
||||
panels = data_manager.update_market_data(
|
||||
"us",
|
||||
["AAA", "BBB"],
|
||||
["close", "open", "high", "low", "volume"],
|
||||
)
|
||||
self.assertEqual(set(panels), {"close", "open", "high", "low", "volume"})
|
||||
self.assertEqual(panels["close"].loc[dates[1], "BBB"], 20.0)
|
||||
self.assertTrue(pd.isna(panels["volume"].loc[dates[1], "BBB"]))
|
||||
self.assertTrue((Path(tmpdir) / "us.csv").exists())
|
||||
self.assertTrue((Path(tmpdir) / "us_open.csv").exists())
|
||||
self.assertTrue((Path(tmpdir) / "us_high.csv").exists())
|
||||
self.assertTrue((Path(tmpdir) / "us_low.csv").exists())
|
||||
self.assertTrue((Path(tmpdir) / "us_volume.csv").exists())
|
||||
|
||||
saved_high = pd.read_csv(Path(tmpdir) / "us_high.csv", index_col=0, parse_dates=True)
|
||||
pd.testing.assert_frame_equal(saved_high, panels["high"], check_freq=False)
|
||||
|
||||
self.assertEqual(mocked_download.call_args.args[0], ["AAA", "BBB"])
|
||||
self.assertEqual(mocked_download.call_args.kwargs["auto_adjust"], True)
|
||||
self.assertIn("start", mocked_download.call_args.kwargs)
|
||||
|
||||
def test_update_market_data_rejects_unsupported_fields(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||
with self.assertRaisesRegex(ValueError, "Unsupported market data field: adjusted_close"):
|
||||
data_manager.update_market_data("us", ["AAA"], ["adjusted_close"])
|
||||
|
||||
def test_update_market_data_preserves_existing_cache_columns_and_dates(self):
|
||||
existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"])
|
||||
new_dates = pd.to_datetime(["2024-01-02", "2024-01-03"])
|
||||
existing_close = pd.DataFrame(
|
||||
{
|
||||
"AAA": [9.0, 10.0],
|
||||
"CCC": [30.0, 31.0],
|
||||
},
|
||||
index=existing_dates,
|
||||
)
|
||||
downloaded_close = pd.DataFrame({"Close": [10.5, 11.5]}, index=new_dates)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
existing_close.to_csv(Path(tmpdir) / "us.csv")
|
||||
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||
with mock.patch("data_manager.yf.download", return_value=downloaded_close):
|
||||
panels = data_manager.update_market_data("us", ["AAA"], ["close"])
|
||||
|
||||
expected = pd.DataFrame(
|
||||
{
|
||||
"AAA": [9.0, 10.5, 11.5],
|
||||
"CCC": [30.0, 31.0, float("nan")],
|
||||
},
|
||||
index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
|
||||
)
|
||||
saved_close = pd.read_csv(Path(tmpdir) / "us.csv", index_col=0, parse_dates=True)
|
||||
|
||||
pd.testing.assert_frame_equal(panels["close"], expected, check_freq=False)
|
||||
pd.testing.assert_frame_equal(saved_close, expected, check_freq=False)
|
||||
|
||||
def test_update_market_data_volume_merge_can_clear_stale_cached_values(self):
|
||||
existing_dates = pd.to_datetime(["2024-01-01", "2024-01-02"])
|
||||
new_dates = pd.to_datetime(["2024-01-02", "2024-01-03", "2024-01-04"])
|
||||
existing_volume = pd.DataFrame(
|
||||
{
|
||||
"AAA": [1000.0, 9999.0],
|
||||
"CCC": [3000.0, 3100.0],
|
||||
},
|
||||
index=existing_dates,
|
||||
)
|
||||
downloaded_volume = pd.DataFrame({"Volume": [float("nan"), 1200.0, 1300.0]}, index=new_dates)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
existing_volume.to_csv(Path(tmpdir) / "us_volume.csv")
|
||||
with mock.patch.object(data_manager, "DATA_DIR", tmpdir):
|
||||
with mock.patch("data_manager.yf.download", return_value=downloaded_volume):
|
||||
panels = data_manager.update_market_data("us", ["AAA"], ["volume"])
|
||||
|
||||
expected = pd.DataFrame(
|
||||
{
|
||||
"AAA": [1000.0, float("nan"), 1200.0, 1300.0],
|
||||
"CCC": [3000.0, 3100.0, float("nan"), float("nan")],
|
||||
},
|
||||
index=pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"]),
|
||||
)
|
||||
saved_volume = pd.read_csv(Path(tmpdir) / "us_volume.csv", index_col=0, parse_dates=True)
|
||||
|
||||
pd.testing.assert_frame_equal(panels["volume"], expected, check_freq=False)
|
||||
pd.testing.assert_frame_equal(saved_volume, expected, check_freq=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user