每當 JAX 即時編譯器必須重新編譯函式時,JAX 中是否有可能收到通知(因為輸入已更改且無法評估快取的編譯版本)?
現在,我使用了一種 hacky 解決方法來通知重新編譯。在當前的實作中,跟蹤器在需要編譯函式時執行一次,因此只有在重新編譯函式時才會執行的副作用是允許的:
import jax
recompilation_count: int = 0
@jax.jit
def func(z):
global recompilation_count
recompilation_count = 1
return z * z 100 / z
func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)
assert recompilation_count == 2
但是,這是 JAX 實作的內部,因此不能以可靠的方式使用。如果頻繁發生,是否有另一種方式可以通知并可能防止重新編譯函式?
uj5u.com熱心網友回復:
我不相信有任何內置的 API 可以完成您的要求。但目前正在積極討論類似的功能(參見例如https://github.com/google/jax/issues/8655)
但請注意,如果您愿意,可以使用內置方法來跟蹤編譯計數:
import jax
@jax.jit
def f(x):
return x
print(f._cache_size())
# 0
_ = f(jnp.arange(3))
print(f._cache_size())
# 1
_ = f(jnp.arange(3)) # should not trigger a recompilation
print(f._cache_size())
# 1
_ = f(jnp.arange(100)) # should trigger a recompilation
print(f._cache_size())
# 2
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/367736.html