如何使用PyTorch Metric Learning實現訓練數據的類別平衡:完整指南與範例

  • 當我們使用深度學習模型進行訓練時,數據的平衡非常重要。如果某些類別的數據比其他類別多很多,模型可能會偏向於頻繁出現的類別,導致對少數類別的預測效果不佳。為了解決這個問題,我們可以使用pytorch_metric_learning這個工具來幫助我們平衡每個batch中的類別數據。

  • 在這篇教學中,我們將逐步介紹如何使用pytorch_metric_learning中的samplers來達到這個目的。我們也會提供一個範例,讓大家更清楚地了解如何應用這個方法。

詳細步驟

Step 1 : 引用samplers

from pytorch_metric_learning import samplers as SAMPLERS

Step2: 將Datasets 的labels特別抽取出來,等等調用sampler 要使用

labels = [label for _, label in train_dataset]

Step3: 調用sampler

per_cls_num = batch_size // number_of_class
sampler = SAMPLERS.MPerClassSampler(labels , per_cls_num, batch_size=None, length_before_new_iter=len(train_dataset))

Step4: 放入DataLoader

train_loader = DataLoader(
dataset=train_dataset,        
sampler=sampler ,
)   

完整範例: 以下是一個完成的MNIST 範例

from torch.utils.data import DataLoader
from pytorch_metric_learning import samplers
from torchvision import datasets, transforms

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 這裡的train_dataset與上面相同
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 使用MPerClassSampler
labels = [label for _, label in train_dataset]
m_per_class_sampler = samplers.MPerClassSampler(labels, m=32, length_before_new_iter=len(train_dataset))

train_loader_sampler = DataLoader(train_dataset, batch_size=64, sampler=m_per_class_sampler)

train_loader = DataLoader(train_dataset, batch_size=64)

## show the effect of MPerClassSampler
for images, labels in train_loader_sampler:
print(labels)
break


## This is original dataloader
for images, labels in train_loader:
print(labels)
break

參考資料

0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Inline Feedbacks
View all comments