這是我有的一些資料:
import jax.numpy as jnp
import numpyro.distributions as dist
import jax
xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)
我想運行該功能
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
在來自xaxis
和的每對值上yaxis
。
這是一種“緩慢”的方法:
results = np.zeros((len(xaxis), len(yaxis)))
for i in range(len(xaxis)):
for j in range(len(yaxis)):
results[i, j] = func(xaxis[i], yaxis[j])
有效,但速度很慢。
所以這是一種矢量化的方法:
jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)
快得多,但很難閱讀。
有沒有一種干凈的方式來撰寫矢量化版本?我可以用一個vmap
,而不是將一個嵌套在另一個中嗎?
編輯
另一種方式是
jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T
但它仍然很亂。
uj5u.com熱心網友回復:
我相信jax 的矢量化指南與您的問題非常相似;使用 vmap 復制嵌套 for 回圈的邏輯需要嵌套 vmap。
使用的最干凈的方法jax.vmap
可能是這樣的:
from functools import partial
@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
func(xaxis, yaxis)
這里的另一個選擇是使用jnp.vectorize
API(通過多個 vmap 實作),在這種情況下,您可以執行以下操作:
print(jnp.vectorize(func)(xaxis[:, None], yaxis))
轉載請註明出處,本文鏈接:https://www.uj5u.com/ruanti/349961.html