Mastering Optuna: Pruning, Pareto Multi-Objective Search, Early Stopping and Visual Analysis
'Hands-on Optuna tutorial showing pruning, Pareto multi-objective search, custom early-stopping callbacks and visual analysis to accelerate and interpret hyperparameter tuning.'
Overview
This tutorial demonstrates a complete Optuna workflow combining pruning, multi-objective optimization, custom callbacks for early stopping, and visualization for deep trial analysis. Each section includes runnable snippets to show how Optuna can speed up experiments, shape smarter search spaces, and surface actionable insights.
Pruning with a Gradient Boosting Classifier
We start with a standard hyperparameter search for a GradientBoostingClassifier while using Optuna's MedianPruner to drop underperforming trials early. The objective reports intermediate fold results to let the pruner make decisions.
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
import numpy as np
from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.model_selection import cross_val_score, KFold
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
import matplotlib.pyplot as plt
def objective_with_pruning(trial):
X, y = load_breast_cancer(return_X_y=True)
params = {
'n_estimators': trial.suggest_int('n_estimators', 50, 200),
'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 10),
'subsample': trial.suggest_float('subsample', 0.6, 1.0),
'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
}
model = GradientBoostingClassifier(**params, random_state=42)
kf = KFold(n_splits=3, shuffle=True, random_state=42)
scores = []
for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
model.fit(X_train, y_train)
score = model.score(X_val, y_val)
scores.append(score)
trial.report(np.mean(scores), fold)
if trial.should_prune():
raise optuna.TrialPruned()
return np.mean(scores)
study1 = optuna.create_study(
direction='maximize',
sampler=TPESampler(seed=42),
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=1)
)
study1.optimize(objective_with_pruning, n_trials=30, show_progress_bar=True)
print(study1.best_value, study1.best_params)This setup quickly eliminates weak configurations and focuses compute on promising regions of the search space.
Multi-objective Optimization and Pareto Fronts
Next, we optimize for two competing goals: classification accuracy and an explicit measure of model complexity. Optuna returns a Pareto set of solutions so you can inspect trade-offs rather than optimizing a single metric.
def multi_objective(trial):
X, y = load_breast_cancer(return_X_y=True)
n_estimators = trial.suggest_int('n_estimators', 10, 200)
max_depth = trial.suggest_int('max_depth', 2, 20)
min_samples_split = trial.suggest_int('min_samples_split', 2, 20)
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
min_samples_split=min_samples_split,
random_state=42,
n_jobs=-1
)
accuracy = cross_val_score(model, X, y, cv=3, scoring='accuracy', n_jobs=-1).mean()
complexity = n_estimators * max_depth
return accuracy, complexity
study2 = optuna.create_study(
directions=['maximize', 'minimize'],
sampler=TPESampler(seed=42)
)
study2.optimize(multi_objective, n_trials=50, show_progress_bar=True)
for t in study2.best_trials[:3]:
print(t.number, t.values)This multi-objective view helps you pick a model that balances performance and simplicity.
Custom Early Stopping Callback for Regression
You can implement callbacks to control study behavior. The example below defines an EarlyStoppingCallback that halts the study when a monitored metric fails to improve for a set number of completed trials.
class EarlyStoppingCallback:
def __init__(self, early_stopping_rounds=10, direction='maximize'):
self.early_stopping_rounds = early_stopping_rounds
self.direction = direction
self.best_value = float('-inf') if direction == 'maximize' else float('inf')
self.counter = 0
def __call__(self, study, trial):
if trial.state != optuna.trial.TrialState.COMPLETE:
return
v = trial.value
if self.direction == 'maximize':
if v > self.best_value:
self.best_value, self.counter = v, 0
else:
self.counter += 1
else:
if v < self.best_value:
self.best_value, self.counter = v, 0
else:
self.counter += 1
if self.counter >= self.early_stopping_rounds:
study.stop()
def objective_regression(trial):
X, y = load_diabetes(return_X_y=True)
alpha = trial.suggest_float('alpha', 1e-3, 10.0, log=True)
max_iter = trial.suggest_int('max_iter', 100, 2000)
from sklearn.linear_model import Ridge
model = Ridge(alpha=alpha, max_iter=max_iter, random_state=42)
score = cross_val_score(model, X, y, cv=5, scoring='neg_mean_squared_error', n_jobs=-1).mean()
return -score
early_stopping = EarlyStoppingCallback(early_stopping_rounds=15, direction='minimize')
study3 = optuna.create_study(direction='minimize', sampler=TPESampler(seed=42))
study3.optimize(objective_regression, n_trials=100, callbacks=[early_stopping], show_progress_bar=True)
print(study3.best_value, study3.best_params)This callback saves compute by stopping the search once improvements plateau.
Visualization: Interpreting Trials and Pareto Fronts
Visual tools let you inspect trial histories, parameter importances, Pareto fronts, and parameter-metric relationships to understand where gains are coming from.
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
ax = axes[0, 0]
values = [t.value for t in study1.trials if t.value is not None]
ax.plot(values, marker='o', markersize=3)
ax.axhline(y=study1.best_value, color='r', linestyle='--')
ax.set_title('Study 1 History')
ax = axes[0, 1]
importance = optuna.importance.get_param_importances(study1)
params = list(importance.keys())[:5]
vals = [importance[p] for p in params]
ax.barh(params, vals)
ax.set_title('Param Importance')
ax = axes[1, 0]
for t in study2.trials:
if t.values:
ax.scatter(t.values[0], t.values[1], alpha=0.3)
for t in study2.best_trials:
ax.scatter(t.values[0], t.values[1], c='red', s=90)
ax.set_title('Pareto Front')
ax = axes[1, 1]
pairs = [(t.params.get('max_depth', 0), t.value) for t in study1.trials if t.value]
Xv, Yv = zip(*pairs) if pairs else ([], [])
ax.scatter(Xv, Yv, alpha=0.6)
ax.set_title('max_depth vs Accuracy')
plt.tight_layout()
plt.savefig('optuna_analysis.png', dpi=150)
plt.show()These plots provide at-a-glance insight into optimization progress and parameter influence.
Aggregate Results and Next Steps
Summarize key metrics across studies to gauge efficiency: pruning rate, Pareto solutions count, and best regression MSE. Use these summaries to decide whether to extend the search, change samplers, or refine the search space for production runs.
p1 = len([t for t in study1.trials if t.state == optuna.trial.TrialState.PRUNED])
print("Study 1 Best Accuracy:", study1.best_value)
print("Study 1 Pruned %:", p1 / len(study1.trials) * 100)
print("Study 2 Pareto Solutions:", len(study2.best_trials))
print("Study 3 Best MSE:", study3.best_value)
print("Study 3 Trials:", len(study3.trials))Follow-up actions include expanding datasets, integrating deep learning models, or adapting the callback logic for custom model training loops.
Сменить язык
Читать эту статью на русском