阅读(6.1k) 书签 (0)

用于字符串张量的TensorFlow操作

2018-01-02 10:59 更新

#版权所有2015 TensorFlow作者.版权所有.

#根据Apache许可证版本2.0(“许可证”)获得许可;

#除了符合许可证外,您不得使用此文件.

#您可以在获得许可证副本

#http://www.apache.org/licenses/LICENSE-2.0

#除非适用法律要求或书面同意,软件

根据许可证分发的#按“现状”分发,

#没有任何形式的明示或暗示保证或条件.

#请参阅许可证以了解特定语言的管理权限和权限

#许可证下的限制.

# ==============================================================================

""用于字符串张量的操作""

请参阅@ {$ python / string_ops}指南.

@@string_to_hash_bucket_fast @@string_to_hash_bucket_strong @@string_to_hash_bucket @@reduce_join @@string_join @@string_split @@substr @@as_string @@encode_base64 @@decode_base64 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_string_ops import * from tensorflow.python.util import deprecation # pylint: enable=wildcard-import def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name """Split elements of `source` based on `delimiter` into a `SparseTensor`. Let N be the size of source (typically N will be the batch size). Split each element of `source` based on `delimiter` and return a `SparseTensor` containing the split tokens. Empty tokens are ignored. If `delimiter` is an empty string, each element of the `source` is split into individual strings, each containing one byte. (This includes splitting multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is treated as a set of delimiters with each considered a potential split point. For example: N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output will be st.indices = [0, 0; 0, 1; 1, 0; 1, 1; 1, 2] st.shape = [2, 3] st.values = ['hello', 'world', 'a', 'b', 'c'] Args: source: `1-D` string `Tensor`, the strings to split. delimiter: `0-D` string `Tensor`, the delimiter character, the string should be length 0 or 1. skip_empty: A `bool`. If `True`, skip the empty strings from the result. Raises: ValueError: If delimiter is not a string. Returns: A `SparseTensor` of rank `2`, the strings split according to the delimiter. The first column of the indices corresponds to the row in `source` and the second column corresponds to the index of the split component in this row. """ delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) source = ops.convert_to_tensor(source, dtype=dtypes.string) # pylint: disable=protected-access indices, values, shape = gen_string_ops._string_split( source, delimiter=delimiter, skip_empty=skip_empty) # pylint: enable=protected-access indices.set_shape([None, 2]) values.set_shape([None]) shape.set_shape([2]) return sparse_tensor.SparseTensor(indices, values, shape) def _reduce_join_reduction_dims(x, axis, reduction_indices): """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None.""" # TODO(aselle): Remove this after deprecation if reduction_indices is not None: if axis is not None: raise ValueError("Can't specify both 'axis' and 'reduction_indices'.") axis = reduction_indices if axis is not None: return axis else: # Fast path: avoid creating Rank and Range ops if ndims is known. if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None: return constant_op.constant( np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32) # Otherwise, we rely on Range and Rank to do the right thing at run-time. return math_ops.range(array_ops.rank(x) - 1, -1, -1) def reduce_join(inputs, axis=None, keep_dims=False, separator="", name=None, reduction_indices=None): reduction_indices = _reduce_join_reduction_dims( inputs, axis, reduction_indices) return gen_string_ops.reduce_join( inputs=inputs, reduction_indices=reduction_indices, keep_dims=keep_dims, separator=separator, name=name) reduce_join.__doc__ = deprecation.rewrite_argument_docstring( gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis") ops.NotDifferentiable("StringToHashBucket") ops.NotDifferentiable("StringToHashBucketFast") ops.NotDifferentiable("StringToHashBucketStrong") ops.NotDifferentiable("ReduceJoin") ops.NotDifferentiable("StringJoin") ops.NotDifferentiable("StringSplit") ops.NotDifferentiable("AsString") ops.NotDifferentiable("EncodeBase64") ops.NotDifferentiable("DecodeBase64")