"""
Module: Regression Logistique
Categorie: Supervised Classification
Difficulte: Debutant

Genere depuis la plateforme ML Formation
"""

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score

# Charger le dataset
df = pd.read_csv('binary_classification.csv')

# Explorer les donnees
# Type: Code executable
# =============================================================================
# ETAPE 1 : EXPLORATION DU DATASET DE CLASSIFICATION
# =============================================================================
# En classification, l'exploration des donnees est cruciale pour comprendre :
# 1) La distribution des classes (equilibre ou desequilibre ?)
# 2) La separabilite des classes (les features permettent-elles de distinguer ?)
# 3) Les caracteristiques de chaque classe

print("=" * 70)
print("EXPLORATION DU DATASET DE CLASSIFICATION BINAIRE")
print("=" * 70)
print()

# --- 1.1 Apercu des donnees ---
print("1. APERCU DES DONNEES")
print("-" * 40)
print("Chaque ligne represente un echantillon avec 2 caracteristiques (features)")
print("et une etiquette (label) indiquant sa classe (0 ou 1).")
print()
display(df.head(10), title="Apercu du dataset de classification")

# --- 1.2 Dimensions ---
n_samples, n_cols = df.shape
print()
print("2. DIMENSIONS DU DATASET")
print("-" * 40)
print(f"   Nombre d'echantillons : {n_samples}")
print(f"   Nombre de colonnes    : {n_cols} (2 features + 1 label)")
print()

# --- 1.3 Distribution des classes ---
print("3. DISTRIBUTION DES CLASSES")
print("-" * 40)
class_counts = df['label'].value_counts().sort_index()
total = len(df)

print("   Repartition des echantillons par classe :")
print()
for classe, count in class_counts.items():
    pct = count / total * 100
    bar = "█" * int(pct / 2)
    print(f"   Classe {classe} : {count:4d} echantillons ({pct:5.1f}%) {bar}")

print()

# Analyse de l'equilibre
ratio = class_counts.min() / class_counts.max()
if ratio > 0.8:
    equilibre = "EQUILIBRE"
    emoji = "✓"
    conseil = "Excellent ! Les metriques standard (accuracy) seront fiables."
elif ratio > 0.5:
    equilibre = "LEGEREMENT DESEQUILIBRE"
    emoji = "~"
    conseil = "Acceptable, mais surveillez precision et recall par classe."
else:
    equilibre = "DESEQUILIBRE"
    emoji = "⚠"
    conseil = "Attention ! L'accuracy peut etre trompeuse. Utilisez F1-score."

print(f"   DIAGNOSTIC : {equilibre} {emoji}")
print(f"   Ratio minoritaire/majoritaire : {ratio:.2f}")
print(f"   → {conseil}")
print()

# --- 1.4 Statistiques descriptives ---
print("4. STATISTIQUES DESCRIPTIVES")
print("-" * 40)
print("Voici les statistiques cles de chaque colonne :")
print()
display(df.describe().round(3), title="Statistiques descriptives")

# --- 1.5 Analyse de separabilite ---
print()
print("5. ANALYSE DE SEPARABILITE")
print("-" * 40)
print("   La separabilite mesure si les classes sont distinguables par les features.")
print()

for feature in ['feature1', 'feature2']:
    mean_0 = df[df['label'] == 0][feature].mean()
    mean_1 = df[df['label'] == 1][feature].mean()
    std_0 = df[df['label'] == 0][feature].std()
    std_1 = df[df['label'] == 1][feature].std()

    # Calcul d'une metrique de separabilite simplifiee
    pooled_std = np.sqrt((std_0**2 + std_1**2) / 2)
    if pooled_std > 0:
        separabilite = abs(mean_1 - mean_0) / pooled_std
    else:
        separabilite = 0

    print(f"   {feature}:")
    print(f"      Classe 0 : moyenne = {mean_0:.3f}, std = {std_0:.3f}")
    print(f"      Classe 1 : moyenne = {mean_1:.3f}, std = {std_1:.3f}")
    print(f"      Difference des moyennes : {abs(mean_1 - mean_0):.3f}")

    if separabilite > 1.5:
        verdict = "BONNE separabilite"
    elif separabilite > 0.8:
        verdict = "Separabilite MODEREE"
    else:
        verdict = "Separabilite FAIBLE"

    print(f"      → {verdict} (score: {separabilite:.2f})")
    print()

# --- 1.6 Conclusion ---
print("=" * 70)
print("CONCLUSION DE L'EXPLORATION")
print("=" * 70)
print()
print("   Ce que nous avons appris :")
print("   → Les deux classes sont-elles equilibrees ?")
print("   → Les features permettent-elles de distinguer les classes ?")
print("   → Y a-t-il des patterns visibles dans les statistiques ?")
print()
print("   Prochaine etape : Visualiser les donnees pour confirmer ces observations.")


# Visualiser les classes
# Type: Code executable
# =============================================================================
# ETAPE 2 : VISUALISATION DES CLASSES
# =============================================================================
# La visualisation est essentielle en classification pour :
# 1) Voir si les classes sont naturellement separables
# 2) Identifier la forme de la frontiere de decision necessaire
# 3) Detecter des outliers ou des chevauchements problematiques

print("=" * 70)
print("VISUALISATION DES DEUX CLASSES")
print("=" * 70)
print()
print("Question cle : Les deux classes sont-elles visuellement separables ?")
print("→ Si oui, la regression logistique (frontiere lineaire) conviendra.")
print("→ Si les classes sont melangees, un modele plus complexe sera necessaire.")
print()

# Separation des classes
class_0 = df[df['label'] == 0]
class_1 = df[df['label'] == 1]

# Creation du graphique
plt.figure(figsize=(10, 7))

plt.scatter(class_0['feature1'], class_0['feature2'],
            alpha=0.6, color='#9B7AC4', label='Classe 0', s=60,
            edgecolors='white', linewidth=0.5)
plt.scatter(class_1['feature1'], class_1['feature2'],
            alpha=0.6, color='#F7E64D', label='Classe 1', s=60,
            edgecolors='white', linewidth=0.5)

plt.xlabel('Feature 1', fontsize=12)
plt.ylabel('Feature 2', fontsize=12)
plt.title('Distribution des deux classes dans l\'espace des features', fontsize=14)
plt.legend(loc='best', fontsize=11)
plt.grid(True, alpha=0.3)

# Ajouter les centroïdes
c0_center = (class_0['feature1'].mean(), class_0['feature2'].mean())
c1_center = (class_1['feature1'].mean(), class_1['feature2'].mean())
plt.plot(*c0_center, 'o', color='#9B7AC4', markersize=15, markeredgecolor='black', markeredgewidth=2)
plt.plot(*c1_center, 'o', color='#F7E64D', markersize=15, markeredgecolor='black', markeredgewidth=2)

plt.tight_layout()
plt.show()

# --- Analyse visuelle ---
print()
print("ANALYSE VISUELLE")
print("-" * 40)
print()
print("   LEGENDE :")
print("   → Points violets : Classe 0")
print("   → Points jaunes  : Classe 1")
print("   → Grands cercles : Centres (moyennes) de chaque classe")
print()

# Calcul de la distance entre centroïdes
distance = np.sqrt((c1_center[0] - c0_center[0])**2 + (c1_center[1] - c0_center[1])**2)
print(f"   DISTANCE ENTRE LES CENTRES : {distance:.2f}")
print()

if distance > 2:
    print("   OBSERVATION : Les classes sont BIEN SEPAREES")
    print("   → Une frontiere lineaire devrait bien fonctionner.")
    print("   → La regression logistique est appropriee.")
elif distance > 1:
    print("   OBSERVATION : Les classes ont un CHEVAUCHEMENT PARTIEL")
    print("   → Quelques erreurs de classification sont attendues.")
    print("   → La regression logistique reste un bon choix de base.")
else:
    print("   OBSERVATION : Les classes sont FORTEMENT MELANGEES")
    print("   → La classification sera difficile.")
    print("   → Envisagez des modeles non-lineaires (SVM, arbres).")

print()
print("   QUESTION A SE POSER :")
print("   'Puis-je tracer une droite qui separe globalement les deux couleurs ?'")
print("   → Si oui, la regression logistique conviendra !")


# Entrainer le classificateur
# Type: Code executable
# =============================================================================
# ETAPE 3 : ENTRAINEMENT DU CLASSIFICATEUR
# =============================================================================
# La regression logistique apprend une frontiere de decision lineaire
# qui separe les deux classes dans l'espace des features.

print("=" * 70)
print("ENTRAINEMENT DE LA REGRESSION LOGISTIQUE")
print("=" * 70)
print()

# --- 3.1 Preparation des donnees ---
print("1. PREPARATION DES DONNEES")
print("-" * 40)

X = df[['feature1', 'feature2']].values
y = df['label'].values

print(f"   Features (X) : {X.shape[0]} echantillons x {X.shape[1]} features")
print(f"   Labels (y)   : {len(y)} etiquettes (0 ou 1)")
print()

# --- 3.2 Division train/test ---
print("2. DIVISION TRAIN/TEST")
print("-" * 40)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"   Ensemble TRAIN : {len(X_train)} echantillons (80%)")
print(f"   Ensemble TEST  : {len(X_test)} echantillons (20%)")
print()
print("   Note : 'stratify=y' assure la meme proportion de classes")
print("          dans train et test (important si classes desequilibrees).")
print()

# Verification de la stratification
train_ratio = y_train.mean()
test_ratio = y_test.mean()
print(f"   Verification stratification :")
print(f"   → % classe 1 dans train : {train_ratio:.1%}")
print(f"   → % classe 1 dans test  : {test_ratio:.1%}")
print()

# --- 3.3 Entrainement ---
print("3. ENTRAINEMENT DU MODELE")
print("-" * 40)
print("   Le modele cherche les coefficients qui maximisent")
print("   la vraisemblance des observations...")
print()

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

print("   ✓ Modele entraine avec succes !")
print()

# --- 3.4 Parametres appris ---
print("4. PARAMETRES APPRIS")
print("-" * 40)

coef1, coef2 = model.coef_[0]
intercept = model.intercept_[0]

print(f"   Equation de la frontiere de decision :")
print(f"   z = {coef1:.4f} × feature1 + {coef2:.4f} × feature2 + ({intercept:.4f})")
print()
print("   INTERPRETATION DES COEFFICIENTS :")
print()

print(f"   • Coefficient feature1 : {coef1:+.4f}")
if coef1 > 0:
    print(f"     → Une augmentation de feature1 AUGMENTE la probabilite de classe 1")
else:
    print(f"     → Une augmentation de feature1 DIMINUE la probabilite de classe 1")
print(f"     → Impact : exp({coef1:.4f}) = {np.exp(coef1):.3f}x sur les odds par unite")
print()

print(f"   • Coefficient feature2 : {coef2:+.4f}")
if coef2 > 0:
    print(f"     → Une augmentation de feature2 AUGMENTE la probabilite de classe 1")
else:
    print(f"     → Une augmentation de feature2 DIMINUE la probabilite de classe 1")
print(f"     → Impact : exp({coef2:.4f}) = {np.exp(coef2):.3f}x sur les odds par unite")
print()

print(f"   • Intercept (biais) : {intercept:+.4f}")
print(f"     → C'est le score de base quand feature1 = feature2 = 0")
print()

# --- 3.5 Test rapide ---
print("5. TEST RAPIDE DE PREDICTION")
print("-" * 40)

sample = X_test[0]
pred_class = model.predict([sample])[0]
pred_proba = model.predict_proba([sample])[0]

print(f"   Echantillon : feature1={sample[0]:.3f}, feature2={sample[1]:.3f}")
print(f"   Prediction  : Classe {pred_class}")
print(f"   Probabilites: P(classe 0)={pred_proba[0]:.1%}, P(classe 1)={pred_proba[1]:.1%}")
print()
print("   Le modele ne predit pas juste une classe, mais une PROBABILITE.")
print("   C'est tres utile pour quantifier l'incertitude !")

print()
print("=" * 70)
print("MODELE ENTRAINE ! Passons a l'evaluation.")
print("=" * 70)


# Matrice de confusion
# Type: Code executable
# =============================================================================
# ETAPE 4 : MATRICE DE CONFUSION
# =============================================================================
# La matrice de confusion est L'OUTIL CENTRAL pour evaluer un classificateur.
# Elle montre EXACTEMENT ou le modele se trompe.

print("=" * 70)
print("MATRICE DE CONFUSION : COMPRENDRE LES ERREURS")
print("=" * 70)
print()

# Preparation et entrainement
X = df[['feature1', 'feature2']].values
y = df['label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

# Predictions
y_pred = model.predict(X_test)

# --- 4.1 Concept ---
print("1. COMPRENDRE LA MATRICE DE CONFUSION")
print("-" * 40)
print()
print("   La matrice compare les predictions aux vraies etiquettes :")
print()
print("                    PREDICTION")
print("                  Classe 0  Classe 1")
print("            ┌──────────┬──────────┐")
print("   REALITE  │    TN    │    FP    │  Classe 0")
print("            │          │          │")
print("            ├──────────┼──────────┤")
print("            │    FN    │    TP    │  Classe 1")
print("            └──────────┴──────────┘")
print()
print("   Vocabulaire :")
print("   • TN (True Negative)  : Correctement predit Classe 0")
print("   • TP (True Positive)  : Correctement predit Classe 1")
print("   • FP (False Positive) : Predit 1 alors que c'etait 0 (FAUSSE ALERTE)")
print("   • FN (False Negative) : Predit 0 alors que c'etait 1 (RATE)")
print()

# --- 4.2 Notre matrice ---
print("2. NOTRE MATRICE DE CONFUSION")
print("-" * 40)

cm = confusion_matrix(y_test, y_pred)
tn, fp, fn, tp = cm.ravel()

print()
print(f"                    PREDICTION")
print(f"                  Classe 0  Classe 1")
print(f"            ┌──────────┬──────────┐")
print(f"   REALITE  │   {tn:3d}    │   {fp:3d}    │  Classe 0")
print(f"            │   (TN)   │   (FP)   │")
print(f"            ├──────────┼──────────┤")
print(f"            │   {fn:3d}    │   {tp:3d}    │  Classe 1")
print(f"            │   (FN)   │   (TP)   │")
print(f"            └──────────┴──────────┘")
print()

# --- 4.3 Interpretation detaillee ---
print("3. INTERPRETATION DETAILLEE")
print("-" * 40)
print()

total = len(y_test)
corrects = tn + tp
erreurs = fp + fn

print(f"   PREDICTIONS CORRECTES : {corrects}/{total} ({corrects/total:.1%})")
print(f"   → Vrais Negatifs (TN) : {tn} - Le modele a correctement identifie {tn} classe 0")
print(f"   → Vrais Positifs (TP) : {tp} - Le modele a correctement identifie {tp} classe 1")
print()

print(f"   ERREURS : {erreurs}/{total} ({erreurs/total:.1%})")
print(f"   → Faux Positifs (FP) : {fp} - FAUSSES ALERTES")
print(f"     Le modele a predit 'Classe 1' alors que c'etait 'Classe 0'")
if fp > 0:
    print(f"     Impact : {fp} fois on a cru detecter quelque chose qui n'existait pas")
print()
print(f"   → Faux Negatifs (FN) : {fn} - CAS RATES")
print(f"     Le modele a predit 'Classe 0' alors que c'etait 'Classe 1'")
if fn > 0:
    print(f"     Impact : {fn} fois on a rate une detection importante")
print()

# --- 4.4 Contexte metier ---
print("4. IMPORTANCE SELON LE CONTEXTE METIER")
print("-" * 40)
print()
print("   Le 'cout' des erreurs depend du domaine d'application :")
print()
print("   DETECTION DE MALADIE :")
print("   → FN (rate) = CATASTROPHIQUE : patient malade non detecte")
print("   → FP (fausse alerte) = Acceptable : examen supplementaire")
print("   → Priorite : MINIMISER les FN (maximiser le Recall)")
print()
print("   DETECTION DE SPAM :")
print("   → FP (fausse alerte) = GENANT : email important en spam")
print("   → FN (rate) = Acceptable : spam dans la boite de reception")
print("   → Priorite : MINIMISER les FP (maximiser la Precision)")
print()
print("   NOTRE MODELE :")

if fn > fp:
    print(f"   → Plus de FN ({fn}) que de FP ({fp})")
    print("   → Le modele est 'prudent' : il sous-detecte la classe 1")
elif fp > fn:
    print(f"   → Plus de FP ({fp}) que de FN ({fn})")
    print("   → Le modele est 'alarmiste' : il sur-detecte la classe 1")
else:
    print(f"   → FP ({fp}) = FN ({fn})")
    print("   → Le modele est equilibre dans ses erreurs")

print()
print("=" * 70)
print("La matrice de confusion : votre meilleur ami pour le diagnostic !")
print("=" * 70)


# Metriques de classification
# Type: Code executable
# =============================================================================
# ETAPE 5 : METRIQUES DE CLASSIFICATION
# =============================================================================
# Les metriques quantifient la performance du modele sous differents angles.
# Chaque metrique repond a une question specifique.

print("=" * 70)
print("METRIQUES DE CLASSIFICATION : EVALUER LA PERFORMANCE")
print("=" * 70)
print()

# Preparation, entrainement et prediction
X = df[['feature1', 'feature2']].values
y = df['label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

# Matrice de confusion pour reference
cm = confusion_matrix(y_test, y_pred)
tn, fp, fn, tp = cm.ravel()

# --- 5.1 Calcul des metriques ---
print("1. CALCUL DES METRIQUES")
print("-" * 40)
print()

acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
rec = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print(f"   Accuracy  : {acc:.4f} ({acc:.1%})")
print(f"   Precision : {prec:.4f} ({prec:.1%})")
print(f"   Recall    : {rec:.4f} ({rec:.1%})")
print(f"   F1-Score  : {f1:.4f} ({f1:.1%})")
print()

# --- 5.2 Interpretation de chaque metrique ---
print("2. INTERPRETATION DETAILLEE")
print("-" * 40)
print()

print(f"   ACCURACY = {acc:.1%}")
print(f"   Formule : (TP + TN) / Total = ({tp} + {tn}) / {len(y_test)}")
print(f"   Question : 'Quel pourcentage de predictions est correct ?'")
print(f"   → {acc:.1%} des predictions sont correctes.")
print()
print("   ⚠ ATTENTION : L'accuracy peut etre trompeuse avec des classes desequilibrees !")
print("   Exemple : 95% de classe 0 → Un modele qui predit toujours 0 a 95% d'accuracy")
print()

print(f"   PRECISION = {prec:.1%}")
print(f"   Formule : TP / (TP + FP) = {tp} / ({tp} + {fp})")
print(f"   Question : 'Parmi les predictions positives, combien sont correctes ?'")
print(f"   → Quand le modele dit 'Classe 1', il a raison {prec:.1%} du temps.")
print(f"   → Une precision de {prec:.1%} signifie {(1-prec)*100:.1f}% de fausses alertes.")
print()

print(f"   RECALL (Sensibilite) = {rec:.1%}")
print(f"   Formule : TP / (TP + FN) = {tp} / ({tp} + {fn})")
print(f"   Question : 'Parmi les vrais positifs, combien sont detectes ?'")
print(f"   → Le modele detecte {rec:.1%} des vrais cas positifs.")
print(f"   → Un recall de {rec:.1%} signifie que {(1-rec)*100:.1f}% des positifs sont rates.")
print()

print(f"   F1-SCORE = {f1:.1%}")
print(f"   Formule : 2 × (Precision × Recall) / (Precision + Recall)")
print(f"   Question : 'Quelle est la performance equilibree ?'")
print(f"   → Le F1 combine precision et recall en une seule metrique.")
print(f"   → Un F1 de {f1:.1%} indique un bon equilibre entre les deux.")
print()

# --- 5.3 Guide d'interpretation ---
print("3. GUIDE D'INTERPRETATION")
print("-" * 40)
print()
print("   ECHELLE DE QUALITE :")
print("   > 0.90 : Excellent")
print("   > 0.80 : Bon")
print("   > 0.70 : Acceptable")
print("   > 0.60 : Mediocre")
print("   < 0.60 : Insuffisant")
print()

# Verdict
print("   VERDICT POUR NOTRE MODELE :")
if f1 > 0.85:
    print(f"   → F1 = {f1:.1%} : EXCELLENT ! Le modele est tres performant.")
elif f1 > 0.75:
    print(f"   → F1 = {f1:.1%} : BON. Le modele est utilisable en production.")
elif f1 > 0.65:
    print(f"   → F1 = {f1:.1%} : ACCEPTABLE. Peut etre ameliore.")
else:
    print(f"   → F1 = {f1:.1%} : A AMELIORER. Envisagez plus de features ou autre modele.")
print()

# --- 5.4 Quelle metrique choisir ? ---
print("4. QUELLE METRIQUE PRIVILEGIER ?")
print("-" * 40)
print()
print("   Cela depend du COUT DES ERREURS dans votre domaine :")
print()
print("   | Domaine              | Priorite      | Metrique cle |")
print("   |----------------------|---------------|--------------|")
print("   | Diagnostic medical   | Ne rien rater | RECALL       |")
print("   | Detection spam       | Pas de faux + | PRECISION    |")
print("   | Classification gen.  | Equilibre     | F1-SCORE     |")
print("   | Classes equilibrees  | Global        | ACCURACY     |")
print()

print("=" * 70)
print("Choisissez la metrique adaptee a votre probleme metier !")
print("=" * 70)


# Visualiser la frontiere de decision
# Type: Code executable
# =============================================================================
# ETAPE 6 : VISUALISATION DE LA FRONTIERE DE DECISION
# =============================================================================
# La frontiere de decision montre comment le modele "decoupe" l'espace
# des features pour separer les classes.

print("=" * 70)
print("FRONTIERE DE DECISION : COMMENT LE MODELE SEPARE LES CLASSES")
print("=" * 70)
print()

# Preparation et entrainement
X = df[['feature1', 'feature2']].values
y = df['label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

print("Ce graphique montre :")
print("→ La zone coloree : Les regions de decision du modele")
print("→ Les points : Les donnees de test (vraies etiquettes)")
print("→ La frontiere : La ligne ou P(classe 1) = 50%")
print()

# Creer une grille pour visualiser la decision
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                     np.linspace(y_min, y_max, 200))

# Predictions sur la grille
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# Visualisation
plt.figure(figsize=(12, 8))

# Zones de decision
plt.contourf(xx, yy, Z, alpha=0.3, levels=[-0.5, 0.5, 1.5],
             colors=['#E5D7F5', '#FFF9D9'])
plt.contour(xx, yy, Z, levels=[0.5], colors=['#3A3A3A'], linewidths=2, linestyles='--')

# Points de test
plt.scatter(X_test[y_test==0, 0], X_test[y_test==0, 1],
            color='#9B7AC4', label='Classe 0 (reelle)', s=80,
            edgecolors='white', linewidth=1)
plt.scatter(X_test[y_test==1, 0], X_test[y_test==1, 1],
            color='#F7E64D', label='Classe 1 (reelle)', s=80,
            edgecolors='white', linewidth=1)

# Marquer les erreurs
y_pred = model.predict(X_test)
errors = y_test != y_pred
if errors.any():
    plt.scatter(X_test[errors, 0], X_test[errors, 1],
                facecolors='none', edgecolors='red', s=200, linewidths=2,
                label=f'Erreurs ({errors.sum()})')

plt.xlabel('Feature 1', fontsize=12)
plt.ylabel('Feature 2', fontsize=12)
plt.title('Frontiere de decision de la Regression Logistique', fontsize=14, fontweight='bold')
plt.legend(loc='best', fontsize=10)

# Ajouter l'equation de la frontiere
coef = model.coef_[0]
intercept = model.intercept_[0]
equation = f'Frontiere: {coef[0]:.2f}×x1 + {coef[1]:.2f}×x2 + {intercept:.2f} = 0'
plt.annotate(equation, xy=(0.02, 0.02), xycoords='axes fraction',
             fontsize=10, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

# --- Analyse ---
print()
print("ANALYSE DE LA FRONTIERE")
print("-" * 40)
print()
print("   INTERPRETATION GEOMETRIQUE :")
print("   → La frontiere est une LIGNE DROITE (modele lineaire)")
print("   → Tous les points d'un cote sont predits 'Classe 0'")
print("   → Tous les points de l'autre cote sont predits 'Classe 1'")
print()

print("   EQUATION DE LA FRONTIERE :")
print(f"   {coef[0]:.3f} × feature1 + {coef[1]:.3f} × feature2 + {intercept:.3f} = 0")
print()
print("   Sur cette ligne, P(classe 1) = 50% exactement.")
print("   C'est le point d'indecision du modele.")
print()

n_errors = errors.sum()
if n_errors > 0:
    print(f"   ERREURS VISUALISEES : {n_errors} points cercles en rouge")
    print("   → Ces points sont du 'mauvais cote' de la frontiere")
    print("   → Soit le modele est imparfait, soit ce sont des outliers")
else:
    print("   AUCUNE ERREUR ! Le modele separe parfaitement les classes de test.")


# Expliquer une prediction individuelle
# Type: Code executable
# =============================================================================
# EXPLICABILITE : DECOMPOSER UNE PREDICTION
# =============================================================================
# Pouvoir expliquer POURQUOI le modele a fait une prediction est crucial
# pour la confiance et la conformite reglementaire.

print("=" * 70)
print("EXPLICABILITE : POURQUOI CETTE PREDICTION ?")
print("=" * 70)
print()

# Preparation et entrainement
X = df[['feature1', 'feature2']].values
y = df['label'].values
feature_names = ['feature1', 'feature2']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

# Prendre un exemple du test set
idx = 0
sample = X_test[idx]
true_label = y_test[idx]
prediction = model.predict([sample])[0]
proba = model.predict_proba([sample])[0]

# --- Presentation de l'echantillon ---
print("1. ECHANTILLON ANALYSE")
print("-" * 40)
print()
print(f"   Valeurs des features :")
for i, name in enumerate(feature_names):
    print(f"   → {name} = {sample[i]:.3f}")
print()
print(f"   Vraie classe  : {true_label}")
print(f"   Prediction    : {prediction}")
print(f"   Probabilites  : P(classe 0) = {proba[0]:.1%}, P(classe 1) = {proba[1]:.1%}")
print()

if prediction == true_label:
    print("   ✓ Prediction CORRECTE !")
else:
    print("   ✗ Prediction INCORRECTE (erreur du modele)")
print()

# --- Decomposition mathematique ---
print("2. DECOMPOSITION DE LA DECISION")
print("-" * 40)
print()
print("   Rappel de l'equation :")
print("   z = coef1 × feature1 + coef2 × feature2 + intercept")
print("   P(classe 1) = sigmoid(z) = 1 / (1 + exp(-z))")
print()

intercept = model.intercept_[0]
coefs = model.coef_[0]

print("   CALCUL ETAPE PAR ETAPE :")
print()
print(f"   Base (intercept) : {intercept:+.4f}")

z = intercept
contributions = []
for i, name in enumerate(feature_names):
    contribution = coefs[i] * sample[i]
    contributions.append(contribution)
    z += contribution
    print(f"   + {name} ({sample[i]:.3f} × {coefs[i]:+.4f}) : {contribution:+.4f}")

print(f"   {'─' * 35}")
print(f"   = Score total (z) : {z:+.4f}")
print()

# Calcul de la probabilite
prob_classe1 = 1 / (1 + np.exp(-z))
print(f"   Conversion en probabilite :")
print(f"   P(classe 1) = 1 / (1 + exp(-{z:.4f}))")
print(f"   P(classe 1) = 1 / (1 + {np.exp(-z):.4f})")
print(f"   P(classe 1) = {prob_classe1:.4f} = {prob_classe1:.1%}")
print()

# Decision
print("   DECISION :")
if prob_classe1 > 0.5:
    print(f"   {prob_classe1:.1%} > 50% → Prediction : Classe 1")
else:
    print(f"   {prob_classe1:.1%} < 50% → Prediction : Classe 0")
print()

# --- Explication en langage naturel ---
print("3. EXPLICATION EN LANGAGE NATUREL")
print("-" * 40)
print()
print(f"   'Cet echantillon est classe {prediction} car :'")
print()

# Trier les contributions par importance
contrib_with_names = list(zip(feature_names, contributions, coefs))
contrib_with_names.sort(key=lambda x: abs(x[1]), reverse=True)

for name, contrib, coef in contrib_with_names:
    if contrib > 0:
        direction = "AUGMENTE"
        impact = "pousse vers classe 1"
    else:
        direction = "DIMINUE"
        impact = "pousse vers classe 0"

    print(f"   → {name} {direction} la probabilite de classe 1")
    print(f"     (contribution : {contrib:+.3f}, {impact})")
    print()

print(f"   Resultat final : probabilite de classe 1 = {prob_classe1:.1%}")

print()
print("=" * 70)
print("L'explicabilite renforce la confiance dans le modele !")
print("=" * 70)


# SHAP pour la classification
# Type: Code executable
# =============================================================================
# SHAP : EXPLICABILITE AVANCEE
# =============================================================================
# SHAP (SHapley Additive exPlanations) formalise l'explicabilite
# de maniere rigoureuse et mathematiquement fondee.

print("=" * 70)
print("SHAP : EXPLICABILITE STANDARDISEE")
print("=" * 70)
print()
print("SHAP attribue a chaque feature une 'contribution' a la prediction.")
print("C'est la methode d'explicabilite la plus rigoureuse et reconnue.")
print()

# Preparation et entrainement
X = df[['feature1', 'feature2']].values
y = df['label'].values
feature_names = ['feature1', 'feature2']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

# --- Calcul SHAP ---
print("1. CALCUL DES VALEURS SHAP")
print("-" * 40)

explainer = shap.LinearExplainer(model, X_train)
shap_values = explainer.shap_values(X_test)

print("   ✓ Valeurs SHAP calculees pour chaque prediction")
print(f"   → {len(X_test)} echantillons analyses")
print(f"   → {len(feature_names)} features par echantillon")
print()

# --- Explication individuelle ---
print("2. EXPLICATION D'UNE PREDICTION")
print("-" * 40)

idx = 0
sample = X_test[idx]
prediction = model.predict([sample])[0]
proba = model.predict_proba([sample])[0]

print()
print(f"   ECHANTILLON {idx} :")
print(f"   → feature1 = {sample[0]:.3f}")
print(f"   → feature2 = {sample[1]:.3f}")
print(f"   → Prediction : Classe {prediction} (P={proba[prediction]:.1%})")
print()

print("   CONTRIBUTIONS SHAP :")
base_value = explainer.expected_value
if hasattr(base_value, '__len__'):
    base_value = base_value[0]

print(f"   → Valeur de base (moyenne) : {base_value:.4f}")
for i, name in enumerate(feature_names):
    shap_val = shap_values[idx][i]
    direction = "↑ classe 1" if shap_val > 0 else "↓ classe 0"
    print(f"   → {name:12} : {shap_val:+.4f} ({direction})")

total_shap = sum(shap_values[idx])
print(f"   → Total des contributions : {total_shap:+.4f}")
print(f"   → Score final : {base_value:.4f} + {total_shap:+.4f} = {base_value + total_shap:.4f}")
print()

# --- Visualisation globale ---
print("3. IMPORTANCE GLOBALE DES FEATURES")
print("-" * 40)
print()
print("   Ce graphique montre l'impact de chaque feature sur TOUTES les predictions :")
print("   → Axe X : Contribution SHAP (positif = vers classe 1, negatif = vers classe 0)")
print("   → Couleur : Valeur de la feature (rouge = haute, bleu = basse)")
print()

plt.figure(figsize=(10, 5))
shap.summary_plot(shap_values, X_test, feature_names=feature_names, show=False)
plt.title('Impact SHAP sur la classification', fontsize=14)
plt.tight_layout()
plt.show()

# --- Interpretation ---
print()
print("4. INTERPRETATION DU GRAPHIQUE SHAP")
print("-" * 40)
print()
print("   COMMENT LIRE CE GRAPHIQUE :")
print()
print("   • Chaque point = un echantillon du dataset de test")
print("   • Position horizontale = contribution a la prediction")
print("   • Couleur = valeur de la feature pour cet echantillon")
print()
print("   PATTERNS A OBSERVER :")
print("   → Points rouges a droite : Valeurs hautes → classe 1")
print("   → Points bleus a gauche : Valeurs basses → classe 0")
print("   → Spread horizontal large : Feature tres influente")
print("   → Spread horizontal etroit : Feature peu influente")
print()

# Calculer l'importance moyenne
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
for i, name in enumerate(feature_names):
    print(f"   Importance moyenne de {name} : {mean_abs_shap[i]:.4f}")

most_important = feature_names[np.argmax(mean_abs_shap)]
print()
print(f"   → Feature la plus importante : {most_important}")

print()
print("=" * 70)
print("SHAP : L'explicabilite au niveau industriel !")
print("=" * 70)

