@@ -394,55 +394,61 @@ end
394
394
ppc_group = :posterior ,
395
395
)
396
396
if length (p. args) < 3
397
- error (" ppcplot requires at least 3 arguments: (posterior_chains, posterior_predictive_chains, observed_data)" )
397
+ error (
398
+ " ppcplot requires at least 3 arguments: (posterior_chains, posterior_predictive_chains, observed_data)" ,
399
+ )
398
400
end
399
-
401
+
400
402
posterior_chains = p. args[1 ]
401
403
pp_chains = p. args[2 ]
402
404
observed_data = p. args[3 ]
403
-
405
+
404
406
if ! (posterior_chains isa Chains)
405
407
error (" First argument must be a Chains object (posterior chains)" )
406
408
end
407
- if ! (pp_chains isa Chains)
409
+ if ! (pp_chains isa Chains)
408
410
error (" Second argument must be a Chains object (posterior predictive chains)" )
409
411
end
410
412
if ! (observed_data isa AbstractVector)
411
413
error (" Third argument must be a vector (observed data)" )
412
414
end
413
-
415
+
414
416
if kind ∉ (:density , :histogram , :scatter , :cumulative )
415
- error (" `kind` must be one of `:density`, `:histogram`, `:scatter`, or `:cumulative`" )
417
+ error (
418
+ " `kind` must be one of `:density`, `:histogram`, `:scatter`, or `:cumulative`" ,
419
+ )
416
420
end
417
-
421
+
418
422
if ppc_group ∉ (:posterior , :prior )
419
423
error (" `ppc_group` must be one of `:posterior` or `:prior`" )
420
424
end
421
-
425
+
422
426
if observed === nothing
423
427
observed = (ppc_group == :posterior )
424
428
end
425
-
429
+
426
430
if length (colors) != 3
427
- error (" `colors` must be a vector of length 3: [predictive_color, observed_color, mean_color]" )
431
+ error (
432
+ " `colors` must be a vector of length 3: [predictive_color, observed_color, mean_color]" ,
433
+ )
428
434
end
429
-
435
+
430
436
if alpha === nothing
431
437
alpha = (kind == :scatter ) ? 0.7 : 0.2
432
438
end
433
-
439
+
434
440
if jitter === nothing
435
441
jitter = 0.0
436
442
end
437
-
443
+
438
444
pp_pooled = pool_chain (pp_chains)
439
445
pp_data = Array (pp_pooled. value. data)
440
446
pp_data = pp_data[:, :, 1 ]
441
-
447
+
442
448
if random_seed != = nothing
443
449
Random. seed! (random_seed)
444
450
end
445
-
451
+
446
452
total_pp_samples = size (pp_data, 1 )
447
453
if num_pp_samples === nothing
448
454
if kind == :scatter
@@ -454,33 +460,33 @@ end
454
460
num_pp_samples = total_pp_samples
455
461
end
456
462
end
457
-
463
+
458
464
if num_pp_samples > total_pp_samples
459
465
@warn " Requested $num_pp_samples samples but only $total_pp_samples available. Using all samples."
460
466
num_pp_samples = total_pp_samples
461
467
end
462
-
468
+
463
469
if num_pp_samples < total_pp_samples
464
470
sample_indices = Random. randperm (total_pp_samples)[1 : num_pp_samples]
465
471
pp_data = pp_data[sample_indices, :]
466
472
end
467
-
473
+
468
474
if ppc_group == :prior
469
475
title := " Prior Predictive Check"
470
476
predictive_label = " Prior Predictive"
471
477
mean_label = " Prior Predictive Mean"
472
478
else
473
- title := " Posterior Predictive Check"
479
+ title := " Posterior Predictive Check"
474
480
predictive_label = " Posterior Predictive"
475
481
mean_label = " Posterior Predictive Mean"
476
482
end
477
483
legend := legend
478
-
484
+
479
485
if kind == :density
480
486
xaxis := " Value"
481
487
yaxis := " Density"
482
-
483
- for i in 1 : size (pp_data, 1 )
488
+
489
+ for i = 1 : size (pp_data, 1 )
484
490
@series begin
485
491
seriestype := :density
486
492
label := i == 1 ? predictive_label : " "
490
496
pp_data[i, :]
491
497
end
492
498
end
493
-
499
+
494
500
if observed
495
501
@series begin
496
502
seriestype := :density
500
506
observed_data
501
507
end
502
508
end
503
-
509
+
504
510
if mean_pp
505
- pp_mean = vec (mean (pp_data, dims= 1 ))
511
+ pp_mean = vec (mean (pp_data, dims = 1 ))
506
512
@series begin
507
513
seriestype := :density
508
514
label := mean_label
512
518
pp_mean
513
519
end
514
520
end
515
-
521
+
516
522
if observed_rug && observed
517
523
y_min = 0
518
524
@series begin
@@ -525,12 +531,12 @@ end
525
531
observed_data
526
532
end
527
533
end
528
-
534
+
529
535
elseif kind == :histogram
530
536
xaxis := " Value"
531
537
yaxis := " Frequency"
532
-
533
- for i in 1 : size (pp_data, 1 )
538
+
539
+ for i = 1 : size (pp_data, 1 )
534
540
@series begin
535
541
seriestype := :histogram
536
542
label := i == 1 ? predictive_label : " "
541
547
pp_data[i, :]
542
548
end
543
549
end
544
-
550
+
545
551
if observed
546
552
@series begin
547
553
seriestype := :histogram
553
559
observed_data
554
560
end
555
561
end
556
-
562
+
557
563
if mean_pp
558
- pp_mean = vec (mean (pp_data, dims= 1 ))
564
+ pp_mean = vec (mean (pp_data, dims = 1 ))
559
565
@series begin
560
566
seriestype := :histogram
561
567
label := mean_label
@@ -566,15 +572,15 @@ end
566
572
pp_mean
567
573
end
568
574
end
569
-
575
+
570
576
elseif kind == :cumulative
571
577
xaxis := " Value"
572
578
yaxis := " Cumulative Probability"
573
-
579
+
574
580
all_data = vcat (vec (pp_data), observed_data)
575
- x_range = range (minimum (all_data), maximum (all_data), length= 200 )
576
-
577
- for i in 1 : size (pp_data, 1 )
581
+ x_range = range (minimum (all_data), maximum (all_data), length = 200 )
582
+
583
+ for i = 1 : size (pp_data, 1 )
578
584
pp_ecdf = ecdf (pp_data[i, :])
579
585
y_vals = pp_ecdf .(x_range)
580
586
@series begin
588
594
()
589
595
end
590
596
end
591
-
597
+
592
598
if observed
593
599
obs_ecdf = ecdf (observed_data)
594
600
obs_y_vals = obs_ecdf .(x_range)
602
608
()
603
609
end
604
610
end
605
-
611
+
606
612
if mean_pp
607
- pp_mean = vec (mean (pp_data, dims= 1 ))
613
+ pp_mean = vec (mean (pp_data, dims = 1 ))
608
614
pp_mean_ecdf = ecdf (pp_mean)
609
615
mean_y_vals = pp_mean_ecdf .(x_range)
610
616
@series begin
618
624
()
619
625
end
620
626
end
621
-
627
+
622
628
if observed_rug && observed
623
629
@series begin
624
630
seriestype := :scatter
@@ -631,19 +637,21 @@ end
631
637
()
632
638
end
633
639
end
634
-
640
+
635
641
elseif kind == :scatter
636
642
xaxis := " Index"
637
643
yaxis := " Value"
638
-
644
+
639
645
if jitter > 0
640
646
jitter_vals = jitter * (rand (size (pp_data, 2 )) .- 0.5 )
641
647
else
642
648
jitter_vals = zeros (size (pp_data, 2 ))
643
649
end
644
-
645
- for i in 1 : size (pp_data, 1 )
646
- y_vals = pp_data[i, :] .+ (jitter > 0 ? jitter * (rand (length (pp_data[i, :])) .- 0.5 ) : 0 )
650
+
651
+ for i = 1 : size (pp_data, 1 )
652
+ y_vals =
653
+ pp_data[i, :] .+
654
+ (jitter > 0 ? jitter * (rand (length (pp_data[i, :])) .- 0.5 ) : 0 )
647
655
@series begin
648
656
seriestype := :scatter
649
657
label := i == 1 ? predictive_label : " "
655
663
()
656
664
end
657
665
end
658
-
666
+
659
667
if observed
660
668
obs_y = observed_data .+ jitter_vals[1 : length (observed_data)]
661
669
@series begin
669
677
()
670
678
end
671
679
end
672
-
680
+
673
681
if mean_pp
674
- pp_mean = vec (mean (pp_data, dims= 1 ))
682
+ pp_mean = vec (mean (pp_data, dims = 1 ))
675
683
mean_y = pp_mean .+ jitter_vals[1 : length (pp_mean)]
676
684
@series begin
677
685
seriestype := :scatter
0 commit comments