Uji apakah larik numpy hanya berisi nol

93

Kami menginisialisasi array numpy dengan nol seperti di bawah ini:

np.zeros((N,N+1))

Tetapi bagaimana kita memeriksa apakah semua elemen dalam matriks array n * n numpy tertentu adalah nol.
Metode ini hanya perlu mengembalikan True jika semua nilainya benar-benar nol.

IUnknown
sumber

Jawaban:

73

Lihat numpy.count_nonzero .

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5
Kumar Prashant
sumber
9
Anda hanya ingin not np.count_nonzero(np.eye(4))mengembalikan Truejika semua nilainya 0.
J. Martinot-Lagarde
166

Jawaban lain yang diposting di sini akan berfungsi, tetapi fungsi paling jelas dan paling efisien untuk digunakan adalah numpy.any():

>>> all_zeros = not np.any(a)

atau

>>> all_zeros = not a.any()
  • Ini lebih disukai numpy.all(a==0)karena menggunakan lebih sedikit RAM. (Ini tidak memerlukan array sementara yang dibuat oleh a==0istilah.)
  • Selain itu, lebih cepat daripada numpy.count_nonzero(a)karena dapat segera kembali ketika elemen bukan nol pertama telah ditemukan.
    • Edit: Seperti yang @Rachel tunjukkan di komentar, np.any()tidak lagi menggunakan logika "sirkuit pendek", jadi Anda tidak akan melihat manfaat kecepatan untuk array kecil.
Stuart Berg
sumber
3
Pada menit lalu, numpy ini anydan allmelakukan tidak pendek-sirkuit. Saya percaya mereka adalah gula untuk logical_or.reducedan logical_and.reduce. Bandingkan satu sama lain dan korsleting saya is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel
3
Itu poin yang bagus, terima kasih. Sepertinya dulu perilaku korsleting adalah perilaku, tetapi itu hilang di beberapa titik. Ada beberapa pembahasan menarik dalam menjawab pertanyaan ini .
Stuart Berg
50

Saya akan menggunakan np.all di sini, jika Anda memiliki array a:

>>> np.all(a==0)
J. Martinot-Lagarde
sumber
3
Saya suka bahwa jawaban ini memeriksa nilai bukan nol juga. Misalnya, seseorang dapat memeriksa apakah semua elemen dalam array sama dengan melakukan np.all(a==a[0]). Terima kasih banyak!
aignas
9

Seperti jawaban lain mengatakan, Anda dapat memanfaatkan evaluasi benar / salah jika Anda tahu bahwa itu 0adalah satu-satunya elemen yang mungkin salah dalam array Anda. Semua elemen dalam array salah jika tidak ada elemen yang benar di dalamnya. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Namun, jawabannya mengklaim bahwa anylebih cepat daripada opsi lain sebagian karena korsleting. Pada 2018, Numpy alldan any tidak mengalami hubungan arus pendek .

Jika Anda sering melakukan hal semacam ini, sangat mudah untuk membuat versi hubung singkat Anda sendiri menggunakan numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Ini cenderung lebih cepat daripada versi Numpy bahkan saat tidak terjadi hubungan arus pendek. count_nonzeroadalah yang paling lambat.

Beberapa masukan untuk memeriksa kinerja:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Memeriksa:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Bermanfaat alldan anykesetaraan:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))
Rachel
sumber
-8

Jika Anda menguji semua nol untuk menghindari peringatan pada fungsi numpy lain kemudian membungkus garis dalam percobaan, kecuali blok akan menghemat keharusan melakukan pengujian untuk nol sebelum operasi yang Anda minati yaitu

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
ReaddyEddy
sumber