Spaces:
Runtime error
Runtime error
| import gradio as gr # type: ignore | |
| import os # type: ignore | |
| import numpy as np #type: ignore | |
| from dotenv import load_dotenv | |
| from transformers import AutoTokenizer # type: ignore | |
| from sentence_transformers import SentenceTransformer # type: ignore | |
| from huggingface_hub import InferenceClient, login, HfApi, whoami # type: ignore | |
| from gradio.components import ChatMessage # type: ignore | |
| from typing import List, TypedDict | |
| import json | |
| from datetime import datetime | |
| import time | |
| import uuid | |
| import tempfile | |
| import shutil | |
| import hashlib # Added for password hashing | |
| class Message(TypedDict): | |
| role: str | |
| content: str | |
| if os.path.exists('.env'): | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN") | |
| """ | |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| client = InferenceClient("https://xk54gqdcp97za8n6.us-east-1.aws.endpoints.huggingface.cloud") | |
| model = SentenceTransformer('all-MiniLM-L6-v2') # You can choose other models depending on your needs | |
| MAX_HISTORY_LENGTH = 5000 # Keep the last 10 exchanges | |
| MAX_TOKENS = 128000 # Token limit for your model (check your model's max tokens) | |
| EMBEDDING_DIM = 384 # Dimension of embeddings, specific to the model you use (e.g., for 'all-MiniLM-L6-v2', it's 384) | |
| login(token=hf_token) | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407") # Ersätt med din egen modell om det behövs | |
| def load_persona(): | |
| try: | |
| with open("profile.md", "r", encoding="utf-8") as profile_file: | |
| profile_content = profile_file.read() | |
| with open("instructions.md", "r", encoding="utf-8") as instructions_file: | |
| instructions_content = instructions_file.read() | |
| # Combine profile and instructions with blank lines in between | |
| content = profile_content + "\n\n" + instructions_content | |
| return content | |
| except FileNotFoundError as e: | |
| print(f"Warning: File not found: {e.filename}. Using default persona.") | |
| return """Act and roleplay as a literal horse.""" | |
| # Preloaded conversation state (initial history) | |
| system_message: List[Message] = [Message(role="system", content=load_persona())] | |
| # Add this after the existing imports | |
| CHAT_HISTORY_DIR = "/data/chat_history" | |
| os.makedirs(CHAT_HISTORY_DIR, exist_ok=True) | |
| # Generate a unique session ID when the app starts | |
| SESSION_ID = f"chat_session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" | |
| # Create a temporary directory for file operations | |
| TEMP_DIR = tempfile.mkdtemp() | |
| # Global authentication state | |
| is_user_authenticated = False | |
| # Hashed password - this is a hash of "password123" using SHA-256 | |
| HASHED_PASSWORD = "f75778f7425be4db0369d09af37a6c2b9a83dea0e53e7bd57412e4b060e607f7" | |
| def get_session_file(): | |
| """Get the current session's chat history file path.""" | |
| return os.path.join(CHAT_HISTORY_DIR, f"{SESSION_ID}.json") | |
| def save_chat_history(history: List[Message]): | |
| """Save chat history to the session file.""" | |
| filename = get_session_file() | |
| # Convert history to a serializable format | |
| serializable_history = [ | |
| {"role": msg["role"], "content": msg["content"]} | |
| for msg in history | |
| ] | |
| # Create a backup of the previous file if it exists | |
| if os.path.exists(filename): | |
| backup_filename = f"{filename}.bak" | |
| os.replace(filename, backup_filename) | |
| try: | |
| with open(filename, "w", encoding="utf-8") as f: | |
| json.dump(serializable_history, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| # If saving fails, restore from backup | |
| if os.path.exists(f"{filename}.bak"): | |
| os.replace(f"{filename}.bak", filename) | |
| raise e | |
| def load_chat_history(session_id: str) -> List[Message]: | |
| """Load chat history from a file.""" | |
| filename = os.path.join(CHAT_HISTORY_DIR, f"chat_history_{session_id}.json") | |
| if os.path.exists(filename): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| history = json.load(f) | |
| return [Message(**msg) for msg in history] | |
| return [] | |
| # Add these constants at the top with your other constants | |
| DATA_FOLDER = "/data/chat_history" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| def is_authenticated(): | |
| """Check if the user is authenticated.""" | |
| # Use the global authentication state | |
| return is_user_authenticated | |
| def list_chat_history_files(): | |
| """List all chat history files in the data folder.""" | |
| if not is_authenticated(): | |
| return [] | |
| try: | |
| files = [f for f in os.listdir(DATA_FOLDER) if f.endswith('.json')] | |
| # Sort files by modification time, newest first | |
| files.sort(key=lambda x: os.path.getmtime(os.path.join(DATA_FOLDER, x)), reverse=True) | |
| return files | |
| except Exception: | |
| return [] | |
| def download_chat_history(filename): | |
| """Download a specific chat history file.""" | |
| if not is_authenticated(): | |
| return None | |
| if not filename: # Handle case when no file is selected | |
| return None | |
| source_path = os.path.join(DATA_FOLDER, filename) | |
| if os.path.exists(source_path): | |
| # Copy the file to temp directory first | |
| temp_path = os.path.join(TEMP_DIR, filename) | |
| shutil.copy2(source_path, temp_path) | |
| return temp_path | |
| return None | |
| def verify_password(password): | |
| """Verify password and update authentication state.""" | |
| global is_user_authenticated | |
| # Hash the provided password using SHA-256 | |
| hashed_input = hashlib.sha256(password.encode()).hexdigest() | |
| # Compare with the stored hash | |
| if hashed_input == HASHED_PASSWORD: | |
| is_user_authenticated = True | |
| return "**Status:** ✅ Authenticated" | |
| else: | |
| is_user_authenticated = False | |
| return "**Status:** ❌ Authentication failed" | |
| def logout(): | |
| """Log the user out by resetting the authentication state.""" | |
| global is_user_authenticated | |
| is_user_authenticated = False | |
| return "**Status:** ❌ Not authenticated", gr.update(visible=False) | |
| # Create a Gradio interface | |
| with gr.Blocks() as iface: | |
| # Your existing chat interface components | |
| chatbot_output = gr.Chatbot(label="Chat History", type="messages") | |
| chatbot_input = gr.Textbox(placeholder="Type your message here...", label="Your Message") | |
| def update_file_list(): | |
| """Update the file list and return the updated dropdown.""" | |
| return gr.update(choices=list_chat_history_files()) | |
| # Authentication status indicator | |
| # Create a lock icon button that expands into a password field | |
| with gr.Group() as auth_group: | |
| auth_lock_btn = gr.Button("🔒", elem_id="auth_lock_btn", scale=0) | |
| with gr.Group("Enter Password", visible=False, elem_id="auth_accordion") as auth_accordion: | |
| auth_password = gr.Textbox( | |
| type="password", | |
| placeholder="Enter your password", | |
| label="Password", | |
| elem_id="auth_password", | |
| visible=not is_authenticated() | |
| ) | |
| auth_submit = gr.Button("Login", elem_id="auth_submit", visible=not is_authenticated()) | |
| auth_status = gr.Markdown( | |
| value="**Status:** " + ("✅ Authenticated" if is_authenticated() else "❌ Not authenticated"), | |
| elem_id="auth_status" | |
| ) | |
| auth_logout = gr.Button("Logout", elem_id="auth_logout", visible=is_authenticated()) | |
| # Toggle visibility of password field when lock button is clicked | |
| auth_lock_btn.click( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=[auth_accordion] | |
| ) | |
| # Add download section (only visible when authenticated) | |
| with gr.Group(visible=is_authenticated()) as download_section: | |
| gr.Markdown("### Download Chat History") | |
| file_list = gr.Dropdown( | |
| choices=list_chat_history_files(), | |
| label="Select a chat history file", | |
| interactive=True, | |
| allow_custom_value=False | |
| ) | |
| download_button = gr.Button("Download Chat History") | |
| download_output = gr.File(label="Downloaded file") | |
| # Update file list when refresh button is clicked | |
| refresh_button = gr.Button("Refresh File List") | |
| refresh_button.click( | |
| fn=update_file_list, | |
| outputs=file_list | |
| ) | |
| # Handle password submission | |
| auth_submit.click( | |
| fn=verify_password, | |
| inputs=[auth_password], | |
| outputs=[auth_status] | |
| ) | |
| # Clear password field after submission | |
| auth_submit.click( | |
| fn=lambda: "", | |
| outputs=[auth_password] | |
| ) | |
| # Update download section visibility based on authentication status | |
| auth_submit.click( | |
| fn=lambda: gr.update(visible=is_authenticated()), | |
| outputs=[download_section] | |
| ) | |
| # Handle logout button | |
| auth_logout.click( | |
| fn=logout, | |
| outputs=[auth_status, download_section] | |
| ) | |
| # Update logout button visibility based on authentication status | |
| auth_submit.click( | |
| fn=lambda: gr.update(visible=is_authenticated()), | |
| outputs=[auth_logout] | |
| ) | |
| auth_submit.click( | |
| fn=update_file_list, | |
| outputs=file_list | |
| ) | |
| # Update file list when dropdown selection changes | |
| file_list.select( | |
| fn=update_file_list, | |
| outputs=file_list | |
| ) | |
| def generate_embeddings(messages: List[str]): | |
| """Generate embeddings for the list of messages.""" | |
| embeddings = model.encode(messages, show_progress_bar=False) | |
| return embeddings | |
| def summarize_conversation(conversation: List[Message]): | |
| """Summarize conversation history into a single embedding.""" | |
| # Extract the text content from the conversation | |
| messages = [msg['content'] for msg in conversation] | |
| # Generate embeddings for the entire conversation | |
| conversation_embeddings = generate_embeddings(messages) | |
| # Return the average of all embeddings (this is a simple approach for compacting) | |
| #compact_representation = np.mean(conversation_embeddings, axis=0) | |
| #return compact_representation | |
| return conversation_embeddings | |
| def count_tokens(messages: List[str]) -> int: | |
| """Beräkna det totala antalet tokens i konversationen.""" | |
| return sum(len(tokenizer.encode(message)) for message in messages) | |
| def get_chat_completion(system_message, history, retry_attempt=0, max_retries=3): | |
| """Get chat completion from the model with retry logic for 503 errors.""" | |
| try: | |
| # Common parameters | |
| params = { | |
| "model": "openerotica/writing-roleplay-20k-context-nemo-12b-v1.0-gguf", | |
| "messages": [*system_message, *history], | |
| "stream_options": {"enabled": True}, | |
| "stream": True, | |
| "frequency_penalty": 1.0, | |
| "max_tokens": 2048, | |
| "n": 1, | |
| "presence_penalty": 1.0, | |
| "temperature": 1.0, | |
| "top_p": 1.0 | |
| } | |
| return client.chat_completion(**params) | |
| except Exception as e: | |
| if hasattr(e.response, 'status_code') and "503" in str(e.response.status_code): | |
| if retry_attempt < max_retries: | |
| message = f"Agent is asleep, waking up... Trying again in 3 minutes... (Attempt {retry_attempt + 1}/{max_retries})" | |
| gr.Warning(message, duration=180) | |
| time.sleep(180) | |
| gr.Info("Retrying...") | |
| return get_chat_completion(system_message, history, retry_attempt + 1, max_retries) | |
| else: | |
| gr.Error(f"Max retries ({max_retries}) reached. Giving up.") | |
| return None | |
| else: | |
| gr.Error(f"Error getting chat completion: {e}") | |
| if retry_attempt < max_retries: | |
| gr.Warning(f"Retrying after error... (Attempt {retry_attempt + 1}/{max_retries})", duration=10) | |
| time.sleep(10) # Wait a bit before retrying after an error | |
| return get_chat_completion(system_message, history, retry_attempt + 1, max_retries) | |
| return None | |
| def user(user_message, history: List[Message]): | |
| new_history = history + [Message(role="user", content=user_message)] | |
| save_chat_history(new_history) | |
| return "", new_history | |
| def bot(history: list): | |
| #compact_history = summarize_conversation(preloaded_history) | |
| #compact_history = preloaded_history[-MAX_HISTORY_LENGTH:] | |
| #conversation = [msg["content"] for msg in compact_history] | |
| session_conversation = [msg["content"] for msg in history] | |
| system_context = [msg["content"] for msg in system_message] | |
| total_tokens = count_tokens(session_conversation) + count_tokens(system_context) | |
| #total_tokens = count_tokens(conversation) + session_tokens | |
| print(f"Total tokens: {total_tokens}") | |
| # Kolla om tokenräkningen överskrider gränsen (igen) | |
| if total_tokens > MAX_TOKENS: | |
| print("Token limit exceeded. Truncating history.") | |
| while (count_tokens([msg["content"] for msg in history]) + total_tokens) > MAX_TOKENS: | |
| history.pop(0) # Ta bort det äldsta meddelandet | |
| response = get_chat_completion(system_message, history) | |
| if response: | |
| # Initialize bot_message | |
| bot_message = "" | |
| history.append(Message(role="assistant", content="")) | |
| for chunk in response: | |
| # Debugging: Log the received chunk | |
| if 'choices' in chunk and chunk['choices']: | |
| choice = chunk['choices'][0] | |
| if choice.get('delta') and choice['delta'].get('content'): | |
| # Append the new content to bot_message | |
| bot_message += choice['delta']['content'] | |
| history[-1]['content'] = bot_message | |
| yield history | |
| save_chat_history(history) | |
| # Add download functionality | |
| download_button.click( | |
| fn=download_chat_history, | |
| inputs=file_list, | |
| outputs=download_output | |
| ) | |
| chatbot_input.submit(user, [chatbot_input, chatbot_output], [chatbot_input, chatbot_output], queue=False).then( | |
| bot, chatbot_output, chatbot_output | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch( | |
| allowed_paths=[DATA_FOLDER, TEMP_DIR] # Add both the data folder and temp directory to allowed paths | |
| ) |