當我們使用深度學習模型進行訓練時,數據的平衡非常重要。如果某些類別的數據比其他類別多很多,模型可能會偏向於頻繁出現的類別,導致對少數類別的預測效果不佳。為了解決這個問題,我們可以使用pytorch_metric_learning這個工具來幫助我們平衡每個batch中的類別數據。
在這篇教學中,我們將逐步介紹如何使用pytorch_metric_learning中的samplers來達到這個目的。我們也會提供一個範例,讓大家更清楚地了解如何應用這個方法。
內容目錄
詳細步驟
Step 1 : 引用samplers
from pytorch_metric_learning import samplers as SAMPLERS
Step2: 將Datasets 的labels特別抽取出來,等等調用sampler 要使用
labels = [label for _, label in train_dataset]
Step3: 調用sampler
per_cls_num = batch_size // number_of_class
sampler = SAMPLERS.MPerClassSampler(labels , per_cls_num, batch_size=None, length_before_new_iter=len(train_dataset))
Step4: 放入DataLoader
train_loader = DataLoader(
dataset=train_dataset,
sampler=sampler ,
)
完整範例: 以下是一個完成的MNIST 範例
from torch.utils.data import DataLoader
from pytorch_metric_learning import samplers
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 這裡的train_dataset與上面相同
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 使用MPerClassSampler
labels = [label for _, label in train_dataset]
m_per_class_sampler = samplers.MPerClassSampler(labels, m=32, length_before_new_iter=len(train_dataset))
train_loader_sampler = DataLoader(train_dataset, batch_size=64, sampler=m_per_class_sampler)
train_loader = DataLoader(train_dataset, batch_size=64)
## show the effect of MPerClassSampler
for images, labels in train_loader_sampler:
print(labels)
break
## This is original dataloader
for images, labels in train_loader:
print(labels)
break
