Source code for llmize.callbacks.adapt_temp_on_plateau

from ..utils.logger import log_info, log_debug

[docs] class AdaptTempOnPlateau:
[docs] def __init__(self, monitor='best_score', init_temperature = 1.0, min_delta=0.0001, patience=10, factor=1.1, max_temperature=2.0, verbose=0): """ Adapt temperature on plateau callback to increase the temperature when the metric plateaus. :param monitor: Metric to monitor (default: 'best_score'). :param min_delta: Minimum change to qualify as an improvement (default: 0.0001). :param patience: Number of steps with no improvement before increasing temperature (default: 10). :param factor: Factor by which the temperature will be increased (default: 1.2). :param verbose: Verbosity mode. 0 = no output, 1 = print message when temperature is adapted (default: 0). """ self.monitor = monitor self.min_delta = min_delta self.patience = patience self.factor = factor self.verbose = verbose self.wait = 0 # Counter for patience self.best_score = None # Best score encountered so far self.temperature = init_temperature # Initial temperature self.max_temperature = max_temperature self.stopped_step = 0 # Step when the temperature adaptation occurred
[docs] def on_step_end(self, step, logs=None): """ Check the stopping condition at the end of each step and adapt temperature if necessary. :param step: The current step number. :param logs: Dictionary containing the logs (contains the metric to monitor). """ current_score = logs.get(self.monitor) # If this is the first step, initialize best_score if self.best_score is None: self.best_score = current_score # Check if the current score has improved if abs(current_score - self.best_score) < self.min_delta: self.wait += 1 if self.verbose > 1: log_debug(f"No improvement in {self.monitor}. Patience count: {self.wait}/{self.patience}") else: self.best_score = current_score self.wait = 0 # Reset wait count # Adapt temperature if patience is exceeded if self.wait >= self.patience: self.temperature *= self.factor if self.temperature > self.max_temperature: self.temperature = self.max_temperature self.wait = 0 # Reset the wait counter self.stopped_step = step if self.verbose == 1: log_info(f"No improvement in {self.monitor} for {self.patience} steps. Adapted temperature to {self.temperature:.2f}.") return self.temperature # Return the updated temperature return self.temperature