Post

JAX 101 - Parallel Evaluation in JAX

Introduction

이 섹션에서는 JAX에서 지원하는 단일 프로그램, 다중 데이터 (Single-Program, Multiple-Data, SPMD) 코드에 대해서 다뤄보겠다.

What is SPMD?

SPMD는 같은 계산 (ex. NN의 forward pass)을 서로 다른 입력 데이터 (ex. 배치 내 다른 입력들)에 대해 다양한 장치들 (ex. 여러 TPU)에서 병렬로 수행하는 병렬처리 기법을 말한다.

개념적으로 이것은 벡터화와 크게 다르지 않다. 벡터화에서는 같은 연산이 단일 장치의 메모리 다른 부분에서 병렬로 발생한다. JAX에서는 jax.pmap을 사용해서, 한 장치에서 작성된 함수를 여러 장치에서 병렬로 실행되는 함수로 변환함으로써, 장치 병렬성을 유사한 방식으로 지원한다.

TPU Setup

1
2
import jax
jax.devices()

나의 경우는 이렇게 나오는데, Tutorial에서는 Kaggle TPU VM을 사용해서 실행하는 것을 추천한다고 한다.

1
2
3
4
5
6
7
8
[cuda(id=0),
 cuda(id=1),
 cuda(id=2),
 cuda(id=3),
 cuda(id=4),
 cuda(id=5),
 cuda(id=6),
 cuda(id=7)]

jax.pmap의 가장 기본적인 사용법은 jax.vmap과 완전히 유사하다고 한다. 아래의 예제를 보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import jax.numpy as jnp

x = np.arange(5)
w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
1
Array([11., 20., 29.], dtype=float32)

이제 convolve 함수를 전체 데이터 배치에서 실행되도록 변환해보자. 여러 장치에 배치를 분산시킬 예정이므로, 배치 크기를 장치의 수와 동일하게 설정할 것이다.

1
2
3
4
5
n_devices = jax.local_device_count() 
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

xs
1
2
3
4
5
6
7
8
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39]])
1
ws
1
2
3
4
5
6
7
8
array([[2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.]])

분산을 수행하기 전에 jax.vmap을 이용해서 벡터화를 할 수 있다.

1
jax.vmap(convolve)(xs, ws)
1
2
3
4
5
6
7
8
Array([[ 11.,  20.,  29.],
       [ 56.,  65.,  74.],
       [101., 110., 119.],
       [146., 155., 164.],
       [191., 200., 209.],
       [236., 245., 254.],
       [281., 290., 299.],
       [326., 335., 344.]], dtype=float32)

여러 장치에 계산을 분산시키기 위해서 jax.vmapjax.pmap으로 교체만 하면 된다고 한다.

1
jax.pmap(convolve)(xs, ws)
1
2
3
4
5
6
7
8
Array([[ 11.,  20.,  29.],
       [ 56.,  65.,  74.],
       [101., 110., 119.],
       [146., 155., 164.],
       [191., 200., 209.],
       [236., 245., 254.],
       [281., 290., 299.],
       [326., 335., 344.]], dtype=float32)

아직까진 별 달라진 것을 모르겠다. 다음을 진행해보자. document에서는 병렬화된 convolve는 `jnp.Array를 반환한다고 하는데 근데 vmap은 그렇다면 jnp.Array로 반환하는 건 아니라는 얘긴가…? 이 배열 요소는 병렬 처리에 사용된 모든 장치에 걸쳐 분할 (shared)되어 있기 때문이라고 한다. 만약 다른 병렬 계산을 실행할 경우, 이 요소들은 각자의 장치에 머물면서, 장치간 통신 비용을 발생시키지 않고 처리된다고 한다.

내부의 jax.pmap(convolve)의 출력은 외부의 jax.pmap(convolve)`로 전달될 때 그 장치를 벗어나지 않는다고 하는데, 이건 또 무슨 얘기일까…?

1
jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
1
This post is licensed under CC BY 4.0 by the author.