a877aed45f
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""
|
|
Author: Weisen Pan
|
|
Date: 2023-10-24
|
|
"""
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
from datetime import timedelta
|
|
from sklearn import metrics
|
|
from tqdm import tqdm
|
|
from scheduler import WarmUpLR, downLR
|
|
|
|
|
|
def get_time_difference(start_time):
|
|
"""Compute time elapsed from the start_time to now."""
|
|
elapsed_time = time.time() - start_time
|
|
return timedelta(seconds=int(round(elapsed_time)))
|
|
|
|
|
|
def train(config, model, train_iter, dev_iter, test_iter):
|
|
"""Train the model and evaluate on the development and test sets."""
|
|
start_time = time.time()
|
|
model.train()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
|
|
|
warmup_epoch = config.num_epochs // 2
|
|
iter_per_epoch = len(train_iter)
|
|
scheduler = downLR(optimizer, (config.num_epochs - warmup_epoch) * iter_per_epoch)
|
|
warmup_scheduler = WarmUpLR(optimizer, warmup_epoch * iter_per_epoch)
|
|
|
|
lr_list = np.zeros((config.num_epochs, 2))
|
|
dev_best_loss = float('inf')
|
|
dev_best_acc = 0
|
|
test_best_acc = 0
|
|
total_batch = 0
|
|
|
|
for epoch in range(config.num_epochs):
|
|
loss_total = 0
|
|
print(f'Epoch [{epoch + 1}/{config.num_epochs}]')
|
|
|
|
predictions, true_values = [], []
|
|
|
|
for trains, labels in tqdm(train_iter):
|
|
trains, labels = trains.to(config.device), labels.long().to(config.device)
|
|
outputs = model(trains)
|
|
loss = F.cross_entropy(outputs, labels)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if epoch < warmup_epoch:
|
|
warmup_scheduler.step()
|
|
else:
|
|
scheduler.step()
|
|
|
|
total_batch += 1
|
|
loss_total += loss.item()
|
|
|
|
predictions.extend(torch.max(outputs.data, 1)[1].tolist())
|
|
true_values.extend(labels.data.tolist())
|
|
|
|
train_acc = get_accuracy(true_values, predictions)
|
|
dev_acc, dev_loss = evaluate(config, model, dev_iter)
|
|
test_acc, test_loss = evaluate(config, model, test_iter)
|
|
|
|
if dev_loss < dev_best_loss:
|
|
dev_best_loss = dev_loss
|
|
improvement_marker = '*'
|
|
else:
|
|
improvement_marker = ''
|
|
|
|
if dev_acc > dev_best_acc:
|
|
dev_best_acc = dev_acc
|
|
test_best_acc = test_acc
|
|
|
|
elapsed_time = get_time_difference(start_time)
|
|
print((
|
|
f'Iter: {total_batch:6}, Train Loss: {loss_total/len(train_iter):.2f}, '
|
|
f'Train Acc: {train_acc:.2%}, Dev Loss: {dev_loss:.2f}, Dev Acc: {dev_acc:.2%}, '
|
|
f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}, Time: {elapsed_time} {improvement_marker}'
|
|
))
|
|
print(f'Best Dev Acc: {dev_best_acc:.2%}, Best Test Acc: {test_best_acc:.2%}')
|
|
|
|
test(config, model, test_iter)
|
|
|
|
|
|
def test(config, model, test_iter):
|
|
"""Evaluate the model on the test set."""
|
|
model.eval()
|
|
start_time = time.time()
|
|
test_acc, test_loss, test_confusion = evaluate(config, model, test_iter, test=True)
|
|
print(f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}')
|
|
print(test_confusion)
|
|
elapsed_time = get_time_difference(start_time)
|
|
print(f"Time usage: {elapsed_time}")
|
|
|
|
|
|
def evaluate(config, model, data_iter, test=False):
|
|
"""Evaluate the model on a given dataset."""
|
|
model.eval()
|
|
loss_total = 0
|
|
predictions, true_values = [], []
|
|
|
|
with torch.no_grad():
|
|
for texts, labels in data_iter:
|
|
texts, labels = texts.float().to(config.device), labels.long().to(config.device)
|
|
outputs = model(texts)
|
|
loss = F.cross_entropy(outputs, labels)
|
|
loss_total += loss.item()
|
|
|
|
predictions.extend(torch.max(outputs.data, 1)[1].tolist())
|
|
true_values.extend(labels.data.tolist())
|
|
|
|
acc = get_accuracy(true_values, predictions)
|
|
|
|
if test:
|
|
confusion = metrics.confusion_matrix(true_values, predictions)
|
|
return acc, loss_total / len(data_iter), confusion
|
|
|
|
return acc, loss_total / len(data_iter)
|
|
|
|
|
|
def get_accuracy(y_true, y_pred):
|
|
"""Calculate accuracy."""
|
|
return metrics.accuracy_score(y_true, y_pred)
|