Source code for bfgn.experiments.callbacks

import datetime
import logging
import os
from typing import List

import keras

from bfgn.configuration import configs
from bfgn.experiments import experiments, histories


_logger = logging.getLogger(__name__)


_DIR_TENSORBOARD = 'tensorboard'


[docs]class HistoryCheckpoint(keras.callbacks.Callback): """ A custom Keras callback for checkpointing model training history and associated information. """ config = None existing_history = None period = None verbose = None epochs_since_last_save = None epoch_begin = None def __init__(self, config: configs.Config, existing_history=None, period=1, verbose=0): super().__init__() self.config = config if existing_history is None: existing_history = dict() self.existing_history = existing_history self.period = period self.verbose = verbose self.epochs_since_last_save = 0 self.epoch_begin = None
[docs] def on_train_begin(self, logs=None): _logger.debug('Beginning network training') for key in ('epoch_start', 'epoch_finish'): self.existing_history.setdefault(key, list()) self.existing_history['train_start'] = datetime.datetime.now() super().on_train_begin(logs)
[docs] def on_train_end(self, logs=None): _logger.debug('Ending network training') self.existing_history['train_finish'] = datetime.datetime.now() self._save_history()
[docs] def on_epoch_begin(self, epoch, logs=None): _logger.debug('Beginning new epoch') self.epoch_begin = datetime.datetime.now()
[docs] def on_epoch_end(self, epoch, logs=None): _logger.debug('Ending epoch') # Update times epoch_end = datetime.datetime.now() self.existing_history['epoch_start'].append(self.epoch_begin) self.existing_history['epoch_finish'].append(epoch_end) self.epoch_begin = None # Save if necessary self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: _logger.debug('Checkpointing') self._save_history() self.epochs_since_last_save = 0
def _save_history(self): _logger.debug('Save model history') if hasattr(self.model, 'history'): new_history = self.model.history.history elif hasattr(self.model, 'model'): assert hasattr(self.model.model, 'history'), \ 'Parallel models are doing something unusual with histories. Tell Nick and let\'s debug.' new_history = self.model.model.history combined_history = histories.combine_histories(self.existing_history, new_history) histories.save_history(combined_history, experiments.get_history_filepath(self.config))
[docs]def get_model_callbacks(config: configs.Config, existing_history: dict) -> List[keras.callbacks.Callback]: """Creates model callbacks from a bfgn config. Args: config: bfgn config. existing_history: Existing model training history if the model has already been partially or completely trained. Returns: List of model callbacks. """ callbacks = [ HistoryCheckpoint( config=config, existing_history=existing_history, period=config.callback_general.checkpoint_periods, verbose=config.model_training.verbosity, ), keras.callbacks.ModelCheckpoint( experiments.get_model_filepath(config), period=config.callback_general.checkpoint_periods, verbose=config.model_training.verbosity, ), ] if config.callback_early_stopping.use_callback: callbacks.append( keras.callbacks.EarlyStopping( monitor=config.callback_early_stopping.loss_metric, min_delta=config.callback_early_stopping.min_delta, patience=config.callback_early_stopping.patience, restore_best_weights=True ), ) if config.callback_reduced_learning_rate.use_callback: callbacks.append( keras.callbacks.ReduceLROnPlateau( monitor=config.callback_reduced_learning_rate.loss_metric, factor=config.callback_reduced_learning_rate.factor, min_delta=config.callback_reduced_learning_rate.min_delta, patience=config.callback_reduced_learning_rate.patience, ), ) if config.callback_tensorboard.use_callback: dir_out = os.path.join(config.model_training.dir_out, _DIR_TENSORBOARD) callbacks.append( keras.callbacks.TensorBoard( dir_out, histogram_freq=config.callback_tensorboard.histogram_freq, write_graph=config.callback_tensorboard.write_graph, write_grads=config.callback_tensorboard.write_grads, write_images=config.callback_tensorboard.write_images, update_freq=config.callback_tensorboard.update_freq, ), ) if config.callback_general.use_terminate_on_nan: callbacks.append(keras.callbacks.TerminateOnNaN()) return callbacks