Source code for unox.input
import numpy as np
import os
import xarray as xr
import pandas as pd
import json
import warnings
import unox.unox as unox
from unox.HPC.data0.paths import verify_path, remove_non_empty_directory, make_file_path
from unox.HPC.data0.dataset import uarray, get_years
from unox.HPC.data0.verify_dataset import verify_dataset, verify_var
from unox.HPC.data0.latlon import shift_lon_arr
import unox.data as udata
from unox.plot_format import pad_extent
# emiss = Emissions (TCR-2 t106)
# chemra = Chemical Reanalysis (TROPESS TCR-2)
# insitu = Insitu data (EPA)
# era5 = ERA5 reanalysis data
# Define a dictionary of the variables to be used for each model variable
era5_vars_list = ['u10', 'v10', 'blh', 'sp', 'skt', 't2m', 'ssrd', 'lsm']
input_vars_dict = {
'no2': {
'x_vars': ['no2', 'no2_tm1'] + era5_vars_list,
'y_vars': ['nox'],
},
'co': {
'x_vars': ['SpeciesConcVV_CO', 'SpeciesConcVV_CO_tm1'] + era5_vars_list,
'y_vars': ['EmisCO_Total'],
}
}
[docs]
def x_or_y_var(
var,
):
""" Return whether the given variable is an x or y variable.
Parameters
----------
var : `str`
The variable to check.
Returns
-------
x_or_y : `str`
'x' if the variable is an x variable, 'y' if it is a y variable.
Examples
--------
>>> x_or_y_var('no2')
'x'
>>> x_or_y_var('nox')
'y'
"""
# Verify the variable is a string
if not isinstance(var, str):
raise TypeError(f"(x_or_y_var) `var` must be a string. Got type: {type(var)}")
# Check if the variable is in the input_vars_dict
for key in input_vars_dict.keys():
if var in input_vars_dict[key]['x_vars']:
return 'x'
elif var in input_vars_dict[key]['y_vars']:
return 'y'
raise ValueError(f"(x_or_y_var) Variable '{var}' not recognized. Available variables in input_vars_dict: {input_vars_dict}")
[docs]
def get_input_index(
var,
):
""" Get the index of the given variable in the input array.
Parameters
----------
var : `str`
The variable to check.
Returns
-------
index : `int`
The index of the variable in the input array.
Examples
--------
>>> get_input_index('no2')
0
>>> get_input_index('u10')
2
"""
# Verify the variable is a string
if not isinstance(var, str):
raise TypeError(f"(get_input_index) Variable must be a string. Got type: {type(var)}")
# Check if the variable is in the input_vars_dict
for key in input_vars_dict.keys():
if var in input_vars_dict[key]['x_vars']:
return input_vars_dict[key]['x_vars'].index(var)
elif var in input_vars_dict[key]['y_vars']:
return input_vars_dict[key]['y_vars'].index(var)
raise ValueError(f"(get_input_index) Variable '{var}' not recognized. Available variables in input_vars_dict: {input_vars_dict}")
[docs]
def make_y_input_file(
year,
var='nox',
emiss_dir='/data/high_res/emacdonald/unet/datafiles/t106',
emiss_pre='nox_',
emiss_post='_t106_US.nc',
scale_factor=1e12,
nan_fill=0,
stage_2_cutoff=2013,
output_dir='test_input',
write_this_year=True,
overwrite=True,
output_format='nc',
**kwargs,
):
""" Create a y input file for the Unet model for the given year.
The array in the generated file will have these dimensions:
- time: 364 (or 365 for leap years)
- One day less than usual to allow for t-1 variable
- lat: length depends on the latitude grid
- lon: length depends on the longitude grid
- var: 1 (a dummy dimension to match the x input files)
Parameters
----------
year : `int`
The year for which to create the y input file (between 2005 and 2021).
var : `str`, optional
The variable to extract from the dataset. Default is 'nox'.
emiss_dir : `str`, optional
Directory where the emissions data are stored.
Default is '/data/high_res/emacdonald/unet/datafiles/t106'.
emiss_pre : `str`, optional
Prefix for the emissions input file name. Default is 'nox_'.
emiss_post : `str`, optional
Extension for the input file name. Default is '_t106_US.nc'.
scale_factor : `float`, optional
Factor by which to scale the data. Default is 1e12.
nan_fill : `float`, optional
Value to fill NaNs in the dataset. Default is 0.
stage_2_cutoff : `int`, optional
Year after which the data will also be saved in stage 2.
output_dir : `str`, optional
Directory inside `inputfiles/` where the output y input file will be saved.
Default is `'test_input'`.
write_this_year : `bool`, optional
Whether to write the data for this year or just return the xarray without writing to file.
Default is True.
output_format : `str`, optional
Whether to save netcdf files ('nc'), numpy arrays ('npy'), or 'both'.
Default is 'nc'. Irrelevant if `write_this_year` is False.
overwrite : `bool`, optional
Whether to overwrite existing netcdf files. Default is True.
**kwargs : `dict`, optional
Additional keyword arguments (not used).
Returns
-------
input_netcdf_xr : `xarray.Dataset`
The y input data for the specified year.
g_attr_dict : `dict`
Dictionary of global attributes for the dataset.
"""
# Assemble file path
filepath = f'{emiss_dir}/{emiss_pre}{year}{emiss_post}'
# Verify the path
filepath = verify_path(filepath)
# Load data for the specified year
y_data = xr.load_dataset(filepath)
# If level dimension present, sum across levels
if "lev" in list(y_data.coords):
print("level dimension detected")
y_data = y_data.sum("lev")
# Scale data
y_data = scale_xr_var(y_data, var, scale_factor)
# y_data = y_data * scale_factor
# Load lats and lons
lats, lons = unox.load_lats_lons()
# Interpolate to the latitude and longitude grid, resample to daily mean,
# and fill NaNs with specified value
y_data = y_data.interp(lat=lats, lon=lons).resample(time='d').mean().fillna(nan_fill)
# Add a dimension of size 1 to the end to match the number of dimensions for the x input files
y_data = y_data.expand_dims('var',-1)
# Convert calendar to 'noleap' to remove February 29th
y_data = y_data.convert_calendar('noleap')
# Skip the first day because of the t-1 thing
input_netcdf_xr = y_data.isel(time=slice(1,None))
y_data = y_data[var][1::]
# Create a dictionary of global attributes
g_attr_dict={
'y_var': var,
'emiss_dir': emiss_dir,
'emiss_pre': emiss_pre,
'emiss_post': emiss_post,
'nan_fill': nan_fill,
'stage_2_cutoff': stage_2_cutoff,
}
# Write out results
if not isinstance(output_dir, type(None)):
# Create metadata file
meta_dict = make_input_metadata_file(
input_netcdf_xr,
output_dir=output_dir,
g_attrs=g_attr_dict,
)
# Save the data to file
# For writing out a numpy file
if output_format in ['npy', 'both']:
# Assemble the file path
output_filepath = f'inputfiles/{output_dir}/stage1/y/Y_{year}.npy'
# Make sure the output directory exists
make_file_path(output_filepath)
np.save(output_filepath, y_data)
if year > stage_2_cutoff:
# Save in stage 2 for years later than specified
output_filepath_stage2 = f'inputfiles/{output_dir}/stage2/y/Y_{year}.npy'
# Make sure the output directory exists
make_file_path(output_filepath_stage2)
np.save(output_filepath_stage2, y_data)
# Output message
print(f"Created Y input file for {var} in {year}, saved to {output_filepath}")
if write_this_year:
# For writing out a netcdf file
if output_format in ['nc', 'both']:
# Assemble the file path
output_filepath = f'inputfiles/{output_dir}/{output_dir}.nc'
# Write data out to a netcdf
input_netcdf_xr = write_input_netcdf(
input_netcdf_xr,
output_filepath,
g_attr_dict=g_attr_dict,
overwrite=overwrite,
**kwargs,
)
print(f"Saved y input data to {output_filepath}")
return xr.load_dataset(output_filepath), g_attr_dict
return input_netcdf_xr, g_attr_dict
[docs]
def write_input_netcdf(
input_netcdf_xr,
output_filepath,
g_attr_dict=None,
overwrite=True,
sort=True,
**kwargs,
):
""" Write an xarray Dataset to a netcdf file, appending or overwriting as needed.
Parameters
----------
input_netcdf_xr : `xarray.Dataset`
The dataset to write to the netcdf file.
output_filepath : `str`
Path to the output netcdf file.
g_attr_dict : `dict`, optional
Dictionary of global attributes to add to the dataset if creating a new file.
overwrite : `bool`, optional
Whether to overwrite existing data in the netcdf file if there are overlapping times.
Default is True.
sort : `bool`, optional
Whether to sort the xarray before writing to netcdf. Sorting takes a long time.
Default is True.
Returns
-------
input_netcdf_xr : `xarray.Dataset`
The dataset that was written to the netcdf file.
"""
# Check whether the netcdf file already exists
if os.path.exists(output_filepath):
# Load the existing netcdf file
existing_ds = xr.load_dataset(output_filepath)
# Verify the dataset
existing_ds = verify_dataset(existing_ds)
# Check if the existing dataset and the new one have the same lat/lon values
existing_lats = existing_ds.coords['lat'].values
existing_lons = existing_ds.coords['lon'].values
new_lats = input_netcdf_xr.coords['lat'].values
new_lons = input_netcdf_xr.coords['lon'].values
if not np.array_equal(existing_lats, new_lats):
raise ValueError(f"(write_input_netcdf) Latitude values of the existing netcdf file and the new data do not match. \nExisting lats: \n{existing_lats} \nNew lats: \n{new_lats}")
if not np.array_equal(existing_lons, new_lons):
raise ValueError(f"(write_input_netcdf) Longitude values of the existing netcdf file and the new data do not match. \nExisting lons: \n{existing_lons} \nNew lons: \n{new_lons}")
# Get lists of variables from both datasets
new_vars = list(input_netcdf_xr.data_vars)
existing_vars = list(existing_ds.data_vars)
# Find the variables in common, if any
shared_vars = set(new_vars) & set(existing_vars)
if len(shared_vars) > 0:
# Check whether any time values are already present in the existing dataset
existing_times = set(existing_ds.coords['time'].values)
new_times = set(input_netcdf_xr.coords['time'].values)
overlapping_times = existing_times.intersection(new_times)
if len(overlapping_times) > 1:
# Get the first and last overlapping times
first_overlap = min(overlapping_times)
last_overlap = max(overlapping_times)
# Format them to YYYY-MM-DD
first_overlap = pd.to_datetime(str(first_overlap)).strftime('%Y-%m-%d')
last_overlap = pd.to_datetime(str(last_overlap)).strftime('%Y-%m-%d')
if overlapping_times and overwrite==False:
raise ValueError(f"(write_input_netcdf) The new data overlaps with the existing file in {output_filepath} between {first_overlap} and {last_overlap}. To overwrite, set overwrite=True.")
elif overlapping_times and overwrite==True:
print(f"Overwriting overlapping data in {output_filepath} for times between {first_overlap} and {last_overlap}.")
# Remove the overlapping times from the existing dataset
existing_ds = existing_ds.drop_sel(time=list(overlapping_times))
# Concatenate the new data with the existing dataset along the time dimension
input_netcdf_xr = xr.concat([existing_ds, input_netcdf_xr], dim='time')
else:
# Merge the datasets
input_netcdf_xr = xr.merge([existing_ds, input_netcdf_xr])
# Sort the dataset by time
if sort:
print("Sorting the dataset by time.")
input_netcdf_xr = input_netcdf_xr.sortby('time')
else:
# Add a description
input_netcdf_xr.attrs['description'] = f"Input data for the Unet model. Data for each year is added to this file as it is generated."
# Add global attributes
input_netcdf_xr = set_global_attrs(input_netcdf_xr, g_attr_dict)
# Save the netcdf file
# Make sure the output directory exists
make_file_path(output_filepath)
input_netcdf_xr.to_netcdf(output_filepath)
return input_netcdf_xr
[docs]
def set_global_attrs(
xr_dataset,
attr_dict,
):
""" Add attributes to an xarray Dataset.
Parameters
----------
xr_dataset : `xarray.Dataset`
The dataset to which attributes will be added.
attr_dict : `dict`
Dictionary of attributes to add to the dataset.
Returns
-------
`xarray.Dataset`
The dataset with added attributes.
"""
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Verify the attribute dictionary
if not isinstance(attr_dict, dict):
raise TypeError(f"(set_global_attrs) `attr_dict` must be a dictionary. Got type: {type(attr_dict)}")
# Add each attribute to the dataset
for key, value in attr_dict.items():
xr_dataset.attrs[key] = value
# Update the modification date
xr_dataset.attrs['modification_date'] = pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
return xr_dataset
[docs]
def set_var_attrs(
xr_dataset,
var,
attr_dict,
):
""" Add attributes to a variable in an xarray Dataset.
Parameters
----------
xr_dataset : `xarray.Dataset`
The dataset containing the variable to which attributes will be added.
var : `str`
The variable to which attributes will be added.
attr_dict : `dict`
Dictionary of attributes to add to the variable.
Returns
-------
`xarray.Dataset`
The dataset with the variable having added attributes.
"""
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Verify `var` is in the dataset
verify_var(xr_dataset, var)
# Verify the attribute dictionary
if not isinstance(attr_dict, dict):
raise TypeError(f"(set_var_attrs) `attr_dict` must be a dictionary. Got type: {type(attr_dict)}")
# Add each attribute to the variable
for key, value in attr_dict.items():
xr_dataset[var].attrs[key] = value
return xr_dataset
[docs]
def scale_xr_var(
xr_dataset,
var,
scale_factor,
):
""" Scale a variable in an xarray Dataset by a given factor.
Parameters
----------
xr_dataset : `xarray.Dataset`
The dataset containing the variable to be scaled.
var : `str`
The variable to be scaled.
scale_factor : `float`
The factor by which to scale the variable.
Returns
-------
`xarray.Dataset`
The dataset with the scaled variable.
"""
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Verify `var` is in the dataset
verify_var(xr_dataset, var)
# Note the variable attributes
var_attrs = xr_dataset[var].attrs
# Scale the variable
xr_dataset[var] = xr_dataset[var] * scale_factor
# Add scale factor to the attributes
## Note: `scale_factor` is a protected attribute name in xarray. If used, the variable
## will be scaled by that factor when loading with xr.load_dataset() and `scale_factor`
## will not show up in the loaded xarray. I'm using `scaled_by` to avoid this confusion.
var_attrs['scaled_by'] = scale_factor
# Reapply the variable attributes
xr_dataset = set_var_attrs(xr_dataset, var, var_attrs)
return xr_dataset
[docs]
def add_tm1_var(
xr_dataset,
var,
year,
):
""" Add a (t-1) version of the given variable to the dataset.
Add a version of the given variable which is shifted by one day (t-1) to the dataset, and drop January 1st from the time coordinate.
Parameters
----------
xr_dataset : `xarray.Dataset`
The dataset containing the variable to shifted.
var : `str`
The variable to be shifted.
year : `int`
The year which xr_dataset covers (between 2005 and 2021).
Returns
-------
`xarray.Dataset`
The dataset with the shifted variable.
"""
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Verify `var` is in the dataset
verify_var(xr_dataset, var)
# Note the variable attributes
var_attrs = xr_dataset[var].attrs
# Create name for t-1 variable
var_tm1 = f'{var}_tm1'
# Create a t-1 shifted version of the variable
xr_dataset[var_tm1] = xr_dataset[var].shift(time=1)
# Add shifted_from to the attributes
var_attrs['shifted_from'] = var
# Add (t-1) to the name in the attributes
var_attrs['long_name'] = var_attrs.get('long_name', var) + ' (t-1)'
# Check for stage 2 variable
var_s2 = f'{var}_s2'
if var_s2 in xr_dataset.data_vars:
# Note the variable attributes
var_s2_attrs = xr_dataset[var_s2].attrs
# Create name for t-1 variable
var_s2_tm1 = f'{var_s2}_tm1'
# Create a t-1 shifted version of the variable
xr_dataset[var_s2_tm1] = xr_dataset[var_s2].shift(time=1)
# Add shifted_from to the attributes
var_s2_attrs['shifted_from'] = var_s2
# Add (t-1) to the name in the attributes
var_s2_attrs['long_name'] = var_s2_attrs.get('long_name', var_s2) + ' (t-1)'
# Drop January 1st, as the t-1 variable will have null values on that day
try:
xr_dataset = xr_dataset.drop_sel(time=f'{year}-01-01')
except:
print(f'\tJanuary 1st, {year} not present in xr_dataset')
# Reapply the variable attributes
xr_dataset = set_var_attrs(xr_dataset, var_tm1, var_attrs)
if var_s2 in xr_dataset.data_vars:
xr_dataset = set_var_attrs(xr_dataset, var_s2_tm1, var_s2_attrs)
return xr_dataset
[docs]
def make_x_input_file(
year,
stage_2=True,
data_dir='/data/high_res',
chemra_path='emacdonald/unet/datafiles/TROPESS/TROPESS_reanalysis_2hr_no2_sfc_',
chemra_var='no2',
insitu_path='US_EPA/NO2/daily_NO2/daily_42602_',
era5_path='ERA5concatenated',
scale_factors={'chemra': 1e-3,
'sp': 1e-5,
'ssrd': 1e-6,
'blh': 1e-3},
stage_2_cutoff=2013,
output_dir='test_input',
write_this_year=True,
output_format='nc',
overwrite=True,
**kwargs,
):
""" Create an x input file for the Unet model for the given year and stage.
The array in the file will have these dimensions:
- time: 364 (or 365 for leap years)
- One day less than usual to allow for t-1 variable
- lat: length depends on the latitude grid
- lon: length depends on the longitude grid
- var: 9 variables (e.g., 'no2', 'u10', 'v10', etc.)
Parameters
----------
year : `int`
The year for which to create the x input file.
stage_2 : `bool`, optional
Whether or not to make stage 2 in addition to stage 1 for the input.
Default is True.
data_dir : `str`, optional
Directory where the NOx data are stored.
Default is '/data/high_res'.
chemra_path : `str`, optional
Path to the chemical reanalysis data files.
Default is 'emacdonald/unet/datafiles/TROPESS/TROPESS_reanalysis_2hr_no2_sfc_'.
chemra_var : `str`, optional
The variable to extract from the dataset. Default is 'no2'
insitu_path : `str`, optional
Path to the insitu data files. Default is 'US_EPA/NO2/daily_NO2/daily_42602_'.
era5_path : `str`, optional
Path to the ERA5 reanalysis data files. Default is 'ERA5concatenated'.
scale_factors : `dict`, optional
Scaling factors for the variables. Default is a dictionary with
scaling factors for 'chemra', 'sp', 'ssrd', and 'blh'.
stage_2_cutoff : `int`, optional
Year after which input files will also be generated for stage 2. Default is 2013.
output_dir : `str`, optional
Directory inside `inputfiles/` where the output x input file will be saved.
Default is `'test_input'`.
write_this_year : `bool`, optional
Whether to write the data for this year or just return the xarray without writing to file.
Default is True.
output_format : `str`, optional
Whether to save netcdf files ('nc'), numpy arrays ('npy'), or 'both'.
Default is 'nc'. Irrelevant if `write_this_year` is False.
overwrite : `bool`, optional
Whether to overwrite existing netcdf files. Default is True.
**kwargs : `dict`, optional
Additional keyword arguments (not used).
Returns
-------
x_data : `xarray.Dataset`
The x input data for the specified year.
"""
# Assemble the file path for the chemical reanalysis data
chemra_filepath = f'{data_dir}/{chemra_path}{year}.nc'
# Verify the path
chemra_filepath = verify_path(chemra_filepath)
# Load chemical reanalysis data
# chemra = xr.load_dataset(chemra_filepath)
chemra = xr.open_dataset(chemra_filepath)
# If level dimension present, sum across levels
if "lev" in list(chemra.coords):
print("level dimension detected")
chemra = chemra.sum("lev")
# Regularize the data depending on the source
if chemra_path=='emacdonald/unet/datafiles/TROPESS/TROPESS_reanalysis_2hr_no2_sfc_':
# Change longitude coordinate convention to match other data
# chemra.coords['lon'] = (chemra.coords['lon'] + 180) % 360 - 180
chemra = shift_lon_arr(chemra)
# Drop the `nv` dimension and the `bnds` variables
if 'nv' in chemra.dims:
chemra = chemra.isel(nv=0).drop_vars(['time_bnds', 'lon_bnds', 'lat_bnds'])
# For time, latitude, and longitude, drop the var+`_bnds` attributes
for coord in ['time', 'lat', 'lon']:
if 'bounds' in chemra[coord].attrs:
chemra[coord].attrs.pop('bounds')
# Get latitude and longitude values
lats, lons = unox.load_lats_lons()
# Get the extent of the lats and lons
extent = udata.get_extent(lats=lats, lons=lons)
# Pad the extent
extent = pad_extent(extent, padding=0.1)
# Trim the chemical reanalysis data to the extent of the lat/lon grid
chemra = chemra.where(
(chemra.lat >= extent[0]) &
(chemra.lat <= extent[1]) &
(chemra.lon >= extent[2]) &
(chemra.lon <= extent[3]),
drop=True,
)
# Resample the time to days
chemra = chemra.resample(time='d').mean()
# Rescale the chemical reanalysis data
chemra = scale_xr_var(chemra, chemra_var, scale_factors['chemra'])
# Find the number of days in the year
ndays = len(chemra.coords['time'])
# Fix the time coordinate to match the year
if chemra_path=='emacdonald/unet/datafiles/TROPESS/TROPESS_reanalysis_2hr_no2_sfc_':
# Save the time attributes
time_attrs = chemra['time'].attrs
# For an unexplained reason, the year in all TCR-2 files is always 2005.
chemra.coords['time'] = pd.date_range(f"{year}-01-01", periods=ndays)
# Reapply the time attributes
chemra['time'].attrs = time_attrs
# Combine chemical reanalysis and insitu data for stage 2
if stage_2 and year > stage_2_cutoff:
stages=[1,2]
print(f'\tAdding stage 2 data for {chemra_var} in {year}')
# Assemble the file path for the insitu data
epa_filepath = f'{data_dir}/{insitu_path}{year}.csv'
# Verify the path
epa_filepath = verify_path(epa_filepath)
# Combine chemical reanalysis and insitu data
chemra = fill_w_insitu(chemra, epa_filepath)
else:
stages=[1]
# Interpolate to latitude and longitude grid
chemra = chemra.interp(lat=lats, lon=lons, method='slinear')
# Get the time-shifted variable (day t-1)
previousday = chemra.copy()
# Fix rounding
previousday.coords['time'] = (previousday.coords['time'] + 1).dt.ceil('D')
# Rename t-1 variable
chemra_var_tm1 = chemra_var+'_tm1'
previousday = previousday.rename({chemra_var: chemra_var_tm1})
# Add the chemical reanalysis data for the previous day (t-1)
chemra = add_tm1_var(chemra, chemra_var, year)
# Add the other variables from the ERA5 dataset
for variable in era5_vars_list:
# Assemble the file path for the ERA5 variable
era5_var_filepath = f'{data_dir}/{era5_path}/{year}{variable}.nc'
# Verify the path
era5_var_filepath = verify_path(era5_var_filepath)
# Load the ERA5 variable dataset
# Note: The variable name in the dataset is assumed to be the same as `variable`
era5_var = xr.load_dataset(era5_var_filepath)
# Drop the `number` coordinate
era5_var = era5_var.drop_vars('number')
# Rename coordinates to match the other datasets
era5_var = era5_var.rename({'valid_time': 'time', 'latitude': 'lat', 'longitude': 'lon'})
# Drop January 1st, as the t-1 variable will have null values on that day
era5_var = era5_var.drop_sel(time=f'{year}-01-01')
# Add the variable to the xarray
## Note: This assumes that the coordinates are the same
## which is true in this case as the lat lon arrays used to interpolate the
## chemra data were used to interpolate the ERA5 data upon their concatenation
chemra[variable] = era5_var[variable]
# Convert calendar to 'noleap' to remove February 29th
input_netcdf_xr = chemra.convert_calendar('noleap')
# Scale some variables to make orders of magnitude more similar
for variable in era5_vars_list:
if variable in scale_factors.keys():
input_netcdf_xr = scale_xr_var(input_netcdf_xr, variable, scale_factors[variable])
# Get a list of the variables in the dataset
datavars = list(input_netcdf_xr.data_vars)
all_datavars = list(input_netcdf_xr.data_vars)
# Remove `lsm` from the list of datavars
datavars.remove('lsm')
all_datavars.remove('lsm')
# Prepare data to be saved to numpy array files
if stage_2 and year > stage_2_cutoff:
# Assemble the names of the stage 2 variables
chemra_var_s2 = f'{chemra_var}_s2'
chemra_var_s2_tm1 = f'{chemra_var}_s2_tm1'
datavars_s2 = list(input_netcdf_xr.data_vars)
datavars_s2.remove('lsm')
datavars_s2.remove(chemra_var)
datavars_s2.remove(chemra_var_tm1)
# Prepare a separate numpy array for stage 2
xnp_s2 = np.ndarray([364, 56, 120, len(datavars_s2)]) # Adjust dimensions as needed
# Fill the numpy array with data from the xarray Dataset
for i in range(len(datavars_s2)):
xnp_s2[:, :, :, i] = input_netcdf_xr[datavars_s2[i]].values
# Remove stage 2 variables from the datavars list
datavars.remove(chemra_var_s2)
datavars.remove(chemra_var_s2_tm1)
else:
datavars_s2 = []
# print('all_datavars:',all_datavars)
# print('datavars:',datavars)
# Create an empty numpy array to hold the data
xnp = np.ndarray([364, 56, 120, len(datavars)]) # Adjust dimensions as needed
# Fill the numpy array with data from the xarray Dataset
for i in range(len(datavars)):
xnp[:, :, :, i] = input_netcdf_xr[datavars[i]].values
# Create a dictionary of global attributes
g_attr_dict={
'x_vars': all_datavars,
'x1_vars': datavars,
'x2_vars': datavars_s2,
'data_dir': data_dir,
'chemra_path': chemra_path,
'insitu_path': insitu_path,
'era5_path': era5_path,
'stages': stages,
}
# Write out results
if not isinstance(output_dir, type(None)):
# Create metadata file
meta_dict = make_input_metadata_file(
input_netcdf_xr,
output_dir=output_dir,
g_attrs=g_attr_dict,
)
# Save the data to file
# For writing out a numpy file
if output_format in ['npy', 'both']:
for stage in stages:
# Assemble the file path
output_filepath = f'inputfiles/{output_dir}/stage{stage}/x/X_{year}.npy'
# Make sure the output directory exists
make_file_path(output_filepath)
# Choose the correct array to save
if stage == 1:
np.save(output_filepath, xnp)
elif stage == 2:
np.save(output_filepath, xnp_s2)
# Output message
print(f"Created X input file for stage {stage} in {year}, saved to {output_filepath}")
if write_this_year:
# For writing out a netcdf file
if output_format in ['nc', 'both']:
# Assemble the file path
output_filepath = f'inputfiles/{output_dir}/{output_dir}.nc'
# Write data out to a netcdf
input_netcdf_xr = write_input_netcdf(
input_netcdf_xr,
output_filepath,
g_attr_dict=g_attr_dict,
overwrite=overwrite,
**kwargs,
)
print(f"Saved x input data to {output_filepath}")
return xr.load_dataset(output_filepath), g_attr_dict
return input_netcdf_xr, g_attr_dict
@unox.time_this
def fill_w_insitu(
xr_dataset,
insitu_filepath,
var='no2',
):
""" Add stage 2 for the variable in an xarray Dataset using available insitu data.
Given an xarray Dataset with reanalysis data, duplicate the specified variable and replace values of that duplicated variable when and where there is available insitu data in the provided filepath, to be used for stage 2 of training the unet.
Parameters
----------
xr_dataset : `xarray.Dataset`
The dataset containing reanalysis data.
insitu_filepath : `str`
Path to the CSV file containing insitu data.
var : `str`, optional
The variable to replace in the dataset. Default is 'no2'.
Returns
-------
`xarray.Dataset`
The updated dataset with insitu data replacing the specified variable.
"""
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Make a new variable to store the stage 2 data, filled with insitu
var_s2 = f'{var}_s2'
# Make a deep copy so that changes to `var_s2` don't affect `var`
xr_dataset[var_s2]= xr_dataset[var].copy(deep=True)
# Save the variable attributes
var_s2_attrs = xr_dataset[var_s2].attrs
# Modify the long_name attribute to indicate stage 2
var_s2_attrs['long_name'] = var_s2_attrs.get('long_name', var_s2) + ' Stage 2'
# Verify the insitu file path
insitu_filepath = verify_path(insitu_filepath)
# Load the insitu data
## Specific to the EPA csv format
insitu_data = pd.read_csv(insitu_filepath, parse_dates={'Date':['Date Local']}, index_col=['Date'], usecols=['Date Local', 'Latitude', 'Longitude', 'Arithmetic Mean'])
# insitu_data = csv_to_pd(insitu_filepath, is_US_EPA=True)
# One group for each day of data in the insitu data file
insitu_groups = insitu_data.groupby(['Date'])
# Get the keys (dates) from the groups
insitu_keys = [key for key in insitu_groups.groups.keys()]
# Narrow the domain to the selected latitude and longitude grid
lats, lons = unox.load_lats_lons()
in1 = xr_dataset[var_s2].where((xr_dataset.lat >= np.min(lats)), drop=True)
in2 = in1.where((in1.lon <= np.max(lons)), drop=True)
# Loop through each day in the insitu data
for i in range(len(insitu_keys)):
# Get the group for the ith day
new_group = insitu_groups.get_group((insitu_keys[i]),)
# Convert the group to a numpy array
group_array = new_group.to_numpy()
# Swap axes to get the shape (lat, lon, no2) for this day
group_array = group_array.swapaxes(0, 1)
# Get the latitude, longitude, and var values of the group
lt = group_array[0]
ln = group_array[1]
values = group_array[2]
# Select the day in the chemical reanalysis dataset
day = in2.sel(indexers={'time': insitu_keys[i]})
# Loop through each latitude in the group
for j in range(len(lt)):
# Find the nearest point in the chemical reanalysis dataset
## Tolerance is set to the grid cell size (1.125 degrees)
pt = day.sel({'lat': lt[j], 'lon': ln[j]}, method='nearest', tolerance=1.125)
# Replace the chemical reanalysis value with the insitu value
xr_dataset[var_s2].loc[{'time': insitu_keys[i], 'lon': pt.lon, 'lat': pt.lat}] = values[j]
# Add attribute to note which variable this is from
var_s2_attrs['insitu_filled_from'] = var
# Reapply the variable attributes
xr_dataset = set_var_attrs(xr_dataset, var_s2, var_s2_attrs)
return xr_dataset
[docs]
def make_all_y_input_files(
years=range(2005, 2021),
var='nox',
output_dir='test_input',
sort=True,
**kwargs,
):
"""Create y input files for multiple years.
Runs the `make_y_input_file` function for each year in the specified range.
Parameters
----------
years : iterable, optional
Years for which to create y input files. Default is range(2005, 2021).
var : str, optional
Variable to extract from the dataset. Default is 'nox'.
output_dir : str, optional
Directory inside `inputfiles/` where the output y input files will be saved.
Default is `'test_input'`.
sort : bool, optional
Whether to sort the xarray after making all y inputs. Sorting takes a long time.
Default is True.
**kwargs : dict, optional
Additional keyword arguments to pass to the `make_y_input_file` function.
Returns
-------
y_data_array : list of numpy.ndarray
List of y input data arrays for the specified years.
"""
# Assemble the filepath
output_filepath = f'inputfiles/{output_dir}/{output_dir}.nc'
# Create an empty array to hold y data
y_data_array = []
for year in years:
print(f"\tCreating y input data for {var} in {year}...")
y_data, g_attr_dict = make_y_input_file(
year=year,
var=var,
output_dir=output_dir,
write_this_year=False,
sort=False,
**kwargs,
)
y_data_array.append(y_data)
# Concatenate the datasets along the time dimension
print(f"Concatenating the y datasets")
input_netcdf_xr = xr.concat(y_data_array, dim='time')
# Sort the dataset by time
if sort:
print("Sorting the y data by time.")
input_netcdf_xr = input_netcdf_xr.sortby('time')
# Save the y data to a netcdf
print(f"Saving y inputs to {output_filepath}")
input_netcdf_xr = write_input_netcdf(
input_netcdf_xr,
output_filepath,
g_attr_dict=g_attr_dict,
**kwargs,
)
return xr.load_dataset(output_filepath)
[docs]
def make_all_x_input_files(
years=range(2005, 2021),
stage_2=True,
stage_2_cutoff=2013,
output_dir='test_input',
sort=True,
**kwargs,
):
""" Create x input files for multiple years and stages.
Run the `make_x_input_file` function for each year and stage in the specified ranges.
Parameters
----------
years : `iterable`, optional
Years for which to create x input files. Default is range(2005, 2021).
stage_2 : `bool`, optional
Whether or not to make stage 2 in addition to stage 1 for the input.
Default is True.
stage_2_cutoff : `int`, optional
Year after which the data will also be saved in stage 2. Default is 2013.
output_dir : `str`, optional
Directory inside `inputfiles/` where the output x input files will be saved.
Default is `'test_input'`.
sort : `bool`, optional
Whether to sort the xarray after making all x inputs. Sorting takes a long time.
Default is True.
**kwargs : `dict`, optional
Additional keyword arguments to pass to the `make_x_input_file` function.
Returns
-------
x_data_array : `list` of `xarray.Dataset`
List of x input data arrays for the specified years and stages.
"""
# Assemble the filepath
output_filepath = f'inputfiles/{output_dir}/{output_dir}.nc'
# Make sure the output directory exists
# if not os.path.exists(f'inputfiles/{output_dir}/stage{stage}/x'):
# os.makedirs(f'inputfiles/{output_dir}/stage{stage}/x')
x_data_array = []
for year in years:
print(f"\tCreating x input file for {year}...")
x_data, g_attr_dict = make_x_input_file(
year=year,
stage_2=stage_2,
stage_2_cutoff=stage_2_cutoff,
output_dir=output_dir,
write_this_year=False,
sort=False,
**kwargs,
)
x_data_array.append(x_data)
# Concatenate the datasets along the time dimension
print(f"Concatenating the x datasets")
input_netcdf_xr = xr.concat(x_data_array, dim='time')
# Sort the dataset by time
if sort:
print("Sorting the x data by time.")
input_netcdf_xr = input_netcdf_xr.sortby('time')
# Save the x data to a netcdf
print(f"Saving x inputs to {output_filepath}")
input_netcdf_xr = write_input_netcdf(
input_netcdf_xr,
output_filepath,
g_attr_dict=g_attr_dict,
**kwargs,
)
return xr.load_dataset(output_filepath)
@unox.time_this
def make_all_input_files(
output_dir='test_input',
sort=True,
**kwargs,
):
""" Create all input files for the Unet model.
This function combines the creation of y input files and x input files for both stages.
Parameters
----------
output_dir : `str`, optional
Directory inside `inputfiles/` where the output input files will be saved.
Default is `'test_input'`.
sort : `bool`, optional
Whether to sort the xarray after making all inputs. Sorting takes a long time.
Default is True.
**kwargs : `dict`, optional
Additional keyword arguments to pass to the `make_y_input_file` and
`make_x_input_file` functions.
Returns
-------
input_netcdf_xr : `xarray.Dataset`
The combined input data for both x and y.
"""
print("Note: It may take around an hour to generate all input files.")
# Make sure the output directory exists
if not os.path.exists(f'inputfiles/{output_dir}'):
os.makedirs(f'inputfiles/{output_dir}')
# Create y input data
print("Creating y input data...")
input_netcdf_xr = make_all_y_input_files(
output_dir=output_dir,
sort=False,
**kwargs,
)
# Create x input data
print("Creating x input data...")
input_netcdf_xr = make_all_x_input_files(
output_dir=output_dir,
sort=False,
**kwargs,
)
# Sort the dataset by time
if sort:
print("Sorting the y data by time.")
input_netcdf_xr = input_netcdf_xr.sortby('time')
print("Completed making all input files.")
return input_netcdf_xr
[docs]
def make_input_metadata_file(
input_set,
output_dir=None,
g_attrs=None,
overwrite=True,
):
""" Create a metadata file for the dataset in the given directory.
Gather the metadata from the given dataset, format it, and output to a clear-text file that can be easily read.
Parameters
----------
input_set : `str`, `xr.Dataset`, `uarray`
Directory inside `inputfiles/` where the dataset is found and
in which the metadata file will be saved, or the xarray Dataset
output_dir : `str`, `None`, optional
Directory inside `inputfiles/` where the metadata file will be saved.
If None, the metadata file will not be saved. Default is None.
g_attrs : `dict`, `None`, optional
Global attributes to use for the metadata file.
overwrite : `bool`, optional
Whether to overwrite an existing metadata file. Default is True.
Returns
-------
metadata_dict : `dict`
The metadata dictionary that was saved to the json file.
Has the format:
```json
{
"years": {
"x": [
2005,
...
2020
],
"y": [
2005,
...
2020
]
},
"y_var": "nox",
"emiss_dir": "/data/high_res/t106",
"emiss_pre": "nox_",
"emiss_post": "_t106_US.nc",
"nan_fill": 0,
"stage_2_cutoff": 2013,
"x_vars": [
"no2",
...
"ssrd"
],
"data_dir": "/data/high_res",
"chemra_path": "emacdonald/unet/datafiles/TROPESS/TROPESS_reanalysis_2hr_no2_sfc_",
"insitu_path": "US_EPA/NO2/daily_NO2/daily_42602_",
"era5_path": "ERA5concatenated",
"stages": [
1,
2
]
}
```
"""
# Verify argument types
if isinstance(input_set, str):
# Load the xr dataset from a uarray
xr_dataset = uarray(input_set, is_input_set=True).xr
elif isinstance(input_set, xr.Dataset):
xr_dataset = input_set
elif isinstance(input_set, uarray):
xr_dataset = input_set.xr
else:
raise TypeError(f"(make_input_metadata_file) `input_set` must be a string, xarray.Dataset, or uarray. Got type: {type(input_set)}")
if not isinstance(output_dir, (str, type(None))):
raise TypeError(f"(make_input_metadata_file) `output_dir` must be a string or None. Got type: {type(output_dir)}")
if not isinstance(g_attrs, (dict, type(None))):
raise TypeError(f"(make_input_metadata_file) `g_attrs` must be a dict or None. Got type: {type(g_attrs)}")
if not isinstance(overwrite, bool):
raise TypeError(f"(make_input_metadata_file) `overwrite` must be a bool. Got type: {type(overwrite)}")
# If the metadata file already exists, load it
if not isinstance(output_dir, type(None)):
# Assemble the filepath for the metadata file
output_filepath = 'inputfiles/' + output_dir + '/input_metadata.json'
if os.path.exists(output_filepath):
with open(output_filepath, 'r') as f:
metadata_dict = json.load(f)
isNew = False
else:
isNew = True
# Make sure the output directory exists
make_file_path(output_filepath)
else:
isNew = True
if isNew:
metadata_dict = {
'years': {
'x': [],
'y': [],
},
}
# Get a list of years present in the dataset
years = get_years(xr_dataset)
# Check for global attributes
if isinstance(g_attrs, type(None)):
g_attrs = xr_dataset.attrs
else:
# Update the types of the given `g_attrs` to be compatible with JSON
for key in g_attrs.keys():
if isinstance(g_attrs[key], np.integer):
g_attrs[key] = int(g_attrs[key])
elif isinstance(g_attrs[key], np.floating):
g_attrs[key] = float(g_attrs[key])
elif isinstance(g_attrs[key], np.ndarray):
g_attrs[key] = g_attrs[key].tolist()
elif isinstance(g_attrs[key], (np.bool_, bool)):
if g_attrs[key] == True:
g_attrs[key] = 'True'
else:
g_attrs[key] = 'False'
# Check whether `lsm` is in the list of data variables
if 'lsm' in list(xr_dataset.data_vars):
# Add `lsm: True` to the global attributes
g_attrs['lsm'] = 'True'
else:
g_attrs['lsm'] = 'False'
# Add select global attributes to the metadata dictionary
for g_attr in [
'y_var',
'x_vars',
'x1_vars',
'x2_vars',
'description',
'data_dir',
'chemra_path',
'insitu_path',
'era5_path',
'modification_date',
'emiss_dir',
'emiss_pre',
'emiss_post',
'nan_fill',
'stage_2_cutoff',
'stages',
'lsm',
]:
if g_attr in g_attrs:
# Add to metadata dictionary
metadata_dict[g_attr] = g_attrs[g_attr]
# If x_vars or y_var, also add year
if g_attr == 'x_vars':
metadata_dict['years']['x'] = metadata_dict['years']['x'] + years
elif g_attr == 'y_var':
metadata_dict['years']['y'] = metadata_dict['years']['y'] + years
## Add the attributes and years to the metadata dictionary
# Check if the attrs match and decide whether to overwrite
# if isNew == False and metadata_dict[x_or_y+'_attrs'] != attr_dict:
# warnings.warn(f'Metadata attributes for {x_or_y} {year} input files do not match the existing metadata. Overwriting existing attributes.')
# Sort the list of years in ascending order, removing duplicates
metadata_dict['years']['x'] = sorted(list(set(metadata_dict['years']['x'])))
metadata_dict['years']['y'] = sorted(list(set(metadata_dict['years']['y'])))
# Output the metadata dictionary to a json file
if not isinstance(output_dir, type(None)):
with open(output_filepath, 'w') as file:
file.write(json.dumps(metadata_dict, indent=4))
return metadata_dict
[docs]
def make_input_config(
config_name,
input_set = 'no2_sample_input',
grid_size = [56, 120],
x_vars = [
'no2',
'no2_tm1',
'u10',
'v10',
'blh',
'sp',
'skt',
't2m',
'ssrd',
],
stage_2 = True,
stage_2_cutoff = 2013,
lsm_vars = [
# 'nox',
# 'no2',
# 'no2_tm1',
# 'no2_s2',
# 'no2_s2_tm1',
# 'u10',
# 'v10',
# 'blh',
# 'sp',
# 'skt',
# 't2m',
# 'ssrd',
],
zfi_vars = [
# 'nox',
# 'no2',
# 'no2_tm1',
# 'no2_s2',
# 'no2_s2_tm1',
# 'u10',
# 'v10',
# 'blh',
# 'sp',
# 'skt',
# 't2m',
# 'ssrd',
],
overwrite=False,
**kwargs,
):
""" Create an input configuration file.
Create the input configuration file for using input data with the Unet model.
Parameters
----------
config_name : `str`
Name of the configuration file to be created.
input_set : `str` or `xr.Dataset`, optional
Directory inside `inputfiles/` where the dataset is found, or the xarray Dataset.
Default is 'no2_sample_input'.
grid_size : `list` of `int`, optional
The number of grid cells to have in [latitude, longitude] when running the Unet model.
Default is [56, 120].
x_vars : `list` of `str`, optional
List of variable names to be used as input features for the model.
Default is a list of common meteorological and chemical variables.
stage_2 : `bool`, optional
Whether or not stage 2 should be run with the Unet model.
Default is True.
stage_2_cutoff : `int`, optional
Year after which stage 2 data will be used.
Default is 2013.
lsm_vars : `list` of `str`, optional
List of variable names that should use land-sea mask.
Default is ['no2', 'no2_tm1'].
zfi_vars : `list` of `str`, optional
List of variable names that should use zero-fill mask.
Default is ['t2m'].
**kwargs : `dict`, optional
Returns
-------
config_dict : `dict`
The configuration dictionary that was saved to the json file.
"""
# Verify argument types
if not isinstance(config_name, str):
raise TypeError(f"(make_input_config) `config_name` must be a string. Got type: {type(config_name)}")
# Verify input_set is a string or xr.Dataset
if isinstance(input_set, str):
# Check whether the given input_set exists in 'inputfiles/'
xr_path = f'inputfiles/{input_set}/{input_set}.nc'
if not os.path.exists(xr_path):
raise ValueError(f"(make_input_config) File {xr_path} does not exist.")
# Load the dataset
xr_dataset = xr.load_dataset(xr_path)
elif isinstance(input_set, xr.Dataset):
xr_dataset = input_set
else:
raise TypeError(f"(make_input_config) `input_set` must be a string or xarray.Dataset. Got type: {type(input_set)}")
if not isinstance(grid_size, list):
raise TypeError(f"(make_input_config) `grid_size` must be a list of integers. Got type: {type(grid_size)}")
if not isinstance(x_vars, list):
raise TypeError(f"(make_input_config) `x_vars` must be a list of strings. Got type: {type(x_vars)}")
if not isinstance(stage_2, bool):
raise TypeError(f"(make_input_config) `stage_2` must be a boolean. Got type: {type(stage_2)}")
if not isinstance(stage_2_cutoff, int):
raise TypeError(f"(make_input_config) `stage_2_cutoff` must be an integer. Got type: {type(stage_2_cutoff)}")
if not isinstance(lsm_vars, list):
raise TypeError(f"(make_input_config) `lsm_vars` must be a list of strings. Got type: {type(lsm_vars)}")
if not isinstance(zfi_vars, list):
raise TypeError(f"(make_input_config) `zfi_vars` must be a list of strings. Got type: {type(zfi_vars)}")
if not isinstance(overwrite, bool):
raise TypeError(f"(make_input_config) `overwrite` must be a boolean. Got type: {type(overwrite)}")
# Verify the dataset
xr_dataset = verify_dataset(xr_dataset)
# Verify that grid_size has exactly 2 integers
if not len(grid_size) == 2:
raise ValueError(f"(make_input_config) Expected `grid_size` to have a length of 2. Got length of {len(grid_size)}: {grid_size}")
else:
if isinstance(grid_size[0], int):
n_lats = grid_size[0]
else:
raise TypeError(f"(make_input_config) Number of latitudes in `grid_size` must be an integer. Got type: {type(grid_size[0])}")
if isinstance(grid_size[1], int):
n_lons = grid_size[1]
else:
raise TypeError(f"(make_input_config) Number of longitudes in `grid_size` must be an integer. Got type: {type(grid_size[1])}")
# Verify that `grid_size` is not larger than the lat-lon grid in xr_dataset
xr_n_lats = xr_dataset.sizes['lat']
xr_n_lons = xr_dataset.sizes['lon']
if n_lats > xr_n_lats:
raise ValueError(f"(make_input_config) Requested length of latitude grid ({n_lats}) cannot exceed length of latitude dimension in the given netcdf ({xr_n_lats}).")
if n_lons > xr_n_lons:
raise ValueError(f"(make_input_config) Requested length of longitude grid ({n_lons}) cannot exceed length of longitude dimension in the given netcdf ({xr_n_lons}).")
# Verify that all x_vars are in the dataset
for var in x_vars:
verify_var(xr_dataset, var)
# Verify that stage 2 exists in the dataset
if not 2 in xr_dataset.attrs.get('stages', []):
stage_2 = False
print('Stage 2 not found in dataset. Setting stage_2 to False in configuration.')
# Verify that stage_2_cutoff is a year that exists in the dataset
years = get_years(xr_dataset)
if stage_2_cutoff not in years:
raise ValueError(f"(make_input_config) `stage_2_cutoff` {stage_2_cutoff} not found in dataset years: {years}")
# Verify that the lsm_vars are in the dataset
for var in lsm_vars:
verify_var(xr_dataset, var)
# Verify that the zfi_vars are in the dataset
for var in zfi_vars:
verify_var(xr_dataset, var)
# Build the dictionary
config_dict = {
'input_set': input_set,
'grid_size': grid_size,
'x_vars': x_vars,
'stage_2': stage_2,
'stage_2_cutoff': stage_2_cutoff,
'lsm_vars': lsm_vars,
'zfi_vars': zfi_vars,
}
# Check whether the configuration file already exists
config_filepath = f'inputfiles/_input_configs/{config_name}.json'
if os.path.exists(config_filepath) and overwrite == False:
# Ask whether to overwrite the existing file
overwrite = unox.interpret_user_input(input(f'Configuration file {config_filepath} already exists. Overwrite? (y/n)'))
if not overwrite:
print('Aborting configuration file creation.')
return
# Save the configuration dictionary to a json file
with open(config_filepath, 'w') as file:
file.write(json.dumps(config_dict, indent=4))
print(f'Saved configuration file to {config_filepath}')
return config_dict
[docs]
def copy_input_files(
source_input_set,
output_dir,
keep_vars='all',
start_date=None,
end_date=None,
overwrite=True,
**kwargs,
):
""" Copy an input set to a new location.
Create a copy of the input netCDF and `input_metadata.json` file from the specified source in a new directory, optionally filtering the netCDF to only include specified variables and date range.
Parameters
----------
source_input_set : `str`
Name of the source input set located in `inputfiles/`.
output_dir : `str`
Name of the output directory inside `inputfiles/` where the new input set will be copied to.
keep_vars : `list` of `str` or `all`, optional
List of variable names to keep in the copied netCDF. If `all`, all variables are kept.
Default is `all`.
start_date : `str`, `None`, or `np.datetime64`, optional
Date from which to start the copied data. If None, the start date equals that of the original file.
Expected format is 'YYYY-MM-DDTHH:MM:SS' or 'YYYY-MM-DD'.
Default is None.
end_date : `str`, `None`, or `np.datetime64`, optional
Date at which to end the copied data. If None, the end date equals that of the original file.
Expected format is 'YYYY-MM-DDTHH:MM:SS' or 'YYYY-MM-DD'.
Default is None.
**kwargs : `dict`, optional
Returns
-------
new_xr_dataset : `xarray.Dataset`
The copied and filtered xarray Dataset that is saved to the new location.
"""
# Verify argument types
if not isinstance(source_input_set, str):
raise TypeError(f"(copy_input_files) `source_input_set` must be a string. Got type: {type(source_input_set)}")
if not isinstance(output_dir, str):
raise TypeError(f"(copy_input_files) `output_dir` must be a string. Got type: {type(output_dir)}")
if not isinstance(keep_vars, list) and not keep_vars == 'all':
if isinstance(keep_vars, str):
# Turn `keep_vars` into a list if it's a single string
keep_vars = [keep_vars]
else:
raise TypeError(f"(copy_input_files) `keep_vars` must be a list of strings or 'all'. Got type: {type(keep_vars)}")
if not (isinstance(start_date, (str, type(None), np.datetime64))):
raise TypeError(f"(copy_input_files) `start_date` must be a string, None, or np.datetime64. Got type: {type(start_date)}")
if not (isinstance(end_date, (str, type(None), np.datetime64))):
raise TypeError(f"(copy_input_files) `end_date` must be a string, None, or np.datetime64. Got type: {type(end_date)}")
# Verify that the source and output directories are not the same
if source_input_set == output_dir:
raise ValueError(f"(copy_input_files) `source_input_set` and `output_dir` cannot be the same. Both are '{source_input_set}'.")
# Check whether the output directory already exists
try:
verify_path(f"inputfiles/{output_dir}/")
except FileNotFoundError as e:
assert True, f"(copy_input_files) Output directory 'inputfiles/{output_dir}/' does not exist. Creating it."
else:
if overwrite == False:
# Ask whether to overwrite the existing directory
overwrite = unox.interpret_user_input(input(f"Output directory 'inputfiles/{output_dir}/' already exists. Overwrite? (y/n)"))
if not overwrite:
print('Aborting input file copy.')
return
print(f"Overwriting existing input files in {output_dir}")
remove_non_empty_directory(f"inputfiles/{output_dir}/")
# Make the output directory
make_file_path(f"inputfiles/{output_dir}/")
# Load the source dataset as a `uarray`
source_uarr = uarray(source_input_set, is_input_set=True)
# Check whether to filter variables
if keep_vars != 'all':
# Verify that all variables in `keep_vars` are in the source dataset
for var in keep_vars:
verify_var(source_uarr.xr, var)
# Drop variables not in `keep_vars`
vars_to_drop = [var for var in source_uarr.xr.data_vars if var not in keep_vars]
source_uarr.xr = source_uarr.xr.drop_vars(vars_to_drop)
# Update the variable list attributes
for var_list_attr in ['x_vars', 'x1_vars', 'x2_vars']:
if var_list_attr in source_uarr.xr.attrs:
filtered_var_list = [var for var in source_uarr.xr.attrs[var_list_attr] if var in keep_vars]
source_uarr.xr.attrs[var_list_attr] = filtered_var_list
if 'y_var' in source_uarr.xr.attrs:
if source_uarr.xr.attrs['y_var'] not in keep_vars:
source_uarr.xr.attrs['y_var'] = 'None'
# Check whether to filter dates
if not isinstance(start_date, type(None)) or not isinstance(end_date, type(None)):
# Convert start_date and end_date to strings of YYYY-MM-DD if they are np.datetime64
if isinstance(start_date, np.datetime64):
start_date = str(pd.to_datetime(start_date).date())
elif isinstance(start_date, type(None)):
this_start = source_uarr.xr['time'].values[0]
start_date = f"{this_start.year:04d}-{this_start.month:02d}-{this_start.day:02d}"
if isinstance(end_date, np.datetime64):
end_date = str(pd.to_datetime(end_date).date())
elif isinstance(end_date, type(None)):
this_end = source_uarr.xr['time'].values[-1]
end_date = f"{this_end.year:04d}-{this_end.month:02d}-{this_end.day:02d}"
# Verify that `start_date` is before `end_date`
if start_date >= end_date:
raise ValueError(f"(copy_input_files) `start_date` must be before `end_date`. Got start_date: {start_date}, end_date: {end_date}.")
# Select the date range
source_uarr.xr = source_uarr.xr.sel(time=slice(start_date, end_date), drop=True)
# Update the modification time attribute
source_uarr.xr.attrs['modification_date'] = pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
# Write out results
if not isinstance(output_dir, type(None)):
# Write the new dataset to the output directory
source_uarr.xr.to_netcdf(f"inputfiles/{output_dir}/{output_dir}.nc")
# Create metadata file
meta_dict = make_input_metadata_file(
source_uarr.xr,
output_dir=output_dir,
g_attrs=source_uarr.xr.attrs,
)
return source_uarr