TensorFlow函数教程:tf.nn.ctc_loss
tf.nn.ctc_loss函数
tf.nn.ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True
)
定义在:tensorflow/python/ops/ctc_ops.py.
参见指南:神经网络>连接时间分类(CTC)
计算CTC(连接时间分类)loss.
输入要求:
sequence_length(b) <= time for all b
max(labels.indices(labels.indices[:, 1] == b, 2))
<= sequence_length(b) for all b.
笔记:
此类为您执行softmax操作,因此输入应该是例如LSTM对输出的线性预测.
该inputs
张量的最内层的维度大小,num_classes
,代表num_labels + 1
类别,其中num_labels是实际的标签的数量,而最大的值(num_classes - 1)
是为空白标签保留的.
例如,对于包含3个标签[a, b, c]
的词汇表,num_classes = 4
,并且标签索引是{a: 0, b: 1, c: 2, blank: 3}
.
关于参数preprocess_collapse_repeated
和ctc_merge_repeated
:
如果preprocess_collapse_repeated
为True,则在loss计算之前运行预处理步骤,其中传递给loss的重复标签会合并为单个标签.如果训练标签来自,例如强制对齐,并因此具有不必要的重复,则这是有用的.
如果ctc_merge_repeated
设置为False,则在CTC计算的深处,重复的非空白标签将不会合并,并被解释为单个标签.这是CTC的简化(非标准)版本.
以下是(大致)预期的第一顺序行为表:
preprocess_collapse_repeated=False
,ctc_merge_repeated=True
典型的CTC行为:输出实际的重复类,其间有空白,还可以输出中间没有空白的重复类,这需要由解码器折叠.
preprocess_collapse_repeated=True
,ctc_merge_repeated=False
不要得知输出重复的类,因为它们在训练之前在输入标签中折叠.
preprocess_collapse_repeated=False
,ctc_merge_repeated=False
输出中间有空白的重复类,但通常不需要解码器折叠/合并重复的类.
preprocess_collapse_repeated=True
,ctc_merge_repeated=True
未经测试,很可能不会得知输出重复的类.
该ignore_longer_outputs_than_inputs
选项允许在处理输出长于输入的序列时指定CTCLoss的行为.如果为true,则CTCLoss将仅为这些项返回零梯度,否则返回InvalidArgument错误,停止训练.
参数:
labels
:一个int32
SparseTensor
;labels.indices[i, :] == [b, t]
表示labels.values[i]
存储(batch b, time t)的id;labels.values[i]
必须采用[0, num_labels)
中的值.inputs
:3-Dfloat
Tensor
;如果time_major == False,这将是一个Tensor
,形状:[batch_size, max_time, num_classes]
;如果time_major == True(默认值),这将是一个Tensor
,形状:[max_time, batch_size, num_classes]
;是logits.sequence_length
:1-Dint32
向量,大小为[batch_size]
;序列长度.preprocess_collapse_repeated
:Boolean,默认值:False;如果为True,则在CTC计算之前折叠重复的标签.ctc_merge_repeated
:Boolean,默认值:True.ignore_longer_outputs_than_inputs
:Boolean,默认值:False;如果为True,则输出比输入长的序列将被忽略.time_major
:inputs
张量的形状格式;如果是True,那些Tensors
必须具有形状[max_time, batch_size, num_classes]
;如果为False,则Tensors
必须具有形状[batch_size, max_time, num_classes]
;使用time_major = True
(默认)更有效,因为它避免了在ctc_loss计算开始时的转置.但是,大多数TensorFlow数据都是批处理为主的,因此通过此函数还可以接受以批处理为主的形式的输入.
返回:
1-Dfloat
Tensor
,大小为[batch]
包含负对数概率.
可能引发的异常:
TypeError
:如果标签不是SparseTensor
.