Rogendo commited on
Commit
cb392b6
Β·
1 Parent(s): bb79f20

Added labels, metrics, and training logs

Browse files
classifier_model.md ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## **DistilBERT Multi-Label Classification Model Documentation**
2
+
3
+ **Objective**
4
+
5
+ This document serves as the official guide for the fine-tuned DistilBERT-Uncased classification model. Its purpose is to provide a complete technical and functional overview for engineering and product teams, enabling them to understand, integrate, and effectively use the model for automated case categorization in call center scenarios.
6
+
7
+ -----
8
+
9
+ ### **1. Model Overview**
10
+
11
+ The model is a **multi-label text classification model** built on the **DistilBERT-base-uncased** architecture. DistilBERT, a distilled version of BERT, was chosen for its optimal balance of performance and computational efficiency, which is critical for real-time inference in a high-volume call center environment.
12
+
13
+ * **Architecture**: The model consists of the pre-trained DistilBERT encoder followed by a custom **classification head**. This head is a linear layer with a sigmoid activation function, which is essential for multi-label classification as it allows the model to predict multiple independent labels for a single input.
14
+ * **Training Data**: The model was fine-tuned on a proprietary dataset of over 1,000 anonymized call transcripts and 10,000 Synthetic call transcripts. The dataset was meticulously cleaned, annotated, balanced and stratified strategically to minimize bias and maximize generalization.
15
+ * **Data Characteristics**: The data includes transcripts from diverse call types, ensuring the model can handle various linguistic patterns, from formal inquiries to informal complaints. The multi-label nature of the data required each transcript to be annotated with one or more relevant categories.
16
+ * **Fine-Tuning Process**: The model was trained with the following key configurations:
17
+ * **Loss Function**: Binary Cross-Entropy (BCE) with logits, which is appropriate for multi-label tasks where each label is treated as a separate binary classification problem.
18
+ * **Optimizer**: AdamW.
19
+ * **Training Schedule**: A learning rate scheduler with a warm-up phase was used to ensure stable training.
20
+
21
+ -----
22
+
23
+ ### **2. Classification Tasks and Class Definitions**
24
+
25
+ The model performs multi-label classification across four distinct classification tasks. For each task, the model can assign one or more labels. The full list of supported categories and their definitions is provided below.
26
+
27
+ | Classification Task | Labels | Definition |
28
+ | :--- | :--- | :--- |
29
+ | **Sub-Topic Categorization** | `Adoption`, `Albinism`, `Balanced Diet`, `Birth Registration`,`Breast Feeding`, `etc` | Categorizes the caller's reason for the call. This is crucial for real-time agent feedback and post-call analysis. |
30
+ | **Priority/Urgency Detection** | `Low`, `Medium`, `High` | Assesses the criticality of the customer's issue. `High` urgency cases are flagged for immediate escalation to a supervisor queue. |
31
+ | **Main Topic Categorization** | `Advice and Counselling`, `Child Maintenance & Custody`, `Disability`, `GBV`, `VANE`, `Nutrition`, `Information` | The primary function of the model, used to tag the call's content. |
32
+ | **Intervention** | `Referred`, `Counselling`, `Signposting`, `Awareness/Information Provided` | Predicts the final state of the call based on the conversation, helping to automate case management and follow-up procedures. |
33
+
34
+ -----
35
+
36
+ ### **3. Performance Metrics**
37
+
38
+ The model's performance is monitored continuously on a separate, un-seen validation set. The metrics are reported using a **"micro-average"** approach, which aggregates the contributions of all classes to compute the average metric. This is particularly useful for multi-label classification with imbalanced classes.
39
+
40
+ * **Micro-F1 Score**: The primary metric for overall model performance. It provides a single score that balances precision and recall across all labels. A high micro-F1 score indicates that the model is performing well on all classes, including minority ones.
41
+ * **Precision, Recall by Class**: These metrics are tracked for each individual label to provide granular insight into the model's performance. For instance, high precision on the "Urgent" label is critical to avoid false escalations, while high recall on "VANE" is vital to catch every instance of a major issue.
42
+ * **Confusion Matrices**: Given the multi-label nature of the model, a standard confusion matrix is replaced by a set of **per-class confusion matrices**. Each matrix shows the performance of the model for one specific label, treating it as a binary classification problem (e.g., "Is this call `VANE` related?").
43
+
44
+ -----
45
+
46
+ ### **4. Usage Examples**
47
+
48
+ #### **4.0 Response Format**
49
+
50
+ The response from the `/classifier/classify` endpoint is a JSON object with the following fields:
51
+
52
+ ```json
53
+ {
54
+ "main_category": "The main category of the case.",
55
+ "sub_category": "The sub-category of the case.",
56
+ "intervention": "The recommended intervention for the case.",
57
+ "priority": "The priority of the case.",
58
+ "processing_time": "The time taken to process the request.",
59
+ "model_info": {
60
+ "model_path": "The path to the model.",
61
+ "loaded": "Whether the model is loaded.",
62
+ "load_time": "The time when the model was loaded.",
63
+ "device": "The device on which the model is running.",
64
+ "error": "Any error that occurred during model loading."
65
+ },
66
+ "timestamp": "The timestamp of the request."
67
+ }
68
+ ```
69
+ #### **4.1. API Endpoints**
70
+
71
+ * **Primary Endpoint**: `POST /classifier/classify`
72
+ * **Request Body (JSON)**:
73
+ ```json
74
+ {
75
+ "text_transcript": "The transcript of the call goes here. It needs to be a single string."
76
+ }
77
+ ```
78
+ * **Example Response (JSON)**:
79
+ ```json
80
+ {
81
+ "case_id": "c-1a2b3c4d",
82
+ "predictions": {
83
+ "sub_category": {
84
+ "label": "Information",
85
+ "confidence": 0.92
86
+ },
87
+ "priority": {
88
+ "label": "1",
89
+ "confidence": 0.85
90
+ },
91
+ "main_category":
92
+ {
93
+ "label": "Nutrition",
94
+ "confidence": 0.95
95
+ }
96
+ ,
97
+ "intervention": {
98
+ "label": "Counseling",
99
+ "confidence": 0.89
100
+ }
101
+ }
102
+ }
103
+ ```
104
+ - Curl request
105
+
106
+ ```bash
107
+ curl -X POST \
108
+ -H "Content-Type: application/json" \
109
+ -d '{"narrative": "A 12-year-old girl is being abused by her stepfather."}' \
110
+ http://localhost:8123/classifier/classify
111
+ ```
112
+
113
+ #### **4.2. Input Text Processing**
114
+
115
+ The Classify function expects a raw, clean transcript string. The model's serving pipeline handles all necessary preprocessing steps, including:
116
+
117
+ * **Normalization**: Converting all text to lowercase and removing irrelevant characters.
118
+ * **Tokenization**: Segmenting the text into tokens using the DistilBERT tokenizer, which handles sub-word units and special characters.
119
+
120
+ -----
121
+
122
+ ### **5. Confidence Thresholds**
123
+
124
+ The confidence score returned for each prediction is a probability value (0 to 1) indicating the model's certainty. The choice of threshold is a business decision that should be tuned based on the application's risk tolerance.
125
+
126
+ * **High Thresholds (e.g., \> 0.90)**: Recommended for **critical, automated actions** like case escalation or auto-creation of a trouble ticket. This ensures high precision and minimizes false positives, but may result in lower recall (i.e., missing some relevant cases).
127
+ * **Medium Thresholds (e.g., \> 0.75)**: Ideal for **data enrichment and analytics**. This provides a broader set of tags for dashboards and reports, balancing precision and recall to get a more complete picture of call trends.
128
+ * **Low Thresholds (e.g., \> 0.50)**: Can be used for **human-in-the-loop applications**, where predictions serve as suggestions to an agent or reviewer. A lower threshold increases recall, ensuring a human sees a potential label even if the model isn't highly confident.
129
+
130
+ -----
131
+
132
+ ### **6. Integration Guide: NLP Pipeline Flow**
133
+
134
+ The classification model is a critical component of our overall NLP pipeline. The standard flow for a call center interaction is as follows:
135
+
136
+ 1. **Audio Ingestion**: Raw audio from the call is captured and streamed to the transcription service.
137
+ 2. **ASR (Automatic Speech Recognition)**: The audio is converted into a text transcript in real time using Finetuned Whisper ASR model.
138
+ 3. **Real-Time Analytics**:
139
+ * The transcript is fed to the **DistilBERT Classification Model**.
140
+ * The model returns labels and confidence scores.
141
+
142
+ ### **7. Fine-Tuning**
143
+
144
+ This section provides comprehensive documentation for the fine-tuning process of the DistilBERT multi-label classification model, including the continuous learning framework and automated version control system.
145
+
146
+ -----
147
+
148
+ #### **7.1. Architecture & Model Configuration**
149
+
150
+ The fine-tuning module implements a **multi-task learning approach** using a custom `MultiTaskDistilBert` class that extends the base DistilBERT architecture:
151
+
152
+ **Model Structure:**
153
+ - **Base Model**: DistilBERT-base-uncased (6 layers, 768 hidden units)
154
+ - **Pre-classifier**: Linear layer (768 β†’ 768) with ReLU activation
155
+ - **Classification Heads**: Four separate linear classifiers for each task:
156
+ - Main Category: 768 β†’ 7 classes
157
+ - Sub Category: 768 β†’ 50+ classes
158
+ - Intervention: 768 β†’ 4 classes
159
+ - Priority: 768 β†’ 3 classes
160
+ - **Loss Function**: Combined Cross-Entropy loss across all tasks
161
+ - **Dropout**: Configurable dropout layer for regularization
162
+
163
+ **Training Configuration:**
164
+ ```python
165
+ TrainingArguments(
166
+ learning_rate=2e-5,
167
+ per_device_train_batch_size=16,
168
+ per_device_eval_batch_size=16,
169
+ num_train_epochs=12,
170
+ weight_decay=0.01,
171
+ eval_strategy="epoch",
172
+ metric_for_best_model="eval_avg_acc"
173
+ )
174
+ ```
175
+
176
+ -----
177
+
178
+ #### **7.2. Data Processing Pipeline**
179
+
180
+ **Dataset Preparation:**
181
+ - **Input Format**: JSON files containing call transcripts with multi-label annotations
182
+ - **Train/Test Split**: 90/10 stratified split based on sub-category distribution
183
+ - **Text Processing**: Raw transcripts are tokenized using DistilBERT tokenizer with:
184
+ - Maximum sequence length: 512 tokens
185
+ - Padding: "max_length"
186
+ - Truncation: Enabled
187
+
188
+ **Label Mapping System:**
189
+ The fine-tuning process includes a hierarchical label mapping that connects sub-categories to main categories:
190
+ ```python
191
+ # Example mapping structure
192
+ sub_to_main_mapping = {
193
+ "Bullying": "Advice and Counselling",
194
+ "Child Labor": "VANE",
195
+ "Malnutrition": "Nutrition",
196
+ # ... additional mappings
197
+ }
198
+ ```
199
+
200
+ -----
201
+
202
+ #### **7.3. Continuous Learning & Version Control**
203
+
204
+ The fine-tuning module implements an **automated continuous learning system** with intelligent version management:
205
+
206
+ **Version Control Features:**
207
+ - **Automatic Model Versioning**: Each improved model is saved as a new version (v1, v2, v3, etc.)
208
+ - **Performance-Based Saving**: New models are only saved if they outperform the previous best model
209
+ - **Metadata Tracking**: Complete training history with timestamps, metrics, and model paths
210
+ - **Rollback Capability**: Ability to load any previous model version for comparison or deployment
211
+
212
+ **Continuous Learning Workflow:**
213
+ 1. **Model Discovery**: System checks for existing best model in version directory
214
+ 2. **Warm Start**: If found, loads previous best model for continued training
215
+ 3. **Training**: Fine-tunes on new data using MultiTaskTrainer
216
+ 4. **Evaluation**: Compares performance against previous best model
217
+ 5. **Conditional Saving**: Only saves if `eval_avg_acc` improves
218
+ 6. **Metadata Update**: Records training session details and performance metrics
219
+
220
+ **Directory Structure:**
221
+ ```
222
+ /multitask_distilbert_version/
223
+ β”œβ”€β”€ model_metadata.json
224
+ β”œβ”€β”€ CHS_tz_classifier_distilbert1/
225
+ β”œβ”€β”€ CHS_tz_classifier_distilbert2/
226
+ └── CHS_tz_classifier_distilbert3/
227
+ ```
228
+
229
+ -----
230
+
231
+ #### **7.4. Performance Monitoring**
232
+
233
+ **Evaluation Metrics:**
234
+ The system tracks comprehensive metrics across all classification tasks:
235
+
236
+ - **Task-Specific Metrics**: Accuracy, Precision, Recall, F1-score for each task
237
+ - **Overall Performance**: Weighted averages across all tasks
238
+ - **Primary Metric**: `eval_avg_acc` (average accuracy across all tasks) used for model selection
239
+
240
+ **MLflow Integration:**
241
+ - **Experiment Tracking**: All training runs logged to MLflow server
242
+ - **Parameter Logging**: Hyperparameters, model architecture details
243
+ - **Metric Tracking**: Real-time performance monitoring during training
244
+ - **Model Registry**: Integration ready for model lifecycle management
245
+
246
+ -----
247
+
248
+ #### **7.5. Embeddings Generation**
249
+
250
+ The fine-tuning process includes **category embedding generation** for enhanced semantic understanding:
251
+
252
+ **Features:**
253
+ - **Category Embeddings**: Pre-computed embeddings for all category names
254
+ - **Semantic Similarity**: Enable similarity-based classification and error analysis
255
+ - **Storage Format**: NumPy arrays saved for efficient loading during inference
256
+ - **Generated Files**:
257
+ - `embeddings/main_cat_embeddings.npy`
258
+ - `embeddings/sub_cat_embeddings.npy`
259
+ - Category mapping JSON files
260
+
261
+ -----
262
+
263
+ #### **7.6. Usage Instructions**
264
+
265
+ **Prerequisites:**
266
+ - CUDA-compatible GPU (recommended)
267
+ - Python 3.8+ with required dependencies
268
+ - MLflow tracking server (optional)
269
+ - Sufficient disk space for model versions
270
+
271
+ **Running Fine-Tuning:**
272
+ ```bash
273
+ # Set environment variables (optional)
274
+ export CUDA_VISIBLE_DEVICES=0
275
+
276
+ # Run fine-tuning script
277
+ python fine_tune_distilbert.py
278
+ ```
279
+
280
+ **Configuration Options:**
281
+ - **Data Path**: Update `df = pd.read_json()` line with your dataset path
282
+ - **MLflow URI**: Modify `mlflow.set_tracking_uri()` for your tracking server
283
+ - **Model Directory**: Change `model_output_dir` for different storage location
284
+ - **Training Arguments**: Adjust hyperparameters in `TrainingArguments`
285
+
286
+ -----
287
+
288
+ #### **7.7. Best Practices & Recommendations**
289
+
290
+ **Training Optimization:**
291
+ - **Batch Size**: Start with 16, adjust based on GPU memory (8GB+ recommended)
292
+ - **Learning Rate**: 2e-5 works well; consider 1e-5 for stable models, 5e-5 for aggressive fine-tuning
293
+ - **Epochs**: 12 epochs typically sufficient; monitor for overfitting beyond 15
294
+ - **Early Stopping**: Enabled via `load_best_model_at_end=True`
295
+
296
+ **Data Quality:**
297
+ - **Balanced Dataset**: Ensure adequate representation across all categories
298
+ - **Clean Annotations**: Verify multi-label annotations are consistent
299
+ - **Regular Updates**: Retrain periodically with new call center data
300
+
301
+ **Production Deployment:**
302
+ - **Model Selection**: Always use the latest version from metadata file
303
+ - **A/B Testing**: Compare new model versions against current production model
304
+ - **Rollback Plan**: Keep previous model versions for quick rollback if needed
305
+ - **Monitoring**: Continuously monitor model performance in production
306
+
307
+ **Troubleshooting:**
308
+ - **Memory Issues**: Reduce batch size or sequence length
309
+ - **Performance Degradation**: Check data quality and class imbalance
310
+ - **Version Loading Errors**: Verify model directory structure and metadata file
311
+ - If the classifier model is not loaded, the API will return a `503 Service Unavailable` error. If the request is invalid, the API will return a `400 Bad Request` error.
312
+
313
+ -----
314
+
315
+
316
+
317
+
interventions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["Awareness/Information Provided", "Counselling", "Referral", "Signposting"]
nohup.out ADDED
The diff for this file is too large to render. See raw diff
 
priorities.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [1, 2, 3]
sub_categories.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["Adoption", "Albinism", "Balanced Diet", "Birth Registration", "Breastfeeding", "Bullying", "Child Abduction", "Child Abuse", "Child Labor", "Child Marriage", "Child Neglect", "Child Rights", "Child Trafficking", "Child in Conflict with the Law", "Custody", "Discrimination", "Drug/Alcohol Abuse", "Emotional Abuse", "Emotional/Psychological Violence", "Family Relationship", "Feeding & Food preparation", "Female Genital Mutilation", "Financial/Economic Violence", "Forced Marriage Violence", "Foster Care", "HIV/AIDS", "Harmful Practice", "Hearing impairment", "Homelessness", "Hydrocephalus", "Info on Helpline", "Legal Issues", "Legal issues", "Maintenance", "Malnutrition", "Missing Child", "Multiple disabilities", "No Care Giver", "OCSEA", "Obesity", "Other", "Outside Mandate", "Peer Relationships", "Physical Abuse", "Physical Health", "Physical Violence", "Physical impairment", "Psychosocial/Mental Health", "Relationships (Boy/Girl)", "Relationships (Parent/Child)", "Relationships (Student/Teacher)", "School Related Issues", "School related issues", "Self Esteem", "Sexual & Reproductive Health", "Sexual Abuse", "Sexual Violence", "Speech impairment", "Spinal bifida", "Stagnation", "Student/ Teacher Relationship", "Teen Pregnancy", "Traditional Practice", "Underweight", "Unlawful Confinement", "Visual impairment"]
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fa539f245fd4481adfe9fbe970e782ae8f948331d95202a4748b223b3b02d21
3
+ size 5713