인공지능 개발자 수다(유튜브 바로가기) 자세히보기

Deep Learning/Pytorch

Torchvision(토치비전) 사용법

Suda_777 2024. 2. 14. 03:19

목차

    반응형

    1. Torchvision(토치비전)은 언제 사용?

    • 컴퓨터 비전 프로젝트에서 사용함.

    2. Torchvision(토치비전)의 주요 기능 요약

    1. 데이터셋 접근 및 사용: torchvision은 MNIST, CIFAR-10, ImageNet 등 다양한 사전 정의된 데이터셋을 제공합니다. 이를 통해 쉽게 데이터를 로드하고 실험할 수 있습니다.
    2. 데이터 변환(Transformation): 이미지 데이터를 전처리하거나 증강하기 위한 다양한 변환 기능을 제공합니다. 예를 들어, 이미지의 크기를 조정하거나, 회전, 뒤집기 등의 작업을 쉽게 수행할 수 있습니다.
    3. 모델: 사전 훈련된 다양한 모델을 제공하여, 이미지 분류, 객체 탐지, 세그멘테이션 등 다양한 비전 태스크에 활용할 수 있습니다. ResNet, VGG, AlexNet 등의 유명한 모델을 쉽게 불러와 사용할 수 있습니다.

    3. 사전 제공 데이터셋

    3.1. 사전 제공 데이터셋 리스트

    torchvision 라이브러리에서 제공하는 몇 가지 사전 정의된 데이터셋은 다음과 같습니다:

    MNIST 손으로 쓴 숫자 데이터셋. 기계 학습 분야에서 널리 사용됩니다.
    CIFAR-10 10개 카테고리의 60000개 32x32 컬러 이미지를 포함합니다.
    CIFAR-100 CIFAR-10과 유사하지만, 100개 카테고리를 포함합니다.
    ImageNet 다양한 카테고리의 백만 단위의 레이블이 있는 이미지 데이터셋입니다.
    COCO 이미지 인식, 분할 및 객체 탐지를 위한 데이터셋입니다.
    VOC 이미지 분류, 객체 탐지 등을 위한 데이터셋입니다.

     

    3.2. 데이터셋 불러오는 예시코드

    데이터셋을 불러올 때는 `transforms`를 이용해 데이터를 변환해 주는 내용을 정의해 주면 됨.

    1. `datasets` 에서 데이터셋을 불러옴
    2. `transforms`로 데이터를 변형
    import torch
    from torchvision import datasets, transforms
    
    # 데이터셋을 불러오기 위한 변환 정의
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # 훈련 데이터셋 불러오기
    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    
    # 테스트 데이터셋 불러오기
    testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 데이터 로더 생성
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    
    print("MNIST 데이터셋이 성공적으로 불러와졌습니다.")

     

    3.3. 커스텀 데이터셋 불러오는 예제 코드

    커스텀하게 데이터셋을 만들 때 에는

    1. 커스텀 데이터셋 정의 (클래스)
    2. `transforms`를 이용해 데이터를 변환해 주는 내용을 정의해 주면 됨.
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    import os
    from PIL import Image
    
    class CustomDataset(Dataset):
        def __init__(self, image_paths, transform=None):
            self.image_paths = image_paths
            self.transform = transform
    
        def __len__(self):
            return len(self.image_paths)
    
        def __getitem__(self, idx):
            image_path = self.image_paths[idx]
            image = Image.open(image_path)
            if self.transform:
                image = self.transform(image)
            return image
    
    # 예제 사용
    image_paths = ['./data/img1.jpg', './data/img2.jpg'] # 실제 이미지 경로로 대체
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    dataset = CustomDataset(image_paths=image_paths, transform=transform)
    
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

     

    4. 전처리

    4.1. 주요 전처리 함수

    torchvision.transforms 모듈에서 사용할 수 있는 주요 이미지 변형 함수들을 요약한 것입니다:

    Resize 이미지의 크기를 조정합니다.
    CenterCrop 이미지의 중앙을 기준으로 정사각형으로 잘라냅니다.
    RandomCrop 이미지에서 무작위로 선택한 위치를 기준으로 정사각형으로 잘라냅니다.
    RandomHorizontalFlip 주어진 확률로 이미지를 수평으로 뒤집습니다.
    RandomVerticalFlip 주어진 확률로 이미지를 수직으로 뒤집습니다.
    ToTensor PIL 이미지나 NumPy ndarray를 PyTorch 텐서로 변환합니다.
    Normalize 채널별로 평균과 표준편차를 사용하여 이미지를 정규화합니다.
    ColorJitter 이미지의 밝기, 대비, 포화도, 색조를 무작위로 변경합니다.
    RandomRotation 주어진 각도 범위 내에서 이미지를 무작위로 회전시킵니다.
    RandomResizedCrop 원본 이미지에서 무작위 크기와 비율로 잘라내어 주어진 크기로 리사이징합니다.
    Grayscale 이미지를 회색조로 변환합니다.
    Compose 여러 변형(transforms)을 하나로 결합하여 순차적으로 적용합니다.

     

    코드 예시

    • Compose를 이용해 여러 변형을 하나로 묶어줌
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

     

     

    4.2. 데이터 커스텀 변형 

    torchvision에서 제공하지 않는 특정 변형을 수행하려면 커스텀 변형 클래스를 작성할 수 있습니다. 이를 위해 Python의 __call__ 메서드를 사용하여 클래스를 정의합니다. 이 메서드는 객체를 함수처럼 호출할 수 있게 해줍니다.

    from torchvision import transforms
    from PIL import ImageOps, ImageEnhance
    
    class AdjustBrightness:
        def __init__(self, factor):
            self.factor = factor
    
        def __call__(self, img):
            enhancer = ImageEnhance.Brightness(img)
            img_enhanced = enhancer.enhance(self.factor)
            return img_enhanced
    
    # 사용 예시
    transform = transforms.Compose([
        AdjustBrightness(1.5),
        transforms.ToTensor()
    ])

     

    5. 사전 훈련된 모델 사용

    5.1. 사전 훈련된 모델

     

    PyTorch의 `torchvision.models` 모듈은 다양한 사전 훈련된 모델을 제공합니다. 이러한 모델들은 대규모 이미지 데이터셋(예: ImageNet)에서 훈련되었으며, 다양한 컴퓨터 비전 태스크에 즉시 사용할 수 있습니다. 사전 훈련된 모델을 사용하는 것은 다음과 같은 장점이 있습니다:

    • 높은 정확도: 대규모 데이터셋에서 훈련된 모델들은 높은 성능을 제공합니다.
    • 시간 절약: 처음부터 모델을 훈련시키는 대신, 사전 훈련된 모델을 사용하여 시간을 절약할 수 있습니다.
    • 전이 학습: 사전 훈련된 모델을 출발점으로 사용하여 특정 태스크에 맞게 미세 조정(fine-tuning)함으로써 새로운 작업에 쉽게 적용할 수 있습니다.

    사전학습 모델

    alexnet AlexNet은 깊은 합성곱 신경망(CNN) 구조 중 하나로, 이미지 분류 태스크에서 널리 사용됩니다.
    vgg VGG 모델은 깊이에 따라 여러 버전(VGG11, VGG13, VGG16, VGG19)이 있으며, 이미지 분류에서 높은 정확도를 보여줍니다.
    resnet ResNet은 깊은 네트워크에서 발생할 수 있는 소실된 기울기 문제를 해결하기 위해 잔차 연결(residual connections)을 도입한 모델입니다.
    inception Inception 모델(또는 GoogLeNet)은 병렬로 배열된 여러 크기의 컨볼루션 레이어를 통해 이미지의 다양한 스케일을 효율적으로 학습할 수 있도록 설계되었습니다.
    densenet DenseNet은 각 레이어가 이전 모든 레이어와 연결되는 밀집 연결(dense connections) 구조를 가지고 있어, 효율적인 그래디언트 흐름을 가능하게 합니다.
    mobilenet_v2 MobileNetV2는 경량화된 네트워크로, 모바일이나 임베디드 장치에서 고성능을 유지하면서도 효율적으로 동작하기 위해 설계되었습니다.
    efficientnet EfficientNet은 스케일링 방식을 통해 다양한 크기에서 효율적으로 작동하도록 최적화된 모델로, 높은 정확도와 효율성을 자랑합니다.

     

    mobilenet_v3 MobileNetV3은 경량화와 효율성에 초점을 맞춘 모델로, 최신 컴퓨터 비전 태스크에 적합합니다.
    regnet RegNet은 시스템적인 접근 방식을 통해 설계된 네트워크 구조로, 다양한 작업에 대해 효율적인 성능을 제공합니다.
    vit Vision Transformer(ViT)는 이미지 처리를 위해 Transformer 구조를 사용하는 모델로, 최근 주목받고 있는 접근 방식입니다.

     

    사전학습 모델 사용 예시

    import torchvision.models as models
    
    # 사전 훈련된 ResNet-18 모델 불러오기
    resnet18 = models.resnet18(pretrained=True)
    
    # 모델을 평가 모드로 설정
    resnet18.eval()

     

    5.2. 파인 튜닝하는 방법

    파인 튜닝은 사전 훈련된 모델을 기반으로 추가 학습을 수행하여 새로운 태스크에 적용하는 과정입니다. 주로 다음 단계를 포함합니다:

    1. 사전 훈련된 모델 불러오기: torchvision.models에서 제공하는 사전 훈련된 모델을 불러옵니다.
    2. 출력 레이어 변경: 대부분의 경우, 원본 모델의 출력 레이어를 새로운 태스크의 클래스 수에 맞게 변경해야 합니다.
    3. 모델 파라미터 고정: 필요에 따라 사전 훈련된 모델의 일부 또는 전체 파라미터를 고정(freeze)하여 추가 학습 중에 업데이트되지 않도록 할 수 있습니다.
    4. 추가 학습 수행: 새로운 데이터셋에 대해 모델을 학습시킵니다. 이때, 고정하지 않은 파라미터만 업데이트됩니다.
    import torchvision.models as models
    import torch.nn as nn
    
    # 사전 훈련된 모델 불러오기
    model = models.resnet18(pretrained=True)
    
    # 모든 파라미터를 고정
    for param in model.parameters():
        param.requires_grad = False
    
    # 마지막 Fully Connected 레이어 변경
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)  # 예를 들어, 10개 클래스가 있는 경우
    
    # 이제 model.fc는 학습될 수 있지만, 나머지 네트워크는 고정됩니다.

     

    코드 설명

    1. num_ftrs = model.fc.in_features:
      • model.fc는 ResNet 모델의 마지막 FC 레이어를 가리킵니다.
      • in_features 속성은 이 레이어로 들어오는 입력 특성의 수, 즉 레이어의 입력 크기를 나타냅니다.
      • 이 값을 num_ftrs 변수에 저장하여, 새로운 FC 레이어를 생성할 때 입력 크기로 사용합니다.
    2. model.fc = nn.Linear(num_ftrs, 10):
      • nn.Linear는 PyTorch에서 제공하는 Fully Connected 레이어(또는 선형 레이어)를 생성하는 클래스입니다.
      • num_ftrs는 앞서 저장한 레이어의 입력 크기로, 이전 레이어의 출력과 일치해야 합니다.
      • 10은 새로운 FC 레이어의 출력 크기로 설정됩니다. 이는 모델이 최종적으로 분류해야 하는 클래스의 수를 의미합니다. 예를 들어, 10개의 다른 클래스를 분류하는 문제에서는 출력 크기를 10으로 설정합니다.
      • 결과적으로, 이 코드는 모델의 마지막 레이어를 새로운 데이터셋에 맞게 조정하여, 사전 훈련된 네트워크의 나머지 부분은 그대로 유지하면서 새로운 태스크에 적합하게 파인 튜닝할 수 있도록 합니다.

    모델 저장

    # 모델의 state_dict 저장
    torch.save(model.state_dict(), 'model_state_dict.pth')

     

    모델 불러오기

    model = ...  # 모델 구조를 동일하게 재정의
    
    # 모델의 state_dict 로드
    model.load_state_dict(torch.load('model_state_dict.pth'))
    
    # 모델을 평가 모드로 설정
    model.eval()

     

    6. 객체 탐지 및 세그멘테이션

    torchvision에는 아래와 같은 모델도 사용할 수 있습니다.

    • 객체 탐지를 위한 모델 사용하기 (예: Faster R-CNN, SSD)
    • 세그멘테이션을 위한 모델 사용하기 (예: FCN, DeepLab)
    • 커스텀 데이터셋에 적용

    7. 고급 기능

    • 비디오 및 3D 데이터셋 처리

     

    객체 탐지 및 세그먼테이션, 비디오 및 3D 데이터셋 처리는 다음 페이지에 나눠서. 작성하도록 하겠습니다.

    반응형