|
1 | 1 | import streamlit as st
|
2 | 2 | import joblib
|
3 | 3 | import pandas as pd
|
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import seaborn as sns |
| 6 | +from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_curve, auc |
4 | 7 |
|
5 | 8 | # Load the model
|
6 | 9 | model = joblib.load('model.pkl')
|
|
59 | 62 |
|
60 | 63 | # Display the prediction
|
61 | 64 | st.write(f'Predicted Job Satisfaction: {prediction[0]}')
|
| 65 | + |
| 66 | + # Evaluate the model on test data (assuming y_test and y_pred are available) |
| 67 | + # This part would typically be done during model development, not in the prediction app |
| 68 | + # However, for demonstration purposes, we can create some dummy data |
| 69 | + y_test = [1, 0, 1, 1, 0] # Example true labels |
| 70 | + y_pred = model.predict(input_df) # Example predicted labels |
| 71 | + |
| 72 | + # Print accuracy |
| 73 | + accuracy = accuracy_score(y_test, y_pred) |
| 74 | + st.write(f'Accuracy: {accuracy:.2f}') |
| 75 | + |
| 76 | + # Print classification report |
| 77 | + report = classification_report(y_test, y_pred, output_dict=True) |
| 78 | + st.write('Classification Report:') |
| 79 | + st.write(report) |
| 80 | + |
| 81 | + # Convert classification report to a DataFrame for better readability |
| 82 | + report_df = pd.DataFrame(report).transpose() |
| 83 | + st.write(report_df) |
| 84 | + |
| 85 | + # Plot confusion matrix |
| 86 | + cm = confusion_matrix(y_test, y_pred) |
| 87 | + plt.figure(figsize=(10, 6)) |
| 88 | + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) |
| 89 | + plt.title('Confusion Matrix') |
| 90 | + plt.xlabel('Predicted') |
| 91 | + plt.ylabel('Actual') |
| 92 | + st.pyplot(plt) |
| 93 | + |
| 94 | + # If the model is a binary classifier, plot the ROC curve |
| 95 | + if len(set(y_test)) == 2: |
| 96 | + fpr, tpr, _ = roc_curve(y_test, y_pred) |
| 97 | + roc_auc = auc(fpr, tpr) |
| 98 | + |
| 99 | + plt.figure(figsize=(10, 6)) |
| 100 | + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') |
| 101 | + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') |
| 102 | + plt.xlim([0.0, 1.0]) |
| 103 | + plt.ylim([0.0, 1.05]) |
| 104 | + plt.xlabel('False Positive Rate') |
| 105 | + plt.ylabel('True Positive Rate') |
| 106 | + plt.title('Receiver Operating Characteristic (ROC) Curve') |
| 107 | + plt.legend(loc='lower right') |
| 108 | + st.pyplot(plt) |
0 commit comments