Spaces:
Running
Running
| """ | |
| core_agent.py | |
| ============= | |
| DataMind Agent β TRUE Agentic AI + Multi-LLM Support | |
| Providers: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok, | |
| Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI) | |
| File formats: CSV, Excel (.xlsx, .xls), JSON | |
| """ | |
| import os | |
| import json | |
| import warnings | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from dotenv import load_dotenv | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| warnings.filterwarnings("ignore") | |
| load_dotenv() | |
| # βββ Palette ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PALETTE = ["#6C63FF", "#FF6584", "#43E97B", "#F7971E", "#4FC3F7", "#CE93D8"] | |
| DARK_BG = "#0F0F1A" | |
| CARD_BG = "#1A1A2E" | |
| # βββ Global state (shared across agent tools) βββββββββββββββββββββββββββββββββ | |
| _df = None | |
| _profile = None | |
| def set_dataframe(df, profile): | |
| global _df, _profile | |
| _df = df | |
| _profile = profile | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PROVIDER REGISTRY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PROVIDERS = { | |
| "gemini": { | |
| "name": "Google Gemini", | |
| "models": [ | |
| "gemini-2.5-flash", | |
| "gemini-2.5-pro", | |
| "gemini-2.0-flash", | |
| "gemini-1.5-pro-002", | |
| "gemini-1.5-flash-002", | |
| ], | |
| "default": "gemini-2.5-flash", | |
| "key_hint": "AIza...", | |
| "color": "#4285f4", | |
| "key_url": "https://aistudio.google.com/app/apikey", | |
| }, | |
| "openai": { | |
| "name": "OpenAI GPT", | |
| "models": [ | |
| "gpt-4o", | |
| "gpt-4o-mini", | |
| "gpt-4-turbo", | |
| "gpt-3.5-turbo-0125", | |
| ], | |
| "default": "gpt-4o", | |
| "key_hint": "sk-...", | |
| "color": "#10a37f", | |
| "key_url": "https://platform.openai.com/api-keys", | |
| }, | |
| "claude": { | |
| "name": "Anthropic Claude", | |
| "models": [ | |
| "claude-opus-4-6", | |
| "claude-sonnet-4-6", | |
| "claude-haiku-4-5-20251001", | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-haiku-20241022", | |
| ], | |
| "default": "claude-sonnet-4-6", | |
| "key_hint": "sk-ant-...", | |
| "color": "#d97706", | |
| "key_url": "https://console.anthropic.com/", | |
| }, | |
| "grok": { | |
| "name": "xAI Grok", | |
| "models": [ | |
| "grok-3", | |
| "grok-3-mini", | |
| "grok-2-1212", | |
| ], | |
| "default": "grok-3", | |
| "key_hint": "xai-...", | |
| "color": "#9b9b9b", | |
| "key_url": "https://console.x.ai/", | |
| }, | |
| "mistral": { | |
| "name": "Mistral AI", | |
| "models": [ | |
| "mistral-large-2411", | |
| "mistral-small-2409", | |
| "open-mixtral-8x22b", | |
| ], | |
| "default": "mistral-large-2411", | |
| "key_hint": "...", | |
| "color": "#ff6b35", | |
| "key_url": "https://console.mistral.ai/", | |
| }, | |
| "llama": { | |
| "name": "Meta Llama (Together AI)", | |
| "models": [ | |
| "meta-llama/llama-4-maverick", | |
| "meta-llama/llama-4-scout", | |
| "meta-llama/llama-3.3-70b-instruct", | |
| "meta-llama/llama-3.1-405b", | |
| "meta-llama/llama-3.1-70b", | |
| ], | |
| "default": "meta-llama/llama-4-maverick", | |
| "key_hint": "Together AI key...", | |
| "color": "#0668E1", | |
| "key_url": "https://api.together.ai/", | |
| "note": "Requires a Together AI API key", | |
| }, | |
| "qwen": { | |
| "name": "Alibaba Qwen (Together AI)", | |
| "models": [ | |
| "Qwen/qwen2.5-72b-instruct", | |
| "Qwen/qwen2.5-coder-32b", | |
| "Qwen/qwen2-72b-instruct", | |
| ], | |
| "default": "Qwen/qwen2.5-72b-instruct", | |
| "key_hint": "Together AI key...", | |
| "color": "#6547d4", | |
| "key_url": "https://api.together.ai/", | |
| "note": "Requires a Together AI API key", | |
| }, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LLM FACTORY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_llm(provider: str, api_key: str, model: str = None): | |
| model = model or PROVIDERS[provider]["default"] | |
| if provider == "gemini": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI( | |
| model=model, google_api_key=api_key, | |
| temperature=0.3, convert_system_message_to_human=True, | |
| ) | |
| elif provider == "openai": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=model, api_key=api_key, temperature=0.3) | |
| elif provider == "claude": | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic(model=model, api_key=api_key, temperature=0.3) | |
| elif provider == "grok": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=model, api_key=api_key, | |
| base_url="https://api.x.ai/v1", temperature=0.3) | |
| elif provider == "mistral": | |
| from langchain_mistralai import ChatMistralAI | |
| return ChatMistralAI(model=model, api_key=api_key, temperature=0.3) | |
| elif provider in ("llama", "qwen"): | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=model, api_key=api_key, | |
| base_url="https://api.together.xyz/v1", temperature=0.3) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def validate_llm(provider: str, api_key: str, model: str = None): | |
| llm = get_llm(provider, api_key, model) | |
| llm.invoke([HumanMessage(content="Say OK")]) | |
| return llm, f"Connected to {PROVIDERS[provider]['name']}!" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FILE LOADING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_file(file): | |
| name = file.name.lower() | |
| if name.endswith(".csv"): | |
| return pd.read_csv(file), "CSV" | |
| elif name.endswith((".xlsx", ".xls")): | |
| return pd.read_excel(file), "Excel" | |
| elif name.endswith(".json"): | |
| content = json.load(file) | |
| if isinstance(content, list): | |
| df = pd.DataFrame(content) | |
| else: | |
| df = pd.DataFrame(content) if any(isinstance(v, list) for v in content.values()) \ | |
| else pd.DataFrame([content]) | |
| return df, "JSON" | |
| else: | |
| raise ValueError(f"Unsupported file type: {name}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DATA PROFILING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def profile_dataframe(df): | |
| numeric_cols = df.select_dtypes(include="number").columns.tolist() | |
| category_cols = df.select_dtypes(include=["object", "category"]).columns.tolist() | |
| datetime_cols = df.select_dtypes(include=["datetime"]).columns.tolist() | |
| profile = { | |
| "shape": df.shape, | |
| "columns": df.columns.tolist(), | |
| "dtypes": df.dtypes.astype(str).to_dict(), | |
| "numeric_columns": numeric_cols, | |
| "categorical_columns": category_cols, | |
| "datetime_columns": datetime_cols, | |
| "null_counts": df.isnull().sum().to_dict(), | |
| "null_pct": (df.isnull().mean() * 100).round(2).to_dict(), | |
| "duplicates": int(df.duplicated().sum()), | |
| } | |
| if numeric_cols: | |
| profile["numeric_stats"] = df[numeric_cols].describe().round(3).to_dict() | |
| if category_cols: | |
| profile["top_categories"] = { | |
| col: df[col].value_counts().head(5).to_dict() for col in category_cols | |
| } | |
| return profile | |
| def profile_to_text(profile, df): | |
| rows, cols = profile["shape"] | |
| lines = [ | |
| f"Dataset: {rows} rows x {cols} columns", | |
| f"Numeric columns : {', '.join(profile['numeric_columns']) or 'None'}", | |
| f"Categorical cols : {', '.join(profile['categorical_columns']) or 'None'}", | |
| f"Datetime cols : {', '.join(profile['datetime_columns']) or 'None'}", | |
| f"Missing values : {sum(profile['null_counts'].values())} total", | |
| f"Duplicate rows : {profile['duplicates']}", | |
| "", "--- Sample Data (first 5 rows) ---", | |
| df.head(5).to_string(index=False), | |
| ] | |
| if profile.get("numeric_stats"): | |
| lines += ["", "--- Numeric Stats ---"] | |
| for col, stats in profile["numeric_stats"].items(): | |
| lines.append(f" {col}: mean={stats.get('mean','?')}, std={stats.get('std','?')}, " | |
| f"min={stats.get('min','?')}, max={stats.get('max','?')}") | |
| return "\n".join(lines) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AGENT TOOLS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def profile_data(query: str) -> str: | |
| """Get full statistical profile of the dataset. Use this FIRST before any analysis.""" | |
| if _df is None: | |
| return "No dataset loaded. Please upload a file first." | |
| return profile_to_text(_profile, _df) | |
| def analyze_column(column_name: str) -> str: | |
| """Deeply analyze a specific column. Provide the exact column name.""" | |
| if _df is None: | |
| return "No dataset loaded." | |
| if column_name not in _df.columns: | |
| return f"Column '{column_name}' not found. Available: {_df.columns.tolist()}" | |
| col = _df[column_name] | |
| result = [f"Analysis of '{column_name}'", f"Type: {col.dtype}", | |
| f"Non-null: {col.count()} / {len(col)}", | |
| f"Nulls: {col.isnull().sum()} ({col.isnull().mean()*100:.1f}%)"] | |
| if pd.api.types.is_numeric_dtype(col): | |
| Q1, Q3 = col.quantile(0.25), col.quantile(0.75) | |
| IQR = Q3 - Q1 | |
| outliers = int(((col < Q1 - 1.5*IQR) | (col > Q3 + 1.5*IQR)).sum()) | |
| result += [f"Mean: {col.mean():.3f}", f"Median: {col.median():.3f}", | |
| f"Std: {col.std():.3f}", f"Min: {col.min()}", f"Max: {col.max()}", | |
| f"Skewness: {col.skew():.3f}", f"Outliers (IQR): {outliers}"] | |
| else: | |
| result += [f"Unique values: {col.nunique()}", | |
| f"Top 5: {col.value_counts().head(5).to_dict()}", | |
| f"Most common: {col.mode()[0] if not col.mode().empty else 'N/A'}"] | |
| return "\n".join(result) | |
| def find_correlations(query: str) -> str: | |
| """Find correlations between numeric columns. Highlights strong relationships.""" | |
| if _df is None: | |
| return "No dataset loaded." | |
| num_cols = _profile["numeric_columns"] | |
| if len(num_cols) < 2: | |
| return "Need at least 2 numeric columns for correlation analysis." | |
| corr = _df[num_cols].corr().round(3) | |
| strong = [] | |
| for i in range(len(num_cols)): | |
| for j in range(i+1, len(num_cols)): | |
| val = corr.iloc[i, j] | |
| if abs(val) >= 0.5: | |
| strength = "strong" if abs(val) >= 0.8 else "moderate" | |
| direction = "positive" if val > 0 else "negative" | |
| strong.append(f" {num_cols[i]} <-> {num_cols[j]}: {val} ({strength} {direction})") | |
| result = ["Correlation Matrix:", corr.to_string()] | |
| if strong: | |
| result += ["", "Notable correlations:"] + strong | |
| else: | |
| result.append("No strong correlations found (|r| >= 0.5)") | |
| return "\n".join(result) | |
| def detect_anomalies(query: str) -> str: | |
| """Detect outliers and anomalies across all numeric columns using IQR method.""" | |
| if _df is None: | |
| return "No dataset loaded." | |
| num_cols = _profile["numeric_columns"] | |
| if not num_cols: | |
| return "No numeric columns found." | |
| results = ["Anomaly Detection Report (IQR Method):"] | |
| total = 0 | |
| for col in num_cols: | |
| series = _df[col].dropna() | |
| Q1, Q3 = series.quantile(0.25), series.quantile(0.75) | |
| IQR = Q3 - Q1 | |
| outliers = _df[((_df[col] < Q1 - 1.5*IQR) | (_df[col] > Q3 + 1.5*IQR))][col] | |
| if len(outliers) > 0: | |
| total += len(outliers) | |
| results.append(f" {col}: {len(outliers)} outliers | Examples: {outliers.head(3).tolist()}") | |
| results.append(f"\nTotal outliers found: {total}") | |
| if total == 0: | |
| results.append("No significant outliers detected.") | |
| return "\n".join(results) | |
| def run_aggregation(query: str) -> str: | |
| """ | |
| Compute group-by aggregations. | |
| Format: 'group_col|agg_col|function' (e.g. 'category|sales|sum') | |
| Supported functions: sum, mean, count, max, min, median | |
| """ | |
| if _df is None: | |
| return "No dataset loaded." | |
| try: | |
| parts = [p.strip() for p in query.split("|")] | |
| if len(parts) == 3: | |
| group_col, agg_col, func = parts | |
| elif len(parts) == 2: | |
| group_col, agg_col, func = parts[0], parts[1], "mean" | |
| else: | |
| cat_cols = _profile["categorical_columns"] | |
| num_cols = _profile["numeric_columns"] | |
| if not cat_cols or not num_cols: | |
| return "Could not determine columns." | |
| group_col, agg_col, func = cat_cols[0], num_cols[0], "sum" | |
| if group_col not in _df.columns: | |
| return f"Column '{group_col}' not found. Available: {_df.columns.tolist()}" | |
| if agg_col not in _df.columns: | |
| return f"Column '{agg_col}' not found. Available: {_df.columns.tolist()}" | |
| fn = func.lower() | |
| result = _df.groupby(group_col)[agg_col].agg(fn).reset_index().sort_values(agg_col, ascending=False) | |
| result.columns = [group_col, f"{fn}_{agg_col}"] | |
| return f"Aggregation: {fn.upper()} of '{agg_col}' by '{group_col}'\n{result.to_string(index=False)}" | |
| except Exception as e: | |
| return f"Aggregation error: {str(e)}" | |
| def generate_insight_report(query: str) -> str: | |
| """Generate a complete automated insight report with data quality score, patterns, and recommendations.""" | |
| if _df is None: | |
| return "No dataset loaded." | |
| rows, cols = _profile["shape"] | |
| num_cols = _profile["numeric_columns"] | |
| cat_cols = _profile["categorical_columns"] | |
| nulls = sum(_profile["null_counts"].values()) | |
| null_pct = (nulls / (rows * cols) * 100) if rows * cols > 0 else 0 | |
| quality = 100 | |
| if null_pct > 20: quality -= 30 | |
| elif null_pct > 10: quality -= 15 | |
| elif null_pct > 5: quality -= 5 | |
| if _profile["duplicates"] > 0: quality -= 10 | |
| report = [ | |
| "=" * 50, "AUTOMATED INSIGHT REPORT", "=" * 50, "", | |
| "1. DATASET OVERVIEW", | |
| f" Rows: {rows:,} | Columns: {cols}", | |
| f" Numeric: {len(num_cols)} | Categorical: {len(cat_cols)}", | |
| f" Data Quality Score: {quality}/100", "", | |
| "2. DATA QUALITY", | |
| f" Missing values: {nulls} ({null_pct:.1f}%)", | |
| f" Duplicate rows: {_profile['duplicates']}", | |
| ] | |
| if nulls > 0: | |
| worst = max(_profile["null_pct"].items(), key=lambda x: x[1]) | |
| report.append(f" Worst column: '{worst[0]}' ({worst[1]}% missing)") | |
| report += ["", "3. KEY STATISTICS"] | |
| for col in num_cols[:5]: | |
| stats = _profile.get("numeric_stats", {}).get(col, {}) | |
| report.append(f" {col}: mean={stats.get('mean','?')}, range=[{stats.get('min','?')}, {stats.get('max','?')}]") | |
| if cat_cols: | |
| report += ["", "4. CATEGORICAL SUMMARY"] | |
| for col in cat_cols[:3]: | |
| top = _df[col].value_counts().index[0] if not _df[col].empty else "N/A" | |
| report.append(f" {col}: {_df[col].nunique()} unique | most common = '{top}'") | |
| report += [ | |
| "", "5. RECOMMENDATIONS", | |
| f" - {'Fix missing values' if null_pct > 5 else 'Data completeness looks good'}", | |
| f" - {'Remove duplicate rows' if _profile['duplicates'] > 0 else 'No duplicates found'}", | |
| f" - {'Run correlation analysis' if len(num_cols) >= 2 else 'Need more numeric columns'}", | |
| f" - {'Encode categorical columns for ML' if cat_cols else 'Add categorical features'}", | |
| "", "=" * 50, | |
| ] | |
| return "\n".join(report) | |
| def recommend_chart(question: str) -> str: | |
| """Recommend best chart type for a question. Returns JSON with chart_type, x_col, y_col.""" | |
| if _profile is None: | |
| return json.dumps({"chart_type": "bar_chart", "x_col": None, "y_col": None}) | |
| num_cols = _profile["numeric_columns"] | |
| cat_cols = _profile["categorical_columns"] | |
| dt_cols = _profile["datetime_columns"] | |
| q = question.lower() | |
| if any(w in q for w in ["trend", "over time", "time", "date"]) and dt_cols and num_cols: | |
| return json.dumps({"chart_type": "time_series", "x_col": dt_cols[0], "y_col": num_cols[0]}) | |
| elif any(w in q for w in ["correlat", "relationship", "vs", "versus"]) and len(num_cols) >= 2: | |
| return json.dumps({"chart_type": "correlation_heatmap", "x_col": None, "y_col": None}) | |
| elif any(w in q for w in ["distribut", "spread", "histogram"]) and num_cols: | |
| return json.dumps({"chart_type": "distribution_plots", "x_col": None, "y_col": num_cols[0]}) | |
| elif any(w in q for w in ["outlier", "box", "range"]) and num_cols: | |
| return json.dumps({"chart_type": "box_plots", "x_col": None, "y_col": None}) | |
| elif any(w in q for w in ["proportion", "share", "percent", "pie"]) and cat_cols: | |
| return json.dumps({"chart_type": "pie_chart", "x_col": cat_cols[0], "y_col": None}) | |
| elif cat_cols and num_cols: | |
| return json.dumps({"chart_type": "bar_chart", "x_col": cat_cols[0], "y_col": num_cols[0]}) | |
| elif len(num_cols) >= 2: | |
| return json.dumps({"chart_type": "scatter", "x_col": num_cols[0], "y_col": num_cols[1]}) | |
| return json.dumps({"chart_type": "bar_chart", "x_col": None, "y_col": None}) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AGENT BUILDER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TOOLS = [profile_data, analyze_column, find_correlations, | |
| detect_anomalies, run_aggregation, generate_insight_report, recommend_chart] | |
| TOOLS_MAP = {t.name: t for t in TOOLS} | |
| SYSTEM_PROMPT = """You are DataMind, an expert autonomous data analyst AI agent. | |
| When a user asks a question: | |
| 1. THINK about what tools you need | |
| 2. PLAN your steps (use multiple tools in sequence when needed) | |
| 3. EXECUTE each tool | |
| 4. SYNTHESIZE the results into a clear, insightful answer | |
| 5. SELF-CORRECT if a tool returns an error β try a different approach | |
| Your tools: | |
| - profile_data: Get dataset overview (use this first if unsure about the data) | |
| - analyze_column: Deep dive into a specific column | |
| - find_correlations: Find relationships between numeric columns | |
| - detect_anomalies: Find outliers and data quality issues | |
| - run_aggregation: Group-by calculations (sum, mean, count, etc.) | |
| - generate_insight_report: Full automated analysis report | |
| - recommend_chart: Suggest best visualization for a question | |
| Always be precise, proactive, and thorough. Use multiple tools when needed. | |
| Remember conversation history and refer to previous questions when relevant.""" | |
| def build_agent(llm): | |
| """Bind tools to LLM β works on all LangChain versions.""" | |
| return llm.bind_tools(TOOLS) | |
| def run_agent(question: str, agent_executor, chat_history: list) -> dict: | |
| """ | |
| Run the tool-calling agent loop manually. | |
| Works without AgentExecutor β pure langchain_core. | |
| """ | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] | |
| messages += chat_history | |
| messages.append(HumanMessage(content=question)) | |
| steps = [] | |
| max_iterations = 6 | |
| for _ in range(max_iterations): | |
| try: | |
| response = agent_executor.invoke(messages) | |
| except Exception as e: | |
| return {"output": f"Agent error: {str(e)}", "steps": steps, "error": str(e)} | |
| messages.append(response) | |
| # Check if agent wants to call tools | |
| if not response.tool_calls: | |
| # No more tool calls β final answer | |
| return { | |
| "output": response.content or "Analysis complete.", | |
| "steps": steps, | |
| "error": None, | |
| } | |
| # Execute each tool call | |
| for tool_call in response.tool_calls: | |
| tool_name = tool_call["name"] | |
| tool_input = tool_call["args"] | |
| tool_id = tool_call["id"] | |
| tool_fn = TOOLS_MAP.get(tool_name) | |
| if tool_fn: | |
| try: | |
| # Pass input as string if tool expects string | |
| inp = tool_input if isinstance(tool_input, str) \ | |
| else list(tool_input.values())[0] if tool_input else "" | |
| result = tool_fn.invoke(inp) | |
| except Exception as e: | |
| result = f"Tool error: {str(e)}" | |
| else: | |
| result = f"Unknown tool: {tool_name}" | |
| # Track step for UI display | |
| class _Action: | |
| def __init__(self, name, inp): | |
| self.tool = name | |
| self.tool_input = inp | |
| steps.append((_Action(tool_name, tool_input), result)) | |
| # Add tool result to messages | |
| from langchain_core.messages import ToolMessage | |
| messages.append(ToolMessage(content=str(result), tool_call_id=tool_id)) | |
| # Max iterations reached | |
| return { | |
| "output": "Analysis complete β reached maximum reasoning steps.", | |
| "steps": steps, | |
| "error": None, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CHART ENGINE (with robust fallbacks for any dataset) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def auto_suggest_charts(profile): | |
| suggestions = [] | |
| if len(profile["numeric_columns"]) >= 2: | |
| suggestions.extend(["correlation_heatmap", "scatter_matrix"]) | |
| if profile["numeric_columns"]: | |
| suggestions.extend(["distribution_plots", "box_plots"]) | |
| if profile["categorical_columns"] and profile["numeric_columns"]: | |
| suggestions.extend(["bar_chart", "pie_chart"]) | |
| if profile["datetime_columns"] and profile["numeric_columns"]: | |
| suggestions.append("time_series") | |
| return suggestions | |
| def _safe_cat_col(df, cat_cols): | |
| """Pick categorical col with lowest unique count β best for charts.""" | |
| if not cat_cols: | |
| return None | |
| return sorted(cat_cols, key=lambda c: df[c].nunique())[0] | |
| def _safe_num_col(df, num_cols): | |
| """Pick first non-null numeric col.""" | |
| for col in num_cols: | |
| if df[col].dropna().shape[0] > 0: | |
| return col | |
| return None | |
| def make_plotly_chart(chart_type, df, profile, x_col=None, y_col=None, color_col=None): | |
| from plotly.subplots import make_subplots | |
| num_cols = [c for c in profile["numeric_columns"] if df[c].dropna().shape[0] > 0] | |
| cat_cols = profile["categorical_columns"] | |
| template = "plotly_dark" | |
| plot_df = df.sample(min(5000, len(df)), random_state=42) if len(df) > 5000 else df | |
| try: | |
| # Correlation Heatmap β fixed -1 to 1 scale | |
| if chart_type == "correlation_heatmap" and len(num_cols) >= 2: | |
| corr = plot_df[num_cols[:10]].corr().round(2) | |
| fig = px.imshow(corr, text_auto=True, color_continuous_scale="RdBu_r", | |
| title="Correlation Heatmap", template=template, | |
| color_continuous_midpoint=0, zmin=-1, zmax=1) | |
| fig.update_layout(height=500) | |
| # Distribution β each column its own subplot + scale | |
| elif chart_type == "distribution_plots" and num_cols: | |
| cols_to_plot = num_cols[:6] | |
| n = len(cols_to_plot) | |
| ncols = min(3, n) | |
| nrows = (n + ncols - 1) // ncols | |
| fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=cols_to_plot) | |
| for idx, col in enumerate(cols_to_plot): | |
| r, c = idx // ncols + 1, idx % ncols + 1 | |
| data = plot_df[col].dropna() | |
| fig.add_trace(go.Histogram(x=data, nbinsx=30, name=col, | |
| marker_color=PALETTE[idx % len(PALETTE)], | |
| showlegend=False), row=r, col=c) | |
| fig.add_vline(x=float(data.mean()), line_dash="dash", | |
| line_color="white", opacity=0.5, row=r, col=c) | |
| fig.update_xaxes(matches=None) | |
| fig.update_yaxes(matches=None) | |
| fig.update_layout(title="Distributions β Independent Scale per Column", | |
| template=template, height=350 * nrows) | |
| # Box Plots β each column its own subplot + scale | |
| elif chart_type == "box_plots" and num_cols: | |
| cols_to_plot = num_cols[:6] | |
| n = len(cols_to_plot) | |
| ncols = min(3, n) | |
| nrows = (n + ncols - 1) // ncols | |
| fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=cols_to_plot) | |
| for idx, col in enumerate(cols_to_plot): | |
| r, c = idx // ncols + 1, idx % ncols + 1 | |
| fig.add_trace(go.Box(y=plot_df[col].dropna(), name=col, | |
| marker_color=PALETTE[idx % len(PALETTE)], | |
| boxmean=True, showlegend=False), row=r, col=c) | |
| fig.update_yaxes(matches=None) | |
| fig.update_layout(title="Box Plots β Independent Scale per Column", | |
| template=template, height=350 * nrows) | |
| # Bar Chart β actual values labeled on bars | |
| elif chart_type == "bar_chart": | |
| xc = x_col if x_col in df.columns else _safe_cat_col(df, cat_cols) | |
| yc = y_col if y_col in num_cols else _safe_num_col(df, num_cols) | |
| if xc and yc: | |
| agg = (df.groupby(xc)[yc].mean().reset_index() | |
| .sort_values(yc, ascending=False).head(15)) | |
| agg[yc] = agg[yc].round(2) | |
| fig = px.bar(agg, x=xc, y=yc, color=yc, | |
| color_continuous_scale="Viridis", | |
| title=f"Average {yc} by {xc}", | |
| template=template, text=yc) | |
| fig.update_traces(textposition="outside") | |
| fig.update_yaxes(range=[0, agg[yc].max() * 1.2]) | |
| fig.update_layout(height=500) | |
| else: | |
| raise ValueError("No suitable columns for bar chart") | |
| # Pie Chart β show label + percent + value | |
| elif chart_type == "pie_chart" and cat_cols: | |
| col = x_col if x_col in cat_cols else _safe_cat_col(df, cat_cols) | |
| counts = df[col].value_counts().head(8) | |
| fig = px.pie(values=counts.values, names=counts.index, | |
| title=f"Distribution of {col}", | |
| color_discrete_sequence=PALETTE, template=template, hole=0.35) | |
| fig.update_traces(textinfo="label+percent+value") | |
| fig.update_layout(height=500) | |
| # Scatter Matrix β each axis auto-scales independently | |
| elif chart_type == "scatter_matrix" and len(num_cols) >= 2: | |
| safe_cat = _safe_cat_col(df, [c for c in cat_cols if df[c].nunique() <= 10]) | |
| fig = px.scatter_matrix(plot_df, dimensions=num_cols[:4], | |
| color=safe_cat, color_discrete_sequence=PALETTE, | |
| title="Scatter Matrix β Each Axis Independent", | |
| template=template) | |
| fig.update_traces(diagonal_visible=False, showupperhalf=False) | |
| fig.update_layout(height=600) | |
| # Time Series β each metric its own subplot + scale | |
| elif chart_type == "time_series" and profile["datetime_columns"] and num_cols: | |
| dt_col = profile["datetime_columns"][0] | |
| cols_to_plot = [y_col] if y_col in num_cols else num_cols[:4] | |
| n = len(cols_to_plot) | |
| fig = make_subplots(rows=n, cols=1, subplot_titles=cols_to_plot, | |
| shared_xaxes=True) | |
| sorted_df = df.sort_values(dt_col) | |
| for idx, col in enumerate(cols_to_plot): | |
| fig.add_trace(go.Scatter(x=sorted_df[dt_col], y=sorted_df[col], | |
| name=col, mode="lines", | |
| line=dict(color=PALETTE[idx % len(PALETTE)])), | |
| row=idx + 1, col=1) | |
| fig.update_yaxes(matches=None) | |
| fig.update_layout(title="Time Series β Independent Scale per Metric", | |
| template=template, height=300 * n) | |
| # Scatter β with trendline and marginal histograms | |
| elif chart_type == "scatter" and len(num_cols) >= 2: | |
| xc = x_col if x_col in num_cols else num_cols[0] | |
| yc = y_col if y_col in num_cols else num_cols[1] | |
| safe_cat = _safe_cat_col(df, [c for c in cat_cols if df[c].nunique() <= 10]) | |
| fig = px.scatter(plot_df, x=xc, y=yc, color=color_col or safe_cat, | |
| color_discrete_sequence=PALETTE, | |
| title=f"{xc} vs {yc}", template=template, | |
| trendline="ols", marginal_x="histogram", marginal_y="histogram") | |
| fig.update_layout(height=600) | |
| # Line β each metric its own subplot + scale | |
| elif chart_type == "line" and num_cols: | |
| xc = x_col if x_col in df.columns else (profile["datetime_columns"][0] if profile["datetime_columns"] else num_cols[0]) | |
| cols_to_plot = [y_col] if y_col in num_cols else num_cols[:4] | |
| n = len(cols_to_plot) | |
| fig = make_subplots(rows=n, cols=1, subplot_titles=cols_to_plot, | |
| shared_xaxes=True) | |
| for idx, col in enumerate(cols_to_plot): | |
| fig.add_trace(go.Scatter(x=plot_df[xc], y=plot_df[col], | |
| name=col, mode="lines", | |
| line=dict(color=PALETTE[idx % len(PALETTE)])), | |
| row=idx + 1, col=1) | |
| fig.update_yaxes(matches=None) | |
| fig.update_layout(title="Line Chart β Independent Scale per Metric", | |
| template=template, height=300 * n) | |
| # Fallback β column means with actual values labeled | |
| else: | |
| if num_cols: | |
| means = df[num_cols[:8]].mean().dropna().round(2) | |
| fig = px.bar(x=means.index, y=means.values, color=means.index, | |
| color_discrete_sequence=PALETTE, | |
| title="Column Means Overview", template=template, | |
| text=means.values, | |
| labels={"x": "Column", "y": "Mean Value"}) | |
| fig.update_traces(textposition="outside") | |
| fig.update_yaxes(range=[0, means.max() * 1.2]) | |
| fig.update_layout(showlegend=False, height=450) | |
| else: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No numeric data available.", | |
| showarrow=False, font=dict(size=14, color="#E0E0FF")) | |
| fig.update_layout(template=template, title="Chart Unavailable") | |
| except Exception as e: | |
| if num_cols: | |
| means = df[num_cols[:8]].mean().dropna().round(2) | |
| fig = px.bar(x=means.index, y=means.values, color=means.index, | |
| color_discrete_sequence=PALETTE, | |
| title="Column Means (fallback)", template=template, | |
| text=means.values, | |
| labels={"x": "Column", "y": "Mean Value"}) | |
| fig.update_traces(textposition="outside") | |
| fig.update_yaxes(range=[0, means.max() * 1.2]) | |
| fig.update_layout(showlegend=False, height=450) | |
| else: | |
| fig = go.Figure() | |
| fig.add_annotation(text=f"Chart error: {str(e)}", | |
| showarrow=False, font=dict(size=12, color="#FF6584")) | |
| fig.update_layout(template=template, title="Chart Error") | |
| fig.update_layout(paper_bgcolor=DARK_BG, plot_bgcolor=CARD_BG, | |
| font=dict(family="DM Sans, sans-serif", color="#E0E0FF"), | |
| margin=dict(l=40, r=40, t=60, b=40)) | |
| return fig |