Testen Sie, ob das Numpy-Array nur Nullen enthält


92

Wir initialisieren ein Numpy-Array mit Nullen wie folgt:

np.zeros((N,N+1))

Aber wie überprüfen wir, ob alle Elemente in einer gegebenen n * n numpy-Array-Matrix Null sind?
Die Methode muss nur ein True zurückgeben, wenn alle Werte tatsächlich Null sind.

Antworten:



160

Die anderen hier veröffentlichten Antworten funktionieren, aber die klarste und effizienteste Funktion ist numpy.any():

>>> all_zeros = not np.any(a)

oder

>>> all_zeros = not a.any()
  • Dies wird gegenüber dem numpy.all(a==0)RAM- Vorzug bevorzugt . (Das durch den a==0Begriff erstellte temporäre Array ist nicht erforderlich .)
  • Außerdem ist es schneller als numpy.count_nonzero(a)weil es sofort zurückkehren kann, wenn das erste Element ungleich Null gefunden wurde.
    • Bearbeiten: Wie @Rachel in den Kommentaren hervorhob, wird np.any()keine "Kurzschluss" -Logik mehr verwendet, sodass Sie für kleine Arrays keinen Geschwindigkeitsvorteil sehen.

2
Ab einer Minute vor, numpy ist anyund alltun nicht kurzschließen. Ich glaube, sie sind Zucker für logical_or.reduceund logical_and.reduce. Vergleichen Sie miteinander und meinen Kurzschluss is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

2
Das ist ein großartiger Punkt, danke. Es sieht aus wie ein Kurzschluss verwendet , das Verhalten zu sein, aber das irgendwann verloren. In den Antworten auf diese Frage gibt es einige interessante Diskussionen .
Stuart Berg

50

Ich würde hier np.all verwenden, wenn Sie ein Array haben a:

>>> np.all(a==0)

3
Ich finde es gut, dass diese Antwort auch nach Werten ungleich Null sucht. Zum Beispiel kann man überprüfen, ob alle Elemente in einem Array gleich sind np.all(a==a[0]). Vielen Dank!
Aignas

9

Wie eine andere Antwort sagt, können Sie Wahrheits- / Falschbewertungen nutzen, wenn Sie wissen, dass dies 0das einzige Falschelement ist, das möglicherweise in Ihrem Array enthalten ist. Alle Elemente in einem Array sind falsch, wenn es keine wahrheitsgemäßen Elemente enthält. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

In der Antwort wurde jedoch behauptet, dass dies anyteilweise aufgrund von Kurzschlüssen schneller sei als bei anderen Optionen. Ab 2018 Numpy's allund any nicht kurzschließen .

Wenn Sie so etwas oft tun, ist es sehr einfach, Ihre eigenen Kurzschlussversionen zu erstellen, indem Sie numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Diese sind in der Regel schneller als die Versionen von Numpy, auch wenn sie nicht kurzgeschlossen werden. count_nonzeroist das langsamste.

Einige Eingaben zur Überprüfung der Leistung:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Prüfen:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Hilfreich allund anyÄquivalenzen:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-9

Wenn Sie auf alle Nullen testen, um eine Warnung bei einer anderen Numpy-Funktion zu vermeiden, wird die Zeile in einen Versuch eingeschlossen, mit Ausnahme des Blocks, der es erspart, den Test für Nullen vor der Operation durchzuführen, an der Sie interessiert sind, d. H.

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.