Misalkan saya memiliki tensor Tensorflow. Bagaimana cara mendapatkan dimensi (bentuk) tensor sebagai nilai integer? Saya tahu ada dua metode, tensor.get_shape()
dan tf.shape(tensor)
, tetapi saya tidak bisa mendapatkan nilai bentuk sebagai int32
nilai integer .
Misalnya, di bawah ini saya telah membuat tensor 2-D, dan saya perlu mendapatkan jumlah baris dan kolom int32
sehingga saya dapat memanggil reshape()
untuk membuat tensor bentuk (num_rows * num_cols, 1)
. Namun, metode ini tensor.get_shape()
mengembalikan nilai sebagai Dimension
tipe, bukan int32
.
import tensorflow as tf
import numpy as np
sess = tf.Session()
tensor = tf.convert_to_tensor(np.array([[1001,1002,1003],[3,4,5]]), dtype=tf.float32)
sess.run(tensor)
# array([[ 1001., 1002., 1003.],
# [ 3., 4., 5.]], dtype=float32)
tensor_shape = tensor.get_shape()
tensor_shape
# TensorShape([Dimension(2), Dimension(3)])
print tensor_shape
# (2, 3)
num_rows = tensor_shape[0] # ???
num_cols = tensor_shape[1] # ???
tensor2 = tf.reshape(tensor, (num_rows*num_cols, 1))
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1750, in reshape
# name=name)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 454, in apply_op
# as_ref=input_arg.is_ref)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 621, in convert_to_tensor
# ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
# return constant(v, dtype=dtype, name=name)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
# tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto
# _AssertCompatible(values, dtype)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 290, in _AssertCompatible
# (dtype.name, repr(mismatch), type(mismatch).__name__))
# TypeError: Expected int32, got Dimension(6) of type 'Dimension' instead.
sumber
tf.reshape()
, tetapi saya benar-benar ingin mendapatkannum_rows
dannum_cols
sebagai bilangan bulat untuk operasi lain.tensor.get_shape().as_list()
as_list()
berhasil. Tolong tambahkan ke jawaban Anda, dan saya akan menerimanya.num_rows, num_cols = x.get_shape().as_list()
Cara lain untuk mengatasinya adalah seperti ini:
tensor_shape[0].value
Ini akan mengembalikan nilai int dari objek Dimension.
sumber
untuk tensor 2-D, Anda bisa mendapatkan jumlah baris dan kolom sebagai int32 menggunakan kode berikut:
rows, columns = map(lambda i: i.value, tensor.get_shape())
sumber
2.0 Jawaban Kompatibel : Dalam
Tensorflow 2.x (2.1)
, Anda bisa mendapatkan dimensi (bentuk) dari tensor sebagai nilai integer, seperti yang ditunjukkan pada Kode di bawah ini:Metode 1 (menggunakan
tf.shape
) :import tensorflow as tf c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) Shape = c.shape.as_list() print(Shape) # [2,3]
Metode 2 (menggunakan
tf.get_shape()
) :import tensorflow as tf c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) Shape = c.get_shape().as_list() print(Shape) # [2,3]
sumber
Solusi sederhana lainnya adalah dengan menggunakan
map()
sebagai berikut:tensor_shape = map(int, my_tensor.shape)
Ini mengubah semua
Dimension
objek menjadiint
sumber
Di versi yang lebih baru (diuji dengan TensorFlow 1.14), ada cara yang lebih numpy untuk mendapatkan bentuk tensor. Anda bisa menggunakan
tensor.shape
untuk mendapatkan bentuk tensor.sumber