from itertools import compress
import pandas as pd
from mergernet.core.constants import DATA_ROOT
from mergernet.core.experiment import Experiment, backup_model
from mergernet.core.hp import HP, HyperParameterSet
from mergernet.core.utils import iauname, iauname_path
from mergernet.data.dataset import Dataset
from mergernet.data.image import (ChannelAverage, Crop, ImagePipeline,
                                  LegacyRGB, TensorToImage, TensorToShards)
from mergernet.estimators.automl import OptunaEstimator
from mergernet.estimators.parametric import ParametricEstimator
from mergernet.services.legacy import LegacyService
[docs]class Job(Experiment):
  """Base model"""
  def __init__(self):
    super().__init__()
    self.exp_id = 27
    self.log_wandb = True
    self.restart = False
[docs]  def call(self):
    hps = HyperParameterSet(
      HP.const('architecture', 'efficientnetv2b0'),
      HP.const('pretrained_weights', 'imagenet'),
      HP.const('metrics', ['f1', 'recall', 'roc']),
      HP.const('positive_class_id', 1),
      HP.const('negative_class_id', 0),
      HP.const('epochs', 35),
      HP.const('tl_epochs', 12),
      HP.const('t1_opt', 'adamw'),
      HP.num('t1_lr', low=2e-4, high=5e-3, log=True),
      HP.const('optimizer', 'adamw'),
      HP.const('lr_decay', 'cosine'),
      HP.num('lr_decay_steps', low=0.5, high=0.9),
      HP.num('lr_decay_alpha', low=0.1, high=1.0),
      HP.num('opt_lr', low=1e-5, high=1e-3, log=True),
      HP.num('weight_decay', low=1e-4, high=1e-1),
      HP.num('label_smoothing', low=0, high=0.1),
      HP.num('batch_size', low=64, high=256, step=64, dtype=int),
      HP.num('dense_1_units', low=32, high=1024, step=1, dtype=int),
      HP.num('dropout_1_rate', low=0.1, high=0.5),
      # HP.num('dense_2_units', low=32, high=1024, step=1, dtype=int),
      # HP.num('dropout_2_rate', low=0.1, high=0.5),
    )
    ds = Dataset(config=Dataset.registry.LS10_TRAIN_224_PNG)
    model = ParametricEstimator(hp=hps, dataset=ds)
    optuna_model = OptunaEstimator(
      estimator=model,
      n_trials=20,
      objective_metric='val_recall',
      objective_direction='maximize',
      resume=True,
    )
    optuna_model.train()
    Experiment.upload_file_gd('model.h5', optuna_model.tf_model)  
if __name__ == '__main__':
  Job().run()