본문 바로가기

코드 정리/Pytorch 정리

5. PyTorch Dataset / DataLoader

Custom Data를 구성해야 하는 이유

  • PyTorch에서는 기본적으로 torchvision.datasets 모듈을 통해 여러 유명한 데이터셋(CIFAR-10, ImageNet, MNIST 등)을 바로 사용할 수 있습니다.
  • 하지만, torchvision.datasets에서 제공하지 않는 데이터를 사용하려면 Custom Dataset을 구성해야 한다.

Custom Dataset을 구성하면 얻는 이점

  • Custom Dataset을 사용하면 원하는 데이터 형식을 직접 정의하여 불러올 수 있다.
  • transform을 활용하여 전처리를 쉽게 적용 가능.
  • 데이터셋을 동적으로 로드하고 변형하기 위해

Custom Dataset이 필수 메서드로 가져야하는 것

  • ‘__init__’: Dataset 객체가 생성될 때 단 한번만 실행되는 메서드 입니다. 해당 메서드는 input data(image), label을 load하고, transform을 사용할 경우 transform을 초기화 합니다.
  • ‘__len__’: Dataset의 샘플 개수를 반환하는 코드를 작성합니다.
  • ‘__getitem__’: 주어진 index의 데이터를 반환하는 코드를 작성합니다.

 

아래는 PyTorch Documentation에서 제공하는 Datasets & DataLoaders 튜토리얼 부분을 가져와 각 줄마다 이해를 돕기 위해 주석을 처리한 코드입니다. 

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

 

Datasets & DataLoaders — PyTorch Tutorials 2.5.0+cu124 documentation

Note Click here to download the full example code Learn the Basics || Quickstart || Tensors || Datasets & DataLoaders || Transforms || Build Model || Autograd || Optimization || Save & Load Model Datasets & DataLoaders Created On: Feb 09, 2021 | Last Updat

pytorch.org

아래의 코드는 절대적인 코드가 아니며, 본인이 가진 데이터에 맞게 수정해야 하는 baseline 코드입니다.

필요한 라이브러리 Load

import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

Custom Dataset 정의

class CustomImageDataset(Dataset):
    # __init__에서 학습 데이터와 label을 초기화하고, 전처리도 있으면 전처리도 초기화합니다.
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # label 정보 초기화
        self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
        # img가 저장된 directory 경로 저장
        self.img_dir = img_dir
        # 입력 이미지에 적용할 전처리 정보 초기화
        self.transform = transform
        # label에 적용할 전처리 정보 초기화
        self.target_transform = target_transform

    # __len__에서 데이터셋의 샘플 개수를 반환하는 코드를 작성합니다.
    def __len__(self):
        # label의 길이만큼 반환 == Dataset의 길이
        return len(self.img_labels)

    # __getitem__에서 주어진 index에 해당하는 Data sample을 반환하는 코드를 작성합니다.
    def __getitem__(self, idx):
        # img_labels에서 idx번째 파일명을 가져오고 img_dir과 결합하여 이미지 파일 경로를 생성합니다.
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # img_path로부터 이미지를 로드
        image = read_image(img_path)
        # index번째 sample label 가져옵니다.
        label = self.img_labels.iloc[idx, 1]
        # 입력 이미지에 전처리 적용
        if self.transform:
            image = self.transform(image)
        # label에 전처리 적용
        if self.target_transform:
            label = self.target_transform(label)
        # index번째 image와 label 반환
        return image, label
    
# CustomImageDataset의 객체를 생성할때, annotations_file, img_dir의 정보를 전달해야 한다.
customDataset = CustomImageDataset('annotations_file의 경로', ', img_dir의 경로')

 

 

DataLoader를 사용해야 하는 이유

  • 데이터가 생성된 후, batch 형태로 만들기 위해 DataLoader를 사용합니다.
  • 매 epoch마다 Data를 shuffle하여 overfit을 방지합니다.
  • DataLoader에 데이터셋을 로드한 후, 필요에 따라 데이터셋을 순회(iterate) 가능합니다.

 

DataLoader 구성

# train_data를 넘겨주고, train_dataloader 객체를 생성
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_data를 넘겨주고, test_dataloader 객체를 생성
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

 

'코드 정리 > Pytorch 정리' 카테고리의 다른 글

6. PyTorch Model 구성  (0) 2025.01.27
4. PyTorch에서 제공하는 여러 함수  (1) 2025.01.23
3. Tensor Indexing, Slicing(텐서 인덱싱, 슬라이싱)  (0) 2025.01.22
2. Tensor 연산  (0) 2025.01.21
1. Tensor 생성  (1) 2025.01.20