我一直在使用TensorFlow中矩阵乘法的介绍性示例。
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
当我打印乘积时,它显示为一个张量对象:
<tensorflow.python.framework.ops.Tensor object at 0x10470fcd0>
但是我怎么知道产品的价值呢?
下面的方法不起作用:
print product
Tensor("MatMul:0", shape=TensorShape([Dimension(1), Dimension(1)]), dtype=float32)
我知道图在会话上运行,但是没有任何方法可以检查张量对象的输出而不在会话中运行图吗?
import tensorflow as tf
sess = tf.InteractiveSession()
x = [[1.,2.,1.],[1.,1.,1.]]
y = tf.nn.softmax(x)
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
print(product.eval())
tf.reset_default_graph()
sess.close()
我发现即使在阅读了所有的答案之后,我也不容易理解需要什么,直到我执行了这个。TensofFlow对我来说也是新的。
def printtest():
x = tf.constant([1.0, 3.0])
x = tf.Print(x,[x],message="Test")
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
b = tf.add(x, x)
with tf.Session() as sess:
sess.run(init)
print(sess.run(b))
sess.close()
但是您仍然可能需要执行会话返回的值。
def printtest():
x = tf.constant([100.0])
x = tf.Print(x,[x],message="Test")
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
b = tf.add(x, x)
with tf.Session() as sess:
sess.run(init)
c = sess.run(b)
print(c)
sess.close()
import tensorflow as tf
sess = tf.InteractiveSession()
x = [[1.,2.,1.],[1.,1.,1.]]
y = tf.nn.softmax(x)
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
print(product.eval())
tf.reset_default_graph()
sess.close()
在Tensorflow 1.x中
import tensorflow as tf
tf.enable_eager_execution()
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
#print the product
print(product) # tf.Tensor([[12.]], shape=(1, 1), dtype=float32)
print(product.numpy()) # [[12.]]
用Tensorflow 2。X,默认开启急切模式。因此下面的代码与TF2.0一起工作。
import tensorflow as tf
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
#print the product
print(product) # tf.Tensor([[12.]], shape=(1, 1), dtype=float32)
print(product.numpy()) # [[12.]]