Skip to content

add top queries by resources #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,26 +460,30 @@ async def analyze_db_health(
return format_text_response(result)


@mcp.tool(description=f"Reports the slowest SQL queries based on execution time, using data from the '{PG_STAT_STATEMENTS}' extension.")
@mcp.tool(
name="get_top_queries",
description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.",
)
async def get_top_queries(
limit: int = Field(description="Number of slow queries to return", default=10),
sort_by: str = Field(
description="Sort criteria: 'total' for total execution time or 'mean' for mean execution time per call",
default="mean",
description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' "
"for resource-intensive queries",
default="resources",
),
limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10),
) -> ResponseType:
"""Reports the slowest SQL queries based on execution time.

This tool handles PostgreSQL version differences automatically:
- In PostgreSQL 13+: Uses total_exec_time/mean_exec_time columns
- In PostgreSQL 12 and older: Uses total_time/mean_time columns
"""
try:
sql_driver = await get_sql_driver()
top_queries_tool = TopQueriesCalc(sql_driver=sql_driver)
if sort_by != "mean" and sort_by != "total":
return format_error_response("Invalid sort criteria. Please use 'mean' or 'total'.")
result = await top_queries_tool.get_top_queries(limit=limit, sort_by=sort_by)

if sort_by == "resources":
result = await top_queries_tool.get_top_resource_queries()
return format_text_response(result)
elif sort_by == "mean_time" or sort_by == "total_time":
# Map the sort_by values to what get_top_queries_by_time expects
result = await top_queries_tool.get_top_queries_by_time(limit=limit, sort_by="mean" if sort_by == "mean_time" else "total")
else:
return format_error_response("Invalid sort criteria. Please use 'resources' or 'mean_time' or 'total_time'.")
return format_text_response(result)
except Exception as e:
logger.error(f"Error getting slow queries: {e}")
Expand Down
43 changes: 22 additions & 21 deletions src/postgres_mcp/sql/extension_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utilities for working with PostgreSQL extensions."""

import logging
from dataclasses import dataclass
from typing import Literal
from typing import TypedDict

from .safe_sql import SafeSqlDriver
from .sql_driver import SqlDriver
Expand All @@ -14,7 +14,8 @@
_POSTGRES_VERSION = None


class ExtensionStatus(TypedDict):
@dataclass
class ExtensionStatus:
"""Status of an extension."""

is_installed: bool
Expand Down Expand Up @@ -118,25 +119,25 @@ async def check_extension(
)

# Initialize result
result: ExtensionStatus = {
"is_installed": False,
"is_available": False,
"name": extension_name,
"message": "",
"default_version": None,
}
result = ExtensionStatus(
is_installed=False,
is_available=False,
name=extension_name,
message="",
default_version=None,
)

if installed_result and len(installed_result) > 0:
# Extension is installed
version = installed_result[0].cells.get("extversion", "unknown")
result["is_installed"] = True
result["is_available"] = True
result.is_installed = True
result.is_available = True

if include_messages:
if message_type == "markdown":
result["message"] = f"The **{extension_name}** extension (version {version}) is already installed."
result.message = f"The **{extension_name}** extension (version {version}) is already installed."
else:
result["message"] = f"The {extension_name} extension (version {version}) is already installed."
result.message = f"The {extension_name} extension (version {version}) is already installed."
else:
# Check if the extension is available but not installed
available_result = await SafeSqlDriver.execute_param_query(
Expand All @@ -147,32 +148,32 @@ async def check_extension(

if available_result and len(available_result) > 0:
# Extension is available but not installed
result["is_available"] = True
result["default_version"] = available_result[0].cells.get("default_version")
result.is_available = True
result.default_version = available_result[0].cells.get("default_version")

if include_messages:
if message_type == "markdown":
result["message"] = (
result.message = (
f"The **{extension_name}** extension is available but not installed.\n\n"
f"You can install it by running: `CREATE EXTENSION {extension_name};`."
)
else:
result["message"] = (
result.message = (
f"The {extension_name} extension is available but not installed.\n"
f"You can install it by running: CREATE EXTENSION {extension_name};"
)
else:
# Extension is not available
if include_messages:
if message_type == "markdown":
result["message"] = (
result.message = (
f"The **{extension_name}** extension is not available on this PostgreSQL server.\n\n"
f"To install it, you need to:\n"
f"1. Install the extension package on the server\n"
f"2. Run: `CREATE EXTENSION {extension_name};`"
)
else:
result["message"] = (
result.message = (
f"The {extension_name} extension is not available on this PostgreSQL server.\n"
f"To install it, you need to:\n"
f"1. Install the extension package on the server\n"
Expand All @@ -195,13 +196,13 @@ async def check_hypopg_installation_status(sql_driver: SqlDriver, message_type:
"""
status = await check_extension(sql_driver, "hypopg", include_messages=False)

if status["is_installed"]:
if status.is_installed:
if message_type == "markdown":
return True, "The **hypopg** extension is already installed."
else:
return True, "The hypopg extension is already installed."

if status["is_available"]:
if status.is_available:
if message_type == "markdown":
return False, (
"The **hypopg** extension is required to test hypothetical indexes, but it is not currently installed.\n\n"
Expand Down
147 changes: 130 additions & 17 deletions src/postgres_mcp/top_queries/top_queries_calc.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
import logging
from typing import Literal
from typing import LiteralString
from typing import Union
from typing import cast

from ..sql import SafeSqlDriver
from ..sql import SqlDriver
from ..sql.extension_utils import check_extension
from ..sql.extension_utils import get_postgres_version

logger = logging.getLogger(__name__)

PG_STAT_STATEMENTS = "pg_stat_statements"

install_pg_stat_statements_message = (
"The pg_stat_statements extension is required to "
"report slow queries, but it is not currently "
"installed.\n\n"
"You can install it by running: "
"`CREATE EXTENSION pg_stat_statements;`\n\n"
"**What does it do?** It records statistics (like "
"execution time, number of calls, rows returned) for "
"every query executed against the database.\n\n"
"**Is it safe?** Installing 'pg_stat_statements' is "
"generally safe and a standard practice for performance "
"monitoring. It adds overhead by tracking statistics, "
"but this is usually negligible unless under extreme load."
)


class TopQueriesCalc:
"""Tool for retrieving the slowest SQL queries."""

def __init__(self, sql_driver: Union[SqlDriver, SafeSqlDriver]):
self.sql_driver = sql_driver

async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str:
async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str:
"""Reports the slowest SQL queries based on execution time.

Args:
Expand All @@ -27,32 +47,21 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
A string with the top queries or installation instructions
"""
try:
logger.debug(f"Getting top queries by time. limit={limit}, sort_by={sort_by}")
extension_status = await check_extension(
self.sql_driver,
PG_STAT_STATEMENTS,
include_messages=False,
)

if not extension_status["is_installed"]:
if not extension_status.is_installed:
logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed")
# Return installation instructions if the extension is not installed
monitoring_message = (
f"The '{PG_STAT_STATEMENTS}' extension is required to "
f"report slow queries, but it is not currently "
f"installed.\n\n"
f"You can install it by running: "
f"`CREATE EXTENSION {PG_STAT_STATEMENTS};`\n\n"
f"**What does it do?** It records statistics (like "
f"execution time, number of calls, rows returned) for "
f"every query executed against the database.\n\n"
f"**Is it safe?** Installing '{PG_STAT_STATEMENTS}' is "
f"generally safe and a standard practice for performance "
f"monitoring. It adds overhead by tracking statistics, "
f"but this is usually negligible unless under extreme load."
)
return monitoring_message
return install_pg_stat_statements_message

# Check PostgreSQL version to determine column names
pg_version = await get_postgres_version(self.sql_driver)
logger.debug(f"PostgreSQL version: {pg_version}")

# Column names changed in PostgreSQL 13
if pg_version >= 13:
Expand All @@ -64,6 +73,8 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
total_time_col = "total_time"
mean_time_col = "mean_time"

logger.debug(f"Using time columns: total={total_time_col}, mean={mean_time_col}")

# Determine which column to sort by based on sort_by parameter and version
order_by_column = total_time_col if sort_by == "total" else mean_time_col

Expand All @@ -78,12 +89,14 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
ORDER BY {order_by_column} DESC
LIMIT {{}};
"""
logger.debug(f"Executing query: {query}")
slow_query_rows = await SafeSqlDriver.execute_param_query(
self.sql_driver,
query,
[limit],
)
slow_queries = [row.cells for row in slow_query_rows] if slow_query_rows else []
logger.info(f"Found {len(slow_queries)} slow queries")

# Create result description based on sort criteria
if sort_by == "total":
Expand All @@ -95,4 +108,104 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
result += str(slow_queries)
return result
except Exception as e:
logger.error(f"Error getting slow queries: {e}", exc_info=True)
return f"Error getting slow queries: {e}"

async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str:
"""Reports the most time consuming queries based on a resource blend.

Args:
frac_threshold: Fraction threshold for filtering queries (default: 0.05)

Returns:
A string with the resource-heavy queries or error message
"""

try:
logger.debug(f"Getting top resource queries with threshold {frac_threshold}")
extension_status = await check_extension(
self.sql_driver,
PG_STAT_STATEMENTS,
include_messages=False,
)

if not extension_status.is_installed:
logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed")
# Return installation instructions if the extension is not installed
return install_pg_stat_statements_message

# Check PostgreSQL version to determine column names
pg_version = await get_postgres_version(self.sql_driver)
logger.debug(f"PostgreSQL version: {pg_version}")

# Column names changed in PostgreSQL 13
if pg_version >= 13:
# PostgreSQL 13 and newer
total_time_col = "total_exec_time"
mean_time_col = "mean_exec_time"
else:
# PostgreSQL 12 and older
total_time_col = "total_time"
mean_time_col = "mean_time"

query = cast(
LiteralString,
f"""
WITH resource_fractions AS (
SELECT
query,
calls,
rows,
{total_time_col} total_exec_time,
{mean_time_col} mean_exec_time,
stddev_exec_time,
shared_blks_hit,
shared_blks_read,
shared_blks_dirtied,
wal_bytes,
total_exec_time / SUM(total_exec_time) OVER () AS total_exec_time_frac,
(shared_blks_hit + shared_blks_read) / SUM(shared_blks_hit + shared_blks_read) OVER () AS shared_blks_accessed_frac,
shared_blks_read / SUM(shared_blks_read) OVER () AS shared_blks_read_frac,
shared_blks_dirtied / SUM(shared_blks_dirtied) OVER () AS shared_blks_dirtied_frac,
wal_bytes / SUM(wal_bytes) OVER () AS total_wal_bytes_frac
FROM pg_stat_statements
)
SELECT
query,
calls,
rows,
total_exec_time,
mean_exec_time,
stddev_exec_time,
total_exec_time_frac,
shared_blks_accessed_frac,
shared_blks_read_frac,
shared_blks_dirtied_frac,
total_wal_bytes_frac,
shared_blks_hit,
shared_blks_read,
shared_blks_dirtied,
wal_bytes
FROM resource_fractions
WHERE
total_exec_time_frac > {frac_threshold}
OR shared_blks_accessed_frac > {frac_threshold}
OR shared_blks_read_frac > {frac_threshold}
OR shared_blks_dirtied_frac > {frac_threshold}
OR total_wal_bytes_frac > {frac_threshold}
ORDER BY total_exec_time DESC
""",
)

logger.debug(f"Executing query: {query}")
slow_query_rows = await SafeSqlDriver.execute_param_query(
self.sql_driver,
query,
)
resource_queries = [row.cells for row in slow_query_rows] if slow_query_rows else []
logger.info(f"Found {len(resource_queries)} resource-intensive queries")

return str(resource_queries)
except Exception as e:
logger.error(f"Error getting resource-intensive queries: {e}", exc_info=True)
return f"Error resource-intensive queries: {e}"
Loading