Why A/B Testing ML Models Differs from Traditional A/B Testing
Traditional A/B tests compare static UI changes (button colors, headlines). ML (as discussed in AI-Assisted Bug Triaging: Intelligent Defect Prioritization at Scale) A/B testing is fundamentally different:
- Non-deterministic: Same input may produce different outputs
- Continuous Learning: Models retrain, behavior evolves
- Complex Metrics: Accuracy, latency, fairness, business KPIs
- Long-term Effects: Model changes impact future data distribution
Example: A recommendation model that increases click-through rate might decrease long-term engagement by showing clickbait.
A/B Testing Framework for ML
1. Hypothesis Formation
class MLExperiment:
def __init__(self, name, hypothesis, success_criteria):
self.name = name
self.hypothesis = hypothesis
self.success_criteria = success_criteria
self.variants = {}
def add_variant(self, name, model, traffic_allocation):
self.variants[name] = {
'model': model,
'traffic_allocation': traffic_allocation,
'metrics': defaultdict(list)
}
# Example
experiment = MLExperiment(
(as discussed in [AI Code Smell Detection: Finding Problems in Test Automation with ML](/blog/ai-code-smell-detection)) name="ranking_model_v2",
hypothesis="New transformer model will increase click-through rate by 5%",
success_criteria={
'ctr_increase': 0.05,
'latency_p95': 200, # ms
'min_statistical_significance': 0.95
}
)
experiment.add_variant('control', model_v1, traffic_allocation=0.5)
experiment.add_variant('treatment', model_v2, traffic_allocation=0.5)
2. Traffic Splitting
class TrafficSplitter:
def __init__(self, experiment):
self.experiment = experiment
def assign_variant(self, user_id):
"""Consistent hash-based assignment"""
hash_value = hashlib.md5(
f"{self.experiment.name}:{user_id}".encode()
).hexdigest()
hash_int = int(hash_value, 16)
threshold = hash_int % 100
cumulative = 0
for variant_name, variant in self.experiment.variants.items():
cumulative += variant['traffic_allocation'] * 100
if threshold < cumulative:
return variant_name, variant['model']
return 'control', self.experiment.variants['control']['model']
# Usage
splitter = TrafficSplitter(experiment)
variant, model = splitter.assign_variant(user_id="user_12345")
prediction = model.predict(features)
3. Metric Collection
class MetricsCollector:
def __init__(self, experiment):
self.experiment = experiment
def log_interaction(self, variant_name, metrics):
"""Log user interaction metrics"""
for metric_name, value in metrics.items():
self.experiment.variants[variant_name]['metrics'][metric_name].append(value)
def calculate_summary_stats(self):
"""Calculate aggregate metrics per variant"""
results = {}
for variant_name, variant in self.experiment.variants.items():
metrics_data = variant['metrics']
results[variant_name] = {
'ctr': np.mean(metrics_data['clicked']),
'avg_latency': np.mean(metrics_data['latency_ms']),
'p95_latency': np.percentile(metrics_data['latency_ms'], 95),
'revenue_per_user': np.mean(metrics_data['revenue'])
}
return results
# Example logging
collector = MetricsCollector(experiment)
# User interaction with control
collector.log_interaction('control', {
'clicked': 1,
'latency_ms': 150,
'revenue': 5.99
})
# User interaction with treatment
collector.log_interaction('treatment', {
'clicked': 1,
'latency_ms': 180,
'revenue': 7.50
})
summary = collector.calculate_summary_stats()
4. Statistical Significance Testing
from scipy import stats
class SignificanceTester:
def __init__(self, alpha=0.05):
self.alpha = alpha
def t_test(self, control_data, treatment_data):
"""Two-sample t-test for continuous metrics"""
t_stat, p_value = stats.ttest_ind(control_data, treatment_data)
return {
't_statistic': t_stat,
'p_value': p_value,
'is_significant': p_value < self.alpha,
'control_mean': np.mean(control_data),
'treatment_mean': np.mean(treatment_data),
'relative_lift': (np.mean(treatment_data) - np.mean(control_data)) / np.mean(control_data)
}
def chi_square_test(self, control_conversions, control_total, treatment_conversions, treatment_total):
"""Chi-square test for binary metrics (clicks, conversions)"""
contingency_table = np.array([
[control_conversions, control_total - control_conversions],
[treatment_conversions, treatment_total - treatment_conversions]
])
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
control_rate = control_conversions / control_total
treatment_rate = treatment_conversions / treatment_total
return {
'chi2_statistic': chi2,
'p_value': p_value,
'is_significant': p_value < self.alpha,
'control_rate': control_rate,
'treatment_rate': treatment_rate,
'relative_lift': (treatment_rate - control_rate) / control_rate
}
def sequential_testing(self, control_data, treatment_data, looks=10):
"""Sequential testing with alpha spending function"""
# Use O'Brien-Fleming bounds to control Type I error
from scipy.stats import norm
n = len(control_data)
alpha_spend = []
for k in range(1, looks + 1):
z = norm.ppf(1 - self.alpha / (2 * looks))
alpha_k = 2 * (1 - norm.cdf(z * np.sqrt(looks / k)))
alpha_spend.append(alpha_k)
# Perform test
t_result = self.t_test(control_data, treatment_data)
for k, alpha_k in enumerate(alpha_spend):
if t_result['p_value'] < alpha_k:
return {
**t_result,
'can_stop_early': True,
'stopped_at_look': k + 1
}
return {**t_result, 'can_stop_early': False}
# Example
tester = SignificanceTester(alpha=0.05)
control_latencies = experiment.variants['control']['metrics']['latency_ms']
treatment_latencies = experiment.variants['treatment']['metrics']['latency_ms']
latency_result = tester.t_test(control_latencies, treatment_latencies)
print(f"Latency lift: {latency_result['relative_lift']:.2%}")
print(f"Statistically significant: {latency_result['is_significant']}")
# Binary conversion testing
ctr_result = tester.chi_square_test(
control_conversions=1250,
control_total=10000,
treatment_conversions=1400,
treatment_total=10000
)
print(f"CTR lift: {ctr_result['relative_lift']:.2%}")
Online vs. Offline Evaluation
Offline Evaluation
class OfflineEvaluator:
def __init__(self, historical_data):
self.data = historical_data
def holdout_validation(self, model_old, model_new):
"""Evaluate on held-out data"""
X_test, y_test = self.data['X_test'], self.data['y_test']
old_predictions = model_old.predict(X_test)
new_predictions = model_new.predict(X_test)
return {
'old_model_auc': roc_auc_score(y_test, old_predictions),
'new_model_auc': roc_auc_score(y_test, new_predictions),
'auc_improvement': roc_auc_score(y_test, new_predictions) - roc_auc_score(y_test, old_predictions)
}
def replay_evaluation(self, model, logged_data):
"""Replay historical data through new model"""
# Causal inference to estimate counterfactual performance
propensity_scores = logged_data['propensity_scores']
rewards = logged_data['rewards']
actions = logged_data['actions']
new_actions = model.predict(logged_data['features'])
# Inverse propensity scoring
ips_estimate = np.mean([
(new_actions[i] == actions[i]) * rewards[i] / propensity_scores[i]
for i in range(len(rewards))
])
return {'estimated_reward': ips_estimate}
Online Evaluation
class OnlineEvaluator:
def __init__(self, experiment):
self.experiment = experiment
self.monitoring_metrics = defaultdict(list)
def monitor_real_time(self):
"""Real-time monitoring with alerts"""
summary = self.calculate_current_performance()
# Check for regressions
alerts = []
if summary['treatment']['error_rate'] > summary['control']['error_rate'] * 1.5:
alerts.append({
'severity': 'CRITICAL',
'message': f"Treatment error rate 50% higher: {summary['treatment']['error_rate']:.3f}"
})
if summary['treatment']['p99_latency'] > 500: # SLA breach
alerts.append({
'severity': 'HIGH',
'message': f"Treatment P99 latency exceeds SLA: {summary['treatment']['p99_latency']:.0f}ms"
})
return {'summary': summary, 'alerts': alerts}
def guardrail_check(self):
"""Ensure critical metrics don't degrade"""
guardrails = {
'error_rate_increase': 0.10, # Max 10% increase
'latency_p95_increase': 0.20, # Max 20% increase
'revenue_decrease': 0.05 # Max 5% decrease
}
violations = []
control = self.experiment.variants['control']['metrics']
treatment = self.experiment.variants['treatment']['metrics']
error_increase = (np.mean(treatment['errors']) - np.mean(control['errors'])) / np.mean(control['errors'])
if error_increase > guardrails['error_rate_increase']:
violations.append(f"Error rate guardrail violated: +{error_increase:.1%}")
return {'passed': len(violations) == 0, 'violations': violations}
Advanced Techniques
1. Multi-Armed Bandits
class ThompsonSampling:
"""Adaptive traffic allocation based on performance"""
def __init__(self, variants):
self.variants = {
name: {'alpha': 1, 'beta': 1} # Beta distribution parameters
for name in variants
}
def select_variant(self):
"""Thompson sampling: sample from posterior distributions"""
samples = {
name: np.random.beta(params['alpha'], params['beta'])
for name, params in self.variants.items()
}
return max(samples, key=samples.get)
def update(self, variant_name, reward):
"""Update posterior based on observed reward"""
if reward > 0:
self.variants[variant_name]['alpha'] += 1
else:
self.variants[variant_name]['beta'] += 1
# Usage
bandit = ThompsonSampling(['model_v1', 'model_v2', 'model_v3'])
for user in users:
selected_variant = bandit.select_variant()
prediction = models[selected_variant].predict(user.features)
reward = user.interact(prediction)
bandit.update(selected_variant, reward)
2. Interleaved Testing
class InterleavedTest:
"""Present results from both models, track which users prefer"""
def __init__(self, model_a, model_b):
self.model_a = model_a
self.model_b = model_b
def team_draft_interleaving(self, query, k=10):
"""Team-draft interleaving for ranking models"""
results_a = self.model_a.rank(query, top_k=k*2)
results_b = self.model_b.rank(query, top_k=k*2)
interleaved = []
used_a, used_b = set(), set()
for i in range(k):
# Alternate picking from each model
if i % 2 == 0:
# Pick from A
for item in results_a:
if item not in used_a and item not in used_b:
interleaved.append({'item': item, 'source': 'A'})
used_a.add(item)
break
else:
# Pick from B
for item in results_b:
if item not in used_a and item not in used_b:
interleaved.append({'item': item, 'source': 'B'})
used_b.add(item)
break
return interleaved
def evaluate_clicks(self, interleaved_results, clicks):
"""Determine which model won based on clicks"""
clicks_a = sum(1 for i in clicks if interleaved_results[i]['source'] == 'A')
clicks_b = sum(1 for i in clicks if interleaved_results[i]['source'] == 'B')
return {
'model_a_wins': clicks_a,
'model_b_wins': clicks_b,
'winner': 'A' if clicks_a > clicks_b else 'B' if clicks_b > clicks_a else 'TIE'
}
Rollout Strategies
class GradualRollout:
def __init__(self, experiment):
self.experiment = experiment
self.current_allocation = 0.05 # Start with 5%
def should_increase_traffic(self, hours_running, metrics):
"""Decide whether to increase traffic to treatment"""
if hours_running < 24:
return False # Wait at least 24 hours
if not metrics['guardrails_passing']:
return False
if metrics['statistical_significance'] and metrics['positive_lift']:
return True
return False
def increase_allocation(self, increment=0.10):
"""Gradually increase treatment traffic"""
self.current_allocation = min(1.0, self.current_allocation + increment)
return self.current_allocation
# Rollout stages
# Stage 1: 5% for 24-48 hours
# Stage 2: 20% for 48 hours (if metrics good)
# Stage 3: 50% for 48 hours
# Stage 4: 100% (full rollout)
Best Practices
Practice | Description |
---|---|
Define Success Metrics Upfront | Primary (CTR), secondary (revenue), guardrail (latency, errors) |
Calculate Required Sample Size | Use power analysis to avoid under-powered tests |
Run for Full Business Cycles | Account for weekly/monthly seasonality |
Monitor Guardrails | Automatically halt if critical metrics degrade |
Test One Change at a Time | Isolate what causes performance differences |
Log Everything | Enable post-hoc analysis and debugging |
Gradual Rollout | Start with small traffic, expand if successful |
Interleaving for Rankings | More sensitive than A/B for search/recommendation |
Conclusion
A/B testing ML models requires rigorous statistical methods, real-time monitoring, and business metric alignment. Unlike static features, ML (as discussed in AI-powered Test Generation: The Future Is Already Here) experiments involve complex interactions, long-term effects, and evolving behavior.
Success comes from combining offline validation, carefully designed online experiments, statistical rigor, guardrail monitoring, and gradual rollouts. The goal isn’t just deploying better models—it’s learning faster and iterating confidently.