Dari apa yang telah saya kumpulkan sejauh ini, ada beberapa cara berbeda untuk membuang grafik TensorFlow ke dalam file dan kemudian memuatnya ke program lain, tetapi saya belum dapat menemukan contoh / informasi yang jelas tentang cara kerjanya. Yang sudah saya ketahui adalah ini:
- Simpan variabel model ke dalam file checkpoint (.ckpt) menggunakan a
tf.train.Saver()
dan pulihkan nanti ( sumber ) - Simpan model ke dalam file .pb dan muat kembali menggunakan
tf.train.write_graph()
dantf.import_graph_def()
( sumber ) - Memuat model dari file .pb, melatihnya kembali, dan membuangnya ke file .pb baru menggunakan Bazel ( sumber )
- Bekukan grafik untuk menyimpan grafik dan bobot ( sumber )
- Gunakan
as_graph_def()
untuk menyimpan model, dan untuk bobot / variabel, petakan mereka menjadi konstanta ( sumber )
Namun, saya belum dapat menjawab beberapa pertanyaan tentang metode berbeda ini:
- Mengenai file checkpoint, apakah mereka hanya menyimpan bobot model yang terlatih? Bisakah file checkpoint dimuat ke program baru, dan digunakan untuk menjalankan model, atau apakah mereka hanya berfungsi sebagai cara untuk menyimpan bobot dalam model pada waktu / tahap tertentu?
- Mengenai
tf.train.write_graph()
, apakah bobot / variabel juga disimpan? - Terkait Bazel, dapatkah itu hanya menyimpan ke / memuat dari file .pb untuk pelatihan ulang? Apakah ada perintah Bazel sederhana hanya untuk membuang grafik ke dalam .pb?
- Mengenai pembekuan, dapatkah grafik beku dimuat dengan menggunakan
tf.import_graph_def()
? - Demo Android untuk TensorFlow dimuat dalam model Inception Google dari file .pb. Jika saya ingin mengganti file .pb saya sendiri, bagaimana cara melakukannya? Apakah saya perlu mengubah kode / metode asli?
- Secara umum, apa sebenarnya perbedaan antara semua metode ini? Atau secara lebih luas, apa perbedaan antara
as_graph_def()
/.ckpt/.pb?
Singkatnya, yang saya cari adalah metode untuk menyimpan grafik (seperti dalam, berbagai operasi dan semacamnya) dan bobot / variabelnya ke dalam file, yang kemudian dapat digunakan untuk memuat grafik dan bobot ke program lain , untuk digunakan (tidak harus melanjutkan / pelatihan ulang).
Dokumentasi tentang topik ini tidak terlalu mudah, jadi jawaban / informasi apa pun akan sangat dihargai.
sumber
Jawaban:
Ada banyak cara untuk mendekati masalah penyimpanan model di TensorFlow, yang bisa membuatnya sedikit membingungkan. Mengambil setiap sub-pertanyaan Anda secara bergantian:
File checkpoint (dihasilkan misalnya dengan memanggil
saver.save()
sebuahtf.train.Saver
objek) hanya berisi bobot, dan variabel lain yang didefinisikan dalam program yang sama. Untuk menggunakannya di program lain, Anda harus membuat ulang struktur grafik terkait (misalnya dengan menjalankan kode untuk membuatnya lagi, atau memanggiltf.import_graph_def()
), yang memberi tahu TensorFlow apa yang harus dilakukan dengan bobot tersebut. Perhatikan bahwa pemanggilansaver.save()
juga menghasilkan file yang berisi aMetaGraphDef
, yang berisi grafik dan detail tentang cara mengaitkan bobot dari checkpoint dengan grafik itu. Lihat tutorial untuk lebih jelasnya.tf.train.write_graph()
hanya menulis struktur grafik; bukan bobotnya.Bazel tidak terkait dengan membaca atau menulis grafik TensorFlow. (Mungkin saya salah memahami pertanyaan Anda: silakan klarifikasi dalam komentar.)
Grafik beku dapat dimuat menggunakan
tf.import_graph_def()
. Dalam kasus ini, bobot (biasanya) disematkan dalam grafik, jadi Anda tidak perlu memuat checkpoint terpisah.Perubahan utamanya adalah memperbarui nama tensor yang dimasukkan ke dalam model, dan nama tensor yang diambil dari model. Dalam demo Android TensorFlow, ini akan sesuai dengan string
inputName
danoutputName
yang diteruskan keTensorFlowClassifier.initializeTensorFlow()
.Ini
GraphDef
adalah struktur program, yang biasanya tidak berubah selama proses pelatihan. Checkpoint adalah snapshot dari status proses pelatihan, yang biasanya berubah di setiap langkah proses pelatihan. Akibatnya, TensorFlow menggunakan format penyimpanan yang berbeda untuk jenis data ini, dan API tingkat rendah menyediakan berbagai cara untuk menyimpan dan memuatnya. Pustaka tingkat yang lebih tinggi, sepertiMetaGraphDef
pustaka, Keras , dan skflow membangun mekanisme ini untuk menyediakan cara yang lebih nyaman untuk menyimpan dan memulihkan seluruh model.sumber
tf.train.write_graph()
dan kemudian menjalankannya?GraphDef
disimpan olehtf.train.write_graph()
, Anda juga perlu mengingat nama tensor yang ingin Anda beri makan dan ambil saat menjalankan grafik (item 5 di atas).Anda dapat mencoba kode berikut:
sumber