diff --git a/pymc_marketing/mmm/causal.py b/pymc_marketing/mmm/causal.py index a3fc5b84..6dc535d8 100644 --- a/pymc_marketing/mmm/causal.py +++ b/pymc_marketing/mmm/causal.py @@ -15,15 +15,21 @@ from __future__ import annotations +import itertools as it import re import warnings +from collections.abc import Sequence +from typing import Annotated, Literal try: import networkx as nx except ImportError: # Optional dependency nx = None # type: ignore[assignment] + +import numpy as np import pandas as pd import pymc as pm +import pytensor import pytensor.tensor as pt from pydantic import Field, InstanceOf, validate_call from pymc_extras.prior import Prior @@ -309,7 +315,7 @@ def _to_tuple(maybe_dims): return tuple() if isinstance(maybe_dims, str): return (maybe_dims,) - if isinstance(maybe_dims, (list, tuple)): + if isinstance(maybe_dims, list | tuple): return tuple(maybe_dims) return tuple() @@ -362,7 +368,7 @@ def _validate_coords_required_are_consistent(self) -> None: def _to_tuple(maybe_dims): if isinstance(maybe_dims, str): return (maybe_dims,) - if isinstance(maybe_dims, (list, tuple)): + if isinstance(maybe_dims, list | tuple): return tuple(maybe_dims) else: return tuple() @@ -515,6 +521,1040 @@ def dag_graph(self): return g +class TBFPC: + r""" + Target-first Bayes Factor PC (TBF-PC) causal discovery algorithm. + + This algorithm is a target-oriented variant of the Peter–Clark (PC) algorithm, + using Bayes factors (via ΔBIC approximation) as the conditional independence test. + + For each conditional independence test of the form + + .. math:: + + H_0 : Y \perp X \mid S + \quad \text{vs.} \quad + H_1 : Y \not\!\perp X \mid S + + we compare two linear models: + + .. math:: + + M_0 : Y \sim S + \\ + M_1 : Y \sim S + X + + where :math:`S` is a conditioning set of variables. + + The Bayesian Information Criterion (BIC) is defined as + + .. math:: + + \mathrm{BIC}(M) = n \log\!\left(\frac{\mathrm{RSS}}{n}\right) + + k \log(n), + + with residual sum of squares :math:`\mathrm{RSS}`, sample size :math:`n`, + and number of parameters :math:`k`. + + The Bayes factor is approximated by + + .. math:: + + \log \mathrm{BF}_{10} \approx -\tfrac{1}{2} + \left[ \mathrm{BIC}(M_1) - \mathrm{BIC}(M_0) \right]. + + Independence is declared if :math:`\mathrm{BF}_{10} < \tau`, + where :math:`\tau` is set via the ``bf_thresh`` parameter. + + Target Edge Rules + ----------------- + Different rules govern how driver → target edges are retained: + + - ``"any"``: + keep :math:`X \to Y` unless **any** conditioning set renders + :math:`X \perp Y \mid S`. + - ``"conservative"``: + keep :math:`X \to Y` if **at least one** conditioning set shows + dependence. + - ``"fullS"``: + test only with the **full set** of other drivers as :math:`S`. + + Examples + -------- + **1. Basic usage with full conditioning set** + + .. code-block:: python + + import numpy as np, pandas as pd + + rng = np.random.default_rng(7) + n = 2000 + C = rng.gamma(2,1,n) + A = 0.7*C + rng.gamma(2,1,n) + D = 0.5*C + rng.gamma(2,1,n) + B = 0.8*A + rng.gamma(2,1,n) + Y = 0.9*B + 0.6*D + 0.7*C + rng.gamma(2,1,n) + + df = pd.DataFrame({"A":A,"B":B,"C":C,"D":D,"Y":Y}) + df = (df - df.mean())/df.std() # recommended scaling + + model = TBFPC(target="Y", target_edge_rule="fullS") + model.fit(df, drivers=["A","B","C","D"]) + + print(model.get_directed_edges()) + print(model.get_undirected_edges()) + print(model.to_digraph()) + + **2. Using forbidden edges** + + You can specify edges that must *not* be tested or included + (prior knowledge about the domain). + + .. code-block:: python + + model = TBFPC( + target="Y", + target_edge_rule="any", + forbidden_edges=[("A","C")] # forbid A--C + ) + model.fit(df, drivers=["A","B","C","D"]) + print(model.to_digraph()) + + **3. Conservative rule** + + Keeps driver → target edges if **any conditioning set** + shows dependence. + + .. code-block:: python + + model = TBFPC(target="Y", target_edge_rule="conservative") + model.fit(df, drivers=["A","B","C","D"]) + print(model.to_digraph()) + + References + ---------- + - Spirtes, Glymour, Scheines (2000). *Causation, Prediction, and Search*. MIT Press. [PC algorithm] + - Spirtes & Glymour (1991). "An Algorithm for Fast Recovery of Sparse Causal Graphs." + - Kass, R. & Raftery, A. (1995). "Bayes Factors." + """ + + @validate_call(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + target: Annotated[ + str, + Field( + min_length=1, + description="Name of the outcome variable to orient the search.", + ), + ], + *, + target_edge_rule: Literal["any", "conservative", "fullS"] = "any", + bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0, + forbidden_edges: Sequence[tuple[str, str]] | None = None, + ): + """Create a new TBFPC causal discovery model. + + Parameters + ---------- + target + Variable name for the model outcome; must be present in the data + used during fitting. + target_edge_rule + Rule that controls which driver → target edges are retained. + Options are ``"any"``, ``"conservative"``, and ``"fullS"``. + bf_thresh + Positive Bayes factor threshold applied during conditional + independence tests. + forbidden_edges + Optional sequence of node pairs that must not be connected in the + learned graph. + """ + warnings.warn( + "TBFPC is experimental and its API may change; use with caution.", + UserWarning, + stacklevel=2, + ) + + self.target = target + self.target_edge_rule = target_edge_rule + self.bf_thresh = float(bf_thresh) + self.forbidden_edges: set[tuple[str, str]] = set(forbidden_edges or []) + + # Internal state + self.sep_sets: dict[tuple[str, str], set[str]] = {} + self._adj_directed: set[tuple[str, str]] = set() + self._adj_undirected: set[tuple[str, str]] = set() + self.nodes_: list[str] = [] + self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {} + + # Shared response vector for symbolic BIC computation + # Initialized with placeholder; will be updated with actual data during fitting + self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh") + self._bic_fn = self._build_symbolic_bic_fn() + + def _key(self, u: str, v: str) -> tuple[str, str]: + """Return a sorted 2-tuple key for an undirected edge between ``u`` and ``v``.""" + return (u, v) if u <= v else (v, u) + + def _set_sep(self, u: str, v: str, S: Sequence[str]) -> None: + """Record the separation set ``S`` for the node pair ``(u, v)``.""" + self.sep_sets[self._key(u, v)] = set(S) + + def _has_forbidden(self, u: str, v: str) -> bool: + """Return True if edge ``u—v`` is forbidden in either direction.""" + return (u, v) in self.forbidden_edges or (v, u) in self.forbidden_edges + + def _add_directed(self, u: str, v: str) -> None: + """Add a directed edge ``u -> v`` if not forbidden; drop undirected if present.""" + if not self._has_forbidden(u, v): + self._adj_undirected.discard(self._key(u, v)) + self._adj_directed.add((u, v)) + + def _add_undirected(self, u: str, v: str) -> None: + """Add an undirected edge ``u -- v`` if allowed and not already directed.""" + if ( + not self._has_forbidden(u, v) + and (u, v) not in self._adj_directed + and (v, u) not in self._adj_directed + ): + self._adj_undirected.add(self._key(u, v)) + + def _remove_all(self, u: str, v: str) -> None: + """Remove any edge (directed or undirected) between ``u`` and ``v``.""" + self._adj_undirected.discard(self._key(u, v)) + self._adj_directed.discard((u, v)) + self._adj_directed.discard((v, u)) + + def _build_symbolic_bic_fn(self): + """Build a BIC callable using a fast solver with a pseudoinverse fallback.""" + X = pt.matrix("X") + n = pt.iscalar("n") + + xtx = pt.dot(X.T, X) + xty = pt.dot(X.T, self.y_sh) + + beta_solve = pt.linalg.solve(xtx, xty) + resid_solve = self.y_sh - pt.dot(X, beta_solve) + rss_solve = pt.sum(resid_solve**2) + + beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh + resid_pinv = self.y_sh - pt.dot(X, beta_pinv) + rss_pinv = pt.sum(resid_pinv**2) + + k = X.shape[1] + + nf = pt.cast(n, "float64") + rss_solve_safe = pt.maximum(rss_solve, np.finfo("float64").tiny) + rss_pinv_safe = pt.maximum(rss_pinv, np.finfo("float64").tiny) + + bic_solve = nf * pt.log(rss_solve_safe / nf) + k * pt.log(nf) + bic_pinv = nf * pt.log(rss_pinv_safe / nf) + k * pt.log(nf) + + bic_solve_fn = pytensor.function( + [X, n], [bic_solve, rss_solve], on_unused_input="ignore", mode="FAST_RUN" + ) + bic_pinv_fn = pytensor.function( + [X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN" + ) + + def bic_fn(X_val: np.ndarray, n_val: int) -> float: + try: + bic_value, rss_value = bic_solve_fn(X_val, n_val) + if np.isfinite(rss_value) and rss_value > np.finfo("float64").tiny: + return float(bic_value) + except (np.linalg.LinAlgError, RuntimeError, ValueError): + pass + return float(bic_pinv_fn(X_val, n_val)) + + return bic_fn + + def _ci_independent( + self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str] + ) -> bool: + """Return True if ΔBIC indicates independence of ``x`` and ``y`` given ``cond``.""" + if self._has_forbidden(x, y): + return True + + n = len(df) + self.y_sh.set_value(df[y].to_numpy().astype("float64")) + + if len(cond) == 0: + X0 = np.ones((n, 1)) + else: + X0 = np.column_stack([np.ones(n), df[list(cond)].to_numpy()]) + X1 = np.column_stack([X0, df[x].to_numpy()]) + + bic0 = float(self._bic_fn(X0, n)) + bic1 = float(self._bic_fn(X1, n)) + + delta_bic = bic1 - bic0 + logBF10 = -0.5 * delta_bic + BF10 = np.exp(logBF10) + + result = { + "bic0": bic0, + "bic1": bic1, + "delta_bic": delta_bic, + "logBF10": logBF10, + "BF10": BF10, + "independent": BF10 < self.bf_thresh, + "conditioning_set": list(cond), + } + self.test_results[(x, y, frozenset(cond))] = result + + return result["independent"] + + def _test_target_edges(self, df: pd.DataFrame, drivers: Sequence[str]) -> None: + """Phase 1: test driver→target edges according to ``target_edge_rule``.""" + for xi in drivers: + nbrs = [d for d in drivers if d != xi] + max_k = min(3, len(nbrs)) + all_sets = [S for k in range(max_k + 1) for S in it.combinations(nbrs, k)] + + if self.target_edge_rule == "any": + keep = True + for S in all_sets: + if self._ci_independent(df, xi, self.target, S): + self._set_sep(xi, self.target, S) + keep = False + break + if keep: + self._add_directed(xi, self.target) + else: + self._remove_all(xi, self.target) + + elif self.target_edge_rule == "conservative": + indep_all = True + for S in all_sets: + if not self._ci_independent(df, xi, self.target, S): + indep_all = False + else: + self._set_sep(xi, self.target, S) + if indep_all: + self._remove_all(xi, self.target) + else: + self._add_directed(xi, self.target) + + elif self.target_edge_rule == "fullS": + S = tuple(nbrs) + if self._ci_independent(df, xi, self.target, S): + self._set_sep(xi, self.target, S) + self._remove_all(xi, self.target) + else: + self._add_directed(xi, self.target) + + def _test_driver_skeleton(self, df: pd.DataFrame, drivers: Sequence[str]) -> None: + """Phase 2: build the undirected driver skeleton via pairwise CI tests.""" + for xi, xj in it.combinations(drivers, 2): + others = [d for d in drivers if d not in (xi, xj)] + max_k = min(3, len(others)) + dependent = True + sep_rec = False + for k in range(max_k + 1): + for S in it.combinations(others, k): + if self._ci_independent(df, xi, xj, S): + self._set_sep(xi, xj, S) + dependent = False + sep_rec = True + break + if sep_rec: + break + if dependent: + self._add_undirected(xi, xj) + else: + self._remove_all(xi, xj) + + def fit(self, df: pd.DataFrame, drivers: Sequence[str]): + """Fit the TBFPC procedure to the supplied dataframe. + + Parameters + ---------- + df : pandas.DataFrame + Dataset containing the target column and every candidate driver. + drivers : Sequence[str] + Iterable of column names to treat as potential drivers of the + target. + + Returns + ------- + TBFPC + The fitted instance (``self``) with internal adjacency structures + populated. + + Examples + -------- + .. code-block:: python + + model = TBFPC(target="Y", target_edge_rule="fullS") + model.fit(df, drivers=["A", "B", "C"]) + """ + self.sep_sets.clear() + self._adj_directed.clear() + self._adj_undirected.clear() + self.test_results.clear() + + self._test_target_edges(df, drivers) + self._test_driver_skeleton(df, drivers) + + self.nodes_ = [*list(drivers), self.target] + return self + + def get_directed_edges(self) -> list[tuple[str, str]]: + """Return directed edges learned by the algorithm. + + Returns + ------- + list[tuple[str, str]] + Sorted list of ``(u, v)`` pairs representing oriented edges. + + Examples + -------- + .. code-block:: python + + directed = model.get_directed_edges() + """ + return sorted(self._adj_directed) + + def get_undirected_edges(self) -> list[tuple[str, str]]: + """Return undirected edges remaining after orientation. + + Returns + ------- + list[tuple[str, str]] + Sorted list of ``(u, v)`` pairs for unresolved adjacencies. + + Examples + -------- + .. code-block:: python + + skeleton = model.get_undirected_edges() + """ + return sorted(self._adj_undirected) + + def get_test_results(self, x: str, y: str) -> list[dict[str, float]]: + """Return ΔBIC diagnostics for the unordered pair ``(x, y)``. + + Parameters + ---------- + x : str + Name of the first variable in the pair. + y : str + Name of the second variable in the pair. + + Returns + ------- + list[dict[str, float]] + Each dictionary contains ``bic0``, ``bic1``, ``delta_bic``, + ``logBF10``, ``BF10``, and the conditioning set used during the + test. + + Examples + -------- + .. code-block:: python + + stats = model.get_test_results("A", "Y") + """ + return [v for (xi, yi, _), v in self.test_results.items() if {xi, yi} == {x, y}] + + def summary(self) -> str: + """Render a text summary of the learned graph and test count. + + Returns + ------- + str + Multiline string describing directed edges, undirected edges, and + the number of conditional independence tests executed. + + Examples + -------- + .. code-block:: python + + print(model.summary()) + """ + lines = ["=== Directed edges ==="] + for u, v in self.get_directed_edges(): + lines.append(f"{u} -> {v}") + lines.append("=== Undirected edges ===") + for u, v in self.get_undirected_edges(): + lines.append(f"{u} -- {v}") + lines.append("=== Number of CI tests run ===") + lines.append(str(len(self.test_results))) + return "\n".join(lines) + + def to_digraph(self) -> str: + """Return the learned graph encoded in DOT format. + + Returns + ------- + str + DOT string compatible with Graphviz rendering utilities. + + Examples + -------- + .. code-block:: python + + dot_str = model.to_digraph() + """ + lines = ["digraph G {", " node [shape=ellipse];"] + for n in self.nodes_: + if n == self.target: + lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];') + else: + lines.append(f' "{n}";') + for u, v in self.get_directed_edges(): + lines.append(f' "{u}" -> "{v}";') + for u, v in self.get_undirected_edges(): + lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none];') + lines.append("}") + return "\n".join(lines) + + +class TBF_FCI: + r""" + Target-first Bayes Factor Temporal PC. + + This is a time-series–adapted version of TBF-PC. It combines ideas from + temporal FCI/PCMCI with a Bayes-factor ΔBIC conditional independence test. + + For each test :math:`X \perp Y \mid S`, compare: + + .. math:: + + M_0 : Y \sim S + \\ + M_1 : Y \sim S + X + + with BIC scores + + .. math:: + + \mathrm{BIC}(M) = n \log\!\left(\tfrac{\mathrm{RSS}}{n}\right) + + k \log(n), + + and Bayes factor approximation + + .. math:: + + \log \mathrm{BF}_{10} \approx -\tfrac{1}{2} + \left[ \mathrm{BIC}(M_1) - \mathrm{BIC}(M_0) \right]. + + Declare independence if :math:`\mathrm{BF}_{10} < \tau`. + + Parameters + ---------- + target : str + Name of the target variable (at time t). + target_edge_rule : {"any", "conservative", "fullS"} + Rule for keeping lagged → target edges. + bf_thresh : float, default=1.0 + Declare independence if BF10 < bf_thresh. + forbidden_edges : list of tuple[str, str], optional + Prior knowledge: edges to exclude. + max_lag : int, default=2 + Maximum lag to include (t-1, t-2, …). + allow_contemporaneous : bool, default=True + Whether to allow contemporaneous edges at time t. + """ + + @validate_call(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + target: Annotated[ + str, + Field( + min_length=1, + description="Name of the outcome variable at time t.", + ), + ], + *, + target_edge_rule: Literal["any", "conservative", "fullS"] = "any", + bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0, + forbidden_edges: Sequence[tuple[str, str]] | None = None, + max_lag: Annotated[int, Field(ge=0)] = 2, + allow_contemporaneous: bool = True, + ): + """Create a new temporal TBF-PC causal discovery model. + + Parameters + ---------- + target + Target variable name at time ``t`` that the algorithm orients + toward. + target_edge_rule + Rule used to retain lagged → target edges. Choose from + ``"any"``, ``"conservative"``, or ``"fullS"``. + bf_thresh + Positive Bayes factor threshold applied during conditional + independence testing. + forbidden_edges + Optional sequence of node pairs that must be excluded from the + final graph. + max_lag + Maximum lag (inclusive) to consider when constructing temporal + drivers. + allow_contemporaneous + Whether contemporaneous edges at time ``t`` are permitted. + """ + warnings.warn( + "TBF_FCI is experimental and its API may change; use with caution.", + UserWarning, + stacklevel=2, + ) + + self.target = target + self.target_edge_rule = target_edge_rule + self.bf_thresh = float(bf_thresh) + self.max_lag = int(max_lag) + self.allow_contemporaneous = allow_contemporaneous + self.forbidden_edges: set[tuple[str, str]] = self._expand_edges(forbidden_edges) + + self.sep_sets: dict[tuple[str, str], set[str]] = {} + self._adj_directed: set[tuple[str, str]] = set() + self._adj_undirected: set[tuple[str, str]] = set() + self.nodes_: list[str] = [] + self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {} + + # Shared response vector for symbolic BIC computation + # Initialized with placeholder; will be updated with actual data during fitting + self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh") + self._bic_fn = self._build_symbolic_bic_fn() + + def _lag_name(self, var: str, lag: int) -> str: + """Return canonical lagged variable name like ``X[t-2]`` or ``X[t]``.""" + return f"{var}[t-{lag}]" if lag > 0 else f"{var}[t]" + + def _parse_lag(self, name: str) -> tuple[str, int]: + """Parse a lagged variable name into its base and lag components.""" + if "[t-" in name: + base, lagpart = name.split("[t-") + return base, int(lagpart[:-1]) + if "[t]" in name: + return name.replace("[t]", ""), 0 + return name, 0 + + def _expand_edges( + self, forbidden_edges: Sequence[tuple[str, str]] | None + ) -> set[tuple[str, str]]: + """Expand collapsed forbidden edge pairs into all lagged variants.""" + expanded = set() + if forbidden_edges: + for u, v in forbidden_edges: + if "[t" in u or "[t" in v: + expanded.add((u, v)) + else: + for lag_u in range(0, self.max_lag + 1): + for lag_v in range(0, self.max_lag + 1): + u_name = f"{u}[t-{lag_u}]" if lag_u > 0 else f"{u}[t]" + v_name = f"{v}[t-{lag_v}]" if lag_v > 0 else f"{v}[t]" + expanded.add((u_name, v_name)) + return expanded + + def _build_lagged_df( + self, df: pd.DataFrame, variables: Sequence[str] + ) -> pd.DataFrame: + """Construct a time-unrolled dataframe up to ``max_lag`` for variables.""" + frames = {} + for lag in range(0, self.max_lag + 1): + shifted = df[variables].shift(lag) + shifted.columns = [self._lag_name(c, lag) for c in shifted.columns] + frames[lag] = shifted + out = pd.concat(frames.values(), axis=1).iloc[self.max_lag :] + return out.astype("float64") + + def _admissible_cond_set( + self, all_vars: Sequence[str], x: str, y: str + ) -> list[str]: + """Return conditioning variables admissible for testing ``x`` and ``y``.""" + _, lag_x = self._parse_lag(x) + _, lag_y = self._parse_lag(y) + max_time = min(lag_x, lag_y) + keep = [] + for z in all_vars: + if z in (x, y): + continue + _, lag_z = self._parse_lag(z) + if lag_z >= max_time: + keep.append(z) + return keep + + def _key(self, u: str, v: str) -> tuple[str, str]: + """Return sorted tuple key for undirected edges between ``u`` and ``v``.""" + return (u, v) if u <= v else (v, u) + + def _set_sep(self, u: str, v: str, S: Sequence[str]) -> None: + """Store separation set ``S`` associated with nodes ``u`` and ``v``.""" + self.sep_sets[self._key(u, v)] = set(S) + + def _has_forbidden(self, u: str, v: str) -> bool: + """Return True if the edge between ``u`` and ``v`` is forbidden.""" + return (u, v) in self.forbidden_edges or (v, u) in self.forbidden_edges + + def _add_directed(self, u: str, v: str) -> None: + """Insert directed edge ``u -> v`` unless forbidden.""" + if not self._has_forbidden(u, v): + self._adj_undirected.discard(self._key(u, v)) + self._adj_directed.add((u, v)) + + def _add_undirected(self, u: str, v: str) -> None: + """Insert undirected edge ``u -- v`` when no orientation is forced.""" + if ( + not self._has_forbidden(u, v) + and (u, v) not in self._adj_directed + and (v, u) not in self._adj_directed + ): + self._adj_undirected.add(self._key(u, v)) + + def _remove_all(self, u: str, v: str) -> None: + """Remove any edge (directed or undirected) between ``u`` and ``v``.""" + self._adj_undirected.discard(self._key(u, v)) + self._adj_directed.discard((u, v)) + self._adj_directed.discard((v, u)) + + def _build_symbolic_bic_fn(self): + """Build a BIC callable using a fast solver with fallback pseudoinverse.""" + X = pt.matrix("X") + n = pt.iscalar("n") + + xtx = pt.dot(X.T, X) + xty = pt.dot(X.T, self.y_sh) + + beta_solve = pt.linalg.solve(xtx, xty) + resid_solve = self.y_sh - pt.dot(X, beta_solve) + rss_solve = pt.sum(resid_solve**2) + + beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh + resid_pinv = self.y_sh - pt.dot(X, beta_pinv) + rss_pinv = pt.sum(resid_pinv**2) + + k = X.shape[1] + + bic_solve = n * pt.log(rss_solve / n) + k * pt.log(n) + bic_pinv = n * pt.log(rss_pinv / n) + k * pt.log(n) + + bic_solve_fn = pytensor.function( + [X, n], bic_solve, on_unused_input="ignore", mode="FAST_RUN" + ) + bic_pinv_fn = pytensor.function( + [X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN" + ) + + def bic_fn(X_val: np.ndarray, n_val: int) -> float: + try: + value = float(bic_solve_fn(X_val, n_val)) + if np.isfinite(value): + return value + except (np.linalg.LinAlgError, RuntimeError, ValueError): + pass + return float(bic_pinv_fn(X_val, n_val)) + + return bic_fn + + def _ci_independent( + self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str] + ) -> bool: + """Return True if Bayes factor suggests independence of ``x`` and ``y``.""" + if self._has_forbidden(x, y): + return True + n = len(df) + self.y_sh.set_value(df[y].to_numpy().astype("float64")) + if len(cond) == 0: + X0 = np.ones((n, 1)) + else: + X0 = np.column_stack([np.ones(n), df[list(cond)].to_numpy()]) + X1 = np.column_stack([X0, df[x].to_numpy()]) + bic0 = float(self._bic_fn(X0, n)) + bic1 = float(self._bic_fn(X1, n)) + delta_bic = bic1 - bic0 + logBF10 = -0.5 * delta_bic + BF10 = np.exp(logBF10) + result = { + "bic0": bic0, + "bic1": bic1, + "delta_bic": delta_bic, + "logBF10": logBF10, + "BF10": BF10, + "independent": BF10 < self.bf_thresh, + "conditioning_set": list(cond), + } + self.test_results[(x, y, frozenset(cond))] = result + return result["independent"] + + def _stageA_target_lagged(self, L: pd.DataFrame, drivers: Sequence[str]) -> None: + """Evaluate lagged driver → target edges according to edge rule.""" + y = self._lag_name(self.target, 0) + all_cols = list(L.columns) + for v in drivers: + for lag in range(1, self.max_lag + 1): + x = self._lag_name(v, lag) + cand = self._admissible_cond_set(all_cols, x, y) + max_k = min(3, len(cand)) + all_sets = [ + S for k in range(max_k + 1) for S in it.combinations(cand, k) + ] + if self.target_edge_rule == "fullS": + all_sets = [tuple(cand)] + if self.target_edge_rule == "any": + keep = True + for S in all_sets: + if self._ci_independent(L, x, y, S): + self._set_sep(x, y, S) + keep = False + break + if keep: + self._add_directed(x, y) + else: + self._remove_all(x, y) + elif self.target_edge_rule == "conservative": + indep_all = True + for S in all_sets: + if not self._ci_independent(L, x, y, S): + indep_all = False + else: + self._set_sep(x, y, S) + if indep_all: + self._remove_all(x, y) + else: + self._add_directed(x, y) + elif self.target_edge_rule == "fullS": + S = all_sets[0] + if self._ci_independent(L, x, y, S): + self._set_sep(x, y, S) + self._remove_all(x, y) + else: + self._add_directed(x, y) + + def _stageA_driver_lagged(self, L: pd.DataFrame, drivers: Sequence[str]) -> None: + """Build lagged driver skeleton via conditional independence tests.""" + cols = [c for c in L.columns if not c.startswith(self.target)] + for xi, xj in it.combinations(cols, 2): + _, li = self._parse_lag(xi) + _, lj = self._parse_lag(xj) + if li == 0 and lj == 0: + continue + cand = self._admissible_cond_set( + [*cols, self._lag_name(self.target, 0)], xi, xj + ) + max_k = min(3, len(cand)) + dependent, found_sep = True, False + for k in range(max_k + 1): + for S in it.combinations(cand, k): + if self._ci_independent(L, xi, xj, S): + self._set_sep(xi, xj, S) + dependent = False + found_sep = True + break + if found_sep: + break + if dependent: + self._add_undirected(xi, xj) + else: + self._remove_all(xi, xj) + + def _parents_of(self, node: str) -> list[str]: + """Return list of parents for ``node`` using directed adjacencies.""" + return [u for (u, v) in self._adj_directed if v == node] + + def _stageB_contemporaneous(self, L: pd.DataFrame, drivers: Sequence[str]) -> None: + """Test contemporaneous (time ``t``) relations among variables.""" + y_nodes = [self._lag_name(v, 0) for v in [*drivers, self.target]] + for xi, xj in it.combinations(y_nodes, 2): + base_S = list(set(self._parents_of(xi) + self._parents_of(xj))) + cand_extra = [z for z in y_nodes if z not in (xi, xj)] + max_k = 2 + dependent, found_sep = True, False + for k in range(max_k + 1): + for extra in it.combinations(cand_extra, k): + S = tuple(sorted(set(base_S).union(extra))) + if self._ci_independent(L, xi, xj, S): + self._set_sep(xi, xj, S) + dependent = False + found_sep = True + break + if found_sep: + break + if dependent: + self._add_undirected(xi, xj) + else: + self._remove_all(xi, xj) + + def fit(self, df: pd.DataFrame, drivers: Sequence[str]): + """Fit the temporal causal discovery algorithm to ``df``. + + Parameters + ---------- + df : pandas.DataFrame + Input dataframe containing the target column and every driver + column. + drivers : Sequence[str] + Iterable of column names to be treated as drivers of the target. + + Returns + ------- + TBF_FCI + The fitted instance with internal adjacency structures populated. + + Examples + -------- + .. code-block:: python + + model = TBF_FCI(target="Y", max_lag=2) + model.fit(df, drivers=["A", "B"]) + """ + self.sep_sets.clear() + self._adj_directed.clear() + self._adj_undirected.clear() + self.test_results.clear() + all_vars = [*list(drivers), self.target] + L = self._build_lagged_df(df, all_vars) + self.nodes_ = list(L.columns) + self._stageA_target_lagged(L, drivers) + self._stageA_driver_lagged(L, drivers) + if self.allow_contemporaneous: + self._stageB_contemporaneous(L, drivers) + return self + + def collapsed_summary( + self, + ) -> tuple[list[tuple[str, str, int]], list[tuple[str, str]]]: + """Summarize lagged edges into a driver-level view. + + Returns + ------- + tuple[list[tuple[str, str, int]], list[tuple[str, str]]] + A tuple with directed edges represented as ``(u, v, lag)`` and + contemporaneous undirected edges represented as ``(u, v)`` pairs. + + Examples + -------- + .. code-block:: python + + directed, undirected = model.collapsed_summary() + """ + collapsed_directed: list[tuple[str, str, int]] = [] + for u, v in self._adj_directed: + base_u, lag_u = self._parse_lag(u) + base_v, lag_v = self._parse_lag(v) + if lag_v == 0: + collapsed_directed.append((base_u, base_v, lag_u)) + + collapsed_undirected: list[tuple[str, str]] = [] + for u, v in self._adj_undirected: + base_u, lag_u = self._parse_lag(u) + base_v, lag_v = self._parse_lag(v) + if lag_u == lag_v == 0: + collapsed_undirected.append((base_u, base_v)) + + return collapsed_directed, collapsed_undirected + + def get_directed_edges(self) -> list[tuple[str, str]]: + """Return directed edges in the time-unrolled graph. + + Returns + ------- + list[tuple[str, str]] + Sorted list of directed edges in the expanded (lagged) graph. + + Examples + -------- + .. code-block:: python + + directed = model.get_directed_edges() + """ + return sorted(self._adj_directed) + + def get_undirected_edges(self) -> list[tuple[str, str]]: + """Return undirected edges in the time-unrolled graph. + + Returns + ------- + list[tuple[str, str]] + Sorted list of undirected edges among lagged variables. + + Examples + -------- + .. code-block:: python + + undirected = model.get_undirected_edges() + """ + return sorted(self._adj_undirected) + + def summary(self) -> str: + """Return a human-readable summary of edges and test count. + + Returns + ------- + str + Multiline description of directed edges, undirected edges, and the + number of conditional independence tests executed. + + Examples + -------- + .. code-block:: python + + print(model.summary()) + """ + lines = ["=== Directed edges ==="] + for u, v in self.get_directed_edges(): + lines.append(f"{u} -> {v}") + lines.append("=== Undirected edges ===") + for u, v in self.get_undirected_edges(): + lines.append(f"{u} -- {v}") + lines.append("=== Number of CI tests run ===") + lines.append(str(len(self.test_results))) + return "\n".join(lines) + + def to_digraph(self, collapsed: bool = True) -> str: + """Export the learned graph as DOT text. + + Parameters + ---------- + collapsed : bool, default True + ``True`` collapses the time-unrolled graph into driver-level nodes + with lag annotations; ``False`` returns the full lag-expanded + structure. + + Returns + ------- + str + DOT format string suitable for Graphviz rendering. + + Examples + -------- + .. code-block:: python + + dot_text = model.to_digraph(collapsed=True) + """ + lines = ["digraph G {", " node [shape=ellipse];"] + + if not collapsed: + # --- original time-unrolled graph --- + for n in self.nodes_: + if n == self._lag_name(self.target, 0): + lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];') + else: + lines.append(f' "{n}";') + for u, v in self.get_directed_edges(): + lines.append(f' "{u}" -> "{v}";') + for u, v in self.get_undirected_edges(): + lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none];') + else: + directed, undirected = self.collapsed_summary() + base_nodes = {self._parse_lag(n)[0] for n in self.nodes_} + for n in base_nodes: + if n == self.target: + lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];') + else: + lines.append(f' "{n}";') + for u, v, lag in directed: + lines.append(f' "{u}" -> "{v}" [label="lag {lag}"];') + for u, v in undirected: + lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none, label="t"];') + + lines.append("}") + return "\n".join(lines) + + class CausalGraphModel: """Represent a causal model based on a Directed Acyclic Graph (DAG). diff --git a/tests/mmm/test_causal.py b/tests/mmm/test_causal.py index 354bdc9c..9efa593c 100644 --- a/tests/mmm/test_causal.py +++ b/tests/mmm/test_causal.py @@ -22,7 +22,12 @@ from pydantic import ValidationError from pymc_extras.prior import Prior -from pymc_marketing.mmm.causal import BuildModelFromDAG, CausalGraphModel +from pymc_marketing.mmm.causal import ( + TBF_FCI, + TBFPC, + BuildModelFromDAG, + CausalGraphModel, +) # Suppress specific dowhy warnings globally warnings.filterwarnings("ignore", message="The graph defines .* variables") @@ -770,81 +775,247 @@ def test_compute_adjustment_sets( ) -def test_networkx_import_error_in_parse_dag(monkeypatch): - """Test that _parse_dag raises ImportError when networkx is not available.""" - # Mock nx to be None - monkeypatch.setattr("pymc_marketing.mmm.causal.nx", None) +@pytest.fixture(scope="module") +def df_non_ts() -> pd.DataFrame: + rng = np.random.default_rng(123) + n = 100 + A = rng.gamma(2, 1, n) + eB = rng.gamma(2, 1, n) + eC = rng.gamma(2, 1, n) + eY = rng.gamma(2, 1, n) - with pytest.raises( - ImportError, - match=( - r"To use Causal Graph functionality, please install the " - r"optional dependencies with: pip install pymc-marketing\[dag\]" - ), - ): - BuildModelFromDAG._parse_dag("A->B") + B = 0.8 * A + eB + C = eC + Y = 0.5 * A + 0.9 * B + 0.7 * C + eY + return pd.DataFrame({"A": A, "B": B, "C": C, "Y": Y}) -def test_networkx_import_error_in_dag_graph(causal_df, monkeypatch): - """Test that dag_graph raises ImportError when networkx is not available.""" - dag = """ - digraph { - Q -> X; - } - """ - coords = {"date": causal_df["date"].unique()} - # First create the builder normally - builder = BuildModelFromDAG( - dag=dag, - df=causal_df, - target="X", - dims=("date",), - coords=coords, +@pytest.mark.parametrize("target_edge_rule", ["any", "conservative", "fullS"]) +@pytest.mark.parametrize("bf_thresh", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize( + "forbidden_edges", + [ + [], + [("A", "C")], # edge not involving target is allowed + [("A", "Y")], # forbid a potential target edge + ], +) +def test_tbfpc_public_api_types( + df_non_ts: pd.DataFrame, + target_edge_rule: str, + bf_thresh: float, + forbidden_edges, +): + model = TBFPC( + target="Y", + target_edge_rule=target_edge_rule, + bf_thresh=bf_thresh, + forbidden_edges=forbidden_edges, ) + out = model.fit(df_non_ts, drivers=["A", "B", "C"]) # returns self + + assert out is model + # public API returns + assert isinstance(model.summary(), str) + assert isinstance(model.to_digraph(), str) + + directed = model.get_directed_edges() + undirected = model.get_undirected_edges() + assert isinstance(directed, list) + assert isinstance(undirected, list) + if directed: + assert all( + isinstance(e, tuple) and len(e) == 2 and all(isinstance(x, str) for x in e) + for e in directed + ) + if undirected: + assert all( + isinstance(e, tuple) and len(e) == 2 and all(isinstance(x, str) for x in e) + for e in undirected + ) - # Then mock nx to be None and test dag_graph - monkeypatch.setattr("pymc_marketing.mmm.causal.nx", None) - with pytest.raises( - ImportError, - match=( - r"To use Causal Graph functionality, please install the " - r"optional dependencies with: pip install pymc-marketing\[dag\]" - ), - ): - builder.dag_graph() +def test_tbfpc_invalid_drivers_raises(df_non_ts: pd.DataFrame): + model = TBFPC(target="Y", target_edge_rule="fullS") + with pytest.raises(KeyError): + model.fit(df_non_ts, drivers=["A", "B", "D"]) # "D" not in df + + +@pytest.mark.parametrize("edge_rule", ["random", "", None]) +def test_tbfpc_invalid_edge_rule_raises(edge_rule): + with pytest.raises(ValueError): + TBFPC(target="Y", target_edge_rule=edge_rule) # type: ignore[arg-type] + + +def test_tbfpc_emits_experimental_warning(df_non_ts: pd.DataFrame): + with pytest.warns(UserWarning, match="experimental"): + TBFPC(target="Y", target_edge_rule="fullS") -def test_causal_model_lazy_import_when_dowhy_missing(monkeypatch): - """Test that CausalModel uses LazyCausalModel when dowhy is not available.""" - # Simulate dowhy not being installed by removing it from sys.modules - import sys +@pytest.mark.parametrize("bf", [0, -1.0]) +def test_tbfpc_invalid_bf_thresh_raises(bf): + with pytest.raises(ValueError): + TBFPC(target="Y", bf_thresh=bf) # type: ignore[arg-type] - import pymc_marketing.mmm.causal as causal_module - # Save the original CausalModel - original_causal_model = causal_module.CausalModel +def test_tbfpc_internal_key_and_sep(): + m = TBFPC(target="Y") + # _key should sort endpoints + assert m._key("B", "A") == ("A", "B") + m._set_sep("A", "B", ["C"]) + assert m.sep_sets[("A", "B")] == {"C"} - try: - monkeypatch.setitem(sys.modules, "dowhy", None) - # Force reload of the causal module to trigger the import error path - import importlib +def test_tbfpc_has_forbidden_blocks_edges(df_non_ts: pd.DataFrame): + m = TBFPC(target="Y", forbidden_edges=[("A", "Y")]) + # if forbidden, CI returns True (treat as independent) + assert m._has_forbidden("A", "Y") is True + # Build minimal state to call _ci_independent + m.fit(df_non_ts, drivers=["A", "B", "C"]) # initializes y_sh and bic_fn + assert m._ci_independent(df_non_ts, "A", "Y", []) is True + + +@pytest.fixture(scope="module") +def df_ts() -> pd.DataFrame: + rng = np.random.default_rng(123) + n = 300 + x1 = rng.uniform(low=0.0, high=1.0, size=n) + X1_t = np.where(x1 > 0.9, x1, x1 / 2) + + x2 = rng.uniform(low=0.3, high=1.0, size=n) + X2_t = np.where(x2 > 0.8, x2, x2 / 4) - importlib.reload(causal_module) + x3 = rng.uniform(low=0.0, high=1.0, size=n) + X3_t = x3 + (X2_t * 0.2) - # Now CausalModel should be the LazyCausalModel that raises ImportError - with pytest.raises( - ImportError, - match=( - r"To use Causal Graph functionality, please install the " - r"optional dependencies with: pip install pymc-marketing\[dag\]" - ), - ): - causal_module.CausalModel( - data=pd.DataFrame(), graph="A->B", treatment=["A"], outcome="B" - ) - finally: - # Restore the original CausalModel to prevent test pollution - causal_module.CausalModel = original_causal_model + Y_t = ( + (X1_t * 0.2) + + (X2_t * 0.1) + + (X3_t * 0.3) + + rng.normal(loc=0.0, scale=0.05, size=n) + ) + + return pd.DataFrame({"X1": X1_t, "X2": X2_t, "X3": X3_t, "Y": Y_t}) + + +@pytest.mark.parametrize("target_edge_rule", ["any", "conservative", "fullS"]) +@pytest.mark.parametrize("bf_thresh", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize( + "forbidden_edges", + [ + [], + [("X2", "Y"), ("X1", "X2")], + [("X1", "Y")], + ], +) +def test_tbf_fci_public_api_types( + df_ts: pd.DataFrame, + target_edge_rule: str, + bf_thresh: float, + forbidden_edges, +): + model = TBF_FCI( + target="Y", + target_edge_rule=target_edge_rule, + bf_thresh=bf_thresh, + forbidden_edges=forbidden_edges, + max_lag=1, + allow_contemporaneous=True, + ) + out = model.fit(df_ts, drivers=["X1", "X2", "X3"]) # returns self + + assert out is model + # public API returns + assert isinstance(model.summary(), str) + assert isinstance(model.to_digraph(collapsed=False), str) + assert isinstance(model.to_digraph(collapsed=True), str) + + directed = model.get_directed_edges() + undirected = model.get_undirected_edges() + assert isinstance(directed, list) + assert isinstance(undirected, list) + if directed: + assert all( + isinstance(e, tuple) and len(e) == 2 and all(isinstance(x, str) for x in e) + for e in directed + ) + if undirected: + assert all( + isinstance(e, tuple) and len(e) == 2 and all(isinstance(x, str) for x in e) + for e in undirected + ) + + collapsed_directed, collapsed_undirected = model.collapsed_summary() + assert isinstance(collapsed_directed, list) + assert isinstance(collapsed_undirected, list) + for e in collapsed_directed: + assert isinstance(e, tuple) and len(e) == 3 + u, v, lag = e + assert isinstance(u, str) and isinstance(v, str) and isinstance(lag, int) + for e in collapsed_undirected: + assert isinstance(e, tuple) and len(e) == 2 + u, v = e + assert isinstance(u, str) and isinstance(v, str) + + +@pytest.mark.parametrize("edge_rule", ["random", "", None]) +def test_tbf_fci_invalid_edge_rule_raises(edge_rule): + with pytest.raises(ValueError): + TBF_FCI(target="Y", target_edge_rule=edge_rule) # type: ignore[arg-type] + + +@pytest.mark.parametrize("bf", [0, -1.0]) +def test_tbf_fci_invalid_bf_thresh_raises(bf): + with pytest.raises(ValueError): + TBF_FCI(target="Y", bf_thresh=bf) # type: ignore[arg-type] + + +@pytest.mark.parametrize("lag", [-1, 1.5]) +def test_tbf_fci_invalid_max_lag_raises(lag): + with pytest.raises(ValueError): + TBF_FCI(target="Y", max_lag=lag) # type: ignore[arg-type] + + +def test_tbf_fci_emits_experimental_warning(df_ts: pd.DataFrame): + with pytest.warns(UserWarning, match="experimental"): + TBF_FCI(target="Y", max_lag=1) + + +def test_tbf_fci_lag_naming_and_parsing(): + m = TBF_FCI(target="Y", max_lag=2) + assert m._lag_name("X", 0) == "X[t]" + assert m._lag_name("X", 2) == "X[t-2]" + assert m._parse_lag("X[t]") == ("X", 0) + assert m._parse_lag("X[t-2]") == ("X", 2) + + +@pytest.mark.parametrize( + "forbidden_in,expected_contains", + [ + ([("X1", "Y")], {("X1[t]", "Y[t]")}), + ([("X1", "Y")], {("X1[t-1]", "Y[t]")}), + ([("X2[t]", "Y[t]")], {("X2[t]", "Y[t]")}), + ], +) +def test_tbf_fci_expand_edges(forbidden_in, expected_contains): + m = TBF_FCI(target="Y", max_lag=1, forbidden_edges=forbidden_in) + # All expected edges should be in expanded set + assert expected_contains.issubset(m.forbidden_edges) + + +def test_tbf_fci_admissible_cond_set(df_ts: pd.DataFrame): + m = TBF_FCI(target="Y", max_lag=1) + all_vars = ["X1[t]", "X1[t-1]", "X2[t]", "X2[t-1]", "Y[t]"] + # conditioning for (X1[t-1], Y[t]) can include same-time and earlier variables + cand = m._admissible_cond_set(all_vars, "X1[t-1]", "Y[t]") + # excludes the tested variables themselves (X1[t-1], Y[t]) + assert set(cand).issuperset({"X1[t]", "X2[t-1]", "X2[t]"}) + + +def test_tbf_fci_invalid_drivers_raises(df_ts: pd.DataFrame): + model = TBF_FCI( + target="Y", target_edge_rule="fullS", max_lag=1, allow_contemporaneous=True + ) + with pytest.raises(KeyError): + model.fit(df_ts, drivers=["X1", "X2", "X9"]) # "X9" not in df