16
16
import itertools as it
17
17
import warnings
18
18
from collections .abc import Sequence
19
+ from typing import Annotated , Literal
19
20
20
21
import numpy as np
21
22
import pandas as pd
22
23
import pytensor
23
- import pytensor .tensor as tt
24
+ import pytensor .tensor as pt
25
+ from pydantic import Field , validate_call
24
26
25
27
try :
26
28
from dowhy import CausalModel
@@ -156,26 +158,43 @@ class TBFPC:
156
158
- Kass, R. & Raftery, A. (1995). "Bayes Factors."
157
159
"""
158
160
161
+ @validate_call (config = dict (arbitrary_types_allowed = True ))
159
162
def __init__ (
160
163
self ,
161
- target : str ,
164
+ target : Annotated [
165
+ str ,
166
+ Field (
167
+ min_length = 1 ,
168
+ description = "Name of the outcome variable to orient the search." ,
169
+ ),
170
+ ],
162
171
* ,
163
- target_edge_rule : str = "any" ,
164
- bf_thresh : float = 1.0 ,
172
+ target_edge_rule : Literal [ "any" , "conservative" , "fullS" ] = "any" ,
173
+ bf_thresh : Annotated [ float , Field ( gt = 0.0 )] = 1.0 ,
165
174
forbidden_edges : Sequence [tuple [str , str ]] | None = None ,
166
175
):
176
+ """Create a new TBFPC causal discovery model.
177
+
178
+ Parameters
179
+ ----------
180
+ target
181
+ Variable name for the model outcome; must be present in the data
182
+ used during fitting.
183
+ target_edge_rule
184
+ Rule that controls which driver → target edges are retained.
185
+ Options are ``"any"``, ``"conservative"``, and ``"fullS"``.
186
+ bf_thresh
187
+ Positive Bayes factor threshold applied during conditional
188
+ independence tests.
189
+ forbidden_edges
190
+ Optional sequence of node pairs that must not be connected in the
191
+ learned graph.
192
+ """
167
193
warnings .warn (
168
194
"TBFPC is experimental and its API may change; use with caution." ,
169
195
UserWarning ,
170
196
stacklevel = 2 ,
171
197
)
172
- if not isinstance (target , str ) or not target :
173
- raise ValueError ("target must be a non-empty string" )
174
- allowed_rules = {"any" , "conservative" , "fullS" }
175
- if target_edge_rule not in allowed_rules :
176
- raise ValueError (f"target_edge_rule must be one of { allowed_rules } " )
177
- if not isinstance (bf_thresh , (int , float )) or bf_thresh <= 0 :
178
- raise ValueError ("bf_thresh must be a positive float" )
179
198
180
199
self .target = target
181
200
self .target_edge_rule = target_edge_rule
@@ -189,7 +208,8 @@ def __init__(
189
208
self .nodes_ : list [str ] = []
190
209
self .test_results : dict [tuple [str , str , frozenset ], dict [str , float ]] = {}
191
210
192
- # Shared response vector for symbolic BIC
211
+ # Shared response vector for symbolic BIC computation
212
+ # Initialized with placeholder; will be updated with actual data during fitting
193
213
self .y_sh = pytensor .shared (np .zeros (1 , dtype = "float64" ), name = "y_sh" )
194
214
self ._bic_fn = self ._build_symbolic_bic_fn ()
195
215
@@ -233,17 +253,47 @@ def _remove_all(self, u: str, v: str) -> None:
233
253
# Statistical methods
234
254
# ---------------------------------------------------------------------
235
255
def _build_symbolic_bic_fn (self ):
236
- """Build and compile a function to compute BIC given a design matrix ``X`` and sample size ``n``."""
237
- X = tt .matrix ("X" )
238
- n = tt .iscalar ("n" )
256
+ """Build a BIC callable using a fast solver with a pseudoinverse fallback."""
257
+ X = pt .matrix ("X" )
258
+ n = pt .iscalar ("n" )
259
+
260
+ xtx = pt .dot (X .T , X )
261
+ xty = pt .dot (X .T , self .y_sh )
262
+
263
+ beta_solve = pt .linalg .solve (xtx , xty )
264
+ resid_solve = self .y_sh - pt .dot (X , beta_solve )
265
+ rss_solve = pt .sum (resid_solve ** 2 )
266
+
267
+ beta_pinv = pt .nlinalg .pinv (X ) @ self .y_sh
268
+ resid_pinv = self .y_sh - pt .dot (X , beta_pinv )
269
+ rss_pinv = pt .sum (resid_pinv ** 2 )
239
270
240
- beta = tt .nlinalg .pinv (X ) @ self .y_sh
241
- resid = self .y_sh - X @ beta
242
- rss = tt .sum (resid ** 2 )
243
271
k = X .shape [1 ]
244
272
245
- bic = n * tt .log (rss / n ) + k * tt .log (n )
246
- return pytensor .function ([X , n ], bic )
273
+ nf = pt .cast (n , "float64" )
274
+ rss_solve_safe = pt .maximum (rss_solve , np .finfo ("float64" ).tiny )
275
+ rss_pinv_safe = pt .maximum (rss_pinv , np .finfo ("float64" ).tiny )
276
+
277
+ bic_solve = nf * pt .log (rss_solve_safe / nf ) + k * pt .log (nf )
278
+ bic_pinv = nf * pt .log (rss_pinv_safe / nf ) + k * pt .log (nf )
279
+
280
+ bic_solve_fn = pytensor .function (
281
+ [X , n ], [bic_solve , rss_solve ], on_unused_input = "ignore" , mode = "FAST_RUN"
282
+ )
283
+ bic_pinv_fn = pytensor .function (
284
+ [X , n ], bic_pinv , on_unused_input = "ignore" , mode = "FAST_RUN"
285
+ )
286
+
287
+ def bic_fn (X_val : np .ndarray , n_val : int ) -> float :
288
+ try :
289
+ bic_value , rss_value = bic_solve_fn (X_val , n_val )
290
+ if np .isfinite (rss_value ) and rss_value > np .finfo ("float64" ).tiny :
291
+ return float (bic_value )
292
+ except (np .linalg .LinAlgError , RuntimeError , ValueError ):
293
+ pass
294
+ return float (bic_pinv_fn (X_val , n_val ))
295
+
296
+ return bic_fn
247
297
248
298
def _ci_independent (
249
299
self , df : pd .DataFrame , x : str , y : str , cond : Sequence [str ]
@@ -532,30 +582,50 @@ class TBF_FCI:
532
582
- Kass & Raftery (1995). "Bayes Factors." JASA. [ΔBIC ≈ 2 log BF]
533
583
"""
534
584
585
+ @validate_call (config = dict (arbitrary_types_allowed = True ))
535
586
def __init__ (
536
587
self ,
537
- target : str ,
588
+ target : Annotated [
589
+ str ,
590
+ Field (
591
+ min_length = 1 ,
592
+ description = "Name of the outcome variable at time t." ,
593
+ ),
594
+ ],
538
595
* ,
539
- target_edge_rule : str = "any" ,
540
- bf_thresh : float = 1.0 ,
596
+ target_edge_rule : Literal [ "any" , "conservative" , "fullS" ] = "any" ,
597
+ bf_thresh : Annotated [ float , Field ( gt = 0.0 )] = 1.0 ,
541
598
forbidden_edges : Sequence [tuple [str , str ]] | None = None ,
542
- max_lag : int = 2 ,
599
+ max_lag : Annotated [ int , Field ( ge = 0 )] = 2 ,
543
600
allow_contemporaneous : bool = True ,
544
601
):
602
+ """Create a new temporal TBF-PC causal discovery model.
603
+
604
+ Parameters
605
+ ----------
606
+ target
607
+ Target variable name at time ``t`` that the algorithm orients
608
+ toward.
609
+ target_edge_rule
610
+ Rule used to retain lagged → target edges. Choose from
611
+ ``"any"``, ``"conservative"``, or ``"fullS"``.
612
+ bf_thresh
613
+ Positive Bayes factor threshold applied during conditional
614
+ independence testing.
615
+ forbidden_edges
616
+ Optional sequence of node pairs that must be excluded from the
617
+ final graph.
618
+ max_lag
619
+ Maximum lag (inclusive) to consider when constructing temporal
620
+ drivers.
621
+ allow_contemporaneous
622
+ Whether contemporaneous edges at time ``t`` are permitted.
623
+ """
545
624
warnings .warn (
546
625
"TBF_FCI is experimental and its API may change; use with caution." ,
547
626
UserWarning ,
548
627
stacklevel = 2 ,
549
628
)
550
- if not isinstance (target , str ) or not target :
551
- raise ValueError ("target must be a non-empty string" )
552
- allowed_rules = {"any" , "conservative" , "fullS" }
553
- if target_edge_rule not in allowed_rules :
554
- raise ValueError (f"target_edge_rule must be one of { allowed_rules } " )
555
- if not isinstance (bf_thresh , (int , float )) or bf_thresh <= 0 :
556
- raise ValueError ("bf_thresh must be a positive float" )
557
- if not isinstance (max_lag , int ) or max_lag < 0 :
558
- raise ValueError ("max_lag must be a non-negative integer" )
559
629
560
630
self .target = target
561
631
self .target_edge_rule = target_edge_rule
@@ -571,7 +641,8 @@ def __init__(
571
641
self .nodes_ : list [str ] = []
572
642
self .test_results : dict [tuple [str , str , frozenset ], dict [str , float ]] = {}
573
643
574
- # Shared response vector for symbolic BIC
644
+ # Shared response vector for symbolic BIC computation
645
+ # Initialized with placeholder; will be updated with actual data during fitting
575
646
self .y_sh = pytensor .shared (np .zeros (1 , dtype = "float64" ), name = "y_sh" )
576
647
self ._bic_fn = self ._build_symbolic_bic_fn ()
577
648
@@ -679,15 +750,43 @@ def _remove_all(self, u: str, v: str) -> None:
679
750
# Statistical methods
680
751
# ---------------------------------------------------------------------
681
752
def _build_symbolic_bic_fn (self ):
682
- """Build and compile a function to compute BIC for a design matrix and sample size."""
683
- X = tt .matrix ("X" )
684
- n = tt .iscalar ("n" )
685
- beta = tt .nlinalg .pinv (X ) @ self .y_sh
686
- resid = self .y_sh - X @ beta
687
- rss = tt .sum (resid ** 2 )
753
+ """Build a BIC callable using a fast solver with a pseudoinverse fallback."""
754
+ X = pt .matrix ("X" )
755
+ n = pt .iscalar ("n" )
756
+
757
+ xtx = pt .dot (X .T , X )
758
+ xty = pt .dot (X .T , self .y_sh )
759
+
760
+ beta_solve = pt .linalg .solve (xtx , xty )
761
+ resid_solve = self .y_sh - pt .dot (X , beta_solve )
762
+ rss_solve = pt .sum (resid_solve ** 2 )
763
+
764
+ beta_pinv = pt .nlinalg .pinv (X ) @ self .y_sh
765
+ resid_pinv = self .y_sh - pt .dot (X , beta_pinv )
766
+ rss_pinv = pt .sum (resid_pinv ** 2 )
767
+
688
768
k = X .shape [1 ]
689
- bic = n * tt .log (rss / n ) + k * tt .log (n )
690
- return pytensor .function ([X , n ], bic )
769
+
770
+ bic_solve = n * pt .log (rss_solve / n ) + k * pt .log (n )
771
+ bic_pinv = n * pt .log (rss_pinv / n ) + k * pt .log (n )
772
+
773
+ bic_solve_fn = pytensor .function (
774
+ [X , n ], bic_solve , on_unused_input = "ignore" , mode = "FAST_RUN"
775
+ )
776
+ bic_pinv_fn = pytensor .function (
777
+ [X , n ], bic_pinv , on_unused_input = "ignore" , mode = "FAST_RUN"
778
+ )
779
+
780
+ def bic_fn (X_val : np .ndarray , n_val : int ) -> float :
781
+ try :
782
+ value = float (bic_solve_fn (X_val , n_val ))
783
+ if np .isfinite (value ):
784
+ return value
785
+ except (np .linalg .LinAlgError , RuntimeError , ValueError ):
786
+ pass
787
+ return float (bic_pinv_fn (X_val , n_val ))
788
+
789
+ return bic_fn
691
790
692
791
def _ci_independent (
693
792
self , df : pd .DataFrame , x : str , y : str , cond : Sequence [str ]
0 commit comments