Skip to content

Commit c189f4a

Browse files
authored
Merge pull request #154 from xKDR/add-vi
Add VI for linear regression
2 parents 5f9e94f + f1947a3 commit c189f4a

File tree

11 files changed

+232
-97
lines changed

11 files changed

+232
-97
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["xKDR Forum, Sourish Das"]
44
version = "0.1.1"
55

66
[deps]
7+
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
78
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
910
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
@@ -19,6 +20,7 @@ StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
1920
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
2021

2122
[compat]
23+
AdvancedVI = "0.2.11"
2224
DataFrames = "1"
2325
Distributions = "0.25"
2426
Documenter = "0.27, 1"

docs/src/api/bayesian_regression.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,42 @@
44
BayesianRegression
55
```
66

7+
## Bayesian Algorithms
8+
9+
```@docs
10+
BayesianAlgorithm
11+
MCMC
12+
VI
13+
```
14+
715
## Linear Regression
816

917
### Linear Regression with User Specific Gaussian Prior
1018
```@docs
11-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Gauss, alpha_prior_mean::Float64, beta_prior_mean::Vector{Float64}, sim_size::Int64 = 1000)
12-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Gauss, alpha_prior_mean::Float64, alpha_prior_sd::Float64, beta_prior_mean::Vector{Float64}, beta_prior_sd::Vector{Float64}, sim_size::Int64 = 1000)
19+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Gauss, alpha_prior_mean::Float64, beta_prior_mean::Vector{Float64}, algorithm::BayesianAlgorithm = MCMC())
20+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Gauss, alpha_prior_mean::Float64, alpha_prior_sd::Float64, beta_prior_mean::Vector{Float64}, beta_prior_sd::Vector{Float64}, algorithm::BayesianAlgorithm = MCMC())
1321
```
1422

1523
### Linear Regression with Ridge Prior
1624
```@docs
17-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Ridge, h::Float64 = 0.01, sim_size::Int64 = 1000)
25+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Ridge, algorithm::BayesianAlgorithm = MCMC(), h::Float64 = 0.01)
1826
```
1927

2028
### Linear Regression with Laplace Prior
2129
```@docs
22-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Laplace, h::Float64 = 0.01, sim_size::Int64 = 1000)
30+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Laplace, algorithm::BayesianAlgorithm = MCMC(), h::Float64 = 0.01)
2331
```
2432
### Linear Regression with Cauchy Prior
2533
```@docs
26-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Cauchy, sim_size::Int64 = 1000)
34+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_Cauchy, algorithm::BayesianAlgorithm = MCMC())
2735
```
2836
### Linear Regression with T-distributed Prior
2937
```@docs
30-
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_TDist, h::Float64 = 2.0, sim_size::Int64 = 1000)
38+
fit(formula::FormulaTerm, data::DataFrame, modelClass::LinearRegression, prior::Prior_TDist, algorithm::BayesianAlgorithm = MCMC(), h::Float64 = 2.0)
3139
```
3240
### Linear Regression with Horse Shoe Prior
3341
```@docs
34-
fit(formula::FormulaTerm,data::DataFrame,modelClass::LinearRegression,prior::Prior_HorseShoe,sim_size::Int64 = 1000)
42+
fit(formula::FormulaTerm,data::DataFrame,modelClass::LinearRegression,prior::Prior_HorseShoe,algorithm::BayesianAlgorithm = MCMC())
3543
```
3644

3745
## Logistic Regression

src/CRRao.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module CRRao
22

33
using DataFrames, GLM, Turing, StatsModels, StatsBase
44
using StatsBase, Distributions, LinearAlgebra
5-
using Optim, NLSolversBase, Random, HypothesisTests
5+
using Optim, NLSolversBase, Random, HypothesisTests, AdvancedVI
66
import StatsBase: coef, coeftable, r2, adjr2, loglikelihood, aic, bic, predict, residuals, cooksdistance, fit
77
import HypothesisTests: pvalue
88

@@ -392,9 +392,47 @@ end
392392

393393
Cauchit() = Cauchit(Cauchit_Link)
394394

395+
"""
396+
```julia
397+
BayesianAlgorithm
398+
```
399+
400+
Abstract type representing bayesian algorithms which are used to dispatch to appropriate calls.
401+
"""
402+
abstract type BayesianAlgorithm end
403+
404+
"""
405+
```julia
406+
MCMC <: BayesianAlgorithm
407+
```
408+
409+
A type representing MCMC algorithms.
410+
"""
411+
struct MCMC <: BayesianAlgorithm
412+
sim_size::Int64
413+
prediction_chain_start::Int64
414+
end
415+
416+
MCMC() = MCMC(1000, 200)
417+
418+
"""
419+
```julia
420+
VI <: BayesianAlgorithm
421+
```
422+
423+
A type representing variational inference algorithms.
424+
"""
425+
struct VI <: BayesianAlgorithm
426+
distribution_sample_count::Int64
427+
vi_max_iters::Int64
428+
vi_samples_per_step::Int64
429+
end
430+
431+
VI() = VI(1000, 10000, 100)
432+
395433
export LinearRegression, LogisticRegression, PoissonRegression, NegBinomRegression, Boot_Residual
396434
export Prior_Ridge, Prior_Laplace, Prior_Cauchy, Prior_TDist, Prior_HorseShoe, Prior_Gauss
397-
export CRRaoLink, Logit, Probit, Cloglog, Cauchit, fit
435+
export CRRaoLink, Logit, Probit, Cloglog, Cauchit, fit, BayesianAlgorithm, MCMC, VI
398436
export coef, coeftable, r2, adjr2, loglikelihood, aic, bic, sigma, predict, residuals, cooksdistance, BPTest, pvalue
399437
export FrequentistRegression, BayesianRegression
400438

src/bayesian/getter.jl

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,27 @@
1-
function predict(container::BayesianRegression{:LinearRegression}, newdata::DataFrame, prediction_chain_start::Int64 = 200)
1+
function predict(container::BayesianRegression{:LinearRegression}, newdata::DataFrame)
22
X = modelmatrix(container.formula, newdata)
3-
4-
params = get_params(container.chain[prediction_chain_start:end,:,:])
5-
W = params.β
6-
if isa(W, Tuple)
7-
W = reduce(hcat, W)
8-
end
9-
#predictions = params.α' .+ X * W'
10-
predictions = X * W'
3+
W = container.samples
4+
predictions = X * W
115
return vec(mean(predictions, dims=2))
126
end
137

148
function predict(container::BayesianRegression{:LogisticRegression}, newdata::DataFrame, prediction_chain_start::Int64 = 200)
159
X = modelmatrix(container.formula, newdata)
16-
17-
params = get_params(container.chain[prediction_chain_start:end,:,:])
18-
W = params.β
19-
if isa(W, Tuple)
20-
W = reduce(hcat, W)
21-
end
22-
#z = params.α' .+ X * W'
23-
z = X * W'
10+
W = container.samples[:, prediction_chain_start:end]
11+
z = X * W
2412
return vec(mean(container.link.link_function.(z), dims=2))
2513
end
2614

2715
function predict(container::BayesianRegression{:NegativeBinomialRegression}, newdata::DataFrame, prediction_chain_start::Int64 = 200)
2816
X = modelmatrix(container.formula, newdata)
29-
30-
params = get_params(container.chain[prediction_chain_start:end,:,:])
31-
W = params.β
32-
if isa(W, Tuple)
33-
W = reduce(hcat, W)
34-
end
35-
#z = params.α' .+ X * W'
36-
z = X * W'
17+
W = container.samples[:, prediction_chain_start:end]
18+
z = X * W
3719
return vec(mean(exp.(z), dims=2))
3820
end
3921

4022
function predict(container::BayesianRegression{:PoissonRegression}, newdata::DataFrame, prediction_chain_start::Int64 = 200)
4123
X = modelmatrix(container.formula, newdata)
42-
43-
params = get_params(container.chain[prediction_chain_start:end,:,:])
44-
W = params.β
45-
if isa(W, Tuple)
46-
W = reduce(hcat, W)
47-
end
48-
#z = params.α' .+ X * W'
49-
z = X * W'
24+
W = container.samples[:, prediction_chain_start:end]
25+
z = X * W
5026
return vec(mean(exp.(z), dims=2))
5127
end

0 commit comments

Comments
 (0)