Source code for unox.data

import numpy as np
import xarray as xr
import pandas as pd
import warnings
import os
import re
from datetime import datetime

from unox.HPC.data0.paths import verify_path
from unox.HPC.data0.latlon import shift_lon_arr
from unox.HPC.data0.verify_dtype import verify_number
from unox.HPC.data0.verify_dataset import verify_dataset
from unox.HPC.data0.dataset import csv_to_xr

# Define the default latitude and longitude extents for this project
DEFAULT_LAT_MIN = 11
DEFAULT_LAT_MAX = 75
DEFAULT_LON_MIN = -175
DEFAULT_LON_MAX = -39
# Needs to be [north, west, south, east] for the cdsapi call in download_era5.py
DEFAULT_EXTENT = [DEFAULT_LAT_MAX, DEFAULT_LON_MIN, DEFAULT_LAT_MIN, DEFAULT_LON_MAX]

[docs] def generate_lats_lons( dataset='datafiles/sample_data/2019u10.nc', output_dir='datafiles/', ): """ Generate latitude and longitude arrays from the given dataset. Create the `lats.npy` and `lons.npy` files from the latitude and longitude values in the given dataset. They were originally generated from the ERA5 concatenated data files created by the `download_era5` and `concatenate` scripts in the `datafiles` directory. Parameters ---------- dataset : `str` or `xarray.Dataset`, optional The filepath to the dataset or an xarray Dataset object from which to extract latitude and longitude values. output_dir : `str`, optional The directory in which to save the generated `lats.npy` and `lons.npy` files. Returns ------- lats : `numpy.ndarray` The latitude values extracted from the dataset. lons : `numpy.ndarray` The longitude values extracted from the dataset. """ # Load or verify the dataset if isinstance(dataset, str): xr_dataset = load_dataset(dataset) else: xr_dataset = verify_dataset(dataset) # Get the latitude and longitude values lats, lons = get_lats_lons(xr_dataset) # Save the latitude and longitude values as numpy arrays np.save(output_dir+'lats.npy', lats) np.save(output_dir+'lons.npy', lons) return lats, lons
[docs] def get_extent( xr_dataset=None, lats=None, lons=None, shift_lons=False, **kwargs, ): """ Get the latitude and longitude extent of the given xarray dataset. Find the maximum and minimum latitude and longitude values in the given dataset. Parameters ---------- xr_dataset : `xarray.Dataset` or `xarray.DataArray`, optional The xarray data of which to find the extent. lats : `numpy.ndarray`, optional The latitude values to use instead of those in the dataset. lons : `numpy.ndarray`, optional The longitude values to use instead of those in the dataset. shift_lons : `bool`, optional If True, shift the longitude values based on the PM_centered kwarg. **kwargs : keyword arguments Additional keyword arguments to pass to `verify_dataset()` and `shift_lon_arr()`. Returns ------- extent : `tuple` A tuple of np.float64 in the form (lat_min, lat_max, lon_min, lon_max). Examples -------- >>> nox = xr.open_dataset('datafiles/nox_2019_t106_US.nc') >>> extent = get_extent(nox) (24.112, 58.878, -126.0, -59.625) >>> lats, lons = get_lats_lons(nox) >>> extent = get_extent(lats=lats, lons=lons) (24.112, 58.878, -126.0, -59.625) """ # If no xarray dataset is provided, use the latitude and longitude values if isinstance(xr_dataset, type(None)): if isinstance(lats, type(None)) or isinstance(lons, type(None)): raise ValueError(f"(get_extent) Either `xr_dataset` or both `lats` and `lons` must be provided.") # Find the min and max lat and lon values lat_min = np.unique(np.min(lats))[0] lat_max = np.unique(np.max(lats))[0] # Shift the longitude values if specified if shift_lons: lons = shift_lon_arr(lons, **kwargs) lon_min = np.unique(np.min(lons))[0] lon_max = np.unique(np.max(lons))[0] else: # Verify the xr_dataset xr_dataset = verify_dataset(xr_dataset, **kwargs) # Find the min and max lat and lon values # Use np.unique to ensure that the values are unique and take only the first value lat_min = np.unique(xr_dataset.lat.min().values)[0] lat_max = np.unique(xr_dataset.lat.max().values)[0] # Shift the longitude values if specified if shift_lons: lons = shift_lon_arr(xr_dataset.lon.values, **kwargs) else: lons = xr_dataset.lon.values lon_min = np.unique(lons.min())[0] lon_max = np.unique(lons.max())[0] # Verify that latitude values are in the range [-90, 90] lat_max = verify_lat(lat_max) lat_min = verify_lat(lat_min) lon_max = verify_lon(lon_max) lon_min = verify_lon(lon_min) # Return the extent as a tuple return (lat_min, lat_max, lon_min, lon_max)
[docs] def get_lats_lons( xr_dataset, **kwargs, ): """ Get the latitude and longitude values from the given dataset. Load the latitude and longitude values from the given dataset and return them as numpy arrays. Parameters ---------- xr_dataset : `xarray.Dataset` or `xarray.DataArray` The xarray data to verify. **kwargs : keyword arguments Additional keyword arguments to pass to `verify_dataset()`. Returns ------- lats : `numpy.ndarray` Array of latitude values. lons : `numpy.ndarray` Array of longitude values. Examples -------- >>> lats, lons = get_lats_lons() """ # Verify the xr_dataset xr_dataset = verify_dataset(xr_dataset, **kwargs) # Get the latitude and longitude values lats = xr_dataset.lat.values lons = xr_dataset.lon.values # Verify the latitude and longitude values map(verify_lat, lats) map(verify_lon, lons) return lats, lons
[docs] def get_latlon_resolution( xr_dataset=None, lats=None, lons=None, **kwargs, ): """ Get the latitude and longitude resolution of the given dataset. Calculate the resolution of coordinate values in the dataset to find the resolution in latitude and longitude separately. Parameters ---------- xr_dataset : `xarray.Dataset` or `xarray.DataArray`, optional The xarray data of which to find the extent. lats : `numpy.ndarray`, optional The latitude values to use instead of those in the dataset. lons : `numpy.ndarray`, optional The longitude values to use instead of those in the dataset. **kwargs : keyword arguments Additional keyword arguments to pass to `verify_dataset()` and `get_lats_lons()`. Returns ------- lat_res : `str` The resolution in latitude. lon_res : `str` The resolution in longitude. Examples -------- >>> nox = xr.open_dataset('datafiles/nox_2019_t106_US.nc') >>> lat_res, lon_res = get_latlon_resolution(nox) (0.25, 0.25) """ # If given an xarray dataset if not isinstance(xr_dataset, type(None)): # Verify the xr_dataset xr_dataset = verify_dataset(xr_dataset, **kwargs) # Get the latitude and longitude values lats, lons = get_lats_lons(xr_dataset, **kwargs) # Calculate the resolution in latitude and longitude ## Make sure to sort the values first lat_res = np.unique(np.diff(np.sort(lats))) if len(lat_res) != 1: # Find the average and standard deviation of the latitude resolution ## Make sure to sort the values first lat_res = np.diff(np.sort(lats)) lat_res_mean = np.mean(lat_res) lat_res_std = np.std(lat_res) lat_res = f"{lat_res_mean} ± {lat_res_std}" else: lat_res = str(lat_res[0]) ## Make sure to sort the values first lon_res = np.unique(np.diff(np.sort(lons))) if len(lon_res) != 1: # Find the average and standard deviation of the longitude resolution ## Make sure to sort the values first lon_res = np.diff(np.sort(lons)) lon_res_mean = np.mean(lon_res) lon_res_std = np.std(lon_res) lon_res = f"{lon_res_mean} ± {lon_res_std}" else: lon_res = str(lon_res[0]) # Return the resolution in latitude and longitude return lat_res, lon_res
[docs] def clean_num_list( val_list, ): """ Clean the list of values that cannot be converted to a number. For each value in the list, if it cannot be converted to a number, all instances of that value are removed from the list. Parameters ---------- val_list : `list` The list of values to clean. Returns ------- return_list : `list` The cleaned list of values. Examples -------- >>> val_list = clean_list([1, 2, 3, "4", 5]) [1, 2, 3, 5] >>> val_list = clean_list([1, 2, 3, np.nan, None, np.inf, -np.inf]) [1, 2, 3] """ # Create an empty list to store cleaned values return_list = [] for val in val_list: if verify_number(val): # Add this value to the return list return_list.append(val) # If the list is empty after removing invalid numbers, raise an error if len(return_list) == 0: raise ValueError("(clean_num_list) No valid numbers in the input list.") return return_list
[docs] def verify_lat( lat_val, ): """ Verify that the given latitude value is valid. If the given latitude value is within the range [-90, 90], return that value. Otherwise, raise a ValueError. Parameters ---------- lat_val : `float` The latitude value to verify. Returns ------- lat_val : `float` The verified latitude value. Examples -------- >>> lat_val = verify_lat(45.0) 45.0 >>> lat_val = verify_lat(-100.0) ValueError: Latitude value must be in the range [-90, 90]. """ if not verify_number(lat_val): raise ValueError(f"(verify_lat) `lat_val` must be a number. Got type: {type(lat_val)}") if np.isnan(lat_val): raise ValueError(f"(verify_lat) `lat_val` must not be NaN.") if lat_val < -90 or lat_val > 90: raise ValueError(f"(verify_lat) `lat_val` must be in the range [-90, 90]. Got: {lat_val}") return lat_val
[docs] def verify_lon( lon_val, PM_centered=None, ): """ Verify that the given longitude value is valid. If the given longitude value is within the range [-180, 180], return that value. Otherwise, raise a ValueError. Parameters ---------- lon_val : `float` The longitude value to verify. PM_centered : `bool`, optional If None, verify that the longitude value is in the range [-180, 360]. If True, verify that the longitude value is in the range [-180, 180]. If False, verify that the longitude value is in the range [0, 360]. Returns ------- lon_val : `float` The verified longitude value. Examples -------- >>> lon_val = verify_lon(45.0) 45.0 >>> lon_val = verify_lon(-200.0) ValueError: Longitude value must be in the range [-180, 180]. """ if not verify_number(lon_val): raise ValueError(f"(verify_lon) `lon_val` must be a number. Got type: {type(lon_val)}") if np.isnan(lon_val): raise ValueError(f"(verify_lon) `lon_val` must not be NaN.") if isinstance(PM_centered, type(None)): if lon_val < -180 or lon_val > 360: raise ValueError(f"(verify_lon) `lon_val` must be in the range [-180, 360]. Got: {lon_val}") elif PM_centered: if lon_val < -180 or lon_val > 180: raise ValueError(f"(verify_lon) `lon_val` must be in the range [-180, 180]. Got: {lon_val}") else: if lon_val < 0 or lon_val > 360: raise ValueError(f"(verify_lon) `lon_val` must be in the range [0, 360]. Got: {lon_val}") return lon_val
[docs] def get_vminmax( arrays, ): """ Get the minimum and maximum values across the given arrays. Flatten and concatenate the given arrays and return the minimum and maximum values, ignoring NaN values. Parameters ---------- arrays : `list` of `numpy.ndarray` The arrays to get the minimum and maximum values from. Returns ------- vmin : `float` The minimum value across the arrays. vmax : `float` The maximum value across the arrays. Examples -------- >>> arrays = [np.array([1, 2, 3]), np.array([4, 5, 6])] >>> vmin, vmax = get_vminmax(arrays) (1, 6) """ # Flatten and concatenate the arrays flat_arrays = np.concatenate([arr.flatten() for arr in arrays]) # Get the minimum and maximum values # Catch warning for all-NaN arrays with warnings.catch_warnings(): warnings.filterwarnings('error', category=RuntimeWarning) try: vmin = np.nanmin(flat_arrays) vmax = np.nanmax(flat_arrays) except RuntimeWarning as e: raise ValueError(f"(get_vminmax) {e}. Does input array contain any non-NaN values?") return vmin, vmax
[docs] def get_max_abs_val( val_list, ): """ Get the maximum absolute value from the given list. Remove invalid numbers from the given list of values, then take the absolute value of the remaining values, and return the largest. Parameters ---------- val_list : `list` of numbers or `numpy.ndarray` The list of values to get the maximum absolute value from. Returns ------- max_abs : `float` The maximum absolute value of the given values. Examples -------- >>> max_abs = get_max_abs_val(-11, 6) 6 >>> vmin, vmax = get_vminmax([np.array([1, 2, -3]), np.array([4, 5, -6])]) >>> max_abs = get_max_abs_val(vmin, vmax) 5 """ # Clean the list of values val_list = clean_num_list(val_list) # Convert the input values to a numpy array, if it is not already val_list = np.array(val_list) return np.max(np.abs(val_list))
[docs] def restrict_domain( arrs_to_restrict, lats, lons, restricting_data, ): """ Restrict the domain of the given arrays. Restrict the domain of the given arrays to the same extent as that in the restricting data. The values of lats, lons are the latitude and longitude values of the arrays to restrict. Parameters ---------- arrs_to_restrict : `list` of `numpy.ndarray` The arrays to restrict in latitude and longitude. lats : `numpy.ndarray` The latitude values of the arrays to restrict. lons : `numpy.ndarray` The longitude values of the arrays to restrict. restricting_data : `xarray.Dataset` or `xarray.DataArray` The dataset to restrict the arrays to. Returns ------- arrs_to_return : `list` of `numpy.ndarray` The restricted arrays. lat_r : `numpy.ndarray` The latitude values of the restricting data. lon_r : `numpy.ndarray` The longitude values of the restricting data. Examples -------- >>> stage1 = np.load(get_pred_data(stage=1, 'HPC_run'='no2_example_run', 'year'=2019)) >>> lats, lons = load_lats_lons() >>> nox = xr.open_dataset('datafiles/nox_2019_t106_US.nc') >>> stage1_restricted = restrict_domain([nox], lats, lons, nox) """ # Get the latitude and longitude values from the restricting data lat_r, lon_r = get_lats_lons(restricting_data) # I feel like this should work, but I can't figure it out right now this_extent = get_extent(restricting_data) # Find indices of lats and lons that are in the restricting data ## within 0.1 degrees of the latmin = np.where(np.abs(lats-np.min(lat_r))<0.1)[0][0] latmax = np.where(np.abs(lats-np.max(lat_r))<0.1)[0][0] + 1 lonmin = np.where(np.abs(lons-np.min(lon_r))<0.1)[0][0] lonmax = np.where(np.abs(lons-np.max(lon_r))<0.1)[0][0] + 1 # Narrow the data to just this region arrs_to_return = [] for arr in arrs_to_restrict: arrs_to_return.append(arr[:,latmin:latmax,lonmin:lonmax,:]) return arrs_to_return, lat_r, lon_r
[docs] def match_domains( xr_a, xr_b, require_equal=True, require_len_gt_1=True, ): """ Restrict the domain of the given xarray Datasets to match each other. Find the maximum extent covered by both given datasets and restrict both to match. Requires that at least some of the actual latitude and longitude values are present in both datasets. Parameters ---------- xr_a : `xarray.Dataset` or `xarray.DataArray` The first dataset. xr_b : `xarray.Dataset` or `xarray.DataArray` The second dataset. require_equal : `bool`, optional Whether to check that the latitude and longitude values in the two datasets are exactly the same after trimming. Default is `True`. require_len_gt_1 : `bool`, optional Whether to check to make sure that the trimmed datasets have more than 1 value in each of the lat and lon dimensions, to catch cases where the datasets only overlap at a single point, resulting in either the lat or lon dimension being dropped. Default is `True`. Returns ------- xr_a : `xarray.Dataset` or `xarray.DataArray` The first dataset, with the latitude and longitude extents trimmed to match `xr_b`. xr_b : `xarray.Dataset` or `xarray.DataArray` The first dataset, with the latitude and longitude extents trimmed to match `xr_a`. """ # Verify argument types xr_a = verify_dataset(xr_a, check_time=False) xr_b = verify_dataset(xr_b, check_time=False) # Get the extent of xr_a (a_lat_min, a_lat_max, a_lon_min, a_lon_max) = get_extent(xr_a, check_time=False) # Get the extent of xr_b (b_lat_min, b_lat_max, b_lon_min, b_lon_max) = get_extent(xr_b, check_time=False) # Find the maximum extent covered by both datasets lat_min = max(a_lat_min, b_lat_min) lat_max = min(a_lat_max, b_lat_max) lon_min = max(a_lon_min, b_lon_min) lon_max = min(a_lon_max, b_lon_max) # Verify these numbers make sense if lat_min > lat_max: raise ValueError(f"(match_domains) `lat_min` ({lat_min}) larger than `lat_max` ({lat_max}).") if lon_min > lon_max: raise ValueError(f"(match_domains) `lon_min` ({lon_min}) larger than `lon_max` ({lon_max}).") # Trim both datasets tr_xr_a = xr_a.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max)) tr_xr_b = xr_b.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max)) # Verify these two datasets have the same latitude and longitude values if require_equal == True: lats_a, lons_a = get_lats_lons(tr_xr_a, check_time=False) lats_b, lons_b = get_lats_lons(tr_xr_b, check_time=False) if not np.array_equal(lats_a, lats_b): raise ValueError(f"(match domains) Latitude values do not match between the two datasets.") if not np.array_equal(lons_a, lons_b): raise ValueError(f"(match domains) Longitude values do not match between the two datasets.") # Verify that the xarray datasets have more than 1 value in each of the lat and lon dimensions if require_len_gt_1 == True: if len(tr_xr_a.lat) <= 1: raise ValueError(f"(match_domains) `xr_a` has 1 or fewer values in the lat dimension after trimming.") if len(tr_xr_a.lon) <= 1: raise ValueError(f"(match_domains) `xr_a` has 1 or fewer values in the lon dimension after trimming.") if len(tr_xr_b.lat) <= 1: raise ValueError(f"(match_domains) `xr_b` has 1 or fewer values in the lon dimension after trimming.") if len(tr_xr_b.lon) <= 1: raise ValueError(f"(match_domains) `xr_b` has 1 or fewer values in the lon dimension after trimming.") return tr_xr_a, tr_xr_b
[docs] def verify_npy( array, ): """ Determine if a variable or file holds a valid numpy array. If a numpy array or a path to a file containing a numpy array was passed, return True. Otherwise, raise a TypeError, ValueError or FileNotFoundError. Parameters ---------- array : `numpy.array` or `string` A numpy array or a path to a file containing a numpy array. Returns ------- nparray : `np.ndarray` The array being passed or pointed to as a np.ndarray. Examples -------- >>> import numpy as np >>> from tempfile import NamedTemporaryFile >>> arr = np.array([1, 2, 3]) >>> verify_npy(arr) array([1, 2, 3]) >>> with NamedTemporaryFile(suffix=".npy", delete=False) as f: ... np.save(f.name, arr) ... verify_npy(f.name) array([1, 2, 3]) >>> with NamedTemporaryFile(suffix=".txt", mode="w", delete=False) as f: ... _ = f.write("1,2,3\\n4,5,6") >>> loaded = verify_npy(f.name) >>> isinstance(loaded, np.ndarray) True """ if isinstance(array, str): if os.path.isdir(array): raise FileNotFoundError(f"(verify_npy) Path {array} leads to a folder.") if not os.path.isfile(array): raise FileNotFoundError(f"(verify_npy) File {array} does not exist.") ext = os.path.splitext(array)[1].lower() try: if ext == ".npy": return np.load(array, allow_pickle=True) elif ext in [".txt", ".csv"]: try: nparray = np.loadtxt(array, delimiter=",") if len(nparray) == 0: raise ValueError(f"(verify_npy) File {array} does not contain a readable numpy array.") return nparray except Exception: try: nparray = np.genfromtxt(array, delimiter=",", skip_header=1) if len(nparray) == 0: raise ValueError(f"(verify_npy) File {array} does not contain a readable numpy array.") return nparray except Exception as e: raise ValueError(f"(verify_npy) File {array} does not contain a readable numpy array.") else: raise ValueError(f"(verify_npy) File {array} does not contain a readable numpy array.") except Exception: raise ValueError(f"(verify_npy) File {array} does not contain a readable numpy array.") elif isinstance(array, np.ndarray): return array else: raise TypeError(f"(verify_npy) `array` is not a numpy array. Got type: {type(array)}")
[docs] def get_num_from_string( str, ): """ Extract numbers from a string. If the string contains numbers, return those numbers in a list. Otherwise, raise a ValueError. Parameters ---------- str : `str` The string to extract the number from. Returns ------- nums : `list` of `int` or `float` A list of numbers extracted from the string. Examples -------- >>> num = get_num_from_string("There are 42.0 apples and 3 oranges.") [42, 3] >>> num = get_num_from_string("No number here") ValueError: No number found in the string. """ # Verify that the input is a string if not isinstance(str, type('')): raise TypeError(f"(get_num_from_string) `str` must be a string. Got type: {type(str)}") # Find all numbers in the string using regular expressions nums = re.findall(r"[-+]?\d*\.\d+|\d+", str) # Convert the numbers to integers or floats nums = [float(num) if '.' in num else int(num) for num in nums] return nums
[docs] def get_DOY( date, ): """ Get the day of the year from a date. Extract the day of the year from a given date and return it as an integer. Parameters ---------- date : `np.datetime64` or `str` The date to extract the day of the year from. Returns ------- doy : `int` The day of the year of the date. Examples -------- >>> get_DOY('2019-12-20') 354 >>> get_DOY(np.datetime64('2020-01-01')) 1 """ # If date is a string, try to parse it as a date using a couple different formats if isinstance(date, str): try: doy = datetime.strptime(date, '%Y-%m-%dT%H:%M:%S').timetuple().tm_yday except: try: doy = datetime.strptime(date, '%Y-%m-%d').timetuple().tm_yday except: raise ValueError(f"(get_doy) Invalid date format: {date}. Expected 'YYYY-MM-DD' or 'YYYY-MM-DDTHH:MM:SS'.") # If date is a numpy datetime64, convert it to a date and get the day of the year elif isinstance(date, np.datetime64): doy = date.astype('datetime64[D]').astype(object).timetuple().tm_yday else: raise TypeError(f"(get_doy) `date` must be a np.datetime64 or str. Got type: {type(date)}") return int(doy)
[docs] def increment_month( month, increment, ): """ Increment the month by a given number of months. Increment the month by the given number of months, wrapping around if the increment goes beyond December (12). Parameters ---------- month : `int` or `str` The month to increment (1 for January, 2 for February, ..., 12 for December). increment : `int` or `str` The number of months to increment by. Returns ------- new_month : `int` or `str` The new month after incrementing. The type will match the type of `month`. increment_year : `bool` Whether the increment caused a year change. True if the month is December and increment > 0. Examples -------- >>> increment_month(1, 2) 3, False >>> increment_month(11, 3) 2, True >>> increment_month('5', '7') '12', False """ # Note return type return_type = type(month) # Ensure month is valid if isinstance(month, str): try: month = int(month) except: raise TypeError(f"(increment_month) `month` must be an integer. Got type: {type(month)}") if not isinstance(month, int) or month < 1 or month > 12: raise ValueError(f"(increment_month) `month` must be an integer between 1 and 12. Got: {month}") # Ensure increment is an integer if isinstance(increment, str): try: increment = int(increment) except: raise TypeError(f"(increment_month) `increment` must be an integer. Got type: {type(increment)}") if not isinstance(increment, int): raise TypeError(f"(increment_month) `increment` must be an integer. Got type: {type(increment)}") # Calculate the new month new_month = (month - 1 + increment) % 12 + 1 # Determine if the increment caused a year change if month + increment > 12: increment_year = True else: increment_year = False # Return the new month in the same type as the input if return_type == str: return str(new_month), increment_year else: return new_month, increment_year
[docs] def get_YMD_from_date( this_date, ): """ Get the year, month, and day from a date. Extract the year, month, and day from a given date and return them as integers. Parameters ---------- this_date : `np.datetime64` or `str` The date to extract the year, month, and day from. Returns ------- year : `int` The year of the date. month : `int` The month of the date. day : `int` The day of the date. Examples -------- >>> get_YMD_from_date('2019-12-20') (2019, 12, 20) >>> get_YMD_from_date(np.datetime64('2020-01-01')) (2020, 1, 1) """ # Ensure that the input is a valid date type if isinstance(this_date, str): try: this_date = np.datetime64(this_date) except ValueError: raise ValueError(f"(get_YMD_from_date) `this_date` is an invalid date string: {this_date}. Must be in 'YYYY-MM-DD' format.") if not isinstance(this_date, np.datetime64): raise TypeError(f"(get_YMD_from_date) `this_date` must be a np.datetime64 or str. Got type: {type(this_date)}") # Extract the year, month, and day from the date year = this_date.astype(object).year month = this_date.astype(object).month day = this_date.astype(object).day return year, month, day
[docs] def get_increment_info( increment, ): """ Get the increment value and unit from a string. Parse a string that represents an increment in the format 'XD', 'XM', or 'XY', where X is an integer and D, M, or Y are the units for days, months, or years respectively. Parameters ---------- increment : `np.timedelta64` or `str` The amount of time to add to the date. If a string, it should be in the format 'XD', 'XM', or 'XY' where X is an integer and D, M, or Y are the units for days, months, or years respectively. Returns ------- value : `int` The numeric value of the increment. unit : `str` The unit of the increment ('D', 'M', or 'Y'). Raises ------ ValueError If the increment string is not in the expected format. TypeError If the increment is not a np.timedelta64 or str. Examples -------- >>> value, unit = get_increment_info('20D') (20, 'D') >>> value, unit = get_increment_info(np.timedelta64(20, 'D')) (20, 'D') >>> value, unit = get_increment_info('3M') (3, 'M') >>> value, unit = get_increment_info(np.timedelta64(2, 'Y')) (2, 'Y') """ # Check if the increment is a np.timedelta64 if isinstance(increment, np.timedelta64): # Determine the unit and value based on the dtype if increment.dtype == 'timedelta64[D]': value = increment.astype('timedelta64[D]').astype(int) unit = 'D' elif increment.dtype == 'timedelta64[M]': value = increment.astype('timedelta64[M]').astype(int) unit = 'M' elif increment.dtype == 'timedelta64[Y]': value = increment.astype('timedelta64[Y]').astype(int) unit = 'Y' else: raise ValueError(f"(get_increment_info) Unsupported timedelta64 type for `increment.dtype`. Use days, months, or years. Got type: {increment.dtype}") elif isinstance(increment, str): # Match the string format using regex match = re.match(r'(\d+)([DMY])', increment) if not match: raise ValueError(f"(get_increment_info) Invalid `increment` format: {increment}. Use 'XD', 'XM', or 'XY' where X is an integer and D, M, or Y are the units for days, months, or years respectively.") value, unit = match.groups() value = int(value) # Convert to integer else: raise TypeError(f"(get_increment_info) `increment` must be a np.timedelta64 or str. Got type: {type(increment)}") return value, unit
[docs] def add_amount_to_date( this_date, increment, keep_within_year=False, ): """ Add an amount of time to a date. Add the given amount of time to the given date and return the new date. Parameters ---------- this_date : `np.datetime64` or `str` The date to add the time to. increment : `np.timedelta64` or `str` The amount of time to add to the date. If a string, it should be in the format 'XD', 'XM', or 'XY' where X is an integer and D, M, or Y are the units for days, months, or years respectively. keep_within_year : `bool`, optional If True, the new date will be kept within the same year as `this_date`. Returns ------- new_date : `np.datetime64` or `str` The new date after adding the time. Examples -------- >>> add_amount_to_date('2019-12-20', '20D') '2020-01-09' >>> add_amount_to_date(np.datetime64('2019-12-25'), np.timedelta64(20, 'D')) np.datetime64('2020-01-14') """ # Make sure the inputs are of the correct type if not isinstance(this_date, (np.datetime64, str)): raise TypeError(f"(add_amount_to_date) `this_date` must be a np.datetime64 or str. Got type: {type(this_date)}") if not isinstance(increment, (np.timedelta64, str)): raise TypeError(f"(add_amount_to_date) `increment` must be a np.timedelta64 or str. Got type: {type(increment)}") # If the date is a string, convert it to a np.datetime64 if isinstance(this_date, str): this_date = np.datetime64(this_date) return_type = str else: return_type = np.datetime64 # Determine whether to add days, months, or years if isinstance(increment, np.timedelta64): # Find whether to add days, or months / years if increment.dtype == 'timedelta64[D]': add_days = True elif increment.dtype == 'timedelta64[M]': add_days = False value = increment.astype('timedelta64[M]').astype(int) unit = 'M' elif increment.dtype == 'timedelta64[Y]': add_days = False value = increment.astype('timedelta64[Y]').astype(int) unit = 'Y' else: raise ValueError(f"(add_amount_to_date) Unsupported timedelta64 type for `increment.dtype`. Use days, months, or years. Got type: {increment.dtype}") elif isinstance(increment, str): # Match the string format using regex match = re.match(r'(\d+)([DMY])', increment) if not match: raise ValueError(f"(add_amount_to_date) Invalid `increment` format: {increment}. Use 'XD', 'XM', or 'XY' where X is an integer and D, M, or Y are the units for days, months, or years respectively.") value, unit = match.groups() # Find whether to add days, or months / years if unit == 'D': add_days = True else: add_days = False else: raise TypeError(f"(add_amount_to_date) `increment` must be a np.timedelta64 or str. Got type: {type(increment)}") # If adding days if add_days: if isinstance(increment, str): increment = np.timedelta64(int(value), unit) # Add the time to the date new_date = this_date + increment else: # Get the Y, M, D from the date this_year, this_month, this_day = get_YMD_from_date(this_date) if isinstance(increment, str): if unit == 'M': # If adding months, increment the month new_month, increment_year = increment_month(this_month, int(value)) if increment_year: # If the increment caused a year change, increment the year this_year += 1 # Create the new date with the incremented month and year new_date = np.datetime64(f"{this_year}-{new_month:02d}-{this_day:02d}") elif unit == 'Y': # If adding years, increment the year this_year += int(value) # Create the new date with the incremented year new_date = np.datetime64(f"{this_year}-{this_month:02d}-{this_day:02d}") else: if unit == 'M': # If adding months, increment the month new_month, increment_year = increment_month(this_month, int(value)) if increment_year: # If the increment caused a year change, increment the year this_year += 1 # Create the new date with the incremented month and year new_date = np.datetime64(f"{this_year}-{new_month:02d}-{this_day:02d}") elif unit == 'Y': # If adding years, increment the year this_year += value # Create the new date with the incremented year new_date = np.datetime64(f"{this_year}-{this_month:02d}-{this_day:02d}") # If keep_within_year is True, ensure the new date is within the same year if keep_within_year: # Get the year from the original date original_year = this_date.astype(object).year # Set the new date to the last day of the same year if it exceeds it if new_date.astype(object).year > original_year: new_date = np.datetime64(f"{original_year}-12-31") # Or, to the first day of the same year if it is before it elif new_date.astype(object).year < original_year: new_date = np.datetime64(f"{original_year}-01-01") # If the return type is a string, convert the date back to a string if return_type == str: new_date = str(new_date) # Return the new date return new_date