莫烦Pytorch conditional_GAN.py笔记与报错修改

本文阅读 1 分钟

406_conditional_GAN.py代码在pytorch1.5以上版本的报错

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). 报错原因:

img

上网搜索原因,有建议如下错误修改方法1:

img

 实际上如上修改是错误的,虽然程序能跑通,但是实际运行出的结果时错误的

img

实际上真实的错误原因是因为 :问题来源主要是opt_D.step()变动了参数,原本pytorch1.4没有这一步的in place 检查。在1.5 版本他们加入了这个检查,所以如果你是1.4版本不会报错,但是1.5版本会报错。

正确的修改方法如下:

for step in range(10000):
    artist_paintings, labels = artist_works_with_labels()           # real painting, label from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)                      # random ideas
    G_inputs = torch.cat((G_ideas, labels), 1)                      # ideas with labels
    G_paintings = G(G_inputs)                                       # fake painting w.r.t label from G
    D_inputs1 = torch.cat((G_paintings, labels), 1)
    prob_artist1 = D(D_inputs1)

    G_loss = torch.mean(torch.log(1. - prob_artist1))
    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

    D_inputs0 = torch.cat((artist_paintings, labels), 1)            # all have their labels
    prob_artist0 = D(D_inputs0)                 # D try to increase this prob
    prob_artist1 = D(torch.cat((G_paintings, labels), 1).detach())  # D try to reduce this prob
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward()      # reusing computational graph
    opt_D.step()

    if step % 200 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        bound = [0, 0.5] if labels.data[0, 0] == 0 else [0.5, 1]
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + bound[1], c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + bound[0], c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.text(-.5, 1.7, 'Class = %i' % int(labels.data[0, 0]), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

plt.ioff()
plt.show()

运行后结果如下:

img

 

本文为互联网自动采集或经作者授权后发布,本文观点不代表立场,若侵权下架请联系我们删帖处理!文章出自:https://blog.csdn.net/mooyuan/article/details/122706119
-- 展开阅读全文 --
Web安全—逻辑越权漏洞(BAC)
« 上一篇 03-13
Redis底层数据结构--简单动态字符串
下一篇 » 04-10

发表评论

成为第一个评论的人

热门文章

标签TAG

最近回复