import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import xarray as xr
import pandas as pd
import proplot as pplt
from datetime import datetime
import warnings
from scipy.stats import linregress
import json
import os
from unox import data as udata
from unox.HPC.data0.dataset import uarray
from unox.HPC.data0.verify_dataset import verify_dataset, verify_var
from unox.HPC.data0.verify_dtype import verify_number
from unox.HPC.data0.paths import verify_path
from unox import plot_format as uplt_fmt
# Set font sizes
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12
mpl.rcParams['legend.fontsize'] = 12
title_font_size = 20
# Add a-b-c labels to subplots
pplt.rc.abc = True
[docs]
def plot_extent(
dataset,
**kwargs,
):
""" Plot the geographic extent of a dataset.
Create a Robinson projection map showing the min/max latitude and longitude
of the dataset.
Parameters
----------
dataset : `str`, `xarray.Dataset`, `xarray.DataArray`
The dataset for which to plot the latitude/longitude extent.
**kwargs : keyword arguments
Additional keyword arguments to pass to `uarray()`.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plot.
Examples
--------
>>> fig = plot_extent('inputfiles/no2_2019_JFM/no2_2019_JFM.nc')
>>> fig = plot_extent('no2_2019_JFM', is_input_set=True)
"""
# Verify argument types
# Making `uarray` object verifies `dataset`
u_arr = uarray(dataset, **kwargs)
# Find the min and max lat and lon values
lat_min, lat_max, lon_min, lon_max = udata.get_extent(u_arr.xr)
# Find the midpoint of the longitude values to center the map
lon_mid = (lon_min + lon_max) / 2
# Create the figure
fig = pplt.figure(refwidth=10)
axs = fig.subplots(nrows=1, proj='robin', proj_kw={'lon_0': lon_mid})
# Plot the extent as a bounding box
axs.plot([lon_min, lon_min, lon_max, lon_max, lon_min],
[lat_min, lat_max, lat_max, lat_min, lat_min],
color='red', lw=2)
# Format the map
axs.format(
suptitle=f'Extent of {u_arr.name}',
latlines=30, lonlines=30, coast=True,
labels=True, gridminor=True
)
# Return the figure
return fig
[docs]
def plot_lats_lons(
dataset,
padding=0.1,
**kwargs,
):
""" Plot the latitude/longitude grid of a dataset.
Create a checkerboard map showing the latitude and longitude resolution of
the dataset.
Parameters
----------
dataset : `str`, `xarray.Dataset`, `xarray.DataArray`
The dataset for which to plot the latitude and longitude grid.
padding : `float`, optional
The padding (in a fraction of total extent) to add to the extent of the map.
Default is `0.1`.
**kwargs : keyword arguments
Additional keyword arguments to pass to `uarray()`.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plot.
Examples
--------
>>> fig = plot_lats_lons('inputfiles/no2_2019_JFM/no2_2019_JFM.nc')
>>> fig = plot_lats_lons('no2_2019_JFM', is_input_set=True)
"""
# Verify argument types
# Making `uarray` object verifies `dataset`
u_arr = uarray(dataset, **kwargs)
# `padding` is verified in `pad_extent()`
# Find the min and max lat and lon values
this_extent = udata.get_extent(u_arr.xr)
# Enlarge the extent of the map by the given padding value
p_lat_min, p_lat_max, p_lon_min, p_lon_max = uplt_fmt.pad_extent(this_extent, padding)
# Make a meshgrid of the lat and lon values
longrid, latgrid = np.meshgrid(u_arr.xr.lon.values, u_arr.xr.lat.values)
# Create the figure
fig = pplt.figure(refwidth=10)
axs = fig.subplots(nrows=1, proj='cyl')
# Select medium resolution for features such as coastlines
pplt.rc.reso = 'med'
# Make a checkerboard pattern the size of the lat/longrid
chk_brd = np.indices(longrid.shape).sum(axis=0) % 2
# Plot a checker pattern of the lat and lon values
axs.pcolorfast(longrid, latgrid, chk_brd, cmap="gray", alpha=0.5)
# Format the map
axs.format(
lonlim=(p_lon_min, p_lon_max), latlim=(p_lat_min, p_lat_max),
suptitle=f'Coordinates of {u_arr.name}',
latlines=10, lonlines=10, coast=True,
labels=True, gridminor=True
)
# Return the figure
return fig
[docs]
def map_ax(
xr_arr,
ax,
plt_title=None,
cmap=pplt.Colormap('Fire'),
cbar_max=None,
cbar_min=None,
cb_ext='neither',
padding=0.1,
**kwargs,
):
""" Plot a geographic map from an xarray DataArray.
Create a map of the provided xarray DataArray on the given axes using the
latitude and longitude coordinates from the data.
Parameters
----------
xr_arr : `xarray.DataArray`
The xarray data to plot. Must not have a time dimension.
ax : `matplotlib.axes.Axes`
The axes on which to plot the data.
plt_title : `str`, optional
The title for the plot.
Default is `None`.
cmap : `matplotlib.colors.Colormap`, optional
The colormap to use for the plot.
Default is `pplt.cm.Fire`.
cbar_max : `float`, optional
Maximum value for the colorbar. When `None`, it uses the max value in the data.
Default is `None`.
cbar_min : `float`, optional
Minimum value for the colorbar. When `None`, it uses the min value in the data.
Default is `None`.
cb_ext : `str`, optional
How to extend the ends of the colorbar. Can be `'neither'`, `'both'`, `'min'`, or `'max'`.
Default is `'neither'`.
padding : `float`, optional
The padding (fraction of total extent) to add to the map extent.
Default is `0.1`.
**kwargs : keyword arguments
Additional keyword arguments accepted to facilitate wrapper functions.
Returns
-------
this_map_ax : `matplotlib.axes.Axes`
The axes object containing the plot.
clrbar_label : `str`
The label for the colorbar containing the variable name and units.
Examples
--------
>>> import proplot as pplt
>>> fig, axs = pplt.subplots(nrows=1, ncols=2, proj='cyl')
>>> this_xr = uarray('no2_2019_JFM', is_input_set=True).xr['no2']
>>> this_xr = this_xr.sel(time='2019-01-02').squeeze(drop=True)
>>> this_ax, cb_label = map_ax(this_xr, axs[0])
"""
# Verify argument types
if not isinstance(xr_arr, xr.DataArray):
raise TypeError(f"(map_ax) `xr_arr` must be an xarray DataArray. Got type: {type(xr_arr)}")
if not isinstance(ax, pplt.axes.Axes):
raise TypeError(f"(map_ax) `ax` must be a proplot Axes object. Got type: {type(ax)}")
if not isinstance(plt_title, (type(None), str)):
raise TypeError(f"(map_ax) `plt_title` must be a string or None. Got type: {type(plt_title)}")
if not isinstance(cmap, mpl.colors.Colormap):
raise TypeError(f"(map_ax) `cmap` must be a matplotlib Colormap. Got type: {type(cmap)}")
if not isinstance(cbar_max, type(None)):
verify_number(cbar_max)
if not isinstance(cbar_min, type(None)):
verify_number(cbar_min)
if cb_ext not in ['neither', 'both', 'min', 'max']:
raise ValueError(f"(map_ax) `cb_ext` must be 'neither', 'both', 'min', or 'max'. Got: {cb_ext}")
# `padding` is verified in `pad_extent()`
# Verify the xr_arr. Assume there is no time dimension
xr_arr = verify_dataset(xr_arr, check_time=False)
# If there are any dimensions of size 1 (var, for example), squeeze them out
xr_arr = xr_arr.squeeze(drop=True)
# Check to ensure that `lat` and `lon` are the only remaining dimensions
if not set(xr_arr.dims).issubset({'lat', 'lon'}):
raise ValueError(f"(map_ax) `xr_arr` must have only 'lat' and 'lon' dimensions after squeezing. Got dimensions: {xr_arr.dims}")
# Get the variable name from xr_arr
var = xr_arr.name
# Get the long name and units of the specified variable for plot labels
try:
var_name = xr_arr.long_name
var_unit = xr_arr.units
except:
var_name = 'var'
var_unit = 'units'
# Find the min and max lat and lon values
this_extent = udata.get_extent(xr_arr, check_time=False)
# Enlarge the extent of the map by the given padding value
p_lat_min, p_lat_max, p_lon_min, p_lon_max = uplt_fmt.pad_extent(this_extent, padding)
# Get the maximum value for the colorbar
if isinstance(cbar_max, type(None)):
cbar_max = xr_arr.max()
cbar_max = cbar_max.values
cbar_max = np.unique(cbar_max)[0]
# Get the minimum value for the colorbar
if isinstance(cbar_min, type(None)):
cbar_min = xr_arr.min()
cbar_min = cbar_min.values
cbar_min = np.unique(cbar_min)[0]
# Plot the data, use `discrete=False` to set a continuous colorbar
this_map_ax = ax.pcolormesh(xr_arr, vmin=cbar_min, vmax=cbar_max, discrete=False, extend=cb_ext, cmap=cmap)
# Format the map
ax.format(
lonlim=(p_lon_min, p_lon_max), latlim=(p_lat_min, p_lat_max),
title=plt_title,
latlines=10, lonlines=10, coast=True,
labels=True, gridminor=True
)
# Assemble colorbar label
clrbar_label = f"{var_name} ({var_unit})"
# Return the axis plot and colorbar label
return this_map_ax, clrbar_label
[docs]
def plot_var_maps(
dataset,
vars=['nox'],
restrict_lat_lon_to=None,
ens_mem=None,
avg_over=True,
sum_over=False,
add_title=True,
add_clrbar=True,
**kwargs,
):
""" Plot maps for one or more variables in a dataset.
A wrapper around `map_ax()` that creates a map for each specified variable.
Parameters
----------
dataset : `uarray`, `str`, `xarray.Dataset`, `xarray.DataArray`
The dataset for which to plot the specified variables.
vars : `str`, `list`, optional
The name(s) of the variable(s) to plot from the dataset.
Default is `['nox']`.
restrict_lat_lon_to : `str`, `xr.DataArray`, `None`, optional
Path to a netCDF file to restrict the latitude and longitude range.
If `None`, the entire dataset is used.
Default is `None`.
ens_mem : `int`, `None`, optional
The ID of the ensemble member to plot.
If `None`, the dataset is assumed not to have multiple ensemble members.
Default is `None`.
avg_over : `bool`, optional
Whether to average over all time steps.
Cannot be `True` at the same time as `sum_over`.
Default is `True`.
sum_over : `bool`, optional
Whether to sum over all time steps.
Cannot be `True` at the same time as `avg_over`.
Default is `False`.
add_title : `bool`, optional
Whether to add a title to the figure.
Default is `True`.
add_clrbar : `bool`, optional
Whether to add a colorbar to each subplot.
Default is `True`.
**kwargs : keyword arguments
Additional keyword arguments to pass to `uarray`, `select_time()`, `set_fig_row_col()`, and `map_ax()`.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plot.
Examples
--------
>>> fig = plot_var_maps('no2_example_run', is_predict=True, vars=['no2_pred'], datetime='2019-06-01', interval='30D', avg_over=True)
"""
# Verify argument types
# Making `uarray` object verifies `dataset`
u_arr = uarray(dataset, **kwargs)
if not isinstance(vars, list):
if isinstance(vars, str):
vars = [vars]
else:
raise TypeError(f"(plot_var_maps) `vars` must be a list of variable names or a single variable name string. Got type: {type(vars)}")
else:
for var in vars:
if not isinstance(var, str):
raise TypeError(f"(plot_var_maps) Each entry in `vars` must be a string. Got type: {type(var)}")
if len(vars) == 0:
raise ValueError("(plot_var_maps) `vars` list cannot be empty.")
if not isinstance(restrict_lat_lon_to, (type(None), str, xr.DataArray)):
raise TypeError(f"(plot_var_maps) `restrict_lat_lon_to` must be a string, `xr.DataArray`, or `None`. Got type: {type(restrict_lat_lon_to)}")
if isinstance(ens_mem, int):
title_ens_ID = f"({ens_mem:02d})"
ens_ID = f"_{ens_mem:02d}"
elif isinstance(ens_mem, type(None)):
title_ens_ID = ""
ens_ID = ""
else:
raise TypeError(f"(plot_var_maps) `ens_mem` must be an integer or None. Got type: {type(ens_mem)}")
if not isinstance(avg_over, bool):
raise TypeError(f"(plot_var_maps) `avg_over` must be a bool. Got type: {type(avg_over)}")
if not isinstance(sum_over, bool):
raise TypeError(f"(plot_var_maps) `sum_over` must be a bool. Got type: {type(sum_over)}")
if avg_over == True and sum_over == True:
raise ValueError("(plot_var_maps) Cannot have both `avg_over` and `sum_over` set to `True`.")
if avg_over == False and sum_over == False:
raise ValueError("(plot_var_maps) Cannot have both `avg_over` and `sum_over` set to `False`.")
# Select the time slice to plot
u_arr.xr, title_segment = select_time(u_arr.xr, avg_over=avg_over, sum_over=sum_over, **kwargs)
# Restrict the latitude and longitude range
if not isinstance(restrict_lat_lon_to, type(None)):
# Load the specified data set to restrict to
restrict_xr = uarray(restrict_lat_lon_to).xr
# Restrict the domain of the data to plot
u_arr.xr, _ = udata.match_domains(u_arr.xr, restrict_xr, require_equal=False)
# Check whether the dataset is an ensemble of runs
if u_arr._is_ensemble():
# If `ens_mem` was not specified and only one variable given, plot all ensemble members
if isinstance(ens_mem, type(None)):
if len(vars) == 1:
# Get the number of ensemble members
ens_size = u_arr.xr.attrs['ensemble_size']
# Create the list of variables to plot
this_var = vars[0]
vars = []
for i in range(1, ens_size+1):
vars.append(f"{this_var}_{i:02d}")
# Add onto the title
title_ens_ID = f"(all ensemble members)"
else:
raise ValueError(f"(plot_var_maps) `dataset` is an ensemble of runs but `ens_mem` was not specified and multiple variables were given to plot. Please specify `ens_mem` or provide a single variable to plot all ensemble members.")
elif not isinstance(ens_mem, type(None)):
raise ValueError(f"(plot_var_maps) `ens_mem` specified as {ens_mem} but dataset is not an ensemble of runs.")
# Create the figure
fig = pplt.figure(refwidth=10)
n_rows, n_cols = uplt_fmt.set_fig_row_col(len(vars), **kwargs)
axs = fig.subplots(nrows=n_rows, ncols=n_cols, proj='cyl')
# Select medium resolution for features such as coastlines
pplt.rc.reso = 'med'
# Plot each of the variables
for i in range(len(vars)):
var = vars[i]+ens_ID
# Verify that the variable is in the dataset
verify_var(u_arr.xr, var)
# Reduce the dataset to just the specified variable
var_xr = u_arr.xr[var]
# Add the plot to the axis
this_var, clrbar_label = map_ax(
var_xr,
axs[i],
**kwargs,
)
# Add a colorbar
if add_clrbar:
axs[i].colorbar(this_var, loc='b', label=clrbar_label)
# Add an overall title
if add_title:
fig.suptitle(f"{u_arr.name}{title_ens_ID} {title_segment}", fontsize=title_font_size)
# Return the figure
return fig
[docs]
def select_time(
xr_data,
start_date=None,
end_date=None,
interval=None,
avg_over=False,
sum_over=False,
**kwargs,
):
""" Select a time slice from an xarray dataset.
Either select a single time slice, average over a time period, or sum over the
entire available time dimension to return an xarray without a time dimension.
Parameters
----------
xr_data : `xarray.Dataset`, `xarray.DataArray`
The xarray data to plot. Must have a time dimension.
start_date : `str`, `None`, optional
First date and time to select from the data file.
Default is `None`.
end_date : `str`, `None`, optional
Last date and time to select from the data file.
Default is `None`.
interval : `str`, `numpy.timedelta64`, `None`, optional
If provided and `end_date` is `None`, calculates `end_date` based on this
interval. If a string, must be in the format `XT` where `X` is an integer
and `T` is the unit (`D` for days, `M` for months, or `Y` for years).
Default is `None`.
avg_over : `bool`, optional
Whether to average over all time steps.
Cannot be `True` at the same time as `sum_over`.
Default is `False`.
sum_over : `bool`, optional
Whether to sum across all time steps.
Cannot be `True` at the same time as `avg_over`.
Default is `False`.
**kwargs : keyword arguments
Additional keyword arguments accepted to facilitate wrapper functions.
Returns
-------
xr_sel_time : `xarray.Dataset`, `xarray.DataArray`
The input xarray over the specified time interval.
If either `avg_over` or `sum_over` is `True`, the return will not have a time dimension.
title_segment : `str`
A segment of the title string for the plot with time information.
Examples
--------
>>> xr_data = uarray('no2_2019_JFM', is_input_set=True).xr
>>> xr_sel_time, title_segment = select_time(xr_data, datetime='2019-01-15')
"""
# Verify argument types
xr_data = verify_dataset(xr_data, check_time=True)
if not isinstance(start_date, (type(None), str, np.timedelta64)):
raise TypeError(f"(select_time) `start_date` must be None, a string, or a numpy.timedelta64. Got type: {type(start_date)}")
if not isinstance(end_date, (type(None), str, np.timedelta64)):
raise TypeError(f"(select_time) `end_date` must be None, a string, or a numpy.timedelta64. Got type: {type(end_date)}")
if not isinstance(interval, (type(None), str, np.timedelta64)):
raise TypeError(f"(select_time) `interval` must be None, a string, or a numpy.timedelta64. Got type: {type(interval)}")
if not isinstance(avg_over, bool):
raise TypeError(f"(select_time) `avg_over` must be a bool. Got type: {type(avg_over)}")
if not isinstance(sum_over, bool):
raise TypeError(f"(select_time) `sum_over` must be a bool. Got type: {type(sum_over)}")
if avg_over == True and sum_over == True:
raise ValueError("(select_time) Cannot have both `avg_over` and `sum_over` set to `True`.")
# Calculate the start and end dates based on the given parameters
if isinstance(start_date, type(None)):
# Use the first available time value as the start date
start_date = str(xr_data.time.values[0]).split('T')[0].split(' ')[0]
if isinstance(end_date, type(None)) and isinstance(interval, type(None)):
# Use the last available time value as the end date
end_date = str(xr_data.time.values[-1]).split('T')[0].split(' ')[0]
if isinstance(end_date, type(None)) and not isinstance(interval, type(None)):
# Add the interval to the start date to get the end date
try:
end_date = str(udata.add_amount_to_date(start_date, interval)).split('T')[0].split(' ')[0]
except:
raise ValueError(f"(select_time) Invalid `interval` value: {interval}")
# Get the value and unit of the averaging
interval_num, interval_unit = udata.get_increment_info(interval)
# Format a string for the title
interval_str = f" ({interval_num} {interval_unit})"
else:
# Format a string for the title
interval_str = ""
# Select the time interval
## Use `drop=False` in `sel()` to maintain attributes for later
xr_sel_time = xr_data.sel(time=slice(start_date, end_date), drop=False)
# Format the dates for the title string
if np.datetime64(start_date) == np.datetime64(end_date):
date_string = f"on {str(start_date)}"
else:
# Format the start and end date strings for the title
start_date_string = str(xr_sel_time.time.values[0]).split('T')[0].split(' ')[0]
end_date_string = str(xr_sel_time.time.values[-1]).split('T')[0].split(' ')[0]
date_string = f"from {start_date_string} to {end_date_string}"
# Check whether to average or sum over time
if avg_over == True:
# Take the mean over the time axis
## Use `keep_attrs=True` in mean() to maintain attributes for later
xr_sel_time = xr_sel_time.mean(dim='time', keep_attrs=True)
# Format a string for the title
title_segment = f"Averaged {date_string}{interval_str}"
elif sum_over == True:
# Sum all the variables over time
## Use `keep_attrs=True` in mean() to maintain attributes for later
xr_sel_time = xr_data.sum(dim='time', keep_attrs=True).squeeze(drop=True)
# Format a string for the title
title_segment = f"Summed {date_string}{interval_str}"
else:
# Format a string for the title
title_segment = f"{date_string}{interval_str}"
return xr_sel_time, title_segment
[docs]
def plot_run_analysis(
dataset,
start_date='2019-01-02',
interval='1Y',
avg_over=True,
restrict_lat_lon_to=None,
ens_mem=None,
add_corr_plots=True,
stage1_only=False,
clr_bar_scale=0.5,
clr_map=pplt.Colormap('Balance'),
**kwargs,
):
""" Compare a model run to the truth with maps and optional correlation plots.
Create a set of maps comparing the truth and model predictions (stages 1 and 2), and optionally add correlation plots.
If both stages and correlations are plotted, the figure will have 9 subplots:
1. 'Truth' map
2. Stage 1 prediction map
3. Stage 2 prediction map
4. Difference between 'Truth' and Stage 1 map
5. Difference between 'Truth' and Stage 2 map
6. Difference between Stage 1 and Stage 2 map
7. Correlation plot between 'Truth' and Stage 1
8. Correlation plot between 'Truth' and Stage 2
9. Correlation plot between Stage 1 and Stage 2
Parameters
----------
dataset : `str`, `xarray.Dataset`, `xarray.DataArray`, `uarray`
The dataset for which to make comparison maps. Must be a predictions dataset.
start_date : `str`, `None`, optional
First date and time to select from the data file.
Default is `None`.
interval : `str`, `numpy.timedelta64`, `None`, optional
If provided and `end_date` is `None`, calculate `end_date` based on this interval.
If a string, it must be in the format `XT` where `X` is an integer and `T` is the unit
(`D`, `M`, or `Y`).
Default is `None`.
avg_over : `bool`, optional
Whether to average over all time steps.
Default is `True`.
restrict_lat_lon_to : `str`, `xr.DataArray`, `None`, optional
Path to a netCDF file to restrict the latitude and longitude range.
If `None`, the entire dataset is used.
Default is `None`.
ens_mem : `int`, `None`, optional
The ID of the ensemble member to plot.
If `None`, the dataset is assumed not to have multiple ensemble members.
Default is `None`.
add_corr_plots : `bool`, optional
Whether to add a row of correlation plots to the figure.
Default is `True`.
stage1_only : `bool`, optional
If `True`, produce graphs only for stage 1.
If `False`, produce graphs for both stage 1 and stage 2.
Default is `False`.
clr_bar_scale : `float`, optional
Scale factor for the colorbar, between 0 and 1.
Default is `0.5`.
clr_map : `matplotlib.colors.Colormap`, optional
The colormap to use for the map plots.
Default is `pplt.cm.Balance`.
**kwargs : keyword arguments
Additional keyword arguments to pass to `select_time()`, `map_ax()`, and `corr_plot()`.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plots.
Examples
--------
>>> fig = plot_run_analysis('no2_example_run', year=2019, datetime='2019-01-02', \
... avg_over='364D', restrict_lat_lon_to='../datafiles/sample_data/nox_2019_t106_US.nc', add_corr_plots=True)
"""
# Verify argument types
# Making `uarray` object verifies `dataset`
pred_uarr = uarray(dataset, is_predict=True)
if not isinstance(start_date, (type(None), str, np.timedelta64)):
raise TypeError(f"(plot_run_analysis) `start_date` must be None, a string, or a numpy.timedelta64. Got type: {type(start_date)}")
if not isinstance(interval, (type(None), str, np.timedelta64)):
raise TypeError(f"(plot_run_analysis) `interval` must be None, a string, or a numpy.timedelta64. Got type: {type(interval)}")
if not isinstance(avg_over, bool):
raise TypeError(f"(select_time) `avg_over` must be a bool. Got type: {type(avg_over)}")
if not isinstance(restrict_lat_lon_to, (type(None), str, xr.DataArray)):
raise TypeError(f"(plot_run_analysis) `restrict_lat_lon_to` must be a string, `xr.DataArray`, or `None`. Got type: {type(restrict_lat_lon_to)}")
if isinstance(ens_mem, int):
title_ens_ID = f"({ens_mem:02d})"
ens_ID = f"_{ens_mem:02d}"
elif isinstance(ens_mem, type(None)):
title_ens_ID = ""
ens_ID = ""
else:
raise TypeError(f"(plot_var_maps) `ens_mem` must be an integer or None. Got type: {type(ens_mem)}")
if not isinstance(add_corr_plots, bool):
raise TypeError(f"(plot_run_analysis) `add_corr_plots` must be a bool. Got type: {type(add_corr_plots)}")
if not isinstance(stage1_only, bool):
raise TypeError(f"(plot_run_analysis) `stage1_only` must be a bool. Got type: {type(stage1_only)}")
verify_number(clr_bar_scale)
if clr_bar_scale < 0 or clr_bar_scale > 1:
raise ValueError(f"(plot_run_analysis) `clr_bar_scale` must be between 0 and 1. Got: {clr_bar_scale}")
if not isinstance(clr_map, mpl.colors.Colormap):
raise TypeError(f"(plot_run_analysis) `clr_map` must be a matplotlib Colormap. Got type: {type(clr_map)}")
# Check whether the dataset is an ensemble of runs
if pred_uarr._is_ensemble():
# If `ens_mem` was not specified and only one variable given, plot all ensemble members
if isinstance(ens_mem, type(None)):
# Warn the user that only a single ensemble member will be plotted
warnings.warn(f"(plot_run_analysis) `dataset` is an ensemble of runs but `ens_mem` was not specified. Using ensemble member 1 for the plots.")
ens_mem = 1
title_ens_ID = f"({ens_mem:02d})"
ens_ID = f"_{ens_mem:02d}"
elif not isinstance(ens_mem, type(None)):
raise ValueError(f"(plot_var_maps) `ens_mem` specified as {ens_mem} but dataset is not an ensemble of runs.")
# Get the metadata from the predictions uarray
meta_dict = pred_uarr._get_metadata()
# Get the input set from the metadata
input_set = meta_dict['config_dict']['input_set']
# Get the input set used in the HPC run
input_uarr = uarray(input_set, is_input_set=True)
# Get and verify input set
input_xarray = input_uarr.xr
# Get the `y_var` name from the input dataset
y_var = input_uarr.xr.attrs['y_var']
# Make a list for the variables to plot
vars_to_plot = [y_var]
# Verify that the prediction array has the correct variable
pred_var = f"{y_var}_pred{ens_ID}"
verify_var(pred_uarr.xr, pred_var)
vars_to_plot.append(pred_var)
# Decide on the number of rows and columns in the figure
if stage1_only == False:
pred_var_s2 = f"{y_var}_pred_s2{ens_ID}"
verify_var(pred_uarr.xr, pred_var_s2)
vars_to_plot.append(pred_var_s2)
# Set the number of rows and columns in the figure
if add_corr_plots:
n_rows = 3
n_rows_maps = 2
else:
n_rows = 2
n_rows_maps = 2
n_cols = 3
n_maps = 6
else:
# Set the number of rows and columns in the figure
if add_corr_plots:
n_rows = 2
n_rows_maps = 1
n_cols = 2
else:
n_rows = 1
n_rows_maps = 1
n_cols = 3
n_maps = 3
# Trim the latitude and longitude extents to match
pred_uarr.xr, input_uarr.xr = udata.match_domains(pred_uarr.xr, input_uarr.xr)
# Add the "truth" data to the prediction array
pred_uarr.xr[y_var] = input_uarr.xr[y_var]
# Select the time slice to plot
## Note: This will not affect the data used in the correlation plots
pred_uarr.xr, time_title = select_time(pred_uarr.xr, start_date, interval=interval, avg_over=avg_over, **kwargs)
# Restrict the latitude and longitude range
## Note: This will not affect the data used in the correlation plots
if not isinstance(restrict_lat_lon_to, type(None)):
# Load the specified data set to restrict to
restrict_xr = uarray(restrict_lat_lon_to).xr
# Restrict the domain of the data to plot
pred_uarr.xr, _ = udata.match_domains(pred_uarr.xr, restrict_xr, require_equal=False)
input_uarr.xr, _ = udata.match_domains(input_uarr.xr, restrict_xr, require_equal=False)
# Get the units of the y_var
y_var_unit = input_uarr.xr[y_var].units
# Calculate the difference between the "truth" and the predictions
pred_uarr.xr['y_m_st1'] = pred_uarr.xr[y_var] - pred_uarr.xr[pred_var]
pred_uarr.xr['y_m_st1'].attrs = {'long_name': f"'Truth' - Stage 1 prediction", 'units': y_var_unit}
vars_to_plot.append('y_m_st1')
if stage1_only == False:
pred_uarr.xr['y_m_st2'] = pred_uarr.xr[y_var] - pred_uarr.xr[pred_var_s2]
pred_uarr.xr['y_m_st2'].attrs = {'long_name': f"'Truth' - Stage 2 prediction", 'units': y_var_unit}
vars_to_plot.append('y_m_st2')
pred_uarr.xr['st1_m_st2'] = pred_uarr.xr[pred_var] - pred_uarr.xr[pred_var_s2]
pred_uarr.xr['st1_m_st2'].attrs = {'long_name': f"Stage 1 - Stage 2", 'units': y_var_unit}
vars_to_plot.append('st1_m_st2')
# Create tuple of the projections for each subplot
if add_corr_plots == False:
# Only one projection required
these_projs = 'cyl'
else:
# Create a list of projections for each subplot
these_projs = []
for i in range(n_maps):
these_projs.append('cyl')
for i in range((n_rows * n_cols) - n_maps):
these_projs.append(None)
# Create the figure
## Setting `share=False` to allow separate axis labels for each subplot
fig, axs = pplt.subplots(refwidth=4, nrows=n_rows, ncols=n_cols, proj=these_projs, share=False)
# Select medium resolution for features such as coastlines
pplt.rc.reso = 'med'
# Get the maximum and minimum values for each variable
vmin_arr = pred_uarr.xr.min(skipna=True)
vmax_arr = pred_uarr.xr.max(skipna=True)
# Gather the maximum and mimum values across all variables
val_list = []
for var in vmin_arr.data_vars:
val_list.append(vmin_arr[var].values)
val_list.append(vmax_arr[var].values)
# Get the halfrange for use with a diverging color map
chr = udata.get_max_abs_val(val_list)
# Scale the color bar
if clr_bar_scale < 0 or clr_bar_scale > 1:
warnings.warn(f"(plot_run_analysis) `clr_bar_scale` should be between 0 and 1. Got {clr_bar_scale}. Setting it to 0.5.")
clr_bar_scale = 0.5
if clr_bar_scale != 1:
chr *= clr_bar_scale
cbe = 'both'
else:
cbe = 'neither'
# Make blank lists to collect vars and colorbar labels
these_vars = [None]*(n_maps)
these_cblbls = [None]*(n_maps)
# Add the plots to the axes
for i in range(len(vars_to_plot)):
data_arr = pred_uarr.xr[vars_to_plot[i]]
# Add the plot to the axis
these_vars[i], these_cblbls[i] = map_ax(
data_arr,
axs[i],
plt_title=data_arr.attrs['long_name'],
cmap=clr_map,
cbar_max=chr,
cbar_min=-chr,
cb_ext=cbe,
**kwargs,
)
# Determine the colorbar label
if len(set(these_cblbls)) == 1:
cb_label = these_cblbls[0]
else:
cb_label = 'Labels vary'
cb_label = these_cblbls[0]
# Add one overall colorbar for the entire figure on the right-hand side
cbar = make_colorbar(fig, these_vars[-1], cb_label, num_ticks=9, cb_loc='r', cb_extend=cbe, rows=(1, n_rows_maps))
# Add correlation plots, if specified
if add_corr_plots:
# Create arrays to hold the plots
fig_q_list = [None]*3
# Add the three correlation plots to the figure
fig_q_list[0] = corr_plot(
dataset,
is_predict=True,
x_vars='pred',
y_vars='truth',
start_date=start_date,
interval=interval,
axs=axs[n_maps],
restrict_lat_lon_to=restrict_lat_lon_to,
ens_mem=ens_mem,
**kwargs,
)
if stage1_only == False:
fig_q_list[1] = corr_plot(
dataset,
is_predict=True,
x_vars='pred_s2',
y_vars='truth',
start_date=start_date,
interval=interval,
axs=axs[-2],
restrict_lat_lon_to=restrict_lat_lon_to,
ens_mem=ens_mem,
**kwargs,
)
fig_q_list[2] = corr_plot(
dataset,
is_predict=True,
x_vars='pred',
y_vars='pred_s2',
start_date=start_date,
interval=interval,
axs=axs[-1],
restrict_lat_lon_to=restrict_lat_lon_to,
ens_mem=ens_mem,
**kwargs,
)
# Add the colorbar
fig.colorbar(fig_q_list[0], loc='r', label='Count per pixel', extend='both', formatter='sci', rows=(n_rows_maps+1, n_rows))
# Set the figure title
fig.suptitle(f"HPC run: {pred_uarr.name}{title_ens_ID}, input set: {input_set}, {time_title}", fontsize=title_font_size)
return fig
[docs]
def plot_comparison(
a_xr_arr,
b_xr_arr,
ax=None,
plt_title=None,
a_label=None,
b_label=None,
cmap=pplt.Colormap('viridis'),
set_under_val=1,
hist_params={'bins':100, 'vmax':1000, 'vmin':10},
log_scale=True,
**kwargs,
):
""" Plot a comparison of two arrays using a 2D histogram.
Create a correlation plot between the values of the two provided arrays.
Parameters
----------
a_xr_arr : `xarray.DataArray`, `numpy.ndarray`
The first array to compare.
b_xr_arr : `xarray.DataArray`, `numpy.ndarray`
The second array to compare.
ax : `matplotlib.axes.Axes`, optional
The axes on which to plot the data.
If `None`, a new figure is created.
Default is `None`.
plt_title : `str`, optional
The title for the plot.
Default is `None`.
a_label : `str`, optional
The label to use for the first array.
b_label : `str`, optional
The label to use for the second array.
cmap : `matplotlib.colors.Colormap`, optional
The colormap to use for the plot.
Default is `pplt.cm.viridis`.
set_under_val : `float`, optional
The value below which the colormap is set to white.
Default is `1`.
hist_params : `dict`, optional
Parameters for the 2D histogram.
Default is `{'bins': 100, 'vmax': 1000, 'vmin': 10}`.
log_scale : `bool`, optional
Whether to use a logarithmic scale for the histogram.
Default is `True`.
**kwargs : keyword arguments
Additional keyword arguments accepted to facilitate wrapper functions.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plot. Returned if `ax` is `None`.
q : `QuadMesh`
The QuadMesh object created by the 2D histogram. Returned if `ax` is given.
"""
# Verify argument types
if isinstance(a_xr_arr, xr.DataArray):
a_xr_label = f"{a_xr_arr.attrs['long_name']} ({a_xr_arr.attrs['units']})"
a_xr_arr = a_xr_arr.values
elif not isinstance(a_xr_arr, (xr.DataArray, np.ndarray)):
raise TypeError(f"(plot_comparison) `a_xr_arr` must be an xarray DataArray or numpy array. Got type: {type(a_xr_arr)}")
else:
a_xr_label = 'Array A'
if isinstance(b_xr_arr, xr.DataArray):
b_xr_label = f"{b_xr_arr.attrs['long_name']} ({b_xr_arr.attrs['units']})"
b_xr_arr = b_xr_arr.values
elif not isinstance(b_xr_arr, (xr.DataArray, np.ndarray)):
raise TypeError(f"(plot_comparison) `b_xr_arr` must be an xarray DataArray or numpy array. Got type: {type(b_xr_arr)}")
else:
b_xr_label = 'Array B'
if not isinstance(ax, (pplt.axes.Axes, type(None))):
raise TypeError(f"(plot_comparison) `ax` must be a proplot Axes object or None. Got type: {type(ax)}")
if not isinstance(plt_title, (type(None), str)):
raise TypeError(f"(plot_comparison) `plt_title` must be a string or None. Got type: {type(plt_title)}")
if isinstance(a_label, type(None)):
a_label = a_xr_label
elif not isinstance(a_label, str):
raise TypeError(f"(plot_comparison) `a_label` must be a string or None. Got type: {type(a_label)}")
if isinstance(b_label, type(None)):
b_label = b_xr_label
elif not isinstance(b_label, str):
raise TypeError(f"(plot_comparison) `b_label` must be a string or None. Got type: {type(b_label)}")
if not isinstance(cmap, mpl.colors.Colormap):
raise TypeError(f"(plot_comparison) `cmap` must be a matplotlib Colormap. Got type: {type(cmap)}")
verify_number(set_under_val)
if not isinstance(hist_params, dict):
raise TypeError(f"(plot_comparison) `hist_params` must be a dictionary. Got type: {type(hist_params)}")
if not isinstance(log_scale, bool):
raise TypeError(f"(plot_comparison) `log_scale` must be a bool. Got type: {type(log_scale)}")
# Convert the xarray DataArrays to numpy arrays above,
# then squeeze and flatten to get one dimensional arrays
npy_a = np.squeeze(a_xr_arr).flatten()
npy_b = np.squeeze(b_xr_arr).flatten()
# Verify these arrays are the same length
if len(npy_a) != len(npy_b) or len(npy_a) <= 1 or len(npy_b) <= 1:
raise ValueError(f"(plot_comparison) `a_xr_arr` and `b_xr_arr` must have the same number of elements, <= 1. Got lengths {len(npy_a)} and {len(npy_b)} respectively.")
# Create a new figure and axis if none is provided
if isinstance(ax, type(None)):
new_fig = True
else:
new_fig = False
if new_fig:
# Create the figure
fig = pplt.figure(refwidth=4)
ax = fig.subplots(nrows=1, ncols=1)
# Set the values under `set_under_val` to white
cmap.set_under('w', set_under_val)
# Plot the data, depending on the scale
if log_scale:
this_hist, xedges, yedges, q = ax.hist2d(npy_a, npy_b, bins=hist_params['bins'], norm='log', cmap=cmap, vmin=hist_params['vmin'], vmax=hist_params['vmax'], extend='both')
else:
this_hist, xedges, yedges, q = ax.hist2d(npy_a, npy_b, bins=hist_params['bins'], norm='linear', cmap=cmap)
# Count the maximum extent of the histogram where values are larger than vmin
counts_0 = np.sum(this_hist > hist_params['vmin'], axis=0)
counts_1 = np.sum(this_hist > hist_params['vmin'], axis=1)
max_0 = max(np.where(counts_0 > 0, yedges[:-1], 0))
max_1 = max(np.where(counts_1 > 0, xedges[:-1], 0))
padding = 1.1
axis_lim = max(max_0, max_1) * padding
# Add line of y=x
xx = np.arange(0, axis_lim, 1)
ax.plot(xx, xx, 'k--', lw=2)#, label='y=x')
# Limit the x and y axes
ax.set_xlim((0, axis_lim))
ax.set_ylim((0, axis_lim))
# Plot the linear regression between the truth and predicted values
## Only if neither array has all the same values
if np.all(npy_a == npy_a[0]) or np.all(npy_b == npy_b[0]):
warnings.warn("One of the arrays has all the same values. Skipping linear regression.")
else:
# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(npy_a, npy_b)
if intercept < 0:
pm_str = '-'
else:
pm_str = '+'
ax.plot(xx, slope*xx+intercept, 'r--', lw=2, label=rf'$y=%.2fx{pm_str}%.2f$, $R^2$=%.2f'%(slope, abs(intercept), r_value**2))
# Format the plot
ax.set_aspect(1)
ax.legend()
ax.grid()
ax.format(
xlabel=a_label,
ylabel=b_label,
)
# If new plot, return the figure
if new_fig:
# Add the colorbar
ax.colorbar(q, loc='r', label='Count per pixel', formatter='sci')
# Set the figure title
fig.suptitle(plt_title, fontsize=title_font_size)
return fig
else:
return q
[docs]
def corr_plot(
dataset,
x_vars = ['pred'],
y_vars = ['truth'],
axs = None,
restrict_lat_lon_to = None,
ens_mem=None,
**kwargs,
):
""" Create correlation plots between specified variables.
Create a heatmap correlation plot with the specified variables on each axis using
data from the given dataset, filtering by datetime, averaging period, and
optional latitude/longitude restrictions.
Parameters
----------
dataset : `str`, `uarray`, `xarray.Dataset`, `xarray.DataArray`
The dataset from which to get the data for the correlation plot.
x_vars : `list`, `str`, optional
The variable(s) to plot on the x-axis.
Can be `truth`, `pred`, `pred_s2`, or any variable in the dataset.
Default is `['pred']`.
y_vars : `list`, `str`, optional
The variable(s) to plot on the y-axis.
Can be `truth`, `pred`, `pred_s2`, or any variable in the dataset.
Default is `['truth']`.
axs : `list`, `matplotlib.axes.Axes`, `None`, optional
The axes on which to plot the data.
If `None`, a new figure is created.
Default is `None`.
restrict_lat_lon_to : `str`, `xr.DataArray`, `None`, optional
Path to a netCDF file to restrict the latitude and longitude range.
If `None`, the entire dataset is used.
Default is `None`.
ens_mem : `int`, `None`, optional
The ID of the ensemble member to plot.
If `None`, the dataset is assumed not to have multiple ensemble members.
Default is `None`.
**kwargs : dict
Additional keyword arguments to pass to `select_time()` and `plot_comparison()`.
Returns
-------
fig : `matplotlib.figure.Figure`
If no axes were given, return the figure object containing the plot.
fig_q : `QuadMesh`
If axes were given, return the QuadMesh object created by the 2D histogram.
Examples
--------
>>> fig = corr_plot('no2_example_run', is_predict=True, x_ax='pred', y_ax='truth')
"""
# Verify argument types
# Making a `uarray` object verifies `dataset`
u_arr = uarray(dataset, **kwargs)
if not isinstance(x_vars, list):
if isinstance(x_vars, str):
x_vars = [x_vars]
else:
raise ValueError(f"(corr_plot) `x_var` must be a list of variable names or a single variable name string. Got type: {type(x_var)}")
else:
for var in x_vars:
if not isinstance(var, str):
raise TypeError(f"(corr_plot) Each entry in `x_vars` must be a string. Got type: {type(var)}")
if not isinstance(y_vars, list):
if isinstance(y_vars, str):
y_vars = [y_vars]
else:
raise ValueError(f"(corr_plot) `y_var` must be a list of variable names or a single variable name string. Got type: {type(y_var)}")
else:
for var in y_vars:
if not isinstance(var, str):
raise TypeError(f"(corr_plot) Each entry in `y_vars` must be a string. Got type: {type(var)}")
if not isinstance(axs, (list, type(None))):
if isinstance(axs, pplt.axes.Axes):
axs = [axs]
else:
raise TypeError(f"(corr_plot) `axs` must be a proplot Axes object, a list of proplot Axes objects, or None. Got type: {type(axs)}")
elif not isinstance(axs, type(None)):
for ax in axs:
if not isinstance(ax, pplt.axes.Axes):
raise TypeError(f"(corr_plot) Each entry in `axs` must be a proplot Axes object. Got type: {type(ax)}")
if isinstance(axs, type(None)):
if len(x_vars) != len(y_vars):
raise ValueError(f"(corr_plot) `x_vars` and `y_vars` must be the same length. Got lengths {len(x_vars)} and {len(y_vars)}, respectively.")
else:
if len(x_vars) != len(y_vars) and len(x_vars) != len(axs):
raise ValueError(f"(corr_plot) `x_vars`, `y_vars`, and `axs` (if given) must all be the same length. Got lengths {len(x_vars)}, {len(y_vars)}, and {len(axs)}, respectively.")
if not isinstance(restrict_lat_lon_to, (type(None), str, xr.DataArray)):
raise TypeError(f"(plot_run_analysis) `restrict_lat_lon_to` must be a string, `xr.DataArray`, or `None`. Got type: {type(restrict_lat_lon_to)}")
if isinstance(ens_mem, int):
title_ens_ID = f"({ens_mem:02d})"
ens_ID = f"_{ens_mem:02d}"
elif isinstance(ens_mem, type(None)):
title_ens_ID = ""
ens_ID = ""
else:
raise TypeError(f"(plot_var_maps) `ens_mem` must be an integer or None. Got type: {type(ens_mem)}")
# Select the time slice to plot
u_arr.xr, title_segment = select_time(u_arr.xr, **kwargs)
# Restrict the latitude and longitude range
if not isinstance(restrict_lat_lon_to, type(None)):
# Load the specified data set to restrict to
restrict_xr = uarray(restrict_lat_lon_to).xr
# Restrict the domain of the data to plot
u_arr.xr, _ = udata.match_domains(u_arr.xr, restrict_xr, require_equal=False)
# Check whether the dataset is an ensemble of runs
if u_arr._is_ensemble():
# If `ens_mem` was not specified and only one variable given, plot all ensemble members
if isinstance(ens_mem, type(None)):
if len(x_vars) == 1:
# Get the number of ensemble members
ens_size = u_arr.xr.attrs['ensemble_size']
# Create the list of variables to plot
this_x_var = x_vars[0]
this_y_var = y_vars[0]
x_vars = []
y_vars = []
for i in range(1, ens_size+1):
x_vars.append(f"{this_x_var}_{i:02d}")
y_vars.append(f"{this_y_var}_{i:02d}")
# Add onto the title
title_ens_ID = f"(all ensemble members)"
else:
raise ValueError(f"(plot_var_maps) `dataset` is an ensemble of runs but `ens_mem` was not specified and multiple variables were given to plot. Please specify `ens_mem` or provide a single variable to plot all ensemble members.")
elif not isinstance(ens_mem, type(None)):
raise ValueError(f"(plot_var_maps) `ens_mem` specified as {ens_mem} but dataset is not an ensemble of runs.")
# If no axes are given, create a new figure
if isinstance(axs, type(None)):
new_plot = True
fig = pplt.figure(refwidth=4)
n_rows, n_cols = uplt_fmt.set_fig_row_col(len(x_vars), **kwargs)
axs = fig.subplots(nrows=n_rows, ncols=n_cols)
else:
new_plot = False
# Loop across each axis
for i in range(len(x_vars)):
# Get the parameters
ax = axs[i]
x_var = x_vars[i]
y_var = y_vars[i]
# Set the x and y data arrays to `None`
x_xarr = None
y_xarr = None
# Verify the specified x and y axes are in the dataset
if 'pred' in x_var or 'pred' in y_var:
# Make sure the dataset is a prediction uarray
if not u_arr.is_predict:
raise ValueError(f"(corr_plot) To plot 'pred' or 'pred_s2', `dataset` {u_arr.name} must be a prediction HPC run. Got {u_arr.name}.is_predict: {u_arr.is_predict}")
# Get the name of the `y_var` used in the HPC run
HPC_y_var = u_arr.xr.attrs['y_var']
# Add that `y_var` to the predcition axes
if 'pred' in x_var:
x_var = f"{HPC_y_var}_{x_var}{ens_ID}"
x_xarr = u_arr.xr[x_var]
if 'pred' in y_var:
y_var = f"{HPC_y_var}_{y_var}{ens_ID}"
y_xarr = u_arr.xr[y_var]
# Check whether to plot the 'truth'
if 'truth' in x_var or 'truth' in y_var:
# Get the name of the `y_var` used in the input set
HPC_y_var = u_arr.xr.attrs['y_var']
# If the dataset is a prediction uarray
if u_arr.is_predict:
# Get the metadata from the predictions uarray
meta_dict = u_arr._get_metadata()
# Get the input set from the metadata
input_set = meta_dict['config_dict']['input_set']
# Get the input set used in the HPC run
input_uarr = uarray(input_set, is_input_set=True)
# Select the time slice to plot such that it matches the prediction array
pred_start_date = str(u_arr.xr.time.values[0]).split('T')[0].split(' ')[0]
pred_end_date = str(u_arr.xr.time.values[-1]).split('T')[0].split(' ')[0]
# Do not pass keyword arguments into this call of `select_time()`
# to ensure there aren't multiple occurrences of `start_date` or `end_date`
input_uarr.xr, title_segment = select_time(input_uarr.xr, start_date=pred_start_date, end_date=pred_end_date)#, **kwargs)
# Restrict the latitude and longitude range, if applicable
if not isinstance(restrict_lat_lon_to, type(None)):
# Restrict the domain of the data to plot
input_uarr.xr, _ = udata.match_domains(input_uarr.xr, restrict_xr, require_equal=False)
elif u_arr.is_input_set:
input_uarr = u_arr
else:
raise ValueError(f"(corr_plot) To plot 'truth', `dataset` {u_arr.name} must be either a prediction HPC run or an input set. Got is_predict: {u_arr.is_predict}, is_input_set: {u_arr.is_input_set}")
if 'truth' in x_var:
x_xarr = input_uarr.xr[HPC_y_var]
if 'truth' in y_var:
y_xarr = input_uarr.xr[HPC_y_var]
# Check whether both x and y data arrays have been set
if isinstance(x_xarr, type(None)):
# Verify the specified variable is in the dataset
verify_var(u_arr.xr, x_var)
# Set the x data array
x_xarr = u_arr.xr[x_var]
if isinstance(y_xarr, type(None)):
# Verify the specified variable is in the dataset
verify_var(u_arr.xr, y_var)
# Set the x data array
y_xarr = u_arr.xr[y_var]
# Plot the comparison
fig_q = plot_comparison(
x_xarr,
y_xarr,
ax=ax,
**kwargs,
)
if new_plot == True:
# Add an overall title
fig.suptitle(f"{u_arr.name}{title_ens_ID} {title_segment}", fontsize=title_font_size)
# Return the figure
return fig
elif len(x_vars) == 1:
return fig_q
else:
return fig
[docs]
def plot_epochs_logs(
dataset,
vars=None,
axs=None,
plt_title=None,
**kwargs,
):
""" Plot training epoch logs for a prediction dataset.
Create a line plot showing how specified metrics changed across training epochs.
Parameters
----------
dataset : `str`, `xarray.Dataset`, `xarray.DataArray`, `uarray`
The dataset for which to plot epoch logs. Must be a predictions dataset.
vars : `str`, `list`, `None`, optional
The variables to plot. Each variable is plotted on its own axis.
If `None`, all available variables are plotted.
Default is `None` (which plots all variables).
axs : `matplotlib.axes.Axes`, `None`, optional
The axes on which to plot the data.
If `None`, a new figure is created.
Default is `None`.
plt_title : `str`, optional
The title for the plot.
Default is `None`.
**kwargs : keyword arguments
Additional keyword arguments accepted to facilitate wrapper functions.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure object containing the plot. Returned if `axs` is `None`.
axs : `matplotlib.axes.Axes`
The axes containing the plot. Returned if `axs` was provided.
"""
# Verify argument types
# Making `uarray` object verifies `dataset`
dataset = uarray(dataset, **kwargs)
# Check whether the dataset is a prediction set
if not dataset.is_predict:
ValueError(f"(plot_epochs_logs) `dataset` must be a prediction set to plot epochs logs.")
# Check whether the dataset is an ensemble run
if dataset.is_ensemble:
ValueError(f"(plot_epochs_logs) `dataset` cannot be an ensemble run to plot epochs logs (not yet implemented).")
# Define the available vars in the epochs logs
epochs_logs_vars = [
'loss',
'msenonzero',
'r2_keras',
'val_loss',
'val_msenonzero',
'val_r2_keras'
]
if isinstance(vars, type(None)):
vars = epochs_logs_vars
elif not isinstance(vars, list):
if isinstance(vars, str):
vars = [vars]
else:
raise TypeError(f"(plot_epochs_logs) `vars` must be a list of strings, a single string, or `None`. Got type: {type(vars)}")
for var in vars:
if not isinstance(var, str):
raise TypeError(f"(plot_epochs_logs) `var` must be a string. Got type: {type(var)}")
if not isinstance(axs, (list, type(None))):
if isinstance(axs, pplt.axes.Axes):
axs = [axs]
else:
raise TypeError(f"(plot_epochs_logs) `axs` must be a proplot Axes object, a list of proplot Axes objects, or None. Got type: {type(axs)}")
elif not isinstance(axs, type(None)):
for ax in axs:
if not isinstance(ax, pplt.axes.Axes):
raise TypeError(f"(plot_epochs_logs) Each entry in `axs` must be a proplot Axes object. Got type: {type(ax)}")
if not isinstance(plt_title, (type(None), str)):
raise TypeError(f"(plot_epochs_logs) `plt_title` must be a string or None. Got type: {type(plt_title)}")
# Get the epochs data for the given dataset
epochs_logs = dataset._get_epochs_logs()
# If no axes are given, create a new figure
if isinstance(axs, type(None)):
new_plot = True
fig = pplt.figure(refwidth=4, sharey=False)
n_rows, n_cols = uplt_fmt.set_fig_row_col(len(vars), **kwargs)
axs = fig.subplots(nrows=n_rows, ncols=n_cols)
if isinstance(plt_title, type(None)):
plt_title = dataset.name
else:
new_plot = False
# Get the stages of this prediction set
stages = dataset.xr.attrs['stages']
# Loop across the variables
for i in range(len(vars)):
var = vars[i]
# Check to make sure it is a valid variable
if not var in epochs_logs_vars:
raise ValueError(f"(plot_epochs_logs) `vars` index {i} is {var}, however all variables must one from the following list:\n\t{epochs_logs_vars}")
# Loop across the stages
for stage in stages:
# Get the epoch logs for this stage
this_stage = epochs_logs.sel(stage=stage)
# Plot the data, depending on the scale
axs[i].plot(this_stage[var], label=f'stage {stage}')
axs[i].legend()
# If new plot, return the figure
if new_plot:
# Set the figure title
fig.suptitle(plt_title, fontsize=title_font_size)
return fig
else:
return axs
[docs]
def make_colorbar(
fig,
cb_ax,
cb_label,
num_ticks=9,
cb_loc='r',
cb_extend='neither',
**kwargs,
):
""" Create a colorbar for the given figure and axis.
Parameters
----------
fig : `matplotlib.figure.Figure`
The figure on which to add the colorbar.
cb_ax : `cartopy.mpl.geocollection.GeoQuadMesh`
The geo quad mesh on which to add the colorbar.
cb_label : `str`
The label for the colorbar.
num_ticks : `int`, optional
The number of ticks for the colorbar.
Default is `9`.
cb_loc : `str`, optional
The location of the colorbar.
Default is `'r'`.
cb_extend : `str`, optional
How to extend the ends of the colorbar. Can be `'neither'`, `'both'`, `'min'`, or `'max'`.
Default is `'neither'`.
**kwargs : keyword arguments
Additional keyword arguments to pass to `colorbar()`, such as the `rows` tuple.
Returns
-------
cbar : `matplotlib.colorbar.Colorbar`
The generated colorbar.
Examples
--------
>>> fig, axs = pplt.subplots(nrows=3, ncols=3)
>>> n_rows_maps = 2
>>> cbar = make_colorbar(fig, axs, cb_label='NOx emissions (kg/m2/s)', rows=(1, n_rows_maps))
"""
from cartopy.mpl.geocollection import GeoQuadMesh
# Verify argument types
if not isinstance(fig, mpl.figure.Figure):
raise TypeError(f"(make_colorbar) `fig` must be a matplotlib Figure. Got type: {type(fig)}")
if not isinstance(cb_ax, GeoQuadMesh):
raise TypeError(f"(make_colorbar) `cb_ax` must be a GeoQuadMesh. Got type: {type(cb_ax)}")
if not isinstance(cb_label, str):
raise TypeError(f"(make_colorbar) `cb_label` must be a string. Got type: {type(cb_label)}")
if not isinstance(num_ticks, int):
raise TypeError(f"(make_colorbar) `num_ticks` must be an integer. Got type: {type(num_ticks)}")
if not isinstance(cb_loc, str):
raise TypeError(f"(make_colorbar) `cb_loc` must be a string. Got type: {type(cb_loc)}")
if not isinstance(cb_extend, str):
raise TypeError(f"(make_colorbar) `cb_extend` must be a string. Got type: {type(cb_extend)}")
# Add one overall colorbar for the entire figure on the right-hand side
cbar = fig.colorbar(cb_ax, loc=cb_loc, label=cb_label, extend=cb_extend, **kwargs)
# Set ticks for the colorbar (use an odd number of ticks to have a zero tick in the middle)
cbar.locator = mpl.ticker.LinearLocator(numticks = num_ticks)
cbar.update_ticks()
return cbar
[docs]
def plot_hist(
data_arrs,
ax=None,
n_bins=100,
ax_label='NOx emissions (kg/m2/s)',
ylabel='Frequency',
plt_title=None,
log_scale=False,
):
""" Plot a histogram of one or more data arrays.
Create a histogram of the given data on the provided axis, or create a new
figure and axis if none is provided. Each array in `data_arrs` is plotted as
a separate histogram.
Parameters
----------
data_arrs : `list`, `numpy.ndarray`, `xr.DataArray`
The array(s) to plot.
ax : `matplotlib.axes.Axes`, `None`, optional
The axes on which to plot the histogram. If `None`, a new figure and axes are created.
Default is `None`.
n_bins : `int`, optional
The number of bins to use for the histogram.
Default is `100`.
ax_label : `str`, optional
The label for the x-axis.
Default is `'NOx emissions (kg/m2/s)'`.
ylabel : `str`, optional
The label for the y-axis.
Default is `'Frequency'`.
plt_title : `str`, optional
The title of the plot.
Default is `None`.
log_scale : `bool`, optional
If `True`, the y-axis will be set to a logarithmic scale.
Default is `False`.
Returns
-------
fig or ax : `matplotlib.figure.Figure` or `matplotlib.axes.Axes`
The figure or axes containing the histogram.
Examples
--------
>>> data_arr1 = uarray('nox_2019_JFM', is_input_set=True).xr['no2'].values
>>> data_arr2 = uarray('nox_2019_JFM', is_input_set=True).xr['no2_s2'].values
>>> fig = plot_hist(data_arr1)
>>> fig, axs = pplt.subplots(nrows=2, ncols=1)
>>> axs[0] = plot_hist([data_arr1, data_arr2], ax=axs[0], n_bins=50, plt_title='Histogram of NO2 emissions, both stages')
"""
# Verify argument types
if not isinstance(data_arrs, list):
if isinstance(data_arrs, (np.ndarray, xr.DataArray)):
data_arrs = [data_arrs]
else:
raise TypeError(f"(plot_hist) `data_arrs` must be a list, numpy array, or xarray DataArray. Got type: {type(data_arrs)}")
else:
for data_arr in data_arrs:
if not isinstance(data_arr, (np.ndarray, xr.DataArray)):
raise TypeError(f"(plot_hist) Each element of `data_arrs` must be a numpy array or xarray DataArray. Got type: {type(data_arr)}")
if not isinstance(ax, (pplt.axes.Axes, type(None))):
raise TypeError(f"(plot_hist) `ax` must be a proplot Axes object or None. Got type: {type(ax)}")
if not isinstance(n_bins, int):
raise TypeError(f"(plot_hist) `n_bins` must be an integer. Got type: {type(n_bins)}")
if not isinstance(ax_label, str):
raise TypeError(f"(plot_hist) `ax_label` must be a string. Got type: {type(ax_label)}")
if not isinstance(ylabel, str):
raise TypeError(f"(plot_hist) `ylabel` must be a string. Got type: {type(ylabel)}")
if not isinstance(plt_title, (type(None), str)):
raise TypeError(f"(plot_hist) `plt_title` must be a string or None. Got type: {type(plt_title)}")
if not isinstance(log_scale, bool):
raise TypeError(f"(plot_hist) `log_scale` must be a bool. Got type: {type(log_scale)}")
# Create a new figure and axis if none is provided
if isinstance(ax, type(None)):
new_fig = True
else:
new_fig = False
if new_fig:
fig, ax = pplt.subplots()
# Loop across the data arrays
for data_arr in data_arrs:
if isinstance(data_arr, xr.DataArray):
data_arr = data_arr.values
# Flatten the array
flat_arr = data_arr.flatten()
# Plot the histogram
ax.hist(flat_arr, bins=n_bins, alpha=0.5, label='n = '+str(len(flat_arr)))
# Format the plot
ax.set_xlabel(ax_label)
ax.set_ylabel(ylabel)
if not isinstance(plt_title, type(None)):
ax.set_title(plt_title)
if log_scale:
ax.set_yscale('log')
# Set the ticks to scientific notation
ax.format(yformatter='sci')
# Add legend to show the number of values in the histogram
ax.legend()
# If new plot, return the figure
if new_fig:
return fig
else:
return ax
[docs]
def plot_BaW(
vars,
datasets,
ds_kwargs=None,
axs=None,
violin=False,
**kwargs,
):
""" Plot box-and-whisker (or violin) plots for specified variables.
Create a box-and-whisker plot (or optional violin plot) of the specified
variables from one or more datasets.
Parameters
----------
vars : `list`, `str`
The variable(s) to plot on the x-axis/axes.
Can be `R2`, `RMSE`, or any variable in the dataset.
datasets : `list`, `str`, `uarray`, `xarray.Dataset`, `xarray.DataArray`
The dataset(s) from which to get the data for the plot.
ds_kwargs : `list`, `dict`, optional
Dictionaries of keyword arguments to specify the format of each plot.
Default is `None`.
axs : `list`, `matplotlib.axes.Axes`, `None`, optional
The axes on which to plot the data.
If `None`, a new figure is created.
Default is `None`.
violin : `bool`, optional
Whether to make a violin plot instead of a box-and-whisker plot.
Default is `False`.
**kwargs : dict
Additional keyword arguments to pass to `set_fig_row_col()`.
Returns
-------
fig : `matplotlib.figure.Figure`
If no axes were provided, return the figure object containing the plot.
axs : `matplotlib.axes.Axes`
If axes were provided, return those axes with the plots on them.
Examples
--------
>>> fig = plot_BaW(
... ['R2', 'RMSE'],
... '_ZFI_ensemble_run',
... ds_kwargs=[
... {
... 'is_predict': True,
... 'start_date': '2019-01-02',
... 'interval': '1Y',
... 'restrict_lat_lon_to': 'datafiles/sample_data/nox_2019_t106_US.nc',
... },
... ]
... )
"""
# Verify argument types
if not isinstance(vars, list):
if isinstance(vars, str):
vars = [vars]
else:
raise TypeError(f"(plot_BaW) `vars` must be a list of strings or a single string. Got type: {type(vars)}")
for var in vars:
if not isinstance(var, str):
raise TypeError(f"(plot_BaW) `var` must be a string. Got type: {type(var)}")
if not isinstance(datasets, (list)):
# Making `uarray` object verifies `dataset`
datasets = [datasets]
if len(datasets) == 1:
one_dataset = True
else:
one_dataset = False
if not isinstance(ds_kwargs, (list, type(None))):
if isinstance(ds_kwargs, type({})):
ds_kwargs = [ds_kwargs]
else:
raise TypeError(f"(plot_BaW) `ds_kwargs` must be a list of dictionaries or `None`. Got type: {type(ds_kwargs)}")
else:
for these_kwargs in ds_kwargs:
if not isinstance(these_kwargs, type({})):
raise TypeError(f"(plot_BaW) Each entry in `ds_kwargs` must be a dictionary. Got type: {type(these_kwargs)}")
if len(ds_kwargs) == 0:
raise ValueError("(plot_BaW) `ds_kwargs` list cannot be empty.")
elif len(ds_kwargs) == 1:
one_ds_kwargs = True
else:
one_ds_kwargs = False
if len(datasets) != len(ds_kwargs):
ds_kwargs_len = len(ds_kwargs)
# Repeat the list of `ds_kwargs` for each entry in `datasets`
ds_kwargs = ds_kwargs * len(datasets)
# Create list, repeating each dataset for each entry in `ds_kwargs`
new_dataset_list = []
for i in range(len(datasets)):
for j in range(ds_kwargs_len):
new_dataset_list.append(datasets[i])
datasets = new_dataset_list
if len(datasets) != len(ds_kwargs):
raise ValueError(f"(plot_BaW) Length of `datasets` ({len(datasets)}) must match length of `ds_kwargs` ({len(ds_kwargs)}).")
if not isinstance(axs, (list, type(None))):
if isinstance(axs, pplt.axes.Axes):
axs = [axs]
else:
raise TypeError(f"(plot_BaW) `axs` must be a proplot Axes object, a list of proplot Axes objects, or None. Got type: {type(axs)}")
elif not isinstance(axs, type(None)):
for ax in axs:
if not isinstance(ax, pplt.axes.Axes):
raise TypeError(f"(plot_BaW) Each entry in `axs` must be a proplot Axes object. Got type: {type(ax)}")
if not isinstance(violin, bool):
raise TypeError(f"(plot_BaW) `violin` must be a bool. Got type: {type(violin)}")
# Check whether to prepare for a set of runs
new_dataset_list = []
new_ds_kwargs_list = []
# Loop across the given datasets
for i in range(len(datasets)):
dataset = datasets[i]
these_ds_kwargs = ds_kwargs[i]
# Check whether they are, in fact, sets of runs
try:
verify_path(f"HPC_runs/{dataset}/SET_OF_RUNS.txt")
except:
# Not a set of runs, add the dataset and the ds_kwargs to the new lists and move to the next one
new_dataset_list.append(dataset)
new_ds_kwargs_list.append(these_ds_kwargs)
continue
# Find the subdirectories
sub_dirs = os.listdir(f"HPC_runs/{dataset}")
for sub_dir in sub_dirs:
# Verify that it is a directory, not a file
if os.path.isdir(f"HPC_runs/{dataset}/{sub_dir}"):
# Add the subdirectory to the dataset list
new_dataset_list.append(f"{dataset}/{sub_dir}")
new_ds_kwargs_list.append(these_ds_kwargs)
datasets = new_dataset_list
ds_kwargs = new_ds_kwargs_list
# Ensure the number of datasets is equal to the number of ds_kwargs dictionaries
if len(datasets) != len(ds_kwargs):
raise ValueError(f"(plot_BaW) Length of `datasets` ({len(datasets)}) must match length of `ds_kwargs` ({len(ds_kwargs)}).")
for i in range(len(datasets)):
# Making `uarray` object verifies `dataset`
datasets[i] = uarray(datasets[i], **ds_kwargs[i])
# Get name for overall title
if one_dataset:
# Take the name of the first dataset before the slash
overall_name = datasets[0].name.split('/')[0].strip('_')
else:
overall_name = ""
# Create dictionaries to hold the data to plot
box_dfs = [None]*len(vars)
# I don't know why, but if I try to define a list of dictionaries in one line, every dictionary in that list gets updated, even when calling a specific index
box_dicts = []
for j in range(len(vars)):
box_dicts.append({})
ax_labels = [None]*len(vars)
# Loop across each dataset
for i in range(len(datasets)):
# Get the `uarray` object
u_arr = datasets[i]
# Select the time slice to plot
u_arr.xr, title_segment = select_time(u_arr.xr, **ds_kwargs[i])
# Restrict the latitude and longitude range
if 'restrict_lat_lon_to' in ds_kwargs[i] and not isinstance(ds_kwargs[i]['restrict_lat_lon_to'], type(None)):
# Load the specified data set to restrict to
restrict_xr = uarray(ds_kwargs[i]['restrict_lat_lon_to']).xr
# Restrict the domain of the data to plot
u_arr.xr, _ = udata.match_domains(u_arr.xr, restrict_xr, require_equal=False)
box_label_restrict = 'restricted'
else:
box_label_restrict = 'full'
# Loop across each variable to plot
for j in range(len(vars)):
var = vars[j]
# Check whether the variable is already in the dataset
if var in u_arr.xr.data_vars:
# Get the array of that variable and flatten it
var_array = u_arr.xr[var].values.flatten()
# Format the label for this variable
ax_labels[j] = f"{u_arr.xr[var].attrs['long_name']} ({u_arr.xr[var].attrs['units']})"
# Get the name of this dataset, if multiple datasets given
# if one_dataset:
# this_name = f"{var} "
# else:
# this_name = f"{u_arr.name.split('/')[-1]} "
# Format the box label, taking only the name of the child (ensemble member) directory
# box_label = f"{this_name}(n={len(var_array)})"
elif var in ['R2', 'RMSE']:
from unox.evaluate import compare_arrs
# Make sure the dataset is an ensemble run
if not u_arr._is_ensemble():
raise ValueError(f"(plot_BaW) To plot `{var}`, `dataset` {u_arr.name} must be an ensemble of runs. Got `is_ensemble`: {u_arr._is_ensemble()}")
# Get the metadata from the predictions uarray
meta_dict = u_arr._get_metadata()
# Get the input set from the metadata
input_set = meta_dict['config_dict']['input_set']
# Get the input set used in the HPC run
input_uarr = uarray(input_set, is_input_set=True)
# Select the time slice to plot such that it matches the prediction array
pred_start_date = str(u_arr.xr.time.values[0]).split('T')[0].split(' ')[0]
pred_end_date = str(u_arr.xr.time.values[-1]).split('T')[0].split(' ')[0]
# Do not pass keyword arguments into this call of `select_time()`
# to ensure there aren't multiple occurrences of `start_date` or `end_date`
input_uarr.xr, title_segment = select_time(input_uarr.xr, start_date=pred_start_date, end_date=pred_end_date)#, **kwargs)
# Restrict the latitude and longitude range, if applicable
if 'restrict_lat_lon_to' in ds_kwargs[i] and not isinstance(ds_kwargs[i]['restrict_lat_lon_to'], type(None)):
# Restrict the domain of the data to plot
input_uarr.xr, _ = udata.match_domains(input_uarr.xr, restrict_xr, require_equal=False)
# Get the truth array and flatten it
truth_array = input_uarr.xr[input_uarr.xr.attrs['y_var']].values.flatten()
# Find the number of ensemble members
ens_size = u_arr.xr.attrs['ensemble_size']
# Get the name of this dataset, if multiple datasets given
if one_dataset:
this_name = u_arr.name.split('/')[-1].replace(overall_name, "").strip('_')
else:
this_name = u_arr.name.split('/')[-1]
# Make an array to collect the comparison values
var_array = [None]*ens_size
# Calculate the comparison value for each ensemble member
for k in range(ens_size):
# Get the prediction variable name
pred_var = f"{input_uarr.xr.attrs['y_var']}_pred_{k+1:02d}"
# Get the prediction array and flatten it
pred_array = u_arr.xr[pred_var].values.flatten()
# Append that comparison value to the array
var_array[k] = compare_arrs(pred_array, truth_array, var)
# Format the axis label
if var == 'R2':
ax_labels[j] = rf"Correlation R$^2$ (pred vs truth)"
elif var == 'RMSE':
ax_labels[j] = rf"RMSE (pred vs truth)"
# Format the name for this box and whisker plot
box_label = BaW_label(
u_arr,
var=var,
one_dataset=one_dataset,
**kwargs,
)
box_label_extras = ""
for ds_kwarg in ['start_date', 'interval', 'label_note']:
if ds_kwarg in ds_kwargs[i] and not isinstance(ds_kwargs[i][ds_kwarg], type(None)):
box_label_extras = f"{box_label_extras}, {ds_kwargs[i][ds_kwarg]}"
if one_ds_kwargs:
title_extras = f"{box_label_restrict}{box_label_extras}"
if one_dataset:
title_extras = f", {title_extras}"
else:
title_extras = ""
# box_label = f"{box_label}\n{box_label_restrict}{box_label_extras}"
# Add that array to the dictionary
box_dicts[j][box_label] = var_array
# If no axes are given, create a new figure
if isinstance(axs, type(None)):
new_plot = True
fig = pplt.figure(refwidth=4, sharex=False)
n_rows, n_cols = uplt_fmt.set_fig_row_col(len(vars), **kwargs)
axs = fig.subplots(nrows=n_rows, ncols=n_cols)
else:
new_plot = False
for j in range(len(axs)):
# Create a pandas DataFrame from the dictionary
## using the pd.Series to fill shorter arrays with NaNs
box_df = pd.DataFrame(
{k:pd.Series(v) for k, v in box_dicts[j].items()}
)
# Plot the violin or box and whisker plot
if violin:
axs[j].violinploth(box_df)
else:
axs[j].boxploth(box_df, whis=1.5)
# Format the axes
axs[j].set_xlabel(ax_labels[j])
if new_plot == True:
# Add an overall title
fig.suptitle(f"{overall_name}{title_extras}", fontsize=title_font_size)
# Return the figure
return fig
else:
# Return the axes
return axs
[docs]
def BaW_label(
u_arr,
label_with=None,
var=None,
one_dataset=False,
**kwargs,
):
""" Assemble a label for a box-and-whisker plot.
Construct a label for a box-and-whisker plot based on metadata from a `uarray`,
the requested components in `label_with`, and any formatting kwargs.
Parameters
----------
u_arr : `uarray`
The uarray object from which to derive label metadata.
label_with : `list`, `None`, optional
Components to include in the label (e.g., `['name', 'size']`).
Default is `None`, which uses `['name', 'size']`.
var : `str`, optional
The variable name to use when computing label components like size.
one_dataset : `bool`, optional
Whether the plot is for a single dataset.
Default is `False`.
**kwargs : dict
Additional keyword arguments passed to `format_sci_notation()`.
Returns
-------
str
The formatted label.
Examples
--------
>>> BaW_label(
... uarray('no2_example_run', is_predict=True),
... label_with=['name', 'size'],
... var='nox_pred',
... one_dataset=False,
... )
'no2_example_run (n=4892160)'
"""
# Verify argument types
if not isinstance(u_arr, uarray):
raise TypeError(f"(BaW_label) `u_arr` must be a `uarray` object. Got type: {type(dataset)}")
if isinstance(label_with, type(None)):
label_with = ['name', 'size']
# Get the metadata from the `uarray` object
meta_dict = u_arr._get_metadata()
# Check for presence of particular components in `label_with` and add those to the label
box_label = ""
spacer=""
# Loop through the components of `label_with`
for label_with_this in label_with:
# Check for specific components to add to the label
if label_with_this == 'name':
if one_dataset:
if var in ['R2', 'RMSE']:
add_this = u_arr.name.split('/')[-1].replace(u_arr.name.split('/')[0].strip('_'), "").strip('_')
elif not isinstance(var, type(None)):
add_this = f"{u_arr.name.split('/')[-1]}"
else:
raise ValueError(f"(BaW_label) To label with 'name' when `one_dataset` is True, `var` must be specified. Got: {var}")
else:
add_this = f"{u_arr.name.split('/')[-1]}"
elif label_with_this == 'size':
if var in u_arr.xr.data_vars:
this_size = len(u_arr.xr[var].values.flatten())
elif var in ['R2', 'RMSE']:
this_size = u_arr.xr.attrs['ensemble_size']
else:
raise ValueError(f"(BaW_label) To label with 'size', `var` must be a variable in the dataset or 'R2' or 'RMSE'. Got: {var}")
add_this = f"n={this_size}"
# Check the metadata dictionary for components to add to the label
elif label_with_this in meta_dict:
add_this = meta_dict[label_with_this]
elif label_with_this in meta_dict['config_dict']:
add_this = meta_dict['config_dict'][label_with_this]
else:
raise ValueError(f"(BaW_label) Component to label with '{label_with_this}' not found. \nComponents must either be 'name', 'size', a variable in the given dataset, or a key in the metadata dictionary.")
# Check whether the component to add to the label is a number
if udata.verify_number(add_this):
# Format the number
add_this = rf"${uplt_fmt.format_sci_notation(add_this, **kwargs)}$"
# Add the component to the label
box_label = rf"{box_label}{spacer}{add_this}"
# Set the spacer for the next loop to be a comma and a space
spacer=", "
return box_label