Python, Pytorch 네트워크 파라미터 수 계산법 / # params 계산법
PyTorch에서 네트워크 모델의 파라미터 개수를 세는 것은 비교적 간단하다.
1. 총 파라미터 수 계산하기
모델의 모든 파라미터(학습 가능한 파라미터와 학습 불가능한 파라미터 포함)의 개수를 계산한다.
2. 학습 가능한 파라미터 수 계산하기
모델의 학습 가능한 파라미터(즉, 가중치)의 개수만을 계산한다.
이를 위한 간단한 함수는 다음과 같다.
import torch.nn as nn
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_params, 'Trainable': trainable_params}
# 예시 모델
model = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
params = count_parameters(model)
print(f"Total parameters: {params['Total']}")
print(f"Trainable parameters: {params['Trainable']}")
위 코드에서 nn.Sequential로 간단한 모델을 구성하고,
count_parameters 함수를 사용하여 해당 모델의 총 파라미터 수와 학습 가능한 파라미터 수를 계산한다.
- p.numel(): 주어진 텐서의 요소(element) 수를 반환.
- p.requires_grad: 해당 파라미터가 학습 과정에서 업데이트되어야 하는지 여부를 나타냄.
이 방법으로 모든 종류의 PyTorch 모델에 대한 파라미터 수를 계산할 수 있다.
현재 연구중인 모델인 경우는 68.14 M 파라미터가 있는 것이다.
'Python' 카테고리의 다른 글
[Python] float 타입을 다룰 때 주의점 (0) | 2024.01.04 |
---|---|
[Python] "효율적 개발로 이끄는 파이썬 실천 기술" 책 중요 내용 정리 - 1편 (0) | 2024.01.04 |
[Python/코테공부] 날짜 비교하기 (0) | 2023.11.06 |
[Python/코테공부] 특별한 이차원 배열 2 (0) | 2023.11.06 |
[Python/코테 공부] 1로 만들기 (0) | 2023.11.03 |