PyTorch의 IterableDataset을 사용해서 데이터 불러오기
PyTorch 1.2 이상부터 torch.utils.data
에서는 크게 map-style dataset (torch.utils.data.Dataset
) 과 iterable dataset (torch.utils.data.IterableDataset
) 의 두 종류의 데이터 클래스를 지원하고 있다. 데이터 사이즈가 클 때는 IterableDataset
을 사용하는 것이 좋은데, Dataset
과는 딜리 아직 개발되어야 할 기능이 더 필요한 클래스라서 사용할 때에 유의할 점이 있어 정리해보게 되었다.
Map-style Dataset
1.2 이하 버전에서 사용되던 map-style dataset은 memory에 모든 데이터를 업로드할 수 있을 때 사용하는 가장 일반적인 dataset type 이다. custom dataset class를 생성하고자 할 때 torch.utils.data.Dataset
을 상속받아 __len__
, __getitem__
을 구현하면 된다.
from torch.utils.data import Dataset |
Iterable Dataset
하지만 학습 데이터가 메모리에 다 올라가지 않는 경우가 발생할 수 있다. 이 문제를 해결할 수 있는 다양한 방법 중에 하나로, torch.utils.data.IterableDataset
을 사용하는 방법이 있다. Map-style Dataset과 비슷하게 torch.utils.data.IterableDataset
을 상속받아서 custom dataset class를 생성하고, __iter__
를 선언하면 된다.
from torch.utils.data import IterableDataset |
Dataset
이 batch data를 생성할 때 map_dataset[index]
를 사용한다면, IterableDataset
은 next(iterable_dataset)
을 사용한다. 이 때문에 DataLoader
를 통해 IterableDataset
을 불러와서 사용하게 되면 sampler
옵션의 사용이 어렵다. 그래서 random suffling 을 하고 싶다면 미리 데이터셋을 shuffling 한 이후에 불러오는 것이 좋다.
Going Parallel
PyTorch 공식문서에 따르면 IterableDataset
을 num_workers > 0
의 조건에서 사용할 때 특별히 다음을 유념할 것을 제안하고 있다.
When
num_workers > 0
, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers.get_worker_info()
, when called in a worker process, returns information about the worker. It can be used in either the dataset’s__iter__()
method or theDataLoader
‘sworker_init_fn
option to modify each copy’s behavior.
위의 문장을 이해하려면 num_workers
에 대한 이해와, num_workers > 0
일 때 IterDataset
에서 어떤 현상이 일어나는지 알아야한다.

num_workers
는 데이터셋을 불러올 때 사용할 subprocess의 개수이다. num_workers == 0
은 main process에서 데이터를 불러오고 모델에 pass하는 작업을 모두 수행한다는 뜻이며, num_workers == 2
는 subprocess 2개에서 데이터를 불러오고 main process에서는 subprocess에서 불러온 데이터를 model에 pass하는 역할만 담당한다. 따라서 num_workers > 0
일 때 data loading에서의 병목이 줄어들며 gpu utilization 을 100% 가까이 끌어올릴 수 있다.
그럼, num_workers > 0
일 때 어떤 현상이 발생하는지 살펴보자.
Map-Style Dataset
from torch.utils.data import DataLoader, Dataset, IterableDataset |
num_workers == 0
인 경우
loader = DataLoader(map_dataset, batch_size=4, num_workers=0) |
num_workers == 2
인 경우
loader = DataLoader(map_dataset, batch_size=4, num_workers=2) |
의도한대로, subprocess 별로 서로 다른 데이터를 불러오는 것을 알 수 있다.
Iterable Dataset
from torch.utils.data import DataLoader, Dataset, IterableDataset |
num_workers == 0
loader = DataLoader(iterable_dataset, batch_size=4, num_workers=0) |
num_workers == 2
loader = DataLoader(iterable_dataset, batch_size=4, num_workers=2) |
⚠️ worker 0과 worker 1에서 같은 데이터를 로딩하고 있다. 이 점이 공식문서에서 지적하고 있는 내용이다. 각 워커별로 같은 __iter__()
를 사용하기 때문에 같은 데이터를 로딩하게 된다. 이를 방지하기 위해서는 worker_init_fn
에서 직접 워커 별 데이터를 재분배 시켜줘야 한다.
def worker_init_fn(_): |
loader = DataLoader(iterable_dataset, batch_size=4, num_workers=2, worker_init_fn=worker_init_fn) |
worker_init_fn
을 통해 분배시켜준 결과 worker 0과 worker 1 에서 다른 데이터를 순차적으로 불러오는 것을 알 수 있다 🙂
Conclusions
IterableDataset
은 데이터가 메모리에 올라가지 않을만큼 클 때 사용하면 좋다.Dataset
과 다르게__iter__()
를 선언해서 데이터를 부른다.- 하지만 이 특징 때문에
Sampler
와 함께 사용할 수 없다. - 또한
num_workers > 0
인 세팅에서는 각 워커에서 다른 데이터를 불러올 수 있도록worker_init_fn
을 선언해야 한다.
- 하지만 이 특징 때문에