lfcnn.models package

Submodules

lfcnn.models.abstracts module

Abstract base class definitions.

class lfcnn.models.abstracts.BaseModel(optimizer, loss, metrics, callbacks)[source]

Bases: object

__build_model__(generated_shape, augmented_shape, gpus=1, cpu_merge=False)[source]

Create the Keras model as defined by the derived class, sets the keras_module attribute and compile it with the specified optimizer, loss and metrics.

TODO: Refactor strategies. The OndeDeviceStrategy and the MirroredStrategy in same instances show deadlocks when using multiprocessing for the data generators. Currently, the workaround is using a DummyStrategy with an empty scope() context manager.

See also

create_model()

__build_necessary__(generated_shape)[source]

Check whether building of the keras model is necessary.

Parameters

generated_shape (List[tuple]) – List of generated shape by generator, i.e. input shapes of model.

Return type

bool

Returns

True if build is necessary, else False.

property callbacks
Return type

Callback

create_model(inputs, augmented_shape)[source]

Create the Keras model. Needs to be implemented by the derived class to define the network topology.

Parameters
  • inputs (List[Input]) – List of Keras Inputs. Single or multi inputs supported.

  • augmented_shape (Optional[Tuple[int, int, int, int, int]]) – The augmented shape as generated by the generator. Can be used to obtain the original light field’s shape, for example the number of subapertures or the number of spectral channels.

Return type

Model

evaluate_challenges(data, data_key, label_keys, augmented_shape, generated_shape, range_data=None, range_labels=None, use_mask=False, gen_kwargs=None, model_weights=None, **kwargs)[source]

Evaluate dataset challenges. Challenges are full-sized inputs with ground truth labels that are used to test/evaluate a model in more depth and full-sized.

Parameters
  • data – Data dictionary or path to test data .h5 file.

  • data_key – Key of light field data in test data file or dictionary..

  • label_keys – Keys of label in test data file or dictionary..

  • augmented_shape – Shape after augmentation. (Indirectly defines angular and spatial crop, when smaller than input shapes)

  • generated_shape – Generated shape or list of generated shapes in case of multi input models.

  • range_data – Dynamic range of input light field data. Used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None.

  • range_labels – Dynamic range of input label data. May be used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None. If a list of labels is used, specify ranges as list, e.g. [255, None, None]

  • use_mask – Whether to use a color coding mask.

  • gen_kwargs (Optional[dict]) – Passed to generator instantiation.

  • model_weights (Optional[str]) – Optional path to saved model weights. If no path is specified, and model has been previously trained, use existing model weights.

  • **kwargs – Passed to tensorflow.keras.model.predict().

Return type

dict

Returns

A dictionary containing a list of predictions (the keys are set corresponding to the output layer(s) name(s) and the corresponding list of metrics.

property generator
property keras_model
Return type

Model

load(filepath, compile=True)[source]

Load the full model, including optimizer, loss, etc.

TODO: For this, the custom loss and metric functions need to implement

proper deserialization via from_config and get_config.

Parameters
  • filepath – Path to save the model.

  • compile – Whether to compile the loaded model.

Returns

A LFCNN model with properly loaded Keras model instance.

load_weights(filepath, generated_shape, augmented_shape, **kwargs)[source]
property loss
Return type

Loss

property metrics
Return type

Metric

property model_crop
Return type

tuple

property optimizer
Return type

OptimizerV2

property reshape_func
Return type

Callable

save(filepath, save_format='tf', overwrite=True, include_optimizer=True)[source]

Save the full model, including optimizer, loss, etc.

TODO: For this, the custom loss and metric functions need to implement

proper deserialization via from_config and get_config.

Parameters
  • filepath – Path to save the model

  • save_format – Format of saving, either “tf” or “h5”

  • overwrite – Whether to possible overwrite existing path.

  • include_optimizer – Whether to include the optimizer upon saving.

save_weights(filepath, overwrite=True)[source]
set_generator_and_reshape()[source]
set_metrics(metric)[source]

Set metrics after model instantiation. This is useful, e.g. when evaluating the model on full-sized light fields. In that case, MS-SSIM can be used with more scales.

test(data, data_key, label_keys, augmented_shape, generated_shape, batch_size, data_percentage=1.0, range_data=None, range_labels=None, use_mask=False, gpus=1, cpu_merge=False, gen_kwargs=None, **kwargs)[source]

Evaluate the model using a test dataset.

Parameters
  • data – Data dictionary or path to test data .h5 file.

  • data_key – Key of light field data in test data file or dictionary.

  • label_keys – Keys of label in test data file or dictionary.

  • augmented_shape – Shape after augmentation. (Indirectly defines angular and spatial crop, when smaller than input shapes)

  • generated_shape – Generated shape or list of generated shapes in case of multi input models.

  • batch_size – Batch size.

  • data_percentage – Percentage of testing data to use.

  • range_data – Dynamic range of input light field data. Used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None.

  • range_labels – Dynamic range of input label data. May be used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None. If a list of labels is used, specify ranges as list, e.g. [255, None, None]

  • use_mask – Whether to use a color coding mask.

  • gpus (Union[int, List[int]]) – Integer or list of integers specifying the number of GPUs or GPU IDs to use for training. Defaults to 1. If more than one GPU is used, the model will be distributed across multiple GPUs, i.e. the batch will be split up across the GPUs.

  • cpu_merge (bool) – Used when gpus > 1. Whether to force merging model weights under the scope of the CPU or not. Defaults to False (recommended for NV-Link)

  • gen_kwargs (Optional[dict]) – Passed to generator instantiation.

  • **kwargs – Passed to tensorflow.keras.model.evaluate().

Return type

dict

Returns

Dictionary containing loss and metric test scores.

train(data, valid_data, data_key, label_keys, augmented_shape, generated_shape, batch_size, valid_data_key=None, valid_label_keys=None, valid_batch_size=None, data_percentage=1.0, valid_percentage=1.0, range_data=None, range_labels=None, range_valid_data=None, range_valid_labels=None, augment=True, shuffle=True, use_mask=False, fix_seed=False, gpus=1, cpu_merge=False, gen_kwargs=None, **kwargs)[source]

Train and validate the model.

Parameters
  • data – Data dictionary or path to training data .h5 file.

  • valid_data – Data dictionary or path to validation data .h5 file.

  • data_key – Key of light field data in training data file or dictionary.

  • label_keys – Keys of label in training data file or dictionary.

  • valid_data_key – Key of light field data in validation data file or dictionary.

  • valid_label_keys – Keys of labels in validation data file or dictionary.

  • augmented_shape – Shape after augmentation. (Indirectly defines angular and spatial crop, when smaller than input shapes)

  • generated_shape – Generated shape or list of generated shapes in case of multi input models.

  • batch_size – Batch size.

  • data_percentage – Percentage of training data to use. Can be used to test a training on a smaller set.

  • valid_percentage – Percentage of validation data to use. Can be used to test a training on a smaller set.

  • valid_batch_size – Batch size used for validation.

  • range_data – Dynamic range of input light field data. Used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None.

  • range_labels – Dynamic range of input label data. May be used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None. If a list of labels is used, specify ranges as list, e.g. [255, None, None]

  • range_valid_data – Dynamic range of input light field validation data. Used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None.

  • range_valid_labels – Dynamic range of input label validation data. May be used to normalize the input data to a range [0, 1]. If no normalization is necessary, specify None. If a list of labels is used, specify ranges as list, e.g. [255, None, None]

  • augment – Whether to perform online augmentation. Cropping to augmented_shape is always performed.

  • shuffle – Whether to shuffle data between epochs.

  • use_mask – Whether to use a color coding mask.

  • fix_seed – Whether to use a constant seed for random augments during training. The seed for validation is always fixed.

  • gpus (Union[int, List[int]]) – Integer or list of integers specifying the number of GPUs or GPU IDs to use for training. Defaults to 1. If more than one GPU is used, the model will be distributed across multiple GPUs, i.e. the batch will be split up across the GPUs.

  • cpu_merge (bool) – Used when gpus > 1. Whether to force merging model weights under the scope of the CPU or not. Defaults to False (recommended for NV-Link)

  • gen_kwargs (Optional[dict]) – Passed to generator instantiation.

  • **kwargs – Passed to tensorflow.keras.model.fit().

Return type

History

Returns

hist A History object. The attribute hist.history contains the logged values.

class lfcnn.models.abstracts.DummyStrategy[source]

Bases: object

This class provides a dummy strategy with a scope() context manager.

When using single device training, avoids the use of TF distributed strategies which are necessary for multi GPU training.

TODO: Is there a nicer solution to this?

scope()[source]

Module contents

The LFCNN models module.