[리뷰] TTN: A DOMAIN-SHIFT AWARE BATCH NORMALIZATION IN TEST-TIME ADAPTATION - ICLR 2023
새롭게 리뷰할 논문은 TTN: A DOMAIN-SHIFT AWARE BATCH NORMALIZATION IN TEST-TIME ADAPTATION입니다. ICLR 2023년 논문으로 퀄컴과 카이스트에서 제출한 논문입니다. Test time adaptation이라는 train과 test에서 발생하는 domain shift를 해결하기 위한 논문으로 간단 리뷰 시작하겠습니다!
리뷰한 내용의 부족한 점, 잘못된 점에 대한 피드백은 언제든 환영합니다!
Abstract
최근 Test time adaptation 분야에서는 train에 사용했던 running mean과 variance를 사용하는 Conventional Batch Normalization(CBN)이 아닌 test batch에서의 running mean과 variance를 수정하는 Transductive Batch Normalization (TBN)에 의존합니다. TBN을 사용하는 것은 domain shift의 영향을 완화시킴으로써 성능 저하를 막을 수 있지만, 이는 test batch에서의 통계를 구하기에 큰 batch size를 필요로 하므로 현실적인 가정이 아닙니다. 따라서 본 논문에서는 CBN과 TBN의 trade off를 고려한 새로운 Test Time Normalization 방법을 제안합니다. TTN에서는 각 Batch Normalization(BN) layer의 domain-shift sensitivity를 측정하고 이를 통해 CBN과 TBN의 영향력을 조절합니다.
Introduction
Test-time Adaptation (TTA)
Test Time Adaptation (TTA)는 test time 동안에 source domain과 target domain에서의 domain shift를 해결하기 위해 제안된 방법입니다. 최근 TTA 방법에서는 1) 현재 test input으로부터 normalization statistics를 구하는 방법 2) entropy minimization과 같이 unsupervised로 model parameter를 최적화하는 방법 3) self-supervised loss를 활용하는 방법들이 연구되고 있습니다. 최근 제안된 방법들은 Test time에서 test bacth statistics에 많이 의존합니다(TBN). TBN을 위해서는 test batch size가 매우 커야 하며 single stationary distribution shift라는 가정이 필요합니다.
[ 제 생각에는? ]
1) Large test batch size
-> Test input에 대한 충분한 통계(mean and variance)가 있어야 해당 domain에 대해 적절한 adaptation이 가능하지 않을지?
2) Single stationary distribution shift
-> Test time에서 single stationary하지 않고 distribution shift가 바뀌는 상황에서, 바뀌기 전 test statistics와 shift가 발생하고 나서의 test statistics는 다르기 때문에 또다시 test batch statistics를 구해야 하는 문제점!
(d)에서의 왼쪽과 같이 TBN은 CBN과 TTN에 비해 batch size가 작을수록 error rate가 높아지는 것을 확인할 수 있습니다. 이는 TBN은 필연적으로 큰 test batch size가 필요하며 이것은 실용적이지 못합니다.
따라서 본 논문에서는 CBN과 TBN 둘 다 고려한 TTN 방법을 제안합니다. TTN 방법은 BN layer에서의 domain-shift sensitivity에 따라서 source domain과 target domain의 중요성을 계산합니다.
Source domain : 학습 시 사용했던 domain
Target domain : Test time에서의 domain
이를 위해 CBN weight와 TBN의 weight를 선형적으로 결합합니다. 만약 test data로의 adaptation이 필요해진다면, CBN의 weight보다 TBN의 weight를 더 크게 조절해 줍니다. TTN은 크게 2가지의 단계로 구성되어 있습니다. 첫 번째 단계는 post-training phase로써, 주어진 pretrained model의 BN layer가 가지는 affine parameter의 channel-wise sensitivity를 측정합니다. 두 번째 단계에서는 channel-wise sensitivity에 따라서 BN layer를 TTN layer로 대체하는 interpolating weight를 최적화합니다.
Method
TEST-TIME NORMALIZATION LAYER
BN layer의 input은 $z$로 표현하며 $B$는 batch size, $C$는 채널의 수, $H$는 차원의 높이, $W$는 차원의 너비를 의미합니다. $z$의 mean과 variance는 각각 $\mu$와 $\sigma^{2}$로 표현되며 식은 위와 같습니다. 일반적으로 training data의 $\mu_{s}$와 $\sigma^{2}_{s}$는 exponential moving average로 측정됩니다. BN layer에서 $\mu$와 $\sigma^{2}$은 learnable parameter $\gamma$, $\beta$에 의해서 scaling과 shifting 됩니다. 본 논문에서는 test time에서의 domain shift를 해결하기 위해, source와 test time에서 statistics를 결합하여 source statistics를 조절합니다. 이를 위해 $\alpha$라는 interpolating weight를 새롭게 제안하며 이는 아래와 같이 source와 test time mini batch에서의 mean과 variance를 적절히 조절합니다.
POST TRAINING
TTN은 testing을 수행하기 전, pretrained model에 대해서 위의 interpolating weight $\alpha$를 최적화하는 post training 과정을 수행합니다. 이 과정을 하는 동안에 $\alpha$를 제외한 모든 파라미터는 freeze됩니다.
Obtain Prior A
Domain shift에 대해서 BN layer의 channel sensitivity를 확인하기 위해, augmenting clean image를 통해 시험합니다. 원본 이미지 $x$와 augmented 이미지 $x^{\prime}$에 대한 모델의 feature는 각각 $z^{(l, c)}$, $\hat {z^{(l, c)}}$로 표현합니다. $(l, c)$는 $l$번째 layer의 $c$ index가 됩니다. 이 이미지들이 입력되었을 때의 차이를 통해 domain shift에 대한 Prior $\mathcal{A}$을 구하고 이 차이가 커질 수록 $(l,c)$는 domain shift에 sensitivity 합니다. 따라서 이럴 경우, shifted input에 대해 adaptation을 수행해야 합니다. Sensitivity를 측정하기 위해서는 affine parameter $\gamma$와 $\beta$의 gradients를 비교합니다.
(a-1)과 같이 $x$와 $x^{\prime}$에 대해 cross entropy를 통해 affine paramter의 gradients를 계산합니다. 이를 통해 gradient score의 계산은 아래와 같습니다.
식을 들여다보면 $g$와 $g^{\prime}$의 cosine similarity를 통해 gradient score $s$를 계산합니다. 이를 통해 $(l, c)$에 대한 최종적인 gradient distance score는
과 같습니다. 본 논문에서는 상대적인 차이를 극대화하기 위해 distance score에 제곱의 집합을 Prior $\mathcal {A}$로 정의합니다.
Optimization
본 논문에서의 최종적인 Loss는 original image $x$와 augmented image $x^{\prime}$의 일관된 예측을 위한 cross entropy loss $L_{CE}$와 $\alpha$가 initial prior $\mathcal {A}$로 부터 너무 멀어지는 것을 방지하기 위한 loss인 $L_{MSE} = \| \alpha - \mathcal{A} \| ^{2}$의 합으로 최종 loss term을 구성합니다.
Experiments
본 논문에서는 크게 Single domain adpatation, Continualy domain adaptation, Mixed domain adpataion에 대해서 CIFAR-10C, CIFAR100 C 데이터셋으로 실험을 진행합니다.
주목해야 할 점은 batch size가 줄어듬에 따라 기존의 방법들은 error rate가 매우 높아지는 반면, TTN은 비교적 낮은 error rate를 유지하고 있음을 알 수 있습니다. 또한 아래의 Ablation table은 제안한 파라미터 $\alpha$가 test에 대한 statistics를 반영하면서도 초기 Prior $\mathcal {A}$로부터 많이 멀어지지 않도록 하는 것이 중요함을 알 수 있습니다($L_{mse}$).
간단하면서 직관적인 아이디어로 TBN의 문제점을 지적하고 이를 해결한 논문입니다. Batch size가 작아짐에 따라서 기존의 방법론들의 성능이 많이 무너짐을 보여주고 TTN은 이를 잘 보완합니다. 비슷한 아이디어가 미팅 때 나와서 검색하다가 알게 된 논문인데 정말 잘 작성된 논문인 것 같아서 많이 배운 것 같습니다!