建立图
一个TensorFlow程序默认是建立一个图的,除了系统自动建图以外,还可以用tf.Graph()手动建立,并做一些其他的操作
如果想要获得程序一开始默认的图,可以使用tf.get_default_graph()函数
如果想要重新建立一张图代替原来的图,可以使用tf.reset_default_graph()函数
注意:在使用tf.reset_default_graph函数时必须保证当前图的资源已经全部释放,否则会报错。例如如果在当前图中使用tf.InteractiveSession函数建立了一个会话,在会话结束时却没有调用close进行关闭,那么再执行tf.reset_default_graph函数时,就会报错。
示例代码如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
-
- var1 = tf.constant(8.0)
- print("var1:",var1.graph)
-
- mygraph = tf.Graph()
- with mygraph.as_default():
- var2 = tf.constant(9.9)
- print("var2:",var2.graph)
- print("mygraph:",mygraph)
-
- mygraph2 = tf.get_default_graph()
- print("mygraph2:",mygraph2)
-
- tf.reset_default_graph()
- mygraph3 = tf.get_default_graph()
- print("mygraph3:",mygraph3)
获取张量
在图里面可以通过名字得到其对应的元素,使用的是get_tensor_by_name()函数
示例代码如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
-
- var1 = tf.constant(8.0)
- print("var1:",var1.graph)
-
- mygraph = tf.Graph()
- with mygraph.as_default():
- var2 = tf.constant(9.9)
- print("var2:",var2.graph)
- print("mygraph:",mygraph)
-
- mygraph2 = tf.get_default_graph()
- print("mygraph2:",mygraph2)
-
- tf.reset_default_graph()
- mygraph3 = tf.get_default_graph()
- print("mygraph3:",mygraph3)
-
- t1 = mygraph.get_tensor_by_name(name = var2.name)
- print(t1)
-
- print("var2.name:",var2.name)
获取元素列表
如果想看一下图中的全部元素,可以使用get_operations()函数来实现。
示例代码如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
-
- var1 = tf.constant(8.0)
- print("var1:",var1.graph)
-
- mygraph = tf.Graph()
- with mygraph.as_default():
- var2 = tf.constant(9.9)
- var3 = tf.constant(11.9)
- print("var2:",var2.graph)
- print("mygraph:",mygraph)
-
- mygraph2 = tf.get_default_graph()
- print("mygraph2:",mygraph2)
-
- tf.reset_default_graph()
- mygraph3 = tf.get_default_graph()
- print("mygraph3:",mygraph3)
-
- t1 = mygraph.get_tensor_by_name(name = var2.name)
- print(t1)
-
- print("var2.name:",var2.name)
-
- t2 = mygraph.get_operations()
- print(t2)
获取对象
使用tf.Graph.as_graph_element(obj,allow_tensor = True,allow_operation = True)函数,可以根据对象来获取元素,即传入的是一个对象,返回一个张量或是一个OP。该函数具有验证和转换功能。
示例代码如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
-
- var1 = tf.constant(8.0)
- print("var1:",var1.graph)
-
- mygraph = tf.Graph()
- with mygraph.as_default():
- var2 = tf.constant(9.9)
- var3 = tf.constant(11.9)
- print("var2:",var2.graph)
- print("mygraph:",mygraph)
-
- mygraph2 = tf.get_default_graph()
- print("mygraph2:",mygraph2)
-
- tf.reset_default_graph()
- mygraph3 = tf.get_default_graph()
- print("mygraph3:",mygraph3)
-
- t1 = mygraph.get_tensor_by_name(name = var2.name)
- print(t1)
-
- print("var2.name:",var2.name)
-
- t2 = mygraph.get_operations()
- print(t2)
-
- t3 = mygraph.as_graph_element(var2)
- print(t3)
获取节点操作
使用get_operation_by_name()函数
示例代码如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
-
- mygraph = tf.get_default_graph()
-
- x1 = tf.constant([[2.3,6.6]])
- x2 = tf.constant([[5.3],[9.6]])
- tensor1 = tf.matmul(x1,x2,name = "op")
-
- test2 = mygraph.get_operation_by_name(tensor1.op.name)
- print(test2)