import logging
from pathlib import Path
from typing import List, Tuple
import tensorflow as tf
import wandb
from mergernet.core.constants import RANDOM_SEED
from mergernet.core.experiment import Experiment
from mergernet.core.hp import HyperParameterSet
from mergernet.core.utils import Timming
from mergernet.data.dataset import Dataset
from mergernet.model.utils import (get_conv_arch, set_trainable_state,
setup_seeds)
setup_seeds()
L = logging.getLogger(__name__)
[docs]def finetune_train(
dataset: Dataset,
hp: HyperParameterSet,
callbacks: List[tf.keras.callbacks.Callback] = None,
run_name: str = 'run'
) -> tf.keras.Model:
tf.keras.backend.clear_session()
ds_train, ds_test = dataset.get_fold(0)
ds_train = dataset.prepare_data(
ds_train,
batch_size=hp.get('batch_size'),
buffer_size=5000,
kind='train'
)
ds_test = dataset.prepare_data(
ds_test,
batch_size=hp.get('batch_size'),
buffer_size=1000,
kind='train'
)
class_weights = dataset.compute_class_weight()
model = _build_model(
input_shape=dataset.config.image_shape,
n_classes=dataset.config.n_classes,
freeze_conv=True,
hp=hp
)
_compile_model(model, tf.keras.optimizers.Adam(hp.get('opt_lr')))
with Experiment.Tracker(hp.to_values_dict(), name=run_name, job_type='train'):
early_stop_cb = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0,
patience=2,
mode='min', # 'min' or 'max'
restore_best_weights=True
)
wandb_cb = wandb.keras.WandbCallback(
monitor='val_loss',
mode='min',
save_graph=True,
save_model=False,
log_weights=False,
log_gradients=False,
compute_flops=True,
)
t = Timming()
L.info('Start of training loop with frozen CNN')
h1 = model.fit(
ds_train,
batch_size=hp.get('batch_size'),
epochs=hp.get('tl_epochs', default=10),
validation_data=ds_test,
class_weight=class_weights,
callbacks=[early_stop_cb, wandb_cb]
)
L.info(f'End of training loop, duration: {t.end()}')
set_trainable_state(model, 'conv_block', True)
_compile_model(model, tf.keras.optimizers.Adam(hp.get('opt_lr')))
t = Timming()
L.info('Start of main training loop')
model.fit(
ds_train,
batch_size=hp.get('batch_size'),
epochs=hp.get('tl_epochs', default=10) + hp.get('epochs'),
validation_data=ds_test,
class_weight=class_weights,
initial_epoch=len(h1.history['loss']),
callbacks=[wandb_cb, *callbacks],
)
L.info(f'End of training loop, duration: {t.end()}')
return model
def _build_model(
input_shape: Tuple,
n_classes: int,
freeze_conv: bool = False,
hp: HyperParameterSet = None
) -> tf.keras.Model:
# dataset.config.n_classes
conv_arch, preprocess_input = get_conv_arch(
hp.get('architecture')
)
conv_block = conv_arch(
input_shape=input_shape,
include_top=False,
weights=hp.get('pretrained_weights'),
)
conv_block._name = 'conv_block'
conv_block.trainable = (not freeze_conv)
L.info(f'Trainable weights (CONV): {len(conv_block.trainable_weights)}')
data_aug_layers = [
tf.keras.layers.RandomFlip(mode='horizontal', seed=RANDOM_SEED),
tf.keras.layers.RandomFlip(mode='vertical', seed=RANDOM_SEED),
tf.keras.layers.RandomRotation(
(-0.08, 0.08),
fill_mode='reflect',
interpolation='bilinear',
seed=RANDOM_SEED
),
tf.keras.layers.RandomZoom(
(-0.15, 0.0),
fill_mode='reflect',
interpolation='bilinear',
seed=RANDOM_SEED
)
]
data_aug_block = tf.keras.Sequential(data_aug_layers, name='data_augmentation')
inputs = tf.keras.Input(shape=input_shape)
x = data_aug_block(inputs)
x = preprocess_input(x)
x = conv_block(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(hp.get('dense_1_units'), activation='relu')(x)
x = tf.keras.layers.Dropout(hp.get('dropout_1_rate'))(x)
x = tf.keras.layers.Dense(hp.get('dense_2_units'), activation='relu')(x)
x = tf.keras.layers.Dropout(hp.get('dropout_2_rate'))(x)
outputs = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
L.info(f'Trainable weights (TOTAL): {len(model.trainable_weights)}')
return model
def _compile_model(
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: list = []
):
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=[
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
*metrics
]
)