Tensorflow实现tensor数值比较 | 您所在的位置:网站首页 › python中的true和false › Tensorflow实现tensor数值比较 |
在进行语义分割的二分类中,需要将预测值大于和小于0.5的logits分别标记为True和False。使用tf.equal(label, 0)只会判断改值是否为1。 使用tf.where(input, a, b)实现这个功能。 其中input是tensor+判断条件,判断得到True和False的一个sensor,它和a、b尺寸一致。 函数作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值。 tf.ones_like(label) 和 tf.zeros_like(label) 两个函数生成和label形状相同的两个纯1和纯0的tensor,具体实现代码: t1 = tf.constant([-0.1, 0.3, -0.49, -0.02]) ones = tf.ones_like(t1) zeros = tf.zeros_like(t1) t2 = tf.where(t1>0, ones, zeros) equal = tf.cast(t2, tf.bool) cast_s = tf.cast(sign, tf.bool) with tf.Session() as sess: print(sess.run(t2)) print(sess.run(equal)) Output: [0. 1. 0. 0.] [False True False False]
关于使用tf.cond的例子,可以参考博文 博文2
|
CopyRight 2018-2019 实验室设备网 版权所有 |