我有一個形狀為 (10,6) 的 ndarray,我想提取與 [:,0] 中的相應維度具有相同符號的數字數量。我的代碼如下,哪里generate_data
是生成的demo資料,get_result
是生產代碼,需要運行幾千萬次:
import numpy as np
rand = np.random.default_rng(seed=0)
def generate_data() -> np.ndarray:
data = rand.uniform(-1, 1, size=(10, 6))
return data
def get_result(data) -> np.array:
dim2 = data.shape[1]
result = np.zeros(dim2, )
data1 = np.sign(data)
for i in range(dim2):
a = data1[:, i]
b = a[0]
if b == 0:
continue
count = 1
for j in range(1, a.shape[0]):
if a[j] != b:
result[i] = count * b
break
count = 1
result[i] = count * b
return result
def main() -> None:
data = generate_data()
print(data)
result = get_result(data)
print(result)
return
if __name__ == '__main__':
main()
我的資料資料如下:
print(data)
[[ 0.27392337 -0.46042657 -0.91805295 -0.96694473 0.62654048 0.82551115]
[ 0.21327155 0.45899312 0.08724998 0.87014485 0.63170711 -0.994523 ]
[ 0.71480855 -0.93282885 0.45931089 -0.64868876 0.72635784 0.08292244]
[-0.40057622 -0.15462556 -0.94336066 -0.75143345 0.34124883 0.29437902]
[ 0.23077022 -0.23264489 0.99441987 0.96167068 0.37108397 0.30091855]
[ 0.37689346 -0.22215715 -0.72980699 0.44297668 0.05070864 -0.37951625]
[-0.02832928 0.77897567 0.86808703 -0.28440961 0.14305966 -0.35626122]
[ 0.18860006 -0.32417755 -0.216762 0.7805487 -0.54568481 0.24637429]
[-0.83196931 0.6652883 0.57419661 -0.52126111 0.75296846 -0.88286393]
[-0.32776588 -0.69944107 -0.09932127 0.59264854 -0.53871558 -0.8959574 ]]
我要生成的結果如下:
print(result)
[ 3. -1. -1. -1. 7. 1.]
在我的電腦上,評估速度get_result()
:
%timeit get_result(data)
23.2 μs ± 3.82 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
uj5u.com熱心網友回復:
如果您想加快迭代速度,我強烈建議您使用njit
from numba
。我認為實作這一目標的最簡單方法是將njit
裝飾器放在函式之前。由于您只使用包含數字資料型別的陣列,因此您無需更改代碼中的任何內容。僅此一項就可以顯著減少您的迭代。
import numpy as np
from numba import njit
rand = np.random.default_rng(seed=0)
def generate_data() -> np.ndarray:
data = rand.uniform(-1, 1, size=(10, 6))
return data
@njit
def get_result_njit(data: np.ndarray) -> np.array:
dim2 = data.shape[1]
result = np.zeros(dim2, )
data1 = np.sign(data)
for i in range(dim2):
a = data1[:, i]
b = a[0]
if b == 0:
continue
count = 1
for j in range(1, a.shape[0]):
if a[j] != b:
result[i] = count * b
break
count = 1
result[i] = count * b
return result
使用和不使用njit
裝飾器的功能比較:
%timeit get_result_njit(data)
-----------------------------------------------------------------------------
1.21 μs ± 92.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
-----------------------------------------------------------------------------
%timeit get_result(data)
-----------------------------------------------------------------------------
15.3 μs ± 252 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
-----------------------------------------------------------------------------
注意:這絕對不是限制,我相信你仍然可以減少運行時間,特別是如果你充分利用numba
. 但我認為這應該是實作加速的最簡單方法。如果您需要更多性能,請務必查看numba。
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/497589.html