TensorFlow读取批处理
2018-09-20 17:37 更新
tf.contrib.data.read_batch_features
read_batch_features ( file_pattern , batch_size , features , reader , reader_args = None , randomize_input = True , num_epochs = None , capacity = 10000 )
定义在:tensorflow/contrib/data/python/ops/dataset_ops.py.
读取示例的批处理.
更多的例子如下:
serialized_examples = [
features {
feature { key: "age" value { int64_list { value: [ 0 ] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
},
features {
feature { key: "age" value { int64_list { value: [] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
}
]
我们可以使用参数:
features: {
"age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"gender": FixedLenFeature([], dtype=tf.string),
"kws": VarLenFeature(dtype=tf.string),
}
预期的输出是:
{
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
"kws": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["code", "art", "sports"]
dense_shape=[2, 2]),
}
ARGS:
- file_pattern:包含示例记录的文件路径的文件或模式列表.参见图案规则的 tf.gfile.Glob.
- batch_size:一个整数,表示该数据集的连续元素的个数,并在单个批处理中合并.
- features:FixedLenFeature 或 VarLenFeature 值的字典映射特征键.见 tf. parse_example.
- reader:可以用文件名张量和 (可选) reader_args 调用的函数或类, 并返回序列化示例的数据集.
- reader_args:要传递给读取器类的其他参数.
- randomize_input:输入是否应该是随机的.
- num_epochs:指定要通过数据集读取的次数的整数.如果没有, 则永远循环遍历数据集.
- capacity:ShuffleDataset 的容量.大的容量能确保更好的洗牌,但会增加内存使用和启动时间.
返回:
从功能键到 Tensor 或 SparseTensor 对象的字典.