ParametricEstimator#

class mergernet.estimators.parametric.ParametricEstimator[source]#

Bases: Estimator

Attributes

build(freeze_conv: bool = False) Model[source]#
compile_model(tf_model: Model, optimizer: Optimizer, metrics: list = [], label_smoothing: float = 0.0)#
cross_validation(run_name: str = 'run-0', callbacks: List[Callback] = [])[source]#
download(config: EstimatorConfig, replace: bool = False)#
get_conv_arch(pretrained_arch: str) Tuple[Callable, Callable]#
get_dataaug_block(flip_horizontal: bool = True, flip_vertical: bool = True, rotation: Tuple[float, float] | bool = (-0.08, 0.08), zoom: Tuple[float, float] | bool = (-0.15, 0.0))#
get_metric(metric: str)#
get_optimizer(optimizer: str, lr: float | LearningRateSchedule) Optimizer#
get_scheduler(scheduler: str, lr: float) LearningRateSchedule#

For cosine_restarts scheduler, the learning rate multiplier first decays from 1 to alpha for first_decay_steps steps. Then, a warm restart is performed. Each new warm restart runs for t_mul times more steps and with m_mul times initial learning rate as the new learning rate.

Parameters:
  • scheduler (str) – Scheduler name

  • lr (float) – Initial learning rate

Returns:

A LearningRateSchedule instance

Return type:

tf.keras.optimizers.schedules.LearningRateSchedule

plot(filename: str = 'model.png')#
predict()[source]#
set_trainable(tf_model: Model, layer: str, trainable: bool)#
train(run_name: str = 'run-0', callbacks: List[Callback] = [], fold: int = 0) Model[source]#
_abc_impl = <_abc_data object>#
registry = <mergernet.estimators.base.EstimatorRegistry object>#
property tf_model#