Skip to content

Commit 2db43b0

Browse files
committed
Update causal.py
1 parent 1a76600 commit 2db43b0

File tree

1 file changed

+140
-38
lines changed

1 file changed

+140
-38
lines changed

pymc_marketing/mmm/causal.py

Lines changed: 140 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import re
2020
import warnings
2121
from collections.abc import Sequence
22+
from typing import Annotated, Literal
2223

2324
try:
2425
import networkx as nx
@@ -637,26 +638,43 @@ class TBFPC:
637638
- Kass, R. & Raftery, A. (1995). "Bayes Factors."
638639
"""
639640

641+
@validate_call(config=dict(arbitrary_types_allowed=True))
640642
def __init__(
641643
self,
642-
target: str,
644+
target: Annotated[
645+
str,
646+
Field(
647+
min_length=1,
648+
description="Name of the outcome variable to orient the search.",
649+
),
650+
],
643651
*,
644-
target_edge_rule: str = "any",
645-
bf_thresh: float = 1.0,
652+
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
653+
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
646654
forbidden_edges: Sequence[tuple[str, str]] | None = None,
647655
):
656+
"""Create a new TBFPC causal discovery model.
657+
658+
Parameters
659+
----------
660+
target
661+
Variable name for the model outcome; must be present in the data
662+
used during fitting.
663+
target_edge_rule
664+
Rule that controls which driver → target edges are retained.
665+
Options are ``"any"``, ``"conservative"``, and ``"fullS"``.
666+
bf_thresh
667+
Positive Bayes factor threshold applied during conditional
668+
independence tests.
669+
forbidden_edges
670+
Optional sequence of node pairs that must not be connected in the
671+
learned graph.
672+
"""
648673
warnings.warn(
649674
"TBFPC is experimental and its API may change; use with caution.",
650675
UserWarning,
651676
stacklevel=2,
652677
)
653-
if not isinstance(target, str) or not target:
654-
raise ValueError("target must be a non-empty string")
655-
allowed_rules = {"any", "conservative", "fullS"}
656-
if target_edge_rule not in allowed_rules:
657-
raise ValueError(f"target_edge_rule must be one of {allowed_rules}")
658-
if not isinstance(bf_thresh, (int, float)) or bf_thresh <= 0:
659-
raise ValueError("bf_thresh must be a positive float")
660678

661679
self.target = target
662680
self.target_edge_rule = target_edge_rule
@@ -670,7 +688,8 @@ def __init__(
670688
self.nodes_: list[str] = []
671689
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
672690

673-
# Shared response vector for symbolic BIC
691+
# Shared response vector for symbolic BIC computation
692+
# Initialized with placeholder; will be updated with actual data during fitting
674693
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
675694
self._bic_fn = self._build_symbolic_bic_fn()
676695

@@ -708,17 +727,47 @@ def _remove_all(self, u: str, v: str) -> None:
708727
self._adj_directed.discard((v, u))
709728

710729
def _build_symbolic_bic_fn(self):
711-
"""Build and compile a function to compute BIC given a design matrix ``X`` and sample size ``n``."""
730+
"""Build a BIC callable using a fast solver with a pseudoinverse fallback."""
712731
X = tt.matrix("X")
713732
n = tt.iscalar("n")
714733

715-
beta = tt.nlinalg.pinv(X) @ self.y_sh
716-
resid = self.y_sh - X @ beta
717-
rss = tt.sum(resid**2)
734+
xtx = tt.dot(X.T, X)
735+
xty = tt.dot(X.T, self.y_sh)
736+
737+
beta_solve = tt.linalg.solve(xtx, xty)
738+
resid_solve = self.y_sh - tt.dot(X, beta_solve)
739+
rss_solve = tt.sum(resid_solve**2)
740+
741+
beta_pinv = tt.nlinalg.pinv(X) @ self.y_sh
742+
resid_pinv = self.y_sh - tt.dot(X, beta_pinv)
743+
rss_pinv = tt.sum(resid_pinv**2)
744+
718745
k = X.shape[1]
719746

720-
bic = n * tt.log(rss / n) + k * tt.log(n)
721-
return pytensor.function([X, n], bic)
747+
nf = tt.cast(n, "float64")
748+
rss_solve_safe = tt.maximum(rss_solve, np.finfo("float64").tiny)
749+
rss_pinv_safe = tt.maximum(rss_pinv, np.finfo("float64").tiny)
750+
751+
bic_solve = nf * tt.log(rss_solve_safe / nf) + k * tt.log(nf)
752+
bic_pinv = nf * tt.log(rss_pinv_safe / nf) + k * tt.log(nf)
753+
754+
bic_solve_fn = pytensor.function(
755+
[X, n], [bic_solve, rss_solve], on_unused_input="ignore", mode="FAST_RUN"
756+
)
757+
bic_pinv_fn = pytensor.function(
758+
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
759+
)
760+
761+
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
762+
try:
763+
bic_value, rss_value = bic_solve_fn(X_val, n_val)
764+
if np.isfinite(rss_value) and rss_value > np.finfo("float64").tiny:
765+
return float(bic_value)
766+
except (np.linalg.LinAlgError, RuntimeError, ValueError):
767+
pass
768+
return float(bic_pinv_fn(X_val, n_val))
769+
770+
return bic_fn
722771

723772
def _ci_independent(
724773
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]
@@ -923,30 +972,50 @@ class TBF_FCI:
923972
Whether to allow contemporaneous edges at time t.
924973
"""
925974

975+
@validate_call(config=dict(arbitrary_types_allowed=True))
926976
def __init__(
927977
self,
928-
target: str,
978+
target: Annotated[
979+
str,
980+
Field(
981+
min_length=1,
982+
description="Name of the outcome variable at time t.",
983+
),
984+
],
929985
*,
930-
target_edge_rule: str = "any",
931-
bf_thresh: float = 1.0,
986+
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
987+
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
932988
forbidden_edges: Sequence[tuple[str, str]] | None = None,
933-
max_lag: int = 2,
989+
max_lag: Annotated[int, Field(ge=0)] = 2,
934990
allow_contemporaneous: bool = True,
935991
):
992+
"""Create a new temporal TBF-PC causal discovery model.
993+
994+
Parameters
995+
----------
996+
target
997+
Target variable name at time ``t`` that the algorithm orients
998+
toward.
999+
target_edge_rule
1000+
Rule used to retain lagged → target edges. Choose from
1001+
``"any"``, ``"conservative"``, or ``"fullS"``.
1002+
bf_thresh
1003+
Positive Bayes factor threshold applied during conditional
1004+
independence testing.
1005+
forbidden_edges
1006+
Optional sequence of node pairs that must be excluded from the
1007+
final graph.
1008+
max_lag
1009+
Maximum lag (inclusive) to consider when constructing temporal
1010+
drivers.
1011+
allow_contemporaneous
1012+
Whether contemporaneous edges at time ``t`` are permitted.
1013+
"""
9361014
warnings.warn(
9371015
"TBF_FCI is experimental and its API may change; use with caution.",
9381016
UserWarning,
9391017
stacklevel=2,
9401018
)
941-
if not isinstance(target, str) or not target:
942-
raise ValueError("target must be a non-empty string")
943-
allowed_rules = {"any", "conservative", "fullS"}
944-
if target_edge_rule not in allowed_rules:
945-
raise ValueError(f"target_edge_rule must be one of {allowed_rules}")
946-
if not isinstance(bf_thresh, (int, float)) or bf_thresh <= 0:
947-
raise ValueError("bf_thresh must be a positive float")
948-
if not isinstance(max_lag, int) or max_lag < 0:
949-
raise ValueError("max_lag must be a non-negative integer")
9501019

9511020
self.target = target
9521021
self.target_edge_rule = target_edge_rule
@@ -961,6 +1030,8 @@ def __init__(
9611030
self.nodes_: list[str] = []
9621031
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
9631032

1033+
# Shared response vector for symbolic BIC computation
1034+
# Initialized with placeholder; will be updated with actual data during fitting
9641035
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
9651036
self._bic_fn = self._build_symbolic_bic_fn()
9661037

@@ -1048,14 +1119,43 @@ def _remove_all(self, u: str, v: str) -> None:
10481119
self._adj_directed.discard((v, u))
10491120

10501121
def _build_symbolic_bic_fn(self):
1122+
"""Build a BIC callable using a fast solver with a pseudoinverse fallback."""
10511123
X = tt.matrix("X")
10521124
n = tt.iscalar("n")
1053-
beta = tt.nlinalg.pinv(X) @ self.y_sh
1054-
resid = self.y_sh - X @ beta
1055-
rss = tt.sum(resid**2)
1125+
1126+
xtx = tt.dot(X.T, X)
1127+
xty = tt.dot(X.T, self.y_sh)
1128+
1129+
beta_solve = tt.linalg.solve(xtx, xty)
1130+
resid_solve = self.y_sh - tt.dot(X, beta_solve)
1131+
rss_solve = tt.sum(resid_solve**2)
1132+
1133+
beta_pinv = tt.nlinalg.pinv(X) @ self.y_sh
1134+
resid_pinv = self.y_sh - tt.dot(X, beta_pinv)
1135+
rss_pinv = tt.sum(resid_pinv**2)
1136+
10561137
k = X.shape[1]
1057-
bic = n * tt.log(rss / n) + k * tt.log(n)
1058-
return pytensor.function([X, n], bic)
1138+
1139+
bic_solve = n * tt.log(rss_solve / n) + k * tt.log(n)
1140+
bic_pinv = n * tt.log(rss_pinv / n) + k * tt.log(n)
1141+
1142+
bic_solve_fn = pytensor.function(
1143+
[X, n], bic_solve, on_unused_input="ignore", mode="FAST_RUN"
1144+
)
1145+
bic_pinv_fn = pytensor.function(
1146+
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
1147+
)
1148+
1149+
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
1150+
try:
1151+
value = float(bic_solve_fn(X_val, n_val))
1152+
if np.isfinite(value):
1153+
return value
1154+
except (np.linalg.LinAlgError, RuntimeError, ValueError):
1155+
pass
1156+
return float(bic_pinv_fn(X_val, n_val))
1157+
1158+
return bic_fn
10591159

10601160
def _ci_independent(
10611161
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]
@@ -1194,15 +1294,17 @@ def fit(self, df: pd.DataFrame, drivers: Sequence[str]):
11941294
self._stageB_contemporaneous(L, drivers)
11951295
return self
11961296

1197-
def collapsed_summary(self):
1198-
collapsed_directed = []
1297+
def collapsed_summary(self) -> tuple[list[tuple[str, str, int]], list[tuple[str, str]]]:
1298+
"""Return collapsed summary of lagged directed and undirected edges."""
1299+
1300+
collapsed_directed: list[tuple[str, str, int]] = []
11991301
for u, v in self._adj_directed:
12001302
base_u, lag_u = self._parse_lag(u)
12011303
base_v, lag_v = self._parse_lag(v)
12021304
if lag_v == 0:
12031305
collapsed_directed.append((base_u, base_v, lag_u))
12041306

1205-
collapsed_undirected = []
1307+
collapsed_undirected: list[tuple[str, str]] = []
12061308
for u, v in self._adj_undirected:
12071309
base_u, lag_u = self._parse_lag(u)
12081310
base_v, lag_v = self._parse_lag(v)

0 commit comments

Comments
 (0)