Bagaimana cara kerja metode "view" di PyTorch?

205

Saya bingung tentang metode view()dalam cuplikan kode berikut.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

Kebingungan saya mengenai baris berikut.

x = x.view(-1, 16*5*5)

Apa tensor.view()fungsinya? Saya telah melihat penggunaannya di banyak tempat, tapi saya tidak mengerti bagaimana ia mengartikan parameternya.

Apa yang terjadi jika saya memberikan nilai negatif sebagai parameter ke view()fungsi? Misalnya, apa yang terjadi jika saya menelepon tensor_variable.view(1, 1, -1),?

Adakah yang bisa menjelaskan prinsip utama view()fungsi dengan beberapa contoh?

Wasi Ahmad
sumber

Jawaban:

283

Fungsi tampilan dimaksudkan untuk membentuk kembali tensor.

Katakanlah Anda memiliki tensor

import torch
a = torch.range(1, 16)

aadalah tensor yang memiliki 16 elemen dari 1 hingga 16 (termasuk). Jika Anda ingin membentuk kembali tensor ini menjadi 4 x 4tensor maka Anda dapat menggunakannya

a = a.view(4, 4)

Sekarang aakan menjadi 4 x 4tensor. Perhatikan bahwa setelah membentuk kembali jumlah total elemen harus tetap sama. Mengubah bentuk tensor amenjadi 3 x 5tensor tidak akan sesuai.

Apa arti dari parameter -1?

Jika ada situasi yang Anda tidak tahu berapa banyak baris yang Anda inginkan tetapi yakin dengan jumlah kolom, maka Anda dapat menentukan ini dengan -1. ( Perhatikan bahwa Anda dapat memperluas ini ke tensor dengan dimensi lebih banyak. Hanya satu dari nilai sumbu yang bisa -1 ). Ini adalah cara memberitahu perpustakaan: "beri saya tensor yang memiliki banyak kolom ini dan Anda menghitung jumlah baris yang diperlukan untuk mewujudkannya".

Ini dapat dilihat pada kode jaringan saraf yang Anda berikan di atas. Setelah garis x = self.pool(F.relu(self.conv2(x)))dalam fungsi maju, Anda akan memiliki peta fitur 16 kedalaman. Anda harus meratakan ini untuk memberikannya ke lapisan yang sepenuhnya terhubung. Jadi, Anda memberi tahu pytorch untuk membentuk kembali tensor yang Anda peroleh untuk memiliki jumlah kolom tertentu dan memerintahkannya untuk memutuskan jumlah baris dengan sendirinya.

Menggambar kesamaan antara numpy dan pytorch, viewmirip dengan fungsi membentuk kembali numpy .

Kashyap
sumber
93
"Pandangannya mirip dengan pembentukan kembali numpy" - mengapa mereka tidak menyebutnya saja reshapedi PyTorch ?!
MaxB
54
@ MaxB Tidak seperti membentuk kembali, tensor baru yang dikembalikan oleh "view" membagikan data yang mendasarinya dengan tensor asli, jadi itu benar-benar pandangan ke tensor lama daripada membuat yang baru.
qihqi
37
@blckbird "membentuk kembali selalu menyalin memori. lihat tidak pernah menyalin memori." github.com/torch/cutorch/issues/98
devinbost
3
@devinbost Torch membentuk kembali selalu menyalin memori. NumPy membentuk kembali tidak.
Tavian Barnes
32

Mari kita lakukan beberapa contoh, dari yang sederhana hingga yang lebih sulit.

  1. The viewMetode mengembalikan tensor dengan data yang sama dengan selftensor (yang berarti bahwa tensor kembali memiliki jumlah elemen yang sama), tapi dengan bentuk yang berbeda. Sebagai contoh:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
  2. Dengan asumsi itu -1bukan salah satu parameter, ketika Anda mengalikannya bersama-sama, hasilnya harus sama dengan jumlah elemen dalam tensor. Jika Anda melakukannya:, a.view(3, 3)ia akan menaikkan RuntimeErrorbentuk karena (3 x 3) tidak valid untuk input dengan 16 elemen. Dengan kata lain: 3 x 3 tidak sama dengan 16 tetapi 9.

  3. Anda dapat menggunakan -1sebagai salah satu parameter yang Anda berikan ke fungsi, tetapi hanya sekali. Yang terjadi hanyalah bahwa metode ini akan menghitung untuk Anda tentang cara mengisi dimensi itu. Misalnya a.view(2, -1, 4)setara dengan a.view(2, 2, 4). [16 / (2 x 4) = 2]

  4. Perhatikan bahwa tensor yang dikembalikan membagikan data yang sama . Jika Anda membuat perubahan dalam "tampilan" Anda mengubah data tensor asli:

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
  5. Sekarang, untuk use case yang lebih kompleks. Dokumentasi mengatakan bahwa setiap dimensi tampilan baru harus berupa subruang dari dimensi asli, atau hanya rentang d, d + 1, ..., d + k yang memenuhi kondisi seperti-kedekatan berikut ini untuk semua i = 0,. .., k - 1, langkah [i] = langkah [i + 1] x ukuran [i + 1] . Kalau tidak, contiguous()perlu dipanggil sebelum tensor dapat dilihat. Sebagai contoh:

    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)

    Perhatikan bahwa untuk a_t, langkah [0]! = Langkah [1] x ukuran [1] sejak 24! = 2 x 3

Jadiel de Armas
sumber
7

torch.Tensor.view()

Sederhananya, torch.Tensor.view()yang terinspirasi oleh numpy.ndarray.reshape()atau numpy.reshape(), menciptakan tampilan tensor baru, selama bentuk baru tersebut kompatibel dengan bentuk tensor asli.

Mari kita pahami ini secara detail dengan menggunakan contoh nyata.

In [43]: t = torch.arange(18) 

In [44]: t 
Out[44]: 
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

Dengan tensor ini tbentuk (18,), baru dilihat dapat hanya dibuat untuk bentuk berikut:

(1, 18)atau ekuivalen (1, -1)atau atau setara atau atau ekuivalen atau atau ekuivalen atau atau ekuivalen atau atau ekuivalen atau(-1, 18)
(2, 9)(2, -1)(-1, 9)
(3, 6)(3, -1)(-1, 6)
(6, 3)(6, -1)(-1, 3)
(9, 2)(9, -1)(-1, 2)
(18, 1)(18, -1)(-1, 1)

Seperti yang sudah dapat kita amati dari bentuk tuple di atas, penggandaan elemen-elemen dari tuple bentuk (misalnya 2*9, 3*6dll.) Harus selalu sama dengan jumlah total elemen dalam tensor asli (18 dalam contoh kita).

Hal lain yang perlu diperhatikan adalah bahwa kami menggunakan a -1di salah satu tempat di masing-masing bentuk tuple. Dengan menggunakan -1, kita menjadi malas dalam melakukan perhitungan diri kita sendiri dan agak mendelegasikan tugas untuk PyTorch untuk melakukan perhitungan yang nilai untuk bentuk ketika ia menciptakan baru tampilan . Satu hal penting yang perlu diperhatikan adalah kita hanya bisa menggunakan satu -1tuple dalam bentuk. Nilai yang tersisa harus diberikan secara eksplisit oleh kami. Lain PyTorch akan mengeluh dengan melemparkan RuntimeError:

RuntimeError: hanya satu dimensi yang dapat disimpulkan

Jadi, dengan semua bentuk yang disebutkan di atas, PyTorch akan selalu mengembalikan tampilan baru dari tensor aslit . Ini pada dasarnya berarti bahwa itu hanya mengubah informasi langkah tensor untuk setiap tampilan baru yang diminta.

Di bawah ini adalah beberapa contoh yang menggambarkan bagaimana langkah tensor diubah dengan setiap tampilan baru .

# stride of our original tensor `t`
In [53]: t.stride() 
Out[53]: (1,)

Sekarang, kita akan melihat langkah untuk tampilan baru :

# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride() 
Out[55]: (18, 1)

# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()       
Out[57]: (9, 1)

# shape (3, 6)
In [59]: t3 = t.view(3, -1) 
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride() 
Out[60]: (6, 1)

# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride() 
Out[63]: (3, 1)

# shape (9, 2)
In [65]: t5 = t.view(9, -1) 
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)

# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)

Jadi itulah keajaiban view()fungsinya. Itu hanya mengubah langkah dari tensor (asli) untuk masing-masing tampilan baru , selama bentuk tampilan baru kompatibel dengan bentuk aslinya.

Hal lain yang satu yang menarik bisa mengamati dari tupel langkah adalah bahwa nilai dari elemen dalam 0 th posisi sama dengan nilai dari elemen dalam 1 st posisi tuple bentuk.

In [74]: t3.shape 
Out[74]: torch.Size([3, 6])
                        |
In [75]: t3.stride()    |
Out[75]: (6, 1)         |
          |_____________|

Hal ini karena:

In [76]: t3 
Out[76]: 
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17]])

langkahnya (6, 1)mengatakan bahwa untuk beralih dari satu elemen ke elemen berikutnya sepanjang dimensi ke- 0 , kita harus melompat atau mengambil 6 langkah. (yaitu untuk pergi dari 0ke 6, kita harus mengambil 6 langkah.) Tapi untuk pergi dari satu elemen ke elemen berikutnya dalam 1 st dimensi, kita hanya perlu satu langkah (untuk misalnya untuk pergi dari 2ke 3).

Dengan demikian, informasi langkah adalah inti dari bagaimana elemen diakses dari memori untuk melakukan perhitungan.


torch.reshape ()

Fungsi ini akan mengembalikan tampilan dan persis sama dengan menggunakan torch.Tensor.view()selama bentuk baru tersebut kompatibel dengan bentuk tensor asli. Jika tidak, itu akan mengembalikan salinan.

Namun, catatan torch.reshape()memperingatkan bahwa:

input dan input yang bersebelahan dengan langkah yang kompatibel dapat dibentuk kembali tanpa menyalin, tetapi orang tidak harus bergantung pada perilaku menyalin vs melihat.

kmario23
sumber
1

Saya menemukan itu x.view(-1, 16 * 5 * 5)setara dengan x.flatten(1), di mana parameter 1 menunjukkan proses perataan dimulai dari dimensi 1 (tidak meratakan dimensi 'sampel') Seperti yang Anda lihat, penggunaan yang terakhir secara semantik lebih jelas dan lebih mudah digunakan, jadi saya lebih suka flatten().

FENGSHI ZHENG
sumber
1

Apa arti dari parameter -1?

Anda dapat membaca -1sebagai jumlah parameter dinamis atau "apa pun". Karena itu hanya ada satu parameter -1dalamview() .

Jika Anda bertanya x.view(-1,1)ini akan menampilkan bentuk tensor [anything, 1]tergantung pada jumlah elemen dalam x. Sebagai contoh:

import torch
x = torch.tensor([1, 2, 3, 4])
print(x,x.shape)
print("...")
print(x.view(-1,1), x.view(-1,1).shape)
print(x.view(1,-1), x.view(1,-1).shape)

Akan menghasilkan:

tensor([1, 2, 3, 4]) torch.Size([4])
...
tensor([[1],
        [2],
        [3],
        [4]]) torch.Size([4, 1])
tensor([[1, 2, 3, 4]]) torch.Size([1, 4])
prosti
sumber
1

weights.reshape(a, b) akan mengembalikan tensor baru dengan data yang sama dengan bobot dengan ukuran (a, b) karena di dalamnya menyalin data ke bagian lain dari memori.

weights.resize_(a, b)mengembalikan tensor yang sama dengan bentuk yang berbeda. Namun, jika bentuk baru menghasilkan elemen lebih sedikit dari tensor asli, beberapa elemen akan dihapus dari tensor (tetapi tidak dari memori). Jika bentuk baru menghasilkan lebih banyak elemen daripada tensor asli, elemen baru akan diinisialisasi dalam memori.

weights.view(a, b) akan mengembalikan tensor baru dengan data yang sama dengan bobot dengan ukuran (a, b)

Jibin Mathew
sumber
0

Saya sangat menyukai contoh @Jadiel de Armas.

Saya ingin menambahkan wawasan kecil tentang bagaimana elemen dipesan untuk .view (...)

  • Untuk Tensor dengan bentuk (a, b, c) , yang urutan elemen itu ditentukan oleh sistem penomoran: dimana digit pertama memiliki sebuah nomor, digit kedua memiliki b angka dan digit ketiga memiliki c angka.
  • Pemetaan elemen dalam Tensor baru yang dikembalikan oleh .view (...) mempertahankan urutan Tensor asli ini.
ychnh
sumber
0

Mari kita coba memahami tampilan dengan contoh-contoh berikut:

    a=torch.range(1,16)

print(a)

    tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
            15., 16.])

print(a.view(-1,2))

    tensor([[ 1.,  2.],
            [ 3.,  4.],
            [ 5.,  6.],
            [ 7.,  8.],
            [ 9., 10.],
            [11., 12.],
            [13., 14.],
            [15., 16.]])

print(a.view(2,-1,4))   #3d tensor

    tensor([[[ 1.,  2.,  3.,  4.],
             [ 5.,  6.,  7.,  8.]],

            [[ 9., 10., 11., 12.],
             [13., 14., 15., 16.]]])
print(a.view(2,-1,2))

    tensor([[[ 1.,  2.],
             [ 3.,  4.],
             [ 5.,  6.],
             [ 7.,  8.]],

            [[ 9., 10.],
             [11., 12.],
             [13., 14.],
             [15., 16.]]])

print(a.view(4,-1,2))

    tensor([[[ 1.,  2.],
             [ 3.,  4.]],

            [[ 5.,  6.],
             [ 7.,  8.]],

            [[ 9., 10.],
             [11., 12.]],

            [[13., 14.],
             [15., 16.]]])

-1 sebagai nilai argumen adalah cara mudah untuk menghitung nilai katakan x asalkan kita tahu nilai y, z atau sebaliknya dalam kasus 3d dan untuk 2d lagi cara mudah untuk menghitung nilai katakan x asalkan kita mengetahui nilai-nilai y atau sebaliknya ..

Lija Alex
sumber