Main

Retour vers le tutoriel complet

np.where et options

Rechercher tous les éléments d'un vecteur correspondant à une clause

a = np.random.randn(100,2)

# element de la premiere colonne < 0.5
index = np.where(a<0.5) # retourne les indices dans (I,J)
                        # I: indice des lignes
                        # J: indice des colonnes

# recherche dans une colonne
index2 = np.where(a[:,0]<0.5)

Attention au type de retour

Le type de index est sans surprise... Mais celui de index2 est plus déroutant: il s'agit d'un tuple mais avec un seul champ rempli...

Pour utiliser les indices extraits facilement, il faut donc faire:

index2, = np.where(a[:,0]<0.5) # on ne s'intéresse qu'au premier membre!
a[index2,:] = ...

Transformation de matrice

La fonction np.where est très utile pour transformer les matrices

a = np.random.randn(100,2)
# Mettre à zeros tous les éléments négatifs:
b = np.where(a<0., 0., a) # (clause, TODO if true, TODO if false)
# Extraire le signe des éléments de a
c = np.where(a<0., -1., 1.)
 

Double clause dans np.where [piège]

Attention aux parenthèses dans les doubles clauses np.where: la priorité des opérations rend certaines parenthèses obligatoires

# pour l'estimation d'une loi jointe entre a et b
N = 100
a = np.ceil(np.random.rand(N) * 10) # entre 1 et 10
b = np.round(np.random.rand(N)) # 0 ou 1
np.where((a == 4) & (b==0), 1., 0.) OK
np.where( a == 4  &  b==0 , 1., 0.) KO !!!

Autres syntaxes & cas d'étude

Comment afficher un jeu de données bi-classes en 2 couleurs???

Solution 1 (avec le np.where classique)

import numpy as np
import numpy.random as rnd
import matplotlib.pyplot as plt

# génération des points de la classe 1 & 2
N=100
x = np.vstack((rnd.randn(N,2)+2,rnd.randn(N,2)-2)) # données 2D
y = np.ones(2*N) # étiquettes
y[:N] = -1

# comment afficher chaque classe d'une couleur???

# solution 1
ind1 = np.where(y==1)
ind2 = np.where(y==-1)
plt.figure()
plt.plot(x[ind1, 0],x[ind1, 1], 'b+') # aff en croix bleues
plt.plot(x[ind2, 0],x[ind2, 1], 'r*') # aff en étoiles rouges

Solution 2 (avec le np.where implicite)

# comment afficher chaque classe d'une couleur???

# solution 2
plt.figure()
plt.plot(x[y==1, 0],x[y==1, 1], 'b+') # aff en croix bleues
plt.plot(x[y==-1, 0],x[y==-1, 1], 'r*') # aff en étoiles rouges

Solution 3, encore plus malin

Scatter permet la définition d'un style par point... On utilise y directement dans l'affichage.

# comment afficher chaque classe d'une couleur???

# solution 3
plt.figure()
plt.scatter(x[:,0], x[:,1], c=y)

Applications

Verification des propriétés de la loi normale

Générer un vecteur contenant 1000 éléments tirés selon la loi normale

  1. vérifier que la loi est centrée en 0 (à peu près autant d'éléments >0 et <0)
  2. vérifier que 2/3 des éléments sont entre {$-\sigma$} et {$-\sigma$}