我是自動微分編程的新手,所以這可能是一個天真的問題。下面是我要解決的問題的簡化版本。
我有兩個輸入陣列 - 一個A
大小向量N
和一個B
形狀矩陣(N, M)
,以及一個theta
大小引數向量M
。我定義了一個新陣列C(theta) = B * theta
來獲得一個大小為 的新向量N
。然后我獲取落入 的上四分位數和下四分位數的元素的索引C
,并使用它們創建一個新陣列A_low(theta) = A[lower quartile indices of C]
和A_high(theta) = A[upper quartile indices of C]
。顯然這兩者確實取決于theta
,但是有可能區分A_low
和A_high
wrttheta
嗎?
到目前為止,我的嘗試似乎表明沒有 - 我使用了 autograd、JAX 和 tensorflow 的 python 庫,但它們都回傳零梯度。(到目前為止,我嘗試過的方法包括使用 argsort 或使用 提取相關子陣列tf.top_k
。)
我正在尋求的幫助是證明導數未定義(或無法進行分析計算),或者如果它確實存在,則是關于如何估計它的建議。我的最終目標是最小化一些函式f(A_low, A_high)
wrt theta
。
uj5u.com熱心網友回復:
這是我根據您的描述撰寫的 JAX 計算:
import numpy as np
import jax.numpy as jnp
import jax
N = 10
M = 20
rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))
def f(A, B, theta, k=3):
C = B @ theta
_, i_upper = lax.top_k(C, k)
_, i_lower = lax.top_k(-C, k)
return A[i_lower], A[i_upper]
x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
導數都為零,我相信這是正確的,因為輸出值的變化不取決于 的值的變化theta
。
但是,你可能會問,這怎么可能?畢竟,theta
進入計算,如果你為 輸入不同的值theta
,你會得到不同的輸出。梯度怎么可能為零?
但是,您必須牢記的是,微分并不衡量輸入是否影響輸出。它測量給定輸入變化無窮小的輸出變化。
我們以一個稍微簡單的函式為例:
import jax
import jax.numpy as jnp
A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])
def f(A, theta):
return A[jnp.argmax(theta)]
x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)
這里區分的結果f
相對于theta
是全零,出于同樣的原因如上。為什么?如果對 進行無窮小的更改theta
,一般不會影響 的排序順序theta
。因此,您選擇的條目A
不會在 theta 發生無窮小的變化時發生變化,因此關于 theta 的導數為零。
現在,您可能會爭辯說,在某些情況下并非如此:例如,如果 theta 中的兩個值非常接近,那么即使是無限小的擾動一個值也肯定會改變它們各自的等級。這是真的,但由此程序產生的梯度是不確定的(輸出的變化相對于輸入的變化并不平滑)。好訊息是這種不連續性是片面的:如果你在另一個方向擾動,秩沒有變化,梯度是明確定義的。為了避免未定義的梯度,大多數 autodiff 系統將隱式使用這種更安全的導數定義進行基于秩的計算。
結果是當你無限地擾動輸入時,輸出的值不會改變,這是梯度為零的另一種說法。并且這不是 autodiff 的失敗——它是正確的梯度,因為 autodiff 建立在微分的定義上。此外,如果您嘗試在這些不連續點處嘗試更改導數的不同定義,您所能希望的最好結果是未定義的輸出,因此導致零的定義可以說更有用和更正確。
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/375446.html