본문 바로가기
  • AI 개발자가 될래요
Python

[Python/Pytorch] 네트워크 파라미터 수 계산법 / #params 계산법

by 꿀개 2023. 11. 17.

Python, Pytorch 네트워크 파라미터 수 계산법 / # params 계산법

 

PyTorch에서 네트워크 모델의 파라미터 개수를 세는 것은 비교적 간단하다.

 

1. 총 파라미터 수 계산하기 

모델의 모든 파라미터(학습 가능한 파라미터와 학습 불가능한 파라미터 포함)의 개수를 계산한다.

 

2. 학습 가능한 파라미터 수 계산하기

모델의 학습 가능한 파라미터(즉, 가중치)의 개수만을 계산한다.

 

이를 위한 간단한 함수는 다음과 같다.

 

import torch.nn as nn

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_params, 'Trainable': trainable_params}

# 예시 모델
model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 1)
)

params = count_parameters(model)
print(f"Total parameters: {params['Total']}")
print(f"Trainable parameters: {params['Trainable']}")

 

위 코드에서 nn.Sequential로 간단한 모델을 구성하고,

count_parameters 함수를 사용하여 해당 모델의 총 파라미터 수와 학습 가능한 파라미터 수를 계산한다.

 

  • p.numel(): 주어진 텐서의 요소(element) 수를 반환.
  • p.requires_grad: 해당 파라미터가 학습 과정에서 업데이트되어야 하는지 여부를 나타냄.

 

이 방법으로 모든 종류의 PyTorch 모델에 대한 파라미터 수를 계산할 수 있다.

 

나의 모델의 경우

 

현재 연구중인 모델인 경우는 68.14 M 파라미터가 있는 것이다.