#!/usr/bin/env python3
"""
LLM Tester - A console application for testing OpenAI-compatible LLM endpoints.

This application allows users to interact with LLM servers using the OpenAI-compatible
chat completions API. It supports model selection and maintains chat history.
"""

import argparse
import sys
import time
import json
import requests


# Global debug flag
DEBUG_MODE = False


def mask_api_key(api_key):
    """Mask API key for secure display.
    
    Args:
        api_key (str): The API key to mask.
    
    Returns:
        str: Masked API key showing only first 5 and last 4 characters.
    """
    if not api_key or len(api_key) < 10:
        return "***"
    return f"{api_key[:5]}***...{api_key[-4:]}"


def debug_print_request(method, url, headers, body=None):
    """Print HTTP request details in a formatted way.
    
    Args:
        method (str): HTTP method (GET, POST, etc.)
        url (str): Request URL
        headers (dict): Request headers
        body (dict, optional): Request body as dictionary
    """
    if not DEBUG_MODE:
        return
    
    # Box drawing characters
    tl, tr, bl, br = "╔", "╗", "╚", "╝"
    h, v = "═", "║"
    cross_l, cross_r = "╠", "╣"
    
    width = 80
    
    print(f"\n{tl}{h * width}{tr}")
    print(f"{v} REQUEST - {method} {url}")
    print(f"{cross_l}{h * width}{cross_r}")
    print(f"{v} Headers:")
    
    # Mask sensitive headers
    for key, value in headers.items():
        if key.lower() == "authorization":
            # Extract and mask the API key
            if value.startswith("Bearer "):
                api_key = value[7:]
                masked_value = f"Bearer {mask_api_key(api_key)}"
            else:
                masked_value = mask_api_key(value)
            print(f"{v}   {key}: {masked_value}")
        else:
            print(f"{v}   {key}: {value}")
    
    print(f"{cross_l}{h * width}{cross_r}")
    
    if body:
        print(f"{v} Body:")
        body_str = json.dumps(body, indent=2)
        for line in body_str.split("\n"):
            # Truncate long lines if needed
            if len(line) > width - 2:
                line = line[:width - 5] + "..."
            print(f"{v}   {line}")
    else:
        print(f"{v} Body: None")
    
    print(f"{bl}{h * width}{br}\n")


def debug_print_response(status_code, reason, headers, body, elapsed_time):
    """Print HTTP response details in a formatted way.
    
    Args:
        status_code (int): HTTP status code
        reason (str): HTTP status reason phrase
        headers (dict): Response headers
        body (dict): Response body as dictionary
        elapsed_time (float): Time taken for the request in seconds
    """
    if not DEBUG_MODE:
        return
    
    # Box drawing characters
    tl, tr, bl, br = "╔", "╗", "╚", "╝"
    h, v = "═", "║"
    cross_l, cross_r = "╠", "╣"
    
    width = 80
    
    # Status color indicator
    if status_code < 300:
        status_indicator = "OK"
    elif status_code < 400:
        status_indicator = "REDIRECT"
    elif status_code < 500:
        status_indicator = "CLIENT ERROR"
    else:
        status_indicator = "SERVER ERROR"
    
    print(f"\n{tl}{h * width}{tr}")
    print(f"{v} RESPONSE - HTTP {status_code} {reason} ({elapsed_time:.2f}s) [{status_indicator}]")
    print(f"{cross_l}{h * width}{cross_r}")
    
    # Print ALL response headers
    print(f"{v} Headers:")
    if headers:
        for key, value in headers.items():
            # Truncate very long header values for readability
            value_str = str(value)
            if len(value_str) > 60:
                value_str = value_str[:57] + "..."
            print(f"{v}   {key}: {value_str}")
    else:
        print(f"{v}   (No headers)")
    
    print(f"{cross_l}{h * width}{cross_r}")
    
    # Print response body
    print(f"{v} Body:")
    if body:
        body_str = json.dumps(body, indent=2)
        for line in body_str.split("\n"):
            # Truncate long lines if needed
            if len(line) > width - 2:
                line = line[:width - 5] + "..."
            print(f"{v}   {line}")
    else:
        print(f"{v}   (No body or empty response)")
    
    print(f"{bl}{h * width}{br}\n")


def parse_arguments():
    """Parse command-line arguments.
    
    Returns:
        argparse.Namespace: Parsed arguments containing llm_url, llm_api_key, and llm_model.
    """
    parser = argparse.ArgumentParser(
        description="LLM Tester - Interactive console for testing OpenAI-compatible LLM endpoints"
    )
    parser.add_argument(
        "--llm-url",
        type=str,
        default="http://localhost/v1",
        help="LLM server URL (default: http://localhost/v1)"
    )
    parser.add_argument(
        "--llm-api-key",
        type=str,
        default="sk-123",
        help="LLM server API key (default: sk-123)"
    )
    parser.add_argument(
        "--llm-model",
        type=str,
        default=None,
        help="LLM model name (optional - if missing, lists available models)"
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        default=False,
        help="Enable debug mode to print HTTP requests and responses"
    )
    return parser.parse_args()


def list_models(base_url, api_key):
    """Fetch and display available models from the LLM server.
    
    Args:
        base_url (str): Base URL of the LLM server.
        api_key (str): API key for authentication.
    
    Returns:
        list: List of available model dictionaries, or empty list on failure.
    """
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    url = f"{base_url}/models"
    
    try:
        start_time = time.time()
        
        # Create a prepared request to capture all headers
        req = requests.Request('GET', url, headers=headers)
        prepared_req = req.prepare()
        
        # Debug print request with all headers
        debug_print_request("GET", url, dict(prepared_req.headers))
        
        # Send the request using a session
        with requests.Session() as session:
            response = session.send(prepared_req, timeout=10)
        
        elapsed_time = time.time() - start_time
        
        # Debug print response
        debug_print_response(
            response.status_code,
            response.reason,
            dict(response.headers),
            response.json() if response.content else None,
            elapsed_time
        )
        
        response.raise_for_status()
        data = response.json()
        
        # Handle different response formats
        models = []
        if "data" in data:
            models = data["data"]
        elif "models" in data:
            models = data["models"]
        elif isinstance(data, list):
            models = data
        
        return models
    except requests.exceptions.ConnectionError:
        print(f"Error: Could not connect to {base_url}. Please check if the server is running.")
        return []
    except requests.exceptions.Timeout:
        print(f"Error: Connection to {base_url} timed out.")
        return []
    except requests.exceptions.HTTPError as e:
        print(f"Error: HTTP error occurred: {e}")
        return []
    except Exception as e:
        print(f"Error: Failed to fetch models: {e}")
        return []


def select_model(base_url, api_key):
    """Allow user to select a model from available models.
    
    Args:
        base_url (str): Base URL of the LLM server.
        api_key (str): API key for authentication.
    
    Returns:
        str: Selected model name, or None if selection failed.
    """
    print("\nFetching available models...")
    models = list_models(base_url, api_key)
    
    if not models:
        print("No models available or failed to fetch models.")
        return None
    
    print("\nAvailable models:")
    print("-" * 50)
    for i, model in enumerate(models, 1):
        # Try to get model ID from various possible fields
        model_id = model.get("id") or model.get("name") or model.get("model") or str(model)
        print(f"{i}. {model_id}")
    print("-" * 50)
    
    while True:
        selection = input("\nEnter model number or name: ").strip()
        
        if not selection:
            print("Please enter a valid selection.")
            continue
        
        # Try to interpret as a number
        try:
            index = int(selection) - 1
            if 0 <= index < len(models):
                model = models[index]
                return model.get("id") or model.get("name") or model.get("model") or str(model)
        except ValueError:
            pass
        
        # Try to interpret as a model name
        for model in models:
            model_id = model.get("id") or model.get("name") or model.get("model") or str(model)
            if model_id.lower() == selection.lower():
                return model_id
        
        print(f"Invalid selection. Please enter a number (1-{len(models)}) or a valid model name.")


def send_chat_request(base_url, api_key, model, messages):
    """Send a chat completion request to the LLM server.
    
    Args:
        base_url (str): Base URL of the LLM server.
        api_key (str): API key for authentication.
        model (str): Model name to use.
        messages (list): List of message dictionaries with 'role' and 'content'.
    
    Returns:
        str: Assistant's response content, or None on failure.
    """
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": model,
        "messages": messages,
        "stream": False
    }
    
    url = f"{base_url}/chat/completions"
    
    try:
        start_time = time.time()
        
        # Create a prepared request to capture all headers
        req = requests.Request('POST', url, headers=headers, json=payload)
        prepared_req = req.prepare()
        
        # Debug print request with all headers
        debug_print_request("POST", url, dict(prepared_req.headers), payload)
        
        # Send the request using a session
        with requests.Session() as session:
            response = session.send(prepared_req, timeout=60)
        
        elapsed_time = time.time() - start_time
        
        # Debug print response
        debug_print_response(
            response.status_code,
            response.reason,
            dict(response.headers),
            response.json() if response.content else None,
            elapsed_time
        )
        
        response.raise_for_status()
        data = response.json()
        
        # Extract the assistant's response
        if "choices" in data and len(data["choices"]) > 0:
            choice = data["choices"][0]
            if "message" in choice:
                return choice["message"].get("content")
            elif "delta" in choice:
                return choice["delta"].get("content")
        
        print("Error: Unexpected response format from server.")
        print(f"Response: {data}")
        return None
        
    except requests.exceptions.ConnectionError:
        print(f"Error: Could not connect to {base_url}. Please check if the server is running.")
        return None
    except requests.exceptions.Timeout:
        print("Error: Request timed out. The model may be taking too long to respond.")
        return None
    except requests.exceptions.HTTPError as e:
        print(f"Error: HTTP error occurred: {e}")
        try:
            error_data = response.json()
            print(f"Server error: {error_data}")
        except:
            pass
        return None
    except Exception as e:
        print(f"Error: Failed to send request: {e}")
        return None


def main():
    """Main function that runs the LLM tester application."""
    global DEBUG_MODE
    
    # Parse command-line arguments
    args = parse_arguments()
    
    # Set debug mode
    DEBUG_MODE = args.debug
    if DEBUG_MODE:
        print("\n" + "=" * 50)
        print("DEBUG MODE ENABLED")
        print("=" * 50 + "\n")
    
    base_url = args.llm_url
    api_key = args.llm_api_key
    model = args.llm_model
    
    # If no model specified, let user select one
    if model is None:
        print("No model specified. Listing available models...")
        model = select_model(base_url, api_key)
        if model is None:
            print("Failed to select a model. Exiting.")
            sys.exit(1)
        print(f"\nUsing model: {model}")
    else:
        print(f"Using model: {model}")
    
    # Initialize chat history
    chat_history = []
    
    print("\n" + "=" * 50)
    print("LLM Tester - Interactive Chat")
    print("=" * 50)
    print("Type 'exit' or 'quit' to end the session")
    print("=" * 50 + "\n")
    
    # Interactive prompt loop
    while True:
        try:
            user_input = input("You: ").strip()
        except KeyboardInterrupt:
            print("\n\nInterrupted. Exiting...")
            break
        except EOFError:
            break
        
        # Check for exit commands
        if user_input.lower() in ["exit", "quit", "q"]:
            print("Goodbye!")
            break
        
        # Skip empty input
        if not user_input:
            continue
        
        # Add user message to history
        chat_history.append({"role": "user", "content": user_input})
        
        # Send request and get response
        print("Assistant: ", end="", flush=True)
        response = send_chat_request(base_url, api_key, model, chat_history)
        
        if response is not None:
            print(response)
            # Add assistant response to history
            chat_history.append({"role": "assistant", "content": response})
        else:
            # Remove the user message if the request failed
            chat_history.pop()
            print("Failed to get a response. Please try again.")


if __name__ == "__main__":
    main()
