Skip to content

Commit

Permalink
Update visualization.py
Browse files Browse the repository at this point in the history
Updated the visualization
  • Loading branch information
dendarko authored Aug 22, 2024
1 parent d63915e commit 3034851
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions visualizations/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,34 @@
def plot_feature_importance(model, features, dataset_name):
"""Generate a bar plot for feature importance."""
importance = model.coef_
indices = np.argsort(importance)

plt.figure(figsize=(10, 6))
plt.title(f"Feature Importance for {dataset_name}")
plt.barh(range(len(indices)), importance[indices], align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.savefig(f'static/images/feature_importance_{dataset_name}.png')
# Sort features by importance
indices = np.argsort(importance)[::-1]
sorted_features = [features[i] for i in indices]
sorted_importance = importance[indices]

plt.figure(figsize=(12, 8))
plt.title(f"Feature Importance for {dataset_name}", fontsize=16)
plt.barh(range(len(indices)), sorted_importance, align='center', color='skyblue')
plt.yticks(range(len(indices)), sorted_features, fontsize=12)
plt.xlabel('Relative Importance', fontsize=14)
plt.gca().invert_yaxis() # Most important feature at the top
plt.grid(axis='x', linestyle='--', alpha=0.7)

plt.savefig(f'static/images/feature_importance_{dataset_name}.png', bbox_inches='tight')
plt.close()

def plot_predictions_vs_true(y_true, y_pred, dataset_name):
"""Generate a scatter plot comparing true vs. predicted values."""
plt.figure(figsize=(10, 6))
plt.scatter(y_true, y_pred, edgecolors=(0, 0, 0))
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], 'k--', lw=3)
plt.title(f"Predictions vs. True Values for {dataset_name}")
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.savefig(f'static/images/predictions_vs_true_{dataset_name}.png')
plt.figure(figsize=(12, 8))
plt.scatter(y_true, y_pred, edgecolors='k', color='dodgerblue', s=100, alpha=0.6, label='Predictions')
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], 'r--', lw=2, label='Perfect Prediction Line')

plt.title(f"Predictions vs. True Values for {dataset_name}", fontsize=16)
plt.xlabel('True Values', fontsize=14)
plt.ylabel('Predictions', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)

plt.savefig(f'static/images/predictions_vs_true_{dataset_name}.png', bbox_inches='tight')
plt.close()

0 comments on commit 3034851

Please sign in to comment.