Saya mencari cara alternatif untuk menyelamatkan model yang terlatih di PyTorch. Sejauh ini, saya telah menemukan dua alternatif.
- torch.save () untuk menyimpan model dan torch.load () untuk memuat model.
- model.state_dict () untuk menyimpan model yang terlatih dan model.load_state_dict () untuk memuat model yang disimpan.
Saya telah menemukan diskusi ini di mana pendekatan 2 direkomendasikan daripada pendekatan 1.
Pertanyaan saya adalah, mengapa pendekatan kedua lebih disukai? Apakah hanya karena modul torch.nn memiliki dua fungsi tersebut dan kami didorong untuk menggunakannya?
python
serialization
deep-learning
pytorch
tensor
Wasi Ahmad
sumber
sumber
torch.save(model, f)
dantorch.save(model.state_dict(), f)
. File yang disimpan memiliki ukuran yang sama. Sekarang saya bingung. Juga, saya menemukan menggunakan acar untuk menyimpan model.state_dict () sangat lambat. Saya pikir cara terbaik adalah menggunakantorch.save(model.state_dict(), f)
karena Anda menangani pembuatan model, dan obor menangani pemuatan bobot model, sehingga menghilangkan kemungkinan masalah. Referensi: mendiskusikan.pytorch.org/t/saving-torch-models/838/4pickle
?Jawaban:
Saya telah menemukan halaman ini di repo github mereka, saya hanya akan menempelkan konten di sini.
Pendekatan yang disarankan untuk menyimpan model
Ada dua pendekatan utama untuk membuat cerita bersambung dan mengembalikan model.
Yang pertama (disarankan) hanya menyimpan dan memuat parameter model:
Kemudian nanti:
Yang kedua menyimpan dan memuat seluruh model:
Kemudian nanti:
Namun dalam kasus ini, data berseri terikat dengan kelas-kelas spesifik dan struktur direktori yang tepat digunakan, sehingga dapat pecah dengan berbagai cara ketika digunakan dalam proyek lain, atau setelah beberapa refaktor serius.
sumber
pickle
?Itu tergantung pada apa yang ingin Anda lakukan.
Kasus # 1: Simpan model untuk menggunakannya sendiri untuk inferensi : Anda menyimpan model, Anda mengembalikannya, dan kemudian Anda mengubah model ke mode evaluasi. Ini dilakukan karena Anda biasanya memiliki
BatchNorm
danDropout
lapisan yang secara default dalam mode kereta di konstruksi:Kasus # 2: Simpan model untuk melanjutkan pelatihan nanti : Jika Anda perlu terus melatih model yang akan Anda simpan, Anda perlu menyimpan lebih dari sekadar model. Anda juga perlu menyimpan status pengoptimal, zaman, skor, dll. Anda akan melakukannya seperti ini:
Untuk melanjutkan pelatihan Anda akan melakukan hal-hal seperti:,
state = torch.load(filepath)
dan kemudian, untuk mengembalikan keadaan setiap objek individu, sesuatu seperti ini:Karena Anda melanjutkan pelatihan, JANGAN menelepon
model.eval()
begitu Anda mengembalikan negara saat memuat.Kasus # 3: Model yang akan digunakan oleh orang lain tanpa akses ke kode Anda : Di Tensorflow Anda dapat membuat
.pb
file yang mendefinisikan arsitektur dan bobot model. Ini sangat berguna, khususnya saat menggunakanTensorflow serve
. Cara yang setara untuk melakukan ini di Pytorch adalah:Cara ini masih belum menjadi bukti dan karena pytorch masih mengalami banyak perubahan, saya tidak akan merekomendasikannya.
sumber
torch.load
mengembalikan hanya sebuah OrderedDict. Bagaimana Anda mendapatkan model untuk membuat prediksi?The acar alat perpustakaan Python protokol biner untuk serialisasi dan de-serialisasi objek Python.
Ketika Anda
import torch
(atau ketika Anda menggunakan PyTorch) itu akanimport pickle
untuk Anda dan Anda tidak perlu meneleponpickle.dump()
danpickle.load()
langsung, yang merupakan metode untuk menyimpan dan memuat objek.Bahkan,
torch.save()
dantorch.load()
akan membungkuspickle.dump()
danpickle.load()
untuk Anda.Sebuah
state_dict
jawaban lain yang disebutkan layak hanya beberapa catatan lagi.Apa
state_dict
yang kita miliki di dalam PyTorch? Sebenarnya ada duastate_dict
s.Model PyTorch
torch.nn.Module
memilikimodel.parameters()
panggilan untuk mendapatkan parameter yang dapat dipelajari (w dan b). Parameter yang dapat dipelajari ini, setelah ditetapkan secara acak, akan diperbarui seiring waktu seperti yang kita pelajari. Parameter yang bisa dipelajari adalah yang pertamastate_dict
.Yang kedua
state_dict
adalah dict state optimizer. Anda ingat bahwa pengoptimal digunakan untuk meningkatkan parameter yang dapat dipelajari. Tetapi optimizerstate_dict
sudah diperbaiki. Tidak ada yang bisa dipelajari di sana.Karena
state_dict
objek adalah kamus Python, mereka dapat dengan mudah disimpan, diperbarui, diubah, dan dipulihkan, menambahkan banyak modularitas untuk model dan pengoptimal PyTorch.Mari kita buat model super sederhana untuk menjelaskan ini:
Kode ini akan menampilkan yang berikut:
Perhatikan ini adalah model minimal. Anda dapat mencoba menambahkan tumpukan berurutan
Perhatikan bahwa hanya lapisan dengan parameter yang dapat dipelajari (lapisan konvolusional, lapisan linier, dll.) Dan buffer terdaftar (lapisan batchnorm) memiliki entri dalam model
state_dict
.Hal-hal yang tidak dapat dipelajari, termasuk dalam objek pengoptimal
state_dict
, yang berisi informasi tentang status pengoptimal, serta hyperparameter yang digunakan.Kisah selanjutnya sama; dalam fase inferensi (ini adalah fase ketika kita menggunakan model setelah pelatihan) untuk memprediksi; kami memprediksi berdasarkan parameter yang kami pelajari. Jadi untuk kesimpulan, kita hanya perlu menyimpan parameter
model.state_dict()
.Dan untuk menggunakan model.load_state_dict nanti (torch.load (filepath)) model.eval ()
Catatan: Jangan lupa baris terakhir
model.eval()
ini sangat penting setelah memuat model.Juga jangan mencoba menyimpan
torch.save(model.parameters(), filepath)
. Itumodel.parameters()
hanya objek generator.Di sisi lain,
torch.save(model, filepath)
menyimpan objek model itu sendiri, tetapi perlu diingat model tidak memiliki pengoptimalstate_dict
. Periksa jawaban luar biasa lainnya oleh @Jadiel de Armas untuk menyimpan dict state optimizer.sumber
Konvensi PyTorch yang umum adalah menyimpan model menggunakan ekstensi file .pt atau .pth.
Simpan / Muat Seluruh Model, Simpan:
Beban:
Kelas model harus didefinisikan di suatu tempat
sumber
Jika Anda ingin menyimpan model dan ingin melanjutkan pelatihan nanti:
GPU tunggal: Simpan:
Beban:
Banyak GPU: Simpan
Beban:
sumber