Source code for bfgn.data_management.ooc_functions



import logging
import numpy as np
import os
from typing import List, Tuple

import shutil

_MAX_UNIQUE_RESPONSES = 100

_logger = logging.getLogger(__name__)


[docs]def one_hot_encode_array(raw_band_types: List[str], array: np.array, memmap_file: str = None, per_band_encoding: List[np.array] = None) -> Tuple[np.array, List[str], List[np.array]]: """One hot encode an array of mixed real and categorical variables. Args: raw_band_types: Band types for given array, either 'R' for real or 'C' for categorical. array: array to encode memmap_file: file to use to do things out-of-core per_band_encoding: if none, this will be calculated and returned. If not none, these will be used to encode the array Returns: array: now one-hot-encoded band_types: the one-hot-encoded versinon of the band-types return_band_encoding: the encoding used on a per-categorical-band basis, if per_band_encoding was None when provided, otherwise None """ cat_band_locations = [idx for idx, val in enumerate(raw_band_types) if val == 'C'] band_types = raw_band_types.copy() if (per_band_encoding is None): return_band_encoding = [] else: assert len(per_band_encoding) == len(cat_band_locations), \ 'Inconsistent lengths of categorical band locations and per_band_encoding provided' return_band_encoding = None for _c in reversed(range(len(cat_band_locations))): if (per_band_encoding is None): un_array = array[..., cat_band_locations[_c]] un_array = np.unique(un_array[np.isfinite(un_array)]) return_band_encoding.append(un_array) else: un_array = per_band_encoding[_c] assert len(un_array) < _MAX_UNIQUE_RESPONSES,\ 'Too many ({}) unique responses found, suspected incorrect categorical specification'.format(len(un_array)) _logger.info('Found {} categorical responses'.format(len(un_array))) _logger.info('Cat response: {}'.format(un_array)) array_shape = list(array.shape) array_shape[-1] = len(un_array) + array.shape[-1] - 1 if (memmap_file is not None): cat_memmap_file = os.path.join( os.path.dirname(memmap_file), str(os.path.splitext(os.path.basename(memmap_file))[0]) + '_cat.npy') cat_array = np.memmap(cat_memmap_file, dtype=np.float32, mode='w+', shape=tuple(array_shape)) else: cat_array = np.zeros(tuple(array_shape)) # One hot-encode for _r in range(array_shape[-1]): if (_r >= cat_band_locations[_c] and _r < len(un_array)): cat_array[..., _r] = np.squeeze(array[..., cat_band_locations[_c]] == un_array[_r - cat_band_locations[_c]]) else: if (_r < cat_band_locations[_c]): cat_array[..., _r] = array[..., _r] else: cat_array[..., _r] = array[..., _r - len(un_array) + 1] # Force file dump, and then reload the encoded responses as the primary response del array, cat_array if (memmap_file is not None): if (os.path.isfile(memmap_file)): os.remove(memmap_file) memmap_file = cat_memmap_file array = np.memmap(memmap_file, dtype=np.float32, mode='r+', shape=tuple(array_shape)) band_types.pop(cat_band_locations[_c]) for _r in range(len(un_array)): band_types.insert(cat_band_locations[_c], 'B' + str(int(_c))) if (per_band_encoding is not None): return array else: return array, band_types, return_band_encoding
[docs]def permute_array(source: np.array, source_filename: str, permutation: np.array) -> np.array: perm_memmap_file = os.path.join( os.path.dirname(source_filename), str(os.path.splitext(os.path.basename(source_filename))[0]) + '_perm.npy') shape = source.shape dtype = source.dtype dest = np.memmap(perm_memmap_file, dtype=dtype, mode='w+', shape=shape) for i in range(len(permutation)): dest[i, ...] = source[permutation[i], ...] if (i % 100 == 0): del dest, source source = np.memmap(source_filename, dtype=dtype, shape=shape, mode='r+') dest = np.memmap(perm_memmap_file, dtype=dtype, shape=shape, mode='r+') del source, dest os.remove(source_filename) shutil.move(perm_memmap_file, source_filename) source = np.memmap(source_filename, dtype=dtype, shape=shape, mode='r+') return source