Apa gunanya torch.no_grad di pytorch?

21

Saya baru mengenal pytorch dan mulai dengan kode github ini . Saya tidak mengerti komentar pada baris 60-61 dalam kode "because weights have requires_grad=True, but we don't need to track this in autograd". Saya mengerti bahwa kita menyebutkan requires_grad=Truevariabel yang perlu kita hitung gradien untuk menggunakan autograd tapi apa artinya "tracked by autograd"?

terbang
sumber

Jawaban:

24

Wrapper "with torch.no_grad ()" untuk sementara mengatur semua flag require_grad menjadi false. Contoh dari tutorial resmi PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Di luar:

True
True
False

Saya sarankan Anda untuk membaca semua tutorial dari situs web di atas.

Dalam contoh Anda: Saya kira penulis tidak ingin PyTorch menghitung gradien dari variabel baru yang didefinisikan w1 dan w2 karena ia hanya ingin memperbarui nilai-nilai mereka.

Adrien D
sumber
6
with torch.no_grad()

akan membuat semua operasi di blok tidak memiliki gradien.

Dalam pytorch, Anda tidak dapat melakukan perubahan penempatan w1 dan w2, yang merupakan dua variabel require_grad = True. Saya berpikir bahwa menghindari perubahan penempatan w1 dan w2 adalah karena akan menyebabkan kesalahan dalam perhitungan propagasi kembali. Karena perubahan penempatan akan benar-benar berubah w1 dan w2.

Namun, jika Anda menggunakan ini no_grad(), Anda dapat mengontrol w1 baru dan w2 baru tidak memiliki gradien karena dihasilkan oleh operasi, yang berarti Anda hanya mengubah nilai w1 dan w2, bukan bagian gradien, mereka masih memiliki informasi gradien variabel yang ditentukan sebelumnya dan propagasi balik dapat dilanjutkan.

Jianing Lu
sumber