Post

JAX 101 - Working with Pytrees

Working with Pytrees

JAX에서는 arraydictionary나, listdictionary와 같은 중첩된 구조들을 다루기 위한 도구를 제공한다. 이러한 구조들을 JAX에서는 주로 pytrees라고 부르지만, 다른 곳에서는 nests 또는 단순히 trees라고 부르기도 한다.

JAX는 이러한 객체들을 지원하기 위해 라이브러리 함수들과 jax.tree_utils의 함수들을 제공한다. (가장 일반적인 함수들을 jax.tree_*로도 사용 가능).

What is a pytree?

Pytree는 다양한 형태의 데이터 구조를 포함할 수 있는 JAX에서 사용하는 일반적인 용어이다. 예를 들어, 숫자, 리스트, 튜플, 딕셔너리 등 다양한 데이터 형태가 조합된 복합 구조를 의미한다. 이 구조들은 JAX내의 다양한 함수와 함께 자연스럽게 작동하여, 계산을 간편하게 만들어 준다.

Pytree를 활용하는 기본적인 예제는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = jax.tree_util.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
1
2
3
4
5
6
Output:
    [1, 'a', <object object at 0x7fc5c82bbfe0>]   has 3 leaves: [1, 'a', <object object at 0x7fc5c82bbfe0>]
    (1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
    [1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
    {'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
    Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]

이 함수는 tree 구조에서 flatten된 leaf 요소들을 추출하는 기능을 제공한다.

Why pytrees?

기계 학습에서 pytree가 자주 사용되는 몇가지 영역을 살펴보자

  1. Model parameters
  2. Dataset entries
  3. RL agent observations

또한, 데이터셋을 대량으로 처리할 때 (ex: 리스트의 리스트의 딕셔너리) 자연스럽게 pytree 구조가 생성되곤 한다. 이는 데이터의 복잡성과 다양성을 효율적으로 관리하기 위해 필요하다.

1
2
Output:
    [[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree_map은 아래와 같이 여러개의 인자가 있어도 잘 작동한다.

1
2
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
1
2
Output:
    [[2, 4, 6], [2, 4], [2, 4, 6, 8]]

그런데, jax.tree_map을 사용할 때 조심해야할 점은, multiple argument를 사용할 때 입력 구조가 정확히 일치해야한다는 것이다. 즉, 목록은 요소 수가 같아야하고 딕셔너리는 키가 같아야하는 등 입력 구조가 일치해야한다.

Example: ML model parameters

MLP를 훈련할 때도 pytree 연산이 유용하게 사용된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

jax.tree_map(lambda x: x.shape, params)

결과를 보면 parameter의 모양이 예상한 것과 일치하는것을 확인할 수 있다.

1
2
3
4
Output:
    [{'biases': (128,), 'weights': (1, 128)},
    {'biases': (128,), 'weights': (128, 128)},
    {'biases': (1,), 'weights': (128, 1)}]

이제 MLP를 학습시켜보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.0001

@jax.jit
def update(params, x, y):

  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of the many JAX functions that has
  # built-in support for pytrees.

  # This is handy, because we can apply the SGD update using tree utils:
  return jax.tree_map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

import matplotlib.pyplot as plt

xs = np.random.normal(size=(128, 1))
ys = xs ** 2

for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();

jax1

Key Paths

Pytree의 각 leaf에는 키 경로가 있다. leaf의 키 경로는 키들의 리스트로 이 리스트의 길이는 pytree내에서 leaf의 깊이 (depth)와 같다. 각 키는 pytree 노드 타입에 인덱스를 나타내는 해시 가능한 객체다. 키의 타입은 pytree 노드 타입에 따라 다르며, 예를 들어 딕셔너리의 키 타입은 튜플의 키 타입과 다르다.

내장된 pytree 노드 타입의 경우, 어떤 pytree 노드 인스턴스에 대한 키 집합은 고유합니다. 이러한 속성을 가진 노드들로 구성된 Pytree에서는 각 리프의 키 경로가 유일하다.

키 경로를 다루기 위한 API는 다음과 같다.:

  1. jax.tree_util.tree_flatten_with_path: jax.tree_util.tree_flatten과 유사하게 작동하지만, 키 경로도 반환함.
  2. jax.tree_util.tree_map_with_path: jax.tree_util.tree_map과 유사하게 작동하지만, 함수가 키 경로를 인자로도 받는다.
  3. jax.tree_util.keystr: 일반 키 경로를 주어, 읽기 쉬운 문자열 표현을 반환한다.
1
2
3
4
5
6
7
import collections
ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
    print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
1
2
3
4
5
6
Output:
    Value of tree[0]: 1
    Value of tree[1]['k1']: 2
    Value of tree[1]['k2'][0]: 3
    Value of tree[1]['k2'][1]: 4
    Value of tree[2].name: foo

JAX에서는 내장된 pytree 노드 타입들에 대해 몇 가지 기본 키 타입을 제공하여, 키 경로를 표현합니다. 이는 다음과 같다.:

  • SequenceKey(idx: int): 리스트와 튜플에 사용된다.
  • DictKey(key: Hashable): 딕셔너리에 사용된다.
  • GetAttrKey(name: str): namedtuple과 사용자 정의 pytree 노드에 주로 사용된다. (이 내용은 다음 섹션에서 더 자세히 설명해준다고 한다.)

사용자는 자신의 커스텀 노드에 대해 자유롭게 자신만의 키 타입을 정의할 수 있다. 이들은 jax.tree_util.keystr과 함께 작동하는 한, 그들의 str() 메소드가 사용자가 읽기 쉬운 표현으로 오버라이드되어 있어야 한다고 한다.

1
2
for key_path, _ in flattened:
    print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
1
2
3
4
5
6
Output:
    Key path of tree[0]: (SequenceKey(idx=0),)
    Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
    Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'),SequenceKey(idx=0))
    Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
    Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
This post is licensed under CC BY 4.0 by the author.