除了一個軸具有不同數量的元素外,當其輸入結構基本保持不變時,是否可以避免重新編譯 JIT 函式?
import jax
@jax.jit
def f(x):
print('recompiling')
return (x 10) * 100
a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't
要求:pip install jax、jaxlib
uj5u.com熱心網友回復:
不,當您使用不同形狀的陣列呼叫函式時,無法避免重新編譯。從根本上說,JAX 為靜態形狀的輸入和輸出編譯函式,并且使用新形狀的陣列呼叫 JIT 編譯的函式將始終觸發重新編譯。
有一些關于放寬此要求的持續作業(在 JAX 的 github 存盤庫中搜索“動態形狀”),但目前沒有此類 API 可用。
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/367737.html
下一篇:創建類和物件