給定兩個具有相同維度(d>=2)
和形狀的張量 A[A_{1},...,A_{d-2},A_{d-1},A_{d}]
和B [A_{1},...,A_{d-2},B_{d-1},B_{d}]
(第一個 d-2 維度的形狀相同)。
有沒有辦法計算最后兩個維度的克羅內克積?的形狀my_kron(A,B)
應該是[A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}]
。例如d=3
,
A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)
C[0,...]
應該是 和 的克羅內克積和A[0,...]
和B[0,...]
的C[1,...]
克羅內克積。A[1,...]
B[1,...]
對于 d=2,這就是jnp.kron
(or np.kron
) 函式的作用。
對于 d=3,這可以通過 來實作jax.vmap
。
jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)
但我無法找到一般(未知)尺寸的解決方案。有什么建議么?
uj5u.com熱心網友回復:
就numpy
我而言,我認為這就是您正在做的事情:
In [104]: A = np.arange(2*3*3).reshape(2,3,3)
In [105]: B = np.arange(2*4*4).reshape(2,4,4)
In [106]: C = np.array([np.kron(a,b) for a,b in zip(A,B)])
In [107]: C.shape
Out[107]: (2, 12, 12)
這將初始維度 2 視為batch
. 一個明顯的概括是重塑陣列,將較高的維度減少到 1,例如reshape(-1,3,3)
,等等。然后,重新整形C
回所需的 n 維度。
np.kron
outer
確實接受 3d(和更高),但它在共享 2 維上做了某種事情:
In [108]: np.kron(A,B).shape
Out[108]: (4, 12, 12)
并將 4 維可視化為 (2,2),我可以diagonal
得到你的C
:
In [109]: np.allclose(np.kron(A,B)[[0,3]], C)
Out[109]: True
完整的kron
計算比需要的多,但仍然更快:
In [110]: timeit C = np.array([np.kron(a,b) for a,b in zip(A,B)])
108 μs ± 2.23 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [111]: timeit np.kron(A,B)[[0,3]]
76.4 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
我確信可以以更直接的方式進行計算,但這樣做需要更好地了解其kron
作業原理。快速瀏覽一下np.kron
代碼所暗示的那樣outer(A,B)
In [114]: np.outer(A,B).shape
Out[114]: (18, 32)
它具有相同數量的元素,但它然后reshapes
并concatenates
產生kron
布局。
但是憑直覺,我發現這和你想要的一樣:
In [123]: D = A[:,:,None,:,None]*B[:,None,:,None,:]
In [124]: np.allclose(D.reshape(2,12,12),C)
Out[124]: True
In [125]: timeit np.reshape(A[:,:,None,:,None]*B[:,None,:,None,:],(2,12,12))
14.3 μs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
這很容易推廣到更領先的維度。
def my_kron(A,B):
D = A[...,:,None,:,None]*B[...,None,:,None,:]
ds = D.shape
newshape = (*ds[:-4],ds[-4]*ds[-3],ds[-2]*ds[-1])
return D.reshape(newshape)
In [137]: my_kron(A.reshape(1,2,1,3,3),B.reshape(1,2,1,4,4)).shape
Out[137]: (1, 2, 1, 12, 12)
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/507401.html