Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
changes:
- section: "Bugs Fixed"
description: "Cap 'azmcp postgres list' database and table results at 10,000 entries to prevent unbounded enumeration on large servers. When the results are truncated, the response includes a 'resultsTruncated: true' flag."
2 changes: 2 additions & 0 deletions servers/Azure.Mcp.Server/docs/azmcp-commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,8 @@ azmcp mysql server param set --subscription <subscription> \
# Without parameters: lists all PostgreSQL servers in the resource group
# With --server: lists all databases on that server
# With --server and --database: lists all tables in that database (optionally scoped to a --schema, defaults to 'public')
# Database and table results are capped at 10,000 entries. When the results are truncated,
# the response includes "resultsTruncated": true.
# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired
azmcp postgres list --subscription <subscription> \
--resource-group <resource-group> \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
if (!string.IsNullOrEmpty(options.Database))
{
// List tables in specified database
List<string> tables = await _postgresService.ListTablesAsync(
TableListResult tableResult = await _postgresService.ListTablesAsync(
options.AuthType!,
options.User!,
options.Password,
Expand All @@ -92,21 +92,21 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
cancellationToken);

context.Response.Results = ResponseResult.Create(
new(null, null, tables ?? []),
new(null, null, tableResult.Tables ?? [], tableResult.IsTruncated ? true : null),
PostgresJsonContext.Default.PostgresListCommandResult);
}
else if (!string.IsNullOrEmpty(options.Server))
{
// List databases on specified server
List<string> databases = await _postgresService.ListDatabasesAsync(
DatabaseListResult databaseResult = await _postgresService.ListDatabasesAsync(
options.AuthType!,
options.User!,
options.Password,
options.Server!,
cancellationToken);

context.Response.Results = ResponseResult.Create(
new(null, databases ?? [], null),
new(null, databaseResult.Databases ?? [], null, databaseResult.IsTruncated ? true : null),
PostgresJsonContext.Default.PostgresListCommandResult);
}
else
Expand All @@ -131,5 +131,5 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
return context.Response;
}

public record PostgresListCommandResult(List<string>? Servers, List<string>? Databases, List<string>? Tables);
public record PostgresListCommandResult(List<string>? Servers, List<string>? Databases, List<string>? Tables, bool? ResultsTruncated = null);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Azure.Mcp.Tools.Postgres.Services;

public sealed record DatabaseListResult(List<string> Databases, bool IsTruncated);
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Azure.Mcp.Tools.Postgres.Services;

public interface IPostgresService
{
Task<List<string>> ListDatabasesAsync(
Task<DatabaseListResult> ListDatabasesAsync(
string authType,
string user,
string? password,
Expand All @@ -23,7 +23,7 @@ Task<List<string>> ExecuteQueryAsync(
string query,
CancellationToken cancellationToken);

Task<List<string>> ListTablesAsync(
Task<TableListResult> ListTablesAsync(
string authType,
string user,
string? password,
Expand Down
32 changes: 26 additions & 6 deletions tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class PostgresService(
private readonly IEntraTokenProvider _entraTokenAuth = entraTokenAuth;
private readonly IDbProvider _dbProvider = dbProvider;

internal const int MaxRowCount = 10_000;

private async Task<string> GetEntraIdAccessTokenAsync(CancellationToken cancellationToken)
{
var tokenCredential = await GetCredential(cancellationToken);
Expand Down Expand Up @@ -78,7 +80,7 @@ private string NormalizeServerName(string server)
return server;
}

public async Task<List<string>> ListDatabasesAsync(
public async Task<DatabaseListResult> ListDatabasesAsync(
string authType,
string user,
string? password,
Expand All @@ -89,16 +91,25 @@ public async Task<List<string>> ListDatabasesAsync(
var host = NormalizeServerName(server);
var connectionString = BuildConnectionString(host, "postgres", user, passwordToUse);

var query = "SELECT datname FROM pg_database WHERE datistemplate = false;";
var query = "SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname LIMIT @maxResults;";
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString, authType, cancellationToken);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
// Fetch cap+1 rows so we can detect truncation by observing whether an extra row exists, then trim it.
command.Parameters.AddWithValue("maxResults", MaxRowCount + 1);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command, cancellationToken);
var dbs = new List<string>();
while (await reader.ReadAsync(cancellationToken))
{
dbs.Add(reader.GetString(0));
}
return dbs;

var isTruncated = dbs.Count > MaxRowCount;
if (isTruncated)
{
dbs.RemoveRange(MaxRowCount, dbs.Count - MaxRowCount);
}

return new DatabaseListResult(dbs, isTruncated);
}

public async Task<List<string>> ExecuteQueryAsync(
Expand Down Expand Up @@ -157,7 +168,7 @@ public async Task<List<string>> ExecuteQueryAsync(
return rows;
}

public async Task<List<string>> ListTablesAsync(
public async Task<TableListResult> ListTablesAsync(
string authType,
string user,
string? password,
Expand All @@ -170,17 +181,26 @@ public async Task<List<string>> ListTablesAsync(
var host = NormalizeServerName(server);
var connectionString = BuildConnectionString(host, database, user, passwordToUse);

var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = @schema ORDER BY table_name;";
var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = @schema ORDER BY table_name LIMIT @maxResults;";
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString, authType, cancellationToken);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
command.Parameters.AddWithValue("schema", schema);
// Fetch cap+1 rows so we can detect truncation by observing whether an extra row exists, then trim it.
command.Parameters.AddWithValue("maxResults", MaxRowCount + 1);
Comment thread
vcolin7 marked this conversation as resolved.
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command, cancellationToken);
var tables = new List<string>();
while (await reader.ReadAsync(cancellationToken))
{
tables.Add(reader.GetString(0));
}
return tables;

var isTruncated = tables.Count > MaxRowCount;
if (isTruncated)
{
tables.RemoveRange(MaxRowCount, tables.Count - MaxRowCount);
}

return new TableListResult(tables, isTruncated);
}

public async Task<List<string>> GetTableSchemaAsync(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Azure.Mcp.Tools.Postgres.Services;

public sealed record TableListResult(List<string> Tables, bool IsTruncated);
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public async Task ExecuteAsync_ListsDatabases_WhenServerProvided()
null,
"server1",
Arg.Any<CancellationToken>())
.Returns(expectedDatabases);
.Returns(new DatabaseListResult(expectedDatabases, false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand All @@ -90,6 +90,7 @@ public async Task ExecuteAsync_ListsDatabases_WhenServerProvided()
Assert.Null(result.Servers);
Assert.Equal(expectedDatabases, result.Databases);
Assert.Null(result.Tables);
Assert.Null(result.ResultsTruncated);
}

[Fact]
Expand All @@ -104,7 +105,7 @@ public async Task ExecuteAsync_ListsTables_WhenServerAndDatabaseProvided()
"db1",
"public",
Arg.Any<CancellationToken>())
.Returns(expectedTables);
.Returns(new TableListResult(expectedTables, false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand All @@ -119,6 +120,64 @@ public async Task ExecuteAsync_ListsTables_WhenServerAndDatabaseProvided()
Assert.Null(result.Servers);
Assert.Null(result.Databases);
Assert.Equal(expectedTables, result.Tables);
Assert.Null(result.ResultsTruncated);
}

[Fact]
public async Task ExecuteAsync_SetsResultsTruncated_WhenTableResultsAreTruncated()
{
var expectedTables = new List<string> { "users", "products", "orders" };
Service.ListTablesAsync(
AuthTypes.MicrosoftEntra,
"user1",
null,
"server1",
"db1",
"public",
Arg.Any<CancellationToken>())
.Returns(new TableListResult(expectedTables, true));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
"--user", "user1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--server", "server1",
"--database", "db1");

var result = ValidateAndDeserializeResponse(response, PostgresJsonContext.Default.PostgresListCommandResult);

Assert.Null(result.Servers);
Assert.Null(result.Databases);
Assert.Equal(expectedTables, result.Tables);
Assert.True(result.ResultsTruncated);
}

[Fact]
public async Task ExecuteAsync_SetsResultsTruncated_WhenDatabaseResultsAreTruncated()
{
var expectedDatabases = new List<string> { "db1", "db2", "db3" };
Service.ListDatabasesAsync(
AuthTypes.MicrosoftEntra,
"user1",
null,
"server1",
Arg.Any<CancellationToken>())
.Returns(new DatabaseListResult(expectedDatabases, true));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
"--user", "user1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--server", "server1");

var result = ValidateAndDeserializeResponse(response, PostgresJsonContext.Default.PostgresListCommandResult);

Assert.Null(result.Servers);
Assert.Equal(expectedDatabases, result.Databases);
Assert.Null(result.Tables);
Assert.True(result.ResultsTruncated);
}

[Fact]
Expand Down Expand Up @@ -147,7 +206,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoDatabasesExist()
null,
"server1",
Arg.Any<CancellationToken>())
.Returns([]);
.Returns(new DatabaseListResult([], false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand Down Expand Up @@ -175,7 +234,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoTablesExist()
"db1",
"public",
Arg.Any<CancellationToken>())
.Returns([]);
.Returns(new TableListResult([], false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand All @@ -191,6 +250,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoTablesExist()
Assert.Null(result.Databases);
Assert.NotNull(result.Tables);
Assert.Empty(result.Tables);
Assert.Null(result.ResultsTruncated);
}

[Fact]
Expand All @@ -205,7 +265,7 @@ public async Task ExecuteAsync_ListsTablesWithSpecifiedSchema_WhenSchemaProvided
"db1",
"analytics",
Arg.Any<CancellationToken>())
.Returns(expectedTables);
.Returns(new TableListResult(expectedTables, false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand Down Expand Up @@ -236,7 +296,7 @@ public async Task ExecuteAsync_ListsTablesWithPublicSchema_WhenSchemaOmitted()
"db1",
"public",
Arg.Any<CancellationToken>())
.Returns(["users"]);
.Returns(new TableListResult(["users"], false));

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
Expand Down
Loading