Skip to content

Commit f8fd2b8

Browse files
authored
Merge pull request #65 from sct-pipeline/jca/cnr
Add CNR figure for spine-generic data
2 parents 1fe1bf2 + eaac07e commit f8fd2b8

File tree

1 file changed

+56
-32
lines changed

1 file changed

+56
-32
lines changed
Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
#!/usr/bin/env python
2+
#
3+
# Generate figures for the spine-generic results
4+
15
import pandas as pd
26
import numpy as np
37
import argparse
4-
import seaborn as sns
58
import os
69
import matplotlib.pyplot as plt
710
import ptitprince as pt
11+
import seaborn as sns
812
from matplotlib.patches import PathPatch
913

1014
sns.set(style="whitegrid", font_scale=1)
1115

1216

13-
def get_parameters():
17+
def get_parser():
1418
parser = argparse.ArgumentParser(description='Generate figure for spine generic dataset')
1519
parser.add_argument("-ir", "--path-input-results",
1620
help="Path to results.csv",
@@ -20,19 +24,15 @@ def get_parameters():
2024
required=True)
2125
parser.add_argument("-o", "--path-output",
2226
help="Path to save images",
23-
required=True,
24-
)
25-
arguments = parser.parse_args()
26-
return arguments
27+
required=True)
28+
return parser
2729

2830

2931
def adjust_box_widths(g, fac):
30-
# From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579
31-
3232
"""
3333
Adjust the widths of a seaborn-generated boxplot.
34+
Source: From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579
3435
"""
35-
3636
# iterating through Axes instances
3737
for ax in g.axes:
3838

@@ -58,32 +58,50 @@ def adjust_box_widths(g, fac):
5858

5959
# setting new width of median line
6060
for l in ax.lines:
61-
if not l.get_xdata().size == 0:
62-
if np.all(np.equal(l.get_xdata(), [xmin, xmax])):
61+
if not len(l.get_xdata()) == 0:
62+
if np.all(np.equal(l.get_xdata()[0:2], [xmin, xmax])):
6363
l.set_xdata([xmin_new, xmax_new])
6464

6565

6666
def generate_figure(data_in, column, path_output):
67-
# Hue Input for Subgroups
6867
dx = np.ones(len(data_in[column]))
6968
dy = column
70-
dhue = "Manufacturer"
71-
ort = "v"
72-
# dodge blue, limegreen, red
73-
colors = [ "#1E90FF", "#32CD32","#FF0000" ]
74-
pal = colors
75-
sigma = .2
69+
hue = "Manufacturer"
70+
pal = ["#1E90FF", "#32CD32", "#FF0000"]
7671
f, ax = plt.subplots(figsize=(4, 6))
77-
78-
ax = pt.RainCloud(x=dx, y=dy, hue=dhue, data=data_in, palette=pal, bw=sigma,
79-
width_viol=.5, ax=ax, orient=ort, alpha=.4, dodge=True, width_box=.35,
80-
box_showmeans=True,
81-
box_meanprops={"marker":"^", "markerfacecolor":"black", "markeredgecolor":"black", "markersize":"10"},
82-
box_notch=True)
72+
if column == 'CNR_single/t':
73+
coeff = 100
74+
else:
75+
coeff = 1
76+
ax = pt.half_violinplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, bw=.4, cut=0., linewidth=0.,
77+
scale="area", width=.8, inner=None, orient="v", dodge=False, alpha=.4, offset=0.5)
78+
ax = sns.boxplot(x=dx, y=dy, data=data_in*coeff, hue=hue, color="black", palette=pal,
79+
showcaps=True, boxprops={'facecolor': 'none', "zorder": 10}, showmeans=True,
80+
meanprops={"marker": "^", "markerfacecolor": "black", "markeredgecolor": "black",
81+
"markersize": "8"},
82+
showfliers=True, whiskerprops={'linewidth': 2, "zorder": 10},
83+
saturation=1, orient="v", dodge=True)
84+
ax = sns.stripplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, edgecolor="white",
85+
size=3, jitter=1, zorder=0, orient="v", dodge=True)
86+
plt.xlim([-1, 0.5])
87+
handles, labels = ax.get_legend_handles_labels()
88+
# The code below doesn't work (the label for CNR is "GEGEGEGEGEGEG...") so i need to hard-code the labels (because
89+
# I don't have time to dig further).
90+
# _ = plt.legend(handles[0:len(labels) // 3], labels[0:len(labels) // 3],
91+
# bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
92+
# title=str(hue))
93+
_ = plt.legend(handles[0:3], ['Philips', 'Siemens', 'GE'],
94+
bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
95+
title=str(hue))
8396
f.gca().invert_xaxis()
84-
#adjust boxplot width
85-
adjust_box_widths(f, 0.4)
86-
plt.xlabel(column)
97+
adjust_box_widths(f, 0.6)
98+
# special hack
99+
if column == 'CNR_single/t':
100+
plt.xlabel('CNR_single/√t')
101+
fname_out = os.path.join(path_output, 'figure_CNR_single_t')
102+
else:
103+
plt.xlabel(column)
104+
fname_out = os.path.join(path_output, 'figure_' + column)
87105
# remove ylabel
88106
plt.ylabel('')
89107
# hide xtick
@@ -93,11 +111,17 @@ def generate_figure(data_in, column, path_output):
93111
bottom=False,
94112
top=False,
95113
labelbottom=False)
96-
# plt.legend(title="Line", loc='upper left', handles=handles[::-1])
97-
plt.savefig(os.path.join(path_output, 'figure_' + column), bbox_inches='tight', dpi=300)
114+
plt.savefig(fname_out, bbox_inches='tight', dpi=300)
115+
98116

117+
def main(argv=None):
118+
# user params
119+
parser = get_parser()
120+
args = parser.parse_args(argv)
121+
path_input_results = args.path_input_results
122+
path_input_participants = args.path_input_participants
123+
path_output = args.path_output
99124

100-
def main(path_input_results, path_input_participants, path_output):
101125
if not os.path.isdir(path_output):
102126
os.makedirs(path_output)
103127

@@ -114,8 +138,8 @@ def main(path_input_results, path_input_participants, path_output):
114138

115139
generate_figure(content_results_csv, 'SNR_single', path_output)
116140
generate_figure(content_results_csv, 'Contrast', path_output)
141+
generate_figure(content_results_csv, 'CNR_single/t', path_output)
117142

118143

119144
if __name__ == "__main__":
120-
args = get_parameters()
121-
main(args.path_input_results, args.path_input_participants, args.path_output)
145+
main()

0 commit comments

Comments
 (0)