๐ ใJAX/Flax๋ก ๋ฅ๋ฌ๋ ๋ ๋ฒจ์ ใ(์ ์ดํ) ๋ฆฌ๋ทฐ
IT ์ค์ฉ์ ์ ๋ฌธ ์ถํ์ฌ ์ ์ดํ์ผ๋ก๋ถํฐ ์ฑ
ใJAX/Flax๋ก ๋ฅ๋ฌ๋ ๋ ๋ฒจ์
ใ๋ฅผ ๋ฌด๋ฃ๋ก ์ ๊ณต ๋ฐ์๋ค.
๊ฐ์
- ๋์๋ช
- JAX/Flax๋ก ๋ฅ๋ฌ๋ ๋ ๋ฒจ์
- ์ง์์ด
- ์ด์๋น , ์ ํ์ , ๊นํ๋น , ์กฐ์๋น , ์ดํํธ , ์ฅ์ง์ฐ , ๋ฐ์ ํ , ๊นํ์ญ , ์ด์นํ
- ๋ฐํ์ฌ
- ์ ์ดํ
- ์ดํ ๋ฐํ
- 2024๋ 9์ 23์ผ
- ์ ๊ฐ
- 24,000์
๋ฒ ํ๋ฆฌ๋ ํ๊ธฐ์ ๋ฐ๋ฅด๋ฉด ใJAX/Flax๋ก ๋ฅ๋ฌ๋ ๋ ๋ฒจ์ ใ์ ๋ฌด๋ ค ๊ตญ๋ด ์ต์ด JAX ์ ๋ฌธ์๋ผ๊ณ ํ๋ค. ์ต๊ทผ ์ฑ์ฅํ๊ณ ์๋ JAX ์ํ๊ณ์ ํ์ฑํ์ ๊ธฐ์ฌํ๋ ์๋ฏธ๊ฐ ์๊ฒ ๋ค.
๊ฐ์ธ์ ์ผ๋ก JAX๋ โ๊ณ ์ฑ๋ฅ ๋ฅ๋ฌ๋ ์ฐ์ฐ์ด ๊ฐ๋ฅํ numpyโ ์ ๋๋ก ์๊ณ ์๋ ์ํ์๊ณ ์ง์ ํ์ฉํด๋ณธ ์ ์ ์์๋ค. ์ด๋ฒ ์ฑ ๋ฆฌ๋ทฐ๋ฅผ ๊ณ๊ธฐ๋ก Numpy์์ ์ฐจ์ด์ ๊ณผ JAX/Flax์ ์ฃผ์ ํน์ง์ ์ดํดํ๊ณ ์ค์ ํํ ๋ฆฌ์ผ์ ๋ฐ๋ผ๊ฐ๋ ๊ฒฝํ์ ์์ผ๋ ค ํ๋ค.
์ฑ ์ ํฌ๊ฒ JAX/Flax๋ฅผ ์๊ฐํ๋ ๋ถ๋ถ๊ณผ JAX/Flax๋ฅผ ํ์ฉํ์ฌ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๊ตฌํํ๋ ๋ถ๋ถ์ผ๋ก ์ด๋ค์ ธ ์๋ค. ํ์ด์ฌ ํ๋ก๊ทธ๋๋ฐ๊ณผ ๊ธฐ๋ณธ์ ์ธ ๋จธ์ ๋ฌ๋ ๊ฐ๋ ์ ์ฑ ์ ์ฝ๊ธฐ ์ํ ์ ์ ์ง์์ผ๋ก ์๊ตฌ๋๋ค.
Jax๋?
ํ๋ง๋๋ก ํํํ๋ฉด ์๋ ๋ฏธ๋ถ๊ณผ XLA๋ฅผ ๊ฒฐํฉํด์ ์ฌ์ฉํ๋ ๊ณ ์ฑ๋ฅ ๋จธ์ ๋ฌ๋ ํ๋ ์์ํฌ์ ๋๋ค โฆ JAX์ ๊ฐ์ฅ ํฐ ๊ฐ์ ์ XLA๋ฅผ ์ ์ฉํด์ ์ฌ์ฉํ ์ ์๋ค๋ ์ ์ ๋๋ค.
PyTorch, Tensorflow ๋ฑ ๋ค๋ฅธ ๋ฅ๋ฌ๋ ํ๋ ์์ํฌ๋ ์๋ ๋ฏธ๋ถ์ ์ง์ํ์ง๋ง, JAX๋ ์ด์ ๋ํด XLA(Accelerated Linear Algebra)์ด ๊ฐ๋ฅํ๋ค๋ ๊ฒ์ด ํต์ฌ์ด๋ค. XLA๋ GPU/TPU ์์์ numpy๋ฅผ ์ปดํ์ผํ๊ณ ์คํํ๋ ์ปดํ์ผ๋ฌ๋ค. JAX๋ JIT(Just-In-Time) ์ปดํ์ผ์ ํตํด ํ์ด์ฌ ์ฝ๋๋ฅผ XLA์ ์ต์ ํ๋ ๊ธฐ๊ณ์ด๋ก ๋ณํํ๊ธฐ ๋๋ฌธ์ PyTorch์ ๋์ ๊ทธ๋ํ๋ณด๋ค๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ผ๋ก ์ฐ์ฐํ ์ ์๋ค๋ ๊ฒ์ด๋ค.
ํ์ด์ฌ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ธํฐํ๋ฆฌํฐ ๋ฐฉ์์ผ๋ก ์คํ๋๊ธฐ ๋๋ฌธ์ ์ฝ๋๋ฅผ ํ ์ค์ฉ ์ฝ๊ณ ํด์ํ๋ ๋ฐ ์๊ฐ์ด ์์๋๋ค. ์ฌ๊ธฐ์ JIT ์ปดํ์ผ์ ์ฌ์ฉํ๋ค๋ฉด ์ฝ๋๋ฅผ ์คํํ๋ ์์ ์ ์ฑ๋ฅ๊ณผ ์ฐ๊ด๋๋ ์ผ๋ถ๋ถ์ ๋ฏธ๋ฆฌ ๊ธฐ๊ณ์ด๋ก ์ปดํ์ผํ์ฌ ์๋๊ฐ ๋นจ๋ผ์ง๋ค๊ณ ์ดํดํ๋ค.
Flax
JAX + Flexibility๋ฅผ ํฉ์ณ์ ธ์ ๋ง๋ค์์ผ๋ฉฐ ์์ง๋์ด๋ค์ด JAX๋ฅผ ์กฐ๊ธ ๋ ์ฝ๊ฒ ์ฌ์ฉํ ์ ์๋ ํ๋ ์์ํฌ์ด๋ฉฐ, ๋ค๋ฅธ ๋ฅ๋ฌ๋ ํ๋ ์์ํฌ๋ค์ฒ๋ผ ๋ ์ด์ด(์ธต) ๊ฐ๋ ์ ์ง์ํฉ๋๋ค.
์ฌ๊ธฐ๊น์ง ์ฝ์์ ๋ Tensorflow & Keras ์ ์ ์ฌํ ๊ฐ๋ (๊ด๊ณ)์ด ์๋๊ฐ ์ถ์๋๋ฐ, JAX/Flax๋ Low-level์ ์ฌ์ธํ ์ปจํธ๋กค์ด ๊ฐ๋ฅํ๋ค๋ ์ ์ ๋ฐฉ์ ์ด ์ฐํ ์๋ ๊ฒ ๊ฐ๋ค. ๊ทธ์ ๋ฌ๋ฆฌ Keras๋ ๋์ ์์ค์ ์ถ์ํ๊ฐ ์ด๋ฃจ์ด์ ธ ์๊ณ ์ฌ์ฉ์ ์นํ์ ์ด๋ค. ๋๊ฐ์ด ๊ตฌ๊ธ์์ ๊ฐ๋ฐํ ํ๋ ์์ํฌ์ง๋ง ์งํฅํ๋ ์ฒ ํ์ด ๋ค๋ฅด๋ค๋ ์ ์ด ์ฌ๋ฐ๋ค.
์ฑ ์ ๋ฐ๋ฅด๋ฉด ๊ตฌ๊ธ์์ ๊ฐ๋ฐํ ๋ชจ๋ธ๋ค์ ๋๋ถ๋ถ JAX๋ก ์์ฑ๋์ด ์๊ณ , ์ฌ์ง์ด Hugging Face์ ๊ธฐ์กด ๋ชจ๋ธ๋ค๋ JAX๋ก ๋ณํํ๊ณ ์๋ค๊ณ ํ๋ค.
ํจ์ํ ํ๋ก๊ทธ๋๋ฐ
JAX/Flax์ ํ์ฉ ๋ฐฉ์์ ๋ ์ ํ์ ํ ์ ์๋๋ก ์ฑ ์ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ๋ํด์ ๋ณ๋ ์ฑํฐ๋ก ์ค๋ช ํ๋ค. ๋ช ๋ น์ด์ ํ๋ฆ(์์)๋๋ก ์ํ๋ฅผ ๋ณ๊ฒฝํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ์ ๋ฌํ๋ ๊ฒ์ด ํต์ฌ์ธ ์ ์ฐจ์ ํ๋ก๊ทธ๋๋ฐ๊ณผ ๋ค๋ฅด๊ฒ, ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ์ธ๋ถ ์ํ์ ์๊ด์์ด ์ฃผ์ด์ง ์ ๋ ฅ์ ๋์ผํ ์ถ๋ ฅ๊ฐ์ ๋ด๋๋ ์์ ํจ์๋ฅผ ์ฌ์ฉํ๋ค. ๋ฐ๋ผ์ ๋ถ์ ํจ๊ณผ๊ฐ ์ ๊ฑฐ๋๋ฉฐ ์ํ๊ฐ ๋ณํํ์ง ์๋ ๋ถ๋ณ์ฑ์ ๊ฐ์กฐํ๋ค. ์ฌ๊ธฐ์ ์ ์ฐจ์ ํ๋ก๊ทธ๋๋ฐ๊ณผ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ์ค๋ช ํ ๋ ๊ฐ๋จํ ํ์ด์ฌ ์์ ๊ฐ ์ฒจ๋ถ๋์ด ์์ด์ ์ดํด๊ฐ ํธํ๋ค.
JAX, ๋์๊ฐ ๋ฅ๋ฌ๋ ์ฐ์ฐ์ ์์ด์ ์ด ๊ฐ๋ ์ ์ดํดํ๋ ๊ฒ์ด ์ค์ํ ์ด์ ๋ฅผ ์ธ ๊ฐ์ง๋ก ์ ์ํ๊ณ ์๋ค.
- XLA ์ปดํ์ผ์ ์ต์ ํ๋ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํด์ง๋ค
- ๋ณ๋ ฌ์ฒ๋ฆฌ์ ๋ถ์ฐ์ฒ๋ฆฌ์ ์ ์ฉํ๋ค
- ์ฝ๋๋ฅผ ๋ชจ๋ํํจ์ผ๋ก์จ ์ฌ์ฌ์ฉ์ฑ์ด ๋์์ง๋ค
JAX ๊ธฐ๋ณธ
๋ฐฑ๋ฌธ์ด ๋ถ์ฌ์ผ๊ฒฌ, ์ง์ JAX ๋ฅผ ํ์ฉํด๋ณด๋ฉฐ ์ฑ ์ ๋ด์ฉ์ ๋ฐ๋ผ๊ฐ๋ณด๊ฒ ๋ค.
์ค์น
๋คํํ๋ JAX๊ฐ Mac M1์ ๊ณต์ ์ง์ํ๋ค๊ณ ํ์ฌ conda
๋ก ์ฝ๊ฒ ์ค์นํ ์ ์์๋ค.
conda create -n jax-env python=3.9
conda activate jax-env
pip install jax jaxlib
import ํ๊ธฐ
import jax
import jax.numpy as jnp
numpy์ ๋น๊ต
x1 = jnp.array([1.0, 2.0, 3.0])
x2 = jnp.array([4.0, 5.0, 6.0])
y = x1 + x2
print(y) # [5. 7. 9.]
print(type(y)) # <class 'jaxlib.xla_extension.ArrayImpl'>
์ ์์ ์์ ๋ณด๋ฏ jax.numpy๋ ๊ธฐ์กด numpy ์ ๊ฑฐ์ ๋๊ฐ์ API๋ฅผ ์ ๊ณตํ๊ณ ์๋ค.
grad ํจ์
def func(x):
return x**2
grad = jax.grad(func)
print(grad(3.)) # Array(6., dtype=float32, weak_type=True)
JAX์์ ๋ฏธ๋ถ, ์ฆ gradient๋ฅผ ๊ณ์ฐํด์ฃผ๋ grad
๋ฅผ ์ฌ์ฉํ ์์ ๋ค.
๋ถ์ ํจ๊ณผ์ ๋ฐฉ์ง
JAX๋ ๋ถ์ ํจ๊ณผ๋ฅผ ์ ๊ฑฐํ๋ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ์ ์ ์ฝ์ ๋ฐ๋ฅด๊ณ ์๋ค. ์ฑ ์์ ์ ๊ณตํด์ค ์๋ ์์ ๋ฅผ ์ฐธ๊ณ ํด๋ณด์.
x_1 = np.array([1, 2, 3])
x_1[0] = 999
print(x_1) # [999 2 3]
numpy๋ก ์์ฑํ ๋ฐฐ์ด์ ์ง์ ์ ๊ทผํด์ ์์๋ฅผ ๋ณ๊ฒฝํ ์ ์๋ค.
x_2 = jnp.array([1, 2, 3])
x_2[0] = 999
# TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.
๋ฐ๋ฉด jax.numpy๋ก ์์ฑํ ๋ฐฐ์ด์ ์ง์ ์ ์ธ ์์ ์ ํ์ฉํ์ง ์๋๋ค. ์ด๋ โ์ธ๋ถ ๋ฐ์ดํฐโ์ธ ๋ฐฐ์ด์ ์ํ๊ฐ ๋ณํ๋๋ฉด์ ๋ถ์ ํจ๊ณผ๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํจ์ด๋ค.
๋ง์ฝ ๋ฐฐ์ด์ ์ผ๋ถ๋ฅผ ์์ ํ๋ ์์
์ ์งํํ๊ณ ์ถ๋ค๋ฉด ๋ถ์ ํจ๊ณผ๊ฐ ์๋ ์์ ํจ์๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค.
x_2 = jnp.array([1, 2, 3])
def modify(x):
return x.at[0].set(999)
y = modify(x_2)
print(y) # Array([999, 2, 3], dtype=int32)
์ฌ๊ธฐ์ modify(x)
๋ ๋ถ์ ํจ๊ณผ๊ฐ ์๋ ์์ ํจ์๋ผ๊ณ ๋ณผ ์ ์๋ ๊ฒ์ด๋ฉฐ, ์ฑ
์ jax.grad
์ jax.jit
๊ฐ์ ํจ์๋ ์์ ํจ์๋ก ์์ฑ๋์ด์ผ ํ๋ค๊ณ ์ค๋ช
ํ๊ณ ์๋ค.
JIT ์ปดํ์ผ
- ๋ณํ
- ์ฃผ์ด์ง ํจ์๋ฅผ ๋ณ๊ฒฝํ๊ฑฐ๋ ์์ ํ๋ ๋ฐฉ์. ์ฑ๋ฅ ์ต์ ํ๋ ์๋ ๋ฏธ๋ถ์ ๊ฐ๋ฅํ๊ฒ ํจ.
์ฑ
์ JAX์์ ๋ณํ(transformation)์ด๋ผ๋ ํค์๋๊ฐ ์ค์ํ๋ค๊ณ ๋งํ๋ค. JAX์์ ๋ณํ์ jaxpr
, ์ฆ JAX ํํ์์ด๋ผ๋ ์ค๊ฐ ์ธ์ด(intermediate language)๋ฅผ ํตํด ์ด๋ฃจ์ด์ง๋ค. jax.jit
๊ฐ ๋ํ์ ์ธ jax ๋ณํ์ด๋ผ๊ณ ์๊ฐ๋๋ค.
def selu(x, alpha=1.67, lamdba_=1.05):
return lamdba_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
# ์ผ๋ฐ
selu(x).block_until_ready()
# XLA
selu_jit = jax.jit(selu)
selu_jit(x).block_until_ready() # ๋น๋๊ธฐ ์คํ
์ ๋ด์ฉ์ ํ์ฑํ ํจ์ SELU(Scaled Exponential Linear Unit)๋ฅผ ๊ตฌํํ๊ณ ํธ์ถํ ๋ด์ฉ์ด๋ค.
selu(x)
๋์ ์ jit ๋ณํ์ ์ ์ฉํ selu_jit(x)
๊ฐ 7๋ฐฐ ๋น ๋ฅด๋ค๊ณ ์ค๋ช
ํ๊ณ ์๋ค. (๊ตฌ๊ธ Colab T4 ๊ธฐ์ค)
์ฑ
์ ์ฝ์ผ๋ฉด์ ํฅ๋ฏธ๋ก์ ๋ ๋ถ๋ถ์ jax.jit
์ ์ปดํ์ผ๋ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ์บ์ฑํ์ฌ ์ฌ์ฌ์ฉํ๋ค๋ ์ ์ด์๋ค. ๋ค๋ง jax.jit
์ ๋ฐ๋ณต๋ฌธ ๋ด๋ถ์์ ํธ์ถํ ๊ฒฝ์ฐ ์ปดํ์ผ ๊ณผ์ ์ด ๋ถํ์ํ๊ฒ ๋ฐ๋ณต๋ ์ ์์ผ๋ ์ง์ํ๋ผ๊ณ ์๋ดํ๊ณ ์๋ค.
Flax
๋ง์ง๋ง์ผ๋ก Flax ๋ฅผ ํ์ฉํ ์์ ๋ฅผ ์ดํด๋ณด๊ฒ ๋ค.
import flax.linen as nn
from jax import random
key = random.PRNGKey(42)
class MLP(nn.Module):
out_dims: int
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1))
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(self.out_dims)(x)
return x
model = MLP(out_dims=10)
x = jnp.empty((4, 28, 28, 1))
weights = model.init(key, x)
y = model.apply(weights, x)
์ฑ
์ import ํ๋ ๋ถ๋ถ์ ์์ด์ ์ถ๊ฐํ๋ค
nn.Module
์์ ์์๋ฐ์ ๋ชจ๋ธ์ ์์ฑํ๋ค๋ ์ ์์ PyTorch ์ ์ ์ฌํ ๋ฐฉ์์ API ๋ผ๊ณ ๋๊ปด์ก๊ณ ๊ธ๋ฐฉ ์ ์ํ ์ ์๊ฒ ๋ค๋ ์๊ฐ์ด ๋ ๋ค.
๋๊ฐ๋ฉฐ
ใJAX/Flax๋ก ๋ฅ๋ฌ๋ ๋ ๋ฒจ์ ใ์์ JAX ํต์ฌ ๊ฐ๋ ์ ์์ฃผ๋ก ์ดํด๋ณด๋ฉฐ ์ฑ ์ ๋ฆฌ๋ทฐํด๋ณด์๋ค. ์์ฆ ์์ ์์ ์ ๊ตณ์ด ์๋ก์ด ๋ฅ๋ฌ๋ ํ๋ ์์ํฌ๊ฐ ํ์ํ ๊น? ๋ผ๊ณ ๋ง์ฐํ ๊ถ๊ธํด ํ๋ฉฐ ๋ฆฌ๋ทฐ์ด ์ ์ฒญ์ ํ๋๋ฐ, ์ข์ ๊ธฐํ๋ก ์ฑ ๋ ์ ๊ณต ๋ฐ๊ณ JAX์ Flax์ ๋ํด ๊ฐ๋ณ๊ฒ ๋ฐฐ์๋ณผ ์ ์๋ ์๊ฐ์ด์๋ค.
JAX๊ฐ ์งํฅํ๋ ์ฒ ํ๊ณผ ํจ๊ป ๊ทธ๊ฒ์ด ๋ น์๋ ํต์ฌ ๊ธฐ๋ฅ์ ์ธ์ธํ๊ฒ ์ค๋ช ํด์ฃผ๊ธฐ ๋๋ฌธ์ JAX ์ ๋ฌธ์๋ก ์์ฃผ ์๋ง์ ๋์๋ผ๋ ์๊ฐ์ด ๋ค์๋ค. ํนํ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ ๊ฐ๋ ๋ง์ ์ค๋ช ํ๊ธฐ ์ํด ๋ณ๋ ์ง๋ฉด์ ํ ์ ํ๋ค๋ ์ ์์๋ JAX์ ์๋ฏธ๋ฅผ ์ ๋๋ก ์ ๋ฌํ๊ฒ ๋ค๋ ๊ฐํ ์์ง๋ ๋ณด์๋ค.
์๋ง JAX๋ก ์ ๋ฌธํ๊ธฐ๊น์ง ๊ฐ์ฅ ํฐ ์ฅ๋ฒฝ์ ์์ ๋ด๊ฐ ๋ ์ฌ๋ฆฐ ๊ฒ๊ณผ ๊ฐ์ด โ์ ๊ผญ ์ด๊ฒ์ด์ด์ผ ํ๋๊ฐ?โ ๋ผ๋ ์๋ฌธ์ผ ํ ๋ฐ, ์ด ์ฑ ์ ์ฝ๋๋ค๋ฉด ๊ทธ ์ฅ๋ฒฝ ์ ๋๋ ์ถฉ๋ถํ ๋์ ์ ์๊ฒ ๋ค. numpy ํ๋๋ก ๋ชจ๋ธ์ ๊ตฌํํ๋ ์ ๋๋ก low-level์์ ๋ชจ๋ธ๋ง๊ณผ ํ์ต ๊ณผ์ ๋ฑ์ ์ ์ฐํ๊ฒ ํต์ ํ ์ ์๋ค๋ ์ ์ด JAX/Flax์ ๊ฐ์ฅ ๊ฐ๋ ฅํ ์ ์ฒด์ฑ์ด๋ผ๊ณ ๋๊ผ๋ค. ๊ตฌ๊ธ TPU๋ฅผ ์ฌ์ฉํ๋ ML์์ง๋์ด๋ผ๋ฉด ์๊ฐ์ ๋ค์ฌ์๋ผ๋ JAX๋ฅผ ์ ์ฉํ ๊ฐ์น๊ฐ ์์ ๋ฏํ๋ค.
๊ทธ ์ธ์ CLIP, GPT ๊ฐ์ ์ต์ ๋ชจ๋ธ์ fine-tuning ์ ์์ ๋ก ๋ค๋ฃฌ ์ ๋ ์ธ์์ ์ด์๋ค. ML ๋ถ์ผ์ ์ ๋ฌธํ ์ดํ๋ก ์ถํ์๋ฅผ ํ์ฉํด์ ๊ณต๋ถ๋ฅผ ํ๋ ๊ฑด ์ ๋ง ์ค๋๋ง์ธ๋ฐ, ์ญ์ ์ต์ ์ฑ ์ด๋ ์ต์ ๋ชจ๋ธ๋ ๋ค๋ฃจ๋๊ตฌ๋ - ์ถ์๋ค.
๋ค๋ง ์ฑ ์๋ฌธ์์ ์ด๋ฏธ ๋ฐํ๋ค์ํผ ๋ฅ๋ฌ๋ ๊ฐ๋ ๊ณผ ํ๋ ์์ํฌ์ ๋ํ ๊ธฐ๋ณธ์ ์ธ ์ง์์ด ์์ด์ผ ์ฑ ์ ๋ด์ฉ์ ์ ๋๋ก ์ดํดํ ์ ์๋ค๋ ์ ์ ์ผ๋์ ๋ ํ์๊ฐ ์๊ฒ ๋ค. ํ์คํ โ์ด๊ธ์โ๋ผ๊ธฐ๋ณด๋จ โ์ ๋ฌธ์โ๋ก ๋ณด๋ ๊ฒ ๋ง๋ค. ๋ํ ์๋์ ๊ณ ๊ธ ํ๋ ์์ํฌ๋ค๋ณด๋๊น JAX/Flax ์์ฒด๊ฐ ์๋น์ค(์๋น)๋ณด๋ค๋ ์ฐ๊ตฌ์ ์ ํฉํ ๋๊ตฌ๋ผ๋ ์๊ฐ์ด ๋ค์๋ค. ์ฑ ์ฝ๊ธฐ ์ ๊ณผ ๋น์ทํ๊ฒ ์ด๊ฒ์ด ์ด๊ฒ์ด ํ์ํ ๊น? ๋ผ๋ ์ง๋ฌธ์ ์ฌ์ ํ ๊น๋ํ๊ฒ ํด๋ช ๋์ง ์์์ง๋ง, ์ฑ ์ ์ฝ๊ณ ๋์ ์ธ์ ๊ฐ JAX๋ฅผ ์จ๋ณด๊ณ ์ถ๋ค๋ ์์ฌ์ ๋ณด๋ค ๋๋ ทํด์ก๋ค.
๋ฆฌ๋ทฐ์ด๋ก ์ ์ ํ์ฌ ๋์๋ฅผ ์ ๊ณตํด์ค ์ถํ์ฌ ์ ์ดํ์ ์ง์ฌ์ผ๋ก ๊ฐ์ฌํ๋ค๋ ๋ง์์ ํํ๋ฉฐ ๋ณธ ๋ฆฌ๋ทฐ๋ฅผ ๋ง๋ฌด๋ฆฌํ๊ฒ ๋ค.