Sangmun

Pytorch Custom Dataset 본문

네이버 AI 부스트캠프 4기

Pytorch Custom Dataset

상상2 2022. 9. 30. 10:30

https://wikidocs.net/57165

 

07. 커스텀 데이터셋(Custom Dataset)

앞 내용을 잠깐 복습해봅시다. 파이토치에서는 데이터셋을 좀 더 쉽게 다룰 수 있도록 유용한 도구로서 torch.utils.data.Dataset과 torch.utils.da ...

wikidocs.net

Pytorch를 사용하면서 필수로 알아야 하는 내용이며 이미 여러 군데 잘 정리된 자료가 있지만 필수로 알아야 되는 부분 위주로 정리를 해보고자 한다.

 

Pytorch에서는 torch.utils.data.Dataset을 상속받아 커스텀 데이터셋을 만들 수가 있다.

torch.utils.data.Dataset은 파이토치에서 제공하는 추상 클래스이며 Dataset을 상속받아 다음 메서드들을 오버라이드 할 수 있다.

 

# 커스텀 데이터셋을 만들때의 기본적이 뼈대구조
class CustomDataset(torch.utils.data.Dataset): 
  def __init__(self):
  # 데이터셋의 전처리를 해주는 부분
  
  def __len__(self):
  # 데이터셋의 길이. 즉, 총 샘플의 수를 적어주는 부분

  def __getitem__(self, idx):
  # 데이터셋에서 특정 1개의 샘플을 가져오는 함수

 

예시로 커스텀 데이터셋을 구현해보자

import torch
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Dataset 상속
class CustomDataset(Dataset): 
  def __init__(self):
    self.x_data = [[73, 80, 75],
                   [93, 88, 93],
                   [89, 91, 90],
                   [96, 98, 100],
                   [73, 66, 70]]
    self.y_data = [[152], [185], [180], [196], [142]]

  # 총 데이터의 개수를 리턴
  def __len__(self): 
    return len(self.x_data)

  # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
  def __getitem__(self, idx): 
    x = torch.FloatTensor(self.x_data[idx])
    y = torch.FloatTensor(self.y_data[idx])
    return x, y
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
next(iter(dataloader))

# 출력결과 DataLoader에서 batch_size를 2로 설정하여 두개의 데이터 쌍이 출력된다.
[tensor([[73., 80., 75.],
         [89., 91., 90.]]), tensor([[152.],
         [180.]])]

 

DataLoader에는 batch_size와 shuffle외에도 다양한 유용한 옵션들이 존재한다.

https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader (공식블로그)

 

torch.utils.data — PyTorch 1.12 documentation

torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si

pytorch.org

 

 

아래의 이미지는 DataLoader의 옵션을 한눈에 볼 수 있는 이미지이다.

DataLoader의 옵션들

 

https://subinium.github.io/pytorch-dataloader

 

[Pytorch] DataLoader parameter별 용도

pytorch reference 문서를 다 외우면 얼마나 편할까!!

subinium.github.io

수비니움님의 블로그는 정리를 잘해주셔서 따로 링크를 추가하였다.

'네이버 AI 부스트캠프 4기' 카테고리의 다른 글

[NLP] Seq2Seq  (0) 2022.10.04
[NLP] Transformer  (0) 2022.10.03
구글 colab vscode에서 접속하기(ngrok)  (0) 2022.09.29
Pytorch Project Template  (0) 2022.09.29
네이버 AI 부스트캠프 1주차 후기  (0) 2022.09.25
Comments