Source code for bfgn.reporting.visualizations.networks

import logging
from typing import List

import keras
import matplotlib.pyplot as plt
import numpy as np

from bfgn.reporting import samples


plt.switch_backend('Agg')  # Needed for remote server plotting

_logger = logging.getLogger(__name__)


[docs]def plot_model_summary(model: keras.Model) -> List[plt.Figure]: stringlist = ['CNN Architecture Summary'] model.summary(print_fn=lambda x: stringlist.append(x)) model_summary_string = "\n".join(stringlist) fig, axes = plt.subplots(figsize=(8.5, 11), nrows=1, ncols=1) plt.text(0, 0, model_summary_string, **{'fontsize': 8, 'fontfamily': 'monospace'}) plt.axis('off') return [fig]
[docs]def plot_network_feature_progression( sampled: samples.Samples, compact: bool = False, max_pages: int = 10, max_filters: int = 10 ) -> List[plt.Figure]: if not sampled.is_model_trained: _logger.debug('Network feature progression not plotted; model is not trained') return list() return [_plot_sample_feature_progression(sampled, idx_sample, compact, max_filters) for idx_sample in range(min(max_pages, sampled.num_samples))]
def _plot_sample_feature_progression( sampled: samples.Samples, idx_sample: int, compact: bool, max_filters: int ) -> plt.Figure: sample_features = np.expand_dims(sampled.trans_features[idx_sample, :], 0) sample_responses = np.expand_dims(sampled.trans_responses[idx_sample, :], 0) # Run through the model and grab any Conv2D layer (other layers could also be grabbed as desired) pred_set = [] layer_names = [] pred_set.append(sample_features) layer_names.append('Feature(s)') for _l in range(0, len(sampled.model.layers)): if (isinstance(sampled.model.layers[_l], keras.layers.convolutional.Conv2D)): im_model = keras.models.Model( inputs=sampled.model.layers[0].output, outputs=sampled.model.layers[_l].output) pred_set.append(im_model.predict(sample_features)) layer_names.append(sampled.model.layers[_l].name) pred_set.append(sample_responses) layer_names.append('Response(s)') # Calculate the per-filter standard deviation, enables plots to preferentially show more interesting layers pred_std = [] for _l in range(len(pred_set)): pred_std.append([np.std(np.squeeze(pred_set[_l][..., x])) for x in range(0, pred_set[_l].shape[-1])]) # Get spacing things worked out and the figure initialized step_size = 1 / float(len(pred_set)+1) if (compact): h_space_fraction = 0.3 else: h_space_fraction = 0.05 image_size = min(step_size * (1-h_space_fraction), 1 / max_filters * (1-h_space_fraction)) h_space_size = step_size*h_space_fraction fig = plt.figure(figsize=(max(max_filters, len(pred_set)), max(max_filters, len(pred_set)))) top = 0 # Step through each layer in the network for _l in range(0, len(pred_set)): # Step through each filter, up to the max for _iii in range(0, min(pred_set[_l].shape[-1], max_filters)): if (compact): ip = [(_l+0.5)*step_size + _iii*h_space_size/5., _iii*image_size*0.2] else: ip = [(_l+0.5)*step_size, _iii*image_size*(1+h_space_fraction)] # Get the indices sorted by filter std, as a proxy for interest ordered_pred_std = np.argsort(pred_std[_l])[::-1] # prep the image tp = np.squeeze(pred_set[_l][0, :, :, ordered_pred_std[_iii]]) # Plot! ax = fig.add_axes([ip[0], ip[1], image_size, image_size], zorder=max_filters+1-_iii) top = max(top, ip[1]+image_size) plt.imshow(tp, vmin=np.nanpercentile(tp, 0), vmax=np.nanpercentile(tp, 100)) _adjust_axis(ax) if (_iii == 0): plt.xlabel(layer_names[_l]) tit = 'Network Feature Progression Visualization ({} sequence {})'.format(sampled.data_sequence_label, idx_sample) if (compact): tit = 'Compact ' + tit ax = fig.add_axes([0.5, top + image_size/2., 0.01, 0.01], zorder=100) ax.axis('off') ax.text(0, 0, tit, ha='center', va='center') return fig def _adjust_axis(lax): for sp in lax.spines: lax.spines[sp].set_color('white') lax.spines[sp].set_linewidth(2) lax.set_xticks([]) lax.set_yticks([])