training ======== .. py:module:: training Functions --------- .. autoapisummary:: training.begin_training training.make_predictions Module Contents --------------- .. py:function:: begin_training(savedir, stage, xtrain, ytrain, xvalid, yvalid, unet, batch_size=30, n_epochs=250, save_format='keras') Begin training the Unet model. :param savedir: Directory to save outputs. :type savedir: `str` :param stage: The stage number (1 or 2). :type stage: `int` :param xtrain: Training input features. :type xtrain: `np.ndarray` :param ytrain: Training target variables. :type ytrain: `np.ndarray` :param xvalid: Validation input features. :type xvalid: `np.ndarray` :param yvalid: Validation target variables. :type yvalid: `np.ndarray` :param unet: The Unet model to be trained. :type unet: `Unet` :param batch_size: Batch size for training. :type batch_size: `int`, optional :param n_epochs: Number of epochs for training. :type n_epochs: `int`, optional :param save_format: Format to save the model ('h5', 'keras', or 'both'). :type save_format: `str`, optional :returns: **unet** -- The trained Unet model. :rtype: Unet .. py:function:: make_predictions(uarr, unet, config_dict, config_path, output_metadata, stage=1) Prepare the input data for the model. Get the training data from the input NetCDF dataset as numpy arrays and concatenate them along the time dimension. :param uarr: The dataset of the input NetCDF file. :type uarr: `unox.uarray` :param unet: The Unet model to be trained. :type unet: `Unet` :param config_dict: A dictionary containing the configuration. :type config_dict: `dict` :param config_path: Path to the input configuration JSON file used to make `config_dict`. :type config_path: `str` :param output_metadata: The dictionary of metadata describing the output of a model run. :type output_metadata: `dict` :param stage: The stage of the data to plot (1 or 2). :type stage: `int` :returns: * **xtrain** (*np.ndarray*) -- Concatenated training input features. * **ytrain** (*np.ndarray*) -- Concatenated training target variables. * **output_metadata** (*dict*) -- The dictionary of metadata describing the output of a model run with values added for `train_years` and `unet_build_shape`.