torch DataLoader의 collate_fn에 대해 알아보자
약간 고급 코드로 들어가기 시작하면 DataLoader를 그냥 사용하는 경우가 드물어진다.
보통은 DataLoader의 collate_fn이라는 메서드를 본인들의 데이터세트에 맞춰 직접 구현해서 사용하는 경우가 많아지는데 collate_fn을 지금까지 딱히 주의깊게 본 적은 없었다.
나도 이제 슬슬 내 데이터를 직접 구현해서 추가하려고 하기 때문에 collate_fn의 동작 방식과 왜 사용하는지를 알고 있어야 할 것 같아서 정리해본다.
DataLoader(
train_dataset,
batch_size=config.train.batch_size,
collate_fn=PaddingCollate(),
shuffle=True,
num_workers=args.num_workers
)
Dataset, DataLoader
collate_fn을 공부하기 전에 먼저 torch의 Dataset과 DataLoader에 대해 간단하게 알아보자
torch로 구현된 nn.Module에 데이터를 input하기 위해서는 보통 torch의 utils.data의 Dataset과 DataLoader를 이용한다.
쉽게 생각하면 DataLoader는 Dataset을 이용해 정형화된 데이터 샘플을 받아서 batch_size만큼 쌓아올려 모델로 전달하는 역할이라고 생각하면 된다.
collate_fn
collate_fn은 이때 DataLoader가 여러개의 샘플을 batch단위로 묶을 때 사용하는 함수이다.
기본적으로 default로 동작하는 collate_fn은 List형태로 저장된 여러 샘플 텐서들을 stack해서 하나의 텐서로 만든다. 예를 들어서
dataset = [
torch.tensor([1, 2]), # 샘플1
torch.tensor([3, 4]), # 샘플2
torch.tensor([5, 6]) # 샘플3
] 와 같은 형태의 dataset을 collate_fn을 이용해서 batch_size = 3으로 묶으면 다음과 같이 동작한다.
batch = torch.stack([dataset[0], dataset[1], dataset[2]])
결과 batch는 tensor([[1, 2], [3, 4], [5, 6]])이 된다.
그렇다면 왜 굳이 collate_fn을 따로 구현해서 사용하는걸까?
보통 두 가지 이유가 있다.
1) 각 샘플들의 크기가 다른 경우
dataset = [
torch.tensor([1, 2]), # 샘플1
torch.tensor([3, 4,5]), # 샘플2
torch.tensor([6, 7]) # 샘플3
] 과 같은 경우가 있다면 각 샘플들의 size가 다르기 때문에 torch.stack()은 사용할 수 없게된다.
# RuntimeError: each element in list of batch should be of equal size
따라서 이러한 경우는 collate_fn을 구현해서 샘플들의 크기를 truncation하던지 padding하는 코드를 구현해줘야 한다.
2) 사용자 정의 class가 포함된 Dataset을 사용할 때
샘플들의 크기가 서로 같다고 해도 사용자가 직접 정의한 Dataset class를 사용하면 기본 collate_fn을 사용할 수 없다.
예를들어 다음 코드를 살펴보자
class Sample: # 사용자 정의 class
def __init__(self, data, label):
self.data = data
self.label = label
dataset = [ # 길이가 정형화된 샘플
Sample(torch.tensor([1.0, 2.0]), 0),
Sample(torch.tensor([3.0, 4.0]), 1)
]
loader = DataLoader(dataset, batch_size=2)
for batch in loader:
print(batch)
# TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists;
# found <class '__main__.Sample'>
길이가 정형화 되어 있는 dataset을 지정했음에도 Sample이라는 사용자 지정 class를 이용한 경우 디폴트 collate_fn으로는 TypeError를 반환하게된다. 이런 경우도 마찬가지로 사용자 정의 collate_fn을 구현해서 사용해야한다.
def custom_collate(batch):
data = torch.stack([item.data for item in batch])
labels = torch.tensor([item.label for item in batch])
return {"data": data, "label": labels}
loader = DataLoader(dataset,
batch_size=2,
collate_fn=custom_collate) # 사용자 정의 collate_fn
for batch in loader:
print(batch)
# {'data': tensor([[1., 2.],[3., 4.]]),
# 'label': tensor([0, 1])}