阅读(2.9k) 书签 (0)

PyTorch torch脚本

2020-09-15 10:39 更新

原文: PyTorch torch脚本

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存并加载到没有 Python 依赖项的进程中。

我们提供了将模型从纯 Python 程序逐步过渡到可以独立于 Python 运行的 TorchScript 程序的工具,例如在独立的 C ++程序中。 这样就可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境中,在该生产环境中 Python 程序可能由于性能和多线程原因而处于不利地位。

有关 TorchScript 的简要介绍,请参见 TorchScript 简介教程。

有关将 PyTorch 模型转换为 TorchScript 并在 C ++中运行的端到端示例,请参见在 C ++中加载 PyTorch 模型教程。

创建 TorchScript 代码

class torch.jit.ScriptModule¶

property code¶

返回forward方法的内部图的漂亮打印表示形式(作为有效的 Python 语法)。

property graph¶

返回forward方法的内部图形的字符串表示形式。

save(f, _extra_files=ExtraFilesMap{})¶

torch.jit.save

class torch.jit.ScriptFunction¶

功能上与 ScriptModule 等效,但是代表单个功能,没有任何属性或参数。

torch.jit.script(obj)¶

为函数或nn.Module编写脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,然后返回 ScriptModuleScriptFunction 。 TorchScript 本身是 Python 语言的子集,因此 Python 并非所有功能都可以使用,但是我们提供了足够的功能来在张量上进行计算并执行与控制有关的操作。

torch.jit.script可用作模块和功能的函数,以及 TorchScript 类和功能的修饰器@torch.jit.script

Scripting a function

@torch.jit.script装饰器将通过编译函数的主体来构造 ScriptFunction

示例(编写函数):

import torch


@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


print(type(foo))  # torch.jit.ScriptFuncion


## See the compiled graph as Python code
print(foo.code)


## Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
Scripting an nn.Module

默认情况下,为nn.Module编写脚本将编译forward方法,并递归编译forward调用的任何方法,子模块和函数。 如果nn.Module仅使用 TorchScript 支持的功能,则无需更改原始模块代码。 script将构建 ScriptModule ,该副本具有原始模块的属性,参数和方法的副本。

示例(使用参数编写简单模块的脚本):

import torch


class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))


        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)


    def forward(self, input):
        output = self.weight.mv(input)


        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output


scripted_module = torch.jit.script(MyModule(2, 3))

示例(使用跟踪的子模块编写模块脚本):

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))


    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input


scripted_module = torch.jit.script(MyModule())

要编译除forward以外的方法(并递归编译其调用的任何内容),请将 @torch.jit.export 装饰器添加到该方法。 要选择退出编译,请使用 @torch.jit.ignore

示例(模块中的导出方法和忽略方法):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()


    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10


    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()


    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99


scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)¶

跟踪一个函数并返回将使用即时编译进行优化的可执行文件或 ScriptFunction 。 对于仅在TensorTensor的列表,字典和元组上运行的代码,跟踪是理想的选择。

使用torch.jit.tracetorch.jit.trace_module ,您可以将现有模块或 Python 函数转换为 TorchScript ScriptFunctionScriptModule 。 您必须提供示例输入,然后我们运行该函数,记录在所有张量上执行的操作。

  • 独立功能的最终记录将产生 ScriptFunction
  • nn.Modulenn.Moduleforward功能的所得记录产生 ScriptModule

该模块还包含原始模块也具有的任何参数。

警告

跟踪仅正确记录不依赖数据的功能和模块(例如,对张量中的数据没有条件)并且不包含任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。 跟踪仅记录在给定张量上运行给定函数时执行的操作。 因此,返回的 ScriptModule 将始终在任何输入上运行相同的跟踪图。 当期望模块根据输入和/或模块状态运行不同的操作集时,这具有重要意义。 例如,

  • 跟踪将不会记录任何控制流,例如 if 语句或循环。 当整个模块的控制流恒定时,这很好,并且通常内联控制流决策。 但是有时控制流实际上是模型本身的一部分。 例如,循环网络是输入序列(可能是动态)长度上的循环。
  • 在返回的 ScriptModule 中,在trainingeval模式下具有不同行为的操作将始终像在跟踪过程中一样处于运行状态,无论是哪种模式 ] ScriptModule 已插入。

在这种情况下,跟踪是不合适的, scripting 是更好的选择。 如果跟踪此类模型,则可能在随后的模型调用中静默地得到不正确的结果。 在执行可能会导致产生不正确跟踪的操作时,跟踪器将尝试发出警告。

参数

  • 函数(可调用的 torch.nn.Module)– Python 函数或torch.nn.Moduleexample_inputs一起运行。 func的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。 将模块传递到 torch.jit.trace 时,仅运行并跟踪forward方法(有关详细信息,参见 torch.jit.trace)。
  • example_inputs (tuple )–示例输入的元组,将在跟踪时传递给函数。 假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入来运行结果跟踪。 example_inputs也可以是单个张量,在这种情况下,它会自动包装在元组中。

Keyword Arguments

  • check_trace (bool,可选)–检查通过跟踪代码运行的相同输入是否产生相同的输出。 默认值:True。 例如,如果您的网络包含不确定性操作,或者即使检查程序失败,但您确定网络正确,则可能要禁用此功能。
  • check_inputs (元组列表 可选)–输入参数的元组列表,应使用这些元组来检查跟踪内容 是期待。 每个元组等效于example_inputs中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的example_inputs进行检查
  • check_tolerance (python:float 可选)–在检查程序中使用的浮点比较公差。 如果结果由于已知原因(例如操作员融合)而在数值上出现差异,则可以使用此方法来放松检查器的严格性。

退货

如果callablenn.Modulenn.Moduleforward,则trace将使用包含跟踪代码的单个forward方法返回 ScriptModule 对象。 返回的 ScriptModule 将具有与原始nn.Module相同的子模块和参数集。 如果callable是独立功能,则trace返回 ScriptFunction

示例(跟踪函数):

import torch


def foo(x, y):
    return 2 * x + y


## Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))


## `traced_foo` can now be run with the TorchScript interpreter or saved
## and loaded in a Python-free environment

示例(跟踪现有模块):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)


    def forward(self, x):
        return self.conv(x)


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)


## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)


## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)¶

跟踪模块并返回可执行文件 ScriptModule ,该文件将使用即时编译进行优化。 将模块传递到 torch.jit.trace 时,仅运行并跟踪forward方法。 使用trace_module,您可以指定方法名称的字典作为示例输入,以跟踪下面的参数(请参见example_inputs)。

有关跟踪的更多信息,参见 torch.jit.trace

Parameters

  • mod (Torch.nn.Module)–一种torch.nn.Module,其中包含名称在example_inputs中指定的方法。 给定的方法将被编译为单个 <cite>ScriptModule</cite> 的一部分。
  • example_inputs (dict )–包含样本输入的字典,该样本输入由mod中的方法名称索引。 输入将在跟踪时传递给名称与输入键对应的方法。 { 'forward' : example_forward_input, 'method2': example_method2_input}

Keyword Arguments

  • check_trace (bool, optional) – Check if the same inputs run through traced code produce the same outputs. Default: True. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.
  • check_inputs (字典列表 可选)–输入参数的字典列表,用于检查跟踪内容 是期待。 每个元组等效于example_inputs中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的example_inputs进行检查
  • check_tolerance (python:float__, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

具有单个forward方法的 ScriptModule 对象,其中包含跟踪的代码。 当functorch.nn.Module时,返回的 ScriptModule 将具有与func相同的子模块和参数集。

示例(使用多种方法跟踪模块):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)


    def forward(self, x):
        return self.conv(x)


    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)


## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)


## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)


## Trace specific methods on a module (specified in `inputs`), constructs
## a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

torch.jit.save(m, f, _extra_files=ExtraFilesMap{})¶

保存此模块的脱机版本以在单独的过程中使用。 保存的模块将序列化此模块的所有方法,子模块,参数和属性。 可以使用torch::jit::load(filename)将其加载到 C ++ API 中,或者使用 torch.jit.load 加载到 Python API 中。

为了能够保存模块,它不得对本地 Python 函数进行任何调用。 这意味着所有子模块也必须是torch.jit.ScriptModule的子类。

危险

所有模块,无论使用哪种设备,都始终在加载期间加载到 CPU 中。 这与 load 的语义不同,并且将来可能会发生变化。

Parameters

  • m –要保存的 ScriptModule。
  • f –类似于文件的对象(必须实现写入和刷新)或包含文件名的字符串。
  • _extra_files -从文件名映射到将作为“ f”的一部分存储的内容。

Warning

如果您使用的是 Python 2,torch.jit.save不支持StringIO.StringIO作为有效的类似文件的对象。 这是因为 write 方法应返回写入的字节数; StringIO.write()不这样做。

请改用io.BytesIO之类的东西。

例:

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


m = torch.jit.script(MyModule())


## Save to file
torch.jit.save(m, 'scriptmodule.pt')
## This line is equivalent to the previous
m.save("scriptmodule.pt")


## Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)


## Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{})¶

加载先前用 torch.jit.save 保存的 ScriptModuleScriptFunction

之前保存的所有模块,无论使用何种设备,都首先加载到 CPU 中,然后再移动到保存它们的设备上。 如果失败(例如,因为运行时系统没有某些设备),则会引发异常。

Parameters

  • f –类似于文件的对象(必须实现读取,读取行,告诉和查找),或包含文件名的字符串
  • map_location (字符串 torch设备)– torch.savemap_location的简化版本 用于动态地将存储重新映射到另一组设备。
  • _extra_files (文件名到内容的字典)–映射中给定的多余文件名将被加载,其内容将存储在提供的映射中。

Returns

ScriptModule 对象。

Example:

import torch
import io


torch.jit.load('scriptmodule.pt')


## Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())


## Load all tensors to the original device
torch.jit.load(buffer)


## Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))


## Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')


## Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])

混合跟踪和脚本编写

在许多情况下,将模型转换为 TorchScript 都可以使用跟踪或脚本编写。 可以组成跟踪和脚本以适合模型一部分的特定要求。

脚本函数可以调用跟踪函数。 当您需要在简单的前馈模型周围使用控制流时,这特别有用。 例如,序列到序列模型的波束搜索通常将用脚本编写,但是可以调用使用跟踪生成的编码器模块。

示例(在脚本中调用跟踪的函数):

import torch


def foo(x, y):
    return 2 * x + y


traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))


@torch.jit.script
def bar(x):
    return traced_foo(x, x)

跟踪的函数可以调用脚本函数。 即使大部分模型只是前馈网络,当模型的一小部分需要一些控制流时,这也很有用。 跟踪函数调用的脚本函数内部的控制流已正确保留。

示例(在跟踪函数中调用脚本函数):

import torch


@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z


traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此组合也适用于nn.Module,在这里它可用于通过跟踪来生成子模块,该跟踪可以从脚本模块的方法中调用。

示例(使用跟踪模块):

import torch
import torchvision


class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))


    def forward(self, input):
        return self.resnet(input - self.means)


my_script_module = torch.jit.script(MyScriptModule())

迁移到 PyTorch 1.2 递归脚本 API

本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。 如果您不熟悉 TorchScript,则可以跳过本节。 PyTorch 1.2 对 TorchScript API 进行了两个主要更改。

\1. torch.jit.script 现在将尝试递归编译遇到的函数,方法和类。 调用torch.jit.script后,编译将是“选择退出”,而不是“选择加入”。

2.现在torch.jit.script(nn_module_instance)是创建 ScriptModule 的首选方法,而不是从torch.jit.ScriptModule继承。 这些更改组合在一起,提供了一个更简单易用的 API,可将您的nn.Module转换为 ScriptModule ,可以在非 Python 环境中进行优化和执行。

新用法如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))


my_model = Model()
my_scripted_model = torch.jit.script(my_model)

  • 该模块的forward是默认编译的。 从forward调用的方法将按照在forward中使用的顺序进行延迟编译。
  • 要编译未从forward调用的forward以外的方法,请添加@torch.jit.export
  • 要停止编译器编译方法,请添加 @torch.jit.ignore@torch.jit.unused@ignore离开
  • 方法作为对 python 的调用,并且@unused将其替换为异常。 @ignored无法导出; @unused可以。
  • 可以推断大多数属性类型,因此不需要torch.jit.Attribute。 对于空容器类型,请使用 PEP 526 样式类注释对其类型进行注释。
  • 可以使用Final类注释来标记常量,而不是将成员的名称添加到__constants__中。
  • 可以使用 Python 3 类型提示代替torch.jit.annotate

As a result of these changes, the following items are considered deprecated and should not appear in new code:

  • @torch.jit.script_method装饰器
  • 继承自torch.jit.ScriptModule的类
  • torch.jit.Attribute包装器类
  • __constants__数组
  • torch.jit.annotate功能

模块

Warning

@torch.jit.ignore 注释的行为在 PyTorch 1.2 中发生了变化。 在 PyTorch 1.2 之前,@ ignore 装饰器用于使函数或方法可从导出的代码中调用。 要恢复此功能,请使用@torch.jit.unused()@torch.jit.ignore现在等同于@torch.jit.ignore(drop=False)。 有关详细信息,参见 @torch.jit.ignore@torch.jit.unused

当传递给 torch.jit.script 函数时,torch.nn.Module的数据将复制到 ScriptModule ,然后 TorchScript 编译器将编译该模块。 该模块的forward默认为编译状态。 从forward调用的方法以及它们在forward中使用的顺序都是按延迟顺序编译的。

torch.jit.export(fn)¶

此修饰符指示nn.Module上的方法用作 ScriptModule 的入口点,应进行编译。

forward隐式地假定为入口点,因此不需要此装饰器。 从forward调用的函数和方法在编译器看到的情况下进行编译,因此它们也不需要此装饰器。

示例(在方法上使用@torch.jit.export):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99


    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10


    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)


    def unused_method(self, x):
        return x - 20


## `m` will contain compiled methods:
##     `forward`
##     `another_forward`
##     `implicitly_compiled_method`
## `unused_method` will not be compiled since it was not called from
## any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

功能

功能没有太大变化,可以根据需要用 @torch.jit.ignoretorch.jit.unused 装饰。

## Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2


## Marks a function as ignored, if nothing
## ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2


## As with ignore, if nothing calls it then it has no effect.
## If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4


## Doesn't do anything, this function is already
## the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

默认情况下,将导出用户定义的 TorchScript 类中的所有内容,可以根据需要用 @torch.jit.ignore 修饰功能。

属性

TorchScript 编译器需要知道模块属性的类型。 大多数类型可以从成员的值推断出来。 空列表和字典不能推断其类型,而必须使用 PEP 526 样式类注释来注释其类型。 如果无法推断类型并且未对显式类型进行注释,则不会将其作为属性添加到结果 ScriptModule

旧 API:

from typing import Dict
import torch


class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)


m = MyModule()

新 API:

from typing import Dict


class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]


    def __init__(self):
        super(MyModule, self).__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}


        # The attribute type here is inferred to be `int`
        self.my_int = 20


    def forward(self):
        pass


m = torch.jit.script(MyModule())

Python 2

如果您受制于 Python 2 并且无法使用类注释语法,则可以使用__annotations__类成员直接应用类型注释。

from typing import Dict


class MyModule(torch.jit.ScriptModule):
    __annotations__ = {'my_dict': Dict[str, int]}


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = {}
        self.my_int = 20

常数

Final类型的构造函数可用于将成员标记为常量。 如果成员未标记为常量,则将其复制为结果 ScriptModule 作为属性。 如果已知该值是固定的,则使用Final可以进行优化,并提供附加的类型安全性。

Old API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2


    def forward(self):
        pass
m = MyModule()

New API:

try:
    from typing_extensions import Final
except:
    # If you don't have `typing_extensions` installed, you can use a
    # polyfill from `torch.jit`.
    from torch.jit import Final


class MyModule(torch.nn.Module):


    my_constant: Final[int]


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2


    def forward(self):
        pass


m = torch.jit.script(MyModule())

变量

假定容器的类型为Tensor,并且是非可选的(有关更多信息,请参见默认类型)。 以前,torch.jit.annotate用来告诉 TorchScript 编译器类型是什么。 现在支持 Python 3 样式类型提示。

import torch
from typing import Dict, Optional


@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

TorchScript 语言参考

TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script 装饰器),也可以通过跟踪从 Python 代码自动生成。 使用跟踪时,通过仅在张量上记录实际的运算符并简单地执行和丢弃其他周围的 Python 代码,代码会自动转换为 Python 的此子集。

使用@torch.jit.script装饰器直接编写 TorchScript 时,程序员只能使用 TorchScript 支持的 Python 子集。 本节记录了 TorchScript 支持的功能,就像它是独立语言的语言参考一样。 本参考中未提及的 Python 的任何功能都不属于 TorchScript。 有关可用的 Pytorch 张量方法,模块和功能的完整参考,请参见内置函数。

作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。 这样就可以禁用 TorchScript 并使用pdb之类的标准 Python 工具调试该功能。 反之则不成立:有许多有效的 Python 程序不是有效的 TorchScript 程序。 相反,TorchScript 特别专注于表示 PyTorch 中的神经网络模型所需的 Python 功能。

类型

TorchScript 与完整的 Python 语言之间的最大区别是 TorchScript 仅支持表达神经网络模型所需的一小部分类型。 特别是,TorchScript 支持:

|

类型

|

描述

| | --- | --- | | Tensor | 任何 dtype,尺寸或后端的 PyTorch 张量 | | Tuple[T0, T1, ...] | 包含子类型T0T1等(例如Tuple[Tensor, Tensor])的元组 | | bool | 布尔值 | | int | 标量整数 | | float | 标量浮点数 | | str | 一串 | | List[T] | 所有成员均为T类型的列表 | | Optional[T] | 无或输入T的值 | | Dict[K, V] | 键类型为K而值类型为V的字典。 只能将strintfloat作为密钥类型。 | | T | 一个 TorchScript 类 | | NamedTuple[T0, T1, ...] | collections.namedtuple元组类型 |

与 Python 不同,TorchScript 函数中的每个变量都必须具有一个静态类型。 这使优化 TorchScript 函数变得更加容易。

示例(类型不匹配)

import torch


@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...


Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~...  <--- HERE
        r = torch.rand(1)
    else:
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE
...

默认类型

默认情况下,TorchScript 函数的所有参数均假定为 Tensor。 要指定 TorchScript 函数的参数是其他类型,可以使用上面列出的类型使用 MyPy 样式的类型注释。

import torch


@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x


print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用typing模块中的 Python 3 类型提示来注释类型。

import torch
from typing import Tuple


@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x


print(foo(3, (torch.rand(3), torch.rand(3))))

在我们的示例中,我们使用基于注释的类型提示来确保 Python 2 的兼容性。

假定空列表为List[Tensor],空字典为Dict[str, Tensor]。 要实例化其他类型的空列表或字典,请使用 Python 3 类型提示。 如果您使用的是 Python 2,则可以使用torch.jit.annotate

示例(Python 3 的类型注释):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple


class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()


    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))


        my_dict: Dict[str, int] = {}
        return my_list, my_dict


x = torch.jit.script(EmptyDataStructures())

示例(适用于 Python 2 的torch.jit.annotate):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple


class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()


    def forward(self, x):
        # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]


        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list = torch.jit.annotate(List[Tuple[int, float]], [])
        for i in range(10):
            my_list.append((i, float(x.item())))


        my_dict = torch.jit.annotate(Dict[str, int], {})
        return my_list, my_dict


x = torch.jit.script(EmptyDataStructures())

可选类型细化

在 if 语句的条件内或在assert中检查与None的比较时,TorchScript 将优化Optional[T]类型的变量的类型。 编译器可以推理与andornot结合的多个None检查。 对于未明确编写的 if 语句的 else 块,也会进行优化。

None检查必须在 if 语句的条件内; 将None检查分配给变量,并在 if 语句的条件下使用它,将不会优化检查中的变量类型。 仅局部变量将被细化,self.x之类的属性将不会且必须分配给要细化的局部变量。

示例(优化参数和局部变量的类型):

import torch
import torch.nn as nn
from typing import Optional


class M(nn.Module):
    z: Optional[int]


    def __init__(self, z):
        super(M, self).__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z


    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1


        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z


        # Refinement via an `assert`
        assert z is not None
        x += z
        return x


module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 类

如果 Python 类使用 @torch.jit.script 注释,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x


  def aug_add_x(self, inc):
    self.x += inc

此子集受限制:

  • 所有函数必须是有效的 TorchScript 函数(包括__init__())。

  • 这些类必须是新型类,因为我们使用__new__()和 pybind11 来构造它们。

  • TorchScript 类是静态类型的。 只能通过在__init__()方法中分配给 self 来声明成员。

\> 例如,在__init__()方法之外分配给self: > > > @torch.jit.script > class Foo: > def assign_x(self): > self.x = torch.rand(2, 3) > > > > 将导致: > > > RuntimeError: > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: > def assign_x(self): > self.x = torch.rand(2, 3) > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE > >

  • 类的主体中不允许使用除方法定义之外的任何表达式。

  • 除了从object继承以指定新样式类外,不支持继承或任何其他多态策略。

定义了一个类之后,就可以像其他任何 TorchScript 类型一样在 TorchScript 和 Python 中互换使用该类:

## Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second


@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second


p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

命名为元组

collections.namedtuple产生的类型可以在 TorchScript 中使用。

import torch
import collections


Point = collections.namedtuple('Point', ['x', 'y'])


@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y


p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

表达式

支持以下 Python 表达式。

文字

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表结构

假定一个空列表具有List[Tensor]类型。 其他列表文字的类型是从成员的类型派生的。 有关更多详细信息,请参见默认类型。

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组结构

(3, 4)
(3,)

字典结构

假定一个空字典为Dict[str, Tensor]类型。 其他 dict 文字的类型是从成员的类型派生的。 有关更多详细信息,请参见默认类型。

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量

有关如何解析变量的信息,请参见变量分辨率。

my_variable_name

算术运算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符

a and b
a or b
not b

下标和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用

调用内置函数

torch.rand(3, dtype=torch.int)

调用其他脚本函数:

import torch


@torch.jit.script
def foo(x):
    return x + 1


@torch.jit.script
def bar(x):
    return foo(x)

方法调用

调用诸如张量之类的内置类型的方法:x.mm(y)

在模块上,必须先编译方法才能调用它们。 TorchScript 编译器以递归方式编译在编译其他方法时看到的方法。 默认情况下,编译从forward方法开始。 将编译forward调用的任何方法,以及这些方法调用的任何方法,依此类推。 要以forward以外的方法开始编译,请使用 @torch.jit.export 装饰器(forward隐式标记为@torch.jit.export)。

直接调用子模块(例如self.resnet(input))等效于调用其forward方法(例如self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))


    def helper(self, input):
        return self.resnet(input - self.means)


    def forward(self, input):
        return self.helper(input)


    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)


    def other_helper(self, input):
        return input + 10


## `my_script_module` will have the compiled methods `forward`, `helper`,
## `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式

x if x > y else y

演员表

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数

self.my_parameter
self.my_submodule.my_parameter

语句

TorchScript 支持以下类型的语句:

简单分配

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配分配

a, b = tuple_or_list
a, b, *c = a_tuple

多项分配

a = b, c = tup

打印报表

print("the result of an add:", a + b)

If 语句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除布尔值外,浮点数,整数和张量还可以在条件中使用,并将隐式转换为布尔值。

While 循环

a = 0
while a < 4:
    print(a)
    a += 1

适用于范围为的循环

x = 0
for i in range(10):
    x *= i

用于遍历元组的循环

这些展开循环,为元组的每个成员生成一个主体。 主体必须对每个成员进行正确的类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

用于在常量 nn.ModuleList 上循环

要在已编译方法中使用nn.ModuleList,必须通过将属性名称添加到__constants__列表中的类型来将其标记为常量。 nn.ModuleList上的 for 循环将在编译时展开循环的主体,并使用常量模块列表的每个成员。

class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(2))


    def forward(self, input):
        return self.weight + input


class MyModule(torch.nn.Module):
    __constants__ = ['mods']


    def __init__(self):
        super(MyModule, self).__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])


    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

中断并继续

for i in range(5):
    if i == 1:
    continue
    if i == 3:
    break
    print(i)

返回

return a, b

可变分辨率

TorchScript 支持 Python 的可变分辨率(即作用域)规则的子集。 局部变量的行为与 Python 中的相同,不同之处在于,在通过函数的所有路径上,变量必须具有相同的类型。 如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它是错误的。

同样,如果沿函数的某些路径仅将定义为,则不允许使用该变量。

Example:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...


y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~...  <--- HERE
        y = 4
    print(y)
...

定义函数时,会在编译时将非局部变量解析为 Python 值。 然后使用 Python 值使用中描述的规则将这些值转换为 TorchScript 值。

使用 Python 值

为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围范围中的 Python 值。 例如,任何时候只要引用torch,当声明函数时,TorchScript 编译器实际上就会将其解析为torch Python 模块。 这些 Python 值不是 TorchScript 的一流部分。 而是在编译时将它们分解为 TorchScript 支持的原始类型。 这取决于编译发生时引用的 Python 值的动态类型。 本节介绍在 TorchScript 中访问 Python 值时使用的规则。

功能

TorchScript 可以调用 Python 函数。 当将模型逐步转换为 TorchScript 时,此功能非常有用。 可以将模型逐函数移至 TorchScript,而对 Python 函数的调用保留在原处。 这样,您可以在进行过程中逐步检查模型的正确性。

torch.jit.ignore(drop=False, **kwargs)¶

该装饰器向编译器指示应忽略函数或方法,而将其保留为 Python 函数。 这使您可以将代码保留在尚未与 TorchScript 兼容的模型中。 具有忽略功能的模型无法导出; 请改用 torch.jit.unused。

示例(在方法上使用@torch.jit.ignore):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    @torch.jit.ignore
    def debugger(self, x):
        import pdb
        pdb.set_trace()


    def forward(self, x):
        x += 10
        # The compiler would normally try to compile `debugger`,
        # but since it is `@ignore`d, it will be left as a call
        # to Python
        self.debugger(x)
        return x


m = torch.jit.script(MyModule())


## Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

示例(在方法上使用@torch.jit.ignore(drop=True)):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    @torch.jit.ignore(drop=True)
    def training_method(self, x):
        import pdb
        pdb.set_trace()


    def forward(self, x):
        if self.training:
            self.training_method(x)
        return x


m = torch.jit.script(MyModule())


## This is OK since `training_method` is not saved, the call is replaced
## with a `raise`.
m.save("m.pt")

torch.jit.unused(fn)¶

此装饰器向编译器指示应忽略函数或方法,并用引发异常的方法代替。 这样,您就可以在尚不兼容 TorchScript 的模型中保留代码,并仍然可以导出模型。

示例(在方法上使用@torch.jit.unused):


import torch
import torch.nn as nn

class MyModule(nn.Module):
def __init__(self, use_memory_efficent):
super(MyModule, self).__init__()
self.use_memory_efficent = use_memory_efficent

@torch.jit.unused
def memory_efficient(self, x):
import pdb
pdb.set_trace()
return x + 10

def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self.use_memory_efficient:
return self.memory_efficient(x)
else:
return x + 10

m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))

torch.jit.is_scripting()¶

在编译时返回 True 的函数,否则返回 False 的函数。 这对于使用@unused 装饰器尤其有用,可以将尚不兼容 TorchScript 的代码保留在模型中。 .. testcode:

import torch


@torch.jit.unused
def unsupported_linear_op(x):
    return x


def linear(x):
   if not torch.jit.is_scripting():
      return torch.linear(x)
   else:
      return unsupported_linear_op(x)

Python 模块上的属性查找

TorchScript 可以在模块上查找属性。 像torch.add这样的内置功能可以通过这种方式访问。 这使 TorchScript 可以调用其他模块中定义的函数。

Python 定义的常量

TorchScript 还提供了一种使用 Python 中定义的常量的方法。 这些可用于将超参数硬编码到函数中,或定义通用常量。 有两种指定 Python 值应视为常量的方式。

  1. 查找为模块属性的值假定为常量:

import math
import torch


@torch.jit.script
def fn():
    return math.pi

  1. 可以通过使用Final[T]注释 ScriptModule 的属性来将其标记为常量。

import torch
import torch.nn as nn


class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]


    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4


    def forward(self, input):
        return self.a + input


f = torch.jit.script(Foo())

支持的常量 Python 类型是

  • int
  • float
  • bool
  • torch.device
  • torch.layout
  • torch.dtype
  • 包含受支持类型的元组
  • torch.nn.ModuleList可以在 TorchScript for 循环中使用

Note

如果您使用的是 Python 2,则可以通过将属性名称添加到类的__constants__属性中来将其标记为常量:

import torch
import torch.nn as nn


class Foo(nn.Module):
    __constants__ = ['a']


    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4


    def forward(self, input):
        return self.a + input


f = torch.jit.script(Foo())

模块属性

torch.nn.Parameter包装器和register_buffer可用于将张量分配给模块。 如果可以推断出其他类型的值,则分配给已编译模块的其他值将添加到已编译模块中。 TorchScript 中可用的所有类型都可以用作模块属性。 张量属性在语义上与缓冲区相同。 空列表和字典的类型以及None值无法推断,必须通过 PEP 526 样式类注释指定。 如果无法推断出类型并且未对其进行显式注释,则不会将其作为属性添加到结果 ScriptModule 中。

Example:

from typing import List, Dict


class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]


    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]


    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict


        # `int`s can be inferred
        self.my_int = 10


    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int


f = torch.jit.script(Foo({'hi': 2}))

Note

如果您使用的是 Python 2,则可以通过将属性的类型添加到__annotations__类属性中作为属性名字典来标记属性的类型

from typing import List, Dict


class Foo(nn.Module):
    __annotations__ = {'words': List[str], 'some_dict': Dict[str, int]}


    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict


        # `int`s can be inferred
        self.my_int = 10


    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int


f = torch.jit.script(Foo({'hi': 2}))

调试

禁用用于调试的 JIT

PYTORCH_JIT¶

设置环境变量PYTORCH_JIT=0将禁用所有脚本和跟踪注释。 如果您的 TorchScript 模型之一存在难以调试的错误,则可以使用此标志来强制一切都使用本机 Python 运行。 由于此标志禁用了 TorchScript(脚本编写和跟踪),因此可以使用pdb之类的工具来调试模型代码。

给定一个示例脚本:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x


def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)


traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

除调用,@torch.jit.script,函数外,使用pdb调试此脚本是可行的。 我们可以全局禁用 JIT,以便我们可以将 @torch.jit.script 函数作为普通的 Python 函数调用,而不进行编译。 如果上述脚本称为disable_jit_example.py,我们可以这样调用它:

$ PYTORCH_JIT=0 python disable_jit_example.py

并且我们将能够像普通的 Python 函数一样进入 @torch.jit.script 函数。 要为特定功能禁用 TorchScript 编译器,请参见 @torch.jit.ignore

检查码

TorchScript 为所有 ScriptModule 实例提供了代码漂亮的打印机。 这个漂亮的打印机可以将脚本方法的代码解释为有效的 Python 语法。 例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv


print(foo.code)

具有单个forward方法的 ScriptModule 将具有属性code,您可以使用该属性检查 ScriptModule 的代码。 如果 ScriptModule 具有多个方法,则需要在方法本身而非模块上访问.code。 我们可以通过访问.foo.code在 ScriptModule 上检查名为foo的方法的代码。 上面的示例产生以下输出:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是 TorchScript 对forward方法的代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)正确捕获了模型代码。

解释图

TorchScript 还以 IR 图的形式在比代码漂亮打印机更低的层次上进行表示。

TorchScript 使用静态单分配(SSA)中间表示(IR)表示计算。 这种格式的指令由 ATen(PyTorch 的 C ++后端)运算符和其他原始运算符组成,包括用于循环和条件的控制流运算符。 举个例子:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv


print(foo.graph)

graph遵循检查代码部分中关于forward方法查找所述的相同规则。

上面的示例脚本生成图形:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10为例。

  • %rv.1 : Tensor表示我们将输出分配给一个名为rv.1的(唯一)值,该值是Tensor类型,并且我们不知道其具体形状。
  • aten::zeros是运算符(与torch.zeros等效),输入列表(%4, %6, %6, %10, %12)指定范围中的哪些值应作为输入传递。 可以在内置函数中找到aten::zeros等内置函数的模式。
  • # test.py:9:10是生成此指令的原始源文件中的位置。 在这种情况下,它是第 9 行和字符 10 处名为 <cite>test.py</cite> 的文件。

请注意,运算符也可以具有关联的blocks,即prim::Loopprim::If运算符。 在图形打印输出中,这些运算符被格式化以反映其等效的源代码形式,以方便进行调试。

如下图所示,可以检查图表以确认 ScriptModule 所描述的计算是正确的,无论是自动方式还是手动方式。

追踪案例

在某些极端情况下,给定 Python 函数/模块的跟踪不会代表基础代码。 这些情况可以包括:

  • 跟踪取决于输入的控制流(例如张量形状)
  • 跟踪张量视图的就地操作(例如,分配左侧的索引)

请注意,这些情况实际上将来可能是可追溯的。

自动跟踪检查

自动捕获跟踪中许多错误的一种方法是使用torch.jit.trace() API 上的check_inputscheck_inputs提取输入元组的列表,这些列表将用于重新追踪计算并验证结果。 例如:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result


inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]


traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

为我们提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:


            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此消息向我们表明,在我们第一次追踪它和使用check_inputs追踪它之间,计算有所不同。 实际上,loop_in_traced_fn主体内的循环取决于输入x的形状,因此,当我们尝试另一种形状不同的x时,迹线会有所不同。

在这种情况下,可以使用 torch.jit.script() 来捕获类似于数据的控制流:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result


inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]


scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())


for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))

产生:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

跟踪器警告

跟踪器会针对跟踪计算中的几种有问题的模式生成警告。 举个例子,追踪一个在 Tensor 的切片(视图)上包含就地分配的函数:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x


traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

产生几个警告和一个仅返回输入的图形:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码来解决此问题,使其不使用就地更新,而是使用torch.cat来错位构建结果张量:

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x


traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

内置函数

TorchScript 支持 PyTorch 提供的内置张量和神经网络功能的子集。 Tensor 上的大多数方法以及torch名称空间中的函数,torch.nn.functional中的所有函数以及torch.nn中的所有模块在 TorchScript 中均受支持,下表中没有列出。 对于不支持的模块,建议使用 torch.jit.trace()

不支持的torch.nn模块

torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
torch.nn.modules.normalization.CrossMapLRN2d
torch.nn.modules.rnn.RNN

有关支持的功能的完整参考,请参见 TorchScript 内置函数。

常见问题解答

问:我想在 GPU 上训练模型并在 CPU 上进行推理。 最佳做法是什么?

首先将模型从 GPU 转换为 CPU,然后将其保存,如下所示:


cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")

traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")

# ... later, when using the model:

if use_gpu:
model = torch.jit.load("gpu.pth")
else:
model = torch.jit.load("cpu.pth")

model(input)


推荐这样做是因为跟踪器可能会在特定设备上见证张量的创建,因此强制转换已加载的模型可能会产生意想不到的效果。 在保存之前对模型进行转换可确保跟踪器具有正确的设备信息。

问:如何在 ScriptModule 上存储属性?

说我们有一个像这样的模型:


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.x = 2

def forward(self):
return self.x

m = torch.jit.script(Model())


如果实例化Model,则将导致编译错误,因为编译器不了解x。 有四种方法可以通知编译器 ScriptModule 的属性:

\1. nn.Parameter-包装在nn.Parameter中的值将像在nn.Module上一样工作

\2. register_buffer-包装在register_buffer中的值将像在nn.Module上一样工作。 这等效于Tensor类型的属性(请参见 4)。

3.常量-将类成员注释为Final(或在类定义级别将其添加到名为__constants__的列表中)会将包含的名称标记为常量。 常数直接保存在模型代码中。 有关详细信息,请参见 Python 定义的常量。

4.属性-可以将支持的类型的值添加为可变属性。 可以推断大多数类型,但可能需要指定一些类型,有关详细信息,请参见模块属性。

问:我想跟踪模块的方法,但一直出现此错误:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

此错误通常表示您要跟踪的方法使用模块的参数,并且您正在传递模块的方法而不是模块实例(例如my_module_instance.forwardmy_module_instance)。

\& 使用模块的方法调用trace会将模块参数(可能需要渐变)捕获为常量。 &
\&
\&
另一方面,使用模块实例(例如my_module)调用trace会创建一个新模块,并将参数正确复制到新模块中,以便在需要时可以累积梯度。

& 要跟踪模块上的特定方法,请参见 torch.jit.trace_module