Source code for bfgn.reporting.reports

import logging
import os
from typing import List

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
from scipy import stats
import numpy as np

from bfgn.configuration import configs
from bfgn.data_management import data_core
from bfgn.experiments import experiments
from bfgn.reporting import samples
from bfgn.reporting.visualizations import histories, logs, model_performance, networks, samples as samples_viz


plt.switch_backend('Agg')  # Needed for remote server plotting

_logger = logging.getLogger(__name__)

_FILENAME_MODEL_REPORT = 'model_report.pdf'

_LABEL_CATEGORICAL = 'CATEGORICAL'
_LABEL_CONTINUOUS = 'CONTINUOUS'


[docs]class Reporter(object): data_container = None experiment = None def __init__( self, data_container: data_core.DataContainer, experiment: experiments.Experiment, config: configs.Config ) -> None: errors = config.get_human_readable_config_errors(include_sections=['model_reporting']) assert not errors, errors self.data_container = data_container self.experiment = experiment
[docs] def create_model_report(self) -> None: filepath_report = os.path.join(self.experiment.config.model_training.dir_out, _FILENAME_MODEL_REPORT) with PdfPages(filepath_report) as pdf: _logger.info('Plot Summary') self._add_figures(self.plot_model_summary(), pdf) _logger.info('Plot Training Sequence Figures') sampled = samples.Samples( self.data_container.training_sequence, self.experiment.model, self.experiment.config, self.experiment.is_model_trained, self.data_container.feature_band_types, self.data_container.response_band_types, data_sequence_label='Training' ) self._create_model_report_for_sequence(sampled, pdf) _logger.info('Plot Validation Sequence Figures') validation_sampled = samples.Samples( self.data_container.validation_sequence, self.experiment.model, self.experiment.config, self.experiment.is_model_trained, self.data_container.feature_band_types, self.data_container.response_band_types, data_sequence_label='Validation' ) self._create_model_report_for_sequence(validation_sampled, pdf) if ('R' in self.data_container.response_band_types): self._add_figures(self.plot_regression_deviation(sampled, validation_sampled), pdf) _logger.info('Plot Model History') self._add_figures(self.plot_model_history(), pdf) self._add_figures(self.plot_log_warnings_and_errors(), pdf)
def _create_model_report_for_sequence(self, sampled: samples.Samples, pdf: PdfPages) -> None: if self.experiment.is_model_trained and self._get_response_data_types() is _LABEL_CATEGORICAL: self._add_figures(self.plot_classification_report(sampled), pdf) self._add_figures(self.plot_confusion_matrix(sampled), pdf, tight=False) self._add_figures(self.plot_sample_histograms(sampled), pdf) self._add_figures(self.plot_samples(sampled), pdf) if self.experiment.config.model_reporting.network_progression_show_full: self._add_figures(self.plot_network_feature_progression(sampled, compact=False), pdf) if self.experiment.config.model_reporting.network_progression_show_compact: self._add_figures(self.plot_network_feature_progression(sampled, compact=True), pdf) self._add_figures(self.plot_spatial_error(sampled), pdf) def _add_figures(self, figures: List[plt.Figure], pdf: PdfPages, tight: bool = True) -> None: for fig in figures: pdf.savefig(fig, bbox_inches='tight' if tight else None)
[docs] def plot_confusion_matrix(self, sampled: samples.Samples) -> List[plt.Figure]: return model_performance.plot_confusion_matrix(sampled)
[docs] def plot_model_summary(self) -> List[plt.Figure]: return networks.plot_model_summary(self.experiment.model)
[docs] def plot_model_history(self) -> List[plt.Figure]: return histories.plot_history(self.experiment.history)
[docs] def plot_log_warnings_and_errors(self) -> List[plt.Figure]: return logs.plot_log_warnings_and_errors(self.data_container.config, self.experiment.config)
[docs] def plot_network_feature_progression( self, sampled: samples.Samples, compact: bool, max_pages: int = None, max_filters: int = None ) -> List[plt.Figure]: return networks.plot_network_feature_progression( sampled, compact=compact, max_pages=max_pages or self.experiment.config.model_reporting.network_progression_max_pages, max_filters=max_filters or self.experiment.config.model_reporting.network_progression_max_filters )
[docs] def plot_samples( self, sampled: samples.Samples, max_pages: int = None, max_samples_per_page: int = None, max_features_per_page: int = None, max_responses_per_page: int = None ) -> List[plt.Figure]: if self._get_response_data_types() is _LABEL_CATEGORICAL: plotter = samples_viz.plot_classification_samples elif self._get_response_data_types() is _LABEL_CONTINUOUS: plotter = samples_viz.plot_regression_samples max_responses_per_page = max_responses_per_page or self.experiment.config.model_reporting.max_responses_per_page return plotter( sampled, max_pages=max_pages or self.experiment.config.model_reporting.max_pages_per_figure, max_samples_per_page=max_samples_per_page or self.experiment.config.model_reporting.max_samples_per_page, max_features_per_page=max_features_per_page or self.experiment.config.model_reporting.max_features_per_page, max_responses_per_page=max_responses_per_page )
[docs] def plot_sample_histograms( self, sampled: samples.Samples, max_responses_per_page: int = None ) -> List[plt.Figure]: max_responses_per_page = max_responses_per_page or self.experiment.config.model_reporting.max_responses_per_page return samples_viz.plot_sample_histograms( sampled, max_responses_per_page=max_responses_per_page )
[docs] def plot_regression_deviation( self, train_sampled: samples.Samples, val_sampled: samples.Samples = None ) -> List[plt.Figure]: return_figs = [] for _r in range(len(self.data_container.response_band_types)): fig = plt.figure(figsize=(5*(1+int(val_sampled is not None)), 4.5)) gs1 = gridspec.GridSpec(1, 1 + int(val_sampled is not None)) if (self.data_container.response_band_types[_r] == 'R'): bounds = self.parity_plot(train_sampled.raw_predictions, train_sampled.raw_responses, plt.subplot(gs1[0, 0]), _r) plt.subplot(gs1[0, 0]).set_title('Training') if (val_sampled is not None): self.parity_plot(val_sampled.raw_predictions, val_sampled.raw_responses, plt.subplot(gs1[0, 1]), _r, bounds=bounds) plt.subplot(gs1[0, 1]).set_title('Validation') fig.suptitle('Response ' + str(_r)) return_figs.append(fig) return return_figs
[docs] def parity_plot(self, pred_Y: np.array, test_Y: np.array, ax: plt.Axes, response_ind: int, bins: int = 200, bounds=[]): loss_window_radius = self.experiment.config.data_build.loss_window_radius window_radius = self.experiment.config.data_build.window_radius buffer = int(window_radius - loss_window_radius) if buffer == 0: test_Y = test_Y[..., response_ind].flatten() pred_Y = pred_Y[..., response_ind].flatten() else: test_Y = test_Y[:, buffer:-buffer, buffer:-buffer, response_ind].flatten() pred_Y = pred_Y[:, buffer:-buffer, buffer:-buffer, response_ind].flatten() slope, intercept, r_value, p_value, std_err = stats.linregress(np.squeeze(test_Y), np.squeeze(pred_Y)) mad = str(round(np.mean(np.abs(pred_Y-test_Y)), 3)) rmse = str(round(np.sqrt(np.mean(np.power(pred_Y-test_Y, 2))), 3)) r2o = str(round(1 - np.sum(np.power(test_Y - pred_Y, 2)) / (np.sum(np.power(test_Y - np.mean(pred_Y), 2))), 3)) r2 = str(round(r_value**2, 3)) if (len(bounds) == 0): pmin = np.min(test_Y) pmax = np.max(test_Y) else: pmin = bounds[0] pmax = bounds[1] z, xrange, yrange = np.histogram2d(test_Y, pred_Y, bins=bins, range=[[pmin, pmax], [pmin, pmax]]) ax.patch.set_facecolor('white') ax.imshow(np.log(z.T), extent=(pmin, pmax, pmin, pmax), cmap=cm.hot, origin='lower', interpolation='nearest') ax.plot([pmin, pmax], [pmin, pmax], color='blue', lw=2, ls='--') fs = 8 ax.text(pmin+(pmax-pmin)*0.05, pmin+(pmax-pmin)*0.95, 'MAD: ' + mad, fontsize=fs) ax.text(pmin+(pmax-pmin)*0.05, pmin+(pmax-pmin)*0.90, 'RMSE: ' + rmse, fontsize=fs) ax.text(pmin+(pmax-pmin)*0.05, pmin+(pmax-pmin)*0.85, 'R${^2}$${_\mathrm{o}}$: ' + r2o, fontsize=fs) ax.set_xlabel('Actual') ax.set_ylabel('Predicted') ax.set_xlim([pmin, pmax]) ax.set_ylim([pmin, pmax]) return [pmin, pmax]
[docs] def plot_spatial_error( self, sampled: samples.Samples, max_pages: int = None, max_responses_per_page: int = None ) -> List[plt.Figure]: if self._get_response_data_types() is _LABEL_CATEGORICAL: plotter = model_performance.plot_spatial_classification_error elif self._get_response_data_types() is _LABEL_CONTINUOUS: plotter = model_performance.plot_spatial_regression_error max_responses_per_page = max_responses_per_page or self.experiment.config.model_reporting.max_responses_per_page return plotter( sampled, max_pages=max_pages or self.experiment.config.model_reporting.max_pages_per_figure, max_responses_per_page=max_responses_per_page )
[docs] def plot_classification_report(self, sampled: samples.Samples) -> List[plt.Figure]: return model_performance.plot_classification_report(sampled)
def _get_response_data_types(self) -> str: data_types = set([dt for file_dts in self.experiment.config.raw_files.response_data_type for dt in file_dts]) if data_types == {'C'}: return _LABEL_CATEGORICAL elif data_types == {'R'}: return _LABEL_CONTINUOUS elif data_types == {'C', 'R'}: raise AssertionError('Reporter does not currently support mixed response data types.') else: raise AssertionError('Unexpected data types found: {}.'.format(data_types))