阅读(8.4k) 书签 (0)

TensorFlow函数教程:tf.nn.ctc_greedy_decoder

2019-01-31 13:45 更新

tf.nn.ctc_greedy_decoder函数

tf.nn.ctc_greedy_decoder(
    inputs,
    sequence_length,
    merge_repeated=True
)

定义在:tensorflow/python/ops/ctc_ops.py.

参见指南:神经网络>连接时间分类(CTC)

对输入中给出的logit上执行greedy解码.(最佳方法)

注意:无论merge_repeated的值如何,如果给定时间和批处理的最大索引对应于空白索引(num_classes - 1),则不会发出新元素.

如果merge_repeatedTrue,则在输出中合并重复的类.这意味着如果连续logits的最大索引相同,则只发出第一个.序列A B B * B * B(其中'*'是空白标签)将会是:

  • A B B B,如果merge_repeated=True.
  • A B B B B,如果merge_repeated=False.

参数:

  • inputs:3-Dfloat Tensor,大小为[max_time, batch_size, num_classes],是logits.
  • sequence_length:1-Dint32向量,包含序列长度,具有大小[batch_size].
  • merge_repeatedBoolean,默认值:True.

返回:

元组(decoded, log_probabilities),其中已解码:单个元素列表,decoded[0] 是一个包含解码输出的SparseTensor:

decoded.indices: Indices matrix (total_decoded_outputs, 2),行存储:[batch, time].

decoded.values: Values vector, size (total_decoded_outputs),向量存储波束 j 的解码类.

decoded.dense_shape: Shape vector, size (2),形状值为[batch_size, max_decoded_length] 

neg_sum_logits:对于找到的序列,一个浮点矩阵(batch_size x 1)包含每个时间框架中最大 logit 之和的负数