본문 바로가기
  • AI 개발자가 될래요
Deep Learning

timm 사용법

by 꿀개 2023. 6. 7.

timm 사용법

 

요즘 모델 경량화 문제 때문에, 여러가지의 backbone network를 사용하여 가장 최적의 backbone은 무엇인지 실험해 보고 있다.

 

원래는 network의 구조를 코드상으로 파악하고, feature map을 따오는 방식으로 실험을 진행하였지만, 시간이 오래 걸리고 해석하기 어렵다는 문제점이 있었다.

 

그러던 중 huggingface에서 작성한 "timm" 라이브러리를 알게 되었다.

 

Timm 이란?

pytorch로 구현된 여러가지 이미지 모델들을 라이브러리화 시킨 것이다. 장점은 쉽게 여러 network들을 학습, 테스트 할 수 있으며, pretrained weight도 제공해 주기 때문에 backbone nerwork로의 활용성도 높다.

 

어떻게 사용하느냐?

깃헙 링크는 다음과 같다.

https://github.com/huggingface/pytorch-image-models

 

GitHub - huggingface/pytorch-image-models: PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, E

PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more - GitHub - hugg...

github.com

 

설치

pip install timm

 

모델 리스트 확인

timm으로 구현된 모델 리스트를 확인하고 싶다면 아래의 코드를 작성해 출력해봐야한다.

pretrained weight가 있는 모델의 리스트를 출력하는 코드이다.

import timm

print(timm.list_models(pretrained=True))

 

여기에 옵션을 더 추가하여, 특정 네트워크를 찾고 싶다면 다음과 같이 하면 된다. 

필자는 mobilenet 관련 네트워크 리스트를 원했기 때문에, *mobile* 이라는 옵션을 넣어줬다.

앞 뒤에 어떤 문자가 와도 상관 없이, mobile 이라는 문자열이 포함되어 있으면 출력하라는 의미이다.

import timm

print(timm.list_models('*mobile*', pretrained=True))

 

mobilenet을 이용해 구현된 모델 리스트가 출력된다. (그러면 2023년 6월 기준 출력 결과)

['mobilenetv2_050', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_100', 'mobilenetv3_large_100_miil', 'mobilenetv3_large_100_miil_in21k', 'mobilenetv3_rw', 
'mobilenetv3_small_050', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilevit_s', 'mobilevit_xs', 'mobilevit_xxs', 'mobilevitv2_050', 'mobilevitv2_075', 'mobilevitv2_100', 'mobilevitv2_125', 'mobilevitv2_150', 
'mobilevitv2_150_384_in22ft1k', 'mobilevitv2_150_in22ft1k', 'mobilevitv2_175', 'mobilevitv2_175_384_in22ft1k', 'mobilevitv2_175_in22ft1k', 'mobilevitv2_200', 'mobilevitv2_200_384_in22ft1k', 'mobilevitv2_200_in22ft1k', 
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']

 

Test

이제 test 를 해보자.

임의의 입력 x를 넣었을 때, mobilenetv3_small_050 모델의 output shape을 출력하는 코드이다.

import timm

# 임의의 입력
x = torch.randn(8, 3, 256, 256)

# 모델 선언
model = timm.create_model('mobilenetv3_small_050', pretrained=True)

# 추론
result = model(x)

# shape 출력
print(result.shape)

 

결과는 다음과 같다.

torch.Size([8, 1000])

해당 모델은 classifier 모듈이 붙어있고, default로 설정된 class의 개수가 1000개 이기 때문에, [8, 1000]의 shape을 가진다.

8은 batch size를 의미한다. 만약 input shape이 (1, 3, 256, 256) 이었다면 [1, 1000] 이 나온다.

 

Backbone Network으로의 사용

timm으로 구현된 모델을 backbone network으로 사용하려면, classifier 모듈을 제거해야 한다. 모듈을 제거하는 방법 중에는 여러가지가 있지만, 가장 쉬운 방법은 다음과 같다.

모델을 선언할 때 'features_only' 옵션을 True로 하는 것이다.

import timm

# 임의의 입력
x = torch.randn(8, 3, 256, 256)

# 모델 선언
model = timm.create_model('mobilenetv3_small_050', pretrained=True, features_only=True)

# 추론
result = model(x)

# shape 출력
for r in result:
	print(r.shape)

 

그러면 모델 output은 마지막 5개 layer가 리스트 형태로 return된다.

torch.Size([8, 16, 128, 128])
torch.Size([8, 8, 64, 64])
torch.Size([8, 16, 32, 32])
torch.Size([8, 24, 16, 16])
torch.Size([8, 288, 8, 8])

이 output을 이용하여 원하는 네트워크의 backbone으로 사용하면 된다.

 

timm 덕분에 backbone network를 변경하는 실험이 매우 간단해졌다. Thank you!