Memahami einsum NumPy

192

Saya berjuang untuk mengerti bagaimana tepatnya einsumbekerja. Saya telah melihat dokumentasi dan beberapa contoh, tetapi sepertinya tidak melekat.

Inilah contoh yang kami pelajari di kelas:

C = np.einsum("ij,jk->ki", A, B)

untuk dua array AdanB

Saya pikir ini akan memakan waktu A^T * B, tapi saya tidak yakin (itu mengambil alih salah satu dari mereka, kan?). Adakah yang bisa menuntun saya melalui apa yang terjadi di sini (dan secara umum saat menggunakan einsum)?

Selat Lance
sumber
7
Sebenarnya itu akan menjadi (A * B)^T, atau setara B^T * A^T.
Tigran Saluev
23
Saya menulis posting blog pendek tentang dasar-dasar di einsum sini . (Saya senang mentransplantasikan bit yang paling relevan ke jawaban di Stack Overflow jika berguna).
Alex Riley
1
@ajcr - Tautan yang indah. Terima kasih. The numpydokumentasi tidak memadai ketika menjelaskan rincian.
rayryeng
Terima kasih atas mosi percaya! Terlambat, saya telah berkontribusi jawaban di bawah ini .
Alex Riley
Perhatikan bahwa dalam Python *bukan perkalian matriks tetapi perkalian elemen. Awas!
ComputerScientist

Jawaban:

373

(Catatan: jawaban ini didasarkan pada posting blog pendek tentang yang einsumsaya tulis beberapa waktu lalu.)

Apa yang einsumharus dilakukan

Bayangkan kita memiliki dua array multi-dimensi, Adan B. Sekarang mari kita anggap kita ingin ...

  • berkembang biak A dengan Bcara tertentu untuk menciptakan berbagai produk baru; dan mungkin
  • jumlah array baru ini sepanjang sumbu tertentu; dan mungkin
  • mengubah posisi sumbu array baru dalam urutan tertentu.

Ada peluang bagus yang einsumakan membantu kami melakukan ini lebih cepat dan lebih efisien-memori yang disukai oleh kombinasi fungsi NumPy multiply, sumdan transposeakan memungkinkan.

Bagaimana cara einsumkerjanya?

Berikut adalah contoh sederhana (tetapi tidak sepenuhnya sepele). Ambil dua array berikut:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Kami akan mengalikan Adan Belemen-bijaksana dan kemudian menjumlahkan sepanjang baris array baru. Dalam NumPy "normal" kami akan menulis:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Jadi di sini, operasi pengindeksan pada Abaris atas sumbu pertama dari dua array sehingga multiplikasi dapat disiarkan. Baris-baris array produk kemudian dijumlahkan untuk mengembalikan jawabannya.

Sekarang jika kita ingin menggunakan einsum, kita bisa menulis:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

The signature String 'i,ij->i'adalah kunci di sini dan membutuhkan sedikit menjelaskan. Anda bisa memikirkannya dalam dua bagian. Di sisi kiri (kiri ->) kami telah memberi label pada dua array input. Di sebelah kanan ->, kami memberi label array yang ingin kami selesaikan.

Inilah yang terjadi selanjutnya:

  • Amemiliki satu sumbu; kami telah melabeli itu i. Dan Bmemiliki dua sumbu; kami telah memberi label sumbu 0 sebagai idan sumbu 1 sebagai j.

  • Dengan mengulangi label idi kedua larik input, kami memberi tahu einsumbahwa kedua sumbu ini harus dikalikan bersama. Dengan kata lain, kami mengalikan array Adengan setiap kolom array B, seperti A[:, np.newaxis] * Bhalnya.

  • Pemberitahuan yang jtidak muncul sebagai label pada output yang diinginkan; kami baru saja menggunakan i(kami ingin berakhir dengan array 1D). Dengan menghilangkan label, kami menyuruh einsumuntuk menjumlahkan sepanjang sumbu ini. Dengan kata lain, kami menjumlahkan deretan produk, seperti .sum(axis=1)halnya.

Pada dasarnya itu semua yang perlu Anda ketahui untuk digunakan einsum. Ini membantu untuk bermain sedikit; jika kita membiarkan kedua label di output 'i,ij->ij',, kita mendapatkan kembali array produk 2D (sama seperti A[:, np.newaxis] * B). Jika kami mengatakan tidak ada label keluaran 'i,ij->,, kami mendapatkan kembali satu nomor (sama dengan melakukan (A[:, np.newaxis] * B).sum()).

Hal yang hebat tentang einsumitu, adalah tidak membangun array produk sementara terlebih dahulu; itu hanya menjumlahkan produk saat berjalan. Ini dapat menyebabkan penghematan besar dalam penggunaan memori.

Contoh yang sedikit lebih besar

Untuk menjelaskan produk titik, berikut adalah dua array baru:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Kami akan menghitung penggunaan produk titik np.einsum('ij,jk->ik', A, B). Berikut adalah gambar yang menunjukkan pelabelan Adan Bdan larik keluaran yang kita dapatkan dari fungsi:

masukkan deskripsi gambar di sini

Anda dapat melihat label jdiulangi - ini berarti kami mengalikan baris Adengan kolom B. Selanjutnya, label jtidak termasuk dalam output - kami menjumlahkan produk ini. Label idan kdisimpan untuk output, jadi kami kembali array 2D.

Mungkin lebih jelas untuk membandingkan hasil ini dengan array di mana label jini tidak dijumlahkan. Di bawah, di sebelah kiri Anda dapat melihat larik 3D yang dihasilkan dari penulisan np.einsum('ij,jk->ijk', A, B)(yaitu, kami telah menyimpan label j):

masukkan deskripsi gambar di sini

Sumbu penjumlahan jmemberikan produk titik yang diharapkan, yang ditunjukkan di sebelah kanan.

Beberapa latihan

Untuk mendapatkan lebih banyak rasa einsum, dapat berguna untuk mengimplementasikan operasi array NumPy yang sudah dikenal menggunakan notasi subskrip. Apa pun yang melibatkan kombinasi mengalikan dan menjumlahkan sumbu dapat ditulis menggunakan einsum.

Misalkan A dan B menjadi dua array 1D dengan panjang yang sama. Sebagai contoh, A = np.arange(10)dan B = np.arange(5, 15).

  • Jumlahnya Adapat ditulis:

    np.einsum('i->', A)
  • Perkalian elemen-bijaksana A * B,, dapat ditulis:

    np.einsum('i,i->i', A, B)
  • Produk dalam atau produk titik, np.inner(A, B)atau np.dot(A, B), dapat ditulis:

    np.einsum('i,i->', A, B) # or just use 'i,i'
  • Produk luar np.outer(A, B),, dapat ditulis:

    np.einsum('i,j->ij', A, B)

Untuk array 2D, Cdan D, asalkan sumbu adalah panjang yang kompatibel (baik panjang yang sama atau salah satunya memiliki panjang 1), berikut adalah beberapa contoh:

  • Jejak C(jumlah diagonal utama) np.trace(C),, dapat ditulis:

    np.einsum('ii', C)
  • Perkalian elemen-bijaksana Cdan transpos dari D, C * D.T, dapat ditulis:

    np.einsum('ij,ji->ij', C, D)
  • Mengalikan setiap elemen Cdengan array D(untuk membuat array 4D) C[:, :, None, None] * D,, dapat ditulis:

    np.einsum('ij,kl->ijkl', C, D)  
Alex Riley
sumber
1
Penjelasan yang sangat bagus, terima kasih. "Perhatikan bahwa saya tidak muncul sebagai label dalam output yang diinginkan" - bukan?
Ian Hincks
Terima kasih @IanHincks! Itu terlihat seperti kesalahan ketik; Saya sudah memperbaikinya sekarang.
Alex Riley
1
Jawaban yang sangat bagus Ini juga patut dicatat yang ij,jkdapat bekerja dengan sendirinya (tanpa panah) untuk membentuk perkalian matriks. Tapi sepertinya untuk kejelasan yang terbaik adalah menempatkan panah dan kemudian dimensi output. Ada dalam posting blog.
ComputerScientist
1
@Peaceful: ini adalah salah satu kesempatan di mana sulit untuk memilih kata yang tepat! Saya merasa "kolom" cocok sedikit lebih baik di sini karena Apanjangnya 3, sama dengan panjang kolom di B(sedangkan baris Bmemiliki panjang 4 dan tidak dapat dikalikan dengan elemen A).
Alex Riley
1
Perhatikan bahwa menghilangkan ->semantik mempengaruhi: "Dalam mode implisit, subskrip yang dipilih penting karena sumbu output disusun ulang sesuai abjad. Ini berarti bahwa np.einsum('ij', a)tidak mempengaruhi array 2D, sambil np.einsum('ji', a)mengambil transposinya."
BallpointBen
41

Memahami ide numpy.einsum()sangat mudah jika Anda memahaminya secara intuitif. Sebagai contoh, mari kita mulai dengan deskripsi sederhana yang melibatkan perkalian matriks .


Untuk menggunakannya numpy.einsum(), yang harus Anda lakukan adalah meneruskan string subkrip yang disebut sebagai argumen, diikuti oleh array input Anda .

Katakanlah Anda memiliki dua array 2D, Adan B, dan Anda ingin melakukan perkalian matriks. Jadi, Anda lakukan:

np.einsum("ij, jk -> ik", A, B)

Di sini string subskrip ij berhubungan dengan array Asedangkan string subskrip jk berhubungan dengan array B. Juga, hal yang paling penting untuk dicatat di sini adalah bahwa jumlah karakter dalam setiap string subskrip harus sesuai dengan dimensi array. (yaitu dua karakter untuk array 2D, tiga karakter untuk array 3D, dan sebagainya.) Dan jika Anda mengulangi karakter di antara string subskrip ( jdalam kasus kami), maka itu berarti Anda ingin einjumlah terjadi di sepanjang dimensi tersebut. Dengan demikian, jumlah tersebut akan dikurangi. (Yaitu dimensi itu akan hilang )

The String subscript setelah ini ->, akan array yang dihasilkan kami. Jika Anda membiarkannya kosong, maka semuanya akan dijumlahkan dan nilai skalar dikembalikan sebagai hasilnya. Lain array yang dihasilkan akan memiliki dimensi sesuai dengan string subskrip . Dalam contoh kita, itu akan menjadi ik. Ini intuitif karena kita tahu bahwa untuk perkalian matriks jumlah kolom dalam array Aharus cocok dengan jumlah baris dalam array Byang merupakan apa yang terjadi di sini (yaitu kita menyandikan pengetahuan ini dengan mengulangi char jdalam string subskrip )


Berikut adalah beberapa contoh yang menggambarkan penggunaan / kekuatan np.einsum()dalam mengimplementasikan beberapa operasi tensor atau nd-array yang umum , secara ringkas.

Input

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Perkalian matriks (mirip dengan np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Ekstrak elemen di sepanjang main-diagonal (mirip dengan np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) Produk Hadamard (yaitu produk elemen-bijaksana dari dua array) (mirip dengan arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Elemen-bijaksana kuadrat (mirip dengan np.square(arr)atau arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Jejak (yaitu jumlah elemen main-diagonal) (mirip dengan np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Matriks transpose (mirip dengan np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Produk Luar (dari vektor) (mirip dengan np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Produk Dalam (dari vektor) (mirip dengan np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Jumlah sepanjang sumbu 0 (mirip dengan np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Jumlahkan sepanjang sumbu 1 (mirip dengan np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Penggandaan Matriks Batch

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Jumlah sepanjang sumbu 2 (mirip dengan np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Jumlahkan semua elemen dalam array (mirip dengan np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Jumlah lebih dari beberapa sumbu (yaitu marginalisasi)
(mirip dengan np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Produk Dot Ganda (mirip dengan np.sum (produk hadamard) lih. 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) penggandaan array 2D dan 3D

Penggandaan seperti itu bisa sangat berguna ketika menyelesaikan sistem persamaan linear ( Ax = b ) di mana Anda ingin memverifikasi hasilnya.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Sebaliknya, jika seseorang harus menggunakan np.matmul()verifikasi ini, kita harus melakukan beberapa reshapeoperasi untuk mencapai hasil yang sama seperti:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Bonus : Baca lebih banyak matematika di sini: Einstein-Summation dan pasti di sini: Tensor-Notation

kmario23
sumber
7

Mari kita membuat 2 array, dengan dimensi yang berbeda, tetapi kompatibel untuk menyoroti interaksi mereka

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Perhitungan Anda, mengambil 'titik' (jumlah produk) dari (2,3) dengan (3,4) untuk menghasilkan array (4,2). iadalah redup pertama A, yang terakhir C; kyang terakhir B, tanggal 1 C. j'dikonsumsi' oleh penjumlahan.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

Ini sama dengan np.dot(A,B).T- ini adalah hasil akhir yang ditransformasikan.

Untuk melihat lebih banyak tentang apa yang terjadi j, ubah Clangganan ke ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Ini juga dapat diproduksi dengan:

A[:,:,None]*B[None,:,:]

Yaitu, tambahkan kdimensi ke akhir A, dan a ike depan B, menghasilkan array (2,3,4).

0 + 4 + 16 = 20,, 9 + 28 + 55 = 92dll; Jumlahkan jdan transpos untuk mendapatkan hasil sebelumnya:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
hpaulj
sumber
7

Saya menemukan NumPy: Trik perdagangan (Bagian II) instruktif

Kami menggunakan -> untuk menunjukkan urutan array output. Jadi pikirkan 'ij, i-> j' sebagai memiliki sisi kiri (LHS) dan sisi kanan (RHS). Setiap pengulangan label pada LHS menghitung elemen produk dengan bijak dan kemudian menyimpulkannya. Dengan mengubah label pada sisi RHS (output), kita dapat menentukan sumbu di mana kita ingin melanjutkan berkenaan dengan array input, yaitu penjumlahan sepanjang sumbu 0, 1 dan seterusnya.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Perhatikan ada tiga sumbu, i, j, k, dan j yang diulang (di sisi kiri). i,jmewakili baris dan kolom untuk a. j,kuntuk b.

Untuk menghitung produk dan menyelaraskan jsumbu, kita perlu menambahkan sumbu a. ( bakan disiarkan di sepanjang (?) Sumbu pertama)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jtidak ada di sisi kanan sehingga kami menjumlahkan jyang merupakan sumbu kedua dari array 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Akhirnya, indeks (alfabet) terbalik di sisi kanan sehingga kami transpos.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>
wwii
sumber
NumPy: Trik perdagangan (Bagian II) tampaknya memerlukan undangan dari pemilik situs serta akun Wordpress
Tejas Shetty
... tautan yang diperbarui, untungnya saya menemukannya dengan pencarian. - Thnx.
wwii
@TejasShetty Banyak jawaban yang lebih baik di sini sekarang - mungkin saya harus menghapus yang ini.
wwii
2
Tolong jangan hapus jawaban Anda.
Tejas Shetty
5

Saat membaca persamaan einsum, saya merasa paling membantu jika hanya mampu merebusnya hingga ke versi imperatifnya.

Mari kita mulai dengan pernyataan berikut (mengesankan):

C = np.einsum('bhwi,bhwj->bij', A, B)

Bekerja melalui tanda baca pertama-tama kita melihat bahwa kita memiliki dua gumpalan yang dipisahkan koma 4 huruf - bhwidan bhwj, sebelum panah, dan gumpalan 3 huruf tunggal bijsetelahnya. Oleh karena itu, persamaan menghasilkan hasil tensor peringkat-3 dari dua input tensor peringkat-4.

Sekarang, biarkan setiap huruf di setiap gumpalan menjadi nama variabel rentang. Posisi di mana huruf muncul di gumpalan adalah indeks sumbu yang berkisar di tensor itu. Penjumlahan imperatif yang menghasilkan setiap elemen C, oleh karena itu, harus dimulai dengan tiga bersarang untuk loop, satu untuk setiap indeks C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Jadi, pada dasarnya, Anda memiliki forlingkaran untuk setiap indeks output C. Kami akan membiarkan rentang tidak ditentukan untuk saat ini.

Selanjutnya kita melihat sisi kiri - apakah ada variabel rentang di sana yang tidak muncul di sisi kanan ? Dalam kasus kami - ya, hdan w. Tambahkan forloop bersarang bagian dalam untuk setiap variabel seperti:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

Di dalam loop terdalam kita sekarang memiliki semua indeks yang ditentukan, sehingga kita dapat menulis penjumlahan yang sebenarnya dan terjemahan selesai:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Jika Anda telah dapat mengikuti kode sejauh ini, selamat! Ini yang Anda butuhkan untuk dapat membaca persamaan einsum. Perhatikan khususnya bagaimana rumus einsum asli memetakan ke pernyataan penjumlahan akhir dalam cuplikan di atas. For-loop dan batas jangkauan hanya bulu dan pernyataan akhir adalah yang Anda benar-benar perlu memahami apa yang terjadi.

Demi kelengkapan, mari kita lihat cara menentukan rentang untuk setiap variabel rentang. Nah, rentang masing-masing variabel hanyalah panjang dimensi yang diindeks. Jelas, jika suatu variabel mengindeks lebih dari satu dimensi dalam satu atau lebih tensor, maka panjang masing-masing dimensi tersebut harus sama. Berikut kode di atas dengan rentang lengkap:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Stefan Dragnev
sumber
0

Saya pikir contoh paling sederhana adalah dalam dokumen tensorflow

Ada empat langkah untuk mengubah persamaan Anda menjadi notasi einsum. Mari kita ambil persamaan ini sebagai contohC[i,k] = sum_j A[i,j] * B[j,k]

  1. Pertama kita menjatuhkan nama variabel. Kita mendapatkanik = sum_j ij * jk
  2. Kami membatalkan sum_jistilah sebagaimana adanya. Kita mendapatkanik = ij * jk
  3. Kami ganti *dengan ,. Kita mendapatkanik = ij, jk
  4. Outputnya ada di RHS dan dipisahkan dengan ->tanda. Kita mendapatkanij, jk -> ik

Einsum interpreter hanya menjalankan 4 langkah ini secara terbalik. Semua indeks yang hilang dalam hasil dijumlahkan.

Berikut adalah beberapa contoh dari dokumen

# Matrix multiplication
einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

# Dot product
einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

# Outer product
einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

# Transpose
einsum('ij->ji', m)  # output[j,i] = m[i,j]

# Trace
einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

# Batch matrix multiplication
einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
Souradeep Nanda
sumber