Source code for bfgn.experiments.models

import logging
import os
from typing import Union

import keras


_logger = logging.getLogger(__name__)


DEFAULT_FILENAME_MODEL = 'model.h5'


[docs]def load_model(filepath: str, custom_objects: dict = None) -> Union[keras.models.Model, None]: """Loads model from serialized file. Args: filepath: Filepath from which model is loaded. custom_objects: Custom objects necessary to build model, including loss functions. Returns: Keras model object if it exists at path. """ if not os.path.exists(filepath): _logger.debug('Model not loaded; file does not exist at {}'.format(filepath)) return None _logger.debug('Load model from {}'.format(filepath)) return keras.models.load_model(filepath, custom_objects=custom_objects)
[docs]def save_model(model: keras.models.Model, filepath: str) -> None: """Saves model to serialized file. Args: model: Keras model object. filepath: Filepath to which model is saved. Returns: None. """ if not os.path.exists(os.path.dirname(filepath)): _logger.debug('Create directory to save model at {}'.format(os.path.dirname(filepath))) os.makedirs(os.path.dirname(filepath)) _logger.debug('Save model to {}'.format(filepath)) model.save(filepath, overwrite=True)