np.where
et options
Rechercher tous les éléments d'un vecteur correspondant à une clause
a = np.random.randn(100,2)
# element de la premiere colonne < 0.5
index = np.where(a<0.5) # retourne les indices dans (I,J)
# I: indice des lignes
# J: indice des colonnes
# recherche dans une colonne
index2 = np.where(a[:,0]<0.5)
Attention au type de retour
Le type de index
est sans surprise... Mais celui de index2
est plus déroutant: il s'agit d'un tuple mais avec un seul champ rempli...
Pour utiliser les indices extraits facilement, il faut donc faire:
index2, = np.where(a[:,0]<0.5) # on ne s'intéresse qu'au premier membre!
a[index2,:] = ...
Transformation de matrice
La fonction np.where
est très utile pour transformer les matrices
a = np.random.randn(100,2)
# Mettre à zeros tous les éléments négatifs:
b = np.where(a<0., 0., a) # (clause, TODO if true, TODO if false)
# Extraire le signe des éléments de a
c = np.where(a<0., -1., 1.)
Double clause dans np.where
[piège]
Attention aux parenthèses dans les doubles clauses np.where: la priorité des opérations rend certaines parenthèses obligatoires
# pour l'estimation d'une loi jointe entre a et b
N = 100
a = np.ceil(np.random.rand(N) * 10) # entre 1 et 10
b = np.round(np.random.rand(N)) # 0 ou 1
np.where((a == 4) & (b==0), 1., 0.) OK
np.where( a == 4 & b==0 , 1., 0.) KO !!!
Autres syntaxes & cas d'étude
Comment afficher un jeu de données bi-classes en 2 couleurs???
Solution 1 (avec le np.where
classique)
import numpy as np
import numpy.random as rnd
import matplotlib.pyplot as plt
# génération des points de la classe 1 & 2
N=100
x = np.vstack((rnd.randn(N,2)+2,rnd.randn(N,2)-2)) # données 2D
y = np.ones(2*N) # étiquettes
y[:N] = -1
# comment afficher chaque classe d'une couleur???
# solution 1
ind1 = np.where(y==1)
ind2 = np.where(y==-1)
plt.figure()
plt.plot(x[ind1, 0],x[ind1, 1], 'b+') # aff en croix bleues
plt.plot(x[ind2, 0],x[ind2, 1], 'r*') # aff en étoiles rouges
Solution 2 (avec le np.where
implicite)
# comment afficher chaque classe d'une couleur???
# solution 2
plt.figure()
plt.plot(x[y==1, 0],x[y==1, 1], 'b+') # aff en croix bleues
plt.plot(x[y==-1, 0],x[y==-1, 1], 'r*') # aff en étoiles rouges
Solution 3, encore plus malin
Scatter permet la définition d'un style par point...
On utilise y
directement dans l'affichage.
# comment afficher chaque classe d'une couleur???
# solution 3
plt.figure()
plt.scatter(x[:,0], x[:,1], c=y)
Applications
Verification des propriétés de la loi normale
Générer un vecteur contenant 1000 éléments tirés selon la loi normale
- vérifier que la loi est centrée en 0 (à peu près autant d'éléments >0 et <0)
- vérifier que 2/3 des éléments sont entre {$-\sigma$} et {$-\sigma$}