Cours

Subpage of Semaine 8 TME 8 Markov Field

##################################################################################
# Fonctions d'affichage

def drawImageTreatement(index, data):
    #plt.figure()
    fig, ax = plt.subplots(2,2)
    ax[0,0].imshow(data['im'][index], interpolation='nearest')
    ax[0,0].set_title('image de base')
    ax[0,1].imshow(data['lab'][index],interpolation='nearest')
    ax[0,1].set_title('etiquetage pixel')
    #ax[0,2].imshow(data['seg'][index],interpolation='nearest')
    ax[1,0].imshow(mark_boundaries(data['im'][index], data['seg'][index]),interpolation='nearest')
    ax[1,0].set_title('segmentation 16x16')
    newIm = data['seg'][index].copy()
    for i,reg in enumerate(np.sort(np.unique(data['seg'][index]))):
        newIm[data['seg'][index] == reg] = data['segLab'][index][i]
    ax[1,1].imshow(newIm,interpolation='nearest')
    ax[1,1].set_title('etiquetage segment')
    #plt.savefig("traitementIm.png")
    fig,ax=plt.subplots()
    color = ['blue','red','green','black']
    ax.imshow(mark_boundaries(data['im'][index], data['seg'][index]),interpolation='nearest')
    scale = len(data['im'][index])
    ax.scatter(data['coord'][index][:,2]*scale, data['coord'][index][:,0]*scale)
    for i,node in enumerate(data['graph'][index]):
        for c,side in enumerate(node):
            if c==1 or c ==3:
                continue
            #print side
            if not side == -1:
                ax.add_line(lines.Line2D(data['coord'][index][[i,side],2]*scale, data['coord'][index][[i,side],0]*scale, linewidth=1, color=color[c]))
    ax.set_title('graphe de l\'image')
    #plt.savefig("traitementImGr.png")

def introspectionModel(classif, data, Y, filename=None):
    # introspection: qu'est ce qui est associé à chaque classe de données
    plt.figure()
    plt.imshow(classif.coef_, interpolation='nearest')
    localLabs = [data['labMeaning'][y-1] for y in np.unique(Y)]
    plt.yticks(range(len(localLabs)),localLabs)
    rgb = [color+' '+strength for color in ['red', 'green', 'blue'] for strength in ['(low)', '(med)', '(high)']]
    plt.xticks(np.arange(len(rgb))-0.5,rgb, rotation=45)
    plt.vlines(np.array([2.5,5.5]), -1, len(localLabs),linestyles='dashed')
    if filename != None:
        plt.savefig(filename)

def plotTransMatrix(A, labStr):
    title = ['g','d','h','b']
    fig, ax = plt.subplots(2,2)
    c=0
    for i in range(2):
        for j in range(2):
            ax[i,j].imshow(A[:,:,c], interpolation='nearest')
            ax[i,j].set_title(title[c])
            plt.setp(ax[i,j], xticks=np.arange(len(labStr)), xticklabels=labStr,  yticks=np.arange(len(labStr)),yticklabels=labStr)
            for tick in ax[i,j].get_xticklabels():
                tick.set_rotation(90)
            c+=1