반응형

정리되지 않은 커스텀 데이터 불러오기 - PyTorch

from torch.utils.data import Dataset

class 클래스명(Dataset): # Dataset을 상속받아 DataLoader에서 배치 단위로 불러올 수 있게 해줌

	def __init__(self): # 데이터 세팅에 필요한 것을 미리 정의
    	...
    
    def __getitem__(self, index): # DataLoader를 통해 샘플이 요청되면 해당하는 샘플을 반환
    	...
    
    def __len__(self): # 크기를 반환
    	...

위 코드는 커스텀 데이터를 불러오는 가장 기본적인 형태이다.

 

예를 들어, 32x32 크기인 RGB 컬러 이미지 100장과 그에 대한 라벨링 작업이 되어 있고 넘파이 배열로 정리가 되어있다고 가정하자.

train_images = np.random.randint(256, size=(100, 32, 32, 3))/255
train_labels = np.random.randint(2, size=(100, 1))

class TensorData(Dataset):

	def __init__(self, x_data, y_data):
    	self.x_data = torch.FloatTensor(x_data)
	self.x_data = self.x_data.permute(0, 3, 1, 2)        
    	self.y_data = torch.FloatTensor(y_data)
        self.len = self.y_data.shape[0]
       
	def __getitem__(self, index):
    	return self.x_data[index], self.y_data[index]
       
	def __len__(self):
    	return self.len