Rescale ๋ฏธํŒ… ์˜ˆ์•ฝ

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™”: ์ „์ฒด ๊ฐ€์ด๋“œ

์ธ๊ณต์ง€๋Šฅ

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™”: ์ „์ฒด ๊ฐ€์ด๋“œ

mm
LLM DPO ์ˆ˜ํ•™ ๋ฐ ์ฝ”๋“œ

์ธ๊ฐ„์˜ ๊ฐ€์น˜์™€ ์„ ํ˜ธ๋„์— ๋งž์ถฐ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ(LLM)์„ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์€ ์–ด๋ ต์Šต๋‹ˆ๋‹ค. ๋“ฑ์˜ ์ „ํ†ต์ ์ธ ๋ฐฉ๋ฒ• ์‚ฌ๋žŒ์˜ ํ”ผ๋“œ๋ฐฑ์„ ํ†ตํ•œ ๊ฐ•ํ™” ํ•™์Šต (RLHF)๋Š” ์‚ฌ๋žŒ์˜ ์ž…๋ ฅ์„ ํ†ตํ•ฉํ•˜์—ฌ ๋ชจ๋ธ ์ถœ๋ ฅ์„ ๊ฐœ์„ ํ•จ์œผ๋กœ์จ ๊ธธ์„ ์—ด์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ RLHF๋Š” ๋ณต์žกํ•˜๊ณ  ๋ฆฌ์†Œ์Šค ์ง‘์•ฝ์ ์ผ ์ˆ˜ ์žˆ์œผ๋ฉฐ ์ƒ๋‹นํ•œ ๊ณ„์‚ฐ ๋Šฅ๋ ฅ๊ณผ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™” (DPO)๋Š” ์ƒˆ๋กญ๊ณ  ๋”์šฑ ๊ฐ„์†Œํ™”๋œ ์ ‘๊ทผ ๋ฐฉ์‹์œผ๋กœ ๋“ฑ์žฅํ•˜์—ฌ ์ด๋Ÿฌํ•œ ๊ธฐ์กด ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ํšจ์œจ์ ์ธ ๋Œ€์•ˆ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. DPO๋Š” ์ตœ์ ํ™” ํ”„๋กœ์„ธ์Šค๋ฅผ ๋‹จ์ˆœํ™”ํ•จ์œผ๋กœ์จ ๊ณ„์‚ฐ ๋ถ€๋‹ด์„ ์ค„์ผ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ธ๊ฐ„์˜ ์„ ํ˜ธ๋„์— ๋น ๋ฅด๊ฒŒ ์ ์‘ํ•˜๋Š” ๋ชจ๋ธ์˜ ๋Šฅ๋ ฅ์„ ํ–ฅ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” DPO์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๊ณ  DPO์˜ ๊ธฐ์ดˆ, ๊ตฌํ˜„ ๋ฐ ์‹ค์ œ ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

์„ ํ˜ธ ์ •๋ ฌ์˜ ํ•„์š”์„ฑ

DPO๋ฅผ ์ดํ•ดํ•˜๋ ค๋ฉด LLM์„ ์ธ๊ฐ„์˜ ์„ ํ˜ธ๋„์— ๋งž์ถ”๋Š” ๊ฒƒ์ด ์™œ ๊ทธ๋ ‡๊ฒŒ ์ค‘์š”ํ•œ์ง€ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ธ์ƒ์ ์ธ ์—ญ๋Ÿ‰์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋ฐฉ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋กœ ํ›ˆ๋ จ๋œ LLM์€ ๋•Œ๋•Œ๋กœ ์ผ๊ด€์„ฑ์ด ์—†๊ฑฐ๋‚˜ ํŽธํ–ฅ์ ์ด๊ฑฐ๋‚˜ ์ธ๊ฐ„์˜ ๊ฐ€์น˜์™€ ๋งž์ง€ ์•Š๋Š” ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ถˆ์ผ์น˜๋Š” ๋‹ค์–‘ํ•œ ๋ฐฉ์‹์œผ๋กœ ๋‚˜ํƒ€๋‚  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ์•ˆ์ „ํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ์œ ํ•ดํ•œ ์ฝ˜ํ…์ธ  ์ƒ์„ฑ
  • ๋ถ€์ •ํ™•ํ•˜๊ฑฐ๋‚˜ ์˜คํ•ด์˜ ์†Œ์ง€๊ฐ€ ์žˆ๋Š” ์ •๋ณด ์ œ๊ณต
  • ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์— ์กด์žฌํ•˜๋Š” ํŽธํ–ฅ ํ‘œ์‹œ

์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์—ฐ๊ตฌ์ž๋“ค์€ ์ธ๊ฐ„์˜ ํ”ผ๋“œ๋ฐฑ์„ ์‚ฌ์šฉํ•˜์—ฌ LLM์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ธฐ์ˆ ์„ ๊ฐœ๋ฐœํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ ‘๊ทผ๋ฒ• ์ค‘ ๊ฐ€์žฅ ๋‘๋“œ๋Ÿฌ์ง„ ๊ฒƒ์€ RLHF์˜€์Šต๋‹ˆ๋‹ค.

RLHF ์ดํ•ด: DPO์˜ ์„ ๊ตฌ์ž

์ธ๊ฐ„ ํ”ผ๋“œ๋ฐฑ์„ ํ†ตํ•œ ๊ฐ•ํ™” ํ•™์Šต(RLHF)์€ LLM์„ ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋งž์ถ”๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๋ฐฉ๋ฒ•์ด์—ˆ์Šต๋‹ˆ๋‹ค. ๋ณต์žก์„ฑ์„ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด RLHF ํ”„๋กœ์„ธ์Šค๋ฅผ ๋ถ„์„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

a) ๊ฐ๋…ํ˜• ๋ฏธ์„ธ ์กฐ์ •(SFT): ํ”„๋กœ์„ธ์Šค๋Š” ๊ณ ํ’ˆ์งˆ ์‘๋‹ต ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ LLM์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค. ์ด ๋‹จ๊ณ„๋Š” ๋ชจ๋ธ์ด ๋Œ€์ƒ ์ž‘์—…์— ๋Œ€ํ•ด ๋ณด๋‹ค ๊ด€๋ จ์„ฑ์ด ๋†’๊ณ  ์ผ๊ด€๋œ ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

b) ๋ณด์ƒ ๋ชจ๋ธ๋ง: ์ธ๊ฐ„์˜ ์„ ํ˜ธ๋„๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•ด ๋ณ„๋„์˜ ๋ณด์ƒ ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—๋Š” ๋‹ค์Œ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

  • ์ฃผ์–ด์ง„ ํ”„๋กฌํ”„ํŠธ์— ๋Œ€ํ•œ ์‘๋‹ต ์Œ ์ƒ์„ฑ
  • ์ธ๊ฐ„์ด ์„ ํ˜ธํ•˜๋Š” ๋ฐ˜์‘์„ ํ‰๊ฐ€ํ•˜๊ฒŒ ํ•จ
  • ์ด๋Ÿฌํ•œ ์„ ํ˜ธ๋„๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ ๋ชจ๋ธ ํ•™์Šต

c) ๊ฐ•ํ™” ํ•™์Šต: ๋ฏธ์„ธ ์กฐ์ •๋œ LLM์€ ๊ฐ•ํ™” ํ•™์Šต์„ ์‚ฌ์šฉํ•˜์—ฌ ๋”์šฑ ์ตœ์ ํ™”๋ฉ๋‹ˆ๋‹ค. ๋ณด์ƒ ๋ชจ๋ธ์€ ํ”ผ๋“œ๋ฐฑ์„ ์ œ๊ณตํ•˜์—ฌ LLM์ด ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋งž๋Š” ์‘๋‹ต์„ ์ƒ์„ฑํ•˜๋„๋ก ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์€ RLHF ํ”„๋กœ์„ธ์Šค๋ฅผ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•œ ๋‹จ์ˆœํ™”๋œ Python ์˜์‚ฌ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.

RLHF๋Š” ํšจ๊ณผ์ ์ด์ง€๋งŒ ๋ช‡ ๊ฐ€์ง€ ๋‹จ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  • ์—ฌ๋Ÿฌ ๋ชจ๋ธ(SFT, ๋ณด์ƒ ๋ชจ๋ธ, RL ์ตœ์ ํ™” ๋ชจ๋ธ)์„ ํ›ˆ๋ จํ•˜๊ณ  ์œ ์ง€ ๊ด€๋ฆฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • RL ํ”„๋กœ์„ธ์Šค๋Š” ๋ถˆ์•ˆ์ •ํ•˜๊ณ  ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์— ๋ฏผ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ณ„์‚ฐ ๋น„์šฉ์ด ๋งŽ์ด ๋“ค๊ณ  ๋ชจ๋ธ์„ ํ†ตํ•ด ๋งŽ์€ ์ •๋ฐฉํ–ฅ ๋ฐ ์—ญ๋ฐฉํ–ฅ ํ†ต๊ณผ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ์ œํ•œ์œผ๋กœ ์ธํ•ด ๋” ๊ฐ„๋‹จํ•˜๊ณ  ํšจ์œจ์ ์ธ ๋Œ€์•ˆ์„ ์ฐพ๊ฒŒ ๋˜์—ˆ๊ณ , ์ด๋Š” DPO ๊ฐœ๋ฐœ๋กœ ์ด์–ด์กŒ์Šต๋‹ˆ๋‹ค.

์ง์ ‘ ์„ ํ˜ธ ์ตœ์ ํ™”: ํ•ต์‹ฌ ๊ฐœ๋…

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™” https://arxiv.org/abs/2305.18290

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™” https://arxiv.org/abs/2305.18290

์ด ์ด๋ฏธ์ง€๋Š” LLM ์ถœ๋ ฅ์„ ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋งž์ถ”๋Š” ๋‘ ๊ฐ€์ง€ ์ ‘๊ทผ ๋ฐฉ์‹, ์ฆ‰ ์ธ๊ฐ„ ํ”ผ๋“œ๋ฐฑ์„ ํ†ตํ•œ ๊ฐ•ํ™” ํ•™์Šต(RLHF)๊ณผ ์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™”(DPO)๋ฅผ ๋Œ€์กฐํ•ฉ๋‹ˆ๋‹ค. RLHF๋Š” ๋ณด์ƒ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ˜๋ณต์ ์ธ ํ”ผ๋“œ๋ฐฑ ๋ฃจํ”„๋ฅผ ํ†ตํ•ด ์–ธ์–ด ๋ชจ๋ธ์˜ ์ •์ฑ…์„ ์•ˆ๋‚ดํ•˜๋Š” ๋ฐ˜๋ฉด, DPO๋Š” ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ธ๊ฐ„์ด ์„ ํ˜ธํ•˜๋Š” ์‘๋‹ต๊ณผ ์ผ์น˜ํ•˜๋„๋ก ๋ชจ๋ธ ์ถœ๋ ฅ์„ ์ง์ ‘ ์ตœ์ ํ™”ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋น„๊ต๋Š” ๊ฐ ๋ฐฉ๋ฒ•์˜ ์žฅ์ ๊ณผ ์ž ์žฌ์ ์ธ ์ ์šฉ์„ ๊ฐ•์กฐํ•˜์—ฌ ๋ฏธ๋ž˜์˜ LLM์ด ์ธ๊ฐ„์˜ ๊ธฐ๋Œ€์— ๋” ์ž˜ ๋ถ€์‘ํ•˜๋„๋ก ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ํ†ต์ฐฐ๋ ฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

DPO์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด:

a) ์•”์‹œ์  ๋ณด์ƒ ๋ชจ๋ธ๋ง: DPO๋Š” ์–ธ์–ด ๋ชจ๋ธ ์ž์ฒด๋ฅผ ์•”์‹œ์  ๋ณด์ƒ ํ•จ์ˆ˜๋กœ ์ทจ๊ธ‰ํ•˜์—ฌ ๋ณ„๋„์˜ ๋ณด์ƒ ๋ชจ๋ธ์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

b) ์ •์ฑ… ๊ธฐ๋ฐ˜ ๊ณต์‹ํ™”: DPO๋Š” ๋ณด์ƒ ํ•จ์ˆ˜๋ฅผ ์ตœ์ ํ™”ํ•˜๋Š” ๋Œ€์‹  ์ •์ฑ…(์–ธ์–ด ๋ชจ๋ธ)์„ ์ง์ ‘ ์ตœ์ ํ™”ํ•˜์—ฌ ์„ ํ˜ธ ์‘๋‹ต ํ™•๋ฅ ์„ ์ตœ๋Œ€ํ™”ํ•ฉ๋‹ˆ๋‹ค.

c) ํ์‡„ํ˜• ์†”๋ฃจ์…˜: DPO๋Š” ์ตœ์ ์˜ ์ •์ฑ…์— ๋Œ€ํ•œ ํ์‡„ํ˜• ์†”๋ฃจ์…˜์„ ํ—ˆ์šฉํ•˜๋Š” ์ˆ˜ํ•™์  ํ†ต์ฐฐ๋ ฅ์„ ํ™œ์šฉํ•˜์—ฌ ๋ฐ˜๋ณต์ ์ธ RL ์—…๋ฐ์ดํŠธ๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

DPO ๊ตฌํ˜„: ์‹ค์šฉ์ ์ธ ์ฝ”๋“œ ์—ฐ์Šต

์•„๋ž˜ ์ด๋ฏธ์ง€๋Š” PyTorch๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ DPO ์†์‹ค ๊ธฐ๋Šฅ์„ ๊ตฌํ˜„ํ•˜๋Š” ์ฝ”๋“œ ์กฐ๊ฐ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ์–ธ์–ด ๋ชจ๋ธ์ด ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋”ฐ๋ผ ์ถœ๋ ฅ์˜ ์šฐ์„ ์ˆœ์œ„๋ฅผ ์ง€์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ฐœ์„ ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค. ์ฃผ์š” ๊ตฌ์„ฑ ์š”์†Œ์— ๋Œ€ํ•œ ๋ถ„์„์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ๊ธฐ๋Šฅ ์„œ๋ช…๋‹ค์Œ dpo_loss ํ•จ์ˆ˜๋Š” ์ •์ฑ… ๋กœ๊ทธ ํ™•๋ฅ (pi_logps), ์ฐธ์กฐ ๋ชจ๋ธ ๋กœ๊ทธ ํ™•๋ฅ (ref_logps), ์„ ํ˜ธ ๋ฐ ๋น„์„ ํ˜ธ ์™„์„ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ง€์ˆ˜(yw_idxs, yl_idxs). ์ถ”๊ฐ€์ ์œผ๋กœ, beta ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” KL ํŽ˜๋„ํ‹ฐ์˜ ๊ฐ•๋„๋ฅผ ์ œ์–ดํ•ฉ๋‹ˆ๋‹ค.
  • ๋กœ๊ทธ ํ™•๋ฅ  ์ถ”์ถœ: ์ฝ”๋“œ๋Š” ์ •์ฑ… ๋ชจ๋ธ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ ๋ชจ๋‘์—์„œ ์„ ํ˜ธ ์™„๋ฃŒ ๋ฐ ๋น„์„ ํ˜ธ ์™„๋ฃŒ์— ๋Œ€ํ•œ ๋กœ๊ทธ ํ™•๋ฅ ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
  • ๋กœ๊ทธ ๋น„์œจ ๊ณ„์‚ฐ: ์„ ํ˜ธ ์™„๋ฃŒ์™€ ๋น„์„ ํ˜ธ ์™„๋ฃŒ์— ๋Œ€ํ•œ ๋กœ๊ทธ ํ™•๋ฅ ์˜ ์ฐจ์ด๋Š” ์ •์ฑ… ๋ชจ๋ธ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ ๋ชจ๋‘์— ๋Œ€ํ•ด ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. ์ด ๋น„์œจ์€ ์ตœ์ ํ™”์˜ ๋ฐฉํ–ฅ๊ณผ ๊ทœ๋ชจ๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
  • ์†์‹ค ๋ฐ ๋ณด์ƒ ๊ณ„์‚ฐ: ์†์‹ค์€ ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. logsigmoid ๊ธฐ๋Šฅ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐ˜๋ฉด, ๋ณด์ƒ์€ ์ •์ฑ… ๋กœ๊ทธ ํ™•๋ฅ ๊ณผ ์ฐธ์กฐ ๋กœ๊ทธ ํ™•๋ฅ  ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์กฐ์ •ํ•˜์—ฌ ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค. beta.
PyTorch๋ฅผ ์‚ฌ์šฉํ•œ DPO ์†์‹ค ๊ธฐ๋Šฅ

PyTorch๋ฅผ ์‚ฌ์šฉํ•œ DPO ์†์‹ค ๊ธฐ๋Šฅ

์ด๋Ÿฌํ•œ ๋ชฉํ‘œ๋ฅผ ๋‹ฌ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด DPO ์ด๋ฉด์˜ ์ˆ˜ํ•™์„ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

DPO์˜ ์ˆ˜ํ•™

DPO๋Š” ์„ ํ˜ธ ํ•™์Šต ๋ฌธ์ œ๋ฅผ ์˜๋ฆฌํ•˜๊ฒŒ ์žฌ๊ตฌ์„ฑํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹จ๊ณ„๋ณ„ ๋ถ„์„์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

a) ์ถœ๋ฐœ์ : KL ์ œ์•ฝ ๋ณด์ƒ ๊ทน๋Œ€ํ™”

์›๋ž˜ RLHF ๋ชฉํ‘œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ ์ด๋ฏธ์ง€์˜ ๋ณต์žกํ•œ ์ˆ˜ํ•™ ๊ณต์‹์€ LLM์ด ์ถœ๋ ฅ์„ ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋งž๊ฒŒ ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ฐœ์„ ํ•˜๋Š” ์ตœ์ฒจ๋‹จ ๊ต์œก ๋ฐฉ๋ฒ•์ธ DPO(Direct Preference Optimization)์— ์‚ฌ์šฉ๋˜๋Š” ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

์–ด๋””์—:
  • ฯ€ฮธ๋Š” ์šฐ๋ฆฌ๊ฐ€ ์ตœ์ ํ™”ํ•˜๋Š” ์ •์ฑ…(์–ธ์–ด ๋ชจ๋ธ)์ž…๋‹ˆ๋‹ค.
  • r(x,y)๋Š” ๋ณด์ƒ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
  • ฯ€ref๋Š” ์ฐธ์กฐ ์ •์ฑ…(๋ณดํ†ต ์ดˆ๊ธฐ SFT ๋ชจ๋ธ)์ž…๋‹ˆ๋‹ค.
  • ฮฒ๋Š” KL ๋ฐœ์‚ฐ ์ œ์•ฝ ์กฐ๊ฑด์˜ ๊ฐ•๋„๋ฅผ ์ œ์–ดํ•ฉ๋‹ˆ๋‹ค.

b) ์ตœ์ ์˜ ์ •์ฑ… ํ˜•ํƒœ: ์ด ๋ชฉํ‘œ์— ๋Œ€ํ•œ ์ตœ์ ์˜ ์ •์ฑ…์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ˜•์‹์„ ์ทจํ•œ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ฯ€_r(y|x) = 1/Z(x) * ฯ€ref(y|x) * exp(1/ฮฒ * r(x,y))

์—ฌ๊ธฐ์„œ Z(x)๋Š” ์ •๊ทœํ™” ์ƒ์ˆ˜์ž…๋‹ˆ๋‹ค.

c) ๋ณด์ƒ ์ •์ฑ… ์ด์ค‘์„ฑ: DPO์˜ ํ•ต์‹ฌ ํ†ต์ฐฐ๋ ฅ์€ ์ตœ์ ์˜ ์ •์ฑ… ์ธก๋ฉด์—์„œ ๋ณด์ƒ ๊ธฐ๋Šฅ์„ ํ‘œํ˜„ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

r(x,y) = ฮฒ * log(ฯ€_r(y|x) / ฯ€ref(y|x)) + ฮฒ * log(Z(x))

d) ์„ ํ˜ธ ๋ชจ๋ธ ์„ ํ˜ธ๊ฐ€ Bradley-Terry ๋ชจ๋ธ์„ ๋”ฐ๋ฅธ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด y1๋ณด๋‹ค y2์„ ์„ ํ˜ธํ•  ํ™•๋ฅ ์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

p*(y1 โ‰ป y2 | x) = ฯƒ(r*(x,y1) - r*(x,y2))

์—ฌ๊ธฐ์„œ ฯƒ๋Š” ๋กœ์ง€์Šคํ‹ฑ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.

e) DPO ๋ชฉํ‘œ ๋ณด์ƒ ์ •์ฑ… ์ด์ค‘์„ฑ์„ ์„ ํ˜ธ ๋ชจ๋ธ๋กœ ๋Œ€์ฒดํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ DPO ๋ชฉํ‘œ์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

L_DPO(ฯ€ฮธ; ฯ€ref) = -E_(x,y_w,y_l)~D [log ฯƒ(ฮฒ * log(ฯ€ฮธ(y_w|x) / ฯ€ref(y_w|x)) - ฮฒ * log(ฯ€ฮธ(y_l|x) / ฯ€ref(y_l|x)))]

์ด ๋ชฉํ‘œ๋Š” RL ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์—†์ด๋„ ํ‘œ์ค€ ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

DPO ๊ตฌํ˜„

์ด์ œ DPO์˜ ๊ธฐ๋ณธ ์ด๋ก ์„ ์ดํ•ดํ–ˆ์œผ๋ฏ€๋กœ ์ด๋ฅผ ์‹ค์ œ๋กœ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋‹ค Python ๊ทธ๋ฆฌ๊ณ  ํŒŒ์ด ํ† ์น˜ ์ด ์˜ˆ์˜ ๊ฒฝ์šฐ:

import torch
import torch.nn.functional as F

class DPOTrainer:
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5):
        self.model = model
        self.ref_model = ref_model
        self.beta = beta
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
    
    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs):
        """
        pi_logps: policy logprobs, shape (B,)
        ref_logps: reference model logprobs, shape (B,)
        yw_idxs: preferred completion indices in [0, B-1], shape (T,)
        yl_idxs: dispreferred completion indices in [0, B-1], shape (T,)
        beta: temperature controlling strength of KL penalty

        Each pair of (yw_idxs[i], yl_idxs[i]) represents the indices of a single preference pair.
        """

        # Extract log probabilities for the preferred and dispreferred completions
        pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
        ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]

        # Calculate log-ratios
        pi_logratios = pi_yw_logps - pi_yl_logps
        ref_logratios = ref_yw_logps - ref_yl_logps

        # Compute DPO loss
        losses = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
        rewards = self.beta * (pi_logps - ref_logps).detach()

        return losses.mean(), rewards

    def train_step(self, batch):
        x, yw_idxs, yl_idxs = batch
        self.optimizer.zero_grad()

        # Compute log probabilities for the model and the reference model
        pi_logps = self.model(x).log_softmax(-1)
        ref_logps = self.ref_model(x).log_softmax(-1)

        # Compute the loss
        loss, _ = self.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs)
        loss.backward()
        self.optimizer.step()

        return loss.item()

# Usage
model = YourLanguageModel()  # Initialize your model
ref_model = YourLanguageModel()  # Load pre-trained reference model
trainer = DPOTrainer(model, ref_model)

for batch in dataloader:
    loss = trainer.train_step(batch)
    print(f"Loss: {loss}")

๋„์ „๊ณผ ์•ž์œผ๋กœ์˜ ๋ฐฉํ–ฅ

DPO๋Š” ๊ธฐ์กด RLHF ์ ‘๊ทผ ๋ฐฉ์‹์— ๋น„ํ•ด ์ƒ๋‹นํ•œ ์ด์ ์„ ์ œ๊ณตํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ์ถ”๊ฐ€ ์—ฐ๊ตฌ๋ฅผ ์œ„ํ•œ ๊ณผ์ œ์™€ ์˜์—ญ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

a) ๋” ํฐ ๋ชจ๋ธ๋กœ์˜ ํ™•์žฅ์„ฑ:

์–ธ์–ด ๋ชจ๋ธ์˜ ํฌ๊ธฐ๊ฐ€ ๊ณ„์† ์ฆ๊ฐ€ํ•จ์— ๋”ฐ๋ผ ์ˆ˜์ฒœ์–ต ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ์— DPO๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์ ์šฉํ•˜๋Š” ๊ฒƒ์€ ์—ฌ์ „ํžˆ โ€‹โ€‹ํ•ด๊ฒฐํ•ด์•ผ ํ•  ๊ณผ์ œ๋กœ ๋‚จ์•„ ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฐ๊ตฌ์›๋“ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ธฐ์ˆ ์„ ํƒ๊ตฌํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

  • ํšจ์œจ์ ์ธ ๋ฏธ์„ธ ์กฐ์ • ๋ฐฉ๋ฒ•(์˜ˆ: LoRA, ํ”„๋ฆฌํ”ฝ์Šค ํŠœ๋‹)
  • ๋ถ„์‚ฐ ํ›ˆ๋ จ ์ตœ์ ํ™”
  • ๊ฒฝ์‚ฌ ์ฒดํฌํฌ์ธํŠธ ๋ฐ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ›ˆ๋ จ

DPO์™€ ํ•จ๊ป˜ LoRA๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์˜ˆ:

from peft import LoraConfig, get_peft_model

class DPOTrainerWithLoRA(DPOTrainer):
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5, lora_rank=8):
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        self.model = get_peft_model(model, lora_config)
        self.ref_model = ref_model
        self.beta = beta
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)

# Usage
base_model = YourLargeLanguageModel()
dpo_trainer = DPOTrainerWithLoRA(base_model, ref_model)

b) ๋‹ค์ค‘ ์ž‘์—… ๋ฐ Few-Shot ์ ์‘:

์ œํ•œ๋œ ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ๋กœ ์ƒˆ๋กœ์šด ์ž‘์—…์ด๋‚˜ ์˜์—ญ์— ํšจ์œจ์ ์œผ๋กœ ์ ์‘ํ•  ์ˆ˜ ์žˆ๋Š” DPO ๊ธฐ์ˆ ์„ ๊ฐœ๋ฐœํ•˜๋Š” ๊ฒƒ์ด ํ™œ๋ฐœํ•œ ์—ฐ๊ตฌ ๋ถ„์•ผ์ž…๋‹ˆ๋‹ค. ํƒ์ƒ‰ ์ค‘์ธ ์ ‘๊ทผ ๋ฐฉ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ์‹ ์†ํ•œ ์ ์‘์„ ์œ„ํ•œ ๋ฉ”ํƒ€๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ
  • DPO๋ฅผ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฐ˜ ๋ฏธ์„ธ ์กฐ์ •
  • ์ผ๋ฐ˜ ์„ ํ˜ธ ๋ชจ๋ธ์—์„œ ํŠน์ • ๋„๋ฉ”์ธ์œผ๋กœ ํ•™์Šต ์ „์ด

c) ๋ชจํ˜ธํ•˜๊ฑฐ๋‚˜ ์ƒ์ถฉ๋˜๋Š” ์„ ํ˜ธ์‚ฌํ•ญ ์ฒ˜๋ฆฌ:

์‹ค์ œ ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ์—๋Š” ๋ชจํ˜ธํ•จ์ด๋‚˜ ์ถฉ๋Œ์ด ํฌํ•จ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ DPO์˜ ๊ฒฌ๊ณ ์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ž ์žฌ์ ์ธ ์†”๋ฃจ์…˜์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ํ™•๋ฅ ์  ์„ ํ˜ธ ๋ชจ๋ธ๋ง
  • ๋ชจํ˜ธํ•จ์„ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋Šฅ๋™์  ํ•™์Šต
  • ๋‹ค์ค‘ ์—์ด์ „ํŠธ ๊ธฐ๋ณธ ์„ค์ • ์ง‘๊ณ„

ํ™•๋ฅ ์  ์„ ํ˜ธ ๋ชจ๋ธ๋ง์˜ ์˜ˆ:

class ProbabilisticDPOTrainer(DPOTrainer):
    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob):
        # Compute log ratios
        pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
        ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
        
        log_ratio_diff = pi_yw_logps.sum(-1) - pi_yl_logps.sum(-1)
        loss = -(preference_prob * F.logsigmoid(self.beta * log_ratio_diff) +
                 (1 - preference_prob) * F.logsigmoid(-self.beta * log_ratio_diff))
        return loss.mean()

# Usage
trainer = ProbabilisticDPOTrainer(model, ref_model)
loss = trainer.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob=0.8)  # 80% confidence in preference

d) DPO๋ฅผ ๋‹ค๋ฅธ ์ •๋ ฌ ๊ธฐ์ˆ ๊ณผ ๊ฒฐํ•ฉ:

DPO๋ฅผ ๋‹ค๋ฅธ ์ •๋ ฌ ์ ‘๊ทผ ๋ฐฉ์‹๊ณผ ํ†ตํ•ฉํ•˜๋ฉด ๋”์šฑ ๊ฐ•๋ ฅํ•˜๊ณ  ์œ ๋Šฅํ•œ ์‹œ์Šคํ…œ์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ๋ช…์‹œ์  ์ œ์•ฝ์กฐ๊ฑด ๋งŒ์กฑ์„ ์œ„ํ•œ ํ—Œ๋ฒ•์  AI ์›์น™
  • ๋ณต์žกํ•œ ์„ ํ˜ธ๋„ ๋„์ถœ์„ ์œ„ํ•œ ํ† ๋ก  ๋ฐ ์žฌ๊ท€์  ๋ณด์ƒ ๋ชจ๋ธ๋ง
  • ๊ธฐ๋ณธ ๋ณด์ƒ ํ•จ์ˆ˜๋ฅผ ์ถ”๋ก ํ•˜๊ธฐ ์œ„ํ•œ ์—ญ ๊ฐ•ํ™” ํ•™์Šต

DPO์™€ ํ—Œ๋ฒ• AI๋ฅผ ๊ฒฐํ•ฉํ•œ ์˜ˆ:

class ConstitutionalDPOTrainer(DPOTrainer):
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5, constraints=None):
        super().__init__(model, ref_model, beta, lr)
        self.constraints = constraints or []

    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs):
        base_loss = super().compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs)
        
        constraint_loss = 0
        for constraint in self.constraints:
            constraint_loss += constraint(self.model, pi_logps, ref_logps, yw_idxs, yl_idxs)
        
        return base_loss + constraint_loss

# Usage
def safety_constraint(model, pi_logps, ref_logps, yw_idxs, yl_idxs):
    # Implement safety checking logic
    unsafe_score = compute_unsafe_score(model, pi_logps, ref_logps)
    return torch.relu(unsafe_score - 0.5)  # Penalize if unsafe score > 0.5

constraints = [safety_constraint]
trainer = ConstitutionalDPOTrainer(model, ref_model, constraints=constraints)

์‹ค์ œ ๊ณ ๋ ค ์‚ฌํ•ญ ๋ฐ ๋ชจ๋ฒ” ์‚ฌ๋ก€

์‹ค์ œ ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์— ๋Œ€ํ•ด DPO๋ฅผ ๊ตฌํ˜„ํ•  ๋•Œ ๋‹ค์Œ ํŒ์„ ๊ณ ๋ คํ•˜์‹ญ์‹œ์˜ค.

a) ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ: ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ์˜ ํ’ˆ์งˆ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์„ธํŠธ๊ฐ€ ๋‹ค์Œ๊ณผ ๊ฐ™์€์ง€ ํ™•์ธํ•˜์„ธ์š”.

  • ๋‹ค์–‘ํ•œ ๋ฒ”์œ„์˜ ์ž…๋ ฅ๊ณผ ์›ํ•˜๋Š” ๋™์ž‘์„ ํฌ๊ด„ํ•ฉ๋‹ˆ๋‹ค.
  • ์ผ๊ด€๋˜๊ณ  ์•ˆ์ •์ ์ธ ๊ธฐ๋ณธ ์„ค์ • ์ฃผ์„์ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋‹ค์–‘ํ•œ ์œ ํ˜•์˜ ์„ ํ˜ธ๋„(์˜ˆ: ์‚ฌ์‹ค์„ฑ, ์•ˆ์ „์„ฑ, ์Šคํƒ€์ผ)์˜ ๊ท ํ˜•์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

b) ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹: DPO๋Š” RLHF๋ณด๋‹ค ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ์ ์ง€๋งŒ ํŠœ๋‹์€ ์—ฌ์ „ํžˆ โ€‹โ€‹์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

  • ฮฒ(๋ฒ ํƒ€): ์„ ํ˜ธ๋„ ๋งŒ์กฑ๊ณผ ์ฐธ์กฐ ๋ชจ๋ธ๊ณผ์˜ ์ฐจ์ด ๊ฐ„์˜ ๊ท ํ˜•์„ ์ œ์–ดํ•ฉ๋‹ˆ๋‹ค. ์ฃผ๋ณ€์˜ ๊ฐ€์น˜๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜์„ธ์š” 0.1-0.5.
  • ํ•™์Šต๋ฅ : ์ผ๋ฐ˜์ ์œผ๋กœ ๋‹ค์Œ ๋ฒ”์œ„์—์„œ ํ‘œ์ค€ ๋ฏธ์„ธ ์กฐ์ •๋ณด๋‹ค ๋‚ฎ์€ ํ•™์Šต๋ฅ ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. 1e-6 ~ 1e-5.
  • ๋ฐฐ์น˜ ํฌ๊ธฐ: ๋” ํฐ ๋ฐฐ์น˜ ํฌ๊ธฐ(32-128) ์ข…์ข… ์„ ํ˜ธ ํ•™์Šต์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.

c) ๋ฐ˜๋ณต์  ๊ฐœ์„ : DPO๋Š” ๋ฐ˜๋ณต์ ์œผ๋กœ ์ ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  1. DPO๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ดˆ๊ธฐ ๋ชจ๋ธ ํ•™์Šต
  2. ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์‘๋‹ต ์ƒ์„ฑ
  3. ์ด๋Ÿฌํ•œ ์‘๋‹ต์— ๋Œ€ํ•œ ์ƒˆ๋กœ์šด ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
  4. ํ™•์žฅ๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์žฌํ›ˆ๋ จ

 

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™”

์ง์ ‘ ์„ ํ˜ธ ์ตœ์ ํ™” ์„ฑ๋Šฅ

์ด ์ด๋ฏธ์ง€๋Š” Direct Preference Optimization(DPO), Supervised Fine-Tuning(SFT), Proximal Policy Optimization(PPO)์„ ํฌํ•จํ•œ ๋‹ค์–‘ํ•œ ํ›ˆ๋ จ ๊ธฐ๋ฒ•์—์„œ ์ธ๊ฐ„์˜ ํŒ๋‹จ๊ณผ ๋น„๊ตํ•œ GPT-4์™€ ๊ฐ™์€ LLM์˜ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด ํ‘œ๋Š” GPT-4์˜ ์ถœ๋ ฅ์ด ํŠนํžˆ ์š”์•ฝ ์ž‘์—…์—์„œ ์ธ๊ฐ„์˜ ์„ ํ˜ธ๋„์™€ ์ ์  ๋” ์ผ์น˜ํ•˜๊ณ  ์žˆ์Œ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. GPT-4์™€ ์ธ๊ฐ„ ๊ฒ€ํ† ์ž ๊ฐ„์˜ ํ•ฉ์˜ ์ˆ˜์ค€์€ ์ด ๋ชจ๋ธ์ด ์ธ๊ฐ„์ด ์ƒ์„ฑํ•œ ์ฝ˜ํ…์ธ ์™€ ๊ฑฐ์˜ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์ธ๊ฐ„ ํ‰๊ฐ€์ž์—๊ฒŒ ๊ณต๊ฐ์„ ์–ป๋Š” ์ฝ˜ํ…์ธ ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์‚ฌ๋ก€ ์—ฐ๊ตฌ ๋ฐ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜

DPO์˜ ํšจ์œจ์„ฑ์„ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ์‹ค์ œ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜๊ณผ ๊ทธ ๋ณ€ํ˜•์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  • ๋ฐ˜๋ณต์ ์ธ DPO: Snorkel(2023)์—์„œ ๊ฐœ๋ฐœํ•œ ์ด ๋ณ€ํ˜•์€ ๊ฑฐ๋ถ€ ์ƒ˜ํ”Œ๋ง๊ณผ DPO๋ฅผ ๊ฒฐํ•ฉํ•˜์—ฌ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ๋ณด๋‹ค ์ •๊ตํ•œ ์„ ํƒ ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ ๋ผ์šด๋“œ์˜ ์„ ํ˜ธ๋„ ์ƒ˜ํ”Œ๋ง์„ ๋ฐ˜๋ณตํ•จ์œผ๋กœ์จ ๋ชจ๋ธ์€ ์žก์Œ์ด ๋งŽ๊ฑฐ๋‚˜ ํŽธํ–ฅ๋œ ์„ ํ˜ธ๋„์— ๋Œ€ํ•œ ๊ณผ์ ํ•ฉ์„ ๋” ์ž˜ ์ผ๋ฐ˜ํ™”ํ•˜๊ณ  ๋ฐฉ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • IPO (๋ฐ˜๋ณต์  ์„ ํ˜ธ๋„ ์ตœ์ ํ™”): Azar ๋“ฑ์ด ์†Œ๊ฐœํ•จ. (2023), IPO๋Š” ์„ ํ˜ธ๋„ ๊ธฐ๋ฐ˜ ์ตœ์ ํ™”์—์„œ ์ผ๋ฐ˜์ ์ธ ๋ฌธ์ œ์ธ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ์ •๊ทœํ™” ์šฉ์–ด๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ด ํ™•์žฅ์„ ํ†ตํ•ด ๋ชจ๋ธ์€ ์„ ํ˜ธ๋„ ์ค€์ˆ˜์™€ ์ผ๋ฐ˜ํ™” ๊ธฐ๋Šฅ ์œ ์ง€ ๊ฐ„์˜ ๊ท ํ˜•์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ณต์‚ฌ(์ง€์‹ ์ด์ „ ์ตœ์ ํ™”): Ethayarajh et al.์˜ ์ตœ์‹  ๋ณ€์ข…์ž…๋‹ˆ๋‹ค. (2023), KTO๋Š” ๋ฐ”์ด๋„ˆ๋ฆฌ ๊ธฐ๋ณธ ์„ค์ •์„ ์™„์ „ํžˆ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค. ๋Œ€์‹ , ์ฐธ์กฐ ๋ชจ๋ธ์˜ ์ง€์‹์„ ์ •์ฑ… ๋ชจ๋ธ๋กœ ์ด์ „ํ•˜๋Š” ๋ฐ ์ค‘์ ์„ ๋‘๊ณ  ์ธ๊ฐ„ ๊ฐ€์น˜์™€ ๋ณด๋‹ค ์›ํ™œํ•˜๊ณ  ์ผ๊ด€๋˜๊ฒŒ ์ •๋ ฌ๋˜๋„๋ก ์ตœ์ ํ™”ํ•ฉ๋‹ˆ๋‹ค.
  • ๋„๋ฉ”์ธ ๊ฐ„ ํ•™์Šต์„ ์œ„ํ•œ ๋‹ค์ค‘ ๋ชจ๋“œ DPO Xu et al. (2024): DPO๊ฐ€ ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€, ์˜ค๋””์˜ค ๋“ฑ ๋‹ค์–‘ํ•œ ํ˜•์‹์— ๊ฑธ์ณ ์ ์šฉ๋˜๋Š” ์ ‘๊ทผ ๋ฐฉ์‹์œผ๋กœ, ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ ์œ ํ˜• ์ „๋ฐ˜์— ๊ฑธ์ณ ์ธ๊ฐ„์˜ ์„ ํ˜ธ๋„์— ๋งž์ถฐ ๋ชจ๋ธ์„ ์กฐ์ •ํ•˜๋Š” ๋‹ค์žฌ๋‹ค๋Šฅํ•จ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด ์—ฐ๊ตฌ๋Š” ๋ณต์žกํ•œ ๋‹ค์ค‘ ๋ชจ๋“œ ์ž‘์—…์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋Š” ๋ณด๋‹ค ํฌ๊ด„์ ์ธ AI ์‹œ์Šคํ…œ์„ ๋งŒ๋“œ๋Š” ๋ฐ ์žˆ์–ด DPO์˜ ์ž ์žฌ๋ ฅ์„ ๊ฐ•์กฐํ•ฉ๋‹ˆ๋‹ค.

๊ฒฐ๋ก 

์ง์ ‘ ์„ ํ˜ธ๋„ ์ตœ์ ํ™”(Direct Preference Optimization)๋Š” ์–ธ์–ด ๋ชจ๋ธ์„ ์ธ๊ฐ„ ์„ ํ˜ธ๋„์— ๋งž์ถ”๋Š” ๋ฐ ์žˆ์–ด ์ƒ๋‹นํ•œ ๋ฐœ์ „์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ๋‹จ์ˆœ์„ฑ, ํšจ์œจ์„ฑ ๋ฐ ํšจ์œจ์„ฑ์œผ๋กœ ์ธํ•ด ์—ฐ๊ตฌ์›๊ณผ ์‹ค๋ฌด์ž ๋ชจ๋‘์—๊ฒŒ ๊ฐ•๋ ฅํ•œ ๋„๊ตฌ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

Direct Preference Optimization์˜ ๊ฐ•๋ ฅํ•œ ๊ธฐ๋Šฅ์„ ํ™œ์šฉํ•˜๊ณ  ์ด๋Ÿฌํ•œ ์›์น™์„ ์—ผ๋‘์— ๋‘๋ฉด ์ธ์ƒ์ ์ธ ๊ธฐ๋Šฅ์„ ๋ณด์—ฌ์ค„ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ธ๊ฐ„์˜ ๊ฐ€์น˜ ๋ฐ ์˜๋„์™€ ๋ฐ€์ ‘ํ•˜๊ฒŒ ์ผ์น˜ํ•˜๋Š” ์–ธ์–ด ๋ชจ๋ธ์„ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ €๋Š” ์ง€๋‚œ 50๋…„ ๋™์•ˆ ๊ธฐ๊ณ„ ํ•™์Šต๊ณผ ๋”ฅ ๋Ÿฌ๋‹์˜ ๋งคํ˜น์ ์ธ ์„ธ๊ณ„์— ๋ชฐ๋‘ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ €์˜ ์—ด์ •๊ณผ ์ „๋ฌธ โ€‹โ€‹์ง€์‹์€ ํŠนํžˆ AI/ML์— ์ค‘์ ์„ ๋‘” XNUMX๊ฐœ ์ด์ƒ์˜ ๋‹ค์–‘ํ•œ ์†Œํ”„ํŠธ์›จ์–ด ์—”์ง€๋‹ˆ์–ด๋ง ํ”„๋กœ์ ํŠธ์— ๊ธฐ์—ฌํ•˜๋„๋ก ์ด๋Œ์—ˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์˜ ๊ณ„์†๋˜๋Š” ํ˜ธ๊ธฐ์‹ฌ์€ ๋˜ํ•œ ๋‚ด๊ฐ€ ๋” ํƒ๊ตฌํ•˜๊ณ  ์‹ถ์€ ๋ถ„์•ผ์ธ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ๋กœ ๋‚˜๋ฅผ ์ด๋Œ์—ˆ์Šต๋‹ˆ๋‹ค.