PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:

torchvision.datasets

torchvision.models

torchvision.transforms

官网介绍   源码

-----------------------------------------------------------------------------------------------------------------

本文介绍 torchvision.models 如何使用。以 vgg16为例子

1) 导入预训练模型:

import torchvision
model = torchvision.models.vgg16(pretrained=True)

2) 只导入网络结构,不导入参数:

model = torchvision.models.vgg16(pretrained=False) #主要是这里改为False

3) 由于 pretrained 参数默认是 False,所以 2) 等价于:

model = torchvision.models.vgg16()

 

Logo

讨论HarmonyOS开发技术,专注于API与组件、DevEco Studio、测试、元服务和应用上架分发等。

更多推荐