์ธ๊ณต์ง๋ฅ
์ง์ ์ ํธ๋ ์ต์ ํ: ์ ์ฒด ๊ฐ์ด๋

์ธ๊ฐ์ ๊ฐ์น์ ์ ํธ๋์ ๋ง์ถฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ(LLM)์ ์กฐ์ ํ๋ ๊ฒ์ ์ด๋ ต์ต๋๋ค. ๋ฑ์ ์ ํต์ ์ธ ๋ฐฉ๋ฒ ์ฌ๋์ ํผ๋๋ฐฑ์ ํตํ ๊ฐํ ํ์ต (RLHF)๋ ์ฌ๋์ ์ ๋ ฅ์ ํตํฉํ์ฌ ๋ชจ๋ธ ์ถ๋ ฅ์ ๊ฐ์ ํจ์ผ๋ก์จ ๊ธธ์ ์ด์์ต๋๋ค. ๊ทธ๋ฌ๋ RLHF๋ ๋ณต์กํ๊ณ ๋ฆฌ์์ค ์ง์ฝ์ ์ผ ์ ์์ผ๋ฉฐ ์๋นํ ๊ณ์ฐ ๋ฅ๋ ฅ๊ณผ ๋ฐ์ดํฐ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์ง์ ์ ํธ๋ ์ต์ ํ (DPO)๋ ์๋กญ๊ณ ๋์ฑ ๊ฐ์ํ๋ ์ ๊ทผ ๋ฐฉ์์ผ๋ก ๋ฑ์ฅํ์ฌ ์ด๋ฌํ ๊ธฐ์กด ๋ฐฉ๋ฒ์ ๋ํ ํจ์จ์ ์ธ ๋์์ ์ ๊ณตํฉ๋๋ค. DPO๋ ์ต์ ํ ํ๋ก์ธ์ค๋ฅผ ๋จ์ํํจ์ผ๋ก์จ ๊ณ์ฐ ๋ถ๋ด์ ์ค์ผ ๋ฟ๋ง ์๋๋ผ ์ธ๊ฐ์ ์ ํธ๋์ ๋น ๋ฅด๊ฒ ์ ์ํ๋ ๋ชจ๋ธ์ ๋ฅ๋ ฅ์ ํฅ์์ํต๋๋ค.
์ด ๊ฐ์ด๋์์๋ DPO์ ๋ํด ์์ธํ ์์๋ณด๊ณ DPO์ ๊ธฐ์ด, ๊ตฌํ ๋ฐ ์ค์ ์์ฉ ํ๋ก๊ทธ๋จ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์ ํธ ์ ๋ ฌ์ ํ์์ฑ
DPO๋ฅผ ์ดํดํ๋ ค๋ฉด LLM์ ์ธ๊ฐ์ ์ ํธ๋์ ๋ง์ถ๋ ๊ฒ์ด ์ ๊ทธ๋ ๊ฒ ์ค์ํ์ง ์ดํดํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์ธ์์ ์ธ ์ญ๋์๋ ๋ถ๊ตฌํ๊ณ ๋ฐฉ๋ํ ๋ฐ์ดํฐ ์ธํธ๋ก ํ๋ จ๋ LLM์ ๋๋๋ก ์ผ๊ด์ฑ์ด ์๊ฑฐ๋ ํธํฅ์ ์ด๊ฑฐ๋ ์ธ๊ฐ์ ๊ฐ์น์ ๋ง์ง ์๋ ์ถ๋ ฅ์ ์์ฑํ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ถ์ผ์น๋ ๋ค์ํ ๋ฐฉ์์ผ๋ก ๋ํ๋ ์ ์์ต๋๋ค.
- ์์ ํ์ง ์๊ฑฐ๋ ์ ํดํ ์ฝํ ์ธ ์์ฑ
- ๋ถ์ ํํ๊ฑฐ๋ ์คํด์ ์์ง๊ฐ ์๋ ์ ๋ณด ์ ๊ณต
- ํ๋ จ ๋ฐ์ดํฐ์ ์กด์ฌํ๋ ํธํฅ ํ์
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ฐ๊ตฌ์๋ค์ ์ธ๊ฐ์ ํผ๋๋ฐฑ์ ์ฌ์ฉํ์ฌ LLM์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ธฐ์ ์ ๊ฐ๋ฐํ์ต๋๋ค. ์ด๋ฌํ ์ ๊ทผ๋ฒ ์ค ๊ฐ์ฅ ๋๋๋ฌ์ง ๊ฒ์ RLHF์์ต๋๋ค.
RLHF ์ดํด: DPO์ ์ ๊ตฌ์
์ธ๊ฐ ํผ๋๋ฐฑ์ ํตํ ๊ฐํ ํ์ต(RLHF)์ LLM์ ์ธ๊ฐ ์ ํธ๋์ ๋ง์ถ๋ ๋ฐ ์ฌ์ฉ๋๋ ๋ฐฉ๋ฒ์ด์์ต๋๋ค. ๋ณต์ก์ฑ์ ์ดํดํ๊ธฐ ์ํด RLHF ํ๋ก์ธ์ค๋ฅผ ๋ถ์ํด ๋ณด๊ฒ ์ต๋๋ค.
a) ๊ฐ๋ ํ ๋ฏธ์ธ ์กฐ์ (SFT): ํ๋ก์ธ์ค๋ ๊ณ ํ์ง ์๋ต ๋ฐ์ดํฐ ์ธํธ์์ ์ฌ์ ํ๋ จ๋ LLM์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. ์ด ๋จ๊ณ๋ ๋ชจ๋ธ์ด ๋์ ์์ ์ ๋ํด ๋ณด๋ค ๊ด๋ จ์ฑ์ด ๋๊ณ ์ผ๊ด๋ ์ถ๋ ฅ์ ์์ฑํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค.
b) ๋ณด์ ๋ชจ๋ธ๋ง: ์ธ๊ฐ์ ์ ํธ๋๋ฅผ ์์ธกํ๊ธฐ ์ํด ๋ณ๋์ ๋ณด์ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค. ์ฌ๊ธฐ์๋ ๋ค์์ด ํฌํจ๋ฉ๋๋ค.
- ์ฃผ์ด์ง ํ๋กฌํํธ์ ๋ํ ์๋ต ์ ์์ฑ
- ์ธ๊ฐ์ด ์ ํธํ๋ ๋ฐ์์ ํ๊ฐํ๊ฒ ํจ
- ์ด๋ฌํ ์ ํธ๋๋ฅผ ์์ธกํ๊ธฐ ์ํ ๋ชจ๋ธ ํ์ต
c) ๊ฐํ ํ์ต: ๋ฏธ์ธ ์กฐ์ ๋ LLM์ ๊ฐํ ํ์ต์ ์ฌ์ฉํ์ฌ ๋์ฑ ์ต์ ํ๋ฉ๋๋ค. ๋ณด์ ๋ชจ๋ธ์ ํผ๋๋ฐฑ์ ์ ๊ณตํ์ฌ LLM์ด ์ธ๊ฐ ์ ํธ๋์ ๋ง๋ ์๋ต์ ์์ฑํ๋๋ก ์๋ดํฉ๋๋ค.
๋ค์์ RLHF ํ๋ก์ธ์ค๋ฅผ ์ค๋ช ํ๊ธฐ ์ํ ๋จ์ํ๋ Python ์์ฌ์ฝ๋์ ๋๋ค.
RLHF๋ ํจ๊ณผ์ ์ด์ง๋ง ๋ช ๊ฐ์ง ๋จ์ ์ด ์์ต๋๋ค.
- ์ฌ๋ฌ ๋ชจ๋ธ(SFT, ๋ณด์ ๋ชจ๋ธ, RL ์ต์ ํ ๋ชจ๋ธ)์ ํ๋ จํ๊ณ ์ ์ง ๊ด๋ฆฌํด์ผ ํฉ๋๋ค.
- RL ํ๋ก์ธ์ค๋ ๋ถ์์ ํ๊ณ ํ์ดํผํ๋ผ๋ฏธํฐ์ ๋ฏผ๊ฐํ ์ ์์ต๋๋ค.
- ๊ณ์ฐ ๋น์ฉ์ด ๋ง์ด ๋ค๊ณ ๋ชจ๋ธ์ ํตํด ๋ง์ ์ ๋ฐฉํฅ ๋ฐ ์ญ๋ฐฉํฅ ํต๊ณผ๊ฐ ํ์ํฉ๋๋ค.
์ด๋ฌํ ์ ํ์ผ๋ก ์ธํด ๋ ๊ฐ๋จํ๊ณ ํจ์จ์ ์ธ ๋์์ ์ฐพ๊ฒ ๋์๊ณ , ์ด๋ DPO ๊ฐ๋ฐ๋ก ์ด์ด์ก์ต๋๋ค.
์ง์ ์ ํธ ์ต์ ํ: ํต์ฌ ๊ฐ๋
์ด ์ด๋ฏธ์ง๋ LLM ์ถ๋ ฅ์ ์ธ๊ฐ ์ ํธ๋์ ๋ง์ถ๋ ๋ ๊ฐ์ง ์ ๊ทผ ๋ฐฉ์, ์ฆ ์ธ๊ฐ ํผ๋๋ฐฑ์ ํตํ ๊ฐํ ํ์ต(RLHF)๊ณผ ์ง์ ์ ํธ๋ ์ต์ ํ(DPO)๋ฅผ ๋์กฐํฉ๋๋ค. RLHF๋ ๋ณด์ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ฐ๋ณต์ ์ธ ํผ๋๋ฐฑ ๋ฃจํ๋ฅผ ํตํด ์ธ์ด ๋ชจ๋ธ์ ์ ์ฑ ์ ์๋ดํ๋ ๋ฐ๋ฉด, DPO๋ ์ ํธ๋ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ์ธ๊ฐ์ด ์ ํธํ๋ ์๋ต๊ณผ ์ผ์นํ๋๋ก ๋ชจ๋ธ ์ถ๋ ฅ์ ์ง์ ์ต์ ํํฉ๋๋ค. ์ด ๋น๊ต๋ ๊ฐ ๋ฐฉ๋ฒ์ ์ฅ์ ๊ณผ ์ ์ฌ์ ์ธ ์ ์ฉ์ ๊ฐ์กฐํ์ฌ ๋ฏธ๋์ LLM์ด ์ธ๊ฐ์ ๊ธฐ๋์ ๋ ์ ๋ถ์ํ๋๋ก ํ๋ จํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋ํ ํต์ฐฐ๋ ฅ์ ์ ๊ณตํฉ๋๋ค.
DPO์ ํต์ฌ ์์ด๋์ด:
a) ์์์ ๋ณด์ ๋ชจ๋ธ๋ง: DPO๋ ์ธ์ด ๋ชจ๋ธ ์์ฒด๋ฅผ ์์์ ๋ณด์ ํจ์๋ก ์ทจ๊ธํ์ฌ ๋ณ๋์ ๋ณด์ ๋ชจ๋ธ์ด ํ์ํ์ง ์์ต๋๋ค.
b) ์ ์ฑ ๊ธฐ๋ฐ ๊ณต์ํ: DPO๋ ๋ณด์ ํจ์๋ฅผ ์ต์ ํํ๋ ๋์ ์ ์ฑ (์ธ์ด ๋ชจ๋ธ)์ ์ง์ ์ต์ ํํ์ฌ ์ ํธ ์๋ต ํ๋ฅ ์ ์ต๋ํํฉ๋๋ค.
c) ํ์ํ ์๋ฃจ์ : DPO๋ ์ต์ ์ ์ ์ฑ ์ ๋ํ ํ์ํ ์๋ฃจ์ ์ ํ์ฉํ๋ ์ํ์ ํต์ฐฐ๋ ฅ์ ํ์ฉํ์ฌ ๋ฐ๋ณต์ ์ธ RL ์ ๋ฐ์ดํธ๊ฐ ํ์ํ์ง ์์ต๋๋ค.
DPO ๊ตฌํ: ์ค์ฉ์ ์ธ ์ฝ๋ ์ฐ์ต
์๋ ์ด๋ฏธ์ง๋ PyTorch๋ฅผ ์ฌ์ฉํ์ฌ DPO ์์ค ๊ธฐ๋ฅ์ ๊ตฌํํ๋ ์ฝ๋ ์กฐ๊ฐ์ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ธ์ด ๋ชจ๋ธ์ด ์ธ๊ฐ ์ ํธ๋์ ๋ฐ๋ผ ์ถ๋ ฅ์ ์ฐ์ ์์๋ฅผ ์ง์ ํ๋ ๋ฐฉ๋ฒ์ ๊ฐ์ ํ๋ ๋ฐ ์ค์ํ ์ญํ ์ ํฉ๋๋ค. ์ฃผ์ ๊ตฌ์ฑ ์์์ ๋ํ ๋ถ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ๊ธฐ๋ฅ ์๋ช
๋ค์
dpo_loss
ํจ์๋ ์ ์ฑ ๋ก๊ทธ ํ๋ฅ (pi_logps
), ์ฐธ์กฐ ๋ชจ๋ธ ๋ก๊ทธ ํ๋ฅ (ref_logps
), ์ ํธ ๋ฐ ๋น์ ํธ ์์ฑ์ ๋ํ๋ด๋ ์ง์(yw_idxs
,yl_idxs
). ์ถ๊ฐ์ ์ผ๋ก,beta
๋งค๊ฐ๋ณ์๋ KL ํ๋ํฐ์ ๊ฐ๋๋ฅผ ์ ์ดํฉ๋๋ค. - ๋ก๊ทธ ํ๋ฅ ์ถ์ถ: ์ฝ๋๋ ์ ์ฑ ๋ชจ๋ธ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ ๋ชจ๋์์ ์ ํธ ์๋ฃ ๋ฐ ๋น์ ํธ ์๋ฃ์ ๋ํ ๋ก๊ทธ ํ๋ฅ ์ ์ถ์ถํฉ๋๋ค.
- ๋ก๊ทธ ๋น์จ ๊ณ์ฐ: ์ ํธ ์๋ฃ์ ๋น์ ํธ ์๋ฃ์ ๋ํ ๋ก๊ทธ ํ๋ฅ ์ ์ฐจ์ด๋ ์ ์ฑ ๋ชจ๋ธ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ ๋ชจ๋์ ๋ํด ๊ณ์ฐ๋ฉ๋๋ค. ์ด ๋น์จ์ ์ต์ ํ์ ๋ฐฉํฅ๊ณผ ๊ท๋ชจ๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐ ์ค์ํฉ๋๋ค.
- ์์ค ๋ฐ ๋ณด์ ๊ณ์ฐ: ์์ค์ ๋ค์์ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋ฉ๋๋ค.
logsigmoid
๊ธฐ๋ฅ์ ์ํํ๋ ๋ฐ๋ฉด, ๋ณด์์ ์ ์ฑ ๋ก๊ทธ ํ๋ฅ ๊ณผ ์ฐธ์กฐ ๋ก๊ทธ ํ๋ฅ ๊ฐ์ ์ฐจ์ด๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์กฐ์ ํ์ฌ ๊ฒฐ์ ๋ฉ๋๋ค.beta
.
์ด๋ฌํ ๋ชฉํ๋ฅผ ๋ฌ์ฑํ๋ ๋ฐฉ๋ฒ์ ์ดํดํ๊ธฐ ์ํด DPO ์ด๋ฉด์ ์ํ์ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
DPO์ ์ํ
DPO๋ ์ ํธ ํ์ต ๋ฌธ์ ๋ฅผ ์๋ฆฌํ๊ฒ ์ฌ๊ตฌ์ฑํ ๊ฒ์ ๋๋ค. ๋จ๊ณ๋ณ ๋ถ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
a) ์ถ๋ฐ์ : KL ์ ์ฝ ๋ณด์ ๊ทน๋ํ
์๋ RLHF ๋ชฉํ๋ ๋ค์๊ณผ ๊ฐ์ด ํํ๋ ์ ์์ต๋๋ค.
- ฯฮธ๋ ์ฐ๋ฆฌ๊ฐ ์ต์ ํํ๋ ์ ์ฑ (์ธ์ด ๋ชจ๋ธ)์ ๋๋ค.
- r(x,y)๋ ๋ณด์ ํจ์์ ๋๋ค.
- ฯref๋ ์ฐธ์กฐ ์ ์ฑ (๋ณดํต ์ด๊ธฐ SFT ๋ชจ๋ธ)์ ๋๋ค.
- ฮฒ๋ KL ๋ฐ์ฐ ์ ์ฝ ์กฐ๊ฑด์ ๊ฐ๋๋ฅผ ์ ์ดํฉ๋๋ค.
b) ์ต์ ์ ์ ์ฑ ํํ: ์ด ๋ชฉํ์ ๋ํ ์ต์ ์ ์ ์ฑ ์ ๋ค์๊ณผ ๊ฐ์ ํ์์ ์ทจํ๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
ฯ_r(y|x) = 1/Z(x) * ฯref(y|x) * exp(1/ฮฒ * r(x,y))
์ฌ๊ธฐ์ Z(x)๋ ์ ๊ทํ ์์์ ๋๋ค.
c) ๋ณด์ ์ ์ฑ ์ด์ค์ฑ: DPO์ ํต์ฌ ํต์ฐฐ๋ ฅ์ ์ต์ ์ ์ ์ฑ ์ธก๋ฉด์์ ๋ณด์ ๊ธฐ๋ฅ์ ํํํ๋ ๊ฒ์ ๋๋ค.
r(x,y) = ฮฒ * log(ฯ_r(y|x) / ฯref(y|x)) + ฮฒ * log(Z(x))
d) ์ ํธ ๋ชจ๋ธ ์ ํธ๊ฐ Bradley-Terry ๋ชจ๋ธ์ ๋ฐ๋ฅธ๋ค๊ณ ๊ฐ์ ํ๋ฉด y1๋ณด๋ค y2์ ์ ํธํ ํ๋ฅ ์ ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
p*(y1 โป y2 | x) = ฯ(r*(x,y1) - r*(x,y2))
์ฌ๊ธฐ์ ฯ๋ ๋ก์ง์คํฑ ํจ์์ ๋๋ค.
e) DPO ๋ชฉํ ๋ณด์ ์ ์ฑ ์ด์ค์ฑ์ ์ ํธ ๋ชจ๋ธ๋ก ๋์ฒดํ๋ฉด ๋ค์๊ณผ ๊ฐ์ DPO ๋ชฉํ์ ๋๋ฌํฉ๋๋ค.
L_DPO(ฯฮธ; ฯref) = -E_(x,y_w,y_l)~D [log ฯ(ฮฒ * log(ฯฮธ(y_w|x) / ฯref(y_w|x)) - ฮฒ * log(ฯฮธ(y_l|x) / ฯref(y_l|x)))]
์ด ๋ชฉํ๋ RL ์๊ณ ๋ฆฌ์ฆ ์์ด๋ ํ์ค ๊ฒฝ์ฌํ๊ฐ๋ฒ์ ์ฌ์ฉํ์ฌ ์ต์ ํํ ์ ์์ต๋๋ค.
DPO ๊ตฌํ
์ด์ DPO์ ๊ธฐ๋ณธ ์ด๋ก ์ ์ดํดํ์ผ๋ฏ๋ก ์ด๋ฅผ ์ค์ ๋ก ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์ฐ๋ฆฌ๋ ์ฌ์ฉํ ๊ฒ์ด๋ค Python ๊ทธ๋ฆฌ๊ณ ํ์ด ํ ์น ์ด ์์ ๊ฒฝ์ฐ:
import torch import torch.nn.functional as F class DPOTrainer: def __init__(self, model, ref_model, beta=0.1, lr=1e-5): self.model = model self.ref_model = ref_model self.beta = beta self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs): """ pi_logps: policy logprobs, shape (B,) ref_logps: reference model logprobs, shape (B,) yw_idxs: preferred completion indices in [0, B-1], shape (T,) yl_idxs: dispreferred completion indices in [0, B-1], shape (T,) beta: temperature controlling strength of KL penalty Each pair of (yw_idxs[i], yl_idxs[i]) represents the indices of a single preference pair. """ # Extract log probabilities for the preferred and dispreferred completions pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs] ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs] # Calculate log-ratios pi_logratios = pi_yw_logps - pi_yl_logps ref_logratios = ref_yw_logps - ref_yl_logps # Compute DPO loss losses = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios)) rewards = self.beta * (pi_logps - ref_logps).detach() return losses.mean(), rewards def train_step(self, batch): x, yw_idxs, yl_idxs = batch self.optimizer.zero_grad() # Compute log probabilities for the model and the reference model pi_logps = self.model(x).log_softmax(-1) ref_logps = self.ref_model(x).log_softmax(-1) # Compute the loss loss, _ = self.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs) loss.backward() self.optimizer.step() return loss.item() # Usage model = YourLanguageModel() # Initialize your model ref_model = YourLanguageModel() # Load pre-trained reference model trainer = DPOTrainer(model, ref_model) for batch in dataloader: loss = trainer.train_step(batch) print(f"Loss: {loss}")
๋์ ๊ณผ ์์ผ๋ก์ ๋ฐฉํฅ
DPO๋ ๊ธฐ์กด RLHF ์ ๊ทผ ๋ฐฉ์์ ๋นํด ์๋นํ ์ด์ ์ ์ ๊ณตํ์ง๋ง ์ฌ์ ํ ์ถ๊ฐ ์ฐ๊ตฌ๋ฅผ ์ํ ๊ณผ์ ์ ์์ญ์ด ์์ต๋๋ค.
a) ๋ ํฐ ๋ชจ๋ธ๋ก์ ํ์ฅ์ฑ:
์ธ์ด ๋ชจ๋ธ์ ํฌ๊ธฐ๊ฐ ๊ณ์ ์ฆ๊ฐํจ์ ๋ฐ๋ผ ์์ฒ์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๊ฐ ์๋ ๋ชจ๋ธ์ DPO๋ฅผ ํจ์จ์ ์ผ๋ก ์ ์ฉํ๋ ๊ฒ์ ์ฌ์ ํ โโํด๊ฒฐํด์ผ ํ ๊ณผ์ ๋ก ๋จ์ ์์ต๋๋ค. ์ฐ๊ตฌ์๋ค์ ๋ค์๊ณผ ๊ฐ์ ๊ธฐ์ ์ ํ๊ตฌํ๊ณ ์์ต๋๋ค.
- ํจ์จ์ ์ธ ๋ฏธ์ธ ์กฐ์ ๋ฐฉ๋ฒ(์: LoRA, ํ๋ฆฌํฝ์ค ํ๋)
- ๋ถ์ฐ ํ๋ จ ์ต์ ํ
- ๊ฒฝ์ฌ ์ฒดํฌํฌ์ธํธ ๋ฐ ํผํฉ ์ ๋ฐ๋ ํ๋ จ
DPO์ ํจ๊ป LoRA๋ฅผ ์ฌ์ฉํ๋ ์:
from peft import LoraConfig, get_peft_model class DPOTrainerWithLoRA(DPOTrainer): def __init__(self, model, ref_model, beta=0.1, lr=1e-5, lora_rank=8): lora_config = LoraConfig( r=lora_rank, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) self.model = get_peft_model(model, lora_config) self.ref_model = ref_model self.beta = beta self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) # Usage base_model = YourLargeLanguageModel() dpo_trainer = DPOTrainerWithLoRA(base_model, ref_model)
b) ๋ค์ค ์์ ๋ฐ Few-Shot ์ ์:
์ ํ๋ ์ ํธ๋ ๋ฐ์ดํฐ๋ก ์๋ก์ด ์์ ์ด๋ ์์ญ์ ํจ์จ์ ์ผ๋ก ์ ์ํ ์ ์๋ DPO ๊ธฐ์ ์ ๊ฐ๋ฐํ๋ ๊ฒ์ด ํ๋ฐํ ์ฐ๊ตฌ ๋ถ์ผ์ ๋๋ค. ํ์ ์ค์ธ ์ ๊ทผ ๋ฐฉ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ ์ํ ์ ์์ ์ํ ๋ฉํ๋ฌ๋ ํ๋ ์์ํฌ
- DPO๋ฅผ ์ํ ํ๋กฌํํธ ๊ธฐ๋ฐ ๋ฏธ์ธ ์กฐ์
- ์ผ๋ฐ ์ ํธ ๋ชจ๋ธ์์ ํน์ ๋๋ฉ์ธ์ผ๋ก ํ์ต ์ ์ด
c) ๋ชจํธํ๊ฑฐ๋ ์์ถฉ๋๋ ์ ํธ์ฌํญ ์ฒ๋ฆฌ:
์ค์ ์ ํธ๋ ๋ฐ์ดํฐ์๋ ๋ชจํธํจ์ด๋ ์ถฉ๋์ด ํฌํจ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ์ด๋ฌํ ๋ฐ์ดํฐ์ ๋ํ DPO์ ๊ฒฌ๊ณ ์ฑ์ ํฅ์์ํค๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์ ์ฌ์ ์ธ ์๋ฃจ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ํ๋ฅ ์ ์ ํธ ๋ชจ๋ธ๋ง
- ๋ชจํธํจ์ ํด๊ฒฐํ๊ธฐ ์ํ ๋ฅ๋์ ํ์ต
- ๋ค์ค ์์ด์ ํธ ๊ธฐ๋ณธ ์ค์ ์ง๊ณ
ํ๋ฅ ์ ์ ํธ ๋ชจ๋ธ๋ง์ ์:
class ProbabilisticDPOTrainer(DPOTrainer): def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob): # Compute log ratios pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs] ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs] log_ratio_diff = pi_yw_logps.sum(-1) - pi_yl_logps.sum(-1) loss = -(preference_prob * F.logsigmoid(self.beta * log_ratio_diff) + (1 - preference_prob) * F.logsigmoid(-self.beta * log_ratio_diff)) return loss.mean() # Usage trainer = ProbabilisticDPOTrainer(model, ref_model) loss = trainer.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob=0.8) # 80% confidence in preference
d) DPO๋ฅผ ๋ค๋ฅธ ์ ๋ ฌ ๊ธฐ์ ๊ณผ ๊ฒฐํฉ:
DPO๋ฅผ ๋ค๋ฅธ ์ ๋ ฌ ์ ๊ทผ ๋ฐฉ์๊ณผ ํตํฉํ๋ฉด ๋์ฑ ๊ฐ๋ ฅํ๊ณ ์ ๋ฅํ ์์คํ ์ด ๋ ์ ์์ต๋๋ค.
- ๋ช ์์ ์ ์ฝ์กฐ๊ฑด ๋ง์กฑ์ ์ํ ํ๋ฒ์ AI ์์น
- ๋ณต์กํ ์ ํธ๋ ๋์ถ์ ์ํ ํ ๋ก ๋ฐ ์ฌ๊ท์ ๋ณด์ ๋ชจ๋ธ๋ง
- ๊ธฐ๋ณธ ๋ณด์ ํจ์๋ฅผ ์ถ๋ก ํ๊ธฐ ์ํ ์ญ ๊ฐํ ํ์ต
DPO์ ํ๋ฒ AI๋ฅผ ๊ฒฐํฉํ ์:
class ConstitutionalDPOTrainer(DPOTrainer): def __init__(self, model, ref_model, beta=0.1, lr=1e-5, constraints=None): super().__init__(model, ref_model, beta, lr) self.constraints = constraints or [] def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs): base_loss = super().compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs) constraint_loss = 0 for constraint in self.constraints: constraint_loss += constraint(self.model, pi_logps, ref_logps, yw_idxs, yl_idxs) return base_loss + constraint_loss # Usage def safety_constraint(model, pi_logps, ref_logps, yw_idxs, yl_idxs): # Implement safety checking logic unsafe_score = compute_unsafe_score(model, pi_logps, ref_logps) return torch.relu(unsafe_score - 0.5) # Penalize if unsafe score > 0.5 constraints = [safety_constraint] trainer = ConstitutionalDPOTrainer(model, ref_model, constraints=constraints)
์ค์ ๊ณ ๋ ค ์ฌํญ ๋ฐ ๋ชจ๋ฒ ์ฌ๋ก
์ค์ ์์ฉ ํ๋ก๊ทธ๋จ์ ๋ํด DPO๋ฅผ ๊ตฌํํ ๋ ๋ค์ ํ์ ๊ณ ๋ คํ์ญ์์ค.
a) ๋ฐ์ดํฐ ํ์ง: ์ ํธ๋ ๋ฐ์ดํฐ์ ํ์ง์ด ์ค์ํฉ๋๋ค. ๋ฐ์ดํฐ์ธํธ๊ฐ ๋ค์๊ณผ ๊ฐ์์ง ํ์ธํ์ธ์.
- ๋ค์ํ ๋ฒ์์ ์ ๋ ฅ๊ณผ ์ํ๋ ๋์์ ํฌ๊ดํฉ๋๋ค.
- ์ผ๊ด๋๊ณ ์์ ์ ์ธ ๊ธฐ๋ณธ ์ค์ ์ฃผ์์ด ์์ต๋๋ค.
- ๋ค์ํ ์ ํ์ ์ ํธ๋(์: ์ฌ์ค์ฑ, ์์ ์ฑ, ์คํ์ผ)์ ๊ท ํ์ ์ ์งํฉ๋๋ค.
b) ํ์ดํผ ํ๋ผ๋ฏธํฐ ํ๋: DPO๋ RLHF๋ณด๋ค ํ์ดํผํ๋ผ๋ฏธํฐ ์๊ฐ ์ ์ง๋ง ํ๋์ ์ฌ์ ํ โโ์ค์ํฉ๋๋ค.
- ฮฒ(๋ฒ ํ): ์ ํธ๋ ๋ง์กฑ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ๊ณผ์ ์ฐจ์ด ๊ฐ์ ๊ท ํ์ ์ ์ดํฉ๋๋ค. ์ฃผ๋ณ์ ๊ฐ์น๋ถํฐ ์์ํ์ธ์ 0.1-0.5.
- ํ์ต๋ฅ : ์ผ๋ฐ์ ์ผ๋ก ๋ค์ ๋ฒ์์์ ํ์ค ๋ฏธ์ธ ์กฐ์ ๋ณด๋ค ๋ฎ์ ํ์ต๋ฅ ์ ์ฌ์ฉํฉ๋๋ค. 1e-6 ~ 1e-5.
- ๋ฐฐ์น ํฌ๊ธฐ: ๋ ํฐ ๋ฐฐ์น ํฌ๊ธฐ(32-128) ์ข ์ข ์ ํธ ํ์ต์ ์ ํฉํฉ๋๋ค.
c) ๋ฐ๋ณต์ ๊ฐ์ : DPO๋ ๋ฐ๋ณต์ ์ผ๋ก ์ ์ฉ๋ ์ ์์ต๋๋ค.
- DPO๋ฅผ ์ฌ์ฉํ์ฌ ์ด๊ธฐ ๋ชจ๋ธ ํ์ต
- ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋ก์ด ์๋ต ์์ฑ
- ์ด๋ฌํ ์๋ต์ ๋ํ ์๋ก์ด ์ ํธ๋ ๋ฐ์ดํฐ ์์ง
- ํ์ฅ๋ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ฌํ๋ จ
์ด ์ด๋ฏธ์ง๋ Direct Preference Optimization(DPO), Supervised Fine-Tuning(SFT), Proximal Policy Optimization(PPO)์ ํฌํจํ ๋ค์ํ ํ๋ จ ๊ธฐ๋ฒ์์ ์ธ๊ฐ์ ํ๋จ๊ณผ ๋น๊ตํ GPT-4์ ๊ฐ์ LLM์ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค. ์ด ํ๋ GPT-4์ ์ถ๋ ฅ์ด ํนํ ์์ฝ ์์ ์์ ์ธ๊ฐ์ ์ ํธ๋์ ์ ์ ๋ ์ผ์นํ๊ณ ์์์ ๋ณด์ฌ์ค๋๋ค. GPT-4์ ์ธ๊ฐ ๊ฒํ ์ ๊ฐ์ ํฉ์ ์์ค์ ์ด ๋ชจ๋ธ์ด ์ธ๊ฐ์ด ์์ฑํ ์ฝํ ์ธ ์ ๊ฑฐ์ ๋ง์ฐฌ๊ฐ์ง๋ก ์ธ๊ฐ ํ๊ฐ์์๊ฒ ๊ณต๊ฐ์ ์ป๋ ์ฝํ ์ธ ๋ฅผ ์์ฑํ ์ ์๋ ๋ฅ๋ ฅ์ ๋ณด์ฌ์ค๋๋ค.
์ฌ๋ก ์ฐ๊ตฌ ๋ฐ ์ ํ๋ฆฌ์ผ์ด์
DPO์ ํจ์จ์ฑ์ ์ค๋ช ํ๊ธฐ ์ํด ์ค์ ์ ํ๋ฆฌ์ผ์ด์ ๊ณผ ๊ทธ ๋ณํ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
- ๋ฐ๋ณต์ ์ธ DPO: Snorkel(2023)์์ ๊ฐ๋ฐํ ์ด ๋ณํ์ ๊ฑฐ๋ถ ์ํ๋ง๊ณผ DPO๋ฅผ ๊ฒฐํฉํ์ฌ ํ๋ จ ๋ฐ์ดํฐ์ ๋ํ ๋ณด๋ค ์ ๊ตํ ์ ํ ํ๋ก์ธ์ค๋ฅผ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ฌ๋ฌ ๋ผ์ด๋์ ์ ํธ๋ ์ํ๋ง์ ๋ฐ๋ณตํจ์ผ๋ก์จ ๋ชจ๋ธ์ ์ก์์ด ๋ง๊ฑฐ๋ ํธํฅ๋ ์ ํธ๋์ ๋ํ ๊ณผ์ ํฉ์ ๋ ์ ์ผ๋ฐํํ๊ณ ๋ฐฉ์งํ ์ ์์ต๋๋ค.
- IPO (๋ฐ๋ณต์ ์ ํธ๋ ์ต์ ํ): Azar ๋ฑ์ด ์๊ฐํจ. (2023), IPO๋ ์ ํธ๋ ๊ธฐ๋ฐ ์ต์ ํ์์ ์ผ๋ฐ์ ์ธ ๋ฌธ์ ์ธ ๊ณผ์ ํฉ์ ๋ฐฉ์งํ๊ธฐ ์ํด ์ ๊ทํ ์ฉ์ด๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ด ํ์ฅ์ ํตํด ๋ชจ๋ธ์ ์ ํธ๋ ์ค์์ ์ผ๋ฐํ ๊ธฐ๋ฅ ์ ์ง ๊ฐ์ ๊ท ํ์ ์ ์งํ ์ ์์ต๋๋ค.
- ๊ณต์ฌ(์ง์ ์ด์ ์ต์ ํ): Ethayarajh et al.์ ์ต์ ๋ณ์ข ์ ๋๋ค. (2023), KTO๋ ๋ฐ์ด๋๋ฆฌ ๊ธฐ๋ณธ ์ค์ ์ ์์ ํ ์๋ตํฉ๋๋ค. ๋์ , ์ฐธ์กฐ ๋ชจ๋ธ์ ์ง์์ ์ ์ฑ ๋ชจ๋ธ๋ก ์ด์ ํ๋ ๋ฐ ์ค์ ์ ๋๊ณ ์ธ๊ฐ ๊ฐ์น์ ๋ณด๋ค ์ํํ๊ณ ์ผ๊ด๋๊ฒ ์ ๋ ฌ๋๋๋ก ์ต์ ํํฉ๋๋ค.
- ๋๋ฉ์ธ ๊ฐ ํ์ต์ ์ํ ๋ค์ค ๋ชจ๋ DPO Xu et al. (2024): DPO๊ฐ ํ ์คํธ, ์ด๋ฏธ์ง, ์ค๋์ค ๋ฑ ๋ค์ํ ํ์์ ๊ฑธ์ณ ์ ์ฉ๋๋ ์ ๊ทผ ๋ฐฉ์์ผ๋ก, ๋ค์ํ ๋ฐ์ดํฐ ์ ํ ์ ๋ฐ์ ๊ฑธ์ณ ์ธ๊ฐ์ ์ ํธ๋์ ๋ง์ถฐ ๋ชจ๋ธ์ ์กฐ์ ํ๋ ๋ค์ฌ๋ค๋ฅํจ์ ๋ณด์ฌ์ค๋๋ค. ์ด ์ฐ๊ตฌ๋ ๋ณต์กํ ๋ค์ค ๋ชจ๋ ์์ ์ ์ฒ๋ฆฌํ ์ ์๋ ๋ณด๋ค ํฌ๊ด์ ์ธ AI ์์คํ ์ ๋ง๋๋ ๋ฐ ์์ด DPO์ ์ ์ฌ๋ ฅ์ ๊ฐ์กฐํฉ๋๋ค.