[AI & DL]/[MKD] 에러 극복기

[Error] CUDA 코드 CPU Only 코드로 변경

Engineering for all 2023. 4. 1. 17:30

주요 변경 사항

cuda() 메소드를 호출하지 않음.

map_location=torch.device('cpu') 파라미터를 torch.load() 함수에 추가하여 CUDA 텐서를 CPU 텐서로 이동

# CUDA

vgg = Vgg16(pretrain).cuda() model = make_arch(config_type, cfg, use_bias, True).cuda() for j, item in enumerate(nn.ModuleList(model.features)): print('layer : {} {}'.format(j, item)) if load_checkpoint: last_checkpoint = config['last_checkpoint'] checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name) model.load_state_dict( torch.load('{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint))) if not pretrain: vgg.load_state_dict( torch.load('{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class))) elif not pretrain: checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name) Path(checkpoint_path).mkdir(parents=True, exist_ok=True) torch.save(vgg.state_dict(), '{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class)) print("Source Checkpoint saved!") return vgg, model"

# CPU Only

vgg = Vgg16(pretrain)
model = make_arch(config_type, cfg, use_bias, True)

for j, item in enumerate(nn.ModuleList(model.features)):
    print('layer : {} {}'.format(j, item))

if load_checkpoint:
    last_checkpoint = config['last_checkpoint']
    checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name)
    model.load_state_dict(
        torch.load('{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint), map_location=torch.device('cpu')))
    if not pretrain:
        vgg.load_state_dict(
            torch.load('{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class), map_location=torch.device('cpu')))
elif not pretrain:
    checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name)
    Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

    torch.save(vgg.state_dict(), '{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class))
    print("Source Checkpoint saved!")

return vgg, model