TensorFlow:tf.Session函数
tf.Session 函数
Session 类
定义在:tensorflow/python/client/session.py.
请参阅指南:运行图>会话管理
用于运行TensorFlow操作的类.
一个Session对象封装了Operation执行对象的环境,并对Tensor对象进行计算.例如:
# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# Launch the graph in a session.
sess = tf.Session()
# Evaluate the tensor `c`.
print(sess.run(c))
session可能拥有的资源,如:tf.Variable,tf.QueueBase和tf.ReaderBase.不再需要时释放这些资源是非常重要的.为此,请在session中调用tf.Session.close方法,或使用session作为上下文管理器.以下两个例子是等价的:
# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()
# Using the context manager.
with tf.Session() as sess:
sess.run(...)
该ConfigProto协议缓存公开了用于session的各种配置选项.例如,要创建为设备放置使用软约束的session,并记录生成的放置决策,请按如下方式创建session:
# Launch the graph in a session that allows soft device placement and
# logs the placement decisions.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=True))
Session属性
- graph
本次session上发布的图表. - graph_def
底层TensorFlow图形的可序列化版本. - 函数返回:
- 包含底层TensorFlow图表中所有操作的节点的graph_pb2.GraphDef原型.
- sess_str
Session 方法
__init__
__init__(
target='',
graph=None,
config=None
)
创建一个新的TensorFlow session.
如果在构建session时没有指定graph参数,则将在session中启动默认关系图.如果使用多个图(在同一个过程中使用tf.Graph()创建,则必须为每个图使用不同的sessio,但是每个图都可以用于多个sessio中,在这种情况下,将图形显式地传递给sessio构造函数通常更清晰.
方法参数
- target:(可选)要连接到的执行引擎.默认使用进程内引擎.有关更多示例,请参阅“Distributed TensorFlow”.
- graph:(可选)将被启动的Graph(如上所述).
- config:(可选)具有session配置选项的ConfigProto协议缓冲区.
__enter__
__enter__()
__exit__
__exit__(
exec_type,
exec_value,
exec_tb
)
as_default
as_default()
返回使该对象成为默认session的上下文管理器.
与with关键字一起使用来指定在此session中调用tf.Operation.run或tf.Tensor.eval应执行的操作.
c = tf.constant(..)
sess = tf.Session()
with sess.as_default():
assert tf.get_default_session() is sess
print(c.eval())
要获取当前的默认session,请使用tf.get_default_session.
注意:退出上下文时,as_default上下文管理器不会关闭session,并且必须显式关闭session.
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
print(c.eval())
# ...
with sess.as_default():
print(c.eval())
sess.close()
或者,您可以使用tf.Session():创建会在退出上下文时自动关闭的session,包括未捕获的异常发生时.
注意,默认session是当前线程的一个属性.如果你创建一个新的线程,并希望在该线程中使用默认的session,则必须明确地添加一个sess.as_default():到该线程的函数.
注意,输入一个sess.as_default():块不会影响当前的默认图形.如果您正在使用多个图表,并且与其sess.graph值不同,则tf.get_default_graph必须明确地输入一个带有sess.graph.as_default():块来创建sess.graph默认图形.
as_default()方法返回:
使用此session作为默认session的上下文管理器.
close
close()
关闭这个session.
调用此方法可释放与session关联的所有资源.
可能引发的异常
- tf.errors.OpError:如果在关闭TensorFlow session时发生错误,则会有一个子类.
list_devices
list_devices()
列出此session中的可用设备.
devices = sess.list_devices()
for d in devices:
print(d.name)
列表中的每个元素都具有以下属性:
- name:具有设备全名的字符串.例如:/job:worker/replica:0/task:3/device:CPU:0
- device_type:设备的类型(例如CPU,GPU,TPU)
- memory_limit:存储设备上可用的最大内存量.注意:取决于设备,可用内存可能会大大减少.
可能引发的异常
- tf.errors.OpError:如果遇到错误(例如session处于无效状态,或发生网络错误).
list_devices()方法返回:
list_devices()方法将返回session中的设备列表.
make_callable
make_callable(
fetches,
feed_list=None,
accept_options=False
)
返回运行特定步骤的Python可调用对象.
返回的可调用将采取 len (feed_list) 参数,其类型必须是feed_list各自元素的兼容feed值.例如,如果feed_list的元素i是一个tf.Tensor,则返回的可调用的第 i 参数必须是一个 numpy 的 ndarray(或可转化成ndarray的东西)具有匹配元素类型和形状.请参阅tf.Session.run允许的Feed键和值类型的详细信息.
返回的可调用将具有与tf.Session.run(fetches, ...).例如,如果fetches是tf.Tensor ,则可调用将返回一个numpy的ndarray; 如果fetches是一个tf.Operation,它会返回None.
方法参数
- fetches:要获取的值或值列表.请参阅tf.Session.run允许的获取类型的详细信息.
- feed_list:(可选)一个feed_dict键列表.请参阅tf.Session.run允许的Feed键类型的详细信息.
- accept_options:(可选)如果为True,则返回的Callable将是能够接受tf.RunOptions和tf.RunMetadata可选关键字参数options,并且run_metadata分别使用与tf.Session.run相同的语法和语义,这对于某些使用情况很有用(分析和调试),但会导致可测量放缓的Callable的表现.默认为False.
方法返回
一个函数调用将执行由feed_list定义的步骤时,并在此会话中读取的函数.
可能引发的异常
- TypeError:如果fetches或feed_list不能被解释为tf.Session.run的参数.
partial_run
partial_run(
handle,
fetches,
feed_dict=None
)
通过更多的feed和fetche继续执行.
这是实验性的,可能会有变化.
要使用部分执行,用户首先调用partial_run_setup(),然后是一个序列partial_run().partial_run_setup指定将在随后的partial_run调用中使用的提要和提取列表.
可选feed_dict参数允许调用者覆盖图中张量的值.请参阅run()以获取更多信息.
下面是一个简单的例子:
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(r1, c)
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
res = sess.partial_run(h, r2, feed_dict={c: res})
方法参数
- handle:部分运行序列的处理器.
- fetches:单个图形元素,图形元素列表或其值为图形元素或图形元素列表的字典(请参阅“run文档”).
- feed_dict:将图表元素映射到值的字典(如上所述).
方法返回:
可以是单个值,如果fetches是单个图元素,或者列表值,如果fetches是列表,或者是具有与字典相同的键fetches的字典(请参阅“run文档”).
方法可能引发的异常
- tf.errors.OpError:其中一个子类出错.
partial_run_setup
partial_run_setup(
fetches,
feeds=None
)
为部分运行设置一个带有feed和fetche的图形.
这是实验性的,可能会有变化.
请注意,与运行相反,feeds只能指定图形元素.张量将由随后的partial_run调用提供.
方法参数
- fetches:一个图形元素,或者一个图元素列表.
- feeds:一个图形元素,或者一个图元素列表.
方法返回:
局部运行的处理器.
可能引发的异常
- RuntimeError:如果这Session是无效状态(例如已经关闭).
- TypeError:如果fetches或者feed_dict键的类型不合适.
- tf.errors.OpError:如果发生TensorFlow错误,或者它的一个子类.
reset
@staticmethod
reset(
target,
containers=None,
config=None
)
在target上重置资源容器,并关闭所有连接的会话.
资源容器分布在同一个群集target中的所有工作人员.target重置资源容器时,与该容器关联的资源将被清除.尤其是,容器中的所有变量都将变得不确定:它们将失去其值和形状.
注意:(i)reset()目前仅用于分布式会话.(ii)任何名为target的主的session将被关闭.
如果没有提供资源容器,则所有的容器都被重置.
方法参数
- target:连接到的执行引擎.
- containers:资源容器名称字符串的列表,如果所有容器都将被重置,则为None.
- config:(可选)具有配置选项的协议缓冲区.
可能引发的异常
- tf.errors.OpError:或者如果在重置容器时发生错误,它的一个子类.
run
run(
fetches,
feed_dict=None,
options=None,
run_metadata=None
)
在fetches中运行操作和计算张量.
此方法运行一个TensorFlow计算的一个“步骤”,通过运行所需的图形片段来执行每个Operation和计算fetches中的每个Tensor,用 feed_dict 中的值替换相应的输入值.
所述fetches参数可以是一个单一的图形元素,或任意嵌套列表、元组、namedtuple、字典、或含有它的叶子图表元素OrderedDict.图形元素可以是以下类型之一:
- 一个tf.Operation.相应的取值将会是None.
- 一个tf.Tensor.相应的取值将是一个包含该张量值的numpy ndarray.
- 一个tf.SparseTensor.相应的取值将是一个tf.SparseTensorValue包含稀疏张量的值.
- 一个get_tensor_handle操作.相应的取值将是包含该张量句柄的numpy ndarray.
- A string是图中张量或操作的名称.
run()返回的值具有与fetches参数相同的形状,叶子由TensorFlow返回的相应值替换.
示例:
a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
# 'fetches' can be a singleton
v = session.run(a)
# v is the numpy array [10, 20]
# 'fetches' can be a list.
v = session.run([a, b])
# v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
# 1-D array [1.0, 2.0]
# 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
MyData = collections.namedtuple('MyData', ['a', 'b'])
v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
# v is a dict with
# v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
# 'b' (the numpy array [1.0, 2.0])
# v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
# [10, 20]
可选的feed_dict参数允许调用者在关系图中覆盖张量的值.feed_dict 中的每个键都可以是以下类型之一:
- 如果键是a tf.Tensor,则值可以是可以转换为与dtype张量相同的Python标量,字符串,列表或numpy ndarray .此外,如果键是a tf.placeholder,则将检查值的形状是否与占位符兼容.
- 如果键是a tf.SparseTensor,则值应该是a tf.SparseTensorValue.
- 如果键是Tensors或SparseTensors 的嵌套元组,则该值应该是一个嵌套元组,其结构与映射到上面相应的值相同.
feed_dict 中的每个值必须可转换为相应键的 dtype 的 numpy 数组.
可选options参数需要一个[ RunOptions] 原型.这些选项允许控制此特定步骤的行为(例如,启用跟踪).
可选run_metadata参数需要一个[ RunMetadata] 原型.在适当的时候,这个步骤的非张量输出将被收集在那里.例如,当用户在options打开跟踪时,配置文件信息将被收集到该参数中并传回.
方法参数
- fetches:单个图形元素,图形元素列表或其值为图元素或图元素列表(如上所述)的字典.
- feed_dict:将图表元素映射到值的字典(如上所述).
- options:一个[ RunOptions]协议缓冲区
- run_metadata:一个[ RunMetadata]协议缓冲区
方法返回:
单个值如果fetches是单个图元素,或者值列表if fetches是列表,或者具有与fetches字典(如上所述)相同的关键字的字典.
可能发生的异常
- RuntimeError:如果该Session是无效状态(例如已经关闭).
- TypeError:如果fetches或者feed_dict键的类型不合适.
- ValueError:如果fetches或者feed_dict键无效或者引用Tensor不存在的键.