import logging
from typing import Union
from matplotlib import patches
import matplotlib.pyplot as plt
import matplotlib.cm
import numpy as np
from bfgn.reporting import samples
from bfgn.reporting.visualizations import colormaps
plt.switch_backend('Agg') # Needed for remote server plotting
_logger = logging.getLogger(__name__)
_LABEL_CATEGORICAL_RESPONSES = 'categorical_responses'
_FLOAT_DECIMALS = 2
"""
Note: we need to know how many plots to generate for some reports, so that we can specify the number of columns in a
figure. We want to keep the actual plotting code synced to the code that calculates the number of columns to avoid
bugs. However, we also do not want to be writing a ton of boilerplate code to check whether plots should be created.
The repeated check for `if ax is not None` is a quick and easy way to avoid plotting when simply trying to get the
number of columns in a single place, even though it feels like it's in the wrong place. This also makes the code a
bit more readable, but I realize this is subjective. Better solutions are welcomed.
"""
[docs]def plot_raw_features(
sampled: samples.Samples,
idx_sample: int,
idx_feature: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
_plot_sample_attribute(sampled, idx_sample, idx_feature, 'raw_features', ax, add_xlabel, add_ylabel)
[docs]def plot_raw_responses(
sampled: samples.Samples,
idx_sample: int,
idx_response: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
_plot_sample_attribute(sampled, idx_sample, idx_response, 'raw_responses', ax, add_xlabel, add_ylabel)
[docs]def plot_categorical_responses(
sampled: samples.Samples,
idx_sample: int,
idx_response: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
_plot_sample_attribute(sampled, idx_sample, idx_response, _LABEL_CATEGORICAL_RESPONSES, ax, add_xlabel, add_ylabel)
[docs]def plot_raw_predictions(
sampled: samples.Samples,
idx_sample: int,
idx_response: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
_plot_sample_attribute(sampled, idx_sample, idx_response, 'raw_predictions', ax, add_xlabel, add_ylabel)
add_internal_window_to_subplot(sampled, ax)
def _plot_sample_attribute(
sampled: samples.Samples,
idx_sample: int,
idx_axis: int,
attribute_name: str,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
# Return early if axis is not provided or attribute is not available
if ax is None:
return
if attribute_name != _LABEL_CATEGORICAL_RESPONSES:
is_attribute_available = getattr(sampled, attribute_name) is not None
else:
is_attribute_available = sampled.raw_responses is not None
if not is_attribute_available:
_logger.debug('Not plotting {}; attribute not available'.format(
attribute_name if attribute_name != _LABEL_CATEGORICAL_RESPONSES else 'raw_responses'))
return
# Categorical responses need special handling, in that we use a particular color scheme
if attribute_name != _LABEL_CATEGORICAL_RESPONSES:
attribute_values = getattr(sampled, attribute_name)[idx_sample, :, :, idx_axis]
min_, max_ = getattr(sampled, attribute_name + '_range')[idx_axis, :]
colormap = colormaps.COLORMAP_DEFAULT
else:
attribute_values = np.argmax(sampled.raw_responses[idx_sample, ...], axis=-1)
min_ = 0
max_ = sampled.num_responses - 1
colormap = colormaps.COLORMAP_CATEGORICAL
assert not colormaps.check_is_categorical_colormap_repeated(sampled.num_responses), \
'Number of categorical responses is greater than length of colormap, figure out how to handle gracefully'
# Handle nan conversions for transformed data
if attribute_name in ('trans_features', 'trans_responses', 'trans_predictions'):
attribute_values = attribute_values.copy()
attribute_values[attribute_values == sampled.data_sequence.nan_replacement_value] = np.nan
# Plot
ax.imshow(attribute_values, vmin=min_, vmax=max_, cmap=colormap)
ax.set_xticks([])
ax.set_yticks([])
if add_xlabel:
x_label = '\n'.join(word.capitalize() for word in attribute_name.split('_') if word != 'raw').rstrip('s')
ax.set_xlabel(
'{} {}\n{}\n{}'.format(x_label, idx_axis, _format_number(min_), _format_number(max_)))
ax.xaxis.set_label_position('top')
if add_ylabel:
ax.set_ylabel('Sample\n{}'.format(idx_sample))
[docs]def plot_classification_predictions_max_likelihood(
sampled: samples.Samples,
idx_sample: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
if ax is None:
return
if sampled.raw_predictions is None:
_logger.debug('Not plotting raw predictions max likelihood; raw_predictions attribute not available')
return
# Note: this assumes that the softmax applied to all prediction axes and that there was no transformation applied
# to the categorical data.
min_ = 0
max_ = sampled.num_responses - 1
assert not colormaps.check_is_categorical_colormap_repeated(sampled.num_responses), \
'Number of categorical responses is greater than length of colormap, figure out how to handle gracefully'
ax.imshow(np.argmax(sampled.raw_predictions[idx_sample, ...], axis=-1), vmin=min_, vmax=max_,
cmap=colormaps.COLORMAP_CATEGORICAL)
ax.set_xticks([])
ax.set_yticks([])
if add_xlabel:
ax.set_xlabel('Categorical\nPredictions MLE\n{}\n{}'.format(_format_number(min_), _format_number(max_)))
ax.xaxis.set_label_position('top')
if add_ylabel:
ax.set_ylabel('Sample\n{}'.format(idx_sample))
[docs]def plot_weights(sampled: samples.Samples, idx_sample: int, ax: plt.Axes, add_xlabel: bool, add_ylabel: bool) -> None:
if ax is None:
return
if sampled.weights is None:
_logger.debug('Not plotting weights; weights attribute not available')
return
min_, max_ = sampled.weights_range[0, :]
weights = sampled.weights[idx_sample, :].copy()
weights[weights == 0] = np.nan
weight_cmap = matplotlib.cm.get_cmap(colormaps.COLORMAP_WEIGHTS)
weight_cmap.set_bad('white')
ax.imshow(weights, vmin=min_, vmax=max_, cmap=weight_cmap)
ax.set_xticks([])
ax.set_yticks([])
if add_xlabel:
ax.set_xlabel('Weights\n{}\n{}'.format(_format_number(min_), _format_number(max_)))
ax.xaxis.set_label_position('top')
if add_ylabel:
ax.set_ylabel('Sample\n{}'.format(idx_sample))
[docs]def plot_binary_error_classification(
sampled: samples.Samples,
idx_sample: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
if ax is None:
return
if sampled.raw_responses is None or sampled.raw_predictions is None:
_logger.debug('Not plotting classification errors; ' +
'raw_responses and/or raw_predictions attributes are not available')
return
# Note: this assumes that the softmax applied to all prediction axes and that there was no transformation applied
# to the categorical data.
# Note: the actual range of this data will be from 0 to 1, i.e., is the class incorrect or correct, but the plots
# will be too dark if we set the vmin and vmax to 0 and 1, respectively
min_ = -0.5
max_ = 1.5
actual = np.argmax(sampled.raw_responses[idx_sample, :], axis=-1)
predicted = np.argmax(sampled.raw_predictions[idx_sample, :], axis=-1)
ax.imshow(predicted == actual, vmin=min_, vmax=max_, cmap=colormaps.COLORMAP_ERROR)
ax.set_xticks([])
ax.set_yticks([])
if add_xlabel:
ax.set_xlabel('Correct\nCategory\nFalse\nTrue')
ax.xaxis.set_label_position('top')
if add_ylabel:
ax.set_ylabel('Sample\n{}'.format(idx_sample))
add_internal_window_to_subplot(sampled, ax)
[docs]def plot_raw_error_regression(
sampled: samples.Samples,
idx_sample: int,
idx_response: int,
ax: plt.Axes,
add_xlabel: bool,
add_ylabel: bool
) -> None:
if ax is None:
return
if sampled.raw_responses is None or sampled.raw_predictions is None:
_logger.debug('Not plotting raw regression errors; ' +
'raw_responses and/or raw_predictions attributes are not available')
return
error = sampled.raw_predictions[idx_sample, :, :, idx_response] - \
sampled.raw_responses[idx_sample, :, :, idx_response]
min_ = float(np.nanmin(error))
max_ = float(np.nanmax(error))
ax.imshow(error, vmin=min_, vmax=max_, cmap=colormaps.COLORMAP_ERROR)
ax.set_xticks([])
ax.set_yticks([])
if add_xlabel:
ax.set_xlabel('Raw\nRegression\nErrors\n{}\n{}'.format(_format_number(min_), _format_number(max_)))
ax.xaxis.set_label_position('top')
if add_ylabel:
ax.set_ylabel('Sample\n{}'.format(idx_sample))
add_internal_window_to_subplot(sampled, ax)
[docs]def add_internal_window_to_subplot(sampled: samples.Samples, ax: plt.Axes) -> None:
if ax is None:
return
loss_window_radius = sampled.config.data_build.loss_window_radius
window_radius = sampled.config.data_build.window_radius
if loss_window_radius == window_radius:
return
buffer = int(window_radius - loss_window_radius)
rect = patches.Rectangle(
(buffer, buffer), loss_window_radius * 2, loss_window_radius * 2, linewidth=1, edgecolor='red',
facecolor='none'
)
ax.add_patch(rect)
def _format_number(number: Union[int, float]) -> str:
# isinstance needed for multiple numpy integer and float types
if isinstance(number, int):
return str(number)
elif isinstance(number, float):
return str(round(number, _FLOAT_DECIMALS))