Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,12 @@ await subscriberService.GetCustomerOrThrow(provider,
}

public async Task<byte[]> GenerateClientInvoiceReport(
Guid providerId,
string invoiceId)
{
ArgumentException.ThrowIfNullOrEmpty(invoiceId);

var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId);
var invoiceItems = await providerInvoiceItemRepository.GetByProviderIdAndInvoiceId(providerId, invoiceId);

if (invoiceItems.Count == 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,23 +485,26 @@ await sutProvider.GetDependency<IStripeAdapter>().Received(1).CreateCustomerAsyn

[Theory, BitAutoData]
public async Task GenerateClientInvoiceReport_NullInvoiceId_ThrowsArgumentNullException(
Guid providerId,
SutProvider<ProviderBillingService> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.GenerateClientInvoiceReport(null));
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.GenerateClientInvoiceReport(providerId, null));

[Theory, BitAutoData]
public async Task GenerateClientInvoiceReport_NoInvoiceItems_ReturnsNull(
Guid providerId,
string invoiceId,
SutProvider<ProviderBillingService> sutProvider)
{
sutProvider.GetDependency<IProviderInvoiceItemRepository>().GetByInvoiceId(invoiceId).Returns([]);
sutProvider.GetDependency<IProviderInvoiceItemRepository>().GetByProviderIdAndInvoiceId(providerId, invoiceId).Returns([]);

var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(invoiceId);
var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(providerId, invoiceId);

Assert.Null(reportContent);
}

[Theory, BitAutoData]
public async Task GenerateClientInvoiceReport_Succeeds(
Guid providerId,
string invoiceId,
SutProvider<ProviderBillingService> sutProvider)
{
Expand All @@ -520,9 +523,9 @@ public async Task GenerateClientInvoiceReport_Succeeds(
}
};

sutProvider.GetDependency<IProviderInvoiceItemRepository>().GetByInvoiceId(invoiceId).Returns(invoiceItems);
sutProvider.GetDependency<IProviderInvoiceItemRepository>().GetByProviderIdAndInvoiceId(providerId, invoiceId).Returns(invoiceItems);

var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(invoiceId);
var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(providerId, invoiceId);

using var memoryStream = new MemoryStream(reportContent);

Expand Down
4 changes: 2 additions & 2 deletions src/Api/Billing/Controllers/ProviderBillingController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ public async Task<IResult> GenerateClientInvoiceReportAsync([FromRoute] Guid pro
return result;
}

var reportContent = await providerBillingService.GenerateClientInvoiceReport(invoiceId);
var reportContent = await providerBillingService.GenerateClientInvoiceReport(provider.Id, invoiceId);

if (reportContent == null)
{
return Error.ServerError("We had a problem generating your invoice CSV. Please contact support.");
return Error.NotFound();
}

return TypedResults.File(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ public interface IProviderInvoiceItemRepository : IRepository<ProviderInvoiceIte
{
Task<ICollection<ProviderInvoiceItem>> GetByInvoiceId(string invoiceId);
Task<ICollection<ProviderInvoiceItem>> GetByProviderId(Guid providerId);
Task<ICollection<ProviderInvoiceItem>> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ Task CreateCustomerForClientOrganization(
/// <summary>
/// Generate a provider's client invoice report in CSV format for the specified <paramref name="invoiceId"/>. Utilizes the <see cref="ProviderInvoiceItem"/>
/// records saved for the <paramref name="invoiceId"/> as part of our webhook processing for the <b>"invoice.created"</b> and <b>"invoice.finalized"</b> Stripe events.
/// The report is scoped to the provided <paramref name="providerId"/>, so a provider can only generate reports for invoices it owns.
/// </summary>
/// <param name="providerId">The ID of the <see cref="Provider"/> that owns the invoice. Only items belonging to this provider are included.</param>
/// <param name="invoiceId">The ID of the Stripe <see cref="Stripe.Invoice"/> to generate the report for.</param>
/// <returns>The provider's client invoice report as a byte array.</returns>
Task<byte[]> GenerateClientInvoiceReport(
Guid providerId,
string invoiceId);

Task<IEnumerable<AddableOrganization>> GetAddableOrganizations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ProviderInvoiceItemRepository(
{
public async Task<ICollection<ProviderInvoiceItem>> GetByInvoiceId(string invoiceId)
{
var sqlConnection = new SqlConnection(ConnectionString);
await using var sqlConnection = new SqlConnection(ConnectionString);

var results = await sqlConnection.QueryAsync<ProviderInvoiceItem>(
"[dbo].[ProviderInvoiceItem_ReadByInvoiceId]",
Expand All @@ -28,7 +28,7 @@ public async Task<ICollection<ProviderInvoiceItem>> GetByInvoiceId(string invoic

public async Task<ICollection<ProviderInvoiceItem>> GetByProviderId(Guid providerId)
{
var sqlConnection = new SqlConnection(ConnectionString);
await using var sqlConnection = new SqlConnection(ConnectionString);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ“ Good catch with these.


var results = await sqlConnection.QueryAsync<ProviderInvoiceItem>(
"[dbo].[ProviderInvoiceItem_ReadByProviderId]",
Expand All @@ -37,4 +37,16 @@ public async Task<ICollection<ProviderInvoiceItem>> GetByProviderId(Guid provide

return results.ToArray();
}

public async Task<ICollection<ProviderInvoiceItem>> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId)
{
await using var sqlConnection = new SqlConnection(ConnectionString);

var results = await sqlConnection.QueryAsync<ProviderInvoiceItem>(
"[dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId]",
new { ProviderId = providerId, InvoiceId = invoiceId },
commandType: CommandType.StoredProcedure);

return results.ToArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,19 @@ from providerInvoiceItem in databaseContext.ProviderInvoiceItems

return await query.ToArrayAsync();
}

public async Task<ICollection<ProviderInvoiceItem>> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId)
{
using var serviceScope = ServiceScopeFactory.CreateScope();

var databaseContext = GetDatabaseContext(serviceScope);

var query =
from providerInvoiceItem in databaseContext.ProviderInvoiceItems
where providerInvoiceItem.ProviderId == providerId &&
providerInvoiceItem.InvoiceId == invoiceId
select providerInvoiceItem;

return await query.ToArrayAsync();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CREATE PROCEDURE [dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId]
@ProviderId UNIQUEIDENTIFIER,
@InvoiceId VARCHAR (50)
AS
BEGIN
SET NOCOUNT ON

SELECT
*
FROM
[dbo].[ProviderInvoiceItemView]
WHERE
[ProviderId] = @ProviderId
AND [InvoiceId] = @InvoiceId
END
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
ο»Ώusing System.Net;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Repositories;
using Xunit;

namespace Bit.Api.IntegrationTest.Billing.Controllers;

/// <summary>
/// Integration tests for the provider client invoice CSV endpoint
/// (GET /providers/{providerId}/billing/invoices/{invoiceId}) on ProviderBillingController,
/// focusing on cross-provider authorization.
///
/// Reproduces VULN-565 (PM-36574): the action authorizes the caller against the route providerId,
/// but the report service loaded ProviderInvoiceItem rows by the attacker-supplied invoiceId without
/// checking the invoice belongs to the authorized provider. A provider admin for provider A could
/// therefore retrieve provider B's client invoice CSV by passing B's invoiceId through A's route.
///
/// The report must be scoped to the authorized provider, so the victim provider's client billing
/// data must never appear in the attacker's response.
/// </summary>
public class ProviderBillingControllerAuthorizationTests : IClassFixture<ApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly ApiApplicationFactory _factory;
private readonly LoginHelper _loginHelper;

private string _attackerAdminEmail = null!;
private Provider _attackerProvider = null!;
private Provider _victimProvider = null!;
private string _victimInvoiceId = null!;

// Distinctive victim values that must never leak into the attacker's response.
private const string VictimClientName = "Victim Client Organization";
private const string VictimPlanName = "Enterprise (Annually)";

public ProviderBillingControllerAuthorizationTests(ApiApplicationFactory factory)
{
_factory = factory;
_client = _factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}

public async Task InitializeAsync()
{
var userRepository = _factory.GetService<IUserRepository>();
var providerRepository = _factory.GetService<IProviderRepository>();
var providerUserRepository = _factory.GetService<IProviderUserRepository>();
var providerInvoiceItemRepository = _factory.GetService<IProviderInvoiceItemRepository>();

// Attacker: a provider admin of their own billable provider (provider A).
_attackerAdminEmail = $"vuln565-attacker-{Guid.NewGuid()}@test.com";
await _factory.LoginWithNewAccount(_attackerAdminEmail);
var attackerAdmin = await userRepository.GetByEmailAsync(_attackerAdminEmail);

_attackerProvider = await CreateBillableProviderAsync(providerRepository, "Attacker Provider");
await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = _attackerProvider.Id,
UserId = attackerAdmin!.Id,
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Confirmed,
Key = Guid.NewGuid().ToString()
});

// Victim: a different provider (provider B) that owns an invoice item. The attacker has no
// membership in this provider.
_victimProvider = await CreateBillableProviderAsync(providerRepository, "Victim Provider");
_victimInvoiceId = $"in_{Guid.NewGuid():N}";
await providerInvoiceItemRepository.CreateAsync(new ProviderInvoiceItem
{
ProviderId = _victimProvider.Id,
InvoiceId = _victimInvoiceId,
InvoiceNumber = "INV-VICTIM-001",
ClientId = Guid.NewGuid(),
ClientName = VictimClientName,
PlanName = VictimPlanName,
AssignedSeats = 42,
UsedSeats = 17,
Total = 1234.56m
});
}

public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}

/// <summary>
/// Control: the attacker cannot use the victim provider's own route β€” they are not a provider
/// admin of provider B, so authorization rejects the request. Passes before and after the fix;
/// anchors the IDOR test below by confirming the attacker has no legitimate access to provider B.
/// </summary>
[Fact]
public async Task GenerateClientInvoiceReport_ThroughVictimProviderRoute_IsUnauthorized()
{
await _loginHelper.LoginAsync(_attackerAdminEmail);

var response = await _client.GetAsync(
$"providers/{_victimProvider.Id}/billing/invoices/{_victimInvoiceId}");

Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
}

/// <summary>
/// Reproduces VULN-565: the attacker requests the victim's invoiceId through their OWN provider's
/// route, which passes authorization. The report must be scoped to the authorized provider, so
/// none of the victim provider's client billing data may be returned.
/// </summary>
[Fact]
public async Task GenerateClientInvoiceReport_ForInvoiceOwnedByAnotherProvider_DoesNotReturnVictimData()
{
await _loginHelper.LoginAsync(_attackerAdminEmail);

var response = await _client.GetAsync(
$"providers/{_attackerProvider.Id}/billing/invoices/{_victimInvoiceId}");

var body = await response.Content.ReadAsStringAsync();

// Core security invariant: the authorized provider does not own this invoice, so the victim
// provider's client billing data must not appear anywhere in the response.
Assert.DoesNotContain(VictimClientName, body);
Assert.DoesNotContain(VictimPlanName, body);

// The endpoint must not hand back a CSV for an invoice the provider doesn't own; the
// authorized provider has no such invoice, so the scoped lookup yields nothing -> 404.
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}

private static Task<Provider> CreateBillableProviderAsync(IProviderRepository providerRepository, string name) =>
providerRepository.CreateAsync(new Provider
{
Name = name,
BillingEmail = $"{name.Replace(" ", "-").ToLowerInvariant()}@example.com",
Type = ProviderType.Msp,
Status = ProviderStatusType.Billable,
Enabled = true,
GatewayCustomerId = $"cus_{Guid.NewGuid():N}",
GatewaySubscriptionId = $"sub_{Guid.NewGuid():N}"
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Models.BitStripe;
using Bit.Core.Test.Billing.Mocks;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
Expand Down Expand Up @@ -159,24 +157,19 @@ public async Task GetInvoices_Ok(
#region GenerateClientInvoiceReportAsync

[Theory, BitAutoData]
public async Task GenerateClientInvoiceReportAsync_NullReportContent_ServerError(
public async Task GenerateClientInvoiceReportAsync_NullReportContent_NotFound(
Provider provider,
string invoiceId,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableProviderAdminInputs(provider, sutProvider);

sutProvider.GetDependency<IProviderBillingService>().GenerateClientInvoiceReport(invoiceId)
sutProvider.GetDependency<IProviderBillingService>().GenerateClientInvoiceReport(provider.Id, invoiceId)
.ReturnsNull();

var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId);

Assert.IsType<JsonHttpResult<ErrorResponseModel>>(result);

var response = (JsonHttpResult<ErrorResponseModel>)result;

Assert.Equal(StatusCodes.Status500InternalServerError, response.StatusCode);
Assert.Equal("We had a problem generating your invoice CSV. Please contact support.", response.Value.Message);
AssertNotFound(result);
}

[Theory, BitAutoData]
Expand All @@ -189,7 +182,7 @@ public async Task GenerateClientInvoiceReportAsync_Ok(

var reportContent = "Report"u8.ToArray();

sutProvider.GetDependency<IProviderBillingService>().GenerateClientInvoiceReport(invoiceId)
sutProvider.GetDependency<IProviderBillingService>().GenerateClientInvoiceReport(provider.Id, invoiceId)
.Returns(reportContent);

var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Scope the provider client invoice report to the authorized provider (VULN-565 / PM-36574).
-- Adds a provider-scoped lookup so a provider can only read invoice items it owns.
CREATE OR ALTER PROCEDURE [dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId]
@ProviderId UNIQUEIDENTIFIER,
@InvoiceId VARCHAR (50)
AS
BEGIN
SET NOCOUNT ON

SELECT
*
FROM
[dbo].[ProviderInvoiceItemView]
WHERE
[ProviderId] = @ProviderId
AND [InvoiceId] = @InvoiceId
END
GO
Loading