@@ -374,53 +374,304 @@ def create_df(
374374)
375375
376376# %% 
377- plumes_over  =  ["run" ]
378- increase_resolution  =  100 
377+ from  itertools  import  cycle 
378+ 
379+ import  matplotlib .lines  as  mlines 
380+ import  matplotlib .patches  as  mpatches 
381+ 
382+ fig , ax  =  plt .subplots ()
383+ in_ts  =  small_ts .loc [pix .isin (variable = "variable_0" )]
384+ quantile_over  =  "run" 
385+ pre_calculated  =  False 
386+ observed  =  True 
379387quantiles_plumes  =  (
380-     (0.5 , 0.8 ),
388+     ((0.5 ,), 0.8 ),
389+     ((0.25 , 0.75 ), 0.75 ),
381390    ((0.05 , 0.95 ), 0.5 ),
382391)
392+ hue_var  =  "scenario" 
393+ hue_var_label  =  None 
394+ style_var  =  "variable" 
395+ style_var_label  =  None 
396+ palette  =  None 
397+ dashes  =  None 
398+ observed  =  True 
399+ increase_resolution  =  100 
400+ linewidth  =  2 
401+ 
402+ # The joy of plotting, you create everything yourself. 
403+ # TODO: split creation from use? 
404+ if  hue_var_label  is  None :
405+     hue_var_label  =  hue_var .capitalize ()
406+ if  style_var_label  is  None :
407+     style_var_label  =  style_var .capitalize ()
408+ 
409+ quantiles  =  []
410+ for  quantile_plot_def  in  quantiles_plumes :
411+     q_def  =  quantile_plot_def [0 ]
412+     try :
413+         for  q  in  q_def :
414+             quantiles .append (q )
415+     except  TypeError :
416+         quantiles .append (q_def )
417+ 
418+ _palette  =  {} if  palette  is  None  else  palette 
419+ 
420+ if  dashes  is  None :
421+     _dashes  =  {}
422+     lines  =  ["-" , "--" , "-." , ":" ]
423+     linestyle_cycler  =  cycle (lines )
424+ else :
425+     _dashes  =  dashes 
426+ 
427+ # Need to keep track of this, just in case we end up plotting only plumes 
428+ _plotted_lines  =  False 
429+ 
430+ quantile_labels  =  {}
431+ plotted_hues  =  []
432+ plotted_styles  =  []
433+ units_l  =  []
434+ for  q , alpha  in  quantiles_plumes :
435+     for  hue_value , hue_ts  in  in_ts .groupby (hue_var , observed = observed ):
436+         for  style_value , hue_style_ts  in  hue_ts .groupby (style_var , observed = observed ):
437+             # Remake in inner loop to avoid leaking between plots 
438+             pkwargs  =  {"alpha" : alpha }
439+ 
440+             if  pre_calculated :
441+                 # Should add some checks here 
442+                 raise  NotImplementedError ()
443+                 # Maybe something like the below 
444+                 # missing_quantile = False 
445+                 # for qt in q: 
446+                 #     if qt not in quantiles: 
447+                 #         warnings.warn( 
448+                 #             f"Quantile {qt} not available for {hue_value=} {style_value=}" 
449+                 #         ) 
450+                 #         missing_quantile = True 
451+ 
452+                 # if missing_quantile: 
453+                 #     continue 
454+             else :
455+                 _pdf  =  (
456+                     hue_ts .ct .to_df (increase_resolution = increase_resolution )
457+                     .ct .groupby_except (quantile_over )
458+                     .quantile (quantiles )
459+                     .ct .fix_index_name_after_groupby_quantile ()
460+                 )
383461
384- fig , ax  =  plt .subplots ()
385- for  scenario , s_ts  in  small_ts .loc [pix .isin (variable = "variable_0" )].groupby (
386-     "scenario" , observed = True 
387- ):
388-     for  quantiles , alpha  in  quantiles_plumes :
389-         s_quants  =  (
390-             s_ts .ct .to_df (increase_resolution = increase_resolution )
391-             .groupby (small_ts .index .names .difference (plumes_over ), observed = True )
392-             .quantile (quantiles )
393-         )
394-         if  isinstance (quantiles , tuple ):
395-             ax .fill_between (
396-                 s_quants .columns .values .squeeze (),
397-                 # As long as there are only two rows, 
398-                 # doesn't matter which way around you do this. 
399-                 s_quants .iloc [0 , :].values .squeeze (),
400-                 s_quants .iloc [1 , :].values .squeeze (),
401-                 alpha = alpha ,
402-                 # label=scenario, 
403-             )
404-         else :
405-             ax .plot (
406-                 s_quants .columns .values .squeeze (),
407-                 s_quants .values .squeeze (),
408-                 alpha = alpha ,
409-                 label = scenario ,
462+             if  hue_value  not  in plotted_hues :
463+                 plotted_hues .append (hue_value )
464+ 
465+             x_vals  =  _pdf .columns .values .squeeze ()
466+             # Require ur for this to work 
467+             # x_vals = get_plot_vals( 
468+             #     self.time_axis.bounds, 
469+             #     "self.time_axis.bounds", 
470+             #     warn_if_magnitudes=warn_if_plotting_magnitudes, 
471+             # ) 
472+ 
473+             if  palette  is  not None :
474+                 try :
475+                     pkwargs ["color" ] =  _palette [hue_value ]
476+                 except  KeyError :
477+                     error_msg  =  f"{ hue_value } { palette = }  
478+                     raise  KeyError (error_msg )
479+             elif  hue_value  in  _palette :
480+                 pkwargs ["color" ] =  _palette [hue_value ]
481+             # else: 
482+             #     # Let matplotlib default cycling do its thing 
483+ 
484+             n_q_for_plume  =  2 
485+             plot_plume  =  len (q ) ==  n_q_for_plume 
486+             plot_line  =  len (q ) ==  1 
487+ 
488+             if  plot_plume :
489+                 label  =  f"{ q [0 ] *  100 :.0f} { q [1 ] *  100 :.0f}  
490+ 
491+                 y_lower_vals  =  _pdf .loc [pix .ismatch (quantile = q [0 ])].values .squeeze ()
492+                 y_upper_vals  =  _pdf .loc [pix .ismatch (quantile = q [1 ])].values .squeeze ()
493+                 # Require ur for this to work 
494+                 # Also need the 1D check back in too 
495+                 # y_lower_vals = get_plot_vals( 
496+                 #     self.time_axis.bounds, 
497+                 #     "self.time_axis.bounds", 
498+                 #     warn_if_magnitudes=warn_if_plotting_magnitudes, 
499+                 # ) 
500+                 p  =  ax .fill_between (
501+                     x_vals ,
502+                     y_lower_vals ,
503+                     y_upper_vals ,
504+                     label = label ,
505+                     ** pkwargs ,
506+                 )
507+ 
508+                 if  palette  is  None :
509+                     _palette [hue_value ] =  p .get_facecolor ()[0 ]
510+ 
511+             elif  plot_line :
512+                 if  style_value  not  in plotted_styles :
513+                     plotted_styles .append (style_value )
514+ 
515+                 _plotted_lines  =  True 
516+ 
517+                 if  dashes  is  not None :
518+                     try :
519+                         pkwargs ["linestyle" ] =  _dashes [style_value ]
520+                     except  KeyError :
521+                         error_msg  =  f"{ style_value } { dashes = }  
522+                         raise  KeyError (error_msg )
523+                 else :
524+                     if  style_value  not  in _dashes :
525+                         _dashes [style_value ] =  next (linestyle_cycler )
526+ 
527+                     pkwargs ["linestyle" ] =  _dashes [style_value ]
528+ 
529+                 if  isinstance (q [0 ], str ):
530+                     label  =  q [0 ]
531+                 else :
532+                     label  =  f"{ q [0 ] *  100 :.0f}  
533+ 
534+                 y_vals  =  _pdf .loc [pix .ismatch (quantile = q [0 ])].values .squeeze ()
535+                 # Require ur for this to work 
536+                 # Also need the 1D check back in too 
537+                 # y_vals = get_plot_vals( 
538+                 #     self.time_axis.bounds, 
539+                 #     "self.time_axis.bounds", 
540+                 #     warn_if_magnitudes=warn_if_plotting_magnitudes, 
541+                 # ) 
542+                 p  =  ax .plot (
543+                     x_vals ,
544+                     y_vals ,
545+                     label = label ,
546+                     linewidth = linewidth ,
547+                     ** pkwargs ,
548+                 )[0 ]
549+ 
550+                 if  dashes  is  None :
551+                     _dashes [style_value ] =  p .get_linestyle ()
552+ 
553+                 if  palette  is  None :
554+                     _palette [hue_value ] =  p .get_color ()
555+ 
556+             else :
557+                 msg  =  f"quantiles to plot must be of length one or two, received: { q }  
558+                 raise  ValueError (msg )
559+ 
560+             if  label  not  in quantile_labels :
561+                 quantile_labels [label ] =  p 
562+ 
563+             # Once we have unit handling with matplotlib, we can remove this 
564+             # (and if matplotlib isn't set up, we just don't do unit handling) 
565+             units_l .extend (_pdf .pix .unique ("units" ).unique ().tolist ())
566+ 
567+     # Fake the line handles for the legend 
568+     hue_val_lines  =  [
569+         mlines .Line2D ([0 ], [0 ], color = _palette [hue_value ], label = hue_value )
570+         for  hue_value  in  plotted_hues 
571+     ]
572+ 
573+     legend_items  =  [
574+         mpatches .Patch (alpha = 0 , label = "Quantiles" ),
575+         * quantile_labels .values (),
576+         mpatches .Patch (alpha = 0 , label = hue_var_label ),
577+         * hue_val_lines ,
578+     ]
579+ 
580+     if  _plotted_lines :
581+         style_val_lines  =  [
582+             mlines .Line2D (
583+                 [0 ],
584+                 [0 ],
585+                 linestyle = _dashes [style_value ],
586+                 label = style_value ,
587+                 color = "gray" ,
588+                 linewidth = linewidth ,
410589            )
590+             for  style_value  in  plotted_styles 
591+         ]
592+         legend_items  +=  [
593+             mpatches .Patch (alpha = 0 , label = style_var_label ),
594+             * style_val_lines ,
595+         ]
596+     elif  dashes  is  not None :
597+         warnings .warn (
598+             "`dashes` was passed but no lines were plotted, the style settings " 
599+             "will not be used" 
600+         )
411601
412- ax .legend ()
602+     ax .legend (handles = legend_items , loc = "best" )
603+ 
604+     if  len (set (units_l )) ==  1 :
605+         ax .set_ylabel (units_l [0 ])
606+ 
607+     # return ax, legend_items 
608+ 
609+ 
610+ quantiles 
413611
414612# %% 
415- (
613+ demo_q   =   (
416614    small_ts .ct .to_df (increase_resolution = 5 )
417-     .groupby ( small_ts . index . names . difference ([ "run" ]),  observed = True )
615+     .ct . groupby_except ( "run" )
418616    .quantile ([0.05 , 0.5 , 0.95 ])
617+     .ct .fix_index_name_after_groupby_quantile ()
618+ )
619+ demo_q 
620+ 
621+ # %% 
622+ units_col  =  "units" 
623+ indf  =  demo_q 
624+ out_l  =  []
625+ 
626+ # The 'shortcut' 
627+ target_units  =  "Gt / yr" 
628+ locs_target_units  =  ((pix .ismatch (** {units_col : "**" }), target_units ),)
629+ locs_target_units  =  (
630+     (pix .ismatch (scenario = "scenario_2" ), "Gt / yr" ),
631+     (pix .ismatch (scenario = "scenario_0" ), "kt / yr" ),
632+     (
633+         demo_q .index .get_level_values ("scenario" ).isin (["scenario_1" ])
634+         &  demo_q .index .get_level_values ("variable" ).isin (["variable_1" ]),
635+         "t / yr" ,
636+     ),
419637)
638+ # locs_target_units = ( 
639+ #     (pix.ismatch(scenario="*"), "t / yr"), 
640+ # ) 
641+ 
642+ converted  =  None 
643+ for  locator , target_unit  in  locs_target_units :
644+     if  converted  is  None :
645+         converted  =  locator 
646+     else :
647+         converted  =  converted  |  locator 
648+ 
649+     def  _convert_unit (idf : pd .DataFrame ) ->  pd .DataFrame :
650+         start_units  =  idf .pix .unique (units_col ).tolist ()
651+         if  len (start_units ) >  1 :
652+             msg  =  f"{ start_units = }  
653+             raise  AssertionError (msg )
654+ 
655+         start_units  =  start_units [0 ]
656+         conversion_factor  =  UR .Quantity (1 , start_units ).to (target_unit ).m 
657+ 
658+         return  (idf  *  conversion_factor ).pix .assign (** {units_col : target_unit })
659+ 
660+     out_l .append (
661+         indf .loc [locator ]
662+         .groupby (units_col , observed = True , group_keys = False )
663+         .apply (_convert_unit )
664+     )
665+ 
666+ out  =  pix .concat ([* out_l , indf .loc [~ converted ]])
667+ if  isinstance (indf .index .dtypes [units_col ], pd .CategoricalDtype ):
668+     # Make sure that units stay as a category, if it started as one. 
669+     out  =  out .reset_index (units_col )
670+     out [units_col ] =  out [units_col ].astype ("category" )
671+     out  =  out .set_index (units_col , append = True ).reorder_levels (indf .index .names )
672+ 
673+ out 
420674
421675# %% [markdown] 
422- # - plot with basic control over labels 
423- # - plot with grouping and plumes for ranges (basically reproduce scmdata API) 
424676# - convert with more fine-grained control over interpolation 
425677#   (e.g. interpolation being passed as pd.Series) 
426- # - unit conversion 
0 commit comments