Introduction
이 섹션에서는 JAX에서 지원하는 단일 프로그램, 다중 데이터 (Single-Program, Multiple-Data, SPMD) 코드에 대해서 다뤄보겠다.
What is SPMD?
SPMD는 같은 계산 (ex. NN의 forward pass)을 서로 다른 입력 데이터 (ex. 배치 내 다른 입력들)에 대해 다양한 장치들 (ex. 여러 TPU)에서 병렬로 수행하는 병렬처리 기법을 말한다.
개념적으로 이것은 벡터화와 크게 다르지 않다. 벡터화에서는 같은 연산이 단일 장치의 메모리 다른 부분에서 병렬로 발생한다. JAX에서는 jax.pmap을 사용해서, 한 장치에서 작성된 함수를 여러 장치에서 병렬로 실행되는 함수로 변환함으로써, 장치 병렬성을 유사한 방식으로 지원한다.
TPU Setup
1
2
| import jax
jax.devices()
|
나의 경우는 이렇게 나오는데, Tutorial에서는 Kaggle TPU VM을 사용해서 실행하는 것을 추천한다고 한다.
1
2
3
4
5
6
7
8
| [cuda(id=0),
cuda(id=1),
cuda(id=2),
cuda(id=3),
cuda(id=4),
cuda(id=5),
cuda(id=6),
cuda(id=7)]
|
jax.pmap의 가장 기본적인 사용법은 jax.vmap과 완전히 유사하다고 한다. 아래의 예제를 보자.
1
2
3
4
5
6
7
8
9
10
11
12
13
| import numpy as np
import jax.numpy as jnp
x = np.arange(5)
w = np.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
|
1
| Array([11., 20., 29.], dtype=float32)
|
이제 convolve 함수를 전체 데이터 배치에서 실행되도록 변환해보자. 여러 장치에 배치를 분산시킬 예정이므로, 배치 크기를 장치의 수와 동일하게 설정할 것이다.
1
2
3
4
5
| n_devices = jax.local_device_count()
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)
xs
|
1
2
3
4
5
6
7
8
| array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]])
|
1
2
3
4
5
6
7
8
| array([[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.]])
|
분산을 수행하기 전에 jax.vmap을 이용해서 벡터화를 할 수 있다.
1
| jax.vmap(convolve)(xs, ws)
|
1
2
3
4
5
6
7
8
| Array([[ 11., 20., 29.],
[ 56., 65., 74.],
[101., 110., 119.],
[146., 155., 164.],
[191., 200., 209.],
[236., 245., 254.],
[281., 290., 299.],
[326., 335., 344.]], dtype=float32)
|
여러 장치에 계산을 분산시키기 위해서 jax.vmap을 jax.pmap으로 교체만 하면 된다고 한다.
1
| jax.pmap(convolve)(xs, ws)
|
1
2
3
4
5
6
7
8
| Array([[ 11., 20., 29.],
[ 56., 65., 74.],
[101., 110., 119.],
[146., 155., 164.],
[191., 200., 209.],
[236., 245., 254.],
[281., 290., 299.],
[326., 335., 344.]], dtype=float32)
|
아직까진 별 달라진 것을 모르겠다. 다음을 진행해보자. document에서는 병렬화된 convolve는 `jnp.Array를 반환한다고 하는데 근데 vmap은 그렇다면 jnp.Array로 반환하는 건 아니라는 얘긴가…? 이 배열 요소는 병렬 처리에 사용된 모든 장치에 걸쳐 분할 (shared)되어 있기 때문이라고 한다. 만약 다른 병렬 계산을 실행할 경우, 이 요소들은 각자의 장치에 머물면서, 장치간 통신 비용을 발생시키지 않고 처리된다고 한다.
내부의 jax.pmap(convolve)의 출력은 외부의 jax.pmap(convolve)`로 전달될 때 그 장치를 벗어나지 않는다고 하는데, 이건 또 무슨 얘기일까…?
1
| jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
|