阅读(13.7k) 书签 (0)

使用Iterator类

2018-09-20 17:09 更新

tf.contrib.data.Iterator

Iterator 类

定义在:tensorflow/contrib/data/python/ops/dataset_ops.py.

表示通过数据集进行迭代的状态.

属性

initializer

应运行一个 tf.Operation 来初始化这个迭代器.

返回:

返回运行一个 tf.Operation 来初始化该迭代器

注意:

  • ValueError:如果此迭代器自动初始化自身.
output_shapes

返回此迭代器元素的每个组件的形状.

返回:

tf.TensorShape 对象的嵌套结构,对应于该迭代器元素的每个组件.

output_types

返回此迭代器元素的每个组件的类型.

返回:

tf.DType 对象的嵌套结构,对应于该迭代器元素的每个组件.

方法

__init__

__init__ ( 
    iterator_resource , 
    initializer , 
    output_types ,
    output_shapes
)

从给定的迭代器资源创建一个新的迭代器.

注意(mrry):大多数用户不会直接调用这个初始化程序,而是使用 Iterator.from_dataset()或Dataset.make_one_shot_iterator().

ARGS:

  • iterator_resource:代表迭代器的 tf.resource 标量 tf.Tensor.
  • initializer:应运行一个 tf.Operation 以初始化这个迭代器.
  • output_types:tf.DType 对象的嵌套结构,对应于该迭代器元素的每个组件.
  • output_shapes:tf.TensorShape 对象的嵌套结构,对应于该数据集的元素的每个组件.

dispose_op

dispose_op ( name = None )

返回一个 tf.Operation 销毁该迭代器.

返回的操作可用于释放此迭代器消耗的任何资源,而不关闭会话.

ARGS:

  • name:(可选)创建操作的名称.

返回:

返回一个 tf.Operation.

from_dataset

from_dataset ( 
    dataset , 
    shared_name = None
 ) 

从给定的 Dataset 创建一个新的、未初始化的 Iterator.

要初始化这个迭代器,你必须运行它的 initializer,如下所示:

dataset = ...
iterator = Iterator.from_dataset(dataset)
# ...
sess.run(iterator.initializer)

ARGS:

  • dataset:一个 Dataset 对象
  • shared_name:(可选)如果非空,则该迭代器将在共享相同设备的多个会话(例如,使用远程服务器)时在给定名称下共享.

返回:

返回一个 Iterator.

from_string_handle

from_string_handle ( 
    string_handle , 
    output_types , 
    output_shapes = None
 )

根据给定的句柄创建一个新的、未初始化的 Iterator .

该方法允许您定义“可馈送”的迭代器,您可以通过在 tf.Session.run 调用中提供值来在具体的迭代器之间进行选择.在这种情况下,string_handle 会是一个 tf.placeholder,你会在每个步骤中使用 tf.contrib.data.Iterator.string_handle 的值满足它.

例如,如果有两个迭代器在训练数据集和一个测试数据集中标记了当前位置,则可以选择在每个步骤中使用哪种方法,如下所示:

train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())

test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
    handle, train_iterator.output_types)

next_element = iterator.get_next()
loss = f(next_element)

train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})

ARGS:

  • string_handle:一个 tf.string 类型的标量 tf.Tensor,用于计算该 Iterator.string_handle()方法生成的句柄.
  • output_types:tf.DType 对象的嵌套结构,对应于该迭代器元素的每个组件.
  • output_shapes:(可选)tf.TensorShape 对象的嵌套结构,与该数据集的元素的每个组件对应.如果省略,每个组件将具有无约束的形状.

返回:

返回一个 Iterator.

from_structure

from_structure(
    output_types,
    output_shapes=None,
    shared_name=None
)

使用给定的结构创建一个新的、未初始化的 Iterator .

此迭代器构造方法可用于创建可重用多个不同数据集的迭代器.

返回的迭代器未绑定到特定的数据集,它没有初始值设定项.要初始化迭代器,请运行 Iterator.make_initializer (dataset) 返回的操作.

以下是一个例子:

iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)

dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)

# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())

# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
  # Initialize the iterator to `dataset_range`
  sess.run(range_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

  # Initialize the iterator to `dataset_evens`
  sess.run(evens_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

ARGS:

  • output_types:tf.DType 对象的嵌套结构,对应于该迭代器元素的每个组件.
  • output_shapes:(可选)tf.TensorShape 对象的嵌套结构,与该数据集的元素的每个组件对应.如果省略,每个组件将具有无约束的形状.
  • shared_name:(可选)如果非空,则该迭代器将在共享相同设备的多个会话(例如,使用远程服务器)时在给定名称下共享.

返回:

返回一个 Iterator.

注意:

  • TypeError:如果 output_shapes 和 output_types 的结构不相同.

get_next

get_next ( name = None )

返回 tf.Tensor 的嵌套结构,包含下一个元素.

ARGS:

  • name:(可选)创建的操作的名称.

返回:

返回 tf.Tensor 对象的嵌套结构.

make_initializer

make_initializer (dataset)

返回一个 tf.Operation 初始化此数据集的迭代器.

ARGS:

  • dataset:一个 Dataset 与此迭代器具有兼容的结构.

返回:

返回可以在给定的数据集上运行的一个 tf.Operation 以初始化该迭代器.

注意:

  • TypeError:如果数据集和这个迭代器没有兼容的元素结构.

string_handle

string_handle ( name = None )

返回一个字符串值的 tf.Tensor 来表示此迭代器.

ARGS:

  • name:(可选)创建的操作的名称.

返回:

一个 tf.string 类型的标量 tf.Tensor.