Article Image
read

Introduction

이전 포스팅에서 Class Activation Map(이하 CAM)을 사용해 CNN이 이미지 분류를 의도한대로 해내는지 확인해봤다. 정리한 내용을 TensorFlow-KR Facebook Group에 공유했었는데, 댓글을 통해 Grad-CAM에 대해 소개받았다. 개념을 제안한 논문과 구현 코드를 살펴보고 그 내용을 간략히 정리해본다.

CAM의 한계

CAM은 CNN의 내부를 열어볼 수 있는 아주 유용한 도구다. 기존 모델의 convolution layer 뒤에 붙은 fully connected layer(이하 FC)를 Global Averager Pooling (이하 GAP)로 교체하고 fine-tuning함으로써, 뉴럴 네트워크가 이미지의 어떤 부분을 보고 특정 레이블로 판단을 내리는지에 대해 알 수 있었다. 기존 접근 방식들에 비해 end-to-end로 한번에 (single forward pass) CAM을 얻을 수 있다는 점에서 매우 매력적인 방법이다. 논문에서의 테스트 결과에 따르면 모델의 후반부 아키텍쳐를 GAP로 바꾸더라도 심각한 성능저하가 일어나지 않았다고 하니 그 유용성이 더 크다고 할 수 있다. 내가 가진 대선주자 얼굴 데이터셋에서는 적은 데이터셋 사이즈 등의 이유로 FC를 사용한 모델은 아예 converge하지 않거나 성능이 GAP에 비해 매우 떨어졌다.

CAM Architecture

이와 같은 장점에도 불구하고 CAM은 그 방식으로 인한 태생적인 단점을 가지고 있다. FC를 GAP로 대체해야 한다는 점, GAP 직전의 Conv layer만 쓸 수 있다는 점, 또 GAP 뒷단에 있는 Dense Layer의 weight 정보 ()가 필요하므로 fine-tuning이나 re-training의 과정을 거쳐야하는 점이다. (애초에 GAP 구조를 가진 네트워크는 재학습할 필요가 없다.) 이러한 문제로 인해서 오브젝트 디텍션 외에 Visual Question Answer(VQA)나 Captioning처럼 다양한 목적을 수행하는 CNN에 CAM을 적용하기 어렵다.

즉, CAM은 오브젝트 디텍션 용도로 만들어진 뉴럴넷을 들여다보는 특수한 도구로 볼 수 있다. 이를 더 일반화해서 적용해볼 수는 없을까?


Grad-CAM: Generalized version of CAM

CAM의 수식 유도 과정에서 GAP는 반드시 필요하다. 마지막 레이어의 k번째 피쳐맵의 GAP값을 로 정의한 후, 이를 이용해 소프트맥스 레이어의 인풋값인 를 표현한 다음, 다시 수식을 정리하여 CAM인 를 구하기 때문이다. Grad-CAM은 어떤 방법을 사용하기에, FC를 사용한 기존 모델 아키텍쳐에서도, 마지막 conv layer가 아닌 다른 레이어에서도 Grad-CAM을 구할 수 있는걸까?

Grad-CAM uses the gradient information flowing into the last convolutional layer of the CNN to understand the importance of each neuron for a decision of interest.

대상 conv layer를 일단 CAM과 같이 마지막 레이어로 두고 생각해보자. Grad-CAM은 마지막 conv 레이어에 들어오는 gradient 정보를 사용해서 타겟 레이블에 대해 각 뉴런이 가지는 중요도를 이해한다고 한다.

우리가 구하고자 하는 Grad-CAM을 이라 할 때 , 는 피쳐맵의 너비와 높이, 는 타겟 클래스를 의미한다.

는 소프트맥스 레이어의 인풋값이고, 는 마지막 conv layer의 k번째 피쳐맵을 의미한다. (CAM 논문에서는 로, 로 표현했다.)

앞에서 마지막 conv layer로 들어오는 gradient information을 사용한다고 했으므로, back propagation이 일어날때 마지막 conv layer로 돌아오는 gradient값을 생각해보자. 논문에서는 소프트맥스 인풋인 에 대해 가지는 gradient를 라 표현하고, 이를 GAP 방식으로 계산한 를 정의한다.

눈문에서는 가 피쳐맵 A로부터의 딥 네트워크 다운스트림의 partial linearization을 의미하며, 타겟 클래스 에 대해 k번째 피쳐맵이 가지는 importance를 의미한다고 표현하고 있다. (partial linearization이 정확히 어떤 의미인지 정확히 모르겠다.)

위에서 구한 에 피쳐맵 A^k를 곱하고 다 더한 후 (weighted combination) ReLU(음수를 모두 0으로)를 씌워 를 구한다.


CAM vs. Grad-CAM

CAM과 비교해서 생각해보자. 혼란을 줄이기 위해 용어를 Grad-CAM 기준으로 바꾼다.

CAM:

Grad-CAM:

Relu를 제외하면 차이는 에서 기인한다. 가장 근본적인 차이는 뭘까? CAM의 는 GAP로 구조를 변경한 후 fine-tuning을 통해 학습한 softmax 레이어의 weight 횡벡터다. 그에 반해 는 소프트맥스 함수의 인풋값의 피쳐맵 k에 대한 편미분값을 구한후 이를 GAP 방식으로 출력한 결과다.

는 피쳐맵 k 이후에 GAP가 오건, FC가 오건 관계없이 다시 conv layer로 넘어오는 gradient를 받아다가 GAP 방식으로 구한 값이라고 정리할 수 있겠다. 즉, CAM과 달리 아키텍쳐 변경이나 재학습을 하지 않고도 를 구할 수 있다.

논문에 첨부된 Appendix A를 통해 저자들은 가 결국 와 같음을 수식 유도를 통해 주장한다. Appendix A를 따라가보자.

Appendix A

먼저 CAM 아키텍쳐를 다시 떠올려보면, 이는 fully-convolutional CNN + GAP + Softmax Layer(FC + softmax)로 구성되어 있다. 마지막 conv layer의 k번째 피쳐맵을 로 두고, 가로 세로가 각각 , 로 인덱스되어있다고 하자. 그러면 번째 피쳐맵의 위치에 있는 액티베이션을 의미한다.

이때 CAM은 에 GAP를 사용해서 를 출력한다. 즉 이를 수식으로 쓰면…

(이를 CAM 논문에서는 로 나눠주지 않고 그냥 합을 구하는 식으로 표현했다. )

그리고 소프트맥스 레이어의 인풋은 를 사용해서 다음과 같이 구한다.

여기서 는 클래스 에 맵핑되는 Softmax 레이어의 weight 횡벡터다.

클래스 에 대한 스코어 에 대한 gradient는 가 되는데, 체인룰에 의해 분모와 분자를 로 나눠서 표현할 수 있다.

에 대해서 편미분을 구하면 가 된다.(차수가 1이니까 상수가 남는다.)

를 식 에 대입하면,

에서 를 유도할 수 있다. 이를 식의 좌변에 다시 대입하면,

피쳐맵의 모든 픽셀에 대해 식의 양변을 합한다.

는 모든 픽셀의 갯수 와 같은데다 , 와는 관계없으므로 다시 다음과 같이 정리할 수 있다. 우변의 상수인 Z도 앞으로 뺀다.

마지막으로 양변의 상수 Z를 제거하면

We can see that up to a proportionality constant() that is normalized out during visualization, the expression for is identical to used by Grad-CAM. Thus Grad-CAM is a generalization of CAM to arbitrary CNN-based architectures, while maintaining the computational efficiency of CAM.

조금 복잡한 수식 유도 과정을 거쳐서, 결국 소프트맥스 레이어의 웨이트 횡벡터가 소프트맥스 함수 인풋이 k번째 피쳐맵의 각 액티베이션에 대한 partial derivative의 총 합임을 확인했다.

내가 이해한 요지는 다음과 같다. Grad-CAM의 는 소프트맥스 인풋의 피쳐맵에 대한 gradient의 픽셀에 대한 GAP로 정의한다. CAM의 아키텍쳐를 가정했을 때, 소프트맥스 레이어의 weight 횡벡터 를 피쳐맵의 픽셀 수(상수)만큼 나눈 값으로 다시 표현할 수 있다. 두 개념이 같다는 것을 전제하면, 굳이 를 재학습할 필요 없이, 기존 모델이 가진 만으로도 충분히 CAM(Grad-CAM)을 얻을 수 있다. 따라서 Grad-CAM을 CAM의 일반화된 버전이라 칭할 수 있다. 또 gradient를 통해 를 구하기만 하면 되므로, 굳이 마지막 conv layer에만 한정되지 않는 장점도 있겠다고 볼 수 있다.

CAM vs. Grad-CAM

Grad-CAM: Implementation

귀여운 고양이와 ResNet50 pretrained 모델로 구현 방법을 살펴보자.

귀여운 고양이

논문에 정리된 구현 방식을 tf.keras를 사용해 구현하면 아래와 같다.

def generate_grad_cam(img_tensor, model, class_index, activation_layer):
    """
    params:
    -------
    img_tensor: resnet50 모델의 이미지 전처리를 통한 image tensor
    model: pretrained resnet50 모델 (include_top=True)
    class_index: 이미지넷 정답 레이블
    activation_layer: 시각화하려는 레이어 이름

    return:
    grad_cam: grad_cam 히트맵
    """
    inp = model.input
    y_c = model.output.op.inputs[0][0, class_index]
    A_k = model.get_layer(activation_layer).output
    
    ## 이미지 텐서를 입력해서
    ## 해당 액티베이션 레이어의 아웃풋(a_k)과
    ## 소프트맥스 함수 인풋의 a_k에 대한 gradient를 구한다.
    get_output = K.function([inp], [A_k, K.gradients(y_c, A_k)[0], model.output])
    [conv_output, grad_val, model_output] = get_output([img_tensor])

    ## 배치 사이즈가 1이므로 배치 차원을 없앤다.
    conv_output = conv_output[0]
    grad_val = grad_val[0]
    
    ## 구한 gradient를 픽셀 가로세로로 평균내서 a^c_k를 구한다.
    weights = np.mean(grad_val, axis=(0, 1))
    
    ## 추출한 conv_output에 weight를 곱하고 합하여 grad_cam을 얻는다.
    grad_cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
    for k, w in enumerate(weights):
        grad_cam += w * conv_output[:, :, k]
    
    grad_cam = cv2.resize(grad_cam, (224, 224))

    ## ReLU를 씌워 음수를 0으로 만든다.
    grad_cam = np.maximum(grad_cam, 0)

    grad_cam = grad_cam / grad_cam.max()
    return grad_cam

CAM 구현도 위와 비슷한 형식으로 다시 작성해보면 다음과 같다.

def generate_cam(img_tensor, model, class_index, activation_layer):
    """
    params:
    -------
    img_tensor: resnet50 모델의 이미지 전처리를 통한 image tensor
    model: pretrained resnet50 모델 (include_top=True)
    class_index: 이미지넷 정답 레이블
    activation_layer: 시각화하려는 레이어 이름

    return:
    cam: cam 히트맵
    """
    inp = model.input
    A_k = model.get_layer(activation_layer).output
    outp = model.layers[-1].output

    ## 이미지 텐서를 입력해서
    ## 해당 액티베이션 레이어의 아웃풋(a_k)과
    ## 소프트맥스 함수 인풋의 a_k에 대한 gradient를 구한다.
    get_output = K.function([inp], [A_k, outp])
    [conv_output, predictions] = get_output([img_tensor])
    
    ## 배치 사이즈가 1이므로 배치 차원을 없앤다.
    conv_output = conv_output[0]
    
    ## 마지막 소프트맥스 레이어의 웨이트 매트릭스에서
    ## 지정한 레이블에 해당하는 횡벡터만 가져온다.
    weights = model.layers[-1].get_weights()[0][:, class_index]
    
    ## 추출한 conv_output에 weight를 곱하고 합하여 cam을 얻는다.
    cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
    for k, w in enumerate(weights):
        cam += w * conv_output[:, :, k]
    
    cam = cv2.resize(cam, (224, 224))
    cam = cam / cam.max()
    
    return cam

두 코드를 비교해보면 차이가 나는 부분은 conv_output에 곱해줄 weight 횡벡터를 구하는 방식과 히트맵에 Relu를 적용하는 부분이다.

CAM vs. Grad-CAM (Implementation)

Pretrained ResNet50에 적용했을 때 그 차이를 살펴보자. 대상 activation layer는 마지막에서 4번째 레이어인 ‘activation_49’로 설정했다.

ResNet50 - CAM & Grad-CAM

먼저 시각적으로 큰 차이가 있어보이지는 않는다. 다만 히트맵에 대한 ReLU 연산으로 인해 Grad-CAM의 중심부와 외곽 영역의 색깔 차이가 더 두드러져 보인다.

그럼 굳이 Grad-CAM을 안써도 되는게 아닐까?

1. Grad-CAM은 GAP이 없어도 쓸 수 있다.

논문에서 주장하다시피 Grad-CAM은 CAM의 일반화된 버전이다. CAM은 FC 대신 Global Average Pooling을 사용하는 모델 아키텍쳐에서만 적용 가능하다. 위에서 구현한 코드를 모델만 VGG16(FC가 있는 원래 버전)으로, 대상 액티베이션 레이어를 ‘block5_conv3’로만 바꿔도 CAM은 아래와 같은 오류를 뱉는다.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-18-5c4032a6fb73> in <module>()
----> 1 cam = generate_cam(img_tensor, vgg_model, 277, 'block5_conv3')

<ipython-input-16-720a814eabd3> in generate_cam(img_tensor, model, class_index, activation_layer)
     32     cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
     33     for k, w in enumerate(weights):
---> 34         cam += w * conv_output[:, :, k]
     35 
     36     cam = cv2.resize(cam, (224, 224))

IndexError: index 512 is out of bounds for axis 2 with size 512

weights를 돌면서 곱해주는 부분에서 에러가 나는데, weights.shape를 찍어보면 4096이 찍힌다. 그 이유는 pretrained된 VGG16은 피쳐맵을 flatten한 후 FC로 넘기는 과정에서 element 수를 피쳐맵 개수가 아닌 4096개로 올리기 때문이다. 이로 인해 마지막 소프트맥스 레이어의 웨이트 매트릭스는 4096 * 1000 사이즈가 되고, 여기서 우리가 관심있는 레이블의 횡벡터를 가져오면 이 길이가 4096이 된다. 그런데 우리의 피쳐맵은 512개에 불과하므로 당연히 계산이 안되는 상황이 발생한다.

그러나 Grad-CAM은 전혀 문제가 없다. Grad-CAM 함수에서 구하는 grad_val은 A_k에 대한 y_c의 gradient이므로, 크기는 A_k와 동일하다. 그러므로 피쳐맵별로 평균을 내면 당연히 A_k의 피쳐맵 수를 길이로 하는 벡터를 얻게 되고, 히트맵 계산이 가능해진다.

VGG16 - Grad-CAM ResNet과 같은 285 클래스(Egyptian_cat)에 대한 VGG16의 Grad-CAM. 모델에 따라 분류결과가 다르고 그 작동방식도 달라지는 것으로 보인다.

그럼 위에서 사용한 ResNet50는? ResNet50의 구조를 summary()함수로 찍어보면 소프트맥스 레이어(fc1000) 전에 AveragePooling2D로 피쳐맵별 평균을 구한 다음, 이를 flatten_1레이어에서 일렬로 세우는 걸 확인할 수 있다. Global Average Pooling에서 하는 연산과 동일하다. (즉 추가적으로 train해야할 파라미터가 없다.)

ResNet50 모델 아키텍쳐 끝부분

_________________________________
activation_49 (Activation)       (None, 7, 7, 2048)    0           add_16[0][0]                     
____________________________________________________________________________________________________
avg_pool (AveragePooling2D)      (None, 1, 1, 2048)    0           activation_49[0][0]              
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 2048)          0           avg_pool[0][0]                   
____________________________________________________________________________________________________
fc1000 (Dense)                   (None, 1000)          2049000     flatten_1[0][0]                  
====================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
_________________________________________________________________

2. Grad-CAM은 마지막 conv layer에 한정되지 않는다.

위에서는 activation_49레이어를 대상으로 CAM과 Grad-CAM을 뽑아봤다. 그렇다면 이전의 레이어들의 뉴런은 이미지의 어떤 부분에 반응했을까?

안타깝게도 CAM은 activation_48에 대한 결과를 계산하지 못한다. 그 이유 역시 계산하는 차원이 맞지 않아서다. activation_48의 shape은 (7, 7, 512)인데 반해, CAM의 weights는 2048길이의 벡터다.

1번 이유와 마찬가지로 Grad-CAM은 gradient를 weights로 사용하기 때문에 차원이 맞지 않는 문제로부터 자유롭다.

ResNet50의 인풋 레이어부터 activation_49까지 Grad-CAM을 찍어보는 것도 가능하다.

ResNet50 - Grad-CAM

activation_45 정도부터 고양이 얼굴 영역을 잡기 시작한다.

Guided Grad-CAM

Grad-CAM을 통해 이미지의 어떤 영역이 특정 클래스에 반응하는지 히트맵으로 시각화할 수 있었다. 아쉬운 점은 히트맵은 이미지의 특성을 상세하게 보기 어렵다는 점인데, 논문에서는 ‘Guided Grad-CAM’을 제안함으로써 해결했다.

‘Guided Grad-CAM’은 영역을 잡아내는 ‘Grad-CAM’에 명확한 이미지 윤곽을 리턴하는 ‘guided backpropagation’의 장점을 접목한 개념이다.

Grad-CAM Architecture

Guided Backpropagation

‘Guided BackPropagation’은 2014년에 나온 ‘Striving for Simplicity: The All Convolutional Net’에서 제안한 시각화 방식이다.

We call this method guided backpropagation, because it adds an additional guidance signal from the higher layers to usual backpropagation. This prevents backward flow of negative gradients, corresponding to the neurons which decrease the activation of the higher layer unit we aim to visualize.

이전에 제안된 ‘deconvnet’은 MaxPooling 레이어를 거슬러 올라가기 위해 최댓값을 가진 부분 피쳐맵이 어디인지 기억하는 ‘switches’를 사용했다. 그런데 이는 추가해야하는 기능일 뿐 아니라, 시각화한 결과가 깔끔하지 않고 알아보기 어렵다. 이를 보완하기 위해 논문에서 제안한 개념이 ‘guided backpropagation’이다.

본 논문에서는 MaxPooling layer를 stride가 2 이상인 conv layer로 대체하기 때문에 guided backpropagation은 ‘switches’에 의존할 필요도 없고, 뒷 레이어의 gradient 뿐만 아니라 현 레이어의 relu 값도 같이 사용하므로 시각화 결과가 더 깔끔하다는 장점이 있다.

논문에 소개된 구현 방식의 차이를 잠시 살펴보자.

guided backpropagation 이미지 출처: ‘Striving for Simplicity: The All Convolutional Net’

먼저 deconvnet이나 guided backpropagation은 a)도식처럼 모델의 backward pass를 이용해서 이미지 을 재구축한다. 이를 위해서는 backpropagation, deconvnet, guided backpropagation의 방법을 통해서 backward pass를 구현할 수 있다. 수식을 말로 풀어보면 다음과 같다.

0) activation

간단한 relu 연산으로, 번째 레이어의 피쳐맵 의 음수를 모두 0으로 바꿔 그 다음 레이어의 피쳐맵 를 만든다.

1) backpropagation

backward pass 과정에서 본 레이어의 재구축 피쳐맵은 뒷 레이어의 gradient()에서 본 레이어의 ReLU 양수인 부분 만 살린다.

2) deconvnet

deconvnet은 본 레이어의 ReLU는 신경쓰지 않는다. 뒷 레이어에서 내려오는 gradient 중 양수만을 취한다.

3) guided backpropagation

마지막으로 guided backpropagation은 backpropagation과 deconvnet을 결합한 형태로, 둘 중 하나라도 음수라면 해당 값을 0으로 처리한다. 이러한 이유로 인해 위 그림처럼 guided backpropagation은 다른 방법에 비해 더 적은 수의 gradient를 이미지 재구축에 사용하게 되고, 이러한 점이 더 깔끔한 이미지로 이어지는 듯 하다.

lasange로 구현된 코드를 참조해 tf.keras를 이용해 다시 구현해보았다.

모델은 VGG16을 사용하였으며, 먼저 정답 확률이 가장 높았던 green_snake(55)를 입력해 saliency map을 그려본다.

1) backpropagation

deconvnet

2) deconvnet

deconvnet

3) guided backpropagation

guided backpropagation

논문에서 설명한 바와 같이 guided backpropagation이 다른 두 방법에 비해 더 깔끔한 뱀 형태를 도출해냈다. (엄밀히 따지면 deconvnet을 구현함에 있어 ‘switches’를 사용하지 않고 위 구현 방식만을 사용했기에, 완전무결한 비교라고는 할 수 없다.)

guided backpropagation 결과를 보면 그럴듯 해보이기는 하지만, 사실 의문이 많이 남는다. 4번째 그림인 neg. saliency는 뱀 클래스 분류 결과와 반대의 상관성을 가진 이미지라 할 수 있다. 즉 직관적으로 뱀 클래스와 전혀 다른 이미지가 표현되어야 할 듯 하지만, 뱀과 크게 다르지 않아보인다.

또한 대상 클래스에 Egyption Cat 클래스(285)를 넣어도 그려지는 맵이 거의 흡사하다.

guided backpropagation - Egyption Cat

Guided Grad-CAM

Grad-CAM 논문에서는 ‘guided backpropagation’ 등의 시도에 대해 매우 뚜렷한 이미지를 생성하기는 하나 클래스별로 특징을 잡아내지는 못한다고 평가한다.

Despite producing fine-grained visualizations, these methods are not class-discriminative. Visualizations with respect to different classes are nearly identical.

Grad-CAM이 만들어내는 히트맵은 위치를 정확하게 포착하지만 영역을 잡아낼 뿐, 뚜렷한 이미지를 생성하지는 않는다. 그렇다면 윤곽을 guided backpropagation으로 잡고, 클래스에 특화된 부분만 Grad-CAM으로 마스킹하면 좋지 않을까? 이것이 Guided Grad-CAM의 접근 방식이다.

코드로는 어떻게 구현해야 할까? 일반적인 backpropagation에서는 앞서 수식에서 살펴본 바와 같이 뒷 레이어의 gradient 중 현 레이어의 ReLU에서 살아남은 영역만을 리턴한다. 원래 모델의 Gradient를 계산하는 로직을 우리가 원하는 guided backpropagation으로 갈아끼우는 과정을 거치면 된다.

1) GuidedBackProp 함수를 registry에 등록한다.

from tensorflow.python.framework import ops

def register_gradient():
    if "GuidedBackProp" not in ops._gradient_registry._registry:
        @ops.RegisterGradient("GuidedBackProp")
        def _GuidedBackProp(op, grad):
            dtype = op.inputs[0].dtype
            return grad * tf.cast(grad > 0., dtype) * tf.cast(op.inputs[0] > 0., dtype)

register_gradient()

2) 새로운 텐서플로우 그래프를 만든다음, 기존의 ‘Relu’의 gradient를 1)에서 등록한 GuidedBackProp으로 업데이트하고, 새로운 모델을 만들어서 리턴한다.

from tensorflow.python.framework import ops

def modify_backprop(model, name):
    g = tf.get_default_graph()
    with g.gradient_override_map({'Relu': name}):
        
        layer_dict = [layer for layer in model.layers[1:]
                     if hasattr(layer, 'activation')]
        
        for layer in layer_dict:
            if layer.activation == activations.relu:
                layer.activation = tf.nn.relu
                
        new_model = vgg16.VGG16()
        
    return new_model

guided_model = modify_backprop(model, 'GuidedBackProp')

3) guided backpropagation 이미지에 Grad-CAM을 적용한다.

## guided backpropagation (55번 클래스로 지정했으나 별 차이는 없다)
saliency_fn = compile_saliency_function(guided_model, 55)
saliency = saliency_fn([img])[0]

grad_cam = generate_grad_cam(img, model, 55, 'block5_conv3')
guided_grad_cam = saliency[0] * grad_cam[..., np.newaxis]

Grad-CAM에서 중요하지 않은 부분들은 ReLU 연산에 의해 모두 0 처리되었으므로, guided backpropagation에서 이 영역에 해당하는 부분들은 모두 가려지게 된다.

4) deprocess 과정을 거쳐 시각화한다.

def deprocess_image(x):
    
    ## 평균을 0으로, 표준편차를 0.1로 하도록 normalize한다.
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1

    ## [0, 1]사이로 클리핑한다.
    x += 0.5
    x = np.clip(x, 0, 1)

    ## 255를 곱해 RGB 값으로 바꾼다.
    x *= 255
    
    ## [0, 255]사이로 클리핑한 후 정수로 바꾼다. 
    x = np.clip(x, 0, 255).astype('uint8')
    return x

plt.imshow(deprocess_image(guided_grad_cam))

이번에는 Green_snake(55)와 Egyption cat(285)에 따라 Guided Grad-CAM이 얼마나 class-discriminative한지 비교해보자.

Guided Grad-CAM

밋밋했던 guided backpropagation보다 훨씬 그럴듯한 그림이 나온다.

Grad-CAM: 대선주자 얼굴 위치 추적기

이론과 코드 테스트를 마쳤으니 실제 데이터셋에 적용해볼 차례다. 지난번에 CAM을 공부할때는 가장 쓰기 쉬웠던 VGG16을 fine-tuning해서 써보았는데, 일전에 여러 pretrained된 모델의 성능을 비교해봤을 때는 ResNet의 성능이 validation accuracy가 가장 좋았다. 또한 최근 테스트셋을 다시 확보하여 테스트했을때도 CAM 논문 방식으로 수정한 VGG16(테스트 정확도 80%)보다 ResNet50(82%)로 소폭 더 좋았다.

ResNet50로 fine-tuning한 모델의 시각화 결과를 살펴보자.

1) CAM + bounding box

ResNet50: CAM

2) Grad-CAM + bounding box

ResNet50: Grad-CAM

3) CAM, Grad-CAM, Guided Grad-CAM

ResNet50: Grad-CAM

앞에서 비교해봤던 다른 사례들과 마찬가지로, ReLU가 적용된 Grad-CAM이 CAM에 비해 히트맵이 더 뚜렷하게 드러났다. 특히 안철수, 심상성 후보의 히트맵은 ReLU의 적용여부가 상당히 큰 차이를 만들었다. 히트맵의 깔끔함과 정성적 판단(얼굴을 잘 잡으냐)으로 보면 Grad-CAM이 더 나은 듯 하다. Guided Grad-CAM을 통해서 추출한 fine-grained 이미지도 후보의 얼굴 윤곽을 잘 잡아낸 것으로 보인다.

Grad-CAM을 통한 개선 방향 도출

위와 같이 시각화를 시도함으로써 현재 모델이 가진 단점을 파악할 수 있다. 내가 학습한 모델은 대체로 후보의 얼굴을 보고 판단을 내리지만, 넥타이와 같은 다른 요소가 미치는 영향도 일부 존재한다. 특히 모델은 얼굴보다도 빨강(자유한국당 대표 색상) 넥타이를 보고 홍준표 후보의 레이블을 선택했다. 유승민 후보의 경우 역시 노란색 넥타이에 모델이 약간 반응했음을 알 수 있다. 이를 보정하기 위해서는 각 후보들의 평상복 차림을 데이터셋에 추가함으로써 모델이 의상에 의존적이지 않도록 만들 필요가 있겠다.


Outro

CAM 논문에 비해 Grad-CAM 논문은 소화하기 버거운 점이 많았다. Grad-CAM 수식 유도과정에서 gradient가 끼어드는데다 guided backpropagation이 등장하면서 예전에 잠깐 읽고 넘어갔던 All Convolutional Net 논문까지 다시 읽었다. 이해해야 하는 양이 많아 버거웠지만 deconvnet, guided backpropagation, CAM, Grad-CAM으로 이어지는 일련의 학문 발달의 과정을 일부라도 맛볼 수 있어 재밌는 공부였다.

본 블로그 포스팅에 사용한 일부 데이터셋 및 코드는 아래 깃헙에 올려두었다. https://github.com/junkwhinger/grad_cam

Reference

https://arxiv.org/abs/1610.02391
https://arxiv.org/pdf/1412.6806.pdf
https://github.com/jacobgil/keras-grad-cam
https://github.com/insikk/Grad-CAM-tensorflow
https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb

Blog Logo

Junsik Whang


Published

Image

jsideas

a novice's journey into data science

Back to Overview