Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,20 @@ public enum ApiFeature {
*
* <p>Disabled by default.
*/
BILLING_EVENTS_LOGGING("billing-events-logging", false);
BILLING_EVENTS_LOGGING("billing-events-logging", false),

/**
* Billing events response feature flag: if enabled, the API will include the per-request billing
* events as a JSON array on the {@code Billing-Events} HTTP response header. Independent from
* {@link #BILLING_EVENTS_LOGGING} — both can be enabled simultaneously.
*
* <p>Set via {@code stargate.feature.flags.billing-events-response=true} at startup
* (authoritative; request headers cannot disable a startup-enabled flag) or per-request via
* {@code Feature-Flag-billing-events-response} header when not configured at startup.
*
* <p>Disabled by default.
*/
BILLING_EVENTS_RESPONSE("billing-events-response", false);

/**
* Prefix for HTTP headers used to override feature flags for specific requests: prepended before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.stargate.sgv2.jsonapi.config.BillingConfig;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures;
import java.util.List;
import java.util.Objects;

/**
Expand All @@ -12,10 +13,13 @@
* #create(BillingConfig, ApiFeatures)} to pick the right implementation for the request:
*
* <ul>
* <li>{@link DefaultBilling} — when {@link ApiFeature#BILLING_EVENTS_LOGGING} is enabled; emits
* structured JSON log lines on the {@code billing.events} logger.
* <li>{@link #NO_OP} — when the feature is disabled, or in tests / contexts where billing is not
* exercised.
* <li>{@link DefaultBilling} — when {@link ApiFeature#BILLING_EVENTS_LOGGING} and/or {@link
* ApiFeature#BILLING_EVENTS_RESPONSE} is enabled. It emits structured JSON log lines on the
* {@code billing.events} logger (logging flag) and/or buffers events in memory to be returned
* on the {@code Billing-Events} HTTP response header (response flag). The two flags are
* independent — both can be on at once.
* <li>{@link #NO_OP} — when both features are disabled, or in tests / contexts where billing is
* not exercised.
* </ul>
*
* Pass each aggregated {@link ModelUsage} to {@link #emitEvent(ModelUsage)}. {@code modelUsage}
Expand All @@ -30,6 +34,18 @@ public interface Billing {
*/
void emitEvent(ModelUsage modelUsage);

/**
* Snapshot of billing events buffered by {@link #emitEvent(ModelUsage)} for this request when
* {@link ApiFeature#BILLING_EVENTS_RESPONSE} is enabled, read later by {@code
* BillingResponseFilter} to populate the {@code Billing-Events} response header. Implementations
* that do not buffer (e.g. {@link #NO_OP}, or {@link DefaultBilling} with the response flag off)
* return an empty list. The returned list is an unmodifiable copy so callers can iterate safely
* while other tasks may still be writing.
*/
default List<BillingEvent> collectedEvents() {
return List.of();
}

/**
* Shared NO-OP {@link Billing}. Still enforces the non-null {@code modelUsage} contract so tests
* (and the feature-disabled production path) don't accidentally mask null-passing bugs in calling
Expand All @@ -40,17 +56,21 @@ public interface Billing {
/**
* Factory that picks the right {@link Billing} implementation for the current request.
* Centralizes the {@code DefaultBilling vs NO_OP} dispatch so callers (e.g. {@link
* io.stargate.sgv2.jsonapi.api.request.RequestContext}) don't have to know the rule. Reads {@code
* config} only when the feature is enabled — when disabled it is fine to pass any value,
* including one that would not validate as a real config.
* io.stargate.sgv2.jsonapi.api.request.RequestContext}) don't have to know the rule. Returns
* {@link DefaultBilling} when {@link ApiFeature#BILLING_EVENTS_LOGGING} and/or {@link
* ApiFeature#BILLING_EVENTS_RESPONSE} is enabled, telling it which sinks to feed; otherwise
* {@link #NO_OP}. Reads {@code config} only when a feature is enabled — when both are disabled it
* is fine to pass any value, including one that would not validate as a real config.
*
* @param config billing configuration; only consulted when the feature is enabled
* @param config billing configuration; only consulted when a feature is enabled
* @param apiFeatures the request's resolved feature set; must not be null
*/
static Billing create(BillingConfig config, ApiFeatures apiFeatures) {
Objects.requireNonNull(apiFeatures, "apiFeatures must not be null");
return apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_LOGGING)
? new DefaultBilling(config)
boolean loggingEnabled = apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_LOGGING);
boolean responseEnabled = apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_RESPONSE);
return (loggingEnabled || responseEnabled)
? new DefaultBilling(config, loggingEnabled, responseEnabled)
: NO_OP;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.stargate.sgv2.jsonapi.service.provider;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import io.stargate.sgv2.jsonapi.api.request.RequestContext;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.container.ContainerResponseContext;
import java.util.List;
import org.jboss.resteasy.reactive.server.ServerResponseFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Adds the {@code Billing-Events} HTTP response header (a JSON array of {@link BillingEvent}s
* collected during the request) when {@link ApiFeature#BILLING_EVENTS_RESPONSE} is enabled.
*
* <p>If the feature is off, or no billing events were emitted, the header is not added. Failures to
* serialize are logged and silently dropped so a serialization bug never breaks the actual API
* response.
*/
@ApplicationScoped
public class BillingResponseFilter {

/** HTTP response header that carries the JSON array of billing events. */
public static final String BILLING_EVENTS_HEADER = "Billing-Events";

private static final Logger LOGGER = LoggerFactory.getLogger(BillingResponseFilter.class);

// ObjectWriter is thread-safe and expensive to build; share one across all requests.
private static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writer();

private final RequestContext requestContext;

@Inject
public BillingResponseFilter(RequestContext requestContext) {
this.requestContext = requestContext;
}

@ServerResponseFilter
public void addBillingHeader(ContainerResponseContext responseContext) {
if (!requestContext.apiFeatures().isFeatureEnabled(ApiFeature.BILLING_EVENTS_RESPONSE)) {
return;
}
List<BillingEvent> events = requestContext.billing().collectedEvents();
if (events.isEmpty()) {
return;
}
try {
responseContext
.getHeaders()
.add(BILLING_EVENTS_HEADER, OBJECT_WRITER.writeValueAsString(events));
} catch (JsonProcessingException e) {
LOGGER.error("Failed to serialize {} billing events to response header", events.size(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand All @@ -18,14 +19,21 @@
import org.slf4j.LoggerFactory;

/**
* {@link Billing} implementation that emits structured JSON log lines on the {@code billing.events}
* logger for downstream billing pipelines.
* {@link Billing} implementation that dispatches each built event to one or both sinks selected at
* construction time:
*
* <p>Construction is driven by {@link
* io.stargate.sgv2.jsonapi.api.request.RequestContext#billing()}: the request context decides
* between this implementation and {@link Billing#NO_OP} based on whether {@link
* ApiFeature#BILLING_EVENTS_LOGGING} is enabled for the request. This class therefore assumes
* billing is enabled and unconditionally emits — it does not re-check the feature flag.
* <ul>
* <li>structured JSON log lines on the {@code billing.events} logger for downstream billing
* pipelines — when {@code loggingEnabled} ({@link ApiFeature#BILLING_EVENTS_LOGGING}).
* <li>an in-memory buffer surfaced via {@link #collectedEvents()} and returned on the {@code
* Billing-Events} HTTP response header — when {@code responseEnabled} ({@link
* ApiFeature#BILLING_EVENTS_RESPONSE}).
* </ul>
*
* <p>Construction is driven by {@link Billing#create} (via {@link
* io.stargate.sgv2.jsonapi.api.request.RequestContext#billing()}), which only picks this
* implementation when at least one of the two flags is enabled and tells it which sinks to feed.
* This class therefore does not re-check the feature flags; if a flag is off, its sink is skipped.
*
* <p>For each {@link ModelUsage}, up to three events are emitted, one per billable metric ({@link
* BillingEventType.Metric#TOTAL_TOKENS TOTAL_TOKENS}, {@link BillingEventType.Metric#EGRESS_BYTES
Expand All @@ -50,35 +58,72 @@ public class DefaultBilling implements Billing {
private final Set<String> internalModelProviders;
private final Set<BillingEventType> enabledEventTypes;

public DefaultBilling(BillingConfig config) {
/** Whether to emit events on the {@code billing.events} logger. */
private final boolean loggingEnabled;

/** Whether to buffer events in {@link #collectedEvents} for the response header. */
private final boolean responseEnabled;

// Events buffered for the BILLING_EVENTS_RESPONSE sink. Populated only when responseEnabled.
// emitEvent can be invoked from concurrent tasks within one request (async embedding / reranking
// calls), so the list is synchronized.
private final List<BillingEvent> collectedEvents =
Collections.synchronizedList(new ArrayList<>());

public DefaultBilling(BillingConfig config, boolean loggingEnabled, boolean responseEnabled) {
Objects.requireNonNull(config, "config must not be null");
this.product = requireNonBlank(config.product(), "billing.product");
this.resourceType = requireNonBlank(config.resourceType(), "billing.resource_type");
this.internalModelProviders = Set.copyOf(config.internalModelProviders());
this.enabledEventTypes =
config.enabledEventTypes().map(Set::copyOf).orElse(BillingEventType.ALL);
this.loggingEnabled = loggingEnabled;
this.responseEnabled = responseEnabled;
}

/**
* Emits billing events for the given aggregated model usage. The {@code billing.events} logger
* level is checked first so we skip event construction when the logger is silenced at runtime.
* Builds billing events for the given aggregated model usage and dispatches them to whichever
* sinks are enabled: the {@code billing.events} logger ({@code loggingEnabled}) and/or the
* in-memory buffer read via {@link #collectedEvents()} ({@code responseEnabled}). The {@code
* billing.events} logger level is also checked so we skip the log sink when the logger is
* silenced at runtime; if no sink is active, event construction is skipped entirely.
*
* @param modelUsage usage data for the model call; must not be null. Callers are expected to
* ensure they have usage data before invoking.
*/
@Override
public void emitEvent(ModelUsage modelUsage) {
Objects.requireNonNull(modelUsage, "modelUsage must not be null");
if (!BILLING_LOGGER.isInfoEnabled()) {
boolean shouldLog = loggingEnabled && BILLING_LOGGER.isInfoEnabled();
if (!shouldLog && !responseEnabled) {
return;
}
for (var event : buildEvents(modelUsage)) {
try {
BILLING_LOGGER.info(OBJECT_WRITER.writeValueAsString(event));
} catch (JacksonException e) {
LOGGER.error("Failed to serialize billing event of type {}", event.eventType(), e);
var events = buildEvents(modelUsage);
if (shouldLog) {
for (var event : events) {
try {
BILLING_LOGGER.info(OBJECT_WRITER.writeValueAsString(event));
} catch (JacksonException e) {
LOGGER.error("Failed to serialize billing event of type {}", event.eventType(), e);
}
}
}
if (responseEnabled) {
collectedEvents.addAll(events);
}
}

/**
* {@inheritDoc}
*
* <p>Returns an unmodifiable copy of the events buffered so far for the {@code
* BILLING_EVENTS_RESPONSE} sink; empty when {@code responseEnabled} is false.
*/
@Override
public List<BillingEvent> collectedEvents() {
synchronized (collectedEvents) {
return List.copyOf(collectedEvents);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package io.stargate.sgv2.jsonapi.service.provider;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.TestConstants;
import io.stargate.sgv2.jsonapi.api.request.RequestContext;
import io.stargate.sgv2.jsonapi.config.BillingConfig;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures;
import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

class BillingResponseFilterTest {

private static final ObjectMapper MAPPER = new ObjectMapper();
private static final TestConstants TEST_CONSTANTS = new TestConstants();

private record BillingAndFeatures(Billing billing, ApiFeatures apiFeatures) {}

private static BillingAndFeatures newBillingWith(boolean logging, boolean response) {
BillingConfig config = mock(BillingConfig.class);
when(config.product()).thenReturn("serverless");
when(config.resourceType()).thenReturn("serverless_database");
when(config.internalModelProviders()).thenReturn(List.of("nvidia"));
when(config.enabledEventTypes()).thenReturn(Optional.empty());

FeaturesConfig featuresConfig = mock(FeaturesConfig.class);
Map<ApiFeature, String> flags = new HashMap<>();
flags.put(ApiFeature.BILLING_EVENTS_LOGGING, String.valueOf(logging));
flags.put(ApiFeature.BILLING_EVENTS_RESPONSE, String.valueOf(response));
when(featuresConfig.flags()).thenReturn(flags);

ApiFeatures apiFeatures = ApiFeatures.fromConfigAndRequest(featuresConfig, null);
// Billing.create picks DefaultBilling when either flag is on (NO_OP only when both off) — the
// same dispatch the filter relies on in production.
return new BillingAndFeatures(Billing.create(config, apiFeatures), apiFeatures);
}

private static ModelUsage usage() {
return new ModelUsage(
ModelProvider.NVIDIA,
ModelType.EMBEDDING,
"test-model",
TEST_CONSTANTS.TENANT,
ModelInputType.INDEX,
10,
20,
100,
200,
1000L);
}

private static BillingResponseFilter filterFor(Billing billing, ApiFeatures apiFeatures) {
RequestContext rc = mock(RequestContext.class);
when(rc.billing()).thenReturn(billing);
when(rc.apiFeatures()).thenReturn(apiFeatures);
return new BillingResponseFilter(rc);
}

private static ContainerResponseContext responseContextWithHeaders(
MultivaluedMap<String, Object> headers) {
ContainerResponseContext response = mock(ContainerResponseContext.class);
when(response.getHeaders()).thenReturn(headers);
return response;
}

@Test
void addsHeaderWhenFeatureOnAndEventsPresent() throws Exception {
BillingAndFeatures bf = newBillingWith(false, true);
bf.billing().emitEvent(usage());
BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());

MultivaluedMap<String, Object> headers = new MultivaluedHashMap<>();
filter.addBillingHeader(responseContextWithHeaders(headers));

Object headerValue = headers.getFirst(BillingResponseFilter.BILLING_EVENTS_HEADER);
assertThat(headerValue).isNotNull();
JsonNode parsed = MAPPER.readTree(headerValue.toString());
assertThat(parsed.isArray()).isTrue();
assertThat(parsed.size()).isEqualTo(3);
assertThat(parsed.get(0).get("event_type").asText()).isEqualTo("internal_model_total_tokens");
}

@Test
void skipsHeaderWhenFeatureOff() {
// RESPONSE off — header must not be added even if LOGGING was on for this request.
BillingAndFeatures bf = newBillingWith(true, false);
bf.billing().emitEvent(usage());
BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());

MultivaluedMap<String, Object> headers = new MultivaluedHashMap<>();
ContainerResponseContext response = responseContextWithHeaders(headers);
filter.addBillingHeader(response);

assertThat(headers.containsKey(BillingResponseFilter.BILLING_EVENTS_HEADER)).isFalse();
// We should never touch the headers either (early return saves the work).
verify(response, never()).getHeaders();
}

@Test
void skipsHeaderWhenNoEventsCollected() {
// RESPONSE on, but no emitEvent calls — header skipped because buffer is empty.
BillingAndFeatures bf = newBillingWith(false, true);
BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());

MultivaluedMap<String, Object> headers = new MultivaluedHashMap<>();
filter.addBillingHeader(responseContextWithHeaders(headers));

assertThat(headers.containsKey(BillingResponseFilter.BILLING_EVENTS_HEADER)).isFalse();
}
}
Loading