diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index ed5513f2b3a0..324a99734cb9 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -293,11 +293,12 @@ await subscriberService.GetCustomerOrThrow(provider, } public async Task GenerateClientInvoiceReport( + Guid providerId, string invoiceId) { ArgumentException.ThrowIfNullOrEmpty(invoiceId); - var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId); + var invoiceItems = await providerInvoiceItemRepository.GetByProviderIdAndInvoiceId(providerId, invoiceId); if (invoiceItems.Count == 0) { diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 8636fe89db79..ee73c1a426f2 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -485,23 +485,26 @@ await sutProvider.GetDependency().Received(1).CreateCustomerAsyn [Theory, BitAutoData] public async Task GenerateClientInvoiceReport_NullInvoiceId_ThrowsArgumentNullException( + Guid providerId, SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GenerateClientInvoiceReport(null)); + await Assert.ThrowsAsync(() => sutProvider.Sut.GenerateClientInvoiceReport(providerId, null)); [Theory, BitAutoData] public async Task GenerateClientInvoiceReport_NoInvoiceItems_ReturnsNull( + Guid providerId, string invoiceId, SutProvider sutProvider) { - sutProvider.GetDependency().GetByInvoiceId(invoiceId).Returns([]); + sutProvider.GetDependency().GetByProviderIdAndInvoiceId(providerId, invoiceId).Returns([]); - var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(invoiceId); + var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(providerId, invoiceId); Assert.Null(reportContent); } [Theory, BitAutoData] public async Task GenerateClientInvoiceReport_Succeeds( + Guid providerId, string invoiceId, SutProvider sutProvider) { @@ -520,9 +523,9 @@ public async Task GenerateClientInvoiceReport_Succeeds( } }; - sutProvider.GetDependency().GetByInvoiceId(invoiceId).Returns(invoiceItems); + sutProvider.GetDependency().GetByProviderIdAndInvoiceId(providerId, invoiceId).Returns(invoiceItems); - var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(invoiceId); + var reportContent = await sutProvider.Sut.GenerateClientInvoiceReport(providerId, invoiceId); using var memoryStream = new MemoryStream(reportContent); diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index dfa705a329b8..b37ef028a443 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -64,11 +64,11 @@ public async Task GenerateClientInvoiceReportAsync([FromRoute] Guid pro return result; } - var reportContent = await providerBillingService.GenerateClientInvoiceReport(invoiceId); + var reportContent = await providerBillingService.GenerateClientInvoiceReport(provider.Id, invoiceId); if (reportContent == null) { - return Error.ServerError("We had a problem generating your invoice CSV. Please contact support."); + return Error.NotFound(); } return TypedResults.File( diff --git a/src/Core/Billing/Providers/Repositories/IProviderInvoiceItemRepository.cs b/src/Core/Billing/Providers/Repositories/IProviderInvoiceItemRepository.cs index 931d8a918633..35b05d138d00 100644 --- a/src/Core/Billing/Providers/Repositories/IProviderInvoiceItemRepository.cs +++ b/src/Core/Billing/Providers/Repositories/IProviderInvoiceItemRepository.cs @@ -7,4 +7,5 @@ public interface IProviderInvoiceItemRepository : IRepository> GetByInvoiceId(string invoiceId); Task> GetByProviderId(Guid providerId); + Task> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId); } diff --git a/src/Core/Billing/Providers/Services/IProviderBillingService.cs b/src/Core/Billing/Providers/Services/IProviderBillingService.cs index fcbb8bc0ad43..add9e74fe0d3 100644 --- a/src/Core/Billing/Providers/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Providers/Services/IProviderBillingService.cs @@ -37,10 +37,13 @@ Task CreateCustomerForClientOrganization( /// /// Generate a provider's client invoice report in CSV format for the specified . Utilizes the /// records saved for the as part of our webhook processing for the "invoice.created" and "invoice.finalized" Stripe events. + /// The report is scoped to the provided , so a provider can only generate reports for invoices it owns. /// + /// The ID of the that owns the invoice. Only items belonging to this provider are included. /// The ID of the Stripe to generate the report for. /// The provider's client invoice report as a byte array. Task GenerateClientInvoiceReport( + Guid providerId, string invoiceId); Task> GetAddableOrganizations( diff --git a/src/Infrastructure.Dapper/Billing/Repositories/ProviderInvoiceItemRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/ProviderInvoiceItemRepository.cs index cf5ac07ead5a..e80bf2927813 100644 --- a/src/Infrastructure.Dapper/Billing/Repositories/ProviderInvoiceItemRepository.cs +++ b/src/Infrastructure.Dapper/Billing/Repositories/ProviderInvoiceItemRepository.cs @@ -16,7 +16,7 @@ public class ProviderInvoiceItemRepository( { public async Task> GetByInvoiceId(string invoiceId) { - var sqlConnection = new SqlConnection(ConnectionString); + await using var sqlConnection = new SqlConnection(ConnectionString); var results = await sqlConnection.QueryAsync( "[dbo].[ProviderInvoiceItem_ReadByInvoiceId]", @@ -28,7 +28,7 @@ public async Task> GetByInvoiceId(string invoic public async Task> GetByProviderId(Guid providerId) { - var sqlConnection = new SqlConnection(ConnectionString); + await using var sqlConnection = new SqlConnection(ConnectionString); var results = await sqlConnection.QueryAsync( "[dbo].[ProviderInvoiceItem_ReadByProviderId]", @@ -37,4 +37,16 @@ public async Task> GetByProviderId(Guid provide return results.ToArray(); } + + public async Task> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId) + { + await using var sqlConnection = new SqlConnection(ConnectionString); + + var results = await sqlConnection.QueryAsync( + "[dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId]", + new { ProviderId = providerId, InvoiceId = invoiceId }, + commandType: CommandType.StoredProcedure); + + return results.ToArray(); + } } diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderInvoiceItemRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderInvoiceItemRepository.cs index ed729070ae99..96d321cfb60b 100644 --- a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderInvoiceItemRepository.cs +++ b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderInvoiceItemRepository.cs @@ -43,4 +43,19 @@ from providerInvoiceItem in databaseContext.ProviderInvoiceItems return await query.ToArrayAsync(); } + + public async Task> GetByProviderIdAndInvoiceId(Guid providerId, string invoiceId) + { + using var serviceScope = ServiceScopeFactory.CreateScope(); + + var databaseContext = GetDatabaseContext(serviceScope); + + var query = + from providerInvoiceItem in databaseContext.ProviderInvoiceItems + where providerInvoiceItem.ProviderId == providerId && + providerInvoiceItem.InvoiceId == invoiceId + select providerInvoiceItem; + + return await query.ToArrayAsync(); + } } diff --git a/src/Sql/dbo/Billing/Stored Procedures/ProviderInvoiceItem_ReadByProviderIdInvoiceId.sql b/src/Sql/dbo/Billing/Stored Procedures/ProviderInvoiceItem_ReadByProviderIdInvoiceId.sql new file mode 100644 index 000000000000..e0c1d58705b7 --- /dev/null +++ b/src/Sql/dbo/Billing/Stored Procedures/ProviderInvoiceItem_ReadByProviderIdInvoiceId.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId] + @ProviderId UNIQUEIDENTIFIER, + @InvoiceId VARCHAR (50) +AS +BEGIN + SET NOCOUNT ON + + SELECT + * + FROM + [dbo].[ProviderInvoiceItemView] + WHERE + [ProviderId] = @ProviderId + AND [InvoiceId] = @InvoiceId +END diff --git a/test/Api.IntegrationTest/Billing/Controllers/ProviderBillingControllerAuthorizationTests.cs b/test/Api.IntegrationTest/Billing/Controllers/ProviderBillingControllerAuthorizationTests.cs new file mode 100644 index 000000000000..6bdb9628614f --- /dev/null +++ b/test/Api.IntegrationTest/Billing/Controllers/ProviderBillingControllerAuthorizationTests.cs @@ -0,0 +1,147 @@ +using System.Net; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Providers.Entities; +using Bit.Core.Billing.Providers.Repositories; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Api.IntegrationTest.Billing.Controllers; + +/// +/// Integration tests for the provider client invoice CSV endpoint +/// (GET /providers/{providerId}/billing/invoices/{invoiceId}) on ProviderBillingController, +/// focusing on cross-provider authorization. +/// +/// Reproduces VULN-565 (PM-36574): the action authorizes the caller against the route providerId, +/// but the report service loaded ProviderInvoiceItem rows by the attacker-supplied invoiceId without +/// checking the invoice belongs to the authorized provider. A provider admin for provider A could +/// therefore retrieve provider B's client invoice CSV by passing B's invoiceId through A's route. +/// +/// The report must be scoped to the authorized provider, so the victim provider's client billing +/// data must never appear in the attacker's response. +/// +public class ProviderBillingControllerAuthorizationTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private string _attackerAdminEmail = null!; + private Provider _attackerProvider = null!; + private Provider _victimProvider = null!; + private string _victimInvoiceId = null!; + + // Distinctive victim values that must never leak into the attacker's response. + private const string VictimClientName = "Victim Client Organization"; + private const string VictimPlanName = "Enterprise (Annually)"; + + public ProviderBillingControllerAuthorizationTests(ApiApplicationFactory factory) + { + _factory = factory; + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + var userRepository = _factory.GetService(); + var providerRepository = _factory.GetService(); + var providerUserRepository = _factory.GetService(); + var providerInvoiceItemRepository = _factory.GetService(); + + // Attacker: a provider admin of their own billable provider (provider A). + _attackerAdminEmail = $"vuln565-attacker-{Guid.NewGuid()}@test.com"; + await _factory.LoginWithNewAccount(_attackerAdminEmail); + var attackerAdmin = await userRepository.GetByEmailAsync(_attackerAdminEmail); + + _attackerProvider = await CreateBillableProviderAsync(providerRepository, "Attacker Provider"); + await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = _attackerProvider.Id, + UserId = attackerAdmin!.Id, + Type = ProviderUserType.ProviderAdmin, + Status = ProviderUserStatusType.Confirmed, + Key = Guid.NewGuid().ToString() + }); + + // Victim: a different provider (provider B) that owns an invoice item. The attacker has no + // membership in this provider. + _victimProvider = await CreateBillableProviderAsync(providerRepository, "Victim Provider"); + _victimInvoiceId = $"in_{Guid.NewGuid():N}"; + await providerInvoiceItemRepository.CreateAsync(new ProviderInvoiceItem + { + ProviderId = _victimProvider.Id, + InvoiceId = _victimInvoiceId, + InvoiceNumber = "INV-VICTIM-001", + ClientId = Guid.NewGuid(), + ClientName = VictimClientName, + PlanName = VictimPlanName, + AssignedSeats = 42, + UsedSeats = 17, + Total = 1234.56m + }); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + /// + /// Control: the attacker cannot use the victim provider's own route — they are not a provider + /// admin of provider B, so authorization rejects the request. Passes before and after the fix; + /// anchors the IDOR test below by confirming the attacker has no legitimate access to provider B. + /// + [Fact] + public async Task GenerateClientInvoiceReport_ThroughVictimProviderRoute_IsUnauthorized() + { + await _loginHelper.LoginAsync(_attackerAdminEmail); + + var response = await _client.GetAsync( + $"providers/{_victimProvider.Id}/billing/invoices/{_victimInvoiceId}"); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + + /// + /// Reproduces VULN-565: the attacker requests the victim's invoiceId through their OWN provider's + /// route, which passes authorization. The report must be scoped to the authorized provider, so + /// none of the victim provider's client billing data may be returned. + /// + [Fact] + public async Task GenerateClientInvoiceReport_ForInvoiceOwnedByAnotherProvider_DoesNotReturnVictimData() + { + await _loginHelper.LoginAsync(_attackerAdminEmail); + + var response = await _client.GetAsync( + $"providers/{_attackerProvider.Id}/billing/invoices/{_victimInvoiceId}"); + + var body = await response.Content.ReadAsStringAsync(); + + // Core security invariant: the authorized provider does not own this invoice, so the victim + // provider's client billing data must not appear anywhere in the response. + Assert.DoesNotContain(VictimClientName, body); + Assert.DoesNotContain(VictimPlanName, body); + + // The endpoint must not hand back a CSV for an invoice the provider doesn't own; the + // authorized provider has no such invoice, so the scoped lookup yields nothing -> 404. + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + private static Task CreateBillableProviderAsync(IProviderRepository providerRepository, string name) => + providerRepository.CreateAsync(new Provider + { + Name = name, + BillingEmail = $"{name.Replace(" ", "-").ToLowerInvariant()}@example.com", + Type = ProviderType.Msp, + Status = ProviderStatusType.Billable, + Enabled = true, + GatewayCustomerId = $"cus_{Guid.NewGuid():N}", + GatewaySubscriptionId = $"sub_{Guid.NewGuid():N}" + }); +} diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index 652e82c80154..a8b20e530588 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -12,12 +12,10 @@ using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Models.Api; using Bit.Core.Models.BitStripe; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; using NSubstitute; using NSubstitute.ReturnsExtensions; @@ -159,24 +157,19 @@ public async Task GetInvoices_Ok( #region GenerateClientInvoiceReportAsync [Theory, BitAutoData] - public async Task GenerateClientInvoiceReportAsync_NullReportContent_ServerError( + public async Task GenerateClientInvoiceReportAsync_NullReportContent_NotFound( Provider provider, string invoiceId, SutProvider sutProvider) { ConfigureStableProviderAdminInputs(provider, sutProvider); - sutProvider.GetDependency().GenerateClientInvoiceReport(invoiceId) + sutProvider.GetDependency().GenerateClientInvoiceReport(provider.Id, invoiceId) .ReturnsNull(); var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId); - Assert.IsType>(result); - - var response = (JsonHttpResult)result; - - Assert.Equal(StatusCodes.Status500InternalServerError, response.StatusCode); - Assert.Equal("We had a problem generating your invoice CSV. Please contact support.", response.Value.Message); + AssertNotFound(result); } [Theory, BitAutoData] @@ -189,7 +182,7 @@ public async Task GenerateClientInvoiceReportAsync_Ok( var reportContent = "Report"u8.ToArray(); - sutProvider.GetDependency().GenerateClientInvoiceReport(invoiceId) + sutProvider.GetDependency().GenerateClientInvoiceReport(provider.Id, invoiceId) .Returns(reportContent); var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId); diff --git a/util/Migrator/DbScripts/2026-06-05_00_AddProviderInvoiceItemReadByProviderIdInvoiceId.sql b/util/Migrator/DbScripts/2026-06-05_00_AddProviderInvoiceItemReadByProviderIdInvoiceId.sql new file mode 100644 index 000000000000..6ae8b2794780 --- /dev/null +++ b/util/Migrator/DbScripts/2026-06-05_00_AddProviderInvoiceItemReadByProviderIdInvoiceId.sql @@ -0,0 +1,18 @@ +-- Scope the provider client invoice report to the authorized provider (VULN-565 / PM-36574). +-- Adds a provider-scoped lookup so a provider can only read invoice items it owns. +CREATE OR ALTER PROCEDURE [dbo].[ProviderInvoiceItem_ReadByProviderIdInvoiceId] + @ProviderId UNIQUEIDENTIFIER, + @InvoiceId VARCHAR (50) +AS +BEGIN + SET NOCOUNT ON + + SELECT + * + FROM + [dbo].[ProviderInvoiceItemView] + WHERE + [ProviderId] = @ProviderId + AND [InvoiceId] = @InvoiceId +END +GO