"""
Module: CNN (Reseaux Convolutifs)
Categorie: Deep Learning
Difficulte: Intermediaire

Genere depuis la plateforme ML Formation
"""

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score

# Charger le dataset
df = pd.read_csv('digits_simple.csv')

# Explorer les donnees (chiffres)
# Type: Code executable
print("=" * 70)
print("       EXPLORATION DU DATASET DE CHIFFRES MANUSCRITS")
print("=" * 70)

print("""
Les CNN sont nes pour traiter des IMAGES. Nous allons travailler avec
un dataset classique: des chiffres manuscrits (version simplifiee).

POURQUOI CE DATASET EST PARFAIT POUR APPRENDRE LES CNN:
  • Images petites (4x4 = 16 pixels) → rapide a traiter
  • 10 classes (0-9) → classification multiclasse
  • Patterns visuels distincts → facile a interpreter
""")

print("\n" + "=" * 70)
print("1. APERCU DU DATASET")
print("=" * 70)
display(df.head(10), title="Premieres lignes du dataset de chiffres")

print("""
STRUCTURE DES DONNEES:
  • Chaque ligne = une image de chiffre
  • Colonnes pixel_* = valeurs des pixels (0=noir, 15=blanc)
  • Colonne 'label' = le chiffre represente (0-9)
""")

print("\n" + "=" * 70)
print("2. DIMENSIONS")
print("=" * 70)

# Separer features et labels
X = df.drop('label', axis=1).values
y = df['label'].values

n_samples, n_pixels = X.shape
img_size = int(np.sqrt(n_pixels))

print(f"""
DIMENSIONS DU DATASET:
  • Nombre d'images: {n_samples}
  • Pixels par image: {n_pixels}
  • Taille d'une image: {img_size}x{img_size} pixels

NOTE: En production, les images sont plus grandes:
  • MNIST: 28x28 = 784 pixels
  • CIFAR-10: 32x32x3 = 3 072 valeurs
  • ImageNet: 224x224x3 = 150 528 valeurs!
""")

print("\n" + "=" * 70)
print("3. DISTRIBUTION DES CLASSES")
print("=" * 70)

print("""
Combien d'exemples avons-nous pour chaque chiffre?
""")

unique, counts = np.unique(y, return_counts=True)
print("-" * 45)
print(f"{'Chiffre':<10} {'Nb images':<12} {'Distribution'}")
print("-" * 45)

for digit, count in zip(unique, counts):
    pct = count / len(y) * 100
    bar = "█" * int(pct / 2)
    print(f"{digit:<10} {count:<12} {pct:5.1f}% {bar}")

print("-" * 45)
print(f"{'TOTAL':<10} {len(y):<12} 100.0%")

# Verifier l'equilibre
balance_ratio = max(counts) / min(counts)
print(f"""

ANALYSE DE L'EQUILIBRE:
  • Classe la plus frequente: {max(counts)} images
  • Classe la moins frequente: {min(counts)} images
  • Ratio: {balance_ratio:.2f}
""")

if balance_ratio < 1.5:
    print("  → Dataset bien equilibre! Toutes les classes sont representees.")
else:
    print("  → Desequilibre detecte. Certaines classes sont sous-representees.")

print("\n" + "=" * 70)
print("4. STATISTIQUES DES PIXELS")
print("=" * 70)

print(f"""
VALEURS DES PIXELS:
  • Minimum: {X.min():.0f} (noir complet)
  • Maximum: {X.max():.0f} (blanc complet)
  • Moyenne: {X.mean():.2f}
  • Ecart-type: {X.std():.2f}

INTERPRETATION:
  Les valeurs sont entre 0 (noir) et ~16 (blanc).
  Les pixels forment des MOTIFS que le CNN va apprendre a reconnaitre.
""")

print("\n" + "=" * 70)
print("              PRET POUR LA VISUALISATION DES CHIFFRES!")
print("=" * 70)


# Visualiser les chiffres
# Type: Code executable
print("=" * 70)
print("       VISUALISATION DES CHIFFRES MANUSCRITS")
print("=" * 70)

print("""
Visualisons les donnees pour comprendre ce que le CNN doit apprendre.
Chaque chiffre a des caracteristiques visuelles distinctes.
""")

# Separer features et labels
X = df.drop('label', axis=1).values
y = df['label'].values
img_size = int(np.sqrt(X.shape[1]))

print("\n" + "=" * 70)
print("1. UN EXEMPLE DE CHAQUE CHIFFRE (0-9)")
print("=" * 70)

fig, axes = plt.subplots(2, 5, figsize=(14, 6))

for i, ax in enumerate(axes.flat):
    # Trouver un exemple du chiffre i
    idx = np.where(y == i)[0][0]
    image = X[idx].reshape(img_size, img_size)

    ax.imshow(image, cmap='gray', interpolation='nearest')
    ax.set_title(f'Chiffre: {i}', fontsize=12, fontweight='bold')
    ax.axis('off')

    # Ajouter une bordure coloree
    for spine in ax.spines.values():
        spine.set_edgecolor('#9B7AC4')
        spine.set_linewidth(2)

plt.suptitle('Exemples de chiffres manuscrits (4x4 pixels)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("""
OBSERVATIONS:
  • Chaque chiffre a une forme caracteristique
  • Les pixels clairs (blancs) dessinent le chiffre
  • Les pixels sombres (noirs) forment le fond

CE QUE LE CNN DOIT APPRENDRE:
  Le CNN va detecter des MOTIFS (patterns) qui distinguent chaque chiffre:
  - Lignes verticales: 1, 4, 7
  - Boucles: 0, 6, 8, 9
  - Angles: 4, 7
  - Courbes: 2, 3, 5, 6, 9
""")

print("\n" + "=" * 70)
print("2. VARIABILITE D'UN MEME CHIFFRE")
print("=" * 70)

print("""
Un meme chiffre peut etre ecrit de differentes facons.
Le CNN doit etre ROBUSTE a ces variations!
""")

# Montrer plusieurs exemples du chiffre 3
target_digit = 3
indices = np.where(y == target_digit)[0][:8]

fig, axes = plt.subplots(1, 8, figsize=(14, 2.5))
fig.suptitle(f'8 facons differentes d\'ecrire le chiffre {target_digit}', fontsize=13, fontweight='bold')

for ax, idx in zip(axes, indices):
    image = X[idx].reshape(img_size, img_size)
    ax.imshow(image, cmap='gray', interpolation='nearest')
    ax.axis('off')

plt.tight_layout()
plt.show()

print("""
DEFI POUR LE CNN:
  Malgre ces differences, le CNN doit reconnaitre
  que tous ces exemples representent le meme chiffre!

C'est la que la CONVOLUTION aide:
  → Detecte les motifs LOCAUX (bords, courbes)
  → Invariante a la position (grace au pooling)
""")

print("\n" + "=" * 70)
print("3. VISUALISATION EN GRILLE DE PIXELS")
print("=" * 70)

print("""
Regardons un chiffre pixel par pixel pour comprendre
comment les valeurs numeriques forment l'image:
""")

# Prendre le premier exemple
image = X[0].reshape(img_size, img_size)
label = y[0]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Image avec valeurs
ax1 = axes[0]
im = ax1.imshow(image, cmap='gray', interpolation='nearest')
ax1.set_title(f'Chiffre {label} - Valeurs des pixels', fontsize=12, fontweight='bold')

# Afficher les valeurs
for i in range(img_size):
    for j in range(img_size):
        color = 'white' if image[i, j] < 8 else 'black'
        ax1.text(j, i, f'{image[i,j]:.0f}', ha='center', va='center',
                color=color, fontsize=10, fontweight='bold')
ax1.axis('off')

# Histogramme des valeurs
ax2 = axes[1]
ax2.hist(image.flatten(), bins=16, range=(0, 16), color='#9B7AC4', edgecolor='white', alpha=0.7)
ax2.set_xlabel('Valeur du pixel', fontsize=11)
ax2.set_ylabel('Frequence', fontsize=11)
ax2.set_title('Distribution des valeurs de pixels', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"""
LECTURE DU GRAPHIQUE:
  • Gauche: l'image avec la valeur numerique de chaque pixel
  • Droite: histogramme montrant la repartition des valeurs

Le chiffre est "dessine" par les pixels de haute valeur (clairs).
""")

print("\n" + "=" * 70)
print("              VISUALISATION TERMINEE!")
print("=" * 70)


# Simuler une convolution
# Type: Code executable
from scipy.signal import convolve2d

print("=" * 70)
print("       SIMULATION D'UNE OPERATION DE CONVOLUTION")
print("=" * 70)

print("""
La CONVOLUTION est l'operation fondamentale des CNN.
Elle "glisse" un FILTRE (kernel) sur l'image pour detecter des motifs.

ANALOGIE:
  C'est comme chercher un motif dans l'image avec une loupe speciale.
  Chaque filtre est une "loupe" qui detecte un type de motif specifique.
""")

print("\n" + "=" * 70)
print("1. DEFINITION DES FILTRES")
print("=" * 70)

# Prendre une image
X = df.drop('label', axis=1).values
y = df['label'].values
img_size = int(np.sqrt(X.shape[1]))
image = X[0].reshape(img_size, img_size)

print(f"""
Image selectionnee: Chiffre {y[0]} ({img_size}x{img_size} pixels)

FILTRES DE DETECTION (3x3):
Ces filtres detectent des motifs specifiques dans l'image.
""")

# Definir des filtres simples
filters = {
    'Horizontal': np.array([[-1, -1, -1],
                            [ 0,  0,  0],
                            [ 1,  1,  1]]),
    'Vertical': np.array([[-1, 0, 1],
                          [-1, 0, 1],
                          [-1, 0, 1]]),
    'Diagonal': np.array([[ 0, -1, -1],
                          [ 1,  0, -1],
                          [ 1,  1,  0]])
}

print("-" * 50)
for name, kernel in filters.items():
    print(f"\nFiltre {name}:")
    for row in kernel:
        print("  ", row)
    if name == 'Horizontal':
        print("  → Detecte les bords HORIZONTAUX")
    elif name == 'Vertical':
        print("  → Detecte les bords VERTICAUX")
    else:
        print("  → Detecte les bords DIAGONAUX")
print("-" * 50)

print("\n" + "=" * 70)
print("2. APPLICATION DES FILTRES")
print("=" * 70)

print("""
Appliquons chaque filtre a l'image.
Le resultat s'appelle une FEATURE MAP.
""")

# Appliquer les filtres
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Premiere ligne: image originale et feature maps
axes[0, 0].imshow(image, cmap='gray', interpolation='nearest')
axes[0, 0].set_title('Image originale', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

feature_maps = {}
for ax, (name, kernel) in zip(axes[0, 1:], filters.items()):
    # Convolution
    output = convolve2d(image, kernel, mode='same')
    feature_maps[name] = output

    im = ax.imshow(output, cmap='RdBu', interpolation='nearest')
    ax.set_title(f'Filtre: {name}', fontsize=12, fontweight='bold')
    ax.axis('off')

# Deuxieme ligne: afficher les filtres
axes[1, 0].text(0.5, 0.5, 'FILTRES\n(Kernels)', ha='center', va='center',
                fontsize=14, fontweight='bold', transform=axes[1, 0].transAxes)
axes[1, 0].axis('off')

for ax, (name, kernel) in zip(axes[1, 1:], filters.items()):
    im = ax.imshow(kernel, cmap='RdBu', interpolation='nearest')
    ax.set_title(f'Kernel {name}', fontsize=11)
    # Afficher les valeurs
    for i in range(3):
        for j in range(3):
            color = 'white' if abs(kernel[i, j]) > 0.5 else 'black'
            ax.text(j, i, f'{kernel[i,j]:+d}', ha='center', va='center',
                   color=color, fontsize=11, fontweight='bold')
    ax.axis('off')

plt.suptitle('Convolution: Image × Filtre = Feature Map', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("""
INTERPRETATION DES FEATURE MAPS:
  • Rouge/Bleu = le filtre a detecte un motif (positive/negative)
  • Blanc = pas de motif detecte

COMMENT LIRE LES RESULTATS:
  • Filtre Horizontal: points rouges/bleus = bords horizontaux
  • Filtre Vertical: points rouges/bleus = bords verticaux
  • Filtre Diagonal: points rouges/bleus = bords diagonaux
""")

print("\n" + "=" * 70)
print("3. STATISTIQUES DES FEATURE MAPS")
print("=" * 70)

print("""
Analysons ce que chaque filtre a detecte:
""")
print("-" * 55)
print(f"{'Filtre':<12} {'Min':<10} {'Max':<10} {'Moyenne':<10} {'Std'}")
print("-" * 55)

for name, fmap in feature_maps.items():
    print(f"{name:<12} {fmap.min():<10.2f} {fmap.max():<10.2f} {fmap.mean():<10.2f} {fmap.std():.2f}")

print("-" * 55)

print("""

INTERPRETATION:
  • Un ecart-type ELEVE signifie que le filtre detecte beaucoup de motifs
  • Un ecart-type FAIBLE signifie peu de motifs de ce type dans l'image

DANS UN VRAI CNN:
  • Les filtres sont APPRIS automatiquement!
  • Le reseau optimise les filtres pour mieux classifier les images.
""")

print("\n" + "=" * 70)
print("              CONVOLUTION SIMULEE AVEC SUCCES!")
print("=" * 70)


# Max Pooling
# Type: Code executable
from skimage.measure import block_reduce

print("=" * 70)
print("       LE MAX POOLING: REDUCTION DE DIMENSION")
print("=" * 70)

print("""
Apres la convolution, on applique le POOLING pour:
  1. Reduire la taille de la feature map
  2. Rendre le modele robuste aux petites translations
  3. Reduire le nombre de parametres

TYPES DE POOLING:
  • MaxPool: garde le MAXIMUM de chaque region
  • AvgPool: garde la MOYENNE de chaque region

MaxPool est le plus courant car il preserve les features fortes.
""")

print("\n" + "=" * 70)
print("1. EXEMPLE SUR UNE IMAGE")
print("=" * 70)

# Prendre une image
X = df.drop('label', axis=1).values
y = df['label'].values
img_size = int(np.sqrt(X.shape[1]))
image = X[0].reshape(img_size, img_size)

print(f"""
Image originale: {img_size}x{img_size} = {img_size**2} pixels
Apres MaxPool 2x2: {img_size//2}x{img_size//2} = {(img_size//2)**2} pixels

REDUCTION: {img_size**2} → {(img_size//2)**2} pixels (-75%)
""")

# Simuler max pooling 2x2
pooled = block_reduce(image, (2, 2), np.max)

print("\n" + "=" * 70)
print("2. VISUALISATION AVANT/APRES")
print("=" * 70)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Image originale avec valeurs
im1 = axes[0].imshow(image, cmap='gray', interpolation='nearest')
axes[0].set_title(f'Avant: {image.shape[0]}x{image.shape[1]} pixels', fontsize=12, fontweight='bold')
for i in range(image.shape[0]):
    for j in range(image.shape[1]):
        color = 'white' if image[i,j] < 8 else 'black'
        axes[0].text(j, i, f'{image[i,j]:.0f}', ha='center', va='center',
                    color=color, fontsize=9, fontweight='bold')
axes[0].axis('off')

# Schema du pooling
axes[1].text(0.5, 0.7, 'MaxPool 2x2', ha='center', va='center',
             fontsize=16, fontweight='bold', transform=axes[1].transAxes)
axes[1].text(0.5, 0.5, '↓', ha='center', va='center',
             fontsize=40, transform=axes[1].transAxes)
axes[1].text(0.5, 0.3, 'Prend le MAX\nde chaque bloc 2x2', ha='center', va='center',
             fontsize=12, transform=axes[1].transAxes)
axes[1].axis('off')

# Image pooled avec valeurs
im2 = axes[2].imshow(pooled, cmap='gray', interpolation='nearest')
axes[2].set_title(f'Apres: {pooled.shape[0]}x{pooled.shape[1]} pixels', fontsize=12, fontweight='bold')
for i in range(pooled.shape[0]):
    for j in range(pooled.shape[1]):
        color = 'white' if pooled[i,j] < 8 else 'black'
        axes[2].text(j, i, f'{pooled[i,j]:.0f}', ha='center', va='center',
                    color=color, fontsize=14, fontweight='bold')
axes[2].axis('off')

plt.suptitle(f'Max Pooling - Chiffre {y[0]}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("""
COMMENT CA MARCHE:
  1. Divise l'image en blocs de 2x2 pixels
  2. Pour chaque bloc, garde UNIQUEMENT le maximum
  3. Resultat: image 2x plus petite

EXEMPLE CONCRET:
""")

# Montrer un exemple de bloc
print("  Bloc 2x2 en haut-gauche de l'image:")
print(f"    [{image[0,0]:.0f}, {image[0,1]:.0f}]")
print(f"    [{image[1,0]:.0f}, {image[1,1]:.0f}]")
print(f"    → Maximum = {pooled[0,0]:.0f}")

print("\n" + "=" * 70)
print("3. POURQUOI LE MAX POOLING EST UTILE")
print("=" * 70)

print("""
AVANTAGES DU MAX POOLING:

1. REDUCTION DE DIMENSION:
   • Moins de pixels = moins de calculs
   • Accelere l'entrainement et l'inference

2. INVARIANCE AUX TRANSLATIONS:
   • Un motif decale de 1 pixel donne le meme resultat apres pooling
   • Le modele devient robuste aux petites variations de position

3. SELECTION DES FEATURES:
   • Garde les activations FORTES (features importantes)
   • Ignore les activations faibles (bruit)

ANALOGIE:
  C'est comme resumer un paragraphe en gardant uniquement
  les mots les plus importants!
""")

print("\n" + "=" * 70)
print("4. STATISTIQUES")
print("=" * 70)

print(f"""
AVANT POOLING:
  • Taille: {image.shape[0]}x{image.shape[1]} = {image.size} pixels
  • Somme: {image.sum():.0f}
  • Moyenne: {image.mean():.2f}

APRES POOLING:
  • Taille: {pooled.shape[0]}x{pooled.shape[1]} = {pooled.size} pixels
  • Somme: {pooled.sum():.0f}
  • Moyenne: {pooled.mean():.2f}

L'information essentielle est PRESERVEE malgre la reduction!
""")

print("\n" + "=" * 70)
print("              MAX POOLING EXPLIQUE!")
print("=" * 70)


# Classification avec features convolutives
# Type: Code executable
from scipy.signal import convolve2d
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

print("=" * 70)
print("       CLASSIFICATION AVEC FEATURES CONVOLUTIVES")
print("=" * 70)

print("""
Maintenant, utilisons la convolution pour AMELIORER la classification!

APPROCHE:
1. Appliquer des filtres de convolution aux images
2. Extraire des statistiques des feature maps
3. Utiliser ces features pour un classifieur standard

C'est une version simplifiee de ce que fait un CNN!
""")

print("\n" + "=" * 70)
print("1. PREPARATION DES DONNEES")
print("=" * 70)

X = df.drop('label', axis=1).values
y = df['label'].values
img_size = int(np.sqrt(X.shape[1]))

print(f"""
Dataset:
  • {X.shape[0]} images
  • {X.shape[1]} pixels par image ({img_size}x{img_size})
  • {len(np.unique(y))} classes (chiffres 0-9)
""")

print("\n" + "=" * 70)
print("2. DEFINITION DES FILTRES")
print("=" * 70)

filters = [
    np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]),  # Vertical
    np.array([[-1, -1, -1], [0, 0, 0], [1, 1, 1]]),  # Horizontal
]

print("""
FILTRES UTILISES:
  • Filtre Vertical: detecte les bords gauche/droite
  • Filtre Horizontal: detecte les bords haut/bas

FEATURES EXTRAITES PAR FILTRE:
  • Moyenne de la feature map
  • Ecart-type de la feature map
  • Maximum de la feature map
  • Minimum de la feature map

Total: 2 filtres x 4 stats = 8 features convolutives
""")

print("\n" + "=" * 70)
print("3. EXTRACTION DES FEATURES CONVOLUTIVES")
print("=" * 70)

def extract_conv_features(images, filters, img_size):
    """Extrait des features statistiques des feature maps."""
    features = []
    for img in images:
        img_2d = img.reshape(img_size, img_size)
        img_features = []
        for filt in filters:
            conv = convolve2d(img_2d, filt, mode='same')
            # Statistiques de la feature map
            img_features.extend([conv.mean(), conv.std(), conv.max(), conv.min()])
        features.append(img_features)
    return np.array(features)

print("Extraction des features en cours...")
X_conv = extract_conv_features(X, filters, img_size)

print(f"""
TRANSFORMATION:
  • Entree: {X.shape[1]} pixels bruts
  • Sortie: {X_conv.shape[1]} features convolutives

EXEMPLE DE FEATURES (premiere image):
""")
feature_names = ['V_mean', 'V_std', 'V_max', 'V_min', 'H_mean', 'H_std', 'H_max', 'H_min']
print("-" * 50)
for name, val in zip(feature_names, X_conv[0]):
    print(f"  {name}: {val:8.3f}")
print("-" * 50)

print("\n" + "=" * 70)
print("4. ENTRAINEMENT DU CLASSIFIEUR")
print("=" * 70)

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X_conv, y, test_size=0.2, random_state=42)

print(f"""
DIVISION DES DONNEES:
  • Entrainement: {len(X_train)} images
  • Test: {len(X_test)} images
""")

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, random_state=42)
clf.fit(X_train_scaled, y_train)

# Evaluation
from sklearn.metrics import accuracy_score, classification_report

y_pred = clf.predict(X_test_scaled)
acc = accuracy_score(y_test, y_pred)

print(f"""
RESULTATS:
═══════════════════════════════════
  Accuracy avec features convolutives: {acc:.1%}
═══════════════════════════════════
""")

print("\n" + "=" * 70)
print("5. DETAILS PAR CLASSE")
print("=" * 70)

print("""
Performance pour chaque chiffre:
""")

from sklearn.metrics import precision_recall_fscore_support
precision, recall, f1, support = precision_recall_fscore_support(y_test, y_pred, average=None)

print("-" * 60)
print(f"{'Chiffre':<10} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support'}")
print("-" * 60)

for i in range(10):
    if i < len(precision):
        print(f"{i:<10} {precision[i]:<12.3f} {recall[i]:<12.3f} {f1[i]:<12.3f} {support[i]}")

print("-" * 60)

print("""

INTERPRETATION:
  • Les features convolutives capturent les motifs des chiffres
  • Certains chiffres sont plus faciles a distinguer que d'autres
  • Un vrai CNN avec plus de filtres et des couches profondes
    ferait encore mieux!
""")

print("\n" + "=" * 70)
print("              CLASSIFICATION TERMINEE!")
print("=" * 70)


# Exercice: Comparer avec/sans convolution
# Type: Exercice
# Exercice: Comparez les performances avec et sans features convolutives

from scipy.signal import convolve2d
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

print("=" * 70)
print("       EXERCICE: IMPACT DES FEATURES CONVOLUTIVES")
print("=" * 70)

print("""
OBJECTIF:
Comparer la classification sur:
  1. Les pixels bruts (16 features)
  2. Les features convolutives extraites

QUESTIONS:
  • Quelle approche donne la meilleure accuracy?
  • Pourquoi les features convolutives peuvent aider?
""")

# Preparez les donnees
X = df.drop('label', axis=1).values
y = df['label'].values
img_size = int(np.sqrt(X.shape[1]))

# TODO: Entrainez un classifieur sur les pixels bruts
# TODO: Entrainez un classifieur sur les features convolutives
# TODO: Comparez les accuracies

print("""
VOTRE CODE ICI:
---------------
# Indice: utilisez extract_conv_features() pour les features convolutives
# Comparez les accuracy avec LogisticRegression
""")


# Quand utiliser un CNN?
# Type: Code executable
print("=" * 70)
print("       GUIDE: QUAND UTILISER UN CNN?")
print("=" * 70)

print("""
Les CNN sont specialises pour les donnees avec STRUCTURE SPATIALE.
Voici un guide pour savoir quand les utiliser.
""")

print("\n" + "=" * 70)
print("1. CAS D'USAGE DES CNN")
print("=" * 70)

use_cases = [
    ("Classification d'images", "Chats/chiens, chiffres, objets", "OUI", "★★★"),
    ("Detection d'objets", "Ou sont les voitures?", "OUI", "★★★"),
    ("Segmentation", "Delimiter les organes", "OUI", "★★★"),
    ("Donnees tabulaires", "Prix immobilier, churn", "NON", "☆☆☆"),
    ("Series temporelles", "Cours de bourse", "PEUT-ETRE", "★★☆"),
    ("Texte", "Classification de texte", "PEUT-ETRE", "★★☆"),
]

print("-" * 75)
print(f"{'Tache':<25} {'Exemple':<25} {'CNN?':<12} {'Efficacite'}")
print("-" * 75)

for task, example, cnn, eff in use_cases:
    if cnn == "OUI":
        indicator = "✓"
    elif cnn == "NON":
        indicator = "✗"
    else:
        indicator = "~"
    print(f"{task:<25} {example:<25} {indicator} {cnn:<8} {eff}")

print("-" * 75)

print("""

RESUME:
  ✓ OUI: Donnees avec structure spatiale 2D (images)
  ✗ NON: Donnees tabulaires sans structure spatiale
  ~ PEUT-ETRE: Donnees sequentielles (Conv1D peut aider)
""")

print("\n" + "=" * 70)
print("2. COMPARAISON AVEC D'AUTRES MODELES")
print("=" * 70)

print("""
┌────────────────────┬──────────────────────────────────────────────┐
│ CNN                │ Images, donnees spatiales                    │
├────────────────────┼──────────────────────────────────────────────┤
│ RNN/LSTM           │ Sequences, texte, series temporelles         │
├────────────────────┼──────────────────────────────────────────────┤
│ Transformer        │ NLP, traduction, generation de texte         │
├────────────────────┼──────────────────────────────────────────────┤
│ MLP (Dense)        │ Donnees tabulaires sans structure            │
├────────────────────┼──────────────────────────────────────────────┤
│ Random Forest      │ Donnees tabulaires, interpretabilite         │
└────────────────────┴──────────────────────────────────────────────┘
""")

print("\n" + "=" * 70)
print("3. FRAMEWORKS POUR IMPLEMENTER DES CNN")
print("=" * 70)

frameworks = [
    ("PyTorch", "Flexible, recherche", "facebook.com/pytorch"),
    ("TensorFlow/Keras", "Production, simplicite", "tensorflow.org"),
    ("FastAI", "Haut niveau, apprentissage rapide", "fast.ai"),
    ("JAX", "Performance, recherche avancee", "jax.readthedocs.io"),
]

print("-" * 70)
print(f"{'Framework':<20} {'Points forts':<30} {'Site'}")
print("-" * 70)
for name, strength, site in frameworks:
    print(f"{name:<20} {strength:<30} {site}")
print("-" * 70)

print("\n" + "=" * 70)
print("4. ARCHITECTURES CNN CELEBRES")
print("=" * 70)

architectures = [
    ("LeNet-5", 1998, "60K", "Chiffres manuscrits"),
    ("AlexNet", 2012, "60M", "ImageNet, debut deep learning"),
    ("VGG-16", 2014, "138M", "Filtres 3x3, profondeur"),
    ("ResNet-50", 2015, "25M", "Connexions residuelles"),
    ("EfficientNet", 2019, "5-66M", "Scaling optimal"),
]

print("-" * 70)
print(f"{'Architecture':<15} {'Annee':<8} {'Params':<10} {'Innovation'}")
print("-" * 70)
for name, year, params, innovation in architectures:
    print(f"{name:<15} {year:<8} {params:<10} {innovation}")
print("-" * 70)

print("""

EVOLUTION:
  • LeNet → premier CNN fonctionnel
  • AlexNet → victoire ImageNet, GPU
  • VGG → plus profond, filtres simples
  • ResNet → connexions residuelles, tres profond
  • EfficientNet → optimisation de l'architecture
""")

print("\n" + "=" * 70)
print("5. RESSOURCES POUR APPROFONDIR")
print("=" * 70)

print("""
COURS GRATUITS:
  • CS231n Stanford - CNN pour la vision (YouTube)
  • fast.ai - Deep Learning pratique
  • Coursera - Deep Learning Specialization (Andrew Ng)

PRATIQUE:
  • Kaggle - Competitions de vision
  • Papers With Code - Implementations open source
  • GitHub - Projets exemples

LIVRES:
  • Deep Learning (Goodfellow, Bengio, Courville) - gratuit en ligne
  • Dive into Deep Learning - interactif
""")

print("\n" + "=" * 70)
print("              MODULE CNN TERMINE!")
print("=" * 70)
print("""
CE QUE VOUS AVEZ APPRIS:
  ✓ L'operation de convolution et son role
  ✓ Le max pooling pour reduire les dimensions
  ✓ Comment extraire des features convolutives
  ✓ Quand utiliser un CNN vs autres modeles

PROCHAINES ETAPES:
  → Implementer un CNN complet avec PyTorch/TensorFlow
  → Utiliser le transfer learning (modeles pre-entraines)
  → Explorer les techniques d'explicabilite (Grad-CAM)
""")

