PyTorch 序列化语义
2020-09-11 10:04 更新
原文: https://pytorch.org/docs/stable/notes/serialization.html
最佳实务
推荐的模型保存方法
序列化和还原模型有两种主要方法。
第一个(推荐)仅保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后再:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
Then later:
the_model = torch.load(PATH)
但是,在这种情况下,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,序列化的数据可能会以各种方式中断。