feat: add PIT-aware tradable universe mask

This commit is contained in:
2026-04-18 00:23:07 +08:00
parent 1edce83430
commit 7853eafe55
2 changed files with 266 additions and 0 deletions

213
tests/test_us_universe.py Normal file
View File

@@ -0,0 +1,213 @@
import unittest
import warnings
import pandas as pd
class BuildTradableMaskTests(unittest.TestCase):
def test_build_tradable_mask_uses_only_lagged_price_and_liquidity_inputs(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=4, freq="D")
close = pd.DataFrame({"AAA": [4.0, 10.0, 10.0, 10.0]}, index=dates)
volume = pd.DataFrame({"AAA": [float("nan"), 200.0, 200.0, 200.0]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=None,
min_price=5.0,
min_dollar_volume=1000.0,
min_history_days=2,
min_valid_volume_days=2,
liquidity_window=2,
)
expected = pd.DataFrame({"AAA": [False, False, False, True]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_uses_only_lagged_history(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=4, freq="D")
close = pd.DataFrame({"AAA": [10.0, float("nan"), 10.0, 10.0]}, index=dates)
volume = pd.DataFrame({"AAA": [200.0, 200.0, 200.0, 200.0]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=None,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=2,
min_valid_volume_days=1,
liquidity_window=1,
)
expected = pd.DataFrame({"AAA": [False, False, False, False]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_requires_membership_history_before_first_eligible_day(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=4, freq="D")
close = pd.DataFrame({"AAA": [10.0, 10.0, 10.0, 10.0]}, index=dates)
volume = pd.DataFrame({"AAA": [200.0, 200.0, 200.0, 200.0]}, index=dates)
pit_membership = pd.DataFrame({"AAA": [False, False, True, True]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=pit_membership,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=1,
liquidity_window=1,
)
expected = pd.DataFrame({"AAA": [False, False, False, True]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_aligns_pit_membership_without_truthy_carryover(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=3, freq="D")
close = pd.DataFrame(
{
"AAA": [10.0, 10.0, 10.0],
"BBB": [12.0, 12.0, 12.0],
},
index=dates,
)
volume = pd.DataFrame(
{
"AAA": [1_000_000.0, 1_000_000.0, 1_000_000.0],
"BBB": [1_000_000.0, 1_000_000.0, 1_000_000.0],
},
index=dates,
)
pit_membership = pd.DataFrame(
{
"BBB": [True, True, False],
"CCC": [True, True, True],
},
index=pd.date_range("2024-01-02", periods=3, freq="D"),
)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=pit_membership,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=1,
liquidity_window=1,
)
self.assertEqual(len(caught), 0)
expected = pd.DataFrame(
{
"AAA": [False, False, False],
"BBB": [False, False, True],
},
index=dates,
dtype=bool,
)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_treats_missing_membership_cells_as_false(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=3, freq="D")
close = pd.DataFrame({"AAA": [10.0, 10.0, 10.0]}, index=dates)
volume = pd.DataFrame({"AAA": [1_000_000.0, 1_000_000.0, 1_000_000.0]}, index=dates)
pit_membership = pd.DataFrame(
{"AAA": [True, pd.NA, True]},
index=dates,
dtype="boolean",
)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=pit_membership,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=1,
liquidity_window=1,
)
expected = pd.DataFrame({"AAA": [False, False, False]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_uses_strict_thresholds(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=3, freq="D")
close = pd.DataFrame({"AAA": [5.0, 5.0, 5.0]}, index=dates)
volume = pd.DataFrame({"AAA": [300.0, 300.0, 300.0]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=None,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=1,
liquidity_window=1,
)
expected = pd.DataFrame({"AAA": [False, False, False]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_uses_strict_dollar_volume_threshold(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=3, freq="D")
close = pd.DataFrame({"AAA": [8.0, 8.0, 8.0]}, index=dates)
volume = pd.DataFrame({"AAA": [125.0, 125.0, 125.0]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=None,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=1,
liquidity_window=1,
)
expected = pd.DataFrame({"AAA": [False, False, False]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
def test_build_tradable_mask_requires_valid_dollar_volume_history(self):
from research.us_universe import build_tradable_mask
dates = pd.date_range("2024-01-01", periods=4, freq="D")
close = pd.DataFrame({"AAA": [10.0, float("nan"), 10.0, 10.0]}, index=dates)
volume = pd.DataFrame({"AAA": [200.0, 200.0, 200.0, 200.0]}, index=dates)
mask = build_tradable_mask(
close=close,
volume=volume,
pit_membership=None,
min_price=5.0,
min_dollar_volume=1_000.0,
min_history_days=1,
min_valid_volume_days=2,
liquidity_window=2,
)
expected = pd.DataFrame({"AAA": [False, False, False, False]}, index=dates, dtype=bool)
pd.testing.assert_frame_equal(mask, expected)
if __name__ == "__main__":
unittest.main()