๐Ÿ“– ใ€ŽJAX/Flax๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ ˆ๋ฒจ์—…ใ€(์ œ์ดํŽ) ๋ฆฌ๋ทฐ

2024-10-13

IMG_4524


IT ์‹ค์šฉ์„œ ์ „๋ฌธ ์ถœํŒ์‚ฌ ์ œ์ดํŽ์œผ๋กœ๋ถ€ํ„ฐ ์ฑ… ใ€ŽJAX/Flax๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ ˆ๋ฒจ์—…ใ€๋ฅผ ๋ฌด๋ฃŒ๋กœ ์ œ๊ณต ๋ฐ›์•˜๋‹ค.

๊ฐœ์š”

๋„์„œ๋ช…
JAX/Flax๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ ˆ๋ฒจ์—…
์ง€์€์ด
์ด์˜๋นˆ , ์œ ํ˜„์•„ , ๊น€ํ•œ๋นˆ , ์กฐ์˜๋นˆ , ์ดํƒœํ˜ธ , ์žฅ์ง„์šฐ , ๋ฐ•์ •ํ˜„ , ๊น€ํ˜•์„ญ , ์ด์Šนํ˜„
๋ฐœํ–‰์‚ฌ
์ œ์ดํŽ
์ดˆํŒ ๋ฐœํ–‰
2024๋…„ 9์›” 23์ผ
์ •๊ฐ€
24,000์›

๋ฒ ํƒ€๋ฆฌ๋” ํ›„๊ธฐ์— ๋”ฐ๋ฅด๋ฉด ใ€ŽJAX/Flax๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ ˆ๋ฒจ์—…ใ€์€ ๋ฌด๋ ค ๊ตญ๋‚ด ์ตœ์ดˆ JAX ์ž…๋ฌธ์„œ๋ผ๊ณ  ํ•œ๋‹ค. ์ตœ๊ทผ ์„ฑ์žฅํ•˜๊ณ  ์žˆ๋Š” JAX ์ƒํƒœ๊ณ„์˜ ํ™œ์„ฑํ™”์— ๊ธฐ์—ฌํ•˜๋Š” ์˜๋ฏธ๊ฐ€ ์žˆ๊ฒ ๋‹ค.

๊ฐœ์ธ์ ์œผ๋กœ JAX๋Š” โ€˜๊ณ ์„ฑ๋Šฅ ๋”ฅ๋Ÿฌ๋‹ ์—ฐ์‚ฐ์ด ๊ฐ€๋Šฅํ•œ numpyโ€™ ์ •๋„๋กœ ์•Œ๊ณ  ์žˆ๋Š” ์ƒํƒœ์˜€๊ณ  ์ง์ ‘ ํ™œ์šฉํ•ด๋ณธ ์ ์€ ์—†์—ˆ๋‹ค. ์ด๋ฒˆ ์ฑ… ๋ฆฌ๋ทฐ๋ฅผ ๊ณ„๊ธฐ๋กœ Numpy์™€์˜ ์ฐจ์ด์ ๊ณผ JAX/Flax์˜ ์ฃผ์š” ํŠน์ง•์„ ์ดํ•ดํ•˜๊ณ  ์‹ค์ œ ํŠœํ† ๋ฆฌ์–ผ์„ ๋”ฐ๋ผ๊ฐ€๋Š” ๊ฒฝํ—˜์„ ์Œ“์œผ๋ ค ํ•œ๋‹ค.

์ฑ…์€ ํฌ๊ฒŒ JAX/Flax๋ฅผ ์†Œ๊ฐœํ•˜๋Š” ๋ถ€๋ถ„๊ณผ JAX/Flax๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ถ€๋ถ„์œผ๋กœ ์ด๋ค„์ ธ ์žˆ๋‹ค. ํŒŒ์ด์ฌ ํ”„๋กœ๊ทธ๋ž˜๋ฐ๊ณผ ๊ธฐ๋ณธ์ ์ธ ๋จธ์‹ ๋Ÿฌ๋‹ ๊ฐœ๋…์€ ์ฑ…์„ ์ฝ๊ธฐ ์œ„ํ•œ ์„ ์ˆ˜ ์ง€์‹์œผ๋กœ ์š”๊ตฌ๋œ๋‹ค.

Jax๋ž€?

ํ•œ๋งˆ๋””๋กœ ํ‘œํ˜„ํ•˜๋ฉด ์ž๋™ ๋ฏธ๋ถ„๊ณผ XLA๋ฅผ ๊ฒฐํ•ฉํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ณ ์„ฑ๋Šฅ ๋จธ์‹ ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ์ž…๋‹ˆ๋‹ค โ€ฆ JAX์˜ ๊ฐ€์žฅ ํฐ ๊ฐ•์ ์€ XLA๋ฅผ ์ ์šฉํ•ด์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

PyTorch, Tensorflow ๋“ฑ ๋‹ค๋ฅธ ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ๋„ ์ž๋™ ๋ฏธ๋ถ„์„ ์ง€์›ํ•˜์ง€๋งŒ, JAX๋Š” ์ด์— ๋”ํ•ด XLA(Accelerated Linear Algebra)์ด ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ๊ฒƒ์ด ํ•ต์‹ฌ์ด๋‹ค. XLA๋Š” GPU/TPU ์œ„์—์„œ numpy๋ฅผ ์ปดํŒŒ์ผํ•˜๊ณ  ์‹คํ–‰ํ•˜๋Š” ์ปดํŒŒ์ผ๋Ÿฌ๋‹ค. JAX๋Š” JIT(Just-In-Time) ์ปดํŒŒ์ผ์„ ํ†ตํ•ด ํŒŒ์ด์ฌ ์ฝ”๋“œ๋ฅผ XLA์— ์ตœ์ ํ™”๋œ ๊ธฐ๊ณ„์–ด๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— PyTorch์˜ ๋™์  ๊ทธ๋ž˜ํ”„๋ณด๋‹ค๋„ ๋น ๋ฅด๊ณ  ํšจ์œจ์ ์œผ๋กœ ์—ฐ์‚ฐํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

ํŒŒ์ด์ฌ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ธํ„ฐํ”„๋ฆฌํ„ฐ ๋ฐฉ์‹์œผ๋กœ ์‹คํ–‰๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ฝ”๋“œ๋ฅผ ํ•œ ์ค„์”ฉ ์ฝ๊ณ  ํ•ด์„ํ•˜๋Š” ๋ฐ ์‹œ๊ฐ„์ด ์†Œ์š”๋œ๋‹ค. ์—ฌ๊ธฐ์„œ JIT ์ปดํŒŒ์ผ์„ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ์‹œ์ ์— ์„ฑ๋Šฅ๊ณผ ์—ฐ๊ด€๋˜๋Š” ์ผ๋ถ€๋ถ„์„ ๋ฏธ๋ฆฌ ๊ธฐ๊ณ„์–ด๋กœ ์ปดํŒŒ์ผํ•˜์—ฌ ์†๋„๊ฐ€ ๋นจ๋ผ์ง„๋‹ค๊ณ  ์ดํ•ดํ–ˆ๋‹ค.

Flax

JAX + Flexibility๋ฅผ ํ•ฉ์ณ์ ธ์„œ ๋งŒ๋“ค์—ˆ์œผ๋ฉฐ ์—”์ง€๋‹ˆ์–ด๋“ค์ด JAX๋ฅผ ์กฐ๊ธˆ ๋” ์‰ฝ๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ํ”„๋ ˆ์ž„์›Œํฌ์ด๋ฉฐ, ๋‹ค๋ฅธ ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ๋“ค์ฒ˜๋Ÿผ ๋ ˆ์ด์–ด(์ธต) ๊ฐœ๋…์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ๊นŒ์ง€ ์ฝ์—ˆ์„ ๋•Œ Tensorflow & Keras ์™€ ์œ ์‚ฌํ•œ ๊ฐœ๋…(๊ด€๊ณ„)์ด ์•„๋‹Œ๊ฐ€ ์‹ถ์—ˆ๋Š”๋ฐ, JAX/Flax๋Š” Low-level์˜ ์„ฌ์„ธํ•œ ์ปจํŠธ๋กค์ด ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ์ ์— ๋ฐฉ์ ์ด ์ฐํ˜€ ์žˆ๋Š” ๊ฒƒ ๊ฐ™๋‹ค. ๊ทธ์™€ ๋‹ฌ๋ฆฌ Keras๋Š” ๋†’์€ ์ˆ˜์ค€์˜ ์ถ”์ƒํ™”๊ฐ€ ์ด๋ฃจ์–ด์ ธ ์žˆ๊ณ  ์‚ฌ์šฉ์ž ์นœํ™”์ ์ด๋‹ค. ๋˜‘๊ฐ™์ด ๊ตฌ๊ธ€์—์„œ ๊ฐœ๋ฐœํ•œ ํ”„๋ ˆ์ž„์›Œํฌ์ง€๋งŒ ์ง€ํ–ฅํ•˜๋Š” ์ฒ ํ•™์ด ๋‹ค๋ฅด๋‹ค๋Š” ์ ์ด ์žฌ๋ฐŒ๋‹ค.

์ฑ…์— ๋”ฐ๋ฅด๋ฉด ๊ตฌ๊ธ€์—์„œ ๊ฐœ๋ฐœํ•œ ๋ชจ๋ธ๋“ค์€ ๋Œ€๋ถ€๋ถ„ JAX๋กœ ์ž‘์„ฑ๋˜์–ด ์žˆ๊ณ , ์‹ฌ์ง€์–ด Hugging Face์˜ ๊ธฐ์กด ๋ชจ๋ธ๋“ค๋„ JAX๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์žˆ๋‹ค๊ณ  ํ•œ๋‹ค.

ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ

JAX/Flax์˜ ํ™œ์šฉ ๋ฐฉ์‹์„ ๋” ์ž˜ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋„๋ก ์ฑ…์€ ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ์— ๋Œ€ํ•ด์„œ ๋ณ„๋„ ์ฑ•ํ„ฐ๋กœ ์„ค๋ช…ํ•œ๋‹ค. ๋ช…๋ น์–ด์˜ ํ๋ฆ„(์ˆœ์„œ)๋Œ€๋กœ ์ƒํƒœ๋ฅผ ๋ณ€๊ฒฝํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ด ํ•ต์‹ฌ์ธ ์ ˆ์ฐจ์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ๊ณผ ๋‹ค๋ฅด๊ฒŒ, ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ์€ ์™ธ๋ถ€ ์ƒํƒœ์™€ ์ƒ๊ด€์—†์ด ์ฃผ์–ด์ง„ ์ž…๋ ฅ์— ๋™์ผํ•œ ์ถœ๋ ฅ๊ฐ’์„ ๋‚ด๋†“๋Š” ์ˆœ์ˆ˜ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ๋ถ€์ˆ˜ ํšจ๊ณผ๊ฐ€ ์ œ๊ฑฐ๋˜๋ฉฐ ์ƒํƒœ๊ฐ€ ๋ณ€ํ™”ํ•˜์ง€ ์•Š๋Š” ๋ถˆ๋ณ€์„ฑ์„ ๊ฐ•์กฐํ•œ๋‹ค. ์—ฌ๊ธฐ์„œ ์ ˆ์ฐจ์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ๊ณผ ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ์„ ์„ค๋ช…ํ•  ๋•Œ ๊ฐ„๋‹จํ•œ ํŒŒ์ด์ฌ ์˜ˆ์ œ๊ฐ€ ์ฒจ๋ถ€๋˜์–ด ์žˆ์–ด์„œ ์ดํ•ด๊ฐ€ ํŽธํ–ˆ๋‹ค.

JAX, ๋‚˜์•„๊ฐ€ ๋”ฅ๋Ÿฌ๋‹ ์—ฐ์‚ฐ์— ์žˆ์–ด์„œ ์ด ๊ฐœ๋…์„ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•œ ์ด์œ ๋ฅผ ์„ธ ๊ฐ€์ง€๋กœ ์ œ์‹œํ•˜๊ณ  ์žˆ๋‹ค.

  1. XLA ์ปดํŒŒ์ผ์— ์ตœ์ ํ™”๋œ ์ฒ˜๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•ด์ง„๋‹ค
  2. ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ์™€ ๋ถ„์‚ฐ์ฒ˜๋ฆฌ์— ์œ ์šฉํ•˜๋‹ค
  3. ์ฝ”๋“œ๋ฅผ ๋ชจ๋“ˆํ™”ํ•จ์œผ๋กœ์จ ์žฌ์‚ฌ์šฉ์„ฑ์ด ๋†’์•„์ง„๋‹ค

JAX ๊ธฐ๋ณธ

๋ฐฑ๋ฌธ์ด ๋ถˆ์—ฌ์ผ๊ฒฌ, ์ง์ ‘ JAX ๋ฅผ ํ™œ์šฉํ•ด๋ณด๋ฉฐ ์ฑ…์˜ ๋‚ด์šฉ์„ ๋”ฐ๋ผ๊ฐ€๋ณด๊ฒ ๋‹ค.

์„ค์น˜

๋‹คํ–‰ํžˆ๋„ JAX๊ฐ€ Mac M1์„ ๊ณต์‹ ์ง€์›ํ•œ๋‹ค๊ณ  ํ•˜์—ฌ conda๋กœ ์‰ฝ๊ฒŒ ์„ค์น˜ํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค.

conda create -n jax-env python=3.9
conda activate jax-env
pip install jax jaxlib

import ํ•˜๊ธฐ

import jax
import jax.numpy as jnp

numpy์™€ ๋น„๊ต

x1 = jnp.array([1.0, 2.0, 3.0])
x2 = jnp.array([4.0, 5.0, 6.0])
y = x1 + x2
print(y)        # [5. 7. 9.]
print(type(y))  # <class 'jaxlib.xla_extension.ArrayImpl'>

์œ„ ์˜ˆ์ œ์—์„œ ๋ณด๋“ฏ jax.numpy๋Š” ๊ธฐ์กด numpy ์™€ ๊ฑฐ์˜ ๋˜‘๊ฐ™์€ API๋ฅผ ์ œ๊ณตํ•˜๊ณ  ์žˆ๋‹ค.

grad ํ•จ์ˆ˜

def func(x):
    return x**2

grad = jax.grad(func)
print(grad(3.))    # Array(6., dtype=float32, weak_type=True)

JAX์—์„œ ๋ฏธ๋ถ„, ์ฆ‰ gradient๋ฅผ ๊ณ„์‚ฐํ•ด์ฃผ๋Š” grad๋ฅผ ์‚ฌ์šฉํ•œ ์˜ˆ์ œ๋‹ค.

๋ถ€์ˆ˜ ํšจ๊ณผ์˜ ๋ฐฉ์ง€

JAX๋Š” ๋ถ€์ˆ˜ ํšจ๊ณผ๋ฅผ ์ œ๊ฑฐํ•˜๋Š” ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ์˜ ์ œ์•ฝ์„ ๋”ฐ๋ฅด๊ณ  ์žˆ๋‹ค. ์ฑ…์—์„œ ์ œ๊ณตํ•ด์ค€ ์•„๋ž˜ ์˜ˆ์ œ๋ฅผ ์ฐธ๊ณ ํ•ด๋ณด์ž.

x_1 = np.array([1, 2, 3])
x_1[0] = 999
print(x_1)      # [999   2   3]

numpy๋กœ ์ƒ์„ฑํ•œ ๋ฐฐ์—ด์€ ์ง์ ‘ ์ ‘๊ทผํ•ด์„œ ์š”์†Œ๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ๋‹ค.

x_2 = jnp.array([1, 2, 3])
x_2[0] = 999
# TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.

๋ฐ˜๋ฉด jax.numpy๋กœ ์ƒ์„ฑํ•œ ๋ฐฐ์—ด์€ ์ง์ ‘์ ์ธ ์ˆ˜์ •์„ ํ—ˆ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค. ์ด๋Š” โ€˜์™ธ๋ถ€ ๋ฐ์ดํ„ฐโ€™์ธ ๋ฐฐ์—ด์˜ ์ƒํƒœ๊ฐ€ ๋ณ€ํ˜•๋˜๋ฉด์„œ ๋ถ€์ˆ˜ ํšจ๊ณผ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•จ์ด๋‹ค.
๋งŒ์•ฝ ๋ฐฐ์—ด์˜ ์ผ๋ถ€๋ฅผ ์ˆ˜์ •ํ•˜๋Š” ์ž‘์—…์„ ์ง„ํ–‰ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ๋ถ€์ˆ˜ ํšจ๊ณผ๊ฐ€ ์—†๋Š” ์ˆœ์ˆ˜ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค.

x_2 = jnp.array([1, 2, 3])
def modify(x):
    return x.at[0].set(999)
y = modify(x_2)
print(y)    # Array([999,   2,   3], dtype=int32)

์—ฌ๊ธฐ์„œ modify(x)๋Š” ๋ถ€์ˆ˜ ํšจ๊ณผ๊ฐ€ ์—†๋Š” ์ˆœ์ˆ˜ ํ•จ์ˆ˜๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ด๋ฉฐ, ์ฑ…์€ jax.grad์™€ jax.jit ๊ฐ™์€ ํ•จ์ˆ˜๋Š” ์ˆœ์ˆ˜ ํ•จ์ˆ˜๋กœ ์ž‘์„ฑ๋˜์–ด์•ผ ํ•œ๋‹ค๊ณ  ์„ค๋ช…ํ•˜๊ณ  ์žˆ๋‹ค.

JIT ์ปดํŒŒ์ผ

๋ณ€ํ™˜
์ฃผ์–ด์ง„ ํ•จ์ˆ˜๋ฅผ ๋ณ€๊ฒฝํ•˜๊ฑฐ๋‚˜ ์ˆ˜์ •ํ•˜๋Š” ๋ฐฉ์‹. ์„ฑ๋Šฅ ์ตœ์ ํ™”๋‚˜ ์ž๋™ ๋ฏธ๋ถ„์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ.

์ฑ…์€ JAX์—์„œ ๋ณ€ํ™˜(transformation)์ด๋ผ๋Š” ํ‚ค์›Œ๋“œ๊ฐ€ ์ค‘์š”ํ•˜๋‹ค๊ณ  ๋งํ•œ๋‹ค. JAX์—์„œ ๋ณ€ํ™˜์€ jaxpr, ์ฆ‰ JAX ํ‘œํ˜„์‹์ด๋ผ๋Š” ์ค‘๊ฐ„ ์–ธ์–ด(intermediate language)๋ฅผ ํ†ตํ•ด ์ด๋ฃจ์–ด์ง„๋‹ค. jax.jit๊ฐ€ ๋Œ€ํ‘œ์ ์ธ jax ๋ณ€ํ™˜์ด๋ผ๊ณ  ์†Œ๊ฐœ๋œ๋‹ค.

def selu(x, alpha=1.67, lamdba_=1.05):
    return lamdba_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)

# ์ผ๋ฐ˜
selu(x).block_until_ready()

# XLA
selu_jit = jax.jit(selu)
selu_jit(x).block_until_ready() # ๋น„๋™๊ธฐ ์‹คํ–‰

์œ„ ๋‚ด์šฉ์€ ํ™œ์„ฑํ™” ํ•จ์ˆ˜ SELU(Scaled Exponential Linear Unit)๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ  ํ˜ธ์ถœํ•œ ๋‚ด์šฉ์ด๋‹ค.

image

selu(x) ๋Œ€์‹ ์— jit ๋ณ€ํ™˜์„ ์ ์šฉํ•œ selu_jit(x)๊ฐ€ 7๋ฐฐ ๋น ๋ฅด๋‹ค๊ณ  ์„ค๋ช…ํ•˜๊ณ  ์žˆ๋‹ค. (๊ตฌ๊ธ€ Colab T4 ๊ธฐ์ค€)

์ฑ…์„ ์ฝ์œผ๋ฉด์„œ ํฅ๋ฏธ๋กœ์› ๋˜ ๋ถ€๋ถ„์€ jax.jit์€ ์ปดํŒŒ์ผ๋œ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ์บ์‹ฑํ•˜์—ฌ ์žฌ์‚ฌ์šฉํ•œ๋‹ค๋Š” ์ ์ด์—ˆ๋‹ค. ๋‹ค๋งŒ jax.jit์„ ๋ฐ˜๋ณต๋ฌธ ๋‚ด๋ถ€์—์„œ ํ˜ธ์ถœํ•  ๊ฒฝ์šฐ ์ปดํŒŒ์ผ ๊ณผ์ •์ด ๋ถˆํ•„์š”ํ•˜๊ฒŒ ๋ฐ˜๋ณต๋  ์ˆ˜ ์žˆ์œผ๋‹ˆ ์ง€์–‘ํ•˜๋ผ๊ณ  ์•ˆ๋‚ดํ•˜๊ณ  ์žˆ๋‹ค.

Flax

๋งˆ์ง€๋ง‰์œผ๋กœ Flax ๋ฅผ ํ™œ์šฉํ•œ ์˜ˆ์ œ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ๋‹ค.

import flax.linen as nn
from jax import random

key = random.PRNGKey(42)

class MLP(nn.Module):
    out_dims: int
    
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)
        return x

model = MLP(out_dims=10)
x = jnp.empty((4, 28, 28, 1))

weights = model.init(key, x)
y = model.apply(weights, x)

์ฑ…์— import ํ•˜๋Š” ๋ถ€๋ถ„์€ ์—†์–ด์„œ ์ถ”๊ฐ€ํ–ˆ๋‹ค
nn.Module ์—์„œ ์ƒ์†๋ฐ›์•„ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•œ๋‹ค๋Š” ์ ์—์„œ PyTorch ์™€ ์œ ์‚ฌํ•œ ๋ฐฉ์‹์˜ API ๋ผ๊ณ  ๋Š๊ปด์กŒ๊ณ  ๊ธˆ๋ฐฉ ์ ์‘ํ•  ์ˆ˜ ์žˆ๊ฒ ๋‹ค๋Š” ์ƒ๊ฐ์ด ๋“ ๋‹ค.


๋‚˜๊ฐ€๋ฉฐ

ใ€ŽJAX/Flax๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ ˆ๋ฒจ์—…ใ€์—์„œ JAX ํ•ต์‹ฌ ๊ฐœ๋…์„ ์œ„์ฃผ๋กœ ์‚ดํŽด๋ณด๋ฉฐ ์ฑ…์„ ๋ฆฌ๋ทฐํ•ด๋ณด์•˜๋‹ค. ์š”์ฆ˜ ์‹œ์ ์—์„œ ์™œ ๊ตณ์ด ์ƒˆ๋กœ์šด ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ๊ฐ€ ํ•„์š”ํ• ๊นŒ? ๋ผ๊ณ  ๋ง‰์—ฐํžˆ ๊ถ๊ธˆํ•ด ํ•˜๋ฉฐ ๋ฆฌ๋ทฐ์–ด ์‹ ์ฒญ์„ ํ–ˆ๋Š”๋ฐ, ์ข‹์€ ๊ธฐํšŒ๋กœ ์ฑ…๋„ ์ œ๊ณต ๋ฐ›๊ณ  JAX์™€ Flax์— ๋Œ€ํ•ด ๊ฐ€๋ณ๊ฒŒ ๋ฐฐ์›Œ๋ณผ ์ˆ˜ ์žˆ๋Š” ์‹œ๊ฐ„์ด์—ˆ๋‹ค.

JAX๊ฐ€ ์ง€ํ–ฅํ•˜๋Š” ์ฒ ํ•™๊ณผ ํ•จ๊ป˜ ๊ทธ๊ฒƒ์ด ๋…น์•„๋“  ํ•ต์‹ฌ ๊ธฐ๋Šฅ์„ ์„ธ์„ธํ•˜๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์— JAX ์ž…๋ฌธ์„œ๋กœ ์•„์ฃผ ์•Œ๋งž์€ ๋„์„œ๋ผ๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋‹ค. ํŠนํžˆ ํ•จ์ˆ˜ํ˜• ํ”„๋กœ๊ทธ๋ž˜๋ฐ ๊ฐœ๋…๋งŒ์„ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ๋ณ„๋„ ์ง€๋ฉด์„ ํ• ์• ํ–ˆ๋‹ค๋Š” ์ ์—์„œ๋Š” JAX์˜ ์˜๋ฏธ๋ฅผ ์ œ๋Œ€๋กœ ์ „๋‹ฌํ•˜๊ฒ ๋‹ค๋Š” ๊ฐ•ํ•œ ์˜์ง€๋„ ๋ณด์˜€๋‹ค.

์•„๋งˆ JAX๋กœ ์ž…๋ฌธํ•˜๊ธฐ๊นŒ์ง€ ๊ฐ€์žฅ ํฐ ์žฅ๋ฒฝ์€ ์•ž์„œ ๋‚ด๊ฐ€ ๋– ์˜ฌ๋ฆฐ ๊ฒƒ๊ณผ ๊ฐ™์ด โ€œ์™œ ๊ผญ ์ด๊ฒƒ์ด์–ด์•ผ ํ•˜๋Š”๊ฐ€?โ€ ๋ผ๋Š” ์˜๋ฌธ์ผ ํ…๋ฐ, ์ด ์ฑ…์„ ์ฝ๋Š”๋‹ค๋ฉด ๊ทธ ์žฅ๋ฒฝ ์ •๋„๋Š” ์ถฉ๋ถ„ํžˆ ๋„˜์„ ์ˆ˜ ์žˆ๊ฒ ๋‹ค. numpy ํ•˜๋‚˜๋กœ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•˜๋Š” ์ •๋„๋กœ low-level์—์„œ ๋ชจ๋ธ๋ง๊ณผ ํ•™์Šต ๊ณผ์ • ๋“ฑ์„ ์œ ์—ฐํ•˜๊ฒŒ ํ†ต์ œํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์ด JAX/Flax์˜ ๊ฐ€์žฅ ๊ฐ•๋ ฅํ•œ ์ •์ฒด์„ฑ์ด๋ผ๊ณ  ๋Š๊ผˆ๋‹ค. ๊ตฌ๊ธ€ TPU๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ML์—”์ง€๋‹ˆ์–ด๋ผ๋ฉด ์‹œ๊ฐ„์„ ๋“ค์—ฌ์„œ๋ผ๋„ JAX๋ฅผ ์ ์šฉํ•  ๊ฐ€์น˜๊ฐ€ ์žˆ์„ ๋“ฏํ•˜๋‹ค.

๊ทธ ์™ธ์— CLIP, GPT ๊ฐ™์€ ์ตœ์‹  ๋ชจ๋ธ์˜ fine-tuning ์„ ์˜ˆ์ œ๋กœ ๋‹ค๋ฃฌ ์ ๋„ ์ธ์ƒ์ ์ด์—ˆ๋‹ค. ML ๋ถ„์•ผ์— ์ž…๋ฌธํ•œ ์ดํ›„๋กœ ์ถœํŒ์„œ๋ฅผ ํ™œ์šฉํ•ด์„œ ๊ณต๋ถ€๋ฅผ ํ•˜๋Š” ๊ฑด ์ •๋ง ์˜ค๋žœ๋งŒ์ธ๋ฐ, ์—ญ์‹œ ์ตœ์‹  ์ฑ…์ด๋‹ˆ ์ตœ์‹  ๋ชจ๋ธ๋„ ๋‹ค๋ฃจ๋Š”๊ตฌ๋‚˜ - ์‹ถ์—ˆ๋‹ค.

๋‹ค๋งŒ ์ฑ… ์„œ๋ฌธ์—์„œ ์ด๋ฏธ ๋ฐํ˜”๋‹ค์‹œํ”ผ ๋”ฅ๋Ÿฌ๋‹ ๊ฐœ๋…๊ณผ ํ”„๋ ˆ์ž„์›Œํฌ์— ๋Œ€ํ•œ ๊ธฐ๋ณธ์ ์ธ ์ง€์‹์ด ์žˆ์–ด์•ผ ์ฑ…์˜ ๋‚ด์šฉ์„ ์ œ๋Œ€๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์€ ์—ผ๋‘์— ๋‘˜ ํ•„์š”๊ฐ€ ์žˆ๊ฒ ๋‹ค. ํ™•์‹คํžˆ โ€˜์ดˆ๊ธ‰์„œโ€™๋ผ๊ธฐ๋ณด๋‹จ โ€˜์ž…๋ฌธ์„œโ€™๋กœ ๋ณด๋Š” ๊ฒŒ ๋งž๋‹ค. ๋˜ํ•œ ์›Œ๋‚™์— ๊ณ ๊ธ‰ ํ”„๋ ˆ์ž„์›Œํฌ๋‹ค๋ณด๋‹ˆ๊นŒ JAX/Flax ์ž์ฒด๊ฐ€ ์„œ๋น„์Šค(์„œ๋น™)๋ณด๋‹ค๋Š” ์—ฐ๊ตฌ์— ์ ํ•ฉํ•œ ๋„๊ตฌ๋ผ๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋‹ค. ์ฑ… ์ฝ๊ธฐ ์ „๊ณผ ๋น„์Šทํ•˜๊ฒŒ ์ด๊ฒƒ์ด ์ด๊ฒƒ์ด ํ•„์š”ํ• ๊นŒ? ๋ผ๋Š” ์งˆ๋ฌธ์€ ์—ฌ์ „ํžˆ ๊น”๋”ํ•˜๊ฒŒ ํ•ด๋ช…๋˜์ง„ ์•Š์•˜์ง€๋งŒ, ์ฑ…์„ ์ฝ๊ณ  ๋‚˜์„œ ์–ธ์  ๊ฐ€ JAX๋ฅผ ์จ๋ณด๊ณ  ์‹ถ๋‹ค๋Š” ์š•์‹ฌ์€ ๋ณด๋‹ค ๋šœ๋ ทํ•ด์กŒ๋‹ค.

๋ฆฌ๋ทฐ์–ด๋กœ ์„ ์ •ํ•˜์—ฌ ๋„์„œ๋ฅผ ์ œ๊ณตํ•ด์ค€ ์ถœํŒ์‚ฌ ์ œ์ดํŽ์— ์ง„์‹ฌ์œผ๋กœ ๊ฐ์‚ฌํ•˜๋‹ค๋Š” ๋ง์”€์„ ํ‘œํ•˜๋ฉฐ ๋ณธ ๋ฆฌ๋ทฐ๋ฅผ ๋งˆ๋ฌด๋ฆฌํ•˜๊ฒ ๋‹ค.