tf.split()函数的用法 您所在的位置:网站首页 tfsplit tf.split()函数的用法

tf.split()函数的用法

#tf.split()函数的用法| 来源: 网络整理| 查看: 265

在tensorflow的代码里经常看到tf.split()这个函数,我们来看看这个具体用法

tf.split( value, num_or_size_splits, axis=0, num=None, name='split' )

把一个张量划分成几个子张量:

value:准备切分的张量 num_or_size_splits:准备切成几份 axis : 准备在第几个维度上进行切割

其中分割方式分为两种

如果num_or_size_splits传入的 是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量

如果num_or_size_splits传入的是一个向量(这里向量各个元素的和要跟原本这个维度的数值相等)就根据这个向量有几个元素分为几项)举个例子

# 张量为(5, 30) # 这个时候5是axis=0, 30是axis=1,如果要在axis=1这个维度上把这个张量拆分成三个子张量 # 传入向量时 split0, split1, split2 = tf.split(value, [4, 15, 11], 1) tf.shape(split0) # [5, 4] tf.shape(split1) # [5, 15] tf.shape(split2) # [5, 11] # 传入整数时 split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1) tf.shape(split0) # [5, 10]

在来个详细的例子:

import tensorflow as tf value = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] print('axis=0时,拆分....') split0, split1, split2 = tf.split(value, [1, 1, 1], 0) with tf.Session() as sess: print(sess.run(split0)) print("------------") print(sess.run(split1)) print("------------") print(sess.run(split2)) print('axis=1时,拆分....') split0, split1, split2 = tf.split(value, [1, 2, 1], 1) with tf.Session() as sess: print(sess.run(split0)) print("------------") print(sess.run(split1)) print("------------") print(sess.run(split2))

运行结果:

[[1 2 3 4]] ------------ [[5 6 7 8]] ------------ [[ 9 10 11 12]] axis=1时,拆分.... [[1] [5] [9]] ------------ [[ 2 3] [ 6 7] [10 11]] ------------ [[ 4] [ 8] [12]]


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有