White Box Testing Methods
Testing AI systems with full access to internals
White Box Testing Methods
Table of Contents
- Learning Objectives
- Introduction
- Core Concepts
- Practical Applications
- Common Pitfalls
- Hands-on Exercise
- Further Reading
- Connections
Learning Objectives
By the end of this topic, you should be able to:
- Understand white box testing principles for AI/ML systems
- Implement comprehensive testing strategies with full model access
- Design test suites that leverage internal model structure
- Analyze model behavior through gradient analysis and neuron activation patterns
- Build automated white box testing pipelines for production ML systems
Introduction
White box testing in machine learning represents a fundamental departure from traditional software testing. Unlike black box approaches that only examine input-output relationships, white box testing leverages complete access to model architecture, weights, gradients, and internal representations. This deep visibility enables sophisticated testing strategies that can uncover subtle vulnerabilities, biases, and failure modes that would be invisible to external observation.
The power of white box testing lies in its ability to directly examine the learned representations and decision boundaries of neural networks. By analyzing gradient flows, neuron activations, and attention patterns, we can understand not just what a model predicts, but why it makes those predictions. This insight is crucial for building trustworthy AI systems, especially in safety-critical applications where understanding failure modes is paramount.
This topic provides a comprehensive exploration of white box testing methodologies, from theoretical foundations to practical implementations, with a focus on building robust testing frameworks for modern deep learning systems.
Core Concepts
Fundamentals of White Box Testing
Definition and Scope:
class WhiteBoxTestingFramework:
"""Core framework for white box testing of ML models"""
def __init__(self, model):
self.model = model
self.architecture = self.extract_architecture()
self.test_components = {
'gradient_analysis': GradientAnalyzer(model),
'neuron_coverage': NeuronCoverageAnalyzer(model),
'activation_patterns': ActivationPatternAnalyzer(model),
'weight_analysis': WeightAnalyzer(model),
'decision_boundary': DecisionBoundaryAnalyzer(model),
'internal_representations': RepresentationAnalyzer(model)
}
Gradient-Based Testing
1. Gradient Analysis
class GradientAnalyzer:
"""Analyze model behavior through gradient inspection"""
def __init__(self, model):
self.model = model
self.gradient_metrics = {}
def compute_input_gradients(self, x, y):
"""Compute gradients with respect to input"""
x = x.clone().requires_grad_(True)
# Forward pass
output = self.model(x)
loss = F.cross_entropy(output, y)
# Backward pass
loss.backward()
# Analyze gradients
grad = x.grad.data
metrics = {
'gradient_magnitude': torch.norm(grad),
'gradient_variance': torch.var(grad),
'max_gradient': torch.max(torch.abs(grad)),
'sparsity': (grad.abs() < 1e-6).float().mean(),
'gradient_pattern': self.analyze_pattern(grad)
}
return grad, metrics
def compute_layer_gradients(self, x, y):
"""Analyze gradients at each layer"""
layer_gradients = {}
# Hook to capture gradients
def hook_fn(module, grad_input, grad_output):
layer_gradients[module] = {
'input_grad': grad_input,
'output_grad': grad_output
}
# Register hooks
hooks = []
for name, module in self.model.named_modules():
if len(list(module.children())) == 0: # Leaf modules only
hook = module.register_backward_hook(hook_fn)
hooks.append(hook)
# Forward and backward
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Remove hooks
for hook in hooks:
hook.remove()
return layer_gradients
def gradient_saturation_test(self, test_loader):
"""Test for vanishing/exploding gradients"""
saturation_stats = []
for x, y in test_loader:
layer_grads = self.compute_layer_gradients(x, y)
for layer, grads in layer_grads.items():
if grads['output_grad'] is not None:
grad_norm = torch.norm(grads['output_grad'][0])
saturation_stats.append({
'layer': str(layer),
'gradient_norm': grad_norm.item(),
'is_vanishing': grad_norm < 1e-6,
'is_exploding': grad_norm > 1e3
})
return saturation_stats
2. Saliency Testing
class SaliencyTester:
"""Test model through saliency analysis"""
def __init__(self, model):
self.model = model
def compute_saliency_map(self, x, target_class):
"""Generate saliency map for input"""
x.requires_grad_()
# Forward pass
output = self.model(x)
# Create one-hot target
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
# Backward pass
output.backward(gradient=one_hot)
# Get saliency
saliency = x.grad.data.abs()
return saliency
def test_saliency_consistency(self, x, augmentations):
"""Test if saliency maps are consistent across augmentations"""
original_saliency = self.compute_saliency_map(x, self.model(x).argmax())
consistency_scores = []
for aug in augmentations:
aug_x = aug(x)
aug_saliency = self.compute_saliency_map(aug_x, self.model(aug_x).argmax())
# Compute similarity
similarity = F.cosine_similarity(
original_saliency.flatten(),
aug_saliency.flatten(),
dim=0
)
consistency_scores.append(similarity.item())
return {
'mean_consistency': np.mean(consistency_scores),
'min_consistency': np.min(consistency_scores),
'consistency_variance': np.var(consistency_scores)
}
Neuron Coverage Testing
1. Basic Neuron Coverage
class NeuronCoverageAnalyzer:
"""Analyze test coverage at neuron level"""
def __init__(self, model, threshold=0.5):
self.model = model
self.threshold = threshold
self.neuron_states = {}
self.coverage_history = []
def setup_coverage_tracking(self):
"""Setup hooks to track neuron activations"""
self.activation_traces = {}
def hook_fn(name):
def hook(module, input, output):
self.activation_traces[name] = output.detach()
return hook
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ReLU)):
module.register_forward_hook(hook_fn(name))
def compute_coverage(self, test_loader):
"""Compute neuron coverage on test set"""
self.setup_coverage_tracking()
activated_neurons = set()
total_neurons = 0
for x, y in test_loader:
_ = self.model(x)
for name, activation in self.activation_traces.items():
# Flatten activation
act_flat = activation.view(activation.size(0), -1)
# Track activated neurons
activated = (act_flat > self.threshold).any(dim=0)
for idx in activated.nonzero().squeeze():
activated_neurons.add(f"{name}_{idx.item()}")
total_neurons = max(total_neurons,
len(activated_neurons) + (~activated).sum().item())
coverage = len(activated_neurons) / total_neurons if total_neurons > 0 else 0
return {
'neuron_coverage': coverage,
'activated_neurons': len(activated_neurons),
'total_neurons': total_neurons,
'coverage_by_layer': self.compute_layer_coverage()
}
2. Multi-Granularity Coverage
class MultiGranularityCoverage:
"""Advanced coverage metrics at multiple granularities"""
def __init__(self, model):
self.model = model
self.coverage_metrics = {
'neuron': NeuronCoverage(),
'layer': LayerCoverage(),
'path': PathCoverage(),
'boundary': BoundaryCoverage()
}
def k_multisection_neuron_coverage(self, test_loader, k=10):
"""K-multisection neuron coverage"""
neuron_ranges = defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')})
section_coverage = defaultdict(lambda: set())
# First pass: find neuron value ranges
for x, y in test_loader:
activations = self.get_all_activations(x)
for layer_name, act in activations.items():
act_flat = act.view(-1)
for idx, value in enumerate(act_flat):
key = f"{layer_name}_{idx}"
neuron_ranges[key]['min'] = min(neuron_ranges[key]['min'], value.item())
neuron_ranges[key]['max'] = max(neuron_ranges[key]['max'], value.item())
# Second pass: compute section coverage
for x, y in test_loader:
activations = self.get_all_activations(x)
for layer_name, act in activations.items():
act_flat = act.view(-1)
for idx, value in enumerate(act_flat):
key = f"{layer_name}_{idx}"
range_min = neuron_ranges[key]['min']
range_max = neuron_ranges[key]['max']
if range_max > range_min:
section = int((value.item() - range_min) / (range_max - range_min) * k)
section = min(section, k-1) # Handle edge case
section_coverage[key].add(section)
# Compute coverage
total_sections = len(neuron_ranges) * k
covered_sections = sum(len(sections) for sections in section_coverage.values())
return covered_sections / total_sections if total_sections > 0 else 0
Internal Representation Analysis
1. Layer-wise Analysis
class LayerAnalyzer:
"""Analyze behavior of individual layers"""
def __init__(self, model):
self.model = model
self.layer_stats = {}
def analyze_layer_outputs(self, layer_name, test_loader):
"""Comprehensive analysis of layer outputs"""
outputs = []
# Collect outputs
def hook_fn(module, input, output):
outputs.append(output.detach())
# Get target layer
target_layer = dict(self.model.named_modules())[layer_name]
hook = target_layer.register_forward_hook(hook_fn)
# Run through test data
for x, y in test_loader:
_ = self.model(x)
hook.remove()
# Analyze collected outputs
all_outputs = torch.cat(outputs, dim=0)
analysis = {
'mean_activation': all_outputs.mean().item(),
'std_activation': all_outputs.std().item(),
'sparsity': (all_outputs == 0).float().mean().item(),
'dead_neurons': self.find_dead_neurons(all_outputs),
'activation_distribution': self.analyze_distribution(all_outputs),
'correlation_matrix': self.compute_neuron_correlation(all_outputs)
}
return analysis
def find_dead_neurons(self, activations):
"""Identify neurons that never activate"""
# Reshape to (num_samples, num_neurons)
act_reshaped = activations.view(activations.size(0), -1)
# Find neurons that are always zero
dead_mask = (act_reshaped == 0).all(dim=0)
dead_indices = dead_mask.nonzero().squeeze()
return {
'num_dead': dead_mask.sum().item(),
'dead_percentage': dead_mask.float().mean().item() * 100,
'dead_indices': dead_indices.tolist()
}
2. Attention Mechanism Testing
class AttentionTester:
"""Test attention mechanisms in transformers"""
def __init__(self, model):
self.model = model
self.attention_maps = {}
def extract_attention_weights(self, x):
"""Extract attention weights from all layers"""
attention_weights = []
def hook_fn(module, input, output):
if hasattr(module, 'attention_weights'):
attention_weights.append(module.attention_weights)
hooks = []
for module in self.model.modules():
if 'attention' in module.__class__.__name__.lower():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
# Forward pass
_ = self.model(x)
# Remove hooks
for hook in hooks:
hook.remove()
return attention_weights
def test_attention_consistency(self, x, perturbation_size=0.01):
"""Test if attention is stable under small perturbations"""
# Original attention
original_attention = self.extract_attention_weights(x)
# Perturbed attention
x_perturbed = x + torch.randn_like(x) * perturbation_size
perturbed_attention = self.extract_attention_weights(x_perturbed)
# Compare attention maps
consistency_scores = []
for orig, pert in zip(original_attention, perturbed_attention):
similarity = F.cosine_similarity(
orig.flatten(),
pert.flatten(),
dim=0
)
consistency_scores.append(similarity.item())
return {
'mean_consistency': np.mean(consistency_scores),
'min_consistency': np.min(consistency_scores),
'layer_consistency': consistency_scores
}
Decision Boundary Analysis
1. Boundary Exploration
class DecisionBoundaryAnalyzer:
"""Analyze model decision boundaries"""
def __init__(self, model):
self.model = model
def find_decision_boundary(self, x1, x2, num_steps=100):
"""Find decision boundary between two samples"""
# Ensure different predictions
pred1 = self.model(x1).argmax()
pred2 = self.model(x2).argmax()
if pred1 == pred2:
return None
# Binary search for boundary
low, high = 0.0, 1.0
for _ in range(num_steps):
mid = (low + high) / 2
x_mid = x1 * (1 - mid) + x2 * mid
pred_mid = self.model(x_mid).argmax()
if pred_mid == pred1:
low = mid
else:
high = mid
# Return boundary point
boundary_point = x1 * (1 - high) + x2 * high
return {
'boundary_point': boundary_point,
'distance_from_x1': torch.norm(boundary_point - x1).item(),
'distance_from_x2': torch.norm(boundary_point - x2).item(),
'interpolation_factor': high
}
def measure_boundary_distance(self, x, num_directions=10):
"""Measure distance to decision boundary in multiple directions"""
original_pred = self.model(x).argmax()
distances = []
for _ in range(num_directions):
# Random direction
direction = torch.randn_like(x)
direction = direction / torch.norm(direction)
# Binary search for boundary
step_size = 0.1
current_x = x.clone()
while self.model(current_x).argmax() == original_pred:
current_x += step_size * direction
# Prevent infinite search
if torch.norm(current_x - x) > 10:
break
distance = torch.norm(current_x - x).item()
distances.append(distance)
return {
'mean_distance': np.mean(distances),
'min_distance': np.min(distances),
'std_distance': np.std(distances),
'distances': distances
}
Weight and Architecture Analysis
1. Weight Distribution Testing
class WeightAnalyzer:
"""Analyze model weight distributions and patterns"""
def __init__(self, model):
self.model = model
def analyze_weight_distributions(self):
"""Analyze weight distributions across layers"""
weight_stats = {}
for name, param in self.model.named_parameters():
if 'weight' in name:
weights = param.data.flatten()
weight_stats[name] = {
'mean': weights.mean().item(),
'std': weights.std().item(),
'min': weights.min().item(),
'max': weights.max().item(),
'sparsity': (weights.abs() < 1e-6).float().mean().item(),
'quantiles': {
'25%': torch.quantile(weights, 0.25).item(),
'50%': torch.quantile(weights, 0.50).item(),
'75%': torch.quantile(weights, 0.75).item()
},
'norm': torch.norm(weights).item()
}
return weight_stats
def test_weight_symmetry(self):
"""Test for unexpected symmetries in weights"""
symmetry_issues = []
for name, param in self.model.named_parameters():
if len(param.shape) >= 2: # Matrix weights
weight_matrix = param.data
# Check for exact duplicates
for i in range(weight_matrix.shape[0]):
for j in range(i+1, weight_matrix.shape[0]):
if torch.allclose(weight_matrix[i], weight_matrix[j], atol=1e-6):
symmetry_issues.append({
'layer': name,
'type': 'duplicate_filters',
'indices': (i, j)
})
# Check for near-zero variance
variance = weight_matrix.var(dim=0)
low_variance_indices = (variance < 1e-6).nonzero().squeeze()
if len(low_variance_indices) > 0:
symmetry_issues.append({
'layer': name,
'type': 'low_variance_neurons',
'count': len(low_variance_indices)
})
return symmetry_issues
Practical Applications
Comprehensive White Box Test Suite
class ComprehensiveWhiteBoxTestSuite:
"""Complete white box testing framework for production models"""
def __init__(self, model, test_data):
self.model = model
self.test_data = test_data
self.test_results = {}
# Initialize all analyzers
self.gradient_analyzer = GradientAnalyzer(model)
self.coverage_analyzer = NeuronCoverageAnalyzer(model)
self.layer_analyzer = LayerAnalyzer(model)
self.boundary_analyzer = DecisionBoundaryAnalyzer(model)
self.weight_analyzer = WeightAnalyzer(model)
def run_full_test_suite(self):
"""Execute comprehensive white box tests"""
print("Running White Box Test Suite...")
# 1. Gradient Tests
print("\n1. Gradient Analysis...")
self.test_results['gradient'] = {
'saturation': self.test_gradient_saturation(),
'input_sensitivity': self.test_input_gradient_sensitivity(),
'layer_gradients': self.test_layer_gradient_flow()
}
# 2. Coverage Tests
print("\n2. Coverage Analysis...")
self.test_results['coverage'] = {
'neuron_coverage': self.test_neuron_coverage(),
'multi_section_coverage': self.test_multi_section_coverage(),
'path_coverage': self.test_path_coverage()
}
# 3. Internal Representation Tests
print("\n3. Internal Representation Analysis...")
self.test_results['representations'] = {
'layer_analysis': self.test_all_layers(),
'dead_neurons': self.test_dead_neurons(),
'activation_patterns': self.test_activation_patterns()
}
# 4. Robustness Tests
print("\n4. Robustness Analysis...")
self.test_results['robustness'] = {
'boundary_distance': self.test_boundary_distances(),
'local_linearity': self.test_local_linearity(),
'adversarial_vulnerability': self.test_adversarial_vulnerability()
}
# 5. Weight Analysis
print("\n5. Weight Analysis...")
self.test_results['weights'] = {
'distributions': self.weight_analyzer.analyze_weight_distributions(),
'symmetries': self.weight_analyzer.test_weight_symmetry(),
'pruning_candidates': self.identify_pruning_candidates()
}
# Generate report
return self.generate_test_report()
def test_gradient_saturation(self):
"""Test for gradient saturation issues"""
saturation_stats = self.gradient_analyzer.gradient_saturation_test(self.test_data)
# Analyze results
vanishing_layers = [s['layer'] for s in saturation_stats if s['is_vanishing']]
exploding_layers = [s['layer'] for s in saturation_stats if s['is_exploding']]
return {
'vanishing_gradient_layers': vanishing_layers,
'exploding_gradient_layers': exploding_layers,
'healthy_gradient_percentage': len([s for s in saturation_stats
if not s['is_vanishing'] and not s['is_exploding']]) / len(saturation_stats)
}
Real-World Case Study: Testing Vision Transformer
class VisionTransformerWhiteBoxTest:
"""Specialized white box testing for Vision Transformers"""
def __init__(self, vit_model):
self.model = vit_model
self.patch_size = vit_model.patch_size
self.num_patches = (224 // self.patch_size) ** 2
def test_attention_patterns(self, test_images):
"""Analyze attention patterns in ViT"""
attention_stats = {
'head_specialization': [],
'attention_entropy': [],
'skip_connections': []
}
for img in test_images:
# Get attention weights
attentions = self.extract_attention_weights(img)
for layer_idx, layer_attention in enumerate(attentions):
# Analyze each attention head
for head_idx in range(layer_attention.size(1)):
head_attention = layer_attention[0, head_idx]
# Compute entropy
entropy = -torch.sum(head_attention * torch.log(head_attention + 1e-6))
attention_stats['attention_entropy'].append({
'layer': layer_idx,
'head': head_idx,
'entropy': entropy.item()
})
# Check for specialized patterns
if self.is_diagonal_attention(head_attention):
attention_stats['head_specialization'].append({
'layer': layer_idx,
'head': head_idx,
'type': 'diagonal/local'
})
elif self.is_global_attention(head_attention):
attention_stats['head_specialization'].append({
'layer': layer_idx,
'head': head_idx,
'type': 'global'
})
return attention_stats
def test_position_embedding_influence(self, test_loader):
"""Test how position embeddings affect predictions"""
results = []
for img, label in test_loader:
# Original prediction
orig_pred = self.model(img)
# Shuffle position embeddings
shuffled_pos = self.shuffle_position_embeddings()
# Prediction with shuffled positions
shuffled_pred = self.model_with_shuffled_pos(img, shuffled_pos)
# Measure impact
pred_change = torch.norm(orig_pred - shuffled_pred).item()
accuracy_impact = (orig_pred.argmax() != shuffled_pred.argmax()).float().item()
results.append({
'prediction_change': pred_change,
'accuracy_impact': accuracy_impact
})
return {
'mean_prediction_change': np.mean([r['prediction_change'] for r in results]),
'accuracy_degradation': np.mean([r['accuracy_impact'] for r in results])
}
Automated Testing Pipeline
class AutomatedWhiteBoxPipeline:
"""Automated pipeline for continuous white box testing"""
def __init__(self, model_path, test_data_path):
self.model_path = model_path
self.test_data_path = test_data_path
self.test_history = []
def run_automated_tests(self, schedule='daily'):
"""Run automated test suite on schedule"""
# Load latest model
model = self.load_latest_model()
test_data = self.load_test_data()
# Initialize test suite
test_suite = ComprehensiveWhiteBoxTestSuite(model, test_data)
# Run tests
results = test_suite.run_full_test_suite()
# Check for regressions
regressions = self.check_for_regressions(results)
if regressions:
self.alert_team(regressions)
# Update metrics dashboard
self.update_dashboard(results)
# Store results
self.test_history.append({
'timestamp': datetime.now(),
'model_version': self.get_model_version(),
'results': results,
'regressions': regressions
})
return results
def check_for_regressions(self, current_results):
"""Compare against historical results"""
if not self.test_history:
return []
last_results = self.test_history[-1]['results']
regressions = []
# Check key metrics
if current_results['coverage']['neuron_coverage'] < last_results['coverage']['neuron_coverage'] - 0.05:
regressions.append({
'type': 'coverage_regression',
'metric': 'neuron_coverage',
'previous': last_results['coverage']['neuron_coverage'],
'current': current_results['coverage']['neuron_coverage']
})
# Check for new dead neurons
current_dead = current_results['representations']['dead_neurons']
previous_dead = last_results['representations']['dead_neurons']
if len(current_dead) > len(previous_dead):
regressions.append({
'type': 'dead_neuron_increase',
'previous_count': len(previous_dead),
'current_count': len(current_dead)
})
return regressions
Common Pitfalls
1. Over-relying on Coverage Metrics
Mistake: Assuming high neuron coverage means thorough testing Problem: Coverage doesn't guarantee semantic diversity Solution: Combine coverage with behavioral testing and semantic analysis
2. Ignoring Computational Cost
Mistake: Running exhaustive white box tests on every input Problem: Prohibitive computational requirements Solution: Smart sampling and incremental testing strategies
3. Missing Layer Interactions
Mistake: Testing layers in isolation Problem: Missing emergent behaviors from layer interactions Solution: Test information flow and layer dependencies
4. Static Test Suites
Mistake: Using fixed test cases Problem: Model learns to pass specific tests Solution: Dynamic test generation based on model evolution
5. Incorrect Gradient Interpretation
Mistake: Misinterpreting gradient magnitudes Problem: Small gradients don't always mean robustness Solution: Context-aware gradient analysis
Hands-on Exercise
Build a white box testing framework:
-
Implement basic analyzers:
- Gradient flow analyzer
- Neuron coverage tracker
- Activation pattern extractor
- Weight distribution analyzer
-
Create advanced tests:
- Dead neuron detection
- Attention consistency testing
- Decision boundary mapping
- Layer interaction analysis
-
Build testing pipeline:
- Automated test execution
- Regression detection
- Performance benchmarking
- Report generation
-
Optimize for scale:
- Efficient gradient computation
- Sampling strategies
- Parallel test execution
- Incremental testing
-
Create visualizations:
- Neuron activation heatmaps
- Gradient flow diagrams
- Coverage progression
- Test result dashboards
Further Reading
- "Testing Deep Neural Networks" - Sun et al. 2018
- "DeepXplore: Automated Whitebox Testing" - Pei et al. 2017
- "Neuron Coverage for Deep Neural Networks" - Ma et al. 2018
- "MODE: Automated Neural Network Model Debugging" - Ma et al. 2019
- "White-box Testing of Deep Neural Networks" - Xie et al. 2019
- "Guiding Deep Learning System Testing using Surprise Adequacy" - Kim et al. 2019
Connections
- Related Topics: Black Box Testing, Grey Box Testing, Interpretability
- Prerequisites: Neural Networks, Gradient Descent
- Next Steps: Model Debugging, Safety Evaluation