Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification
![/5_prototype-guided-pseudo-labeling-for-semi-supervis-8afb431333e341cd84dba2211420f974/featured.png /5_prototype-guided-pseudo-labeling-for-semi-supervis-8afb431333e341cd84dba2211420f974/featured.png](/5_prototype-guided-pseudo-labeling-for-semi-supervis-8afb431333e341cd84dba2211420f974/featured.png)
Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification
Introduction
這篇的題目是Prototype-Guided Pseudo Labeling for Semi-Supervised Text Classification。顧名思義,這篇想要解決的問題是半監督式的文本分類。方法是利用prototype來引導pseudo labeling。
Task
假設今天有個任務是要做文本分類,要分成經濟、政治和體育三個類別。
Problem - few label data
在訓練過程中,越多的label資料,訓練的模型效果會越好。而較少的label資料,則模型效果會越差。
我們會希望從unlabeled data中得到更多的標註資料,但是請專家標註資料很花費時間和金錢
Semi-supervised learning
為了解決labeld data資料過少的問題,我們希望能使用unlabeled data來幫助模型學習
Problem - Semi-supervised learning
雖然我們手上有許多未標注的資料,如果全部給模型學習,模型很難從那些未標注的資料上很快地學到資料特徵,會讓模型的學習效果不好。因此訓練出來的模型會underfitting,不容易區分出新的資料應該屬於哪一個類別,而預測錯誤
這會讓模型在分配pseudo labels時,錯誤的將某類別的資料判斷為其他類別,如圖中的藍色類別,導致藍色類別的資料數量越來越多,使模型產生偏差。
Solution
為了解決前面提到的underfitting,我們的目標是使用更多的labeled data來訓練模型,希望能從unlabeled data中,透過某些策略來取得一些有代表性的資料
這篇論文提出的策略是
- 使用labeled data找出prototype (一個類別中最具有代表性的資料點)
- unlabeled data越靠近prototype,表示他們很類似
- 透過策略產生pseudo label給unlabeled data
- unlabeled data經過augment後,產生augment data
- 把pseudo label和augment data合併為新的訓練資料
- 拿這份新的訓練資料和labeled data一起去訓練模型
Method
Unsupervised Data Augmentation
首先會先介紹Unsupervised Data Augmentation,因為這篇(prototype)有些想法是參考這個方法(UDA)
一開始先來說明普通的Data Augmentation。
因為我們只有少量的labeled data,我們希望能有更多的labeled data去訓練模型
因此我們會將labeled data拿去用各種方式改變他的句子、文字…,讓他跟原本的文本不一樣,但是lable還是相同的。
這裡會有個問題,如果一個句子經過Augment後,他的句子特徵已經不見了,但是label卻還是一樣,這會讓模型的性能降低。
因為labeled data經過augment後對模型的影響很小,因此這篇Unsupervised Data Augmentation透過對unlabeled data做augment,並讓模型來預測。
如果模型對unlabel和augmentation的預測結果是一致的話,表示模型有學習到這個樣本的特徵
最後,在訓練模型時會把supervised的loss和unsupervised 的loss綜合起來一起去做調整
Prototype-Guided Pseudo Labeling - supervised loss(Partition 1)
我們會將這個架構圖分為三個區塊
- 第一個區塊是使用labeled data去訓練模型,並利用一個特殊的訓練方法控制training data的使用數量
- 第二個區塊是先用labeled data找出一個具有代表性的點,這裡稱它為prototype,如果unlabeled data足夠靠近prototype,則給他一個pseudo label,再與unlabeled data合併為augmentation data,以此拿去訓練模型
- 第三個區塊是使用prototype和unlabeled data做對比學習,希望能讓unlabeled data越靠近prototype,且讓prototype之間彼此遠離
首先說明第一個區塊,使用labeled data訓練模型的方式。
我們使用labeled data經過encoder編碼後,來訓練模型。這裡的訓練方式並不是使用全部的labeled data,而是使用Training Signal Annealing這個方式來調節每個iteration的labeled data數量
Training Signal Annealing
因為我們只有少量的labeled data,和很多的unlabeled data,如果全部都給模型訓練的話,很容易會讓labeled data overfitting,且讓unlabeled data under-fitting。(因為labeled data一下就overfitting,讓loss變小,導致模型已經從unlabeled data上學不到特徵了)
這個訓練信號退火的方法是,設定一個門檻值。假設總共要訓練十次,訓練一開始
- t = 1 ,假設labeled data只能有5筆進入模型訓練
- t = 5,門檻值變高,限制變寬鬆,可以讓20筆label data進入模型訓練
- t = 10,門檻值全開,可以讓所有資料進入訓練
這方法的策略是在訓練階段逐步釋放labeled data,目的是讓模型不要太快overfitting
這是他的計算公式和過程。假設我們要總共要訓練10次,則T=10
- $\tau$ 是一個初始參數,取決於類別總數(這裡假設有十個類別)的倒數,這裡假設$\tau$為0.5
- t 是當前的iteration
透過計算,會發現當t越大,則threshold會越寬鬆,也符合前面提到的會逐步釋放labeled data
Annealing Supervised loss
在label data上的loss計算公式如圖,log右邊是softmax,而中間的I是一個二元關係,可以想成是True/False
這裡的條件為,當模型預測labeled data的置信度<門檻值時,I= 0,不去計算loss
當置信度>門檻值時,I=1,拿這個樣本去算loss
這是為了鼓勵模型去學習低置信度的樣本
Prototype-Guided Pseudo Labeling - unsupervised loss(Partition 2)
這裡要說明unsupervised loss的訓練和計算過程。
將上面這張圖解構後,畫成如下的流程圖
- 先將unlabeled data透過encoder轉成embedding
- 將labeled data透過encoder轉成embedding,再透過Prototype extraction找到每個類別的prototype
- 將prototype與unlabeled data算距離
- 如果unlabeled data與prototype的距離足夠近,會考慮將該unlabeled data列入loss計算
- 把該unlabeled data做augment後,用augment data與pseudo label合併為新的訓練資料集,再拿去訓練模型
右上角的公式為判斷unlabeled data的方式,採用置信度最高的類別,為其pseudo label
Prototype Extraction
我們的目標是找到每個類別中的最有代表性的資料。
假設這是一堆有label的data,經過數量計算後,發現藍色的點有6個,綠色的點有5個
接下來就是把藍色這一群的點,透過model轉為embedding後相加,再取平均,就可以得到藍色類別的prototype
我們希望unlabeled data能越靠近prototype越好,則他屬於該類別的機率也就越大
Prototype-guided Pseudo labeling Module
模型如果沒有學到資料特徵,會造成underfitting,這會讓後續模型在預測unlabeled data時,給予錯誤的標籤。這會導致多數類別的數量會變得更多,造成偏差
解決方法是,找到最靠近prototype的前K個unlabeled data,加上一個策略,這策略是不要選擇在前t次過度訓練的類別,這就是這個PGPL module提出的解決方法
右上角是假設現在有三個類別,在前t次分別得到的unlabeled data,在這一次又被分配到一些Unlabeled data。但是經過策略後,A已經拿太多資料,因此這次不能拿。B因為介於兩者之間,且經過策略後,可以拿兩筆。C因為是最少,因此可以全拿
下圖是假設現在有10筆augment data,經過計算後,得到每一類別在這個iteration中有多少筆資料可以拿。
這個策略的目標是計算每個類別應該在這一個iteration中該拿多少pseudo label
透過這個公式,可以得到每個類別在t iteration可以選擇多少pseudo label
以下是計算過程。
(Augment Data 和 unlabeled data是相同的embedding)
接著來說明一下符號
$\mu^c_{<t}$ 是前t次中,該類別所拿到的pseudo label數量
$\gamma_t$ 是前t次中,所有類別中數量最小的值
$k_c$ 是這一次,各類別可以拿到的pseudo label的值
上圖中決定$k_c$ 的方式是經過右邊條件計算後,各類別可以拿到的pseudo label的數量
第二部分是計算這些augment data到prototype的距離
假設前面算出來的$k_c$ ,如上圖中,右邊所示。(藍色星星為prototype,各類別的prototype不一樣,上圖為示意圖)
在經濟類別,$k_c$ 為0,表示他在這次沒有任何資料被選到,因此距離為0
政治的$k_c$ 是1,因此選擇一個離prototype最近的一個augment data
同理,體育也是,找出五個最近的augmentation。
找出最近的augment data後,以他到prototype的距離為$d_c$
Selective Unsupervised Loss
藍色框框是計算augment data 的softmax
$I$ 是判斷
- 如果pseudo label真的屬於該類別
- 且unlabeled data到prototype的距離,小於augment data到prototype的距離
如果兩條件都成立則為1,反則為0
Prototype-Guided Pseudo Labeling(Partition3)
這部分會透過prototype和augment data來做對比學習,希望能讓augment data越靠近該類別,且遠離其他類別
Prototype-Anchored Contrasting
這個PAC module是希望能讓每個instance 更靠近prototype,並且讓labeled data越接近該類別的prototype,讓labeled data距離不同類別的prototype越遠越好
上面的紅色框框是labeled data的loss,下方是augment data的loss
橘色框框的部分為權重,模型預測augment data後會有置信度,這裡利用置信度當作權重來調整loss
Prototype-Guided Pseudo Labeling
最後將所有loss合併,並在Lp和Lu上加上權重調整
Experiment
Dataset
這裡用四個資料集來做實驗
Baseline
- BERT
- 用labeled data訓練模型,屬於監督式學習
- UDA(unsupervised data augmentation)
- 把unlabeled data拿去做資料增強後,與原本的unlabeled data計算loss,最後跟labeled data的supervised loss相加後,做整體的調整
- MixText
- TMix : 一種新的資料增強方式,對兩輸入向量在hidden 空間做插值,生成全新的向量。也是屬於一種半監督式學習
Experiment Results
首先上方為資料集名稱,右邊是baseline model和這篇論文提出來的PGPL模型
資料集下方的數字為實驗中使用的labeled data的數量
如果以AG News來看
- 使用10個label data
- semi-supervised 是有明顯比supervised好很多
- 使用200個label data
- semi-supervised 和 supervised之間的差距縮小
- PGPL在不同的lable data數量下,都有不錯的表現
Compare different partition protocols
這裡分成supervised和semi-supervised
- supervised
- all label
- 10 labels
- semi-supervised
- 10 labels
比較
- supervised
- 使用全部的labels > 10 labels
- 只使用10 labels
- semi-supervised > supervised
- all labels 的 supervised > 10 labels 的 semi-supervised
Ablation Study
- PGP和PAC可以獨立提高模型性能
- TSA helps too, besides DBpedia
Conclusion
- 這篇論文針對半監督的文本分類任務,提出了一種結合 PAC 和 PGP 策略的半監督模型 PGPL。
- 建構原型後,使用PAC把屬於同一類的text embedding聚在一起,緩解underfitting的問題
- PGP 選擇prototype附近可靠的pseudo label來解決不平衡資料帶來的偏差問題