import logging
from pathlib import Path
from typing import List, Tuple
import tensorflow as tf
import wandb
from mergernet.core.constants import RANDOM_SEED
from mergernet.core.experiment import Experiment
from mergernet.core.hp import HyperParameterSet
from mergernet.core.utils import Timming
from mergernet.data.dataset import Dataset
from mergernet.model.utils import (get_conv_arch, set_trainable_state,
                                   setup_seeds)
setup_seeds()
L = logging.getLogger(__name__)
[docs]def finetune_train(
  dataset: Dataset,
  hp: HyperParameterSet,
  callbacks: List[tf.keras.callbacks.Callback] = None,
  run_name: str = 'run'
) -> tf.keras.Model:
  tf.keras.backend.clear_session()
  ds_train, ds_test = dataset.get_fold(0)
  ds_train = dataset.prepare_data(
    ds_train,
    batch_size=hp.get('batch_size'),
    buffer_size=5000,
    kind='train'
  )
  ds_test = dataset.prepare_data(
    ds_test,
    batch_size=hp.get('batch_size'),
    buffer_size=1000,
    kind='train'
  )
  class_weights = dataset.compute_class_weight()
  model = _build_model(
    input_shape=dataset.config.image_shape,
    n_classes=dataset.config.n_classes,
    freeze_conv=True,
    hp=hp
  )
  _compile_model(model, tf.keras.optimizers.Adam(hp.get('opt_lr')))
  with Experiment.Tracker(hp.to_values_dict(), name=run_name, job_type='train'):
    early_stop_cb = tf.keras.callbacks.EarlyStopping(
      monitor='val_loss',
      min_delta=0,
      patience=2,
      mode='min', # 'min' or 'max'
      restore_best_weights=True
    )
    wandb_cb = wandb.keras.WandbCallback(
      monitor='val_loss',
      mode='min',
      save_graph=True,
      save_model=False,
      log_weights=False,
      log_gradients=False,
      compute_flops=True,
    )
    t = Timming()
    L.info('Start of training loop with frozen CNN')
    h1 = model.fit(
      ds_train,
      batch_size=hp.get('batch_size'),
      epochs=hp.get('tl_epochs', default=10),
      validation_data=ds_test,
      class_weight=class_weights,
      callbacks=[early_stop_cb, wandb_cb]
    )
    L.info(f'End of training loop, duration: {t.end()}')
    set_trainable_state(model, 'conv_block', True)
    _compile_model(model, tf.keras.optimizers.Adam(hp.get('opt_lr')))
    t = Timming()
    L.info('Start of main training loop')
    model.fit(
      ds_train,
      batch_size=hp.get('batch_size'),
      epochs=hp.get('tl_epochs', default=10) + hp.get('epochs'),
      validation_data=ds_test,
      class_weight=class_weights,
      initial_epoch=len(h1.history['loss']),
      callbacks=[wandb_cb, *callbacks],
    )
    L.info(f'End of training loop, duration: {t.end()}')
  return model 
def _build_model(
  input_shape: Tuple,
  n_classes: int,
  freeze_conv: bool = False,
  hp: HyperParameterSet = None
) -> tf.keras.Model:
  # dataset.config.n_classes
  conv_arch, preprocess_input = get_conv_arch(
    hp.get('architecture')
  )
  conv_block = conv_arch(
    input_shape=input_shape,
    include_top=False,
    weights=hp.get('pretrained_weights'),
  )
  conv_block._name = 'conv_block'
  conv_block.trainable = (not freeze_conv)
  L.info(f'Trainable weights (CONV): {len(conv_block.trainable_weights)}')
  data_aug_layers = [
    tf.keras.layers.RandomFlip(mode='horizontal', seed=RANDOM_SEED),
    tf.keras.layers.RandomFlip(mode='vertical', seed=RANDOM_SEED),
    tf.keras.layers.RandomRotation(
      (-0.08, 0.08),
      fill_mode='reflect',
      interpolation='bilinear',
      seed=RANDOM_SEED
    ),
    tf.keras.layers.RandomZoom(
      (-0.15, 0.0),
      fill_mode='reflect',
      interpolation='bilinear',
      seed=RANDOM_SEED
    )
  ]
  data_aug_block = tf.keras.Sequential(data_aug_layers, name='data_augmentation')
  inputs = tf.keras.Input(shape=input_shape)
  x = data_aug_block(inputs)
  x = preprocess_input(x)
  x = conv_block(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(hp.get('dense_1_units'), activation='relu')(x)
  x = tf.keras.layers.Dropout(hp.get('dropout_1_rate'))(x)
  x = tf.keras.layers.Dense(hp.get('dense_2_units'), activation='relu')(x)
  x = tf.keras.layers.Dropout(hp.get('dropout_2_rate'))(x)
  outputs = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
  model = tf.keras.Model(inputs, outputs)
  L.info(f'Trainable weights (TOTAL): {len(model.trainable_weights)}')
  return model
def _compile_model(
  model: tf.keras.Model,
  optimizer: tf.keras.optimizers.Optimizer,
  metrics: list = []
):
  model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    metrics=[
      tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
      *metrics
    ]
  )