Source code for bfgn.architectures.dilation_net

from typing import Tuple

import keras
from keras.layers import BatchNormalization, Conv2D

from bfgn.architectures import config_sections


DEFAULT_DILATION_RATE = 2
DEFAULT_NUM_LAYERS = 8


[docs]class ArchitectureConfigSection( config_sections.DilationMixin, config_sections.FlatMixin, config_sections.BaseArchitectureConfigSection ): pass
[docs]def create_model( inshape: Tuple[int, int, int], n_classes: int, output_activation: str, dilation_rate: int = config_sections.DEFAULT_DILATION_RATE, filters: int = config_sections.DEFAULT_FILTERS, kernel_size: Tuple[int, int] = config_sections.DEFAULT_KERNEL_SIZE, num_layers: int = config_sections.DEFAULT_NUM_LAYERS, padding: str = config_sections.DEFAULT_PADDING, use_batch_norm: bool = config_sections.DEFAULT_USE_BATCH_NORM, use_initial_colorspace_transformation_layer: bool = config_sections.DEFAULT_USE_INITIAL_COLORSPACE_TRANSFORMATION_LAYER ) -> keras.models.Model: inlayer = keras.layers.Input(inshape) conv = inlayer if use_initial_colorspace_transformation_layer: intermediate_color_depth = int(inshape[-1] ** 2) conv = Conv2D(filters=intermediate_color_depth, kernel_size=(1, 1), padding='same')(inlayer) conv = Conv2D(filters=inshape[-1], kernel_size=(1, 1), padding='same')(conv) conv = BatchNormalization()(conv) for idx_layer in range(num_layers): conv = Conv2D(filters=filters, dilation_rate=dilation_rate, kernel_size=kernel_size, padding=padding)(conv) conv = Conv2D(filters=filters, dilation_rate=dilation_rate, kernel_size=kernel_size, padding=padding)(conv) if use_batch_norm: conv = BatchNormalization()(conv) output_layer = Conv2D(n_classes, (1, 1), activation=output_activation, padding=padding)(conv) model = keras.models.Model(inputs=inlayer, outputs=output_layer) return model