본문 바로가기
  • AI 개발자가 될래요
Python

[Python] Segmentation 분야에서 클래스 별 레이블이 필요할 때 - Label One-Hot Encoding / python 코드

by 꿀개 2023. 2. 2.

Segmentation 분야에서 딥러닝 모델을 학습시킬 때, 

레이블에는 주로 아래 사진과 같이 한 장의 이미지에

픽셀 별로 Segmentation된 이미지가 들어간다.

 

Segmentation 분야에서 주로 사용되는 레이블 예시

 

하지만 모델을 고치거나 Loss Function을 변형하려 하면

각각의 클래스들로만 구성된 레이어들이 필요할 때가 있다.

나는 Fairness Learning을 위해 DRO를 사용하거나, Focal Loss Function을 사용하기 위해 필요했다.

 

예를 들어 클래스가 왼쪽 눈, 오른쪽 눈, ... , 목 으로 총 19개 있다면

왼쪽 눈만 있는 레이어, 오른쪽 눈만 있는 레이어, ... , 목만 있는 레이어 해서

총 19개의 레이어가 필요한 것이다.

 

사진을 보고 예시를 들자면 아래와 같은 레이어들이 필요한 것이다.

 

 

 

Label One-Hot Encoding

이는 레이블에 one-hot 인코딩을 적용하여 해결할 수 있다.

레이블을 one-hot encoding 시키는 python 코드는 아래와 같다.

 

import torch
from typing import Optional

def label_to_one_hot_label(
        labels: torch.Tensor,
        num_classes: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
        eps: float = 1e-6,
        ignore_index=255,
) -> torch.Tensor:

    shape = labels.shape      # torch.Size([8, 473, 473])
    
    # one hot : (B, C=ignore_index+1, H, W)
    one_hot = torch.zeros((shape[0], ignore_index + 1) + shape[1:], dtype=dtype).cuda()     # torch.Size([8, 256, 473, 473]) 

    # labels : (B, H, W)
    # labels.unsqueeze(1) : (B, C=1, H, W)
    # one_hot : (B, C=ignore_index+1, H, W)
    one_hot = one_hot.scatter_(1, labels.unsqueeze(1), 1)      # torch.Size([8, 256, 473, 473])

    # ret : (B, C=num_classes, H, W)
    ret = torch.split(one_hot, [num_classes, ignore_index + 1 - num_classes], dim=1)[0]  # torch.Size([8, 19, 473, 473])

    return ret

 

 

참고 링크: https://gaussian37.github.io/dl-concept-focal_loss/