Metadata-Version: 2.4
Name: acwa-trainer
Version: 1.0.0
Summary: Adaptive Class Weight Adjustment (ACWA) for Imbalanced Deep Learning
Home-page: https://github.com/Seread335/Thu-t-To-n-Adaptive-Class-Weight-Adjustment-ACWA-.git
Author: Huynh Thai Bao
Author-email: seread335@gmail.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: torchvision
Requires-Dist: torchmetrics
Requires-Dist: numpy
Requires-Dist: scikit-learn
Requires-Dist: matplotlib
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license-file
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# Adaptive Class Weight Adjustment (ACWA) - Automated Class Balancing for Deep Learning

![Python](https://img.shields.io/badge/Python-3.7%2B-blue)
![PyTorch](https://img.shields.io/badge/PyTorch-1.8%2B-orange)
![License](https://img.shields.io/badge/License-MIT-green)

## 📖 Table of Contents
- [Overview](#-overview)
- [Key Features](#-key-features)
- [Algorithm Design](#-algorithm-design)
- [When to Use ACWA](#-when-to-use-acwa)
- [Implementation Guide](#-implementation-guide)
- [Usage for Different Datasets](#-usage-for-different-datasets)
- [FAQ](#-faq)
- [Best Practices](#-best-practices)
- [Benchmark Results](#-benchmark-results)
- [Contributing](#-contributing)

## 🌟 Overview

ACWA is an advanced optimization algorithm designed to automatically adjust class weights during neural network training, particularly effective for imbalanced datasets. Unlike traditional approaches, ACWA dynamically adapts based on real-time performance metrics.

**Traditional Methods Limitations**:
- Static class weighting based on frequency
- Manual oversampling/undersampling
- Fixed cost-sensitive learning

**ACWA Advantages**:
- 🚀 Real-time performance monitoring
- ⚖️ Dynamic weight adjustment
- 🎯 Focus on underperforming classes
- 🤖 No manual intervention needed

## ✨ Key Features

- **Adaptive Learning**: Adjusts weights based on validation performance
- **Smoothing Mechanism**: Prevents drastic weight fluctuations
- **Multi-class Support**: Works with any number of classes
- **Framework Agnostic**: Compatible with PyTorch, TensorFlow, etc.
- **Plug-and-Play**: Easy integration into existing pipelines
- **TorchMetrics Integration**: Efficient F1-score calculation
- **Dynamic Weight Initialization**: Supports inverse class frequency
- **Early Stopping**: Prevents overfitting by monitoring validation performance
- **Numerical Stability**: Epsilon added to class frequency for robust weight initialization

## 🧠 Algorithm Design

### Core Concept
ACWA operates through a feedback loop:
1. **Monitor** class-wise performance
2. **Calculate** performance gaps
3. **Adjust** weights dynamically

### Mathematical Formulation

**Performance Error**:
```math
error_c = target\_metric - current\_metric_c
```

**Weight Update**:
```math
weight_c^{(t+1)} = clip(\beta \cdot weight_c^{(t)} + (1-\beta) \cdot (weight_c^{(t)} + \alpha \cdot error_c), 0.5, 2.0)
```

**Loss Modification**:
```math
\mathcal{L} = \sum_{c=1}^C weight_c \cdot \mathcal{L}_c
```

### Hyperparameters
| Parameter | Description      | Recommended Value |
|-----------|------------------|-------------------|
| α         | Learning rate    | 0.01-0.05         |
| β         | Smoothing factor | 0.8-0.95          |
| K         | Update frequency | 50-200 batches    |
| Target    | Performance goal | Class-specific    |

## 🏆 When to Use ACWA

### Ideal Scenarios
- 🏥 Medical diagnosis (rare disease detection)
- 💳 Fraud detection
- ⚠️ Rare event prediction
- 🛡️ Anomaly detection
- 📊 Highly imbalanced datasets

### Comparison with Alternatives
| Method          | Pros                | Cons                  |
|-----------------|---------------------|-----------------------|
| ACWA            | Adaptive, automatic | Slightly more compute |
| Class Weighting | Simple              | Static, manual tuning |
| Resampling      | Balances data       | May lose information  |
| Focal Loss      | Handles hard samples| Fixed strategy        |

## 💻 Implementation Guide

### Installation
```bash
pip install acwa-torch
```

### Basic Usage
```python
from acwa import ACWATrainer

# Initialize
trainer = ACWATrainer(
    num_classes=10,
    target_metric=0.85,  # Target F1-score
    alpha=0.02,
    beta=0.9,
    update_freq=100
)

# Training loop
for batch in dataloader:
    # Forward pass
    outputs = model(inputs)
    
    # ACWA-weighted loss
    loss = trainer.get_weighted_loss(outputs, labels)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Update metrics
    trainer.update_metrics(outputs, labels)
```

### Advanced Features
```python
# Custom metrics
trainer = ACWATrainer(
    metric_fn=custom_f1_function,
    metric_mode='max'  # or 'min'
)

# Combined with Focal Loss
trainer = ACWATrainer(
    loss_fn=FocalLoss(gamma=2.0),
    ...
)

# Initialize weights using inverse class frequency
class_counts = torch.bincount(torch.tensor(trainset.targets))
class_frequencies = class_counts.float() / (class_counts.sum() + 1e-6)

trainer = ACWATrainer(
    model=model,
    num_classes=10,
    class_frequencies=class_frequencies
)

# Early stopping example
best_f1 = 0
early_stop_counter = 0
patience = 5

for epoch in range(num_epochs):
    # ...training logic...
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), 'best_model.pth')
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    if early_stop_counter >= patience:
        print("Early stopping triggered.")
        break
```

## 📚 Usage for Different Datasets

### Vision Task (CIFAR-100)
1. Prepare the dataset:
   ```python
   from demo_cifar100 import main
   main()
   ```

### NLP Task (IMDB with Hugging Face)
1. Use Hugging Face's `datasets` library to load the IMDB dataset:
   ```python
   from demo_huggingface import main
   main()
   ```

### Time-Series Task (UCR)
1. Load UCR dataset using `tslearn` or a similar library.

## ❓ FAQ

1. **How to integrate Hugging Face models?**
   - Use the `transformers` library and wrap the model with `ACWATrainer`.

2. **What if F1-score does not improve?**
   - Ensure `update_freq` is set appropriately (20-50).
   - Check if the dataset is properly preprocessed.

## 📝 Best Practices

1. **Validation Set**: Ensure representative distribution
2. **Initial Weights**: Start with uniform weights (1.0)
3. **Hyperparameter Tuning**:
   - Start with α=0.01, β=0.9
   - Adjust based on convergence
4. **Monitoring**: Track weight evolution during training
5. **Combination Strategies**:
   - Works well with data augmentation
   - Can be combined with focal loss

```python
# Example weight evolution plot
plt.plot(weight_history)
plt.title('ACWA Weight Adjustment')
plt.xlabel('Update Steps')
plt.ylabel('Class Weight')
plt.show()
```

## 🏅 Benchmark Results

### CIFAR-10 (Imbalanced)
| Method          | Accuracy | Macro F1 | Training Time |
|-----------------|----------|----------|---------------|
| ACWA (Version 3)| 86.3%    | 0.781    | 0.7h          |
| ACWA (Final)    | **87.5%**| **0.799**| **0.65h**     |

## 🤝 Contributing

We welcome contributions! Please see our:
- [Contribution Guidelines](CONTRIBUTING.md)
- [Code of Conduct](CODE_OF_CONDUCT.md)

### Future Improvements
1. **Unit Testing**:
   - Add test cases for edge scenarios (e.g., empty classes, small batch sizes).
   - Ensure compatibility with various datasets and imbalance ratios.

2. **Distributed Training**:
   - Implement support for multi-GPU setups using `torch.nn.parallel.DistributedDataParallel`.
   - Synchronize metrics across GPUs for consistent weight updates.

3. **Additional Frameworks**:
   - Extend support to TensorFlow/Keras for broader adoption.

## 📜 License

MIT License - Free for academic and commercial use
