我有兩個 numpy 陣列,例如:
a = [False, False, False, False, False, True, False, False]
b = [1, 2, 3, 4, 5, 6, 7, 8]
我需要求和b
,而不是整個陣列,而只是直到具有等效索引的元素a
為True
換句話說,我想做 1 2 3 4 5=15 而不是 1 2 3 4 5 6 7 8=36
我需要一個有效的解決方案,我想我需要屏蔽所有元素,b
然后將它們設為 0 True
。a
旁注:我的代碼在 jax.numpy 中,而不是原始的 numpy 中,但我想這并不重要。
uj5u.com熱心網友回復:
你可以做一個累計
np.sum(b[np.cumsum(a)==0])
uj5u.com熱心網友回復:
我建議將陣列轉換為串列,.tolist()
然后應用.index()
以獲取第一個True
:的索引i = a.tolist().index(True)
。然后簡單的切片和求和:total = numpy.sum(b[:i])
uj5u.com熱心網友回復:
我可以想到兩種方法:你可以通過構造一個掩碼來做到這一點cumsum
(這也適用于常規 numpy):
a = jnp.array([False, False, False, False, False, True, False, False])
b = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
mask = a.cumsum() == 0
b.sum(where=mask) # 15
或者您可以找到第一個 True 索引jnp.where
(請注意,該size
引數僅存在于 JAX 版本中jnp.where
,而不存在于 numpy 中):
idx = jnp.where(a, size=1)[0][0]
b[:idx].sum() # 15
您可能會做一些微基準測驗來確定哪個對您關心的陣列大小更有效。
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/439816.html