考慮這個張量。
a = tf.constant([0,1,2,3,5,6,7,8,9,10,19,20,21,22,23,24])
我想將它分成 3 個張量(對于這個特定示例),其中包含數字緊鄰的組。預期的輸出將是:
output_tensor = [ [0,1,2,3], [5,6,7,8,9,10], [19,20,21,22,23,24] ]
關于如何做到這一點的任何想法?是否有張量流 .math 方法可以幫助有效地做到這一點?我什么也找不到。
uj5u.com熱心網友回復:
對于提供的示例,split應該可以作業:
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
print(tf.split(a, [4, 6, 6]))
輸出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5, 6, 7, 8, 9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]
第二個引數指示沿分割軸的每個輸出張量的大小(默認為 0) - 因此在這種情況下,第一個張量的大小為 4,第二個張量的大小為 6,第三個張量的大小為 6。或者,可以提供一個 int ,只要您拆分的軸上的張量大小可以被該值整除。在這種情況下,3 不起作用(16/3 = 5.3333),但 4 會:
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
print(tf.split(a, 4))
輸出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([5, 6, 7, 8], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 9, 10, 19, 20], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([21, 22, 23, 24], dtype=int32)>]
假設數字連續位置的描述是未知的,可以使用相鄰差異有效地計算索引并提供給tf.split
:
def compute_split_indices(x):
adjacent_diffs = x[1:] - x[:-1] # compute adjacent differences
indices_where_not_continuous = tf.where(adjacent_diffs > 1) 1
splits = tf.concat([indices_where_not_continuous[:1], indices_where_not_continuous[1:] -
indices_where_not_continuous[:-1]], axis=0) # compute split sizes from the indices
splits_as_ints = [split.numpy().tolist()[0] for split in splits] # convert to a list of integers for ease of use
final_split_sizes = splits_as_ints [len(x) - sum(splits_as_ints)] # account for the rest of the tensor
return final_split_sizes
if __name__ == "__main__":
a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
splits = compute_split_indices(a)
print(tf.split(a, splits))
輸出:
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5, 6, 7, 8, 9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]
請注意,輸出與我們明確提供時相同[4, 6, 6]
。
轉載請註明出處,本文鏈接:https://www.uj5u.com/ruanti/482763.html