Skip to content

Commit 537514a

Browse files
run JuliaFormatter
1 parent 98a733e commit 537514a

File tree

3 files changed

+137
-86
lines changed

3 files changed

+137
-86
lines changed

src/plot.jl

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -394,55 +394,61 @@ end
394394
ppc_group = :posterior,
395395
)
396396
if length(p.args) < 3
397-
error("ppcplot requires at least 3 arguments: (posterior_chains, posterior_predictive_chains, observed_data)")
397+
error(
398+
"ppcplot requires at least 3 arguments: (posterior_chains, posterior_predictive_chains, observed_data)",
399+
)
398400
end
399-
401+
400402
posterior_chains = p.args[1]
401403
pp_chains = p.args[2]
402404
observed_data = p.args[3]
403-
405+
404406
if !(posterior_chains isa Chains)
405407
error("First argument must be a Chains object (posterior chains)")
406408
end
407-
if !(pp_chains isa Chains)
409+
if !(pp_chains isa Chains)
408410
error("Second argument must be a Chains object (posterior predictive chains)")
409411
end
410412
if !(observed_data isa AbstractVector)
411413
error("Third argument must be a vector (observed data)")
412414
end
413-
415+
414416
if kind (:density, :histogram, :scatter, :cumulative)
415-
error("`kind` must be one of `:density`, `:histogram`, `:scatter`, or `:cumulative`")
417+
error(
418+
"`kind` must be one of `:density`, `:histogram`, `:scatter`, or `:cumulative`",
419+
)
416420
end
417-
421+
418422
if ppc_group (:posterior, :prior)
419423
error("`ppc_group` must be one of `:posterior` or `:prior`")
420424
end
421-
425+
422426
if observed === nothing
423427
observed = (ppc_group == :posterior)
424428
end
425-
429+
426430
if length(colors) != 3
427-
error("`colors` must be a vector of length 3: [predictive_color, observed_color, mean_color]")
431+
error(
432+
"`colors` must be a vector of length 3: [predictive_color, observed_color, mean_color]",
433+
)
428434
end
429-
435+
430436
if alpha === nothing
431437
alpha = (kind == :scatter) ? 0.7 : 0.2
432438
end
433-
439+
434440
if jitter === nothing
435441
jitter = 0.0
436442
end
437-
443+
438444
pp_pooled = pool_chain(pp_chains)
439445
pp_data = Array(pp_pooled.value.data)
440446
pp_data = pp_data[:, :, 1]
441-
447+
442448
if random_seed !== nothing
443449
Random.seed!(random_seed)
444450
end
445-
451+
446452
total_pp_samples = size(pp_data, 1)
447453
if num_pp_samples === nothing
448454
if kind == :scatter
@@ -454,33 +460,33 @@ end
454460
num_pp_samples = total_pp_samples
455461
end
456462
end
457-
463+
458464
if num_pp_samples > total_pp_samples
459465
@warn "Requested $num_pp_samples samples but only $total_pp_samples available. Using all samples."
460466
num_pp_samples = total_pp_samples
461467
end
462-
468+
463469
if num_pp_samples < total_pp_samples
464470
sample_indices = Random.randperm(total_pp_samples)[1:num_pp_samples]
465471
pp_data = pp_data[sample_indices, :]
466472
end
467-
473+
468474
if ppc_group == :prior
469475
title := "Prior Predictive Check"
470476
predictive_label = "Prior Predictive"
471477
mean_label = "Prior Predictive Mean"
472478
else
473-
title := "Posterior Predictive Check"
479+
title := "Posterior Predictive Check"
474480
predictive_label = "Posterior Predictive"
475481
mean_label = "Posterior Predictive Mean"
476482
end
477483
legend := legend
478-
484+
479485
if kind == :density
480486
xaxis := "Value"
481487
yaxis := "Density"
482-
483-
for i in 1:size(pp_data, 1)
488+
489+
for i = 1:size(pp_data, 1)
484490
@series begin
485491
seriestype := :density
486492
label := i == 1 ? predictive_label : ""
@@ -490,7 +496,7 @@ end
490496
pp_data[i, :]
491497
end
492498
end
493-
499+
494500
if observed
495501
@series begin
496502
seriestype := :density
@@ -500,9 +506,9 @@ end
500506
observed_data
501507
end
502508
end
503-
509+
504510
if mean_pp
505-
pp_mean = vec(mean(pp_data, dims=1))
511+
pp_mean = vec(mean(pp_data, dims = 1))
506512
@series begin
507513
seriestype := :density
508514
label := mean_label
@@ -512,7 +518,7 @@ end
512518
pp_mean
513519
end
514520
end
515-
521+
516522
if observed_rug && observed
517523
y_min = 0
518524
@series begin
@@ -525,12 +531,12 @@ end
525531
observed_data
526532
end
527533
end
528-
534+
529535
elseif kind == :histogram
530536
xaxis := "Value"
531537
yaxis := "Frequency"
532-
533-
for i in 1:size(pp_data, 1)
538+
539+
for i = 1:size(pp_data, 1)
534540
@series begin
535541
seriestype := :histogram
536542
label := i == 1 ? predictive_label : ""
@@ -541,7 +547,7 @@ end
541547
pp_data[i, :]
542548
end
543549
end
544-
550+
545551
if observed
546552
@series begin
547553
seriestype := :histogram
@@ -553,9 +559,9 @@ end
553559
observed_data
554560
end
555561
end
556-
562+
557563
if mean_pp
558-
pp_mean = vec(mean(pp_data, dims=1))
564+
pp_mean = vec(mean(pp_data, dims = 1))
559565
@series begin
560566
seriestype := :histogram
561567
label := mean_label
@@ -566,15 +572,15 @@ end
566572
pp_mean
567573
end
568574
end
569-
575+
570576
elseif kind == :cumulative
571577
xaxis := "Value"
572578
yaxis := "Cumulative Probability"
573-
579+
574580
all_data = vcat(vec(pp_data), observed_data)
575-
x_range = range(minimum(all_data), maximum(all_data), length=200)
576-
577-
for i in 1:size(pp_data, 1)
581+
x_range = range(minimum(all_data), maximum(all_data), length = 200)
582+
583+
for i = 1:size(pp_data, 1)
578584
pp_ecdf = ecdf(pp_data[i, :])
579585
y_vals = pp_ecdf.(x_range)
580586
@series begin
@@ -588,7 +594,7 @@ end
588594
()
589595
end
590596
end
591-
597+
592598
if observed
593599
obs_ecdf = ecdf(observed_data)
594600
obs_y_vals = obs_ecdf.(x_range)
@@ -602,9 +608,9 @@ end
602608
()
603609
end
604610
end
605-
611+
606612
if mean_pp
607-
pp_mean = vec(mean(pp_data, dims=1))
613+
pp_mean = vec(mean(pp_data, dims = 1))
608614
pp_mean_ecdf = ecdf(pp_mean)
609615
mean_y_vals = pp_mean_ecdf.(x_range)
610616
@series begin
@@ -618,7 +624,7 @@ end
618624
()
619625
end
620626
end
621-
627+
622628
if observed_rug && observed
623629
@series begin
624630
seriestype := :scatter
@@ -631,19 +637,21 @@ end
631637
()
632638
end
633639
end
634-
640+
635641
elseif kind == :scatter
636642
xaxis := "Index"
637643
yaxis := "Value"
638-
644+
639645
if jitter > 0
640646
jitter_vals = jitter * (rand(size(pp_data, 2)) .- 0.5)
641647
else
642648
jitter_vals = zeros(size(pp_data, 2))
643649
end
644-
645-
for i in 1:size(pp_data, 1)
646-
y_vals = pp_data[i, :] .+ (jitter > 0 ? jitter * (rand(length(pp_data[i, :])) .- 0.5) : 0)
650+
651+
for i = 1:size(pp_data, 1)
652+
y_vals =
653+
pp_data[i, :] .+
654+
(jitter > 0 ? jitter * (rand(length(pp_data[i, :])) .- 0.5) : 0)
647655
@series begin
648656
seriestype := :scatter
649657
label := i == 1 ? predictive_label : ""
@@ -655,7 +663,7 @@ end
655663
()
656664
end
657665
end
658-
666+
659667
if observed
660668
obs_y = observed_data .+ jitter_vals[1:length(observed_data)]
661669
@series begin
@@ -669,9 +677,9 @@ end
669677
()
670678
end
671679
end
672-
680+
673681
if mean_pp
674-
pp_mean = vec(mean(pp_data, dims=1))
682+
pp_mean = vec(mean(pp_data, dims = 1))
675683
mean_y = pp_mean .+ jitter_vals[1:length(pp_mean)]
676684
@series begin
677685
seriestype := :scatter

0 commit comments

Comments
 (0)