|
| 1 | +# |
| 2 | +# Log-likelihood functions |
| 3 | +# |
| 4 | +# This file is part of PINTS. |
| 5 | +# Copyright (c) 2017-2018, University of Oxford. |
| 6 | +# For licensing information, see the LICENSE file distributed with the PINTS |
| 7 | +# software package. |
| 8 | +# |
| 9 | +from __future__ import absolute_import, division |
| 10 | +from __future__ import print_function, unicode_literals |
| 11 | +import pints |
| 12 | +import numpy as np |
| 13 | +import scipy.special |
| 14 | + |
| 15 | + |
| 16 | +class KalmanFilterLogLikelihood(pints.ProblemLogLikelihood): |
| 17 | + """ |
| 18 | + *Extends:* :class:`ProblemLogLikelihood` |
| 19 | +
|
| 20 | +The idea is (I think) would be to define the measurements to come from a base model m(p) with fixed parameters p (i.e. any pints model), plus a linear term with the varying parameters x, plus a normal noise term. That is, defined at time points k =1..N the measurements are: |
| 21 | +
|
| 22 | +z_k = m_k(p) + H_k x_k + v_k |
| 23 | +
|
| 24 | +that you would have a model for the varying parameters as |
| 25 | +
|
| 26 | +x_{k+1} = A_k * x_k + w_k |
| 27 | +
|
| 28 | +where x_k is the vector of varying parameters (i.e. states), A_k is a matrix defining how the states evolve over time, and w_k are samples from a multivariate normal distribution. |
| 29 | +
|
| 30 | +Given a set of fixed paramters p, everything else becomes linear you can use a kalman filter to calculate the likelihood https://en.wikipedia.org/wiki/Kalman_filter#Marginal_likelihood |
| 31 | +
|
| 32 | +The user would specify the base model m, the measurement matrix H_k, the transition matrix A_k, and the variances for v_k and w_k (or perhaps these could be unknowns). |
| 33 | +
|
| 34 | +
|
| 35 | + """ |
| 36 | + def __init__(self, problem, measurement_matrix,measurement_sigma, |
| 37 | + transition_matrix, transition_sigma): |
| 38 | + super(KalmanFilterLogLikelihood, self).__init__(problem) |
| 39 | + |
| 40 | + # Store counts |
| 41 | + self._no = problem.n_outputs() |
| 42 | + self._np = problem.n_parameters() |
| 43 | + self._nt = problem.n_times() |
| 44 | + |
| 45 | + self._H = measurement_matrix |
| 46 | + self._v = measurement_sigma |
| 47 | + |
| 48 | + self._A = transition_matrix |
| 49 | + self._w = transition_sigma |
| 50 | + |
| 51 | + # Check sigmas |
| 52 | + for sigma in [measurement_sigma,transition_sigma]: |
| 53 | + if np.isscalar(sigma): |
| 54 | + sigma = np.ones(self._no) * float(sigma) |
| 55 | + else: |
| 56 | + sigma = pints.vector(sigma) |
| 57 | + if len(sigma) != self._no: |
| 58 | + raise ValueError( |
| 59 | + 'Sigma must be a scalar or a vector of length n_outputs.') |
| 60 | + if np.any(sigma <= 0): |
| 61 | + raise ValueError('Standard deviation must be greater than zero.') |
| 62 | + |
| 63 | + # Pre-calculate parts |
| 64 | + self._offset = -0.5 * self._nt * np.log(2 * np.pi) |
| 65 | + self._offset -= self._nt * np.log(sigma) |
| 66 | + self._multip = -1 / (2.0 * sigma**2) |
| 67 | + |
| 68 | + # Pre-calculate S1 parts |
| 69 | + self._isigma2 = sigma**-2 |
| 70 | + |
| 71 | + def __call__(self, x): |
| 72 | + sim = self._problem.evaluate(x) |
| 73 | + x = x0 |
| 74 | + P = ? |
| 75 | + H = self._H |
| 76 | + A = self._A |
| 77 | + log_like = 0.0 |
| 78 | + for m, z in zip(self._problem.evaluate(x),self._values): |
| 79 | + # predict |
| 80 | + x = A.dot(x) |
| 81 | + P = np.matmul(A,np.matmul(P * A.T)) + Q # Q is transition covariance |
| 82 | + |
| 83 | + # update |
| 84 | + y = z - H.dot(x) - m |
| 85 | + S = R + np.matmul(H , np.matmul(P * H.T)) # R is measurement covariance |
| 86 | + invS = np.linalg.inv(S) |
| 87 | + K = np.matmul(P,np.matmul(H.T * invS)) |
| 88 | + x += P.dot(H.T.dot(K.dot(y))) |
| 89 | + tmp = I - np.matmul(K,H) |
| 90 | + P = np.matmul(tmp,np.matmul(P ,tmp.T)) + np.matmul(K,np.matmul(R,K.T)) |
| 91 | + # or P = np.matmul(tmp,P) # only valid for optimal gain? |
| 92 | + #postfit_residual = z - H.dot(x) - m |
| 93 | + |
| 94 | + log_like -= 0.5*(np.inner(y,invS.dot(y)) + np.linalg.slogdet(S)[1] + no*log(2*pi)) |
| 95 | + |
| 96 | + |
| 97 | + error = self._values - self._problem.evaluate(x) |
| 98 | + return np.sum(self._offset + self._multip * np.sum(error**2, axis=0)) |
| 99 | + |
| 100 | + def evaluateS1(self, x): |
| 101 | + """ See :meth:`LogPDF.evaluateS1()`. """ |
| 102 | + |
0 commit comments