是否可以基于一個 pytree(用于 JAX)xarray.DataArray
?
關鍵問題似乎是data
an 的屬性xarray.DataArray
被構造為numpy.ndarray
輸入陣列的視圖(例如DeviceArray
)
import jax.numpy as jnp
import xarray as xr
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
>>> type(da.data)
numpy.ndarray
將 扁平化/不扁平DataArray
化成 pytree 相對簡單(所有屬性都是 aux 除了data
),但我不知道如何檢索DeviceArray
,甚至分配給它(我不能at[.].set(.)
在這里使用)。
構建容器類(帶有DeviceArray
成員)的替代方法需要手動實作 xarray 的所有相關功能。對于單個功能,例如標記索引,這是可能的,但也是多余的。
uj5u.com熱心網友回復:
您可以按照檔案中的示例將 xarray 型別注冊為自定義 PyTree 節點來做您想做的事情。
例如,它可能看起來像這樣:
import jax.numpy as jnp
from jax import tree_util
import xarray as xr
tree_util.register_pytree_node(
xr.DataArray,
lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
lambda kwds, args: xr.DataArray(*args, **kwds)
)
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]
da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
除了最簡單的情況外,我認為這不太可能按預期作業:例如,JAX 轉換僅限于功能性、無副作用的代碼,而 JAX 陣列是不可變的。xarray 的操作通常違反這兩個約束。
另一個問題:你必須小心地定義你的 flatten 和 unflatten 函式,以確保所有資料和元資料都被正確序列化,如果你希望在任何 JAX 函式中使用它,請注意你xarray 的輸入驗證可能會遇到問題;有關更多資訊,請參閱自定義 PyTrees 和初始化。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/461732.html
標籤:Python 麻木的 python-xarray 贾克斯
上一篇:將值添加到矩陣的下一行