본문 바로가기
  • AI 개발자가 될래요
컴퓨터비전

컬러 영상의 히스토그램 매칭(Histogram Matching) / 파이썬 코드

by 꿀개 2023. 5. 2.

히스토그램 매칭이란?

이미지의 색 분포를 목표 이미지의 색 분포와 비슷하게 하는 알고리즘이다.

 

히스토그램 매칭의 원리

타겟 이미지(T)의 색 발생 빈도에 따라 원본 이미지(I)의 색 발생 빈도를 바꾸는 것.

T에서 가장 많이 나타나는 색을 R에서 가장 빈도수가 높은 색으로 변경한다고 생각하면 되겠다.

 

이 과정에서 LUT(Look Up Table)을 작성하는데, 그 작성 원리는 다음과 같다.

# 히스토그램 매칭
lut = np.zeros((256), dtype=np.uint8)
for i in range(256):
    minDiff = float('inf')
    for j in range(256):
        # srcCdf: 소스 이미지의 누적 히스토그램
        # dstCdf: 타겟 이미지의 누적 히스토그램
        diff = abs(srcCdf[i] - dstCdf[j])
        if diff < minDiff:
            minDiff = diff
            lut[i] = j

 

1. 영상의 픽셀 값은 0~255 사이 값만 얻을 수 있으므로 사이즈가 256이고 0으로 채워진 array를 선언한다.

2. 소스 이미지와 타겟 이미지의 누적 히스토그램을 비교하면서, 가장 차이가 큰 부분의 인덱스로 LUT 값을 설정한다.

이 것은 히스토그램 매칭을 효율적으로 할 수 있도록 고안된 알고리즘이다.

 

예시

입력 이미지는 아래와 같이 주로 초록색으로 이루어져있고,

input

 

타겟 이미지는 아래와 같이 주로 파란색으로 이루어져 있다면,

target

 

히스토그램 매칭의 결과는 다음과 같다.

output

 

히스토그램 비교

B, G, R 각각의 채널에 대해 히스토그램 매칭이 잘 되었는지 비교해보았다.

채널 B에 대한 히스토그램

왼쪽: input / 중간: target / 오른쪽: output 에 대한 히스토그램이다.

 

채널 B에 대한 히스토그램 비교

 

output의 히스토그램이 input에 비해 target과 더 비슷해져 매칭이 잘 된 것을 볼 수 있다.

 

마찬가지로 채널 G, R 에 대해서도 히스토그램을 살펴보면

채널 G에 대한 히스토그램 비교
채널 R에 대한 히스토그램 비교

 

히스토그램 매칭이 잘 된 것을 볼 수 있다.

 

코드

히스토그램 매칭을 수행하는 전체 파이썬 코드는 아래와 같다.

import cv2
import numpy as np
import matplotlib.pyplot as plt
import itertools


def histogram_matching(srcImage, srcHist, dstHist):
    srcHist = (255 * (srcHist / max(srcHist))).astype(np.uint8)
    dstHist = (255 * (dstHist / max(dstHist))).astype(np.uint8)

    srcCdf = np.cumsum(srcHist)
    dstCdf = np.cumsum(dstHist)

    srcCdf = srcCdf * dstHist.max() / srcCdf.max()
    dstCdf = dstCdf * srcHist.max() / dstCdf.max()

    lut = np.zeros((256), dtype=np.uint8)
    for i in range(256):
        minDiff = float('inf')
        for j in range(256):
            diff = abs(float(srcCdf[i] - dstCdf[j]))
            if diff < minDiff:
                minDiff = diff
                lut[i] = j

    resultImage = cv2.LUT(srcImage, lut)

    return resultImage


srcImage = cv2.imread("./forest.png")
dstImage = cv2.imread("./sea.png")

# split
img_b, img_g, img_r = cv2.split(srcImage)
target_b, target_g, target_r = cv2.split(dstImage)


histSize = 256
range_ = [0, 256]
histRange = range_
uniform = True
accumulate = False
srcHist = cv2.calcHist([srcImage], [0], None, [histSize], histRange, uniform, accumulate)
dstHist = cv2.calcHist([dstImage], [0], None, [histSize], histRange, uniform, accumulate)
g_srcHist = cv2.calcHist([srcImage], [1], None, [histSize], histRange, uniform, accumulate)
g_dstHist = cv2.calcHist([dstImage], [1], None, [histSize], histRange, uniform, accumulate)
r_srcHist = cv2.calcHist([srcImage], [2], None, [histSize], histRange, uniform, accumulate)
r_dstHist = cv2.calcHist([dstImage], [2], None, [histSize], histRange, uniform, accumulate)


b_resultImage = histogram_matching(srcImage[:,:,0], srcHist, dstHist)
g_resultImage = histogram_matching(srcImage[:,:,1], g_srcHist, g_dstHist)
r_resultImage = histogram_matching(srcImage[:,:,2], r_srcHist, r_dstHist)

# merge
resultChannels = [b_resultImage, g_resultImage, r_resultImage]
resultImage = cv2.merge(resultChannels)

# calc result hist
b_resultHist = cv2.calcHist([b_resultImage], [0], None, [histSize], histRange, uniform, accumulate)
g_resultHist = cv2.calcHist([g_resultImage], [0], None, [histSize], histRange, uniform, accumulate)
r_resultHist = cv2.calcHist([r_resultImage], [0], None, [histSize], histRange, uniform, accumulate)

b_resultHist[0] -= len(zero_index[0])
g_resultHist[0] -= len(zero_index[0])
r_resultHist[0] -= len(zero_index[0])

# cv2.imwrite("gray_test_result.png", resultImage)

cv2.imshow('original', srcImage)
cv2.imshow('target', dstImage)
cv2.imshow('out', resultImage)

plt.figure("b_result", figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(srcHist)
plt.subplot(1, 3, 2)
plt.plot(dstHist)
plt.subplot(1, 3, 3)
plt.plot(b_resultHist)

plt.figure("g_result", figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(g_srcHist)
plt.subplot(1, 3, 2)
plt.plot(g_dstHist)
plt.subplot(1, 3, 3)
plt.plot(g_resultHist)

plt.figure("r_result", figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(r_srcHist)
plt.subplot(1, 3, 2)
plt.plot(r_dstHist)
plt.subplot(1, 3, 3)
plt.plot(r_resultHist)
plt.show()

cv2.waitKey()