Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- pep8
- NaverAItech
- autoencoder
- docker
- datascience
- GCP
- Matplotlib
- 프로그래머스
- leetcode
- DeepLearning
- 네이버AItech
- FDS
- torchserve
- NLP
- github
- 백준
- PytorchLightning
- pytorch
- 코딩테스트
- GitHub Action
- GIT
- wandb
- 완전탐색
- vscode
- Kaggle
- Kubernetes
- 알고리즘
- rnn
- python
- FastAPI
Archives
- Today
- Total
Sangmun
Pytorch Custom Dataset 본문
Pytorch를 사용하면서 필수로 알아야 하는 내용이며 이미 여러 군데 잘 정리된 자료가 있지만 필수로 알아야 되는 부분 위주로 정리를 해보고자 한다.
Pytorch에서는 torch.utils.data.Dataset을 상속받아 커스텀 데이터셋을 만들 수가 있다.
torch.utils.data.Dataset은 파이토치에서 제공하는 추상 클래스이며 Dataset을 상속받아 다음 메서드들을 오버라이드 할 수 있다.
# 커스텀 데이터셋을 만들때의 기본적이 뼈대구조
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# 데이터셋의 전처리를 해주는 부분
def __len__(self):
# 데이터셋의 길이. 즉, 총 샘플의 수를 적어주는 부분
def __getitem__(self, idx):
# 데이터셋에서 특정 1개의 샘플을 가져오는 함수
예시로 커스텀 데이터셋을 구현해보자
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# Dataset 상속
class CustomDataset(Dataset):
def __init__(self):
self.x_data = [[73, 80, 75],
[93, 88, 93],
[89, 91, 90],
[96, 98, 100],
[73, 66, 70]]
self.y_data = [[152], [185], [180], [196], [142]]
# 총 데이터의 개수를 리턴
def __len__(self):
return len(self.x_data)
# 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
def __getitem__(self, idx):
x = torch.FloatTensor(self.x_data[idx])
y = torch.FloatTensor(self.y_data[idx])
return x, y
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
next(iter(dataloader))
# 출력결과 DataLoader에서 batch_size를 2로 설정하여 두개의 데이터 쌍이 출력된다.
[tensor([[73., 80., 75.],
[89., 91., 90.]]), tensor([[152.],
[180.]])]
DataLoader에는 batch_size와 shuffle외에도 다양한 유용한 옵션들이 존재한다.
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader (공식블로그)
아래의 이미지는 DataLoader의 옵션을 한눈에 볼 수 있는 이미지이다.
https://subinium.github.io/pytorch-dataloader
수비니움님의 블로그는 정리를 잘해주셔서 따로 링크를 추가하였다.
'네이버 AI 부스트캠프 4기' 카테고리의 다른 글
[NLP] Seq2Seq (0) | 2022.10.04 |
---|---|
[NLP] Transformer (0) | 2022.10.03 |
구글 colab vscode에서 접속하기(ngrok) (0) | 2022.09.29 |
Pytorch Project Template (0) | 2022.09.29 |
네이버 AI 부스트캠프 1주차 후기 (0) | 2022.09.25 |
Comments