Skip to content

Commit

Permalink
Update visualization.py
Browse files Browse the repository at this point in the history
Updated Visualizations and requirements using seaborn
  • Loading branch information
dendarko authored Aug 22, 2024
1 parent 1e09ac6 commit bf2b9ea
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions visualizations/visualization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

sns.set(style="whitegrid") # Set the seaborn style

def plot_feature_importance(model, features, dataset_name):
"""Generate a bar plot for feature importance."""
importance = model.coef_
Expand All @@ -12,27 +15,24 @@ def plot_feature_importance(model, features, dataset_name):
sorted_importance = importance[indices]

plt.figure(figsize=(12, 8))
sns.barplot(x=sorted_importance, y=sorted_features, palette="coolwarm")
plt.title(f"Feature Importance for {dataset_name}", fontsize=16)
plt.barh(range(len(indices)), sorted_importance, align='center', color='skyblue')
plt.yticks(range(len(indices)), sorted_features, fontsize=12)
plt.xlabel('Relative Importance', fontsize=14)
plt.gca().invert_yaxis() # Most important feature at the top
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.ylabel('Features', fontsize=14)

plt.savefig(f'static/images/feature_importance_{dataset_name}.png', bbox_inches='tight')
plt.close()

def plot_predictions_vs_true(y_true, y_pred, dataset_name):
"""Generate a scatter plot comparing true vs. predicted values."""
plt.figure(figsize=(12, 8))
plt.scatter(y_true, y_pred, edgecolors='k', color='dodgerblue', s=100, alpha=0.6, label='Predictions')
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], 'r--', lw=2, label='Perfect Prediction Line')
sns.scatterplot(x=y_true, y=y_pred, color='dodgerblue', s=100, edgecolor='k', alpha=0.6)
sns.lineplot(x=y_true, y=y_true, color='red', linestyle='--', lw=2) # Perfect prediction line

plt.title(f"Predictions vs. True Values for {dataset_name}", fontsize=16)
plt.xlabel('True Values', fontsize=14)
plt.ylabel('Predictions', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)

plt.savefig(f'static/images/predictions_vs_true_{dataset_name}.png', bbox_inches='tight')
plt.close()

0 comments on commit bf2b9ea

Please sign in to comment.