Skip to content

Commit 62bea1e

Browse files
committed
few changes
1 parent 4a24e0a commit 62bea1e

File tree

1 file changed

+83
-68
lines changed

1 file changed

+83
-68
lines changed

src/pybamm/plotting/quick_plot.py

Lines changed: 83 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -175,66 +175,61 @@ def __init__(
175175
# Use discharge capacity as x-axis
176176
self.x_axis = "Discharge capacity [A.h]"
177177

178-
# Extract discharge capacities for all solutions
179178
discharge_capacities = [
180179
solution["Discharge capacity [A.h]"].entries for solution in solutions
181180
]
182-
self.dc_values = discharge_capacities # Store as the x-axis values
181+
self.dc_values = discharge_capacities
183182

184-
# Set discharge capacity range
185183
self.min_dc = min(dc[0] for dc in discharge_capacities)
186184
self.max_dc = max(dc[-1] for dc in discharge_capacities)
187185

188-
# Scaling and unit specific to discharge capacity
189-
self.dc_scaling_factor = 1 # No scaling needed for discharge capacity
190186
self.dc_unit = "A.h"
191-
else:
192-
# Default to time
193-
self.ts_seconds = [solution.t for solution in solutions]
194-
min_t = np.min([t[0] for t in self.ts_seconds])
195-
max_t = np.max([t[-1] for t in self.ts_seconds])
196-
197-
hermite_interp = all(sol.hermite_interpolation for sol in solutions)
198-
199-
def t_sample(sol):
200-
if hermite_interp and n_t_linear > 2:
201-
# Linearly spaced time points
202-
t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1]
203-
t_plot = np.union1d(sol.t, t_linspace)
204-
else:
205-
t_plot = sol.t
206-
return t_plot
207-
208-
ts_seconds = []
209-
for sol in solutions:
210-
# Sample time points for each sub-solution
211-
t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions]
212-
ts_seconds.append(np.concatenate(t_sol))
213-
self.ts_seconds = ts_seconds
214-
215-
# Set timescale
216-
if time_unit is None:
217-
# defaults depend on how long the simulation is
218-
if max_t >= 3600:
219-
time_scaling_factor = 3600 # time in hours
220-
self.time_unit = "h"
221-
else:
222-
time_scaling_factor = 1 # time in seconds
223-
self.time_unit = "s"
224-
elif time_unit == "seconds":
225-
time_scaling_factor = 1
226-
self.time_unit = "s"
227-
elif time_unit == "minutes":
228-
time_scaling_factor = 60
229-
self.time_unit = "min"
230-
elif time_unit == "hours":
231-
time_scaling_factor = 3600
187+
188+
# Default to time
189+
self.ts_seconds = [solution.t for solution in solutions]
190+
min_t = np.min([t[0] for t in self.ts_seconds])
191+
max_t = np.max([t[-1] for t in self.ts_seconds])
192+
193+
hermite_interp = all(sol.hermite_interpolation for sol in solutions)
194+
195+
def t_sample(sol):
196+
if hermite_interp and n_t_linear > 2:
197+
# Linearly spaced time points
198+
t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1]
199+
t_plot = np.union1d(sol.t, t_linspace)
200+
else:
201+
t_plot = sol.t
202+
return t_plot
203+
ts_seconds = []
204+
for sol in solutions:
205+
# Sample time points for each sub-solution
206+
t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions]
207+
ts_seconds.append(np.concatenate(t_sol))
208+
self.ts_seconds = ts_seconds
209+
210+
# Set timescale
211+
if time_unit is None:
212+
# defaults depend on how long the simulation is
213+
if max_t >= 3600:
214+
time_scaling_factor = 3600 # time in hours
232215
self.time_unit = "h"
233216
else:
234-
raise ValueError(f"time unit '{time_unit}' not recognized")
235-
self.time_scaling_factor = time_scaling_factor
236-
self.min_t = min_t / time_scaling_factor
237-
self.max_t = max_t / time_scaling_factor
217+
time_scaling_factor = 1 # time in seconds
218+
self.time_unit = "s"
219+
elif time_unit == "seconds":
220+
time_scaling_factor = 1
221+
self.time_unit = "s"
222+
elif time_unit == "minutes":
223+
time_scaling_factor = 60
224+
self.time_unit = "min"
225+
elif time_unit == "hours":
226+
time_scaling_factor = 3600
227+
self.time_unit = "h"
228+
else:
229+
raise ValueError(f"time unit '{time_unit}' not recognized")
230+
self.time_scaling_factor = time_scaling_factor
231+
self.min_t = min_t / time_scaling_factor
232+
self.max_t = max_t / time_scaling_factor
238233

239234
# Prepare dictionary of variables
240235
# output_variables is a list of strings or lists, e.g.
@@ -435,8 +430,12 @@ def reset_axis(self):
435430
self.axis_limits = {}
436431
for key, variable_lists in self.variables.items():
437432
if variable_lists[0][0].dimensions == 0:
438-
x_min = self.min_t
439-
x_max = self.max_t
433+
if self.x_axis == "Discharge capacity [A.h]":
434+
x_min = self.min_dc
435+
x_max = self.max_dc
436+
else:
437+
x_min = self.min_t
438+
x_max = self.max_t
440439
elif variable_lists[0][0].dimensions == 1:
441440
x_min = self.first_spatial_variable[key][0]
442441
x_max = self.first_spatial_variable[key][-1]
@@ -458,22 +457,38 @@ def reset_axis(self):
458457

459458
# Get min and max variable values
460459
if self.variable_limits[key] == "fixed":
461-
# fixed variable limits: calculate "globlal" min and max
460+
# fixed variable limits: calculate "global" min and max
462461
spatial_vars = self.spatial_variable_dict[key]
463-
var_min = np.min(
464-
[
465-
ax_min(var(self.ts_seconds[i], **spatial_vars))
466-
for i, variable_list in enumerate(variable_lists)
467-
for var in variable_list
468-
]
469-
)
470-
var_max = np.max(
471-
[
472-
ax_max(var(self.ts_seconds[i], **spatial_vars))
473-
for i, variable_list in enumerate(variable_lists)
474-
for var in variable_list
475-
]
476-
)
462+
if self.x_axis == "Discharge capacity [A.h]":
463+
var_min = np.min(
464+
[
465+
ax_min(var(self.dc_values[i], **spatial_vars))
466+
for i, variable_list in enumerate(variable_lists)
467+
for var in variable_list
468+
]
469+
)
470+
var_max = np.max(
471+
[
472+
ax_max(var(self.dc_values[i], **spatial_vars))
473+
for i, variable_list in enumerate(variable_lists)
474+
for var in variable_list
475+
]
476+
)
477+
else:
478+
var_min = np.min(
479+
[
480+
ax_min(var(self.ts_seconds[i], **spatial_vars))
481+
for i, variable_list in enumerate(variable_lists)
482+
for var in variable_list
483+
]
484+
)
485+
var_max = np.max(
486+
[
487+
ax_max(var(self.ts_seconds[i], **spatial_vars))
488+
for i, variable_list in enumerate(variable_lists)
489+
for var in variable_list
490+
]
491+
)
477492
if np.isnan(var_min) or np.isnan(var_max):
478493
raise ValueError(
479494
"The variable limits are set to 'fixed' but the min and max "
@@ -568,7 +583,7 @@ def plot(self, t, dynamic=False):
568583
elif self.x_axis == "Discharge capacity [A.h]":
569584
full_dc = self.dc_values[i]
570585
(self.plots[key][i][j],) = ax.plot(
571-
full_dc / self.dc_scaling_factor,
586+
full_dc,
572587
variable(full_dc),
573588
color=self.colors[i],
574589
linestyle=linestyle,

0 commit comments

Comments
 (0)