TensorFlow函数教程:tf.lite.Interpreter
tf.lite.Interpreter函数
类Interpreter
别名:>
- 类 tf.contrib.lite.Interpreter
- 类 tf.lite.Interpreter
定义在:tensorflow/lite/python/interpreter.py。
TF-Lite模型的解释器推理。
__init__
__init__(
model_path=None,
model_content=None
)
构造函数。
参数:
- model_path:TF-Lite Flatbuffer文件的路径。
- model_content:模型的内容。
可能引发的异常:
- ValueError:如果解释器无法创建。
方法
allocate_tensors
allocate_tensors()
get_input_details
get_input_details()
获取模型输入详细信息。
返回:
输入详细信息列表。
get_output_details
get_output_details()
获取模型输出详细信息
返回:
输出详细信息列表。
get_tensor
get_tensor(tensor_index)
获取输入张量的值(获取副本)。
如果您想避免复制,请使用tensor()。
参数:
- tensor_index:得到张量的张量指数。该值可以从get_output_details中的'index'字段获得。
返回:
一个numpy数组。
get_tensor_details
get_tensor_details()
获取具有有效张量详细信息的每个张量的张量详细信息。
未找到有关张量的所需信息的张量不会添加到列表中。这包括没有名字的临时张量。
返回:
包含张量信息的词典列表。
invoke
invoke()
调用解释器。
在调用之前,请务必设置输入大小,分配张量和填充值。
可能引发的异常:
- ValueError:当底层解释器失败时引发ValueError。
reset_all_variables
reset_all_variables()
resize_tensor_input
resize_tensor_input(
input_index,
tensor_size
)
调整输入张量的大小。
参数:
- input_index:要设置的输入的张量索引。该值可以从get_input_details中的'index'字段获得。
- tensor_size:tensor_shape调整输入的大小。
可能引发的异常:
- ValueError:如果解释器无法调整输入张量的大小。
set_tensor
set_tensor(
tensor_index,
value
)
设置输入张量的值。请注意,这将复制value的数据。
如果要避免复制,可以使用该tensor()函数获取指向tflite解释器中输入缓冲区的numpy缓冲区。
参数:
- tensor_index:设置的张量的张量指数。该值可以从get_input_details中的'index'字段获得。
- value:张量值设置。
可能引发的异常:
- ValueError:如果解释器无法设置张量。
tensor
tensor(tensor_index)
返回给出当前张量缓冲区的numpy视图的函数。
这允许在没有副本的情况下读写这个张量。这更接近于C ++ Interpreter类接口的tensor()成员,因此得名。小心不要通过调用allocate_tensors()和invoke()来保持这些输出引用。
用法:
interpreter.allocate_tensors()
input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
for i in range(10):
input().fill(3.)
interpreter.invoke()
print("inference %s" % output())
注意这个函数如何避免直接生成numpy数组。将实际numpy视图保持到数据的时间不能超过必要的时间是很重要的。如果你这样做了,则无法再调用解释器,因为解释器可能会调整大小并使引用的张量无效。NumPy API不允许底层缓冲区的任何可变性。
错误:
input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
interpreter.allocate_tensors() # This will throw RuntimeError
for i in range(10):
input.fill(3.)
interpreter.invoke() # this will throw RuntimeError since input,output
参数:
- tensor_index:得到的张量的张量指数。该值可以从get_output_details中的'index'字段获得。
返回:
一个函数,可以在任何点返回指向内部TFLite张量状态的新numpy数组。永久保持该函数是安全的,但永久保持numpy阵列是不安全的。