The Flaky Test Problem

Flaky tests are the plague of modern test automation. They pass and fail intermittently without any code changes, eroding confidence in test suites and wasting engineering hours on investigation. Studies show that 15-25% of tests in large codebases exhibit flaky behavior, and teams spend 10-30% of QA time debugging false failures.

Traditional approaches to detecting flaky tests rely on re-running tests multiple times and manually analyzing patterns. This is slow, expensive, and reactive. Machine Learning (as discussed in AI Code Smell Detection: Finding Problems in Test Automation with ML) offers a proactive solution: predict which tests are likely to be flaky before they cause problems, identify root causes automatically, and suggest fixes.

This article explores how to leverage ML (as discussed in AI-powered Test Generation: The Future Is Already Here) for flaky test detection, with practical algorithms, implementation examples, and proven strategies for building stable test suites.

Understanding Flaky Tests

Types of Flakiness

TypeDescriptionExample
Order-DependentTest outcome depends on execution orderTest A passes alone but fails after Test B
Async Wait IssuesRace conditions, insufficient waitsButton click before element is clickable
Resource LeaksShared state not cleaned upDatabase connection left open
Concurrency IssuesMulti-threading, parallel executionRace conditions in parallel test runs
Infrastructure FlakesNetwork latency, external dependenciesAPI timeout, 3rd-party service downtime
Non-Deterministic CodeRandom data, timestamps, UUIDsTest expects specific UUID value
Platform-SpecificOS, browser, environment differencesWorks on Chrome, fails on Safari

Cost of Flaky Tests

Direct costs:

  • Investigation time: 2-8 hours per flaky test
  • CI/CD pipeline delays: Average 15-minute re-run delay
  • False confidence: Ignoring real failures masked by flakiness

Indirect costs:

  • Developer frustration and decreased morale
  • Reduced trust in test suite → skipping tests
  • Delayed releases due to uncertainty about failures

Machine Learning Approaches to Flaky Test Detection

1. Supervised Learning: Classification Models

Train a model to classify tests as “flaky” or “stable” based on historical data and code features.

Feature Engineering

Execution history features:

execution_features = {
    'pass_rate': 0.73,  # Passes 73% of time
    'consecutive_failures': 2,
    'failure_variability': 0.45,  # High variance in outcomes
    'avg_execution_time': 12.3,
    'execution_time_stddev': 4.2,  # High time variance
    'last_10_outcomes': [1,1,0,1,1,0,1,1,1,0],  # 1=pass, 0=fail
}

Code-based features:

import ast

def extract_code_features(test_code):
    tree = ast.parse(test_code)

    return {
        'uses_sleep': 'time.sleep' in test_code,
        'uses_random': 'random' in test_code,
        'async_count': test_code.count('async '),
        'network_calls': test_code.count('requests.') + test_code.count('httpx.'),
        'database_queries': test_code.count('execute('),
        'wait_statements': test_code.count('WebDriverWait'),
        'thread_usage': test_code.count('Thread('),
        'external_deps': test_code.count('mock.') == 0,  # No mocking
        'assertion_count': test_code.count('assert'),
        'test_length': len(test_code.split('\n')),
    }

Environment features:

environment_features = {
    'os': 'linux',
    'browser': 'chrome',
    'parallel_execution': True,
    'ci_system': 'github_actions',
    'python_version': '3.11',
    'dependency_count': 147,
}

Training a Random Forest Classifier

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import pandas as pd

# Load labeled dataset
# Format: test_id, features..., is_flaky (0=stable, 1=flaky)
data = pd.read_csv('test_history.csv')

# Separate features and labels
X = data.drop(['test_id', 'is_flaky'], axis=1)
y = data['is_flaky']

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train model
model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    random_state=42
)
model.fit(X_train, y_train)

# Evaluate
from sklearn.metrics import classification_report, confusion_matrix

y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))

# Feature importance
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop flaky test indicators:")
print(feature_importance.head(10))

Example output:

Top flaky test indicators:
feature                    importance
pass_rate                  0.28
uses_sleep                 0.15
execution_time_stddev      0.12
network_calls              0.11
parallel_execution         0.09
failure_variability        0.08
async_count                0.06
...

2. Unsupervised Learning: Anomaly Detection

Identify unusual test behavior patterns without labeled data.

Isolation Forest for Flake Detection

from sklearn.ensemble import IsolationForest
import numpy as np

class FlakyTestAnomalyDetector:
    def __init__(self, contamination=0.1):
        """
        contamination: Expected proportion of flaky tests (10% default)
        """
        self.model = IsolationForest(
            contamination=contamination,
            random_state=42
        )

    def fit(self, test_execution_history):
        """
        test_execution_history: DataFrame with columns:
          - test_id, execution_time, outcome, timestamp, etc.
        """
        features = self._extract_time_series_features(test_execution_history)
        self.model.fit(features)

        return self

    def _extract_time_series_features(self, history):
        """Extract statistical features from execution history."""
        features = []

        for test_id in history['test_id'].unique():
            test_data = history[history['test_id'] == test_id]

            # Calculate statistics
            outcomes = test_data['outcome'].values  # 1=pass, 0=fail
            times = test_data['execution_time'].values

            features.append({
                'pass_rate': np.mean(outcomes),
                'pass_rate_variance': np.var(outcomes),
                'execution_time_mean': np.mean(times),
                'execution_time_variance': np.var(times),
                'outcome_entropy': self._calculate_entropy(outcomes),
                'consecutive_change_count': self._count_outcome_changes(outcomes),
            })

        return pd.DataFrame(features)

    def _calculate_entropy(self, outcomes):
        """Shannon entropy of outcome sequence."""
        from scipy.stats import entropy
        _, counts = np.unique(outcomes, return_counts=True)
        return entropy(counts)

    def _count_outcome_changes(self, outcomes):
        """Count how many times outcome switches (pass→fail or fail→pass)."""
        return np.sum(np.diff(outcomes) != 0)

    def predict_flaky(self, test_execution_history):
        """
        Returns: DataFrame with test_id, anomaly_score, is_flaky
        """
        features = self._extract_time_series_features(test_execution_history)

        # -1 = anomaly (flaky), 1 = normal (stable)
        predictions = self.model.predict(features)
        anomaly_scores = self.model.score_samples(features)

        results = pd.DataFrame({
            'test_id': test_execution_history['test_id'].unique(),
            'anomaly_score': anomaly_scores,
            'is_flaky': predictions == -1
        })

        return results.sort_values('anomaly_score')

# Usage
detector = FlakyTestAnomalyDetector(contamination=0.15)
detector.fit(historical_data)

flaky_tests = detector.predict_flaky(recent_executions)
print(flaky_tests[flaky_tests['is_flaky']])

3. Time Series Analysis: Failure Pattern Recognition

Analyze temporal patterns in test outcomes.

LSTM for Outcome Prediction

import tensorflow as tf
from tensorflow import keras

class FlakyTestPredictor:
    def __init__(self, sequence_length=20):
        self.sequence_length = sequence_length
        self.model = self._build_model()

    def _build_model(self):
        """LSTM model to predict next test outcome."""
        model = keras.Sequential([
            keras.layers.LSTM(64, input_shape=(self.sequence_length, 1)),
            keras.layers.Dropout(0.2),
            keras.layers.Dense(32, activation='relu'),
            keras.layers.Dense(1, activation='sigmoid')
        ])

        model.compile(
            optimizer='adam',
            loss='binary_crossentropy',
            metrics=['accuracy']
        )

        return model

    def prepare_sequences(self, outcomes):
        """
        Convert outcome list to sequences for LSTM.
        outcomes: [1,1,0,1,1,0,1,...]  (1=pass, 0=fail)
        """
        X, y = [], []

        for i in range(len(outcomes) - self.sequence_length):
            X.append(outcomes[i:i+self.sequence_length])
            y.append(outcomes[i+self.sequence_length])

        return np.array(X).reshape(-1, self.sequence_length, 1), np.array(y)

    def train(self, test_outcomes):
        """Train on historical outcomes."""
        X, y = self.prepare_sequences(test_outcomes)

        self.model.fit(
            X, y,
            epochs=50,
            batch_size=32,
            validation_split=0.2,
            verbose=0
        )

    def predict_stability(self, recent_outcomes):
        """
        Predict likelihood of next failure.
        Returns: probability of flakiness
        """
        if len(recent_outcomes) < self.sequence_length:
            return None

        X = np.array(recent_outcomes[-self.sequence_length:])
        X = X.reshape(1, self.sequence_length, 1)

        prediction = self.model.predict(X, verbose=0)[0][0]

        # High variance in predictions indicates flakiness
        predictions = []
        for _ in range(10):
            pred = self.model.predict(X, verbose=0)[0][0]
            predictions.append(pred)

        prediction_variance = np.var(predictions)

        return {
            'next_pass_probability': prediction,
            'prediction_variance': prediction_variance,
            'likely_flaky': prediction_variance > 0.05 or (0.3 < prediction < 0.7)
        }

# Usage
predictor = FlakyTestPredictor()

# Historical outcomes for a specific test
test_history = [1,1,1,0,1,1,0,1,1,1,0,0,1,1,0,1,1,1,1,0]  # 1=pass, 0=fail

predictor.train(test_history)

result = predictor.predict_stability(test_history)
print(f"Next pass probability: {result['next_pass_probability']:.2f}")
print(f"Likely flaky: {result['likely_flaky']}")

4. Root Cause Analysis with NLP

Use Natural Language Processing to analyze test logs and identify flakiness patterns.

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans

class FlakyRootCauseAnalyzer:
    def __init__(self):
        self.vectorizer = TfidfVectorizer(max_features=100)
        self.clustering = KMeans(n_clusters=5)

    def analyze_failure_messages(self, failure_logs):
        """
        Cluster similar failure messages to identify common root causes.

        failure_logs: List of dicts with 'test_id' and 'error_message'
        """
        messages = [log['error_message'] for log in failure_logs]

        # Convert error messages to feature vectors
        X = self.vectorizer.fit_transform(messages)

        # Cluster similar errors
        clusters = self.clustering.fit_predict(X)

        # Analyze clusters
        results = []
        for cluster_id in range(self.clustering.n_clusters):
            cluster_logs = [
                failure_logs[i] for i in range(len(failure_logs))
                if clusters[i] == cluster_id
            ]

            # Extract common keywords
            cluster_messages = [log['error_message'] for log in cluster_logs]
            common_terms = self._extract_common_terms(cluster_messages)

            results.append({
                'cluster_id': cluster_id,
                'test_count': len(cluster_logs),
                'common_terms': common_terms,
                'tests': [log['test_id'] for log in cluster_logs],
                'likely_cause': self._infer_cause(common_terms)
            })

        return results

    def _extract_common_terms(self, messages):
        """Extract most frequent terms in cluster."""
        vectorizer = TfidfVectorizer(max_features=5)
        vectorizer.fit(messages)
        return vectorizer.get_feature_names_out().tolist()

    def _infer_cause(self, terms):
        """Infer likely root cause from common terms."""
        cause_patterns = {
            'timeout': ['timeout', 'wait', 'deadline', 'hung'],
            'race_condition': ['concurrent', 'thread', 'lock', 'race'],
            'network': ['connection', 'network', 'socket', 'dns'],
            'resource': ['memory', 'disk', 'cpu', 'resource'],
            'state': ['cleanup', 'state', 'reset', 'teardown'],
        }

        for cause, keywords in cause_patterns.items():
            if any(keyword in term.lower() for term in terms for keyword in keywords):
                return cause

        return 'unknown'

# Usage
analyzer = FlakyRootCauseAnalyzer()

failure_logs = [
    {'test_id': 'test_login', 'error_message': 'Timeout waiting for element #submit'},
    {'test_id': 'test_checkout', 'error_message': 'Element not found after 10s timeout'},
    {'test_id': 'test_api', 'error_message': 'Connection timeout to api.example.com'},
    # ... more logs
]

causes = analyzer.analyze_failure_messages(failure_logs)

for cause in causes:
    print(f"\nCluster {cause['cluster_id']}: {cause['likely_cause']}")
    print(f"Affected tests: {cause['tests']}")
    print(f"Common terms: {cause['common_terms']}")

Practical Implementation

Building a Flaky Test Detection Pipeline

class FlakyTestDetectionPipeline:
    def __init__(self):
        self.classifier = RandomForestClassifier()
        self.anomaly_detector = FlakyTestAnomalyDetector()
        self.root_cause_analyzer = FlakyRootCauseAnalyzer()

    def collect_test_data(self, test_suite):
        """Run tests and collect execution data."""
        results = []

        for test in test_suite:
            # Run test multiple times
            outcomes = []
            for _ in range(10):
                outcome = test.run()
                outcomes.append({
                    'test_id': test.id,
                    'outcome': 1 if outcome.passed else 0,
                    'execution_time': outcome.duration,
                    'error_message': outcome.error if outcome.failed else None,
                    'timestamp': time.time()
                })

            results.extend(outcomes)

        return pd.DataFrame(results)

    def detect_flaky_tests(self, execution_data):
        """Main detection logic combining multiple approaches."""
        # 1. Statistical analysis
        stats = self._calculate_test_statistics(execution_data)

        # 2. ML (as discussed in [AI Test Metrics Analytics: Intelligent Analysis of QA Metrics](/blog/ai-test-metrics)) classification
        if self.classifier:
            ml_predictions = self._ml_classification(stats)

        # 3. Anomaly detection
        anomalies = self.anomaly_detector.predict_flaky(execution_data)

        # 4. Root cause analysis for failures
        failures = execution_data[execution_data['outcome'] == 0]
        if len(failures) > 0:
            root_causes = self.root_cause_analyzer.analyze_failure_messages(
                failures.to_dict('records')
            )

        # Combine results
        flaky_tests = self._combine_predictions(stats, ml_predictions, anomalies)

        return {
            'flaky_tests': flaky_tests,
            'root_causes': root_causes if len(failures) > 0 else [],
            'statistics': stats
        }

    def _calculate_test_statistics(self, data):
        """Calculate per-test statistics."""
        stats = []

        for test_id in data['test_id'].unique():
            test_data = data[data['test_id'] == test_id]
            outcomes = test_data['outcome'].values

            stats.append({
                'test_id': test_id,
                'pass_rate': np.mean(outcomes),
                'total_runs': len(outcomes),
                'failures': np.sum(outcomes == 0),
                'variance': np.var(outcomes),
                'is_deterministic': len(np.unique(outcomes)) == 1
            })

        return pd.DataFrame(stats)

    def _combine_predictions(self, stats, ml_preds, anomalies):
        """Combine multiple detection methods with voting."""
        results = stats.copy()

        # Simple voting: test is flaky if 2+ methods agree
        results['statistical_flaky'] = (results['pass_rate'] < 0.95) & (results['pass_rate'] > 0.05)
        results['ml_flaky'] = ml_preds if ml_preds is not None else False
        results['anomaly_flaky'] = anomalies['is_flaky'].values

        results['votes'] = (
            results['statistical_flaky'].astype(int) +
            results['ml_flaky'].astype(int) +
            results['anomaly_flaky'].astype(int)
        )

        results['is_flaky'] = results['votes'] >= 2

        return results[results['is_flaky']].sort_values('pass_rate')

# Usage in CI/CD
pipeline = FlakyTestDetectionPipeline()

# Collect data from test runs
execution_data = pipeline.collect_test_data(test_suite)

# Detect flaky tests
detection_results = pipeline.detect_flaky_tests(execution_data)

# Report findings
print(f"Found {len(detection_results['flaky_tests'])} flaky tests:")
for _, test in detection_results['flaky_tests'].iterrows():
    print(f"- {test['test_id']}: {test['pass_rate']:.1%} pass rate")

# Root causes
for cause in detection_results['root_causes']:
    print(f"\n{cause['likely_cause']}: {cause['test_count']} tests affected")

Integration with CI/CD

# .github/workflows/flaky-detection.yml
name: Flaky Test Detection

on:
  schedule:
    - cron: '0 2 * * *'  # Daily at 2 AM
  workflow_dispatch:

jobs:
  detect-flaky:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3

      - name: Run tests with flaky detection
        run: |
          pip install flaky-detector pytest pytest-json-report
          pytest --json-report --json-report-file=report.json

      - name: Analyze for flaky tests
        run: |
          python scripts/detect_flaky_tests.py --report report.json

      - name: Create issue for flaky tests
        if: env.FLAKY_TESTS_FOUND == 'true'
        uses: actions/github-script@v6
        with:
          script: |
            const fs = require('fs');
            const flaky = JSON.parse(fs.readFileSync('flaky_tests.json'));

            let body = '## 🚨 Flaky Tests Detected\n\n';
            flaky.forEach(test => {
              body += `### ${test.test_id}\n`;
              body += `- Pass rate: ${(test.pass_rate * 100).toFixed(1)}%\n`;
              body += `- Likely cause: ${test.root_cause}\n`;
              body += `- Recommendation: ${test.fix_suggestion}\n\n`;
            });

            github.rest.issues.create({
              owner: context.repo.owner,
              repo: context.repo.repo,
              title: 'Flaky Tests Detected',
              body: body,
              labels: ['flaky-test', 'quality']
            });

Strategies for Fixing Flaky Tests

Automated Fix Suggestions

def suggest_fixes(test_code, failure_patterns):
    """AI-powered fix suggestions based on detected patterns."""
    suggestions = []

    # Pattern: Uses sleep
    if 'time.sleep' in test_code:
        suggestions.append({
            'issue': 'Uses time.sleep (non-deterministic)',
            'fix': 'Replace with explicit WebDriverWait',
            'example': 'WebDriverWait(driver, 10).until(EC.element_to_be_clickable((By.ID, "submit")))'
        })

    # Pattern: No waits before interactions
    if 'click()' in test_code and 'wait' not in test_code.lower():
        suggestions.append({
            'issue': 'Click without explicit wait',
            'fix': 'Add wait before clicking',
            'example': 'wait.until(EC.element_to_be_clickable(element)).click()'
        })

    # Pattern: External API calls without mocking
    if 'requests.get' in test_code and 'mock' not in test_code:
        suggestions.append({
            'issue': 'External API dependency',
            'fix': 'Mock external API calls',
            'example': '@mock.patch("requests.get")\ndef test_api(mock_get): ...'
        })

    # Pattern: Database state dependency
    if 'SELECT' in test_code or 'INSERT' in test_code:
        suggestions.append({
            'issue': 'Database state dependency',
            'fix': 'Use test fixtures with cleanup',
            'example': '@pytest.fixture\ndef clean_db(): yield; db.rollback()'
        })

    return suggestions

Measuring Success

Key Metrics

MetricBefore ML DetectionAfter ML DetectionTarget
Flaky test identification time4-8 hours/test5 minutes (automated)< 10 min
False positive rate in CI25%8%< 5%
Test suite stability75%95%> 98%
Investigation time per failure30 min10 min< 15 min
Flaky tests in production15%3%< 2%

ROI Calculation

Team of 10 engineers, 5000 tests, 15% flaky rate

Manual detection cost:
750 flaky tests × 4 hours investigation = 3,000 hours
3,000 hours × $75/hour = $225,000

ML detection cost:
Setup: 80 hours × $75 = $6,000
Maintenance: 2 hours/week × 52 weeks × $75 = $7,800
Total: $13,800

Savings: $225,000 - $13,800 = $211,200/year
ROI: 1,530%

Conclusion

Flaky tests undermine confidence in automation and waste significant engineering resources. Machine Learning transforms flaky test detection from a reactive, manual process into a proactive, automated system that predicts flakiness, identifies root causes, and suggests fixes.

Start with simple statistical detection, add ML classification as you collect data, and integrate anomaly detection for comprehensive coverage. Track metrics, iterate on your models, and continuously improve test stability.

Remember: The goal isn’t just detecting flaky tests—it’s preventing them. Use ML insights to improve test design patterns, enforce best practices, and build a culture of test reliability.

Resources

  • Tools: Flaky Test Tracker (Google), DeFlaker, NonDex, FlakeFlagger
  • Research: “An Empirical Analysis of Flaky Tests” (IEEE), Google’s Flaky Test research
  • Datasets: Flaky test datasets on Zenodo, IDoFT dataset
  • Frameworks: pytest-flakefinder, Jest –detectFlakes

Stable tests, confident deployments. Let ML be your flakiness guardian.