Skip to content

Commit b7e097c

Browse files
committed
Update causal.py
1 parent abec212 commit b7e097c

File tree

1 file changed

+141
-42
lines changed

1 file changed

+141
-42
lines changed

pymc_marketing/mmm/causal.py

Lines changed: 141 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import itertools as it
1717
import warnings
1818
from collections.abc import Sequence
19+
from typing import Annotated, Literal
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytensor
23-
import pytensor.tensor as tt
24+
import pytensor.tensor as pt
25+
from pydantic import Field, validate_call
2426

2527
try:
2628
from dowhy import CausalModel
@@ -156,26 +158,43 @@ class TBFPC:
156158
- Kass, R. & Raftery, A. (1995). "Bayes Factors."
157159
"""
158160

161+
@validate_call(config=dict(arbitrary_types_allowed=True))
159162
def __init__(
160163
self,
161-
target: str,
164+
target: Annotated[
165+
str,
166+
Field(
167+
min_length=1,
168+
description="Name of the outcome variable to orient the search.",
169+
),
170+
],
162171
*,
163-
target_edge_rule: str = "any",
164-
bf_thresh: float = 1.0,
172+
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
173+
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
165174
forbidden_edges: Sequence[tuple[str, str]] | None = None,
166175
):
176+
"""Create a new TBFPC causal discovery model.
177+
178+
Parameters
179+
----------
180+
target
181+
Variable name for the model outcome; must be present in the data
182+
used during fitting.
183+
target_edge_rule
184+
Rule that controls which driver → target edges are retained.
185+
Options are ``"any"``, ``"conservative"``, and ``"fullS"``.
186+
bf_thresh
187+
Positive Bayes factor threshold applied during conditional
188+
independence tests.
189+
forbidden_edges
190+
Optional sequence of node pairs that must not be connected in the
191+
learned graph.
192+
"""
167193
warnings.warn(
168194
"TBFPC is experimental and its API may change; use with caution.",
169195
UserWarning,
170196
stacklevel=2,
171197
)
172-
if not isinstance(target, str) or not target:
173-
raise ValueError("target must be a non-empty string")
174-
allowed_rules = {"any", "conservative", "fullS"}
175-
if target_edge_rule not in allowed_rules:
176-
raise ValueError(f"target_edge_rule must be one of {allowed_rules}")
177-
if not isinstance(bf_thresh, (int, float)) or bf_thresh <= 0:
178-
raise ValueError("bf_thresh must be a positive float")
179198

180199
self.target = target
181200
self.target_edge_rule = target_edge_rule
@@ -189,7 +208,8 @@ def __init__(
189208
self.nodes_: list[str] = []
190209
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
191210

192-
# Shared response vector for symbolic BIC
211+
# Shared response vector for symbolic BIC computation
212+
# Initialized with placeholder; will be updated with actual data during fitting
193213
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
194214
self._bic_fn = self._build_symbolic_bic_fn()
195215

@@ -233,17 +253,47 @@ def _remove_all(self, u: str, v: str) -> None:
233253
# Statistical methods
234254
# ---------------------------------------------------------------------
235255
def _build_symbolic_bic_fn(self):
236-
"""Build and compile a function to compute BIC given a design matrix ``X`` and sample size ``n``."""
237-
X = tt.matrix("X")
238-
n = tt.iscalar("n")
256+
"""Build a BIC callable using a fast solver with a pseudoinverse fallback."""
257+
X = pt.matrix("X")
258+
n = pt.iscalar("n")
259+
260+
xtx = pt.dot(X.T, X)
261+
xty = pt.dot(X.T, self.y_sh)
262+
263+
beta_solve = pt.linalg.solve(xtx, xty)
264+
resid_solve = self.y_sh - pt.dot(X, beta_solve)
265+
rss_solve = pt.sum(resid_solve**2)
266+
267+
beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh
268+
resid_pinv = self.y_sh - pt.dot(X, beta_pinv)
269+
rss_pinv = pt.sum(resid_pinv**2)
239270

240-
beta = tt.nlinalg.pinv(X) @ self.y_sh
241-
resid = self.y_sh - X @ beta
242-
rss = tt.sum(resid**2)
243271
k = X.shape[1]
244272

245-
bic = n * tt.log(rss / n) + k * tt.log(n)
246-
return pytensor.function([X, n], bic)
273+
nf = pt.cast(n, "float64")
274+
rss_solve_safe = pt.maximum(rss_solve, np.finfo("float64").tiny)
275+
rss_pinv_safe = pt.maximum(rss_pinv, np.finfo("float64").tiny)
276+
277+
bic_solve = nf * pt.log(rss_solve_safe / nf) + k * pt.log(nf)
278+
bic_pinv = nf * pt.log(rss_pinv_safe / nf) + k * pt.log(nf)
279+
280+
bic_solve_fn = pytensor.function(
281+
[X, n], [bic_solve, rss_solve], on_unused_input="ignore", mode="FAST_RUN"
282+
)
283+
bic_pinv_fn = pytensor.function(
284+
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
285+
)
286+
287+
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
288+
try:
289+
bic_value, rss_value = bic_solve_fn(X_val, n_val)
290+
if np.isfinite(rss_value) and rss_value > np.finfo("float64").tiny:
291+
return float(bic_value)
292+
except (np.linalg.LinAlgError, RuntimeError, ValueError):
293+
pass
294+
return float(bic_pinv_fn(X_val, n_val))
295+
296+
return bic_fn
247297

248298
def _ci_independent(
249299
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]
@@ -532,30 +582,50 @@ class TBF_FCI:
532582
- Kass & Raftery (1995). "Bayes Factors." JASA. [ΔBIC ≈ 2 log BF]
533583
"""
534584

585+
@validate_call(config=dict(arbitrary_types_allowed=True))
535586
def __init__(
536587
self,
537-
target: str,
588+
target: Annotated[
589+
str,
590+
Field(
591+
min_length=1,
592+
description="Name of the outcome variable at time t.",
593+
),
594+
],
538595
*,
539-
target_edge_rule: str = "any",
540-
bf_thresh: float = 1.0,
596+
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
597+
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
541598
forbidden_edges: Sequence[tuple[str, str]] | None = None,
542-
max_lag: int = 2,
599+
max_lag: Annotated[int, Field(ge=0)] = 2,
543600
allow_contemporaneous: bool = True,
544601
):
602+
"""Create a new temporal TBF-PC causal discovery model.
603+
604+
Parameters
605+
----------
606+
target
607+
Target variable name at time ``t`` that the algorithm orients
608+
toward.
609+
target_edge_rule
610+
Rule used to retain lagged → target edges. Choose from
611+
``"any"``, ``"conservative"``, or ``"fullS"``.
612+
bf_thresh
613+
Positive Bayes factor threshold applied during conditional
614+
independence testing.
615+
forbidden_edges
616+
Optional sequence of node pairs that must be excluded from the
617+
final graph.
618+
max_lag
619+
Maximum lag (inclusive) to consider when constructing temporal
620+
drivers.
621+
allow_contemporaneous
622+
Whether contemporaneous edges at time ``t`` are permitted.
623+
"""
545624
warnings.warn(
546625
"TBF_FCI is experimental and its API may change; use with caution.",
547626
UserWarning,
548627
stacklevel=2,
549628
)
550-
if not isinstance(target, str) or not target:
551-
raise ValueError("target must be a non-empty string")
552-
allowed_rules = {"any", "conservative", "fullS"}
553-
if target_edge_rule not in allowed_rules:
554-
raise ValueError(f"target_edge_rule must be one of {allowed_rules}")
555-
if not isinstance(bf_thresh, (int, float)) or bf_thresh <= 0:
556-
raise ValueError("bf_thresh must be a positive float")
557-
if not isinstance(max_lag, int) or max_lag < 0:
558-
raise ValueError("max_lag must be a non-negative integer")
559629

560630
self.target = target
561631
self.target_edge_rule = target_edge_rule
@@ -571,7 +641,8 @@ def __init__(
571641
self.nodes_: list[str] = []
572642
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
573643

574-
# Shared response vector for symbolic BIC
644+
# Shared response vector for symbolic BIC computation
645+
# Initialized with placeholder; will be updated with actual data during fitting
575646
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
576647
self._bic_fn = self._build_symbolic_bic_fn()
577648

@@ -679,15 +750,43 @@ def _remove_all(self, u: str, v: str) -> None:
679750
# Statistical methods
680751
# ---------------------------------------------------------------------
681752
def _build_symbolic_bic_fn(self):
682-
"""Build and compile a function to compute BIC for a design matrix and sample size."""
683-
X = tt.matrix("X")
684-
n = tt.iscalar("n")
685-
beta = tt.nlinalg.pinv(X) @ self.y_sh
686-
resid = self.y_sh - X @ beta
687-
rss = tt.sum(resid**2)
753+
"""Build a BIC callable using a fast solver with a pseudoinverse fallback."""
754+
X = pt.matrix("X")
755+
n = pt.iscalar("n")
756+
757+
xtx = pt.dot(X.T, X)
758+
xty = pt.dot(X.T, self.y_sh)
759+
760+
beta_solve = pt.linalg.solve(xtx, xty)
761+
resid_solve = self.y_sh - pt.dot(X, beta_solve)
762+
rss_solve = pt.sum(resid_solve**2)
763+
764+
beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh
765+
resid_pinv = self.y_sh - pt.dot(X, beta_pinv)
766+
rss_pinv = pt.sum(resid_pinv**2)
767+
688768
k = X.shape[1]
689-
bic = n * tt.log(rss / n) + k * tt.log(n)
690-
return pytensor.function([X, n], bic)
769+
770+
bic_solve = n * pt.log(rss_solve / n) + k * pt.log(n)
771+
bic_pinv = n * pt.log(rss_pinv / n) + k * pt.log(n)
772+
773+
bic_solve_fn = pytensor.function(
774+
[X, n], bic_solve, on_unused_input="ignore", mode="FAST_RUN"
775+
)
776+
bic_pinv_fn = pytensor.function(
777+
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
778+
)
779+
780+
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
781+
try:
782+
value = float(bic_solve_fn(X_val, n_val))
783+
if np.isfinite(value):
784+
return value
785+
except (np.linalg.LinAlgError, RuntimeError, ValueError):
786+
pass
787+
return float(bic_pinv_fn(X_val, n_val))
788+
789+
return bic_fn
691790

692791
def _ci_independent(
693792
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]

0 commit comments

Comments
 (0)