Skip to content

Commit 2dbd954

Browse files
authored
add top queries by resources (#54)
* add top queries by resources * formatting * fix tests * fixes * lint and type check
1 parent 070b842 commit 2dbd954

File tree

5 files changed

+201
-81
lines changed

5 files changed

+201
-81
lines changed

src/postgres_mcp/server.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -460,26 +460,30 @@ async def analyze_db_health(
460460
return format_text_response(result)
461461

462462

463-
@mcp.tool(description=f"Reports the slowest SQL queries based on execution time, using data from the '{PG_STAT_STATEMENTS}' extension.")
463+
@mcp.tool(
464+
name="get_top_queries",
465+
description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.",
466+
)
464467
async def get_top_queries(
465-
limit: int = Field(description="Number of slow queries to return", default=10),
466468
sort_by: str = Field(
467-
description="Sort criteria: 'total' for total execution time or 'mean' for mean execution time per call",
468-
default="mean",
469+
description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' "
470+
"for resource-intensive queries",
471+
default="resources",
469472
),
473+
limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10),
470474
) -> ResponseType:
471-
"""Reports the slowest SQL queries based on execution time.
472-
473-
This tool handles PostgreSQL version differences automatically:
474-
- In PostgreSQL 13+: Uses total_exec_time/mean_exec_time columns
475-
- In PostgreSQL 12 and older: Uses total_time/mean_time columns
476-
"""
477475
try:
478476
sql_driver = await get_sql_driver()
479477
top_queries_tool = TopQueriesCalc(sql_driver=sql_driver)
480-
if sort_by != "mean" and sort_by != "total":
481-
return format_error_response("Invalid sort criteria. Please use 'mean' or 'total'.")
482-
result = await top_queries_tool.get_top_queries(limit=limit, sort_by=sort_by)
478+
479+
if sort_by == "resources":
480+
result = await top_queries_tool.get_top_resource_queries()
481+
return format_text_response(result)
482+
elif sort_by == "mean_time" or sort_by == "total_time":
483+
# Map the sort_by values to what get_top_queries_by_time expects
484+
result = await top_queries_tool.get_top_queries_by_time(limit=limit, sort_by="mean" if sort_by == "mean_time" else "total")
485+
else:
486+
return format_error_response("Invalid sort criteria. Please use 'resources' or 'mean_time' or 'total_time'.")
483487
return format_text_response(result)
484488
except Exception as e:
485489
logger.error(f"Error getting slow queries: {e}")

src/postgres_mcp/sql/extension_utils.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Utilities for working with PostgreSQL extensions."""
22

33
import logging
4+
from dataclasses import dataclass
45
from typing import Literal
5-
from typing import TypedDict
66

77
from .safe_sql import SafeSqlDriver
88
from .sql_driver import SqlDriver
@@ -14,7 +14,8 @@
1414
_POSTGRES_VERSION = None
1515

1616

17-
class ExtensionStatus(TypedDict):
17+
@dataclass
18+
class ExtensionStatus:
1819
"""Status of an extension."""
1920

2021
is_installed: bool
@@ -118,25 +119,25 @@ async def check_extension(
118119
)
119120

120121
# Initialize result
121-
result: ExtensionStatus = {
122-
"is_installed": False,
123-
"is_available": False,
124-
"name": extension_name,
125-
"message": "",
126-
"default_version": None,
127-
}
122+
result = ExtensionStatus(
123+
is_installed=False,
124+
is_available=False,
125+
name=extension_name,
126+
message="",
127+
default_version=None,
128+
)
128129

129130
if installed_result and len(installed_result) > 0:
130131
# Extension is installed
131132
version = installed_result[0].cells.get("extversion", "unknown")
132-
result["is_installed"] = True
133-
result["is_available"] = True
133+
result.is_installed = True
134+
result.is_available = True
134135

135136
if include_messages:
136137
if message_type == "markdown":
137-
result["message"] = f"The **{extension_name}** extension (version {version}) is already installed."
138+
result.message = f"The **{extension_name}** extension (version {version}) is already installed."
138139
else:
139-
result["message"] = f"The {extension_name} extension (version {version}) is already installed."
140+
result.message = f"The {extension_name} extension (version {version}) is already installed."
140141
else:
141142
# Check if the extension is available but not installed
142143
available_result = await SafeSqlDriver.execute_param_query(
@@ -147,32 +148,32 @@ async def check_extension(
147148

148149
if available_result and len(available_result) > 0:
149150
# Extension is available but not installed
150-
result["is_available"] = True
151-
result["default_version"] = available_result[0].cells.get("default_version")
151+
result.is_available = True
152+
result.default_version = available_result[0].cells.get("default_version")
152153

153154
if include_messages:
154155
if message_type == "markdown":
155-
result["message"] = (
156+
result.message = (
156157
f"The **{extension_name}** extension is available but not installed.\n\n"
157158
f"You can install it by running: `CREATE EXTENSION {extension_name};`."
158159
)
159160
else:
160-
result["message"] = (
161+
result.message = (
161162
f"The {extension_name} extension is available but not installed.\n"
162163
f"You can install it by running: CREATE EXTENSION {extension_name};"
163164
)
164165
else:
165166
# Extension is not available
166167
if include_messages:
167168
if message_type == "markdown":
168-
result["message"] = (
169+
result.message = (
169170
f"The **{extension_name}** extension is not available on this PostgreSQL server.\n\n"
170171
f"To install it, you need to:\n"
171172
f"1. Install the extension package on the server\n"
172173
f"2. Run: `CREATE EXTENSION {extension_name};`"
173174
)
174175
else:
175-
result["message"] = (
176+
result.message = (
176177
f"The {extension_name} extension is not available on this PostgreSQL server.\n"
177178
f"To install it, you need to:\n"
178179
f"1. Install the extension package on the server\n"
@@ -195,13 +196,13 @@ async def check_hypopg_installation_status(sql_driver: SqlDriver, message_type:
195196
"""
196197
status = await check_extension(sql_driver, "hypopg", include_messages=False)
197198

198-
if status["is_installed"]:
199+
if status.is_installed:
199200
if message_type == "markdown":
200201
return True, "The **hypopg** extension is already installed."
201202
else:
202203
return True, "The hypopg extension is already installed."
203204

204-
if status["is_available"]:
205+
if status.is_available:
205206
if message_type == "markdown":
206207
return False, (
207208
"The **hypopg** extension is required to test hypothetical indexes, but it is not currently installed.\n\n"

src/postgres_mcp/top_queries/top_queries_calc.py

Lines changed: 130 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,41 @@
1+
import logging
12
from typing import Literal
3+
from typing import LiteralString
24
from typing import Union
5+
from typing import cast
36

47
from ..sql import SafeSqlDriver
58
from ..sql import SqlDriver
69
from ..sql.extension_utils import check_extension
710
from ..sql.extension_utils import get_postgres_version
811

12+
logger = logging.getLogger(__name__)
13+
914
PG_STAT_STATEMENTS = "pg_stat_statements"
1015

16+
install_pg_stat_statements_message = (
17+
"The pg_stat_statements extension is required to "
18+
"report slow queries, but it is not currently "
19+
"installed.\n\n"
20+
"You can install it by running: "
21+
"`CREATE EXTENSION pg_stat_statements;`\n\n"
22+
"**What does it do?** It records statistics (like "
23+
"execution time, number of calls, rows returned) for "
24+
"every query executed against the database.\n\n"
25+
"**Is it safe?** Installing 'pg_stat_statements' is "
26+
"generally safe and a standard practice for performance "
27+
"monitoring. It adds overhead by tracking statistics, "
28+
"but this is usually negligible unless under extreme load."
29+
)
30+
1131

1232
class TopQueriesCalc:
1333
"""Tool for retrieving the slowest SQL queries."""
1434

1535
def __init__(self, sql_driver: Union[SqlDriver, SafeSqlDriver]):
1636
self.sql_driver = sql_driver
1737

18-
async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str:
38+
async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str:
1939
"""Reports the slowest SQL queries based on execution time.
2040
2141
Args:
@@ -27,32 +47,21 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
2747
A string with the top queries or installation instructions
2848
"""
2949
try:
50+
logger.debug(f"Getting top queries by time. limit={limit}, sort_by={sort_by}")
3051
extension_status = await check_extension(
3152
self.sql_driver,
3253
PG_STAT_STATEMENTS,
3354
include_messages=False,
3455
)
3556

36-
if not extension_status["is_installed"]:
57+
if not extension_status.is_installed:
58+
logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed")
3759
# Return installation instructions if the extension is not installed
38-
monitoring_message = (
39-
f"The '{PG_STAT_STATEMENTS}' extension is required to "
40-
f"report slow queries, but it is not currently "
41-
f"installed.\n\n"
42-
f"You can install it by running: "
43-
f"`CREATE EXTENSION {PG_STAT_STATEMENTS};`\n\n"
44-
f"**What does it do?** It records statistics (like "
45-
f"execution time, number of calls, rows returned) for "
46-
f"every query executed against the database.\n\n"
47-
f"**Is it safe?** Installing '{PG_STAT_STATEMENTS}' is "
48-
f"generally safe and a standard practice for performance "
49-
f"monitoring. It adds overhead by tracking statistics, "
50-
f"but this is usually negligible unless under extreme load."
51-
)
52-
return monitoring_message
60+
return install_pg_stat_statements_message
5361

5462
# Check PostgreSQL version to determine column names
5563
pg_version = await get_postgres_version(self.sql_driver)
64+
logger.debug(f"PostgreSQL version: {pg_version}")
5665

5766
# Column names changed in PostgreSQL 13
5867
if pg_version >= 13:
@@ -64,6 +73,8 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
6473
total_time_col = "total_time"
6574
mean_time_col = "mean_time"
6675

76+
logger.debug(f"Using time columns: total={total_time_col}, mean={mean_time_col}")
77+
6778
# Determine which column to sort by based on sort_by parameter and version
6879
order_by_column = total_time_col if sort_by == "total" else mean_time_col
6980

@@ -78,12 +89,14 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
7889
ORDER BY {order_by_column} DESC
7990
LIMIT {{}};
8091
"""
92+
logger.debug(f"Executing query: {query}")
8193
slow_query_rows = await SafeSqlDriver.execute_param_query(
8294
self.sql_driver,
8395
query,
8496
[limit],
8597
)
8698
slow_queries = [row.cells for row in slow_query_rows] if slow_query_rows else []
99+
logger.info(f"Found {len(slow_queries)} slow queries")
87100

88101
# Create result description based on sort criteria
89102
if sort_by == "total":
@@ -95,4 +108,104 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean
95108
result += str(slow_queries)
96109
return result
97110
except Exception as e:
111+
logger.error(f"Error getting slow queries: {e}", exc_info=True)
98112
return f"Error getting slow queries: {e}"
113+
114+
async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str:
115+
"""Reports the most time consuming queries based on a resource blend.
116+
117+
Args:
118+
frac_threshold: Fraction threshold for filtering queries (default: 0.05)
119+
120+
Returns:
121+
A string with the resource-heavy queries or error message
122+
"""
123+
124+
try:
125+
logger.debug(f"Getting top resource queries with threshold {frac_threshold}")
126+
extension_status = await check_extension(
127+
self.sql_driver,
128+
PG_STAT_STATEMENTS,
129+
include_messages=False,
130+
)
131+
132+
if not extension_status.is_installed:
133+
logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed")
134+
# Return installation instructions if the extension is not installed
135+
return install_pg_stat_statements_message
136+
137+
# Check PostgreSQL version to determine column names
138+
pg_version = await get_postgres_version(self.sql_driver)
139+
logger.debug(f"PostgreSQL version: {pg_version}")
140+
141+
# Column names changed in PostgreSQL 13
142+
if pg_version >= 13:
143+
# PostgreSQL 13 and newer
144+
total_time_col = "total_exec_time"
145+
mean_time_col = "mean_exec_time"
146+
else:
147+
# PostgreSQL 12 and older
148+
total_time_col = "total_time"
149+
mean_time_col = "mean_time"
150+
151+
query = cast(
152+
LiteralString,
153+
f"""
154+
WITH resource_fractions AS (
155+
SELECT
156+
query,
157+
calls,
158+
rows,
159+
{total_time_col} total_exec_time,
160+
{mean_time_col} mean_exec_time,
161+
stddev_exec_time,
162+
shared_blks_hit,
163+
shared_blks_read,
164+
shared_blks_dirtied,
165+
wal_bytes,
166+
total_exec_time / SUM(total_exec_time) OVER () AS total_exec_time_frac,
167+
(shared_blks_hit + shared_blks_read) / SUM(shared_blks_hit + shared_blks_read) OVER () AS shared_blks_accessed_frac,
168+
shared_blks_read / SUM(shared_blks_read) OVER () AS shared_blks_read_frac,
169+
shared_blks_dirtied / SUM(shared_blks_dirtied) OVER () AS shared_blks_dirtied_frac,
170+
wal_bytes / SUM(wal_bytes) OVER () AS total_wal_bytes_frac
171+
FROM pg_stat_statements
172+
)
173+
SELECT
174+
query,
175+
calls,
176+
rows,
177+
total_exec_time,
178+
mean_exec_time,
179+
stddev_exec_time,
180+
total_exec_time_frac,
181+
shared_blks_accessed_frac,
182+
shared_blks_read_frac,
183+
shared_blks_dirtied_frac,
184+
total_wal_bytes_frac,
185+
shared_blks_hit,
186+
shared_blks_read,
187+
shared_blks_dirtied,
188+
wal_bytes
189+
FROM resource_fractions
190+
WHERE
191+
total_exec_time_frac > {frac_threshold}
192+
OR shared_blks_accessed_frac > {frac_threshold}
193+
OR shared_blks_read_frac > {frac_threshold}
194+
OR shared_blks_dirtied_frac > {frac_threshold}
195+
OR total_wal_bytes_frac > {frac_threshold}
196+
ORDER BY total_exec_time DESC
197+
""",
198+
)
199+
200+
logger.debug(f"Executing query: {query}")
201+
slow_query_rows = await SafeSqlDriver.execute_param_query(
202+
self.sql_driver,
203+
query,
204+
)
205+
resource_queries = [row.cells for row in slow_query_rows] if slow_query_rows else []
206+
logger.info(f"Found {len(resource_queries)} resource-intensive queries")
207+
208+
return str(resource_queries)
209+
except Exception as e:
210+
logger.error(f"Error getting resource-intensive queries: {e}", exc_info=True)
211+
return f"Error resource-intensive queries: {e}"

0 commit comments

Comments
 (0)