"FlyAI 秃头挑战赛"

最近参加了FlyAI的新手比赛 https://flyai.com/d/BaldClassification

创建gitee项目后在master上搭建的cv-dnn分类问题框架,用vgg16做二分类得分96.3,为了提分从3个方向(数据增强、超参调整、模型融合)进行修改,创建了3个分支feat_data / feat_hypar / feat_model,然后怪事发生了:无论在哪个方向怎么调整得分均不升反降,InceptionV3和Resnet50甚至难以收敛。

后来对比代码发现 train数据集的shuffle参数被设置为False,如下:

torch.utils.data.DataLoader(train_folder, self.batch_size, shuffle=False, num_workers=0)

修改为True后各模型迅速收敛。经调试,如果shuffle=False每个batch只会取出一个class,dnn对该类别不断构建模型,然后class切换又重新构建,因此模型难以收敛,但有趣的是vgg16却收敛的很好,EfficientNet也收敛尚可。如果torch.utils.data.DataLoaderdatasettorchvision.datasets.ImageFolder,shuffle的设置尤其重要,因为shuffle=False数据会按floder提供,必然会出现上述batch中class单一的情况。

修改shuffle参数后得分97.29,还有优化空间。

© 2024 lanser.fun修订时间: 2023-05-02 02:19:15

results matching ""

    No results matching ""