Skip to content

Commit 56ef457

Browse files
committed
Finish sketching out major APIs
1 parent 2141063 commit 56ef457

File tree

2 files changed

+316
-35
lines changed

2 files changed

+316
-35
lines changed

docs/tutorials/pandas_accessor_tutorial.py

Lines changed: 286 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -374,53 +374,304 @@ def create_df(
374374
)
375375

376376
# %%
377-
plumes_over = ["run"]
378-
increase_resolution = 100
377+
from itertools import cycle
378+
379+
import matplotlib.lines as mlines
380+
import matplotlib.patches as mpatches
381+
382+
fig, ax = plt.subplots()
383+
in_ts = small_ts.loc[pix.isin(variable="variable_0")]
384+
quantile_over = "run"
385+
pre_calculated = False
386+
observed = True
379387
quantiles_plumes = (
380-
(0.5, 0.8),
388+
((0.5,), 0.8),
389+
((0.25, 0.75), 0.75),
381390
((0.05, 0.95), 0.5),
382391
)
392+
hue_var = "scenario"
393+
hue_var_label = None
394+
style_var = "variable"
395+
style_var_label = None
396+
palette = None
397+
dashes = None
398+
observed = True
399+
increase_resolution = 100
400+
linewidth = 2
401+
402+
# The joy of plotting, you create everything yourself.
403+
# TODO: split creation from use?
404+
if hue_var_label is None:
405+
hue_var_label = hue_var.capitalize()
406+
if style_var_label is None:
407+
style_var_label = style_var.capitalize()
408+
409+
quantiles = []
410+
for quantile_plot_def in quantiles_plumes:
411+
q_def = quantile_plot_def[0]
412+
try:
413+
for q in q_def:
414+
quantiles.append(q)
415+
except TypeError:
416+
quantiles.append(q_def)
417+
418+
_palette = {} if palette is None else palette
419+
420+
if dashes is None:
421+
_dashes = {}
422+
lines = ["-", "--", "-.", ":"]
423+
linestyle_cycler = cycle(lines)
424+
else:
425+
_dashes = dashes
426+
427+
# Need to keep track of this, just in case we end up plotting only plumes
428+
_plotted_lines = False
429+
430+
quantile_labels = {}
431+
plotted_hues = []
432+
plotted_styles = []
433+
units_l = []
434+
for q, alpha in quantiles_plumes:
435+
for hue_value, hue_ts in in_ts.groupby(hue_var, observed=observed):
436+
for style_value, hue_style_ts in hue_ts.groupby(style_var, observed=observed):
437+
# Remake in inner loop to avoid leaking between plots
438+
pkwargs = {"alpha": alpha}
439+
440+
if pre_calculated:
441+
# Should add some checks here
442+
raise NotImplementedError()
443+
# Maybe something like the below
444+
# missing_quantile = False
445+
# for qt in q:
446+
# if qt not in quantiles:
447+
# warnings.warn(
448+
# f"Quantile {qt} not available for {hue_value=} {style_value=}"
449+
# )
450+
# missing_quantile = True
451+
452+
# if missing_quantile:
453+
# continue
454+
else:
455+
_pdf = (
456+
hue_ts.ct.to_df(increase_resolution=increase_resolution)
457+
.ct.groupby_except(quantile_over)
458+
.quantile(quantiles)
459+
.ct.fix_index_name_after_groupby_quantile()
460+
)
383461

384-
fig, ax = plt.subplots()
385-
for scenario, s_ts in small_ts.loc[pix.isin(variable="variable_0")].groupby(
386-
"scenario", observed=True
387-
):
388-
for quantiles, alpha in quantiles_plumes:
389-
s_quants = (
390-
s_ts.ct.to_df(increase_resolution=increase_resolution)
391-
.groupby(small_ts.index.names.difference(plumes_over), observed=True)
392-
.quantile(quantiles)
393-
)
394-
if isinstance(quantiles, tuple):
395-
ax.fill_between(
396-
s_quants.columns.values.squeeze(),
397-
# As long as there are only two rows,
398-
# doesn't matter which way around you do this.
399-
s_quants.iloc[0, :].values.squeeze(),
400-
s_quants.iloc[1, :].values.squeeze(),
401-
alpha=alpha,
402-
# label=scenario,
403-
)
404-
else:
405-
ax.plot(
406-
s_quants.columns.values.squeeze(),
407-
s_quants.values.squeeze(),
408-
alpha=alpha,
409-
label=scenario,
462+
if hue_value not in plotted_hues:
463+
plotted_hues.append(hue_value)
464+
465+
x_vals = _pdf.columns.values.squeeze()
466+
# Require ur for this to work
467+
# x_vals = get_plot_vals(
468+
# self.time_axis.bounds,
469+
# "self.time_axis.bounds",
470+
# warn_if_magnitudes=warn_if_plotting_magnitudes,
471+
# )
472+
473+
if palette is not None:
474+
try:
475+
pkwargs["color"] = _palette[hue_value]
476+
except KeyError:
477+
error_msg = f"{hue_value} not in palette. {palette=}"
478+
raise KeyError(error_msg)
479+
elif hue_value in _palette:
480+
pkwargs["color"] = _palette[hue_value]
481+
# else:
482+
# # Let matplotlib default cycling do its thing
483+
484+
n_q_for_plume = 2
485+
plot_plume = len(q) == n_q_for_plume
486+
plot_line = len(q) == 1
487+
488+
if plot_plume:
489+
label = f"{q[0] * 100:.0f}th - {q[1] * 100:.0f}th"
490+
491+
y_lower_vals = _pdf.loc[pix.ismatch(quantile=q[0])].values.squeeze()
492+
y_upper_vals = _pdf.loc[pix.ismatch(quantile=q[1])].values.squeeze()
493+
# Require ur for this to work
494+
# Also need the 1D check back in too
495+
# y_lower_vals = get_plot_vals(
496+
# self.time_axis.bounds,
497+
# "self.time_axis.bounds",
498+
# warn_if_magnitudes=warn_if_plotting_magnitudes,
499+
# )
500+
p = ax.fill_between(
501+
x_vals,
502+
y_lower_vals,
503+
y_upper_vals,
504+
label=label,
505+
**pkwargs,
506+
)
507+
508+
if palette is None:
509+
_palette[hue_value] = p.get_facecolor()[0]
510+
511+
elif plot_line:
512+
if style_value not in plotted_styles:
513+
plotted_styles.append(style_value)
514+
515+
_plotted_lines = True
516+
517+
if dashes is not None:
518+
try:
519+
pkwargs["linestyle"] = _dashes[style_value]
520+
except KeyError:
521+
error_msg = f"{style_value} not in dashes. {dashes=}"
522+
raise KeyError(error_msg)
523+
else:
524+
if style_value not in _dashes:
525+
_dashes[style_value] = next(linestyle_cycler)
526+
527+
pkwargs["linestyle"] = _dashes[style_value]
528+
529+
if isinstance(q[0], str):
530+
label = q[0]
531+
else:
532+
label = f"{q[0] * 100:.0f}th"
533+
534+
y_vals = _pdf.loc[pix.ismatch(quantile=q[0])].values.squeeze()
535+
# Require ur for this to work
536+
# Also need the 1D check back in too
537+
# y_vals = get_plot_vals(
538+
# self.time_axis.bounds,
539+
# "self.time_axis.bounds",
540+
# warn_if_magnitudes=warn_if_plotting_magnitudes,
541+
# )
542+
p = ax.plot(
543+
x_vals,
544+
y_vals,
545+
label=label,
546+
linewidth=linewidth,
547+
**pkwargs,
548+
)[0]
549+
550+
if dashes is None:
551+
_dashes[style_value] = p.get_linestyle()
552+
553+
if palette is None:
554+
_palette[hue_value] = p.get_color()
555+
556+
else:
557+
msg = f"quantiles to plot must be of length one or two, received: {q}"
558+
raise ValueError(msg)
559+
560+
if label not in quantile_labels:
561+
quantile_labels[label] = p
562+
563+
# Once we have unit handling with matplotlib, we can remove this
564+
# (and if matplotlib isn't set up, we just don't do unit handling)
565+
units_l.extend(_pdf.pix.unique("units").unique().tolist())
566+
567+
# Fake the line handles for the legend
568+
hue_val_lines = [
569+
mlines.Line2D([0], [0], color=_palette[hue_value], label=hue_value)
570+
for hue_value in plotted_hues
571+
]
572+
573+
legend_items = [
574+
mpatches.Patch(alpha=0, label="Quantiles"),
575+
*quantile_labels.values(),
576+
mpatches.Patch(alpha=0, label=hue_var_label),
577+
*hue_val_lines,
578+
]
579+
580+
if _plotted_lines:
581+
style_val_lines = [
582+
mlines.Line2D(
583+
[0],
584+
[0],
585+
linestyle=_dashes[style_value],
586+
label=style_value,
587+
color="gray",
588+
linewidth=linewidth,
410589
)
590+
for style_value in plotted_styles
591+
]
592+
legend_items += [
593+
mpatches.Patch(alpha=0, label=style_var_label),
594+
*style_val_lines,
595+
]
596+
elif dashes is not None:
597+
warnings.warn(
598+
"`dashes` was passed but no lines were plotted, the style settings "
599+
"will not be used"
600+
)
411601

412-
ax.legend()
602+
ax.legend(handles=legend_items, loc="best")
603+
604+
if len(set(units_l)) == 1:
605+
ax.set_ylabel(units_l[0])
606+
607+
# return ax, legend_items
608+
609+
610+
quantiles
413611

414612
# %%
415-
(
613+
demo_q = (
416614
small_ts.ct.to_df(increase_resolution=5)
417-
.groupby(small_ts.index.names.difference(["run"]), observed=True)
615+
.ct.groupby_except("run")
418616
.quantile([0.05, 0.5, 0.95])
617+
.ct.fix_index_name_after_groupby_quantile()
618+
)
619+
demo_q
620+
621+
# %%
622+
units_col = "units"
623+
indf = demo_q
624+
out_l = []
625+
626+
# The 'shortcut'
627+
target_units = "Gt / yr"
628+
locs_target_units = ((pix.ismatch(**{units_col: "**"}), target_units),)
629+
locs_target_units = (
630+
(pix.ismatch(scenario="scenario_2"), "Gt / yr"),
631+
(pix.ismatch(scenario="scenario_0"), "kt / yr"),
632+
(
633+
demo_q.index.get_level_values("scenario").isin(["scenario_1"])
634+
& demo_q.index.get_level_values("variable").isin(["variable_1"]),
635+
"t / yr",
636+
),
419637
)
638+
# locs_target_units = (
639+
# (pix.ismatch(scenario="*"), "t / yr"),
640+
# )
641+
642+
converted = None
643+
for locator, target_unit in locs_target_units:
644+
if converted is None:
645+
converted = locator
646+
else:
647+
converted = converted | locator
648+
649+
def _convert_unit(idf: pd.DataFrame) -> pd.DataFrame:
650+
start_units = idf.pix.unique(units_col).tolist()
651+
if len(start_units) > 1:
652+
msg = f"{start_units=}"
653+
raise AssertionError(msg)
654+
655+
start_units = start_units[0]
656+
conversion_factor = UR.Quantity(1, start_units).to(target_unit).m
657+
658+
return (idf * conversion_factor).pix.assign(**{units_col: target_unit})
659+
660+
out_l.append(
661+
indf.loc[locator]
662+
.groupby(units_col, observed=True, group_keys=False)
663+
.apply(_convert_unit)
664+
)
665+
666+
out = pix.concat([*out_l, indf.loc[~converted]])
667+
if isinstance(indf.index.dtypes[units_col], pd.CategoricalDtype):
668+
# Make sure that units stay as a category, if it started as one.
669+
out = out.reset_index(units_col)
670+
out[units_col] = out[units_col].astype("category")
671+
out = out.set_index(units_col, append=True).reorder_levels(indf.index.names)
672+
673+
out
420674

421675
# %% [markdown]
422-
# - plot with basic control over labels
423-
# - plot with grouping and plumes for ranges (basically reproduce scmdata API)
424676
# - convert with more fine-grained control over interpolation
425677
# (e.g. interpolation being passed as pd.Series)
426-
# - unit conversion

src/continuous_timeseries/pandas_accessors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,16 @@ def differentiate(
218218

219219
return res
220220

221+
def groupby_except(
222+
self, non_groupers: str | list[str], observed: bool = True
223+
) -> pd.core.groupby.generic.SeriesGroupBy:
224+
if isinstance(non_groupers, str):
225+
non_groupers = [non_groupers]
226+
227+
return self._series.groupby(
228+
self._series.index.names.difference(non_groupers), observed=observed
229+
)
230+
221231
def plot(
222232
self,
223233
label: str | tuple[str, ...] | None = None,
@@ -341,6 +351,9 @@ def get_timeseries_parallel_helper(
341351

342352
res = getattr(df, meth_to_call)(
343353
# TODO: make this injectable too
354+
# This will also allow us to introduce an extra layer
355+
# to handle the case when interpolation is a Series,
356+
# rather than the same across all rows.
344357
Timeseries.from_pandas_series,
345358
axis="columns",
346359
interpolation=interpolation,
@@ -421,6 +434,23 @@ def to_timeseries( # noqa: PLR0913
421434

422435
return res
423436

437+
def groupby_except(
438+
self, non_groupers: str | list[str], observed: bool = True
439+
) -> pd.core.groupby.generic.DataFrameGroupBy:
440+
if isinstance(non_groupers, str):
441+
non_groupers = [non_groupers]
442+
443+
return self._df.groupby(
444+
self._df.index.names.difference(non_groupers), observed=observed
445+
)
446+
447+
def fix_index_name_after_groupby_quantile(self) -> pd.DataFrame:
448+
# TODO: think about doing in place
449+
res = self._df.copy()
450+
res.index = res.index.rename({None: "quantile"})
451+
452+
return res
453+
424454

425455
def register_pandas_accessor(namespace: str = "ct") -> None:
426456
"""

0 commit comments

Comments
 (0)