阅读(13.2k) 书签 (0)

TensorFlow读取批处理

2018-09-20 17:37 更新

tf.contrib.data.read_batch_features

read_batch_features ( 
    file_pattern , 
    batch_size , 
    features , 
    reader , 
    reader_args = None , 
    randomize_input = True , 
    num_epochs = None , 
    capacity = 10000 
)

定义在:tensorflow/contrib/data/python/ops/dataset_ops.py.

读取示例的批处理.

更多的例子如下:

serialized_examples = [
  features {
    feature { key: "age" value { int64_list { value: [ 0 ] } } }
    feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
    feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
  },
  features {
    feature { key: "age" value { int64_list { value: [] } } }
    feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
    feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
  }
]

我们可以使用参数:

features: {
  "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
  "gender": FixedLenFeature([], dtype=tf.string),
  "kws": VarLenFeature(dtype=tf.string),
}

预期的输出是:

{
  "age": [[0], [-1]],
  "gender": [["f"], ["f"]],
  "kws": SparseTensor(
    indices=[[0, 0], [0, 1], [1, 0]],
    values=["code", "art", "sports"]
    dense_shape=[2, 2]),
}

ARGS:

  • file_pattern:包含示例记录的文件路径的文件或模式列表.参见图案规则的 tf.gfile.Glob.
  • batch_size:一个整数,表示该数据集的连续元素的个数,并在单个批处理中合并.
  • features:FixedLenFeature 或 VarLenFeature 值的字典映射特征键.见 tf. parse_example.
  • reader:可以用文件名张量和 (可选) reader_args 调用的函数或类, 并返回序列化示例的数据集.
  • reader_args:要传递给读取器类的其他参数.
  • randomize_input:输入是否应该是随机的.
  • num_epochs:指定要通过数据集读取的次数的整数.如果没有, 则永远循环遍历数据集.
  • capacity:ShuffleDataset 的容量.大的容量能确保更好的洗牌,但会增加内存使用和启动时间.

返回:

从功能键到 Tensor 或 SparseTensor 对象的字典.