Attention please

[논문 리뷰] U-Net++: A Nested U-Net Architecture for Medical Image Segmentation(2018) 본문

논문 리뷰/Image Segmentation

[논문 리뷰] U-Net++: A Nested U-Net Architecture for Medical Image Segmentation(2018)

Seongmin.C 2023. 7. 5. 11:16

이번에 리뷰할 논문은 UNet++: A Nested U-Net Architecture for Medical Image Segmentation 입니다. 

 

https://paperswithcode.com/paper/unet-a-nested-u-net-architecture-for-medical

 

Papers with Code - UNet++: A Nested U-Net Architecture for Medical Image Segmentation

🏆 SOTA for Video Polyp Segmentation on SUN-SEG-Easy (Dice metric)

paperswithcode.com

 

본 논문에서 제안하는 모델은 UNet++ 으로, 2015년에 제안되었던 UNet의 업그레이드 버전이라 생각하시면 됩니다. 

2023.07.02 - [논문 리뷰/Image Segmentation] - [논문 리뷰] U-Net: Convolutional Networks for BiomedicalImage Segmentation(2015)

 

[논문 리뷰] U-Net: Convolutional Networks for BiomedicalImage Segmentation(2015)

이번에 리뷰할 논문은 U-Net: Convolutional Networks for BiomedicalImage Segmentation 입니다. https://paperswithcode.com/paper/u-net-convolutional-networks-for-biomedical Papers with Code - U-Net: Convolutional Networks for Biomedical Image Segme

smcho1201.tistory.com

 

UNet 모델에 대해 간단히 설명을 드리자면 해당 architecture는 크게 Contracting Path(수축 경로) 와 Expanding Path(확장 경로) 으로 구성됩니다. 

 

 

Contracting Path는 UNet의 Encoder에 해당하는 부분이며 Input Image 에서 pixel 단위의 차원을 축소하여 semantic한 information을 추출합니다. 해당 부분은 일반적인 CNN Architecture와 구조가 동일합니다.

 

Expanding Path는 UNet의 Decoder에 해당하는 부분이며 Contracting Path에서 추출한 semantic한 정보와 local한 정보를 concatenation 하여 up-sampling을 진행합니다. 원래 이미지와 pixel 단위로 비교해야하기 때문에 원래 이미지의 size만큼 해상도를 복원해주어야 하기 때문이죠. 해상도를 복원하기 위해 진행하는 Up-sampling 과정은 많은 정보 손실을 가져옵니다. UNet 에서는 이 문제를 해결하기 위해 Contracting Path의 feature map과 Expanding Path의 feature map을 concate 해줍니다. 

 

하지만 이러한 Concatenation 역시 손실된 정보를 잘 보완한다고 하기에는 부족합니다. 이런 connection 부분을 좀 더 dense하고 nested 하게 연결하자라는 아이디어에서 UNet ++ 가 나오게 됩니다.

 

 

 

 

 

 

 

Abstract

본 논문에서 제안하는 UNet ++ 모델은 medical image segmentation 분야에서 더 강력한 성능을 보여준다고 합니다. 이 Architecture는 UNet 모델과 마찬가지로 Encoder-Decoder 형태를 가지고 있으며 서로 nested, dense skip pathways 로 연결됩니다. 본 논문에서 새롭게 디자인된 skip pathways는 encoder와 decoder의 feature map들 사이에 semantic gap을 줄이는 것을 목표로 합니다.

 

보통 지금까지의 모델들은 Encoder를 통해 semantic information을 추출하여 바로 Decoder의 feature map과 연결하였기 때문에 의미적으로 비슷하지 않은 상태에서 fusion되어 왔습니다. 하지만 본 논문에서는 결합될 feature map이 의미적으로 비슷할 때 optimizer는 더 쉬운 학습 작업을 처리하게 될 것이라고 주장합니다.

 

 

 

 

 

 

 

Introduction

image segmentation 분야에서 SOTA 모델들은 보통 UNet, FCN 과 같이 encoder-decoder Architecture 형태를 가집니다. 이런 network의 유사점으로는 skip connections 를 들 수 있는데 이는 각 layer의 출력을 모델의 뒷부분으로 직접 전달하여 정보가 network를 통과하면서 손실되는 것을 막아주기 위해 Encoder와 Decoder를 연결해주는 것을 말합니다. 즉, 깊은 semantic 정보와 얕은 low level의 정보를 결합하는 것이죠.

 

natural 한 일반적인 이미지의 경우 그리 높은 수준의 accuracy를 요하지는 않습니다. 하지만 의료분야에서는 아주 작은 분할 오류도 사용자에게 큰 피해가 될 수 있기 때문에 target object에 대해 미세한 세부 사항을 효과적으로 복구할 수 있는 architecture를 만들어야합니다.

 

본 논문에서는 medical image를 더 정확하게 segmentation 하기 위해 UNet++ 를 제안합니다. 이 모델의 경우 nested 와 dense skip connection을 사용합니다. 이는 기존 UNet에서 사용되었던 skip connection과는 차이가 있습니다. Encoder에서 Decoder로 직접 feature map을 전달하기 때문에 서로 비슷하지 않은 feature map이 결합이 이루어지지만 UNet++의 경우 의미론적으로 비슷한 feature map을 결합시킨다는 차이가 있죠.

 

실제로 기존의 UNet 과 wide UNet 모델에 비해 더 좋은 성능을 보여주었다고 설명하고 있습니다. 

 

 

 

 

 

 

 

re-designed skip pathways

먼저 UNet++이 기존의 UNet 모델과 어떤 차이를 가지는지 그림을 통해 알아봅시다.

 

 

아주 복잡하게 각 convolution들이 연결되어있는 것을 확인할 수 있습니다. 위 figure에서 검은색 부분은 기존 UNet 모델과 동일합니다. down sampling, up sampling 과정을 진행하며 각 level에 맞는 encoder와 decoder를 skip connection으로 연결해주죠. 

 

여기에서 초록색, 파란색, 빨간색 부분의 연결이 UNet++ 에서 추가된 부분입니다. 이중에서도 re-designed skip connection 에 해당되는 초록색과 파란색 부분을 살펴봅시다.

 

사실 이 re-designed skip connection에서 나오는 dense convolution block의 개념은 DenseNet 에서 나온 개념입니다. 각 layer의 모든 출력들을 다음 layer의 입력으로 받는다는 아이디어이죠. 

 

2022.12.29 - [논문 리뷰/Image classification] - [논문 리뷰] DenseNet(2017), 파이토치 구현

 

[논문 리뷰] DenseNet(2017), 파이토치 구현

이번에 리뷰할 논문은 "Densely Connected Convolutional Networks" 이다. CNN 모델의 성능을 높이기 위해 가장 직접적인 방법은 층의 깊이를 늘리는 것이다. 하지만 단순히 층이 깊어지기만 하면 vanishing gradie

smcho1201.tistory.com

 

UNet++의 skip connection 방식도 이와 유사합니다. dense convolution block을 추가하여 semantic information이 손실되는 것을 최대한 막으면서 전달하기 위함입니다.

 

 

위 공식은 각 위치에 존재하는 노드들이 어떤 계산을 통해 나타나는지를 보여줍니다. 

 

  • $ i $ : encoder 내의 down sampling layer를 indexing
  • $ j $ : skip pathways의 dense block의 convolution layer를 indexing
  • $ H $ : convolution, activation function
  • $ U $ : up-sampling layer
  • $ \left [   \right ] $ : concatenation layer

 

계산 방식은 j=0, j=1, j>1 일 때로 나누어지는데 j=0일 때는 encoder의 이전 layer에서 단 1개의 input만 받습니다. 

 

 

이 부분을 보면 $ X^{1,0} $ 노드는 encoder의 이전 layer $ X^{0,0} $ 의 input만을 받는 것을 볼 수 있습니다. j = 0 이기 때문에 1개의 input을 받는 것이죠.

 

 

 

j = 1 인 node의 경우 encoder의 연속된 2개의 layer에서 input을 받습니다. 

 

 

$ X^{0,1} $ 노드가 encoder의 연속된 2개의 layer $ X^{0,0} $ 와 $ X^{1,0} $ 에서 input을 받는 것을 볼 수 있죠.  j = 1인 경우 총 2개의 input을 받는 것이죠.

 

 

 

마지막으로 j > 1 인 node의 경우 j + 1 개의 input을 받는데 j개는 같은 skip pathways 에 있는 node들의 output값들이며, 나머지 1개는 아래쪽 skip pathways의 up-sampling 출력입니다.

 

 

$ X^{0,2} $ 노드가 같은 skip pathways에 있는 2개의 노드 $ X^{0,0} $, $ X^{0,1} $의 output을 input으로 받으며, 아래쪽 skip pathways 아래쪽 node인 $ X^{1,1} $ 의 up-sampling 으로 나온 output을 받습니다. j = 2인 경우 총 3개의 input을 받는 것이죠.

 

 

 

다음으로 skip connection이 어떤 연산을 통해 이루어지는지 밑의 그림을 통해 확인할 수 있습니다.

 

첫 번째 노드인 $ X^{0,1} $ 의 경우 연산과정이 $ H\left [ X^{0,0}, U(X^{1,0}) \right ] $ 으로 이루어지는데 $ X^{0,0} $ 와 $ U(X^{1,0})$ 를 concatenation을 한 후 convolution 과 activation function 연산을 진행합니다.

 

 

 

 

 

 

 

Deep supervision

Deep supervision은 총 두가지의 mode가 존재합니다. 

  • accurate mode : 각 segmentation branch의 결과를 평균내어 결과 도출
  • fast mode : 하나의 segmentation branch 만으로 최종 segmentation map 결정

 

자 그러면 segmentation branch 라는게 무엇일까요?

 

 

위 그림은 UNet++ architecture 로 만들 수 있는 4개의 segmentation branch를 보여주고 있습니다. 기존의 UNet과 달리 skip pathways 에 encoder와 decoder의 feature map 사이에 있는 semantic gap을 연결해주는 convolution layer가 존재하기 때문에 위와 같이 모델을 branch 단위로 쪼갤 수 있는 것이죠.

 

가장 상단의 convolution layer 인 $ \left\{ x^{0,j}, j \in \left\{ 1,2,3,4 \right\} \right\} $ 에 해당하는 노드들은 input image 와 같은 resolution의 feature map을 생성합니다. 

 

그럼 여기까지 segmentation branch가 무엇이지에 대해 감이 오실 겁니다. 그런데 위에서 말했던 것처럼 deep supervision에는 두가지 종류가 있었죠. accurate mode와 fast mode 입니다.

 

 

Accurate mode

위에서 보았던 것처럼 해당 논문에서 보여주는 branch는 $ UNet++L^{1} $, $ UNet++L^{2} $, $ UNet++L^{3} $, $ UNet++L^{4} $ 가 있습니다. 각각의 branch들은 같은 resolution의 feature map을 생성하겠죠. 이 각각의 branch들이 만들어낸 feature map들은 모두 평균을 낸다는 것이 바로 accurate mode의 접근입니다.

 

 

 

Fast mode

다음으로 두번째 mode인 fast mode에 대해 살펴보죠.

 

먼저 branch들 중 $ UNet++L^{4} $를 살펴봅시다

 

 

branch 중에서도 UNet++ 전체 architecture와 동일한 가장 크고 깊은 branch입니다. 반대로 $ UNet++L^{1} $을 살펴보죠.

 

 

상단의 dense convolution block들 중 가장 왼쪽에 있는 $ X^{0,1} $이 생성하는 feature map을 output으로 받고 있습니다. 

 

$ UNet++L^{4} $ branch는 $ UNet++L^{1} $ 에 비해 모델이 무거워 정확도는 기대해볼 수 있지만 속도 측면에 있어서는 분명 더 느릴 것입니다. 즉, fast mode는 가지치기(pruning)을 통해 segmentation branch를 선택하는 것을 의미합니다. 

 

모든 branch들의 결과를 다 평균내어 고려하는 accurate mode와 달리 하나의 branch만을 선택하기에 속도에 더 집중하였다는 것을 확인할 수 있습니다.

 

 

 

 

 

 

 

Loss function

본 논문에서는 loss function을 binary cross-entropy와 dice coefficient 를 융합하여 사용합니다.

 

 

 

 

 

 

 

 

 

 

Experiment

본 논문에서 실험을 위해 사용한 데이터셋은 총 4가지이며, U-Net 모델과 wide U-Net 모델을 비교 대상으로 삼았습니다. 

 

 

위 figure를 통해 UNet++ 모델이 다른 두 모델에 비해 성능적으로 개선되었음을 확인할 수 있습니다.

 

 

위 figure에서는 4개의 branch의 속도와 성능을 비교하고 있습니다. (a) ~ (d) 는 서로 다른 데이터셋을 적용한 것이며, 당연하게도 더 무거운 branch 일수록 속도는 느려지지만 성능은 올라가는 것을 확인할 수 있습니다.

 

 

 

 

 

 

 

Pytorch 구현

그럼 이제 UNet++ 모델을 코드로 구현해보록 하겠습니다. 사용할 framework는 pytorch 이며 각각의 part들을 차례로 살펴보도록 하죠.

 

class Unet_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

 

UNet++ 모델의 기본틀은 UNet입니다. 위 코드는 base가 되는 UNet block을 생성하기 위한 모듈입니다.

 

input data가 들어오면 conv -> bn -> relu -> conv -> bn -> relu 순으로 적용되어 feature map을 생성합니다.

 

class Nested_UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False):
        super().__init__()

        num_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # DownSampling
        self.conv0_0 = Unet_block(input_channels, num_filter[0], num_filter[0])
        self.conv1_0 = Unet_block(num_filter[0], num_filter[1], num_filter[1])
        self.conv2_0 = Unet_block(num_filter[1], num_filter[2], num_filter[2])
        self.conv3_0 = Unet_block(num_filter[2], num_filter[3], num_filter[3])
        self.conv4_0 = Unet_block(num_filter[3], num_filter[4], num_filter[4])

        # Upsampling & Dense skip
        # N to 1 skip
        self.conv0_1 = Unet_block(num_filter[0] + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_1 = Unet_block(num_filter[1] + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_1 = Unet_block(num_filter[2] + num_filter[3], num_filter[2], num_filter[2])
        self.conv3_1 = Unet_block(num_filter[3] + num_filter[4], num_filter[3], num_filter[3])

        # N to 2 skip
        self.conv0_2 = Unet_block(num_filter[0]*2 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_2 = Unet_block(num_filter[1]*2 + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_2 = Unet_block(num_filter[2]*2 + num_filter[3], num_filter[2], num_filter[2])

        # N to 3 skip
        self.conv0_3 = Unet_block(num_filter[0]*3 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_3 = Unet_block(num_filter[1]*3 + num_filter[2], num_filter[1], num_filter[1])

        # N to 4 skip
        self.conv0_4 = Unet_block(num_filter[0]*4 + num_filter[1], num_filter[0], num_filter[0])

        if self.deep_supervision:
            self.output1 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output2 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output3 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output4 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        else:
            self.output = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        '''# initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')'''


    def forward(self, x):                    # (Batch, 3, 256, 256)

        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1))

        if self.deep_supervision:
            output1 = self.output1(x0_1)
            output2 = self.output2(x0_2)
            output3 = self.output3(x0_3)
            output4 = self.output4(x0_4)
            output = (output1 + output2 + output3 + output4) / 4
        else:
            output = self.output(x0_4)

        return output
                                                                                                                                          124,1         Bot

 

먼저 각 convolution layer들을 위에서 정의했던 Unet_block 클래스를 사용하여 초기화합니다. 예를 들어 self.conv2_0 은 $ X^{2,0} $ 노드와 이전 노드를 연결해주는 convolution 연산이라고 생각하면 되겠습니다.

 

추가로 self.deep_supervision이 True 인 경우에는 조건문 1가지를 수행하는데 위에서 설명했던 것처럼 x0_1, x0_2, x0_3, x0_4 의 결과를 모두 더해 나누어 평균값을 output값으로 사용하는 것을 확인할 수 있습니다.

 

 

 

 

Comments