阅读(2k) 书签 (0)

PyTorch torch.utils.model_zoo

2020-09-15 11:57 更新

原文: PyTorch torch.utils.model_zoo

移至 <cite>torch.hub</cite> 。

torch.utils.model_zoo.load_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)¶

将 Torch 序列化对象加载到给定的 URL。

如果下载的文件是 zip 文件,它将被自动解压缩。

如果 <cite>model_dir</cite> 中已经存在该对象,则将其反序列化并返回。 <cite>model_dir</cite> 的默认值为$TORCH_HOME/checkpoints,其中环境变量$TORCH_HOME的默认值为$XDG_CACHE_HOME/torch$XDG_CACHE_HOME遵循 Linux 文件系统布局的 X 设计组规范,如果未设置,则默认值为~/.cache

参数

  • url (字符串)–要下载的对象的 URL
  • model_dir (字符串 可选)–保存对象的目录
  • map_location (可选)–指定如何重新映射存储位置的函数或命令(请参见 torch.load)
  • 进度 (bool 可选)–是否显示 stderr 进度条。 默认值:True
  • check_hash (bool 可选)–如果为 True,则 URL 的文件名部分应遵循命名约定filename-<sha256>.ext,其中[ <sha256>是文件内容的 SHA256 哈希值的前 8 位或更多位。 哈希用于确保唯一的名称并验证文件的内容。 默认值:False

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')