我對使用 JAX 訓練神經網路很感興趣。我看了一下tf.data.Dataset
,但它只提供了 tf 張量。我尋找了一種將資料集更改為 JAX numpy 陣列的方法,并且發現了許多Dataset.as_numpy_generator()
用于將 tf 張量轉換為 numpy 陣列的實作。但是我想知道這是否是一個好習慣,因為 numpy 陣列存盤在 CPU 記憶體中,這不是我想要的訓練(我使用 GPU)。所以我發現的最后一個想法是通過呼叫手動重鑄陣列,jnp.array
但這并不是很優雅(我擔心 GPU 記憶體中的副本)。有沒有人對此有更好的主意?
快速代碼來說明:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
uj5u.com熱心網友回復:
tensorflow 和 JAX 都能夠在不復制記憶體的情況下將陣列轉換為dlpack張量,因此可以從 tensorflow 陣列創建 JAX 陣列而不復制底層資料緩沖區的一種方法是通過 dlpack:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
通過執行到 JAX 的往返,您可以進行比較unsafe_buffer_pointer()
以確保陣列指向同一個緩沖區,而不是沿途復制緩沖區:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/346307.html
標籤:Python 张量流 numpy-ndarray 贾克斯