| from transformers import AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModelForSequenceClassification |
| from torch.utils.data import DataLoader, RandomSampler, SequentialSampler |
| import torch |
| from sklearn.model_selection import train_test_split |
| from dataset.load_dataset import df, prepare_dataset |
| from torch.nn import BCEWithLogitsLoss |
| from transformers import BertForSequenceClassification, BertConfig |
| from tqdm.auto import tqdm |
| from torch.cuda.amp import GradScaler, autocast |
| from torch.utils.tensorboard import SummaryWriter |
| import datetime |
|
|
| |
| current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') |
| log_dir = f'runs/train_{current_time}' |
| writer = SummaryWriter(log_dir) |
| epochs = 10 |
| lr = 1e-5 |
| optimizer_name = 'AdamW' |
| loss_fn_name = 'BCEWithLogitsLoss' |
| batch_size = 16 |
|
|
| |
| model_save_name = f'model_{current_time}_lr{lr}_opt{optimizer_name}_loss{loss_fn_name}_batch{batch_size}_epoch{epochs}.pt' |
| model_save_path = f'./saved_models/{model_save_name}' |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| "pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| train_df, val_df = train_test_split(df, test_size=0.1) |
|
|
| |
| train_dataset = prepare_dataset(train_df, tokenizer) |
| val_dataset = prepare_dataset(val_df, tokenizer) |
| |
| train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size) |
| validation_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=batch_size) |
|
|
| |
| config = BertConfig.from_pretrained("pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition") |
| config.num_labels = 8 |
|
|
| model = AutoModelForSequenceClassification.from_pretrained( |
| "pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition", config=config, ignore_mismatched_sizes=True).to( |
| device) |
| |
| optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8) |
| total_steps = len(train_dataloader) * epochs |
| scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) |
| loss_fn = BCEWithLogitsLoss() |
| |
| scaler = GradScaler() |
|
|
| for epoch in range(epochs): |
| print(f"\nEpoch {epoch + 1}/{epochs}") |
| print('-------------------------------') |
| model.train() |
| total_loss = 0 |
| train_progress_bar = tqdm(train_dataloader, desc="Training", leave=False) |
| for step, batch in enumerate(train_progress_bar): |
| |
| batch = tuple(t.to(device) for t in batch) |
| b_input_ids, b_input_mask, b_labels = batch |
| model.zero_grad() |
| |
| outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask) |
| logits = outputs.logits |
| |
| loss = loss_fn(logits, b_labels) |
| total_loss += loss.item() |
| |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| |
| if torch.isnan(loss).any(): |
| print(f"Loss is nan in epoch {epoch + 1}, step {step}.") |
| |
| |
| continue |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| train_progress_bar.set_postfix({'loss': f"{loss.item():.2f}"}) |
| |
| writer.add_scalar('Loss/train', loss.item(), epoch * len(train_dataloader) + step) |
|
|
| |
| avg_train_loss = total_loss / len(train_dataloader) |
| print(f"Training loss: {avg_train_loss:.2f}") |
|
|
| |
| model.eval() |
| total_eval_accuracy = 0 |
| eval_progress_bar = tqdm(validation_dataloader, desc="Validation", leave=False) |
| total_eval_loss = 0 |
|
|
| for batch in eval_progress_bar: |
| batch = tuple(t.to(device) for t in batch) |
| b_input_ids, b_input_mask, b_labels = batch |
| with torch.no_grad(): |
| outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask) |
| logits = outputs.logits |
| |
| loss = loss_fn(logits, b_labels) |
| total_eval_loss += loss.item() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| logits_sas = logits[:, :4] |
| logits_sds = logits[:, 4:] |
| |
| probs_sas = torch.softmax(logits_sas, dim=1) |
| probs_sds = torch.softmax(logits_sds, dim=1) |
|
|
| |
| _, predictions_sas = torch.max(probs_sas, dim=1) |
| _, predictions_sds = torch.max(probs_sds, dim=1) |
|
|
| |
| true_sas = b_labels[:, 0].long() |
| true_sds = b_labels[:, 1].long() |
|
|
| |
| accuracy_sas = (predictions_sas == true_sas).float().mean() |
| accuracy_sds = (predictions_sds == true_sds).float().mean() |
|
|
| |
| accuracy = (accuracy_sas + accuracy_sds) / 2 |
| total_eval_accuracy += accuracy |
| |
| eval_progress_bar.set_postfix({'accuracy': f"{accuracy:.2f}"}) |
| |
| avg_val_loss = total_eval_loss / len(validation_dataloader) |
| print(f"Validation Loss: {avg_val_loss:.2f}") |
| avg_val_accuracy = total_eval_accuracy / len(validation_dataloader) |
| writer.add_scalar('Loss/val', avg_val_loss, epoch) |
| print(f"Validation Accuracy: {avg_val_accuracy:.2f}") |
|
|
| writer.close() |
| |
| torch.save(model.state_dict(), model_save_path) |
| print(f"traing end, save model to :{model_save_path}") |
|
|