阅读(23.1k) 书签 (0)

TensorFlow:tf.Session函数

2018-01-22 10:37 更新

tf.Session 函数

Session 类

定义在:tensorflow/python/client/session.py.

请参阅指南:运行图>会话管理

用于运行TensorFlow操作的类.

一个Session对象封装了Operation执行对象的环境,并对Tensor对象进行计算.例如:

# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b

# Launch the graph in a session.
sess = tf.Session()

# Evaluate the tensor `c`.
print(sess.run(c))

session可能拥有的资源,如:tf.Variable,tf.QueueBase和tf.ReaderBase.不再需要时释放这些资源是非常重要的.为此,请在session中调用tf.Session.close方法,或使用session作为上下文管理器.以下两个例子是等价的:

# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()

# Using the context manager.
with tf.Session() as sess:
  sess.run(...)

ConfigProto协议缓存公开了用于session的各种配置选项.例如,要创建为设备放置使用软约束的session,并记录生成的放置决策,请按如下方式创建session:

# Launch the graph in a session that allows soft device placement and
# logs the placement decisions.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=True))

Session属性

  • graph
    本次session上发布的图表.
  • graph_def
    底层TensorFlow图形的可序列化版本.
    • 函数返回:
      • 包含底层TensorFlow图表中所有操作的节点的graph_pb2.GraphDef原型.
  • sess_str

Session 方法

__init__

__init__(
    target='',
    graph=None,
    config=None
)

创建一个新的TensorFlow session.

如果在构建session时没有指定graph参数,则将在session中启动默认关系图.如果使用多个图(在同一个过程中使用tf.Graph()创建,则必须为每个图使用不同的sessio,但是每个图都可以用于多个sessio中,在这种情况下,将图形显式地传递给sessio构造函数通常更清晰.

方法参数

  • target:(可选)要连接到的执行引擎.默认使用进程内引擎.有关更多示例,请参阅“Distributed TensorFlow”.
  • graph:(可选)将被启动的Graph(如上所述).
  • config:(可选)具有session配置选项的ConfigProto协议缓冲区.

__enter__

__enter__()

__exit__

__exit__(
    exec_type,
    exec_value,
    exec_tb
)

as_default

as_default()

返回使该对象成为默认session的上下文管理器.

与with关键字一起使用来指定在此session中调用tf.Operation.run或tf.Tensor.eval应执行的操作.

c = tf.constant(..)
sess = tf.Session()

with sess.as_default():
  assert tf.get_default_session() is sess
  print(c.eval())

要获取当前的默认session,请使用tf.get_default_session.

注意:退出上下文时,as_default上下文管理器不会关闭session,并且必须显式关闭session.

c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()

或者,您可以使用tf.Session():创建会在退出上下文时自动关闭的session,包括未捕获的异常发生时.

注意,默认session是当前线程的一个属性.如果你创建一个新的线程,并希望在该线程中使用默认的session,则必须明确地添加一个sess.as_default():到该线程的函数.

注意,输入一个sess.as_default():块不会影响当前的默认图形.如果您正在使用多个图表,并且与其sess.graph值不同,则tf.get_default_graph必须明确地输入一个带有sess.graph.as_default():块来创建sess.graph默认图形.

as_default()方法返回:

使用此session作为默认session的上下文管理器.

close

close()

关闭这个session.

调用此方法可释放与session关联的所有资源.

可能引发的异常

  • tf.errors.OpError:如果在关闭TensorFlow session时发生错误,则会有一个子类.

list_devices

list_devices()

列出此session中的可用设备.

devices = sess.list_devices()
for d in devices:
  print(d.name)

列表中的每个元素都具有以下属性:

  • name:具有设备全名的字符串.例如:/job:worker/replica:0/task:3/device:CPU:0
  • device_type:设备的类型(例如CPU,GPU,TPU) 
  • memory_limit:存储设备上可用的最大内存量.注意:取决于设备,可用内存可能会大大减少.

可能引发的异常

  • tf.errors.OpError:如果遇到错误(例如session处于无效状态,或发生网络错误).

list_devices()方法返回:

list_devices()方法将返回session中的设备列表.

make_callable

make_callable(
    fetches,
    feed_list=None,
    accept_options=False
)

返回运行特定步骤的Python可调用对象.

返回的可调用将采取 len (feed_list) 参数,其类型必须是feed_list各自元素的兼容feed值.例如,如果feed_list的元素i是一个tf.Tensor,则返回的可调用的第 i 参数必须是一个 numpy 的 ndarray(或可转化成ndarray的东西)具有匹配元素类型和形状.请参阅tf.Session.run允许的Feed键和值类型的详细信息.

返回的可调用将具有与tf.Session.run(fetches, ...).例如,如果fetches是tf.Tensor ,则可调用将返回一个numpy的ndarray; 如果fetches是一个tf.Operation,它会返回None.

方法参数

  • fetches:要获取的值或值列表.请参阅tf.Session.run允许的获取类型的详细信息.
  • feed_list:(可选)一个feed_dict键列表.请参阅tf.Session.run允许的Feed键类型的详细信息.
  • accept_options:(可选)如果为True,则返回的Callable将是能够接受tf.RunOptions和tf.RunMetadata可选关键字参数options,并且run_metadata分别使用与tf.Session.run相同的语法和语义,这对于某些使用情况很有用(分析和调试),但会导致可测量放缓的Callable的表现.默认为False.

方法返回

一个函数调用将执行由feed_list定义的步骤时,并在此会话中读取的函数.

可能引发的异常

  • TypeError:如果fetches或feed_list不能被解释为tf.Session.run的参数.

partial_run

partial_run(
    handle,
    fetches,
    feed_dict=None
)

通过更多的feed和fetche继续执行.

这是实验性的,可能会有变化.

要使用部分执行,用户首先调用partial_run_setup(),然后是一个序列partial_run().partial_run_setup指定将在随后的partial_run调用中使用的提要和提取列表.

可选feed_dict参数允许调用者覆盖图中张量的值.请参阅run()以获取更多信息.

下面是一个简单的例子:

a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(r1, c)

h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
res = sess.partial_run(h, r2, feed_dict={c: res})

方法参数

  • handle:部分运行序列的处理器.
  • fetches:单个图形元素,图形元素列表或其值为图形元素或图形元素列表的字典(请参阅“run文档”).
  • feed_dict:将图表元素映射到值的字典(如上所述).

方法返回:

可以是单个值,如果fetches是单个图元素,或者列表值,如果fetches是列表,或者是具有与字典相同的键fetches的字典(请参阅“run文档”).

方法可能引发的异常

  • tf.errors.OpError:其中一个子类出错.

partial_run_setup

partial_run_setup(
    fetches,
    feeds=None
)

为部分运行设置一个带有feed和fetche的图形.

这是实验性的,可能会有变化.

请注意,与运行相反,feeds只能指定图形元素.张量将由随后的partial_run调用提供.

方法参数

  • fetches:一个图形元素,或者一个图元素列表.
  • feeds:一个图形元素,或者一个图元素列表.

方法返回:

局部运行的处理器.

可能引发的异常

  • RuntimeError:如果这Session是无效状态(例如已经关闭).
  • TypeError:如果fetches或者feed_dict键的类型不合适.
  • tf.errors.OpError:如果发生TensorFlow错误,或者它的一个子类.

reset

@staticmethod
reset(
    target,
    containers=None,
    config=None
)

在target上重置资源容器,并关闭所有连接的会话.

资源容器分布在同一个群集target中的所有工作人员.target重置资源容器时,与该容器关联的资源将被清除.尤其是,容器中的所有变量都将变得不确定:它们将失去其值和形状.

注意:(i)reset()目前仅用于分布式会话.(ii)任何名为target的主的session将被关闭.

如果没有提供资源容器,则所有的容器都被重置.

方法参数

  • target:连接到的执行引擎.
  • containers:资源容器名称字符串的列表,如果所有容器都将被重置,则为None.
  • config:(可选)具有配置选项的协议缓冲区.

可能引发的异常

  • tf.errors.OpError:或者如果在重置容器时发生错误,它的一个子类.

run

run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

在fetches中运行操作和计算张量.

此方法运行一个TensorFlow计算的一个“步骤”,通过运行所需的图形片段来执行每个Operation和计算fetches中的每个Tensor,用 feed_dict 中的值替换相应的输入值.

所述fetches参数可以是一个单一的图形元素,或任意嵌套列表、元组、namedtuple、字典、或含有它的叶子图表元素OrderedDict.图形元素可以是以下类型之一:

  • 一个tf.Operation.相应的取值将会是None.
  • 一个tf.Tensor.相应的取值将是一个包含该张量值的numpy ndarray.
  • 一个tf.SparseTensor.相应的取值将是一个tf.SparseTensorValue包含稀疏张量的值.
  • 一个get_tensor_handle操作.相应的取值将是包含该张量句柄的numpy ndarray.
  • A string是图中张量或操作的名称.

run()返回的值具有与fetches参数相同的形状,叶子由TensorFlow返回的相应值替换.

示例:

a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
# 'fetches' can be a singleton
v = session.run(a)
# v is the numpy array [10, 20]
# 'fetches' can be a list.
v = session.run([a, b])
# v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
# 1-D array [1.0, 2.0]
# 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
MyData = collections.namedtuple('MyData', ['a', 'b'])
v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
# v is a dict with
# v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
# 'b' (the numpy array [1.0, 2.0])
# v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
# [10, 20]

可选的feed_dict参数允许调用者在关系图中覆盖张量的值.feed_dict 中的每个键都可以是以下类型之一:

  • 如果键是a tf.Tensor,则值可以是可以转换为与dtype张量相同的Python标量,字符串,列表或numpy ndarray .此外,如果键是a tf.placeholder,则将检查值的形状是否与占位符兼容.
  • 如果键是a tf.SparseTensor,则值应该是a tf.SparseTensorValue.
  • 如果键是Tensors或SparseTensors 的嵌套元组,则该值应该是一个嵌套元组,其结构与映射到上面相应的值相同.

feed_dict 中的每个值必须可转换为相应键的 dtype 的 numpy 数组.

可选options参数需要一个[ RunOptions] 原型.这些选项允许控制此特定步骤的行为(例如,启用跟踪).

可选run_metadata参数需要一个[ RunMetadata] 原型.在适当的时候,这个步骤的非张量输出将被收集在那里.例如,当用户在options打开跟踪时,配置文件信息将被收集到该参数中并传回.

方法参数

  • fetches:单个图形元素,图形元素列表或其值为图元素或图元素列表(如上所述)的字典.
  • feed_dict:将图表元素映射到值的字典(如上所述).
  • options:一个[ RunOptions]协议缓冲区
  • run_metadata:一个[ RunMetadata]协议缓冲区

方法返回:

单个值如果fetches是单个图元素,或者值列表if fetches是列表,或者具有与fetches字典(如上所述)相同的关键字的字典.

可能发生的异常

  • RuntimeError:如果该Session是无效状态(例如已经关闭).
  • TypeError:如果fetches或者feed_dict键的类型不合适.
  • ValueError:如果fetches或者feed_dict键无效或者引用Tensor不存在的键.