目錄

Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification

Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification

Introduction


這篇的題目是Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification。顧名思義,這篇想要解決的問題是半監督式的文本分類。方法是利用prototype來引導pseudo labeling。

Task

假設今天有個任務是要做文本分類,要分成經濟、政治和體育三個類別。

./Untitled.png

Problem - few label data

在訓練過程中,越多的label資料,訓練的模型效果會越好。而較少的label資料,則模型效果會越差。

我們會希望從unlabeled data中得到更多的標註資料,但是請專家標註資料很花費時間和金錢

Semi-supervised learning

為了解決labeld data資料過少的問題,我們希望能使用unlabeled data來幫助模型學習

Problem - Semi-supervised learning

雖然我們手上有許多未標注的資料,如果全部給模型學習,模型很難從那些未標注的資料上很快地學到資料特徵,會讓模型的學習效果不好。因此訓練出來的模型會underfitting,不容易區分出新的資料應該屬於哪一個類別,而預測錯誤

./Untitled%201.png

這會讓模型在分配pseudo labels時,錯誤的將某類別的資料判斷為其他類別,如圖中的藍色類別,導致藍色類別的資料數量越來越多,使模型產生偏差。

Solution

為了解決前面提到的underfitting,我們的目標是使用更多的labeled data來訓練模型,希望能從unlabeled data中,透過某些策略來取得一些有代表性的資料

這篇論文提出的策略是

  1. 使用labeled data找出prototype (一個類別中最具有代表性的資料點)
  2. unlabeled data越靠近prototype,表示他們很類似
  3. 透過策略產生pseudo label給unlabeled data
  4. unlabeled data經過augment後,產生augment data
  5. 把pseudo label和augment data合併為新的訓練資料
  6. 拿這份新的訓練資料和labeled data一起去訓練模型

./Untitled%202.png

Method


Unsupervised Data Augmentation

首先會先介紹Unsupervised Data Augmentation,因為這篇(prototype)有些想法是參考這個方法(UDA)

一開始先來說明普通的Data Augmentation。

因為我們只有少量的labeled data,我們希望能有更多的labeled data去訓練模型

因此我們會將labeled data拿去用各種方式改變他的句子、文字…,讓他跟原本的文本不一樣,但是lable還是相同的。

這裡會有個問題,如果一個句子經過Augment後,他的句子特徵已經不見了,但是label卻還是一樣,這會讓模型的性能降低。

./Untitled%203.png

因為labeled data經過augment後對模型的影響很小,因此這篇Unsupervised Data Augmentation透過對unlabeled data做augment,並讓模型來預測。

如果模型對unlabel和augmentation的預測結果是一致的話,表示模型有學習到這個樣本的特徵

./Untitled%204.png

最後,在訓練模型時會把supervised的loss和unsupervised 的loss綜合起來一起去做調整

./Untitled%205.png

Prototype-Guided Pseudo Labeling - supervised loss(Partition 1)

我們會將這個架構圖分為三個區塊

./Untitled%206.png

  • 第一個區塊是使用labeled data去訓練模型,並利用一個特殊的訓練方法控制training data的使用數量
  • 第二個區塊是先用labeled data找出一個具有代表性的點,這裡稱它為prototype,如果unlabeled data足夠靠近prototype,則給他一個pseudo label,再與unlabeled data合併為augmentation data,以此拿去訓練模型
  • 第三個區塊是使用prototype和unlabeled data做對比學習,希望能讓unlabeled data越靠近prototype,且讓prototype之間彼此遠離

首先說明第一個區塊,使用labeled data訓練模型的方式。

我們使用labeled data經過encoder編碼後,來訓練模型。這裡的訓練方式並不是使用全部的labeled data,而是使用Training Signal Annealing這個方式來調節每個iteration的labeled data數量

./Untitled%207.png

Training Signal Annealing

因為我們只有少量的labeled data,和很多的unlabeled data,如果全部都給模型訓練的話,很容易會讓labeled data overfitting,且讓unlabeled data under-fitting。(因為labeled data一下就overfitting,讓loss變小,導致模型已經從unlabeled data上學不到特徵了)

這個訓練信號退火的方法是,設定一個門檻值。假設總共要訓練十次,訓練一開始

  • t = 1 ,假設labeled data只能有5筆進入模型訓練
  • t = 5,門檻值變高,限制變寬鬆,可以讓20筆label data進入模型訓練
  • t = 10,門檻值全開,可以讓所有資料進入訓練

這方法的策略是在訓練階段逐步釋放labeled data,目的是讓模型不要太快overfitting

這是他的計算公式和過程。假設我們要總共要訓練10次,則T=10

./Untitled%208.png

  • $\tau$ 是一個初始參數,取決於類別總數(這裡假設有十個類別)的倒數,這裡假設$\tau$為0.5
  • t 是當前的iteration

透過計算,會發現當t越大,則threshold會越寬鬆,也符合前面提到的會逐步釋放labeled data

Annealing Supervised loss

在label data上的loss計算公式如圖,log右邊是softmax,而中間的I是一個二元關係,可以想成是True/False

這裡的條件為,當模型預測labeled data的置信度<門檻值時,I= 0,不去計算loss

當置信度>門檻值時,I=1,拿這個樣本去算loss

這是為了鼓勵模型去學習低置信度的樣本

./Untitled%209.png

Prototype-Guided Pseudo Labeling - unsupervised loss(Partition 2)

這裡要說明unsupervised loss的訓練和計算過程。

./Untitled%2010.png

將上面這張圖解構後,畫成如下的流程圖

./Untitled%2011.png

  1. 先將unlabeled data透過encoder轉成embedding
  2. 將labeled data透過encoder轉成embedding,再透過Prototype extraction找到每個類別的prototype
  3. 將prototype與unlabeled data算距離
  4. 如果unlabeled data與prototype的距離足夠近,會考慮將該unlabeled data列入loss計算
  5. 把該unlabeled data做augment後,用augment data與pseudo label合併為新的訓練資料集,再拿去訓練模型

右上角的公式為判斷unlabeled data的方式,採用置信度最高的類別,為其pseudo label

Prototype Extraction

我們的目標是找到每個類別中的最有代表性的資料。

./Untitled%2012.png

假設這是一堆有label的data,經過數量計算後,發現藍色的點有6個,綠色的點有5個

接下來就是把藍色這一群的點,透過model轉為embedding後相加,再取平均,就可以得到藍色類別的prototype

我們希望unlabeled data能越靠近prototype越好,則他屬於該類別的機率也就越大

Prototype-guided Pseudo labeling Module

模型如果沒有學到資料特徵,會造成underfitting,這會讓後續模型在預測unlabeled data時,給予錯誤的標籤。這會導致多數類別的數量會變得更多,造成偏差

解決方法是,找到最靠近prototype的前K個unlabeled data,加上一個策略,這策略是不要選擇在前t次過度訓練的類別,這就是這個PGPL module提出的解決方法

./Untitled%2013.png

右上角是假設現在有三個類別,在前t次分別得到的unlabeled data,在這一次又被分配到一些Unlabeled data。但是經過策略後,A已經拿太多資料,因此這次不能拿。B因為介於兩者之間,且經過策略後,可以拿兩筆。C因為是最少,因此可以全拿

下圖是假設現在有10筆augment data,經過計算後,得到每一類別在這個iteration中有多少筆資料可以拿。

./Untitled%2014.png

這個策略的目標是計算每個類別應該在這一個iteration中該拿多少pseudo label

透過這個公式,可以得到每個類別在t iteration可以選擇多少pseudo label

以下是計算過程。

(Augment Data 和 unlabeled data是相同的embedding)

接著來說明一下符號

./Untitled%2015.png

$\mu^c_{<t}$ 是前t次中,該類別所拿到的pseudo label數量

$\gamma_t$ 是前t次中,所有類別中數量最小的值

$k_c$ 是這一次,各類別可以拿到的pseudo label的值

上圖中決定$k_c$ 的方式是經過右邊條件計算後,各類別可以拿到的pseudo label的數量

第二部分是計算這些augment data到prototype的距離

./Untitled%2016.png

假設前面算出來的$k_c$ ,如上圖中,右邊所示。(藍色星星為prototype,各類別的prototype不一樣,上圖為示意圖)

在經濟類別,$k_c$ 為0,表示他在這次沒有任何資料被選到,因此距離為0

政治的$k_c$ 是1,因此選擇一個離prototype最近的一個augment data

同理,體育也是,找出五個最近的augmentation。

找出最近的augment data後,以他到prototype的距離為$d_c$

Selective Unsupervised Loss

藍色框框是計算augment data 的softmax

./Untitled%2017.png

$I$ 是判斷

  • 如果pseudo label真的屬於該類別
  • 且unlabeled data到prototype的距離,小於augment data到prototype的距離

如果兩條件都成立則為1,反則為0

Prototype-Guided Pseudo Labeling(Partition3)

./Untitled%2018.png

這部分會透過prototype和augment data來做對比學習,希望能讓augment data越靠近該類別,且遠離其他類別

Prototype-Anchored Contrasting

這個PAC module是希望能讓每個instance 更靠近prototype,並且讓labeled data越接近該類別的prototype,讓labeled data距離不同類別的prototype越遠越好

上面的紅色框框是labeled data的loss,下方是augment data的loss

橘色框框的部分為權重,模型預測augment data後會有置信度,這裡利用置信度當作權重來調整loss

Prototype-Guided Pseudo Labeling

最後將所有loss合併,並在Lp和Lu上加上權重調整

./Untitled%2019.png

Experiment


Dataset

這裡用四個資料集來做實驗

./Untitled%2020.png

Baseline

  • BERT
    • 用labeled data訓練模型,屬於監督式學習
  • UDA(unsupervised data augmentation)
    • 把unlabeled data拿去做資料增強後,與原本的unlabeled data計算loss,最後跟labeled data的supervised loss相加後,做整體的調整
  • MixText
    • TMix : 一種新的資料增強方式,對兩輸入向量在hidden 空間做插值,生成全新的向量。也是屬於一種半監督式學習

Experiment Results

./Untitled%2021.png

首先上方為資料集名稱,右邊是baseline model和這篇論文提出來的PGPL模型

資料集下方的數字為實驗中使用的labeled data的數量

如果以AG News來看

  • 使用10個label data
    • semi-supervised 是有明顯比supervised好很多
  • 使用200個label data
    • semi-supervised 和 supervised之間的差距縮小
  • PGPL在不同的lable data數量下,都有不錯的表現

Compare different partition protocols

這裡分成supervised和semi-supervised

  • supervised
    • all label
    • 10 labels
  • semi-supervised
    • 10 labels

比較

  1. supervised
    1. 使用全部的labels > 10 labels
  2. 只使用10 labels
    1. semi-supervised > supervised
  3. all labels 的 supervised > 10 labels 的 semi-supervised

Ablation Study

  1. PGP和PAC可以獨立提高模型性能
  2. TSA helps too, besides DBpedia

Conclusion


  • 這篇論文針對半監督的文本分類任務,提出了一種結合 PAC 和 PGP 策略的半監督模型 PGPL。
  • 建構原型後,使用PAC把屬於同一類的text embedding聚在一起,緩解underfitting的問題
  • PGP 選擇prototype附近可靠的pseudo label來解決不平衡資料帶來的偏差問題