data-mind-ultra / core_agent.py
sanjaystarc's picture
Update core_agent.py
6f184a5 verified
"""
core_agent.py
=============
DataMind Agent β€” TRUE Agentic AI + Multi-LLM Support
Providers: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok,
Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI)
File formats: CSV, Excel (.xlsx, .xls), JSON
"""
import os
import json
import warnings
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
warnings.filterwarnings("ignore")
load_dotenv()
# ─── Palette ──────────────────────────────────────────────────────────────────
PALETTE = ["#6C63FF", "#FF6584", "#43E97B", "#F7971E", "#4FC3F7", "#CE93D8"]
DARK_BG = "#0F0F1A"
CARD_BG = "#1A1A2E"
# ─── Global state (shared across agent tools) ─────────────────────────────────
_df = None
_profile = None
def set_dataframe(df, profile):
global _df, _profile
_df = df
_profile = profile
# ══════════════════════════════════════════════════════════════════════════════
# PROVIDER REGISTRY
# ══════════════════════════════════════════════════════════════════════════════
PROVIDERS = {
"gemini": {
"name": "Google Gemini",
"models": [
"gemini-2.5-flash",
"gemini-2.5-pro",
"gemini-2.0-flash",
"gemini-1.5-pro-002",
"gemini-1.5-flash-002",
],
"default": "gemini-2.5-flash",
"key_hint": "AIza...",
"color": "#4285f4",
"key_url": "https://aistudio.google.com/app/apikey",
},
"openai": {
"name": "OpenAI GPT",
"models": [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-3.5-turbo-0125",
],
"default": "gpt-4o",
"key_hint": "sk-...",
"color": "#10a37f",
"key_url": "https://platform.openai.com/api-keys",
},
"claude": {
"name": "Anthropic Claude",
"models": [
"claude-opus-4-6",
"claude-sonnet-4-6",
"claude-haiku-4-5-20251001",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
],
"default": "claude-sonnet-4-6",
"key_hint": "sk-ant-...",
"color": "#d97706",
"key_url": "https://console.anthropic.com/",
},
"grok": {
"name": "xAI Grok",
"models": [
"grok-3",
"grok-3-mini",
"grok-2-1212",
],
"default": "grok-3",
"key_hint": "xai-...",
"color": "#9b9b9b",
"key_url": "https://console.x.ai/",
},
"mistral": {
"name": "Mistral AI",
"models": [
"mistral-large-2411",
"mistral-small-2409",
"open-mixtral-8x22b",
],
"default": "mistral-large-2411",
"key_hint": "...",
"color": "#ff6b35",
"key_url": "https://console.mistral.ai/",
},
"llama": {
"name": "Meta Llama (Together AI)",
"models": [
"meta-llama/llama-4-maverick",
"meta-llama/llama-4-scout",
"meta-llama/llama-3.3-70b-instruct",
"meta-llama/llama-3.1-405b",
"meta-llama/llama-3.1-70b",
],
"default": "meta-llama/llama-4-maverick",
"key_hint": "Together AI key...",
"color": "#0668E1",
"key_url": "https://api.together.ai/",
"note": "Requires a Together AI API key",
},
"qwen": {
"name": "Alibaba Qwen (Together AI)",
"models": [
"Qwen/qwen2.5-72b-instruct",
"Qwen/qwen2.5-coder-32b",
"Qwen/qwen2-72b-instruct",
],
"default": "Qwen/qwen2.5-72b-instruct",
"key_hint": "Together AI key...",
"color": "#6547d4",
"key_url": "https://api.together.ai/",
"note": "Requires a Together AI API key",
},
}
# ══════════════════════════════════════════════════════════════════════════════
# LLM FACTORY
# ══════════════════════════════════════════════════════════════════════════════
def get_llm(provider: str, api_key: str, model: str = None):
model = model or PROVIDERS[provider]["default"]
if provider == "gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=model, google_api_key=api_key,
temperature=0.3, convert_system_message_to_human=True,
)
elif provider == "openai":
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, api_key=api_key, temperature=0.3)
elif provider == "claude":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, api_key=api_key, temperature=0.3)
elif provider == "grok":
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, api_key=api_key,
base_url="https://api.x.ai/v1", temperature=0.3)
elif provider == "mistral":
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, api_key=api_key, temperature=0.3)
elif provider in ("llama", "qwen"):
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, api_key=api_key,
base_url="https://api.together.xyz/v1", temperature=0.3)
else:
raise ValueError(f"Unknown provider: {provider}")
def validate_llm(provider: str, api_key: str, model: str = None):
llm = get_llm(provider, api_key, model)
llm.invoke([HumanMessage(content="Say OK")])
return llm, f"Connected to {PROVIDERS[provider]['name']}!"
# ══════════════════════════════════════════════════════════════════════════════
# FILE LOADING
# ══════════════════════════════════════════════════════════════════════════════
def load_file(file):
name = file.name.lower()
if name.endswith(".csv"):
return pd.read_csv(file), "CSV"
elif name.endswith((".xlsx", ".xls")):
return pd.read_excel(file), "Excel"
elif name.endswith(".json"):
content = json.load(file)
if isinstance(content, list):
df = pd.DataFrame(content)
else:
df = pd.DataFrame(content) if any(isinstance(v, list) for v in content.values()) \
else pd.DataFrame([content])
return df, "JSON"
else:
raise ValueError(f"Unsupported file type: {name}")
# ══════════════════════════════════════════════════════════════════════════════
# DATA PROFILING
# ══════════════════════════════════════════════════════════════════════════════
def profile_dataframe(df):
numeric_cols = df.select_dtypes(include="number").columns.tolist()
category_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
datetime_cols = df.select_dtypes(include=["datetime"]).columns.tolist()
profile = {
"shape": df.shape,
"columns": df.columns.tolist(),
"dtypes": df.dtypes.astype(str).to_dict(),
"numeric_columns": numeric_cols,
"categorical_columns": category_cols,
"datetime_columns": datetime_cols,
"null_counts": df.isnull().sum().to_dict(),
"null_pct": (df.isnull().mean() * 100).round(2).to_dict(),
"duplicates": int(df.duplicated().sum()),
}
if numeric_cols:
profile["numeric_stats"] = df[numeric_cols].describe().round(3).to_dict()
if category_cols:
profile["top_categories"] = {
col: df[col].value_counts().head(5).to_dict() for col in category_cols
}
return profile
def profile_to_text(profile, df):
rows, cols = profile["shape"]
lines = [
f"Dataset: {rows} rows x {cols} columns",
f"Numeric columns : {', '.join(profile['numeric_columns']) or 'None'}",
f"Categorical cols : {', '.join(profile['categorical_columns']) or 'None'}",
f"Datetime cols : {', '.join(profile['datetime_columns']) or 'None'}",
f"Missing values : {sum(profile['null_counts'].values())} total",
f"Duplicate rows : {profile['duplicates']}",
"", "--- Sample Data (first 5 rows) ---",
df.head(5).to_string(index=False),
]
if profile.get("numeric_stats"):
lines += ["", "--- Numeric Stats ---"]
for col, stats in profile["numeric_stats"].items():
lines.append(f" {col}: mean={stats.get('mean','?')}, std={stats.get('std','?')}, "
f"min={stats.get('min','?')}, max={stats.get('max','?')}")
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# AGENT TOOLS
# ══════════════════════════════════════════════════════════════════════════════
@tool
def profile_data(query: str) -> str:
"""Get full statistical profile of the dataset. Use this FIRST before any analysis."""
if _df is None:
return "No dataset loaded. Please upload a file first."
return profile_to_text(_profile, _df)
@tool
def analyze_column(column_name: str) -> str:
"""Deeply analyze a specific column. Provide the exact column name."""
if _df is None:
return "No dataset loaded."
if column_name not in _df.columns:
return f"Column '{column_name}' not found. Available: {_df.columns.tolist()}"
col = _df[column_name]
result = [f"Analysis of '{column_name}'", f"Type: {col.dtype}",
f"Non-null: {col.count()} / {len(col)}",
f"Nulls: {col.isnull().sum()} ({col.isnull().mean()*100:.1f}%)"]
if pd.api.types.is_numeric_dtype(col):
Q1, Q3 = col.quantile(0.25), col.quantile(0.75)
IQR = Q3 - Q1
outliers = int(((col < Q1 - 1.5*IQR) | (col > Q3 + 1.5*IQR)).sum())
result += [f"Mean: {col.mean():.3f}", f"Median: {col.median():.3f}",
f"Std: {col.std():.3f}", f"Min: {col.min()}", f"Max: {col.max()}",
f"Skewness: {col.skew():.3f}", f"Outliers (IQR): {outliers}"]
else:
result += [f"Unique values: {col.nunique()}",
f"Top 5: {col.value_counts().head(5).to_dict()}",
f"Most common: {col.mode()[0] if not col.mode().empty else 'N/A'}"]
return "\n".join(result)
@tool
def find_correlations(query: str) -> str:
"""Find correlations between numeric columns. Highlights strong relationships."""
if _df is None:
return "No dataset loaded."
num_cols = _profile["numeric_columns"]
if len(num_cols) < 2:
return "Need at least 2 numeric columns for correlation analysis."
corr = _df[num_cols].corr().round(3)
strong = []
for i in range(len(num_cols)):
for j in range(i+1, len(num_cols)):
val = corr.iloc[i, j]
if abs(val) >= 0.5:
strength = "strong" if abs(val) >= 0.8 else "moderate"
direction = "positive" if val > 0 else "negative"
strong.append(f" {num_cols[i]} <-> {num_cols[j]}: {val} ({strength} {direction})")
result = ["Correlation Matrix:", corr.to_string()]
if strong:
result += ["", "Notable correlations:"] + strong
else:
result.append("No strong correlations found (|r| >= 0.5)")
return "\n".join(result)
@tool
def detect_anomalies(query: str) -> str:
"""Detect outliers and anomalies across all numeric columns using IQR method."""
if _df is None:
return "No dataset loaded."
num_cols = _profile["numeric_columns"]
if not num_cols:
return "No numeric columns found."
results = ["Anomaly Detection Report (IQR Method):"]
total = 0
for col in num_cols:
series = _df[col].dropna()
Q1, Q3 = series.quantile(0.25), series.quantile(0.75)
IQR = Q3 - Q1
outliers = _df[((_df[col] < Q1 - 1.5*IQR) | (_df[col] > Q3 + 1.5*IQR))][col]
if len(outliers) > 0:
total += len(outliers)
results.append(f" {col}: {len(outliers)} outliers | Examples: {outliers.head(3).tolist()}")
results.append(f"\nTotal outliers found: {total}")
if total == 0:
results.append("No significant outliers detected.")
return "\n".join(results)
@tool
def run_aggregation(query: str) -> str:
"""
Compute group-by aggregations.
Format: 'group_col|agg_col|function' (e.g. 'category|sales|sum')
Supported functions: sum, mean, count, max, min, median
"""
if _df is None:
return "No dataset loaded."
try:
parts = [p.strip() for p in query.split("|")]
if len(parts) == 3:
group_col, agg_col, func = parts
elif len(parts) == 2:
group_col, agg_col, func = parts[0], parts[1], "mean"
else:
cat_cols = _profile["categorical_columns"]
num_cols = _profile["numeric_columns"]
if not cat_cols or not num_cols:
return "Could not determine columns."
group_col, agg_col, func = cat_cols[0], num_cols[0], "sum"
if group_col not in _df.columns:
return f"Column '{group_col}' not found. Available: {_df.columns.tolist()}"
if agg_col not in _df.columns:
return f"Column '{agg_col}' not found. Available: {_df.columns.tolist()}"
fn = func.lower()
result = _df.groupby(group_col)[agg_col].agg(fn).reset_index().sort_values(agg_col, ascending=False)
result.columns = [group_col, f"{fn}_{agg_col}"]
return f"Aggregation: {fn.upper()} of '{agg_col}' by '{group_col}'\n{result.to_string(index=False)}"
except Exception as e:
return f"Aggregation error: {str(e)}"
@tool
def generate_insight_report(query: str) -> str:
"""Generate a complete automated insight report with data quality score, patterns, and recommendations."""
if _df is None:
return "No dataset loaded."
rows, cols = _profile["shape"]
num_cols = _profile["numeric_columns"]
cat_cols = _profile["categorical_columns"]
nulls = sum(_profile["null_counts"].values())
null_pct = (nulls / (rows * cols) * 100) if rows * cols > 0 else 0
quality = 100
if null_pct > 20: quality -= 30
elif null_pct > 10: quality -= 15
elif null_pct > 5: quality -= 5
if _profile["duplicates"] > 0: quality -= 10
report = [
"=" * 50, "AUTOMATED INSIGHT REPORT", "=" * 50, "",
"1. DATASET OVERVIEW",
f" Rows: {rows:,} | Columns: {cols}",
f" Numeric: {len(num_cols)} | Categorical: {len(cat_cols)}",
f" Data Quality Score: {quality}/100", "",
"2. DATA QUALITY",
f" Missing values: {nulls} ({null_pct:.1f}%)",
f" Duplicate rows: {_profile['duplicates']}",
]
if nulls > 0:
worst = max(_profile["null_pct"].items(), key=lambda x: x[1])
report.append(f" Worst column: '{worst[0]}' ({worst[1]}% missing)")
report += ["", "3. KEY STATISTICS"]
for col in num_cols[:5]:
stats = _profile.get("numeric_stats", {}).get(col, {})
report.append(f" {col}: mean={stats.get('mean','?')}, range=[{stats.get('min','?')}, {stats.get('max','?')}]")
if cat_cols:
report += ["", "4. CATEGORICAL SUMMARY"]
for col in cat_cols[:3]:
top = _df[col].value_counts().index[0] if not _df[col].empty else "N/A"
report.append(f" {col}: {_df[col].nunique()} unique | most common = '{top}'")
report += [
"", "5. RECOMMENDATIONS",
f" - {'Fix missing values' if null_pct > 5 else 'Data completeness looks good'}",
f" - {'Remove duplicate rows' if _profile['duplicates'] > 0 else 'No duplicates found'}",
f" - {'Run correlation analysis' if len(num_cols) >= 2 else 'Need more numeric columns'}",
f" - {'Encode categorical columns for ML' if cat_cols else 'Add categorical features'}",
"", "=" * 50,
]
return "\n".join(report)
@tool
def recommend_chart(question: str) -> str:
"""Recommend best chart type for a question. Returns JSON with chart_type, x_col, y_col."""
if _profile is None:
return json.dumps({"chart_type": "bar_chart", "x_col": None, "y_col": None})
num_cols = _profile["numeric_columns"]
cat_cols = _profile["categorical_columns"]
dt_cols = _profile["datetime_columns"]
q = question.lower()
if any(w in q for w in ["trend", "over time", "time", "date"]) and dt_cols and num_cols:
return json.dumps({"chart_type": "time_series", "x_col": dt_cols[0], "y_col": num_cols[0]})
elif any(w in q for w in ["correlat", "relationship", "vs", "versus"]) and len(num_cols) >= 2:
return json.dumps({"chart_type": "correlation_heatmap", "x_col": None, "y_col": None})
elif any(w in q for w in ["distribut", "spread", "histogram"]) and num_cols:
return json.dumps({"chart_type": "distribution_plots", "x_col": None, "y_col": num_cols[0]})
elif any(w in q for w in ["outlier", "box", "range"]) and num_cols:
return json.dumps({"chart_type": "box_plots", "x_col": None, "y_col": None})
elif any(w in q for w in ["proportion", "share", "percent", "pie"]) and cat_cols:
return json.dumps({"chart_type": "pie_chart", "x_col": cat_cols[0], "y_col": None})
elif cat_cols and num_cols:
return json.dumps({"chart_type": "bar_chart", "x_col": cat_cols[0], "y_col": num_cols[0]})
elif len(num_cols) >= 2:
return json.dumps({"chart_type": "scatter", "x_col": num_cols[0], "y_col": num_cols[1]})
return json.dumps({"chart_type": "bar_chart", "x_col": None, "y_col": None})
# ══════════════════════════════════════════════════════════════════════════════
# AGENT BUILDER
# ══════════════════════════════════════════════════════════════════════════════
TOOLS = [profile_data, analyze_column, find_correlations,
detect_anomalies, run_aggregation, generate_insight_report, recommend_chart]
TOOLS_MAP = {t.name: t for t in TOOLS}
SYSTEM_PROMPT = """You are DataMind, an expert autonomous data analyst AI agent.
When a user asks a question:
1. THINK about what tools you need
2. PLAN your steps (use multiple tools in sequence when needed)
3. EXECUTE each tool
4. SYNTHESIZE the results into a clear, insightful answer
5. SELF-CORRECT if a tool returns an error β€” try a different approach
Your tools:
- profile_data: Get dataset overview (use this first if unsure about the data)
- analyze_column: Deep dive into a specific column
- find_correlations: Find relationships between numeric columns
- detect_anomalies: Find outliers and data quality issues
- run_aggregation: Group-by calculations (sum, mean, count, etc.)
- generate_insight_report: Full automated analysis report
- recommend_chart: Suggest best visualization for a question
Always be precise, proactive, and thorough. Use multiple tools when needed.
Remember conversation history and refer to previous questions when relevant."""
def build_agent(llm):
"""Bind tools to LLM β€” works on all LangChain versions."""
return llm.bind_tools(TOOLS)
def run_agent(question: str, agent_executor, chat_history: list) -> dict:
"""
Run the tool-calling agent loop manually.
Works without AgentExecutor β€” pure langchain_core.
"""
messages = [SystemMessage(content=SYSTEM_PROMPT)]
messages += chat_history
messages.append(HumanMessage(content=question))
steps = []
max_iterations = 6
for _ in range(max_iterations):
try:
response = agent_executor.invoke(messages)
except Exception as e:
return {"output": f"Agent error: {str(e)}", "steps": steps, "error": str(e)}
messages.append(response)
# Check if agent wants to call tools
if not response.tool_calls:
# No more tool calls β€” final answer
return {
"output": response.content or "Analysis complete.",
"steps": steps,
"error": None,
}
# Execute each tool call
for tool_call in response.tool_calls:
tool_name = tool_call["name"]
tool_input = tool_call["args"]
tool_id = tool_call["id"]
tool_fn = TOOLS_MAP.get(tool_name)
if tool_fn:
try:
# Pass input as string if tool expects string
inp = tool_input if isinstance(tool_input, str) \
else list(tool_input.values())[0] if tool_input else ""
result = tool_fn.invoke(inp)
except Exception as e:
result = f"Tool error: {str(e)}"
else:
result = f"Unknown tool: {tool_name}"
# Track step for UI display
class _Action:
def __init__(self, name, inp):
self.tool = name
self.tool_input = inp
steps.append((_Action(tool_name, tool_input), result))
# Add tool result to messages
from langchain_core.messages import ToolMessage
messages.append(ToolMessage(content=str(result), tool_call_id=tool_id))
# Max iterations reached
return {
"output": "Analysis complete β€” reached maximum reasoning steps.",
"steps": steps,
"error": None,
}
# ══════════════════════════════════════════════════════════════════════════════
# CHART ENGINE (with robust fallbacks for any dataset)
# ══════════════════════════════════════════════════════════════════════════════
def auto_suggest_charts(profile):
suggestions = []
if len(profile["numeric_columns"]) >= 2:
suggestions.extend(["correlation_heatmap", "scatter_matrix"])
if profile["numeric_columns"]:
suggestions.extend(["distribution_plots", "box_plots"])
if profile["categorical_columns"] and profile["numeric_columns"]:
suggestions.extend(["bar_chart", "pie_chart"])
if profile["datetime_columns"] and profile["numeric_columns"]:
suggestions.append("time_series")
return suggestions
def _safe_cat_col(df, cat_cols):
"""Pick categorical col with lowest unique count β€” best for charts."""
if not cat_cols:
return None
return sorted(cat_cols, key=lambda c: df[c].nunique())[0]
def _safe_num_col(df, num_cols):
"""Pick first non-null numeric col."""
for col in num_cols:
if df[col].dropna().shape[0] > 0:
return col
return None
def make_plotly_chart(chart_type, df, profile, x_col=None, y_col=None, color_col=None):
from plotly.subplots import make_subplots
num_cols = [c for c in profile["numeric_columns"] if df[c].dropna().shape[0] > 0]
cat_cols = profile["categorical_columns"]
template = "plotly_dark"
plot_df = df.sample(min(5000, len(df)), random_state=42) if len(df) > 5000 else df
try:
# Correlation Heatmap β€” fixed -1 to 1 scale
if chart_type == "correlation_heatmap" and len(num_cols) >= 2:
corr = plot_df[num_cols[:10]].corr().round(2)
fig = px.imshow(corr, text_auto=True, color_continuous_scale="RdBu_r",
title="Correlation Heatmap", template=template,
color_continuous_midpoint=0, zmin=-1, zmax=1)
fig.update_layout(height=500)
# Distribution β€” each column its own subplot + scale
elif chart_type == "distribution_plots" and num_cols:
cols_to_plot = num_cols[:6]
n = len(cols_to_plot)
ncols = min(3, n)
nrows = (n + ncols - 1) // ncols
fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=cols_to_plot)
for idx, col in enumerate(cols_to_plot):
r, c = idx // ncols + 1, idx % ncols + 1
data = plot_df[col].dropna()
fig.add_trace(go.Histogram(x=data, nbinsx=30, name=col,
marker_color=PALETTE[idx % len(PALETTE)],
showlegend=False), row=r, col=c)
fig.add_vline(x=float(data.mean()), line_dash="dash",
line_color="white", opacity=0.5, row=r, col=c)
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.update_layout(title="Distributions β€” Independent Scale per Column",
template=template, height=350 * nrows)
# Box Plots β€” each column its own subplot + scale
elif chart_type == "box_plots" and num_cols:
cols_to_plot = num_cols[:6]
n = len(cols_to_plot)
ncols = min(3, n)
nrows = (n + ncols - 1) // ncols
fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=cols_to_plot)
for idx, col in enumerate(cols_to_plot):
r, c = idx // ncols + 1, idx % ncols + 1
fig.add_trace(go.Box(y=plot_df[col].dropna(), name=col,
marker_color=PALETTE[idx % len(PALETTE)],
boxmean=True, showlegend=False), row=r, col=c)
fig.update_yaxes(matches=None)
fig.update_layout(title="Box Plots β€” Independent Scale per Column",
template=template, height=350 * nrows)
# Bar Chart β€” actual values labeled on bars
elif chart_type == "bar_chart":
xc = x_col if x_col in df.columns else _safe_cat_col(df, cat_cols)
yc = y_col if y_col in num_cols else _safe_num_col(df, num_cols)
if xc and yc:
agg = (df.groupby(xc)[yc].mean().reset_index()
.sort_values(yc, ascending=False).head(15))
agg[yc] = agg[yc].round(2)
fig = px.bar(agg, x=xc, y=yc, color=yc,
color_continuous_scale="Viridis",
title=f"Average {yc} by {xc}",
template=template, text=yc)
fig.update_traces(textposition="outside")
fig.update_yaxes(range=[0, agg[yc].max() * 1.2])
fig.update_layout(height=500)
else:
raise ValueError("No suitable columns for bar chart")
# Pie Chart β€” show label + percent + value
elif chart_type == "pie_chart" and cat_cols:
col = x_col if x_col in cat_cols else _safe_cat_col(df, cat_cols)
counts = df[col].value_counts().head(8)
fig = px.pie(values=counts.values, names=counts.index,
title=f"Distribution of {col}",
color_discrete_sequence=PALETTE, template=template, hole=0.35)
fig.update_traces(textinfo="label+percent+value")
fig.update_layout(height=500)
# Scatter Matrix β€” each axis auto-scales independently
elif chart_type == "scatter_matrix" and len(num_cols) >= 2:
safe_cat = _safe_cat_col(df, [c for c in cat_cols if df[c].nunique() <= 10])
fig = px.scatter_matrix(plot_df, dimensions=num_cols[:4],
color=safe_cat, color_discrete_sequence=PALETTE,
title="Scatter Matrix β€” Each Axis Independent",
template=template)
fig.update_traces(diagonal_visible=False, showupperhalf=False)
fig.update_layout(height=600)
# Time Series β€” each metric its own subplot + scale
elif chart_type == "time_series" and profile["datetime_columns"] and num_cols:
dt_col = profile["datetime_columns"][0]
cols_to_plot = [y_col] if y_col in num_cols else num_cols[:4]
n = len(cols_to_plot)
fig = make_subplots(rows=n, cols=1, subplot_titles=cols_to_plot,
shared_xaxes=True)
sorted_df = df.sort_values(dt_col)
for idx, col in enumerate(cols_to_plot):
fig.add_trace(go.Scatter(x=sorted_df[dt_col], y=sorted_df[col],
name=col, mode="lines",
line=dict(color=PALETTE[idx % len(PALETTE)])),
row=idx + 1, col=1)
fig.update_yaxes(matches=None)
fig.update_layout(title="Time Series β€” Independent Scale per Metric",
template=template, height=300 * n)
# Scatter β€” with trendline and marginal histograms
elif chart_type == "scatter" and len(num_cols) >= 2:
xc = x_col if x_col in num_cols else num_cols[0]
yc = y_col if y_col in num_cols else num_cols[1]
safe_cat = _safe_cat_col(df, [c for c in cat_cols if df[c].nunique() <= 10])
fig = px.scatter(plot_df, x=xc, y=yc, color=color_col or safe_cat,
color_discrete_sequence=PALETTE,
title=f"{xc} vs {yc}", template=template,
trendline="ols", marginal_x="histogram", marginal_y="histogram")
fig.update_layout(height=600)
# Line β€” each metric its own subplot + scale
elif chart_type == "line" and num_cols:
xc = x_col if x_col in df.columns else (profile["datetime_columns"][0] if profile["datetime_columns"] else num_cols[0])
cols_to_plot = [y_col] if y_col in num_cols else num_cols[:4]
n = len(cols_to_plot)
fig = make_subplots(rows=n, cols=1, subplot_titles=cols_to_plot,
shared_xaxes=True)
for idx, col in enumerate(cols_to_plot):
fig.add_trace(go.Scatter(x=plot_df[xc], y=plot_df[col],
name=col, mode="lines",
line=dict(color=PALETTE[idx % len(PALETTE)])),
row=idx + 1, col=1)
fig.update_yaxes(matches=None)
fig.update_layout(title="Line Chart β€” Independent Scale per Metric",
template=template, height=300 * n)
# Fallback β€” column means with actual values labeled
else:
if num_cols:
means = df[num_cols[:8]].mean().dropna().round(2)
fig = px.bar(x=means.index, y=means.values, color=means.index,
color_discrete_sequence=PALETTE,
title="Column Means Overview", template=template,
text=means.values,
labels={"x": "Column", "y": "Mean Value"})
fig.update_traces(textposition="outside")
fig.update_yaxes(range=[0, means.max() * 1.2])
fig.update_layout(showlegend=False, height=450)
else:
fig = go.Figure()
fig.add_annotation(text="No numeric data available.",
showarrow=False, font=dict(size=14, color="#E0E0FF"))
fig.update_layout(template=template, title="Chart Unavailable")
except Exception as e:
if num_cols:
means = df[num_cols[:8]].mean().dropna().round(2)
fig = px.bar(x=means.index, y=means.values, color=means.index,
color_discrete_sequence=PALETTE,
title="Column Means (fallback)", template=template,
text=means.values,
labels={"x": "Column", "y": "Mean Value"})
fig.update_traces(textposition="outside")
fig.update_yaxes(range=[0, means.max() * 1.2])
fig.update_layout(showlegend=False, height=450)
else:
fig = go.Figure()
fig.add_annotation(text=f"Chart error: {str(e)}",
showarrow=False, font=dict(size=12, color="#FF6584"))
fig.update_layout(template=template, title="Chart Error")
fig.update_layout(paper_bgcolor=DARK_BG, plot_bgcolor=CARD_BG,
font=dict(family="DM Sans, sans-serif", color="#E0E0FF"),
margin=dict(l=40, r=40, t=60, b=40))
return fig