我最近在 Jax 中實作了一個兩層 GRU 網路,但對其性能感到失望(它無法使用)。
因此,我嘗試與 Pytorch 進行一些速度比較。
最小作業示例
這是我的最小作業示例,輸出是在帶有 GPU 運行時的 Google Colab 上創建的。colab 中的筆記本
import flax.linen as jnn
import jax
import torch
import torch.nn as tnn
import numpy as np
import jax.numpy as jnp
def keyGen(seed):
key1 = jax.random.PRNGKey(seed)
while True:
key1, key2 = jax.random.split(key1)
yield key2
key = keyGen(1)
hidden_size=200
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8
class RNN_jax(jnn.Module):
@jnn.compact
def __call__(self, x, carry_gru1, carry_gru2):
carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
x = jnn.Dense(4)(x)
x = x/jnp.linalg.norm(x)
return x, carry_gru1, carry_gru2
class RNN_torch(tnn.Module):
def __init__(self, batch_size, hidden_size, in_features, out_features):
super().__init__()
self.gru = tnn.GRU(
input_size=in_features,
hidden_size=hidden_size,
num_layers=2
)
self.dense = tnn.Linear(hidden_size, out_features)
self.init_carry = torch.zeros((2, batch_size, hidden_size))
def forward(self, X):
X, final_carry = self.gru(X, self.init_carry)
X = self.dense(X)
return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))
rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)
Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))
initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))
params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)
def forward(params, X):
carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2
Yhat = []
for x in X: # x.shape = (batch_size, in_features)
yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
Yhat.append(yhat) # y.shape = (batch_size, out_features)
#return jnp.concatenate(Y, axis=0)
jitted_forward = jax.jit(forward)
結果
# uncompiled jax version
%time forward(params, Xj)
CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s
# time for compiling
%time jitted_forward(params, Xj)
CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s
# compiled jax version
%timeit jitted_forward(params, Xj)
The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 μs per loop
# torch version
%timeit lambda: rnn_torch(Xt)
10000000 loops, best of 5: 65.7 ns per loop
問題
為什么我的 Jax 實作這么慢?我究竟做錯了什么?
另外,為什么編譯需要這么長時間?順序沒那么長。。
謝謝 :)
uj5u.com熱心網友回復:
JAX 代碼編譯緩慢的原因是在 JIT 編譯期間 JAX 會展開回圈。所以就XLA編譯而言,你的函式實際上非常大:你呼叫了rnn_jax.apply()
1000次,編譯時間在陳述句數量上趨于大致二次方。
相比之下,您的 pytorch 函式不使用 Python 回圈,因此在后臺它依賴于運行速度更快的矢量化操作。
任何時候for
在 Python 中對資料使用回圈時,一個很好的選擇是您的代碼會很慢:無論您使用的是 JAX、torch、numpy、pandas 等,都是如此。我建議找到解決問題的方法在依賴矢量化操作而不是依賴于緩慢的 Python 回圈的 JAX 中。
轉載請註明出處,本文鏈接:https://www.uj5u.com/ruanti/351992.html
上一篇:如何生成和繪制所有生成樹?