阅读(6.9k) 书签 (0)

TensorFlow函数:将base_type的对象转换为Tensor

2018-12-20 11:02 更新

tf.register_tensor_conversion_function 函数

register_tensor_conversion_function(
    base_type,
    conversion_func,
    priority=100
)

定义在:tensorflow/python/framework/ops.py.

参见指南:构建图表>用于构建TensorFlow的库

注册一个函数,用于将 base_type 的对象转换为 Tensor.

这个转换函数必须具有以下签名:

def conversion_func(value, dtype=None, name=None, as_ref=False):
  # ...

如果指定,它必须返回具有给定 dtype 的 Tensor.如果转换函数创建一个新的 Tensor,它应该使用给定的 name(如果指定).所有异常将被传播给调用方.

转换函数可能会为一些输入返回 NotImplemented.在这种情况下,转换过程将继续尝试后续的转换函数.

如果 as_ref 为 true,则该函数必须返回一个Tensor引用,如 Variable.

注意:转换函数将按照优先级顺序执行,然后是注册顺序.要确保转换函数 F 在另一个转换函数 G 之前运行,请确保函数 F 使用比函数 G 小的优先级注册.

参数:

  • base_type:conversion_func 接受的所有对象的基本类型或基本类型的元组.
  • conversion_func:将 base_type 的实例转换为 Tensor 的函数.
  • priority:表示应用此转换函数的优先级的可选整数.具有较小优先级值的转换函数比具有较大优先级值的转换函数更早运行.默认为100.

可能引发的异常:

  • TypeError:如果参数没有适当的类型.