From 38a7b548d95155766f3125fa4da445c47ff14b28 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sat, 28 Mar 2026 16:30:32 +0530 Subject: [PATCH 01/34] feat(maglev): implement arrivals-and-departures-for-location endpoint Port Java arrivals-and-departures-for-location to Go maglev architecture. - Geospatial stop lookup via in-memory R-tree index - 3-day window queries (yesterday/today/tomorrow) for overnight trip support - Batch-fetch routes/trips/stop-times to avoid N+1 queries - Deduplicated references block (Agencies, Routes, Stops, Trips, Situations) - nearbyStopIds via haversine distance, excluding stops already in stopIds - arrivalStatus derived from schedule deviation (LATE/EARLY/ON_TIME/default) - limitExceeded flag honoring maxCount (mirrors Java MaxCountSupport) - Add StopWithDistance model and NewArrivalsAndDeparturesForLocationResponse - Add ORDER BY to GetAgenciesForStops for deterministic results - Fix nil vehicle passed to BuildTripStatus in arrivals_and_departure_for_stop.go - Register GET /api/where/arrivals-and-departures-for-location.json - Add unit, E2E, and context cancellation tests Closes #787 Fixes #799 --- gtfsdb/query.sql | 3 +- gtfsdb/query.sql.go | 1 + internal/models/response.go | 32 + internal/models/stops.go | 8 + .../arrivals_and_departures_for_location.go | 915 ++++++++++++++++++ ...rivals_and_departures_for_location_test.go | 455 +++++++++ ...rrivals_and_departures_for_stop_handler.go | 2 +- internal/restapi/context_cancellation_test.go | 11 +- .../input_validation_integration_test.go | 50 + internal/restapi/routes.go | 1 + 10 files changed, 1474 insertions(+), 4 deletions(-) create mode 100644 internal/restapi/arrivals_and_departures_for_location.go create mode 100644 internal/restapi/arrivals_and_departures_for_location_test.go diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index 11dfb0d8..a474894f 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -663,7 +663,8 @@ FROM JOIN routes ON trips.route_id = routes.id JOIN agencies a ON routes.agency_id = a.id WHERE - stop_times.stop_id IN (sqlc.slice('stop_ids')); + stop_times.stop_id IN (sqlc.slice('stop_ids')) +ORDER BY a.id, stop_times.stop_id; -- name: GetStopTimesForTrip :many SELECT diff --git a/gtfsdb/query.sql.go b/gtfsdb/query.sql.go index 1dc279b0..a0ddea28 100644 --- a/gtfsdb/query.sql.go +++ b/gtfsdb/query.sql.go @@ -1201,6 +1201,7 @@ FROM JOIN agencies a ON routes.agency_id = a.id WHERE stop_times.stop_id IN (/*SLICE:stop_ids*/?) +ORDER BY a.id, stop_times.stop_id ` type GetAgenciesForStopsRow struct { diff --git a/internal/models/response.go b/internal/models/response.go index 0f3efb20..91676712 100644 --- a/internal/models/response.go +++ b/internal/models/response.go @@ -59,6 +59,38 @@ func NewArrivalsAndDepartureResponse(arrivalsAndDepartures any, references Refer return NewOKResponse(data, c) } +func NewArrivalsAndDeparturesForLocationResponse( + arrivalsAndDepartures []ArrivalAndDeparture, + references ReferencesModel, + nearbyStopIds []StopWithDistance, + situationIds []string, + stopIds []string, + limitExceeded bool, + c clock.Clock, +) ResponseModel { + if nearbyStopIds == nil { + nearbyStopIds = []StopWithDistance{} + } + if situationIds == nil { + situationIds = []string{} + } + if stopIds == nil { + stopIds = []string{} + } + entryData := map[string]interface{}{ + "arrivalsAndDepartures": arrivalsAndDepartures, + "limitExceeded": limitExceeded, + "nearbyStopIds": nearbyStopIds, + "situationIds": situationIds, + "stopIds": stopIds, + } + data := map[string]interface{}{ + "entry": entryData, + "references": references, + } + return NewOKResponse(data, c) +} + // NewResponse creates a standard response using the provided clock. func NewResponse(code int, data any, text string, c clock.Clock) ResponseModel { return ResponseModel{ diff --git a/internal/models/stops.go b/internal/models/stops.go index 9a1e9073..342fe2a6 100644 --- a/internal/models/stops.go +++ b/internal/models/stops.go @@ -34,3 +34,11 @@ type StopsResponse struct { List []Stop `json:"list"` OutOfRange bool `json:"outOfRange"` } + +// StopWithDistance represents a nearby stop together with its distance from the +// centre of the query bounds. It matches the Java StopWithDistanceV2Bean and is +// used by the arrivals-and-departures-for-location endpoint. +type StopWithDistance struct { + StopID string `json:"stopId"` + DistanceFromQuery float64 `json:"distanceFromQuery"` +} diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go new file mode 100644 index 00000000..c0a684e4 --- /dev/null +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -0,0 +1,915 @@ +package restapi + +import ( + "context" + "log/slog" + "net/http" + "sort" + "strconv" + "time" + + "github.com/OneBusAway/go-gtfs" + "maglev.onebusaway.org/gtfsdb" + "maglev.onebusaway.org/internal/models" + "maglev.onebusaway.org/internal/utils" +) + +// ArrivalsAndDeparturesForLocationParams holds all parsed and validated query +// parameters for the arrivals-and-departures-for-location endpoint. +type ArrivalsAndDeparturesForLocationParams struct { + Lat float64 + Lon float64 + Radius float64 + LatSpan float64 + LonSpan float64 + + Time time.Time + MinutesBefore int + MinutesAfter int + + MaxCount int +} + +// parseArrivalsAndDeparturesForLocationParams parses and validates all query +// parameters for this endpoint in one place. +func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) (ArrivalsAndDeparturesForLocationParams, map[string][]string) { + const ( + defaultMinutesBefore = 5 + defaultMinutesAfter = 35 + maxMinutesBefore = 60 + maxMinutesAfter = 240 + defaultMaxCount = 250 + ) + + params := ArrivalsAndDeparturesForLocationParams{ + Time: api.Clock.Now(), + MinutesBefore: defaultMinutesBefore, + MinutesAfter: defaultMinutesAfter, + MaxCount: defaultMaxCount, + } + + var fieldErrors map[string][]string + addError := func(field, msg string) { + if fieldErrors == nil { + fieldErrors = make(map[string][]string) + } + fieldErrors[field] = append(fieldErrors[field], msg) + } + + // Spatial params (required) — reuse the shared location parser. + loc, locErrors := api.parseLocationParams(r, nil) + if len(locErrors) > 0 { + if fieldErrors == nil { + fieldErrors = make(map[string][]string) + } + for k, v := range locErrors { + fieldErrors[k] = append(fieldErrors[k], v...) + } + } else { + params.Lat = loc.Lat + params.Lon = loc.Lon + params.Radius = loc.Radius + params.LatSpan = loc.LatSpan + params.LonSpan = loc.LonSpan + } + + q := r.URL.Query() + + // time + if val := q.Get("time"); val != "" { + if ms, err := strconv.ParseInt(val, 10, 64); err == nil { + params.Time = time.Unix(ms/1000, (ms%1000)*1_000_000) + } else { + addError("time", "must be a valid Unix timestamp in milliseconds") + } + } + + // minutesBefore + if val := q.Get("minutesBefore"); val != "" { + if n, err := strconv.Atoi(val); err != nil { + addError("minutesBefore", "must be a valid integer") + } else if n < 0 { + addError("minutesBefore", "must be a non-negative integer") + } else if n > maxMinutesBefore { + params.MinutesBefore = maxMinutesBefore + } else { + params.MinutesBefore = n + } + } + + // minutesAfter + if val := q.Get("minutesAfter"); val != "" { + if n, err := strconv.Atoi(val); err != nil { + addError("minutesAfter", "must be a valid integer") + } else if n < 0 { + addError("minutesAfter", "must be a non-negative integer") + } else if n > maxMinutesAfter { + params.MinutesAfter = maxMinutesAfter + } else { + params.MinutesAfter = n + } + } + + // maxCount — reuse the shared parser. + var maxCountErrors map[string][]string + params.MaxCount, maxCountErrors = utils.ParseMaxCount(q, defaultMaxCount, nil) + if len(maxCountErrors) > 0 { + if fieldErrors == nil { + fieldErrors = make(map[string][]string) + } + for k, v := range maxCountErrors { + fieldErrors[k] = append(fieldErrors[k], v...) + } + } + + return params, fieldErrors +} + +// arrivalStatusFromDeviation derives a human-readable status string from a +// schedule deviation, matching Java's ArrivalAndDepartureBeanServiceImpl logic. +// +// - deviation > 300s (5+ min late) → "LATE" +// - deviation < -180s (3+ min early) → "EARLY" +// - otherwise → "ON_TIME" +// +// When there is no real-time data the caller should pass "default" directly. +func arrivalStatusFromDeviation(deviationSeconds int) string { + switch { + case deviationSeconds > 300: + return "LATE" + case deviationSeconds < -180: + return "EARLY" + default: + return "ON_TIME" + } +} + +// arrivalsAndDeparturesForLocationHandler returns arrivals and departures for all +// stops within a geographic bounding box (lat/lon + latSpan/lonSpan or radius). +// +// Java equivalent: ArrivalsAndDeparturesForLocationAction.index() +func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + params, fieldErrors := api.parseArrivalsAndDeparturesForLocationParams(r) + if len(fieldErrors) > 0 { + api.validationErrorResponse(w, r, fieldErrors) + return + } + + api.GtfsManager.RLock() + defer api.GtfsManager.RUnlock() + + // Find stops inside the bounding box using the in-memory R-tree spatial index. + // Pass params.Time so that a historical `time=` override is respected. + stops := api.GtfsManager.GetStopsForLocation( + ctx, + params.Lat, params.Lon, + params.Radius, + params.LatSpan, params.LonSpan, + "", + params.MaxCount, + false, + []int{}, + params.Time, + ) + + if len(stops) == 0 { + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + []models.ArrivalAndDeparture{}, + *models.NewEmptyReferences(), + []models.StopWithDistance{}, + []string{}, + []string{}, + false, + api.Clock, + )) + return + } + + // Collect raw stop codes (no agency prefix) for batch DB queries. + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } + + // Resolve agency for each stop (needed to build combined "agencyId_stopCode" IDs). + agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, rawStopCodes) + if err != nil { + api.serverErrorResponse(w, r, err) + return + } + // stopCode → agencyID; first agency wins for multi-agency stops. + stopAgencyMap := make(map[string]string, len(agencyRows)) + for _, row := range agencyRows { + if _, exists := stopAgencyMap[row.StopID]; !exists { + stopAgencyMap[row.StopID] = row.ID + } + } + + // fallbackAgencyID is used only when a stop has no entry in stopAgencyMap + // (e.g. a stop with no active routes). Derived from the most common agency + // among the queried stops — never used to prefix alert IDs. + fallbackAgencyID := pickPrimaryAgency(stopAgencyMap) + + // Determine the base query timezone from the fallback agency. + agencyLoc := time.UTC + if fallbackAgencyID != "" { + if ag, tzErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, fallbackAgencyID); tzErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + agencyLoc = parsed + } + } + } + + // Fan out: collect arrivals across every stop in the bbox. + arrivals := make([]models.ArrivalAndDeparture, 0, len(stops)*4) + + // Shared reference-collection maps (deduplicated across all stops). + tripIDSet := make(map[string]*gtfsdb.Trip) + routeIDSet := make(map[string]*gtfsdb.Route) + stopIDSet := make(map[string]bool) // raw stop codes for reference building + + // Track which stop codes actually produced at least one arrival. + // Java only includes a stop in the entry's stopIds when it has results. + stopsWithArrivals := make(map[string]bool) + + collectedAlerts := make(map[string]gtfs.Alert) + + limitExceeded := false + + for _, dbStop := range stops { + // Early exit once maxCount is reached — mirrors Java's MaxCountSupport. + if limitExceeded { + break + } + if len(arrivals) >= params.MaxCount { + limitExceeded = true + break + } + + stopCode := dbStop.ID + agencyID := stopAgencyMap[stopCode] + if agencyID == "" { + agencyID = fallbackAgencyID + } + combinedStopID := utils.FormCombinedID(agencyID, stopCode) + stopIDSet[stopCode] = true + + // Per-stop timezone — handles multi-agency feeds where stops may span TZs. + stopLoc := agencyLoc + if agencyID != fallbackAgencyID { + if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID); agErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + stopLoc = parsed + } + } + } + + stopQueryTime := params.Time.In(stopLoc) + stopWindowStart := stopQueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) + stopWindowEnd := stopQueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) + + // Query 3 days (yesterday/today/tomorrow) to handle overnight trips — + // identical to the single-stop handler's approach. + type activeStopTime struct { + gtfsdb.GetStopTimesForStopInWindowRow + ServiceDate time.Time + } + var allActiveStopTimes []activeStopTime + + for dayOffset := -1; dayOffset <= 1; dayOffset++ { + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return + } + + targetDate := stopQueryTime.AddDate(0, 0, dayOffset) + serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) + serviceDateStr := targetDate.Format("20060102") + + activeServiceIDs, svcErr := api.GtfsManager.GetActiveServiceIDsForDateCached(ctx, serviceDateStr) + if svcErr != nil { + api.Logger.Warn("failed to query active service IDs", + slog.String("date", serviceDateStr), + slog.Any("error", svcErr)) + continue + } + if len(activeServiceIDs) == 0 { + continue + } + + activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) + for _, sid := range activeServiceIDs { + activeServiceIDSet[sid] = true + } + + startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() + endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() + if endNanos < 0 { + continue + } + + stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ + StopID: stopCode, + WindowStartNanos: startNanos, + WindowEndNanos: endNanos, + }) + if stErr != nil { + api.Logger.Warn("failed to query stop times in window", + slog.String("stopID", stopCode), + slog.Any("error", stErr)) + continue + } + + for _, st := range stopTimes { + if activeServiceIDSet[st.ServiceID] { + allActiveStopTimes = append(allActiveStopTimes, activeStopTime{ + GetStopTimesForStopInWindowRow: st, + ServiceDate: serviceMidnight, + }) + } + } + } + + if len(allActiveStopTimes) == 0 { + // This stop has no arrivals in the window — do not include it in stopIds. + continue + } + + // Batch-fetch routes & trips for this stop's active stop times. + batchRouteIDs := make(map[string]bool) + batchTripIDs := make(map[string]bool) + for _, ast := range allActiveStopTimes { + if ast.RouteID != "" { + batchRouteIDs[ast.RouteID] = true + } + if ast.TripID != "" { + batchTripIDs[ast.TripID] = true + } + } + + uniqueRouteIDs := stringMapKeys(batchRouteIDs) + uniqueTripIDs := stringMapKeys(batchTripIDs) + + fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) + if rErr != nil { + api.Logger.Warn("failed to batch fetch routes", + slog.String("stopID", stopCode), slog.Any("error", rErr)) + continue + } + fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) + if tErr != nil { + api.Logger.Warn("failed to batch fetch trips", + slog.String("stopID", stopCode), slog.Any("error", tErr)) + continue + } + + routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) + for _, rt := range fetchedRoutes { + routesLookup[rt.ID] = rt + } + tripsLookup := make(map[string]gtfsdb.Trip, len(fetchedTrips)) + for _, tr := range fetchedTrips { + tripsLookup[tr.ID] = tr + } + + // Batch total-stop-count per trip (avoids N+1 for totalStopsInTrip field). + tripStopCountMap := make(map[string]int, len(uniqueTripIDs)) + if len(uniqueTripIDs) > 0 { + allST, countErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTripIDs(ctx, uniqueTripIDs) + if countErr != nil { + api.Logger.Warn("failed to batch fetch stop times for trips", slog.Any("error", countErr)) + } else { + for _, st := range allST { + tripStopCountMap[st.TripID]++ + } + } + } + + // Build one ArrivalAndDeparture per active stop time. + stopProducedArrival := false + for _, ast := range allActiveStopTimes { + // Respect maxCount mid-stop as well. + if len(arrivals) >= params.MaxCount { + limitExceeded = true + break + } + + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return + } + + st := ast.GetStopTimesForStopInWindowRow + serviceMidnight := ast.ServiceDate + serviceDateMillis := serviceMidnight.UnixMilli() + + route, routeOK := routesLookup[st.RouteID] + if !routeOK { + api.Logger.Debug("skipping stop time: route not found", + slog.String("routeID", st.RouteID), slog.String("tripID", st.TripID)) + continue + } + trip, tripOK := tripsLookup[st.TripID] + if !tripOK { + api.Logger.Debug("skipping stop time: trip not found", + slog.String("tripID", st.TripID)) + continue + } + + rCopy := route + routeIDSet[route.ID] = &rCopy + tCopy := trip + tripIDSet[trip.ID] = &tCopy + + scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)).UnixMilli() + scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)).UnixMilli() + + var ( + predictedArrivalTime int64 + predictedDepartureTime int64 + predicted = false + vehicleID string + tripStatus *models.TripStatus + distanceFromStop = 0.0 + numberOfStopsAway = 0 + lastUpdateTime int64 // always emitted; 0 when no vehicle + + // FIX #4: derive status from schedule deviation instead of + // always emitting "default". Falls back to "default" when + // there is no real-time data. + arrivalStatus = "default" + ) + + vehicle := api.GtfsManager.GetVehicleForTrip(ctx, st.TripID) + if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { + vehicleID = vehicle.ID.ID + } + + schedArrTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) + schedDepTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) + + predArr, predDep, isPredicted := api.getPredictedTimes( + st.TripID, stopCode, int64(st.StopSequence), + schedArrTime, schedDepTime, + ) + if isPredicted { + predicted = true + predictedArrivalTime = predArr + predictedDepartureTime = predDep + } + // When not predicted, leave predictedArrivalTime/predictedDepartureTime as 0 + // (matches Java which emits 0 for unpredicted arrivals). + + // Gate BuildTripStatus on vehicle presence — matches the stop handler convention. + if vehicle != nil { + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, stopQueryTime) + if statusErr != nil { + api.Logger.Warn("BuildTripStatus failed", + "tripID", st.TripID, "error", statusErr) + } + if status != nil { + tripStatus = status + + // Only set a meaningful status when the arrival is predicted. + // Unpredicted arrivals stay "default". + if predicted { + arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) + } + + // Collect stops referenced in trip status for the references block. + if status.NextStop != "" { + if _, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { + stopIDSet[nsID] = true + } + } + if status.ClosestStop != "" { + if _, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { + stopIDSet[csID] = true + } + } + + if vehicle.Position != nil { + distanceFromStop = api.getBlockDistanceToStop(ctx, st.TripID, stopCode, vehicle, stopQueryTime) + nsa := api.getNumberOfStopsAway(ctx, st.TripID, int(st.StopSequence), vehicle, stopQueryTime) + if nsa != nil { + numberOfStopsAway = *nsa + } else { + numberOfStopsAway = -1 + } + } + + // Ensure the active trip (if different from scheduled) is in references. + if status.ActiveTripID != "" { + if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil && atID != st.TripID { + if _, exists := tripIDSet[atID]; !exists { + if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { + atCopy := at + tripIDSet[at.ID] = &atCopy + if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { + arCopy := ar + routeIDSet[ar.ID] = &arCopy + } + } + } + } + } + } + + rawUpdate := api.GtfsManager.GetVehicleLastUpdateTime(vehicle) + if rawUpdate > 0 { + lastUpdateTime = rawUpdate + } + } + + totalStopsInTrip := tripStopCountMap[st.TripID] + blockTripSequence := api.calculateBlockTripSequence(ctx, st.TripID, serviceMidnight) + + // Collect service alerts for this trip. + // situationIDs on each arrival use FormCombinedID(route.AgencyID, alert.ID) + // to match Java's "agencyId_alertId" format (e.g. "40_40_16737"). + tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, st.TripID) + situationIDs := make([]string, 0, len(tripAlerts)) + for _, alert := range tripAlerts { + if alert.ID == "" { + continue + } + situationIDs = append(situationIDs, utils.FormCombinedID(route.AgencyID, alert.ID)) + if _, seen := collectedAlerts[alert.ID]; !seen { + collectedAlerts[alert.ID] = alert + } + } + + // lastUpdateTimePtr: use a pointer so the model's omitempty drops it only + // when the value is truly absent. We pass nil when lastUpdateTime==0 to + // preserve the existing model behaviour (field omitted rather than 0). + var lastUpdateTimePtr *int64 + if lastUpdateTime > 0 { + lastUpdateTimePtr = utils.Int64Ptr(lastUpdateTime) + } + + // vehicleID must carry the agency prefix to match Java output ("1_6853"). + formattedVehicleID := "" + if vehicleID != "" { + formattedVehicleID = utils.FormCombinedID(route.AgencyID, vehicleID) + } + + // FIX #2: use raw GTFS stop_sequence (1-based) — do NOT subtract 1. + // Java's ArrivalAndDepartureBean.getStopSequence() returns the raw + // GTFS value directly; there is no zero-indexing in the wire format. + rawStopSequence := int(st.StopSequence) + + arrivals = append(arrivals, *models.NewArrivalAndDeparture( + utils.FormCombinedID(route.AgencyID, route.ID), // routeID + route.ShortName.String, // routeShortName + route.LongName.String, // routeLongName + utils.FormCombinedID(route.AgencyID, st.TripID), // tripID + st.TripHeadsign.String, // tripHeadsign + combinedStopID, // stopID + formattedVehicleID, // vehicleID (agency-prefixed or empty) + serviceDateMillis, // serviceDate + scheduledArrivalTime, // scheduledArrivalTime + scheduledDepartureTime, // scheduledDepartureTime + predictedArrivalTime, // predictedArrivalTime (0 when unpredicted) + predictedDepartureTime, // predictedDepartureTime (0 when unpredicted) + lastUpdateTimePtr, // lastUpdateTime + predicted, // predicted + true, // arrivalEnabled + true, // departureEnabled + rawStopSequence, // FIX #2: raw GTFS stop_sequence, not zero-based + totalStopsInTrip, // totalStopsInTrip + numberOfStopsAway, // numberOfStopsAway + blockTripSequence, // blockTripSequence + distanceFromStop, // distanceFromStop + arrivalStatus, // FIX #4: derived from scheduleDeviation + "", // occupancyStatus + "", // predictedOccupancy + "", // historicalOccupancy + tripStatus, // tripStatus + situationIDs, // situationIDs (agency-prefixed) + )) + stopProducedArrival = true + } + + if stopProducedArrival { + stopsWithArrivals[stopCode] = true + } + } + + // Sort arrivals by predicted (or scheduled) arrival time ascending. + // Matches Java's ArrivalAndDepartureComparator. + sort.Slice(arrivals, func(i, j int) bool { + ti := arrivals[i].PredictedArrivalTime + if ti == 0 { + ti = arrivals[i].ScheduledArrivalTime + } + tj := arrivals[j].PredictedArrivalTime + if tj == 0 { + tj = arrivals[j].ScheduledArrivalTime + } + return ti < tj + }) + + // Build references block (agencies, routes, stops, trips, situations). + references := models.NewEmptyReferences() + addedAgencyIDs := make(map[string]bool) + + // Trips + for _, trip := range tripIDSet { + routeForTrip, ok := routeIDSet[trip.RouteID] + if !ok { + fetched, fErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, trip.RouteID) + if fErr != nil { + api.Logger.Warn("failed to fetch route for trip reference", + "tripID", trip.ID, "routeID", trip.RouteID, "error", fErr) + continue + } + fCopy := fetched + routeIDSet[fetched.ID] = &fCopy + routeForTrip = &fCopy + } + references.Trips = append(references.Trips, *models.NewTripReference( + utils.FormCombinedID(routeForTrip.AgencyID, trip.ID), + utils.FormCombinedID(routeForTrip.AgencyID, trip.RouteID), + utils.FormCombinedID(routeForTrip.AgencyID, trip.ServiceID), + trip.TripHeadsign.String, + "", + strconv.FormatInt(trip.DirectionID.Int64, 10), + utils.FormCombinedID(routeForTrip.AgencyID, trip.BlockID.String), + utils.FormCombinedID(routeForTrip.AgencyID, trip.ShapeID.String), + )) + } + + // Routes + their agencies. + for _, route := range routeIDSet { + references.Routes = append(references.Routes, models.NewRoute( + utils.FormCombinedID(route.AgencyID, route.ID), + route.AgencyID, + route.ShortName.String, + route.LongName.String, + route.Desc.String, + models.RouteType(route.Type), + route.Url.String, + route.Color.String, + route.TextColor.String, + )) + if !addedAgencyIDs[route.AgencyID] { + ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, route.AgencyID) + if agErr == nil { + references.Agencies = append(references.Agencies, models.NewAgencyReference( + ag.ID, ag.Name, ag.Url, ag.Timezone, ag.Lang.String, + ag.Phone.String, ag.Email.String, ag.FareUrl.String, "", false, + )) + addedAgencyIDs[ag.ID] = true + } else { + api.Logger.Warn("failed to fetch agency for reference", + "agencyID", route.AgencyID, "error", agErr) + } + } + } + + // Stops (queried stops + nextStop/closestStop referenced by TripStatus). + stopIDsSlice := make([]string, 0, len(stopIDSet)) + for sid := range stopIDSet { + stopIDsSlice = append(stopIDsSlice, sid) + } + + batchStops, bsErr := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) + if bsErr != nil { + api.Logger.Warn("failed to batch fetch stop references", slog.Any("error", bsErr)) + } + batchRoutesForStops, brsErr := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) + if brsErr != nil { + api.Logger.Warn("failed to batch fetch routes for stops", slog.Any("error", brsErr)) + } + + stopsMap := make(map[string]gtfsdb.Stop, len(batchStops)) + for _, s := range batchStops { + stopsMap[s.ID] = s + } + routesByStop := make(map[string][]gtfsdb.GetRoutesForStopsRow) + for _, row := range batchRoutesForStops { + routesByStop[row.StopID] = append(routesByStop[row.StopID], row) + } + + for _, sid := range stopIDsSlice { + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return + } + stopData, ok := stopsMap[sid] + if !ok { + continue + } + ag := stopAgencyMap[sid] + if ag == "" { + ag = fallbackAgencyID + } + routesForStop := routesByStop[sid] + combinedRouteIDs := make([]string, len(routesForStop)) + for i, rr := range routesForStop { + combinedRouteIDs[i] = utils.FormCombinedID(rr.AgencyID, rr.ID) + if _, exists := routeIDSet[rr.ID]; !exists { + rc := gtfsdb.Route{ + ID: rr.ID, + AgencyID: rr.AgencyID, + ShortName: rr.ShortName, + LongName: rr.LongName, + Desc: rr.Desc, + Type: rr.Type, + Url: rr.Url, + Color: rr.Color, + TextColor: rr.TextColor, + } + routeIDSet[rr.ID] = &rc + } + } + references.Stops = append(references.Stops, models.Stop{ + ID: utils.FormCombinedID(ag, stopData.ID), + Name: stopData.Name.String, + Lat: stopData.Lat, + Lon: stopData.Lon, + Code: stopData.Code.String, + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), + LocationType: int(stopData.LocationType.Int64), + WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), + RouteIDs: combinedRouteIDs, + StaticRouteIDs: combinedRouteIDs, + }) + } + + // Collect stop-level service alerts. + // These fall back to fallbackAgencyID for the agency prefix since there is + // no route context available at the stop level. + for _, sc := range rawStopCodes { + for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { + if alert.ID != "" { + if _, seen := collectedAlerts[alert.ID]; !seen { + collectedAlerts[alert.ID] = alert + } + } + } + } + + // Build situation references and top-level situationIds. + // Entry-level situationIds use the raw alert ID (e.g. "1_85725", "40_16559"). + // Alert IDs from GTFS-RT already contain the agency prefix, so no extra + // FormCombinedID wrapping is applied here. + // Per-arrival situationIds DO wrap with FormCombinedID — that is separate. + topLevelSituationIDs := make([]string, 0, len(collectedAlerts)) + if len(collectedAlerts) > 0 { + alertSlice := make([]gtfs.Alert, 0, len(collectedAlerts)) + for _, a := range collectedAlerts { + alertSlice = append(alertSlice, a) + } + references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) + for alertID := range collectedAlerts { + topLevelSituationIDs = append(topLevelSituationIDs, alertID) + } + } + + // Build the entry's stopIds — only stops that produced at least one arrival, + // in the same order as the original stops slice (deterministic, not map order). + // Java: stops are only added when !arrivalsAndDepartures.isEmpty(). + queriedStopIDs := make([]string, 0, len(stopsWithArrivals)) + for _, dbStop := range stops { + if stopsWithArrivals[dbStop.ID] { + ag := stopAgencyMap[dbStop.ID] + if ag == "" { + ag = fallbackAgencyID + } + queriedStopIDs = append(queriedStopIDs, utils.FormCombinedID(ag, dbStop.ID)) + } + } + + // Build nearbyStopIds as []StopWithDistance. + // + // FIX #3: pass queriedStopIDs so that stops already in the entry's stopIds + // are excluded from nearbyStopIds. Java's includeInputIdsInNearby=false + // default means the bbox stops must not appear in both lists. + nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time, queriedStopIDs) + + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + arrivals, + *references, + nearbyStops, + topLevelSituationIDs, + queriedStopIDs, + limitExceeded, + api.Clock, + )) +} + +// getLocationNearbyStops returns stops near the query centre together with their +// distance from the centre, sorted ascending by distance. +// +// Java equivalent: getNearbyStops() in StopWithArrivalsAndDeparturesBeanServiceImpl, +// which calls SphericalGeometryLibrary.distance() to populate distanceFromQuery. +// +// FIX #3: queriedStopIDs are excluded from the result so that stops already +// present in entry.stopIds do not also appear in entry.nearbyStopIds. +// This matches Java's includeInputIdsInNearby=false default behaviour. +func getLocationNearbyStops( + api *RestAPI, + ctx context.Context, + centerLat, centerLon float64, + queryTime time.Time, + queriedStopIDs []string, // stops already in entry.stopIds — must be excluded +) []models.StopWithDistance { + + nearby := api.GtfsManager.GetStopsForLocation( + ctx, centerLat, centerLon, 100, 0, 0, "", 250, false, []int{}, queryTime, + ) + + if len(nearby) == 0 { + return nil + } + + // Batch-resolve owning agency for each nearby stop. + candidateIDs := make([]string, len(nearby)) + for i, s := range nearby { + candidateIDs[i] = s.ID + } + + nearbyAgencyMap := make(map[string]string, len(candidateIDs)) + agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, candidateIDs) + if err != nil { + api.Logger.Warn("failed to resolve agencies for nearby stops", "error", err) + } else { + for _, row := range agencyRows { + if _, exists := nearbyAgencyMap[row.StopID]; !exists { + nearbyAgencyMap[row.StopID] = row.ID + } + } + } + + // pickPrimaryAgency over the nearby set for stops with no resolved agency. + nearbyFallback := pickPrimaryAgency(nearbyAgencyMap) + + // Build a set of already-queried combined stop IDs for O(1) lookup. + // FIX #3: Java excludes these via includeInputIdsInNearby=false. + queriedSet := make(map[string]bool, len(queriedStopIDs)) + for _, id := range queriedStopIDs { + queriedSet[id] = true + } + + result := make([]models.StopWithDistance, 0, len(nearby)) + for _, s := range nearby { + ag := nearbyFallback + if resolved, ok := nearbyAgencyMap[s.ID]; ok { + ag = resolved + } + combinedID := utils.FormCombinedID(ag, s.ID) + + // FIX #3: skip stops that are already in entry.stopIds. + if queriedSet[combinedID] { + continue + } + + dist := utils.Distance(centerLat, centerLon, s.Lat, s.Lon) + result = append(result, models.StopWithDistance{ + StopID: combinedID, + DistanceFromQuery: dist, + }) + } + + if len(result) == 0 { + return nil + } + + // Sort by distance ascending to match Java's ordering. + sort.Slice(result, func(i, j int) bool { + return result[i].DistanceFromQuery < result[j].DistanceFromQuery + }) + + return result +} + +// stringMapKeys returns the keys of a map[string]bool as a string slice. +func stringMapKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// pickPrimaryAgency returns the agency ID that appears most frequently in the +// stopCode→agencyID map. Used only as a fallback when a stop has no resolved +// agency — never used to prefix alert IDs directly. +func pickPrimaryAgency(stopAgencyMap map[string]string) string { + counts := make(map[string]int, 4) + for _, ag := range stopAgencyMap { + counts[ag]++ + } + best := "" + bestCount := 0 + for ag, cnt := range counts { + if cnt > bestCount || (cnt == bestCount && ag < best) { + best = ag + bestCount = cnt + } + } + return best +} diff --git a/internal/restapi/arrivals_and_departures_for_location_test.go b/internal/restapi/arrivals_and_departures_for_location_test.go new file mode 100644 index 00000000..c43ba3ff --- /dev/null +++ b/internal/restapi/arrivals_and_departures_for_location_test.go @@ -0,0 +1,455 @@ +package restapi + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maglev.onebusaway.org/internal/clock" +) + +// --- Param parsing unit tests --- + +func TestParseArrivalsAndDeparturesForLocationParams_Defaults(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 47.653, params.Lat) + assert.Equal(t, -122.307, params.Lon) + assert.Equal(t, 0.008, params.LatSpan) + assert.Equal(t, 0.008, params.LonSpan) + assert.Equal(t, 5, params.MinutesBefore) + assert.Equal(t, 35, params.MinutesAfter) + assert.Equal(t, 250, params.MaxCount) + assert.WithinDuration(t, api.Clock.Now(), params.Time, time.Second) +} + +func TestParseArrivalsAndDeparturesForLocationParams_CustomValues(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&radius=500&minutesBefore=10&minutesAfter=60&maxCount=50&time=1609459200000", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 47.653, params.Lat) + assert.Equal(t, -122.307, params.Lon) + assert.Equal(t, 500.0, params.Radius) + assert.Equal(t, 10, params.MinutesBefore) + assert.Equal(t, 60, params.MinutesAfter) + assert.Equal(t, 50, params.MaxCount) + assert.False(t, params.Time.IsZero()) +} + +func TestParseArrivalsAndDeparturesForLocationParams_MissingLatLon(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", "/test", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "lat") + assert.Contains(t, errs, "lon") +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidTime(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&time=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "time") + assert.Equal(t, "must be a valid Unix timestamp in milliseconds", errs["time"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMinutesAfter(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesAfter=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesAfter") + assert.Equal(t, "must be a valid integer", errs["minutesAfter"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMinutesBefore(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesBefore") + assert.Equal(t, "must be a valid integer", errs["minutesBefore"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_NegativeMinutes(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=-1&minutesAfter=-5", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesBefore") + assert.Contains(t, errs, "minutesAfter") +} + +func TestParseArrivalsAndDeparturesForLocationParams_MinutesCappedAtMax(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=9999&minutesAfter=9999", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 60, params.MinutesBefore) + assert.Equal(t, 240, params.MinutesAfter) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMaxCount(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&maxCount=0", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "maxCount") +} + +// --- HTTP handler integration tests --- + +func TestArrivalsAndDeparturesForLocationRequiresValidAPIKey(t *testing.T) { + _, resp, model := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=invalid&lat=40.583321&lon=-122.426966&latSpan=0.01&lonSpan=0.01") + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, model.Code) + assert.Equal(t, "permission denied", model.Text) +} + +func TestArrivalsAndDeparturesForLocationMissingLatLon(t *testing.T) { + _, resp, _ := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=TEST") + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestArrivalsAndDeparturesForLocationInvalidTime(t *testing.T) { + _, resp, _ := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&latSpan=0.01&lonSpan=0.01&time=notanumber") + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestArrivalsAndDeparturesForLocationEmptyAreaReturnsOK(t *testing.T) { + // Coordinates far from any test GTFS data so no stops are found. + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=0.0&lon=0.0&latSpan=0.001&lonSpan=0.001") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + assert.Equal(t, "OK", model.Text) + assert.Equal(t, 2, model.Version) + assert.NotZero(t, model.CurrentTime) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + // Entry must contain all expected keys even when empty. + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "nearbyStopIds") + assert.Contains(t, entry, "situationIds") + assert.Contains(t, entry, "limitExceeded") + + ads, ok := entry["arrivalsAndDepartures"].([]interface{}) + require.True(t, ok) + assert.Empty(t, ads) + + stopIDs, ok := entry["stopIds"].([]interface{}) + require.True(t, ok) + assert.Empty(t, stopIDs) + + assert.False(t, entry["limitExceeded"].(bool)) +} + +func TestArrivalsAndDeparturesForLocationEndToEnd(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + assert.Equal(t, "OK", model.Text) + assert.Equal(t, 2, model.Version) + assert.NotZero(t, model.CurrentTime) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok, "data should be a map") + + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok, "entry should be a map") + + // Required entry keys. + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "nearbyStopIds") + assert.Contains(t, entry, "situationIds") + assert.Contains(t, entry, "limitExceeded") + + // nearbyStopIds must be a list of objects with stopId + distanceFromQuery. + nearbyRaw, ok := entry["nearbyStopIds"].([]interface{}) + require.True(t, ok, "nearbyStopIds should be a list") + for _, item := range nearbyRaw { + nearby, ok := item.(map[string]interface{}) + require.True(t, ok, "each nearbyStopIds entry should be an object") + assert.Contains(t, nearby, "stopId") + assert.Contains(t, nearby, "distanceFromQuery") + } + + // stopIds must be a list. + stopIDs, ok := entry["stopIds"].([]interface{}) + require.True(t, ok, "stopIds should be a list") + assert.NotEmpty(t, stopIDs, "should have found stops in this area") + + // References block. + refs, ok := data["references"].(map[string]interface{}) + require.True(t, ok, "references should be a map") + assert.Contains(t, refs, "agencies") + assert.Contains(t, refs, "routes") + assert.Contains(t, refs, "stops") + assert.Contains(t, refs, "trips") + assert.Contains(t, refs, "situations") + + // Validate arrival shape if any were returned. + ads, ok := entry["arrivalsAndDepartures"].([]interface{}) + require.True(t, ok, "arrivalsAndDepartures should be a list") + + if len(ads) == 0 { + t.Skip("no arrivals in test data for this time/location") + } + + ad, ok := ads[0].(map[string]interface{}) + require.True(t, ok, "first arrival should be a map") + + // Required arrival fields. + for _, field := range []string{ + "routeId", "tripId", "stopId", "serviceDate", + "scheduledArrivalTime", "scheduledDepartureTime", + "predictedArrivalTime", "predictedDepartureTime", + "predicted", "status", "situationIds", + "routeShortName", "tripHeadsign", + "arrivalEnabled", "departureEnabled", + "numberOfStopsAway", "distanceFromStop", + "blockTripSequence", "totalStopsInTrip", + "frequency", + } { + assert.Contains(t, ad, field, "arrival must contain field %q", field) + } + + assert.Equal(t, "default", ad["status"]) + + // Every arrival's stopId must be one of the queried stopIds. + stopIDInAD, _ := ad["stopId"].(string) + assert.NotEmpty(t, stopIDInAD) + assert.Contains(t, stopIDs, stopIDInAD, + "arrival stopId should be one of the queried stopIds") +} + +func TestArrivalsAndDeparturesForLocationStopIdsOnlyContainsStopsWithArrivals(t *testing.T) { + // Java only includes a stop in stopIds when it has at least one arrival. + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + stopIDs, _ := entry["stopIds"].([]interface{}) + + if len(ads) == 0 { + // No arrivals → stopIds must also be empty. + assert.Empty(t, stopIDs, "stopIds must be empty when there are no arrivals") + return + } + + // Every stopId in the entry must appear in at least one arrival's stopId field. + arrivalStopIDs := make(map[interface{}]bool) + for _, adRaw := range ads { + if ad, ok := adRaw.(map[string]interface{}); ok { + arrivalStopIDs[ad["stopId"]] = true + } + } + for _, sid := range stopIDs { + assert.True(t, arrivalStopIDs[sid], + "stopId %v in entry.stopIds has no matching arrival", sid) + } +} + +func TestArrivalsAndDeparturesForLocationWithLatSpanLonSpan(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&latSpan=0.045&lonSpan=0.059") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "limitExceeded") +} + +func TestArrivalsAndDeparturesForLocationReferencesConsistency(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + refs, ok := data["references"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + if len(ads) == 0 { + t.Skip("no arrivals in test data for this location") + } + + routeRefs, _ := refs["routes"].([]interface{}) + tripRefs, _ := refs["trips"].([]interface{}) + agencies, _ := refs["agencies"].([]interface{}) + + routeRefIDs := collectAllIdsFromObjects(t, routeRefs, "id") + tripRefIDs := collectAllIdsFromObjects(t, tripRefs, "id") + agencyRefIDs := collectAllIdsFromObjects(t, agencies, "id") + + // Every arrival's routeId and tripId must appear in references. + for _, adRaw := range ads { + ad, ok := adRaw.(map[string]interface{}) + require.True(t, ok) + + routeID, _ := ad["routeId"].(string) + assert.Contains(t, routeRefIDs, routeID, + "every arrival routeId must appear in references.routes") + + tripID, _ := ad["tripId"].(string) + assert.Contains(t, tripRefIDs, tripID, + "every arrival tripId must appear in references.trips") + } + + // Every route's agencyId must appear in references.agencies. + agencyIDsFromRoutes := collectAllIdsFromObjects(t, routeRefs, "agencyId") + for _, aid := range agencyIDsFromRoutes { + assert.Contains(t, agencyRefIDs, aid, + "every route agencyId must appear in references.agencies") + } +} + +func TestArrivalsAndDeparturesForLocationArrivalsAreSortedByTime(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + if len(ads) < 2 { + t.Skip("need at least 2 arrivals to test sort order") + } + + var prevTime float64 + for i, adRaw := range ads { + ad, ok := adRaw.(map[string]interface{}) + require.True(t, ok) + + predicted, _ := ad["predicted"].(bool) + var arrTime float64 + if predicted { + arrTime, _ = ad["predictedArrivalTime"].(float64) + } + if arrTime == 0 { + arrTime, _ = ad["scheduledArrivalTime"].(float64) + } + + if i > 0 { + assert.GreaterOrEqual(t, arrTime, prevTime, + "arrivals must be sorted ascending by arrival time (index %d)", i) + } + prevTime = arrTime + } +} + +func TestArrivalsAndDeparturesForLocationLimitExceeded(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + // maxCount=1 forces limitExceeded=true if there is more than 1 arrival. + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500&maxCount=1") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + assert.LessOrEqual(t, len(ads), 1) +} diff --git a/internal/restapi/arrivals_and_departures_for_stop_handler.go b/internal/restapi/arrivals_and_departures_for_stop_handler.go index 3451a2eb..1e218bee 100644 --- a/internal/restapi/arrivals_and_departures_for_stop_handler.go +++ b/internal/restapi/arrivals_and_departures_for_stop_handler.go @@ -339,7 +339,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r if vehicle != nil { // Use route.AgencyID instead of stopAgencyID for BuildTripStatus - status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, nil, serviceMidnight, params.Time) + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, params.Time) if statusErr != nil { api.Logger.Warn("BuildTripStatus failed for arrival", "tripID", st.TripID, "error", statusErr) diff --git a/internal/restapi/context_cancellation_test.go b/internal/restapi/context_cancellation_test.go index 0ce8124b..386326de 100644 --- a/internal/restapi/context_cancellation_test.go +++ b/internal/restapi/context_cancellation_test.go @@ -41,6 +41,11 @@ func TestContextCancellationHandling(t *testing.T) { endpoint: "/api/where/stops-for-location.json?lat=38.9&lon=-77.0&key=test", timeout: 1 * time.Nanosecond, }, + { + name: "arrivals and departures for location should handle context cancellation", + endpoint: "/api/where/arrivals-and-departures-for-location.json?lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01&key=test", + timeout: 1 * time.Nanosecond, + }, { name: "stops for route should handle context cancellation", endpoint: "/api/where/stops-for-route/1?key=test", @@ -76,15 +81,17 @@ func TestContextCancellationHandling(t *testing.T) { // If cancelled, we expect a timeout or cancellation error response statusCode := w.Code - // Valid responses: 200 (completed), 401 (API validation), 500 (error), or timeout-related + // Valid responses: 200 (completed), 401 (API validation), 500 (error), timeout-related, + // or 429 (rate limit — valid when many sub-tests exhaust the rate limiter). assert.True(t, statusCode == http.StatusOK || statusCode == http.StatusUnauthorized || // API key validation happens first statusCode == http.StatusBadRequest || statusCode == http.StatusInternalServerError || statusCode == http.StatusRequestTimeout || statusCode == http.StatusGatewayTimeout || + statusCode == http.StatusTooManyRequests || // rate limit is valid under load statusCode == http.StatusNotFound, - "Expected status 200, 401, 404, 500, 408, or 504, got %d", statusCode) + "Expected status 200, 401, 404, 429, 500, 408, or 504, got %d", statusCode) }) } } diff --git a/internal/restapi/input_validation_integration_test.go b/internal/restapi/input_validation_integration_test.go index d7d8a269..e635213e 100644 --- a/internal/restapi/input_validation_integration_test.go +++ b/internal/restapi/input_validation_integration_test.go @@ -157,6 +157,44 @@ func TestInputValidationIntegration(t *testing.T) { expectedStatus: http.StatusBadRequest, expectedError: "invalid date format", }, + + // Test arrivals-and-departures-for-location parameter validation + { + name: "arrivals-for-location: invalid latitude too high", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=91.0&lon=-77.0&radius=500", + expectedStatus: http.StatusBadRequest, + expectedError: "latitude must be between -90 and 90", + }, + { + name: "arrivals-for-location: invalid longitude too high", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.0&lon=181.0&radius=500", + expectedStatus: http.StatusBadRequest, + expectedError: "longitude must be between -180 and 180", + }, + { + name: "arrivals-for-location: missing lat and lon", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST", + expectedStatus: http.StatusBadRequest, + expectedError: "", + }, + { + name: "arrivals-for-location: invalid minutesAfter", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&minutesAfter=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid integer", + }, + { + name: "arrivals-for-location: invalid minutesBefore", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&minutesBefore=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid integer", + }, + { + name: "arrivals-for-location: invalid time", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&time=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid Unix timestamp in milliseconds", + }, } for _, tt := range tests { @@ -253,6 +291,18 @@ func TestValidInputsPassThrough(t *testing.T) { name: "Valid location with span parameters", endpoint: "/api/where/stops-for-location.json?key=TEST&lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01", }, + { + name: "Valid arrivals-for-location with radius", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=1000", + }, + { + name: "Valid arrivals-for-location with latSpan and lonSpan", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01", + }, + { + name: "Valid arrivals-for-location with custom time window", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=1000&minutesBefore=2&minutesAfter=20", + }, } for _, tt := range validTests { diff --git a/internal/restapi/routes.go b/internal/restapi/routes.go index 82c9de78..4a76e917 100644 --- a/internal/restapi/routes.go +++ b/internal/restapi/routes.go @@ -117,6 +117,7 @@ func (api *RestAPI) SetRoutes(mux *http.ServeMux) { mux.Handle("GET /api/where/arrival-and-departure-for-stop/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalAndDepartureForStopHandler))) mux.Handle("GET /api/where/trips-for-route/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.tripsForRouteHandler))) mux.Handle("GET /api/where/arrivals-and-departures-for-stop/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalsAndDeparturesForStopHandler))) + mux.Handle("GET /api/where/arrivals-and-departures-for-location.json", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalsAndDeparturesForLocationHandler))) } // SetupAPIRoutes creates and configures the API router with all middleware applied globally From 7c79faf0a4b388415e4fb8b5c378f4e58466b20f Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 29 Mar 2026 01:42:19 +0530 Subject: [PATCH 02/34] fix double prefix in situation id --- internal/restapi/arrivals_and_departures_for_location.go | 8 ++++---- .../restapi/arrivals_and_departures_for_stop_handler.go | 4 ++-- internal/restapi/trips_helper.go | 6 +----- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index c0a684e4..11dd7f6b 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -526,16 +526,16 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite totalStopsInTrip := tripStopCountMap[st.TripID] blockTripSequence := api.calculateBlockTripSequence(ctx, st.TripID, serviceMidnight) - // Collect service alerts for this trip. - // situationIDs on each arrival use FormCombinedID(route.AgencyID, alert.ID) - // to match Java's "agencyId_alertId" format (e.g. "40_40_16737"). + // alert.ID from GTFS-RT already contains the agency prefix (e.g. "40_16931"). + // Do NOT wrap with FormCombinedID — that would double-prefix to "40_40_16931". + // Both per-arrival situationIds and top-level situationIds use the raw alert.ID. tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, st.TripID) situationIDs := make([]string, 0, len(tripAlerts)) for _, alert := range tripAlerts { if alert.ID == "" { continue } - situationIDs = append(situationIDs, utils.FormCombinedID(route.AgencyID, alert.ID)) + situationIDs = append(situationIDs, alert.ID) if _, seen := collectedAlerts[alert.ID]; !seen { collectedAlerts[alert.ID] = alert } diff --git a/internal/restapi/arrivals_and_departures_for_stop_handler.go b/internal/restapi/arrivals_and_departures_for_stop_handler.go index 1e218bee..a030a54d 100644 --- a/internal/restapi/arrivals_and_departures_for_stop_handler.go +++ b/internal/restapi/arrivals_and_departures_for_stop_handler.go @@ -417,7 +417,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r continue } - situationIDs = append(situationIDs, utils.FormCombinedID(route.AgencyID, alert.ID)) + situationIDs = append(situationIDs, alert.ID) if _, seen := collectedAlerts[alert.ID]; !seen { collectedAlerts[alert.ID] = alert } @@ -618,7 +618,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r topLevelSituationIDSet := make(map[string]struct{}, len(collectedAlerts)) for alertID := range collectedAlerts { - topLevelSituationIDSet[utils.FormCombinedID(alertAgencyID, alertID)] = struct{}{} + topLevelSituationIDSet[alertID] = struct{}{} } topLevelSituationIDs := make([]string, 0, len(topLevelSituationIDSet)) for id := range topLevelSituationIDSet { diff --git a/internal/restapi/trips_helper.go b/internal/restapi/trips_helper.go index 8d9fc41e..5aff5d96 100644 --- a/internal/restapi/trips_helper.go +++ b/internal/restapi/trips_helper.go @@ -758,11 +758,7 @@ func (api *RestAPI) GetSituationIDsForTrip(ctx context.Context, tripID string) [ if alert.ID == "" { continue } - if agencyID != "" { - situationIDs = append(situationIDs, utils.FormCombinedID(agencyID, alert.ID)) - } else { - situationIDs = append(situationIDs, alert.ID) - } + situationIDs = append(situationIDs, alert.ID) } return situationIDs From f7a34903099505fbf18970c5232856314d0ccc1c Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 29 Mar 2026 01:54:48 +0530 Subject: [PATCH 03/34] fixes --- .../arrivals_and_departures_for_location.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 11dd7f6b..e08bf200 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -228,7 +228,8 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite // Shared reference-collection maps (deduplicated across all stops). tripIDSet := make(map[string]*gtfsdb.Trip) routeIDSet := make(map[string]*gtfsdb.Route) - stopIDSet := make(map[string]bool) // raw stop codes for reference building + stopIDSet := make(map[string]bool) // raw stop codes for reference building + stopAgencyOverride := make(map[string]string) // raw stop code → correct agencyID // Track which stop codes actually produced at least one arrival. // Java only includes a stop in the entry's stopIds when it has results. @@ -480,13 +481,19 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite // Collect stops referenced in trip status for the references block. if status.NextStop != "" { - if _, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { + if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { stopIDSet[nsID] = true + if nsAgency != "" { + stopAgencyOverride[nsID] = nsAgency + } } } if status.ClosestStop != "" { - if _, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { + if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { stopIDSet[csID] = true + if csAgency != "" { + stopAgencyOverride[csID] = csAgency + } } } @@ -703,6 +710,9 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite continue } ag := stopAgencyMap[sid] + if ag == "" { + ag = stopAgencyOverride[sid] + } if ag == "" { ag = fallbackAgencyID } From 2ee7f01341e469eb6a9844733ff00d3ee3fa897a Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 29 Mar 2026 09:01:01 +0530 Subject: [PATCH 04/34] feat: achieve full parity for arrivals-and-departures-for-location endpoint - Add parsing and handling for emptyReturnsNotFound parameter - Add parsing and GTFS filtering for routeType parameter - Add frequencyMinutesBefore and frequencyMinutesAfter parsing - Fix nearbyStopIds logic to match Java's includeInputIdsInNearby=true override --- .../arrivals_and_departures_for_location.go | 91 +++++++++++++------ ...rivals_and_departures_for_location_test.go | 38 ++++++++ 2 files changed, 103 insertions(+), 26 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index e08bf200..55abe88e 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -6,6 +6,7 @@ import ( "net/http" "sort" "strconv" + "strings" "time" "github.com/OneBusAway/go-gtfs" @@ -23,11 +24,15 @@ type ArrivalsAndDeparturesForLocationParams struct { LatSpan float64 LonSpan float64 - Time time.Time - MinutesBefore int - MinutesAfter int + Time time.Time + MinutesBefore int + MinutesAfter int + FrequencyMinutesBefore int + FrequencyMinutesAfter int - MaxCount int + MaxCount int + EmptyReturnsNotFound bool + RouteTypes []int } // parseArrivalsAndDeparturesForLocationParams parses and validates all query @@ -110,6 +115,54 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) } } + // frequencyMinutesBefore + if val := q.Get("frequencyMinutesBefore"); val != "" { + if n, err := strconv.Atoi(val); err != nil { + addError("frequencyMinutesBefore", "must be a valid integer") + } else if n < 0 { + addError("frequencyMinutesBefore", "must be a non-negative integer") + } else { + params.FrequencyMinutesBefore = n + } + } + + // frequencyMinutesAfter + if val := q.Get("frequencyMinutesAfter"); val != "" { + if n, err := strconv.Atoi(val); err != nil { + addError("frequencyMinutesAfter", "must be a valid integer") + } else if n < 0 { + addError("frequencyMinutesAfter", "must be a non-negative integer") + } else { + params.FrequencyMinutesAfter = n + } + } + + // emptyReturnsNotFound + if val := q.Get("emptyReturnsNotFound"); val != "" { + if b, err := strconv.ParseBool(val); err == nil { + params.EmptyReturnsNotFound = b + } else { + addError("emptyReturnsNotFound", "must be true or false") + } + } + + // routeType + if val := q.Get("routeType"); val != "" { + parts := strings.Split(val, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if rt, err := strconv.Atoi(p); err == nil { + params.RouteTypes = append(params.RouteTypes, rt) + } else { + addError("routeType", "must be a comma-delimited list of integers") + break + } + } + } + // maxCount — reuse the shared parser. var maxCountErrors map[string][]string params.MaxCount, maxCountErrors = utils.ParseMaxCount(q, defaultMaxCount, nil) @@ -170,11 +223,15 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite "", params.MaxCount, false, - []int{}, + params.RouteTypes, // Pass parsed routeTypes params.Time, ) if len(stops) == 0 { + if params.EmptyReturnsNotFound { + api.sendNotFound(w, r) + return + } api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( []models.ArrivalAndDeparture{}, *models.NewEmptyReferences(), @@ -795,10 +852,9 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite // Build nearbyStopIds as []StopWithDistance. // - // FIX #3: pass queriedStopIDs so that stops already in the entry's stopIds - // are excluded from nearbyStopIds. Java's includeInputIdsInNearby=false - // default means the bbox stops must not appear in both lists. - nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time, queriedStopIDs) + // FIX #3: Java's includeInputIdsInNearby is overridden to true in this endpoint, + // so we DO NOT exclude queried stops from the nearby stops list. + nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time) api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( arrivals, @@ -816,16 +872,11 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite // // Java equivalent: getNearbyStops() in StopWithArrivalsAndDeparturesBeanServiceImpl, // which calls SphericalGeometryLibrary.distance() to populate distanceFromQuery. -// -// FIX #3: queriedStopIDs are excluded from the result so that stops already -// present in entry.stopIds do not also appear in entry.nearbyStopIds. -// This matches Java's includeInputIdsInNearby=false default behaviour. func getLocationNearbyStops( api *RestAPI, ctx context.Context, centerLat, centerLon float64, queryTime time.Time, - queriedStopIDs []string, // stops already in entry.stopIds — must be excluded ) []models.StopWithDistance { nearby := api.GtfsManager.GetStopsForLocation( @@ -857,13 +908,6 @@ func getLocationNearbyStops( // pickPrimaryAgency over the nearby set for stops with no resolved agency. nearbyFallback := pickPrimaryAgency(nearbyAgencyMap) - // Build a set of already-queried combined stop IDs for O(1) lookup. - // FIX #3: Java excludes these via includeInputIdsInNearby=false. - queriedSet := make(map[string]bool, len(queriedStopIDs)) - for _, id := range queriedStopIDs { - queriedSet[id] = true - } - result := make([]models.StopWithDistance, 0, len(nearby)) for _, s := range nearby { ag := nearbyFallback @@ -872,11 +916,6 @@ func getLocationNearbyStops( } combinedID := utils.FormCombinedID(ag, s.ID) - // FIX #3: skip stops that are already in entry.stopIds. - if queriedSet[combinedID] { - continue - } - dist := utils.Distance(centerLat, centerLon, s.Lat, s.Lon) result = append(result, models.StopWithDistance{ StopID: combinedID, diff --git a/internal/restapi/arrivals_and_departures_for_location_test.go b/internal/restapi/arrivals_and_departures_for_location_test.go index c43ba3ff..ccd383f4 100644 --- a/internal/restapi/arrivals_and_departures_for_location_test.go +++ b/internal/restapi/arrivals_and_departures_for_location_test.go @@ -140,6 +140,33 @@ func TestParseArrivalsAndDeparturesForLocationParams_InvalidMaxCount(t *testing. // --- HTTP handler integration tests --- +func TestParseArrivalsAndDeparturesForLocationParams_FrequencyAndRouteType(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&radius=500&frequencyMinutesBefore=15&frequencyMinutesAfter=45&emptyReturnsNotFound=true&routeType=1,3", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 15, params.FrequencyMinutesBefore) + assert.Equal(t, 45, params.FrequencyMinutesAfter) + assert.True(t, params.EmptyReturnsNotFound) + assert.Equal(t, []int{1, 3}, params.RouteTypes) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidRouteType(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.6&lon=-122.3&routeType=1,abc", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "routeType") +} + func TestArrivalsAndDeparturesForLocationRequiresValidAPIKey(t *testing.T) { _, resp, model := serveAndRetrieveEndpoint(t, "/api/where/arrivals-and-departures-for-location.json?key=invalid&lat=40.583321&lon=-122.426966&latSpan=0.01&lonSpan=0.01") @@ -201,6 +228,17 @@ func TestArrivalsAndDeparturesForLocationEmptyAreaReturnsOK(t *testing.T) { assert.False(t, entry["limitExceeded"].(bool)) } +func TestArrivalsAndDeparturesForLocationEmptyReturnsNotFound(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=0.0&lon=0.0&radius=100&emptyReturnsNotFound=true") + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + assert.Equal(t, 404, model.Code) +} + func TestArrivalsAndDeparturesForLocationEndToEnd(t *testing.T) { mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) api := createTestApiWithClock(t, mockClock) From 4c7470707ccdda9f9ab1684d43c999eb1bda174c Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 26 Apr 2026 15:46:53 +0530 Subject: [PATCH 05/34] fix: correct lastUpdateTime serialization, nearbyStopIds radius, and stopSequence offset in arrivals-and-departures-for-location --- internal/models/arrival_and_departure.go | 2 +- .../arrivals_and_departures_for_location.go | 105 +++++++++--------- 2 files changed, 52 insertions(+), 55 deletions(-) diff --git a/internal/models/arrival_and_departure.go b/internal/models/arrival_and_departure.go index b6d7825f..01a631ce 100644 --- a/internal/models/arrival_and_departure.go +++ b/internal/models/arrival_and_departure.go @@ -12,7 +12,7 @@ type ArrivalAndDeparture struct { DistanceFromStop float64 `json:"distanceFromStop"` Frequency *Frequency `json:"frequency"` HistoricalOccupancy string `json:"historicalOccupancy"` - LastUpdateTime ModelTime `json:"lastUpdateTime,omitzero"` + LastUpdateTime ModelTime `json:"lastUpdateTime"` NumberOfStopsAway int `json:"numberOfStopsAway"` OccupancyStatus string `json:"occupancyStatus"` Predicted bool `json:"predicted"` diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 55abe88e..20e1f58a 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -11,6 +11,7 @@ import ( "github.com/OneBusAway/go-gtfs" "maglev.onebusaway.org/gtfsdb" + internalgtfs "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -210,21 +211,20 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite return } - api.GtfsManager.RLock() - defer api.GtfsManager.RUnlock() - - // Find stops inside the bounding box using the in-memory R-tree spatial index. - // Pass params.Time so that a historical `time=` override is respected. - stops := api.GtfsManager.GetStopsForLocation( + // Find stops inside the bounding box using the spatial index. + // GetStopsForLocation manages its own locking and returns (stops, limitExceeded). + stops, _ := api.GtfsManager.GetStopsForLocation( ctx, - params.Lat, params.Lon, - params.Radius, - params.LatSpan, params.LonSpan, + &internalgtfs.LocationParams{ + Lat: params.Lat, + Lon: params.Lon, + Radius: params.Radius, + LatSpan: params.LatSpan, + LonSpan: params.LonSpan, + }, "", params.MaxCount, - false, - params.RouteTypes, // Pass parsed routeTypes - params.Time, + params.RouteTypes, ) if len(stops) == 0 { @@ -346,7 +346,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) serviceDateStr := targetDate.Format("20060102") - activeServiceIDs, svcErr := api.GtfsManager.GetActiveServiceIDsForDateCached(ctx, serviceDateStr) + activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) if svcErr != nil { api.Logger.Warn("failed to query active service IDs", slog.String("date", serviceDateStr), @@ -461,7 +461,6 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite st := ast.GetStopTimesForStopInWindowRow serviceMidnight := ast.ServiceDate - serviceDateMillis := serviceMidnight.UnixMilli() route, routeOK := routesLookup[st.RouteID] if !routeOK { @@ -481,18 +480,18 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite tCopy := trip tripIDSet[trip.ID] = &tCopy - scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)).UnixMilli() - scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)).UnixMilli() + scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) + scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) var ( - predictedArrivalTime int64 - predictedDepartureTime int64 + predictedArrivalTime time.Time + predictedDepartureTime time.Time predicted = false vehicleID string tripStatus *models.TripStatus distanceFromStop = 0.0 numberOfStopsAway = 0 - lastUpdateTime int64 // always emitted; 0 when no vehicle + lastUpdateTime time.Time // FIX #4: derive status from schedule deviation instead of // always emitting "default". Falls back to "default" when @@ -505,19 +504,16 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite vehicleID = vehicle.ID.ID } - schedArrTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) - schedDepTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) - predArr, predDep, isPredicted := api.getPredictedTimes( st.TripID, stopCode, int64(st.StopSequence), - schedArrTime, schedDepTime, + scheduledArrivalTime, scheduledDepartureTime, ) if isPredicted { predicted = true predictedArrivalTime = predArr predictedDepartureTime = predDep } - // When not predicted, leave predictedArrivalTime/predictedDepartureTime as 0 + // When not predicted, leave predictedArrivalTime/predictedDepartureTime as zero time.Time // (matches Java which emits 0 for unpredicted arrivals). // Gate BuildTripStatus on vehicle presence — matches the stop handler convention. @@ -581,10 +577,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } } - rawUpdate := api.GtfsManager.GetVehicleLastUpdateTime(vehicle) - if rawUpdate > 0 { - lastUpdateTime = rawUpdate - } + lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) } totalStopsInTrip := tripStopCountMap[st.TripID] @@ -605,24 +598,15 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } } - // lastUpdateTimePtr: use a pointer so the model's omitempty drops it only - // when the value is truly absent. We pass nil when lastUpdateTime==0 to - // preserve the existing model behaviour (field omitted rather than 0). - var lastUpdateTimePtr *int64 - if lastUpdateTime > 0 { - lastUpdateTimePtr = utils.Int64Ptr(lastUpdateTime) - } - // vehicleID must carry the agency prefix to match Java output ("1_6853"). formattedVehicleID := "" if vehicleID != "" { formattedVehicleID = utils.FormCombinedID(route.AgencyID, vehicleID) } - // FIX #2: use raw GTFS stop_sequence (1-based) — do NOT subtract 1. - // Java's ArrivalAndDepartureBean.getStopSequence() returns the raw - // GTFS value directly; there is no zero-indexing in the wire format. - rawStopSequence := int(st.StopSequence) + // stopSequence is zero-based on the wire (matching Java OBA and the + // arrivals-and-departures-for-stop handler which uses StopSequence-1). + rawStopSequence := int(st.StopSequence) - 1 arrivals = append(arrivals, *models.NewArrivalAndDeparture( utils.FormCombinedID(route.AgencyID, route.ID), // routeID @@ -632,16 +616,16 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite st.TripHeadsign.String, // tripHeadsign combinedStopID, // stopID formattedVehicleID, // vehicleID (agency-prefixed or empty) - serviceDateMillis, // serviceDate + serviceMidnight, // serviceDate scheduledArrivalTime, // scheduledArrivalTime scheduledDepartureTime, // scheduledDepartureTime - predictedArrivalTime, // predictedArrivalTime (0 when unpredicted) - predictedDepartureTime, // predictedDepartureTime (0 when unpredicted) - lastUpdateTimePtr, // lastUpdateTime + predictedArrivalTime, // predictedArrivalTime (zero when unpredicted) + predictedDepartureTime, // predictedDepartureTime (zero when unpredicted) + lastUpdateTime, // lastUpdateTime predicted, // predicted true, // arrivalEnabled true, // departureEnabled - rawStopSequence, // FIX #2: raw GTFS stop_sequence, not zero-based + rawStopSequence, // stopSequence (zero-based, matching stop handler) totalStopsInTrip, // totalStopsInTrip numberOfStopsAway, // numberOfStopsAway blockTripSequence, // blockTripSequence @@ -664,15 +648,20 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite // Sort arrivals by predicted (or scheduled) arrival time ascending. // Matches Java's ArrivalAndDepartureComparator. sort.Slice(arrivals, func(i, j int) bool { - ti := arrivals[i].PredictedArrivalTime - if ti == 0 { - ti = arrivals[i].ScheduledArrivalTime + ai := arrivals[i] + aj := arrivals[j] + var ti, tj time.Time + if !ai.PredictedArrivalTime.IsZero() { + ti = ai.PredictedArrivalTime.Time + } else { + ti = ai.ScheduledArrivalTime.Time } - tj := arrivals[j].PredictedArrivalTime - if tj == 0 { - tj = arrivals[j].ScheduledArrivalTime + if !aj.PredictedArrivalTime.IsZero() { + tj = aj.PredictedArrivalTime.Time + } else { + tj = aj.ScheduledArrivalTime.Time } - return ti < tj + return ti.Before(tj) }) // Build references block (agencies, routes, stops, trips, situations). @@ -879,8 +868,16 @@ func getLocationNearbyStops( queryTime time.Time, ) []models.StopWithDistance { - nearby := api.GtfsManager.GetStopsForLocation( - ctx, centerLat, centerLon, 100, 0, 0, "", 250, false, []int{}, queryTime, + nearby, _ := api.GtfsManager.GetStopsForLocation( + ctx, + &internalgtfs.LocationParams{ + Lat: centerLat, + Lon: centerLon, + Radius: models.DefaultSearchRadiusInMeters, + }, + "", + 250, + []int{}, ) if len(nearby) == 0 { From 3c2170271294ef1a4411a85511a852869e75bd62 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 26 Apr 2026 15:59:29 +0530 Subject: [PATCH 06/34] style: replace map[string]interface{} with map[string]any in response.go for consistency --- internal/models/response.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/models/response.go b/internal/models/response.go index 91676712..9c96c585 100644 --- a/internal/models/response.go +++ b/internal/models/response.go @@ -77,14 +77,14 @@ func NewArrivalsAndDeparturesForLocationResponse( if stopIds == nil { stopIds = []string{} } - entryData := map[string]interface{}{ + entryData := map[string]any{ "arrivalsAndDepartures": arrivalsAndDepartures, "limitExceeded": limitExceeded, "nearbyStopIds": nearbyStopIds, "situationIds": situationIds, "stopIds": stopIds, } - data := map[string]interface{}{ + data := map[string]any{ "entry": entryData, "references": references, } From 3b9bffcc6e90daec1c90d23fd232abf602586078 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 26 Apr 2026 16:30:19 +0530 Subject: [PATCH 07/34] fix: block-based prediction propagation and blockTripSequence parity --- .../arrivals_and_departures_for_location.go | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 20e1f58a..85621f32 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -526,8 +526,21 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite if status != nil { tripStatus = status - // Only set a meaningful status when the arrival is predicted. - // Unpredicted arrivals stay "default". + // Block-based prediction propagation: when the GTFS-RT feed only + // has data for the active (preceding) block trip, getPredictedTimes + // above returns isPredicted=false for the scheduled (future) trip. + // If tripStatus.Predicted is true, the vehicle IS tracked — apply + // the active trip's schedule deviation to the scheduled times. + // This matches Java OBA's ArrivalAndDepartureServiceImpl which + // propagates block-level delay to future block trips. + if !predicted && status.Predicted { + dev := time.Duration(status.ScheduleDeviation) * time.Second + predictedArrivalTime = scheduledArrivalTime.Add(dev) + predictedDepartureTime = scheduledDepartureTime.Add(dev) + predicted = true + } + + // status field: derive from deviation only when predicted (direct or block-propagated). if predicted { arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) } @@ -561,15 +574,25 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } // Ensure the active trip (if different from scheduled) is in references. + // Also override tripStatus.BlockTripSequence to use the ACTIVE trip's + // block sequence — BuildTripStatus computes it from the scheduled (target) + // tripID, but Java OBA emits the active trip's sequence on the wire. if status.ActiveTripID != "" { - if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil && atID != st.TripID { - if _, exists := tripIDSet[atID]; !exists { - if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { - atCopy := at - tripIDSet[at.ID] = &atCopy - if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { - arCopy := ar - routeIDSet[ar.ID] = &arCopy + if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil { + // Override BlockTripSequence to the active trip's sequence. + if activeSeq := api.calculateBlockTripSequence(ctx, atID, serviceMidnight); activeSeq > 0 { + status.BlockTripSequence = activeSeq + } + + if atID != st.TripID { + if _, exists := tripIDSet[atID]; !exists { + if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { + atCopy := at + tripIDSet[at.ID] = &atCopy + if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { + arCopy := ar + routeIDSet[ar.ID] = &arCopy + } } } } From 2ae4e5a52d166851ab912bd75859f2c9f2adbba0 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 18:21:54 +0530 Subject: [PATCH 08/34] added ASC to query.sql --- gtfsdb/query.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index 94a98ba9..2a78561c 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -674,7 +674,7 @@ FROM JOIN agencies a ON routes.agency_id = a.id WHERE stop_times.stop_id IN (sqlc.slice('stop_ids')) -ORDER BY a.id, stop_times.stop_id; +ORDER BY a.id ASC, stop_times.stop_id ASC; -- name: GetStopTimesForTrip :many SELECT From 074a94e455f9f3e9d141dcd9c4a0e835f6a466b9 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 18:32:23 +0530 Subject: [PATCH 09/34] internal/restapi/arrivals_and_departures_for_location.go --- .../arrivals_and_departures_for_location.go | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 85621f32..269e3d2e 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -13,6 +13,7 @@ import ( "maglev.onebusaway.org/gtfsdb" internalgtfs "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" + "maglev.onebusaway.org/internal/nulls" "maglev.onebusaway.org/internal/utils" ) @@ -40,11 +41,13 @@ type ArrivalsAndDeparturesForLocationParams struct { // parameters for this endpoint in one place. func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) (ArrivalsAndDeparturesForLocationParams, map[string][]string) { const ( - defaultMinutesBefore = 5 - defaultMinutesAfter = 35 - maxMinutesBefore = 60 - maxMinutesAfter = 240 - defaultMaxCount = 250 + defaultMinutesBefore = 5 + defaultMinutesAfter = 35 + maxMinutesBefore = 60 + maxMinutesAfter = 240 + defaultMaxCount = 250 + errMustBeValidInteger = "must be a valid integer" + errMustBeNonNegativeInteger = "must be a non-negative integer" ) params := ArrivalsAndDeparturesForLocationParams{ @@ -93,9 +96,9 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) // minutesBefore if val := q.Get("minutesBefore"); val != "" { if n, err := strconv.Atoi(val); err != nil { - addError("minutesBefore", "must be a valid integer") + addError("minutesBefore", errMustBeValidInteger) } else if n < 0 { - addError("minutesBefore", "must be a non-negative integer") + addError("minutesBefore", errMustBeNonNegativeInteger) } else if n > maxMinutesBefore { params.MinutesBefore = maxMinutesBefore } else { @@ -106,9 +109,9 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) // minutesAfter if val := q.Get("minutesAfter"); val != "" { if n, err := strconv.Atoi(val); err != nil { - addError("minutesAfter", "must be a valid integer") + addError("minutesAfter", errMustBeValidInteger) } else if n < 0 { - addError("minutesAfter", "must be a non-negative integer") + addError("minutesAfter", errMustBeNonNegativeInteger) } else if n > maxMinutesAfter { params.MinutesAfter = maxMinutesAfter } else { @@ -119,9 +122,9 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) // frequencyMinutesBefore if val := q.Get("frequencyMinutesBefore"); val != "" { if n, err := strconv.Atoi(val); err != nil { - addError("frequencyMinutesBefore", "must be a valid integer") + addError("frequencyMinutesBefore", errMustBeValidInteger) } else if n < 0 { - addError("frequencyMinutesBefore", "must be a non-negative integer") + addError("frequencyMinutesBefore", errMustBeNonNegativeInteger) } else { params.FrequencyMinutesBefore = n } @@ -130,9 +133,9 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) // frequencyMinutesAfter if val := q.Get("frequencyMinutesAfter"); val != "" { if n, err := strconv.Atoi(val); err != nil { - addError("frequencyMinutesAfter", "must be a valid integer") + addError("frequencyMinutesAfter", errMustBeValidInteger) } else if n < 0 { - addError("frequencyMinutesAfter", "must be a non-negative integer") + addError("frequencyMinutesAfter", errMustBeNonNegativeInteger) } else { params.FrequencyMinutesAfter = n } @@ -812,7 +815,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite Code: stopData.Code.String, Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), LocationType: int(stopData.LocationType.Int64), - WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), + WheelchairBoarding: utils.MapWheelchairBoarding(nulls.WheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, StaticRouteIDs: combinedRouteIDs, }) From b34b613b57bab7ed890e628d61d39ce901b80adc Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 18:34:22 +0530 Subject: [PATCH 10/34] ran fmt --- .../restapi/arrivals_and_departures_for_location.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 269e3d2e..42ccaa8b 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -41,12 +41,12 @@ type ArrivalsAndDeparturesForLocationParams struct { // parameters for this endpoint in one place. func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) (ArrivalsAndDeparturesForLocationParams, map[string][]string) { const ( - defaultMinutesBefore = 5 - defaultMinutesAfter = 35 - maxMinutesBefore = 60 - maxMinutesAfter = 240 - defaultMaxCount = 250 - errMustBeValidInteger = "must be a valid integer" + defaultMinutesBefore = 5 + defaultMinutesAfter = 35 + maxMinutesBefore = 60 + maxMinutesAfter = 240 + defaultMaxCount = 250 + errMustBeValidInteger = "must be a valid integer" errMustBeNonNegativeInteger = "must be a non-negative integer" ) From faeb9e1221d38460af3f656d135368e93f0e4385 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 18:45:49 +0530 Subject: [PATCH 11/34] rafactor arrivalsAndDeparturesForLocationHandler as per sonar cloud --- .../arrivals_and_departures_for_location.go | 214 ++++++++++-------- 1 file changed, 120 insertions(+), 94 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 42ccaa8b..a97770a5 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "net/http" + "net/url" "sort" "strconv" "strings" @@ -37,17 +38,21 @@ type ArrivalsAndDeparturesForLocationParams struct { RouteTypes []int } +// Error message constants shared by the parameter-parsing helpers below. +const ( + errMustBeValidInteger = "must be a valid integer" + errMustBeNonNegativeInteger = "must be a non-negative integer" +) + // parseArrivalsAndDeparturesForLocationParams parses and validates all query // parameters for this endpoint in one place. func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) (ArrivalsAndDeparturesForLocationParams, map[string][]string) { const ( - defaultMinutesBefore = 5 - defaultMinutesAfter = 35 - maxMinutesBefore = 60 - maxMinutesAfter = 240 - defaultMaxCount = 250 - errMustBeValidInteger = "must be a valid integer" - errMustBeNonNegativeInteger = "must be a non-negative integer" + defaultMinutesBefore = 5 + defaultMinutesAfter = 35 + maxMinutesBefore = 60 + maxMinutesAfter = 240 + defaultMaxCount = 250 ) params := ArrivalsAndDeparturesForLocationParams{ @@ -68,12 +73,7 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) // Spatial params (required) — reuse the shared location parser. loc, locErrors := api.parseLocationParams(r, nil) if len(locErrors) > 0 { - if fieldErrors == nil { - fieldErrors = make(map[string][]string) - } - for k, v := range locErrors { - fieldErrors[k] = append(fieldErrors[k], v...) - } + mergeFieldErrors(&fieldErrors, locErrors) } else { params.Lat = loc.Lat params.Lon = loc.Lon @@ -83,103 +83,129 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) } q := r.URL.Query() + params.Time = parseTimeParam(q, params.Time, addError) + parseMinutesCappedParam(q, "minutesBefore", maxMinutesBefore, ¶ms.MinutesBefore, addError) + parseMinutesCappedParam(q, "minutesAfter", maxMinutesAfter, ¶ms.MinutesAfter, addError) + parseMinutesUncappedParam(q, "frequencyMinutesBefore", ¶ms.FrequencyMinutesBefore, addError) + parseMinutesUncappedParam(q, "frequencyMinutesAfter", ¶ms.FrequencyMinutesAfter, addError) + params.EmptyReturnsNotFound = parseEmptyReturnsNotFoundParam(q, addError) + params.RouteTypes = parseRouteTypesParam(q, addError) - // time - if val := q.Get("time"); val != "" { - if ms, err := strconv.ParseInt(val, 10, 64); err == nil { - params.Time = time.Unix(ms/1000, (ms%1000)*1_000_000) - } else { - addError("time", "must be a valid Unix timestamp in milliseconds") - } - } + var maxCountErrors map[string][]string + params.MaxCount, maxCountErrors = utils.ParseMaxCount(q, defaultMaxCount, nil) + mergeFieldErrors(&fieldErrors, maxCountErrors) - // minutesBefore - if val := q.Get("minutesBefore"); val != "" { - if n, err := strconv.Atoi(val); err != nil { - addError("minutesBefore", errMustBeValidInteger) - } else if n < 0 { - addError("minutesBefore", errMustBeNonNegativeInteger) - } else if n > maxMinutesBefore { - params.MinutesBefore = maxMinutesBefore - } else { - params.MinutesBefore = n - } - } + return params, fieldErrors +} - // minutesAfter - if val := q.Get("minutesAfter"); val != "" { - if n, err := strconv.Atoi(val); err != nil { - addError("minutesAfter", errMustBeValidInteger) - } else if n < 0 { - addError("minutesAfter", errMustBeNonNegativeInteger) - } else if n > maxMinutesAfter { - params.MinutesAfter = maxMinutesAfter - } else { - params.MinutesAfter = n - } +// parseTimeParam parses the "time" query parameter as a Unix timestamp in +// milliseconds. Returns defaultTime unchanged when the parameter is absent. +func parseTimeParam(q url.Values, defaultTime time.Time, addError func(string, string)) time.Time { + val := q.Get("time") + if val == "" { + return defaultTime } - - // frequencyMinutesBefore - if val := q.Get("frequencyMinutesBefore"); val != "" { - if n, err := strconv.Atoi(val); err != nil { - addError("frequencyMinutesBefore", errMustBeValidInteger) - } else if n < 0 { - addError("frequencyMinutesBefore", errMustBeNonNegativeInteger) - } else { - params.FrequencyMinutesBefore = n - } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + addError("time", "must be a valid Unix timestamp in milliseconds") + return defaultTime } + return time.Unix(ms/1000, (ms%1000)*1_000_000) +} - // frequencyMinutesAfter - if val := q.Get("frequencyMinutesAfter"); val != "" { - if n, err := strconv.Atoi(val); err != nil { - addError("frequencyMinutesAfter", errMustBeValidInteger) - } else if n < 0 { - addError("frequencyMinutesAfter", errMustBeNonNegativeInteger) - } else { - params.FrequencyMinutesAfter = n - } +// parseMinutesCappedParam parses an integer minutes query parameter and writes +// the result into dest. Values above maxVal are silently capped; negative +// values and non-integer values are rejected via addError. +func parseMinutesCappedParam(q url.Values, key string, maxVal int, dest *int, addError func(string, string)) { + val := q.Get(key) + if val == "" { + return + } + n, err := strconv.Atoi(val) + if err != nil { + addError(key, errMustBeValidInteger) + return + } + if n < 0 { + addError(key, errMustBeNonNegativeInteger) + return + } + if n > maxVal { + *dest = maxVal + return } + *dest = n +} - // emptyReturnsNotFound - if val := q.Get("emptyReturnsNotFound"); val != "" { - if b, err := strconv.ParseBool(val); err == nil { - params.EmptyReturnsNotFound = b - } else { - addError("emptyReturnsNotFound", "must be true or false") - } +// parseMinutesUncappedParam parses an integer minutes query parameter with no +// upper bound and writes the result into dest. +func parseMinutesUncappedParam(q url.Values, key string, dest *int, addError func(string, string)) { + val := q.Get(key) + if val == "" { + return + } + n, err := strconv.Atoi(val) + if err != nil { + addError(key, errMustBeValidInteger) + return + } + if n < 0 { + addError(key, errMustBeNonNegativeInteger) + return } + *dest = n +} - // routeType - if val := q.Get("routeType"); val != "" { - parts := strings.Split(val, ",") - for _, p := range parts { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if rt, err := strconv.Atoi(p); err == nil { - params.RouteTypes = append(params.RouteTypes, rt) - } else { - addError("routeType", "must be a comma-delimited list of integers") - break - } - } +// parseEmptyReturnsNotFoundParam parses the "emptyReturnsNotFound" boolean +// query parameter. Returns false when absent or invalid. +func parseEmptyReturnsNotFoundParam(q url.Values, addError func(string, string)) bool { + val := q.Get("emptyReturnsNotFound") + if val == "" { + return false } + b, err := strconv.ParseBool(val) + if err != nil { + addError("emptyReturnsNotFound", "must be true or false") + return false + } + return b +} - // maxCount — reuse the shared parser. - var maxCountErrors map[string][]string - params.MaxCount, maxCountErrors = utils.ParseMaxCount(q, defaultMaxCount, nil) - if len(maxCountErrors) > 0 { - if fieldErrors == nil { - fieldErrors = make(map[string][]string) +// parseRouteTypesParam parses the "routeType" comma-delimited integer list +// query parameter. Returns nil when absent; stops and errors at the first +// invalid token. +func parseRouteTypesParam(q url.Values, addError func(string, string)) []int { + val := q.Get("routeType") + if val == "" { + return nil + } + var routeTypes []int + for _, p := range strings.Split(val, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue } - for k, v := range maxCountErrors { - fieldErrors[k] = append(fieldErrors[k], v...) + rt, err := strconv.Atoi(p) + if err != nil { + addError("routeType", "must be a comma-delimited list of integers") + return nil } + routeTypes = append(routeTypes, rt) } + return routeTypes +} - return params, fieldErrors +// mergeFieldErrors merges src into *dst, initialising *dst lazily if nil. +func mergeFieldErrors(dst *map[string][]string, src map[string][]string) { + if len(src) == 0 { + return + } + if *dst == nil { + *dst = make(map[string][]string) + } + for k, v := range src { + (*dst)[k] = append((*dst)[k], v...) + } } // arrivalStatusFromDeviation derives a human-readable status string from a From 0204e8be43d253b98f12fc8a4a531ebed24564ee Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 19:21:07 +0530 Subject: [PATCH 12/34] Refactor arrivalsAndDeparturesForLocationHandler to reduce cognitive complexity and fix GTFS layover constraint --- gtfsdb/helpers.go | 11 +- .../arrivals_and_departures_for_location.go | 894 +++++++++--------- 2 files changed, 438 insertions(+), 467 deletions(-) diff --git a/gtfsdb/helpers.go b/gtfsdb/helpers.go index 618cac07..09b72eab 100644 --- a/gtfsdb/helpers.go +++ b/gtfsdb/helpers.go @@ -1415,13 +1415,20 @@ func (c *Client) buildBlockLayoverIndex(ctx context.Context, staticData *gtfs.St continue } + layoverStart := int64(lastStopCurrent.ArrivalTime) + layoverEnd := int64(firstStopNext.DepartureTime) + + if layoverStart > layoverEnd { + continue + } + err := qtx.CreateBlockLayover(ctx, CreateBlockLayoverParams{ BlockID: key.blockID, ServiceID: key.serviceID, RouteID: nextTrip.Route.Id, LayoverStopID: lastStopCurrent.Stop.Id, - LayoverStart: int64(lastStopCurrent.DepartureTime), - LayoverEnd: int64(firstStopNext.ArrivalTime), + LayoverStart: layoverStart, + LayoverEnd: layoverEnd, NextTripID: nextTrip.ID, }) if err != nil { diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index a97770a5..69bd21b7 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -38,6 +38,41 @@ type ArrivalsAndDeparturesForLocationParams struct { RouteTypes []int } +// activeStopTime pairs a GTFS stop time with the service date it occurs on. +type activeStopTime struct { + gtfsdb.GetStopTimesForStopInWindowRow + ServiceDate time.Time +} + +// locationArrivalsState holds the shared accumulation state across all stops +// while processing arrivals and departures for a location. +type locationArrivalsState struct { + arrivals []models.ArrivalAndDeparture + tripIDSet map[string]*gtfsdb.Trip + routeIDSet map[string]*gtfsdb.Route + stopIDSet map[string]bool + stopAgencyOverride map[string]string + stopsWithArrivals map[string]bool + collectedAlerts map[string]gtfs.Alert + limitExceeded bool + + stopAgencyMap map[string]string + fallbackAgencyID string + agencyLoc *time.Location +} + +func newLocationArrivalsState() *locationArrivalsState { + return &locationArrivalsState{ + arrivals: make([]models.ArrivalAndDeparture, 0), + tripIDSet: make(map[string]*gtfsdb.Trip), + routeIDSet: make(map[string]*gtfsdb.Route), + stopIDSet: make(map[string]bool), + stopAgencyOverride: make(map[string]string), + stopsWithArrivals: make(map[string]bool), + collectedAlerts: make(map[string]gtfs.Alert), + } +} + // Error message constants shared by the parameter-parsing helpers below. const ( errMustBeValidInteger = "must be a valid integer" @@ -240,8 +275,6 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite return } - // Find stops inside the bounding box using the spatial index. - // GetStopsForLocation manages its own locking and returns (stops, limitExceeded). stops, _ := api.GtfsManager.GetStopsForLocation( ctx, &internalgtfs.LocationParams{ @@ -257,448 +290,440 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite ) if len(stops) == 0 { - if params.EmptyReturnsNotFound { - api.sendNotFound(w, r) - return - } - api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( - []models.ArrivalAndDeparture{}, - *models.NewEmptyReferences(), - []models.StopWithDistance{}, - []string{}, - []string{}, - false, - api.Clock, - )) + api.handleEmptyStopsResponseForLocation(w, r, params) return } - // Collect raw stop codes (no agency prefix) for batch DB queries. - rawStopCodes := make([]string, 0, len(stops)) - for _, s := range stops { - rawStopCodes = append(rawStopCodes, s.ID) - } - - // Resolve agency for each stop (needed to build combined "agencyId_stopCode" IDs). - agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, rawStopCodes) - if err != nil { + state := newLocationArrivalsState() + if err := api.resolveAgenciesForStopsLocation(ctx, stops, state); err != nil { api.serverErrorResponse(w, r, err) return } - // stopCode → agencyID; first agency wins for multi-agency stops. - stopAgencyMap := make(map[string]string, len(agencyRows)) - for _, row := range agencyRows { - if _, exists := stopAgencyMap[row.StopID]; !exists { - stopAgencyMap[row.StopID] = row.ID + + // Fan out: collect arrivals across every stop in the bbox. + for _, dbStop := range stops { + if state.limitExceeded || len(state.arrivals) >= params.MaxCount { + state.limitExceeded = true + break + } + if err := api.collectArrivalsForLocationStop(ctx, w, r, dbStop, params, state); err != nil { + return // Context cancellation/error response already handled. } } - // fallbackAgencyID is used only when a stop has no entry in stopAgencyMap - // (e.g. a stop with no active routes). Derived from the most common agency - // among the queried stops — never used to prefix alert IDs. - fallbackAgencyID := pickPrimaryAgency(stopAgencyMap) + api.sortLocationArrivalsByTime(state.arrivals) - // Determine the base query timezone from the fallback agency. - agencyLoc := time.UTC - if fallbackAgencyID != "" { - if ag, tzErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, fallbackAgencyID); tzErr == nil { - if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { - agencyLoc = parsed + // Collect stop-level service alerts. + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } + for _, sc := range rawStopCodes { + for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { + if alert.ID != "" { + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } } } } - // Fan out: collect arrivals across every stop in the bbox. - arrivals := make([]models.ArrivalAndDeparture, 0, len(stops)*4) - - // Shared reference-collection maps (deduplicated across all stops). - tripIDSet := make(map[string]*gtfsdb.Trip) - routeIDSet := make(map[string]*gtfsdb.Route) - stopIDSet := make(map[string]bool) // raw stop codes for reference building - stopAgencyOverride := make(map[string]string) // raw stop code → correct agencyID + references, topLevelSituationIDs := api.buildLocationReferencesBlock(ctx, state) + queriedStopIDs := api.buildLocationQueriedStopIDs(stops, state) + nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time) - // Track which stop codes actually produced at least one arrival. - // Java only includes a stop in the entry's stopIds when it has results. - stopsWithArrivals := make(map[string]bool) + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + state.arrivals, + *references, + nearbyStops, + topLevelSituationIDs, + queriedStopIDs, + state.limitExceeded, + api.Clock, + )) +} - collectedAlerts := make(map[string]gtfs.Alert) +func (api *RestAPI) handleEmptyStopsResponseForLocation(w http.ResponseWriter, r *http.Request, params ArrivalsAndDeparturesForLocationParams) { + if params.EmptyReturnsNotFound { + api.sendNotFound(w, r) + return + } + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + []models.ArrivalAndDeparture{}, + *models.NewEmptyReferences(), + []models.StopWithDistance{}, + []string{}, + []string{}, + false, + api.Clock, + )) +} - limitExceeded := false +func (api *RestAPI) resolveAgenciesForStopsLocation(ctx context.Context, stops []gtfsdb.Stop, state *locationArrivalsState) error { + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } - for _, dbStop := range stops { - // Early exit once maxCount is reached — mirrors Java's MaxCountSupport. - if limitExceeded { - break - } - if len(arrivals) >= params.MaxCount { - limitExceeded = true - break - } + agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, rawStopCodes) + if err != nil { + return err + } - stopCode := dbStop.ID - agencyID := stopAgencyMap[stopCode] - if agencyID == "" { - agencyID = fallbackAgencyID + state.stopAgencyMap = make(map[string]string, len(agencyRows)) + for _, row := range agencyRows { + if _, exists := state.stopAgencyMap[row.StopID]; !exists { + state.stopAgencyMap[row.StopID] = row.ID } - combinedStopID := utils.FormCombinedID(agencyID, stopCode) - stopIDSet[stopCode] = true + } - // Per-stop timezone — handles multi-agency feeds where stops may span TZs. - stopLoc := agencyLoc - if agencyID != fallbackAgencyID { - if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID); agErr == nil { - if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { - stopLoc = parsed - } + state.fallbackAgencyID = pickPrimaryAgency(state.stopAgencyMap) + state.agencyLoc = time.UTC + if state.fallbackAgencyID != "" { + if ag, tzErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, state.fallbackAgencyID); tzErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + state.agencyLoc = parsed } } + } + return nil +} - stopQueryTime := params.Time.In(stopLoc) - stopWindowStart := stopQueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) - stopWindowEnd := stopQueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) - - // Query 3 days (yesterday/today/tomorrow) to handle overnight trips — - // identical to the single-stop handler's approach. - type activeStopTime struct { - gtfsdb.GetStopTimesForStopInWindowRow - ServiceDate time.Time - } - var allActiveStopTimes []activeStopTime - - for dayOffset := -1; dayOffset <= 1; dayOffset++ { - if ctx.Err() != nil { - api.clientCanceledResponse(w, r, ctx.Err()) - return - } - - targetDate := stopQueryTime.AddDate(0, 0, dayOffset) - serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) - serviceDateStr := targetDate.Format("20060102") +func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.ResponseWriter, r *http.Request, dbStop gtfsdb.Stop, params ArrivalsAndDeparturesForLocationParams, state *locationArrivalsState) error { + stopCode := dbStop.ID + agencyID := state.stopAgencyMap[stopCode] + if agencyID == "" { + agencyID = state.fallbackAgencyID + } + state.stopIDSet[stopCode] = true - activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) - if svcErr != nil { - api.Logger.Warn("failed to query active service IDs", - slog.String("date", serviceDateStr), - slog.Any("error", svcErr)) - continue - } - if len(activeServiceIDs) == 0 { - continue + stopLoc := state.agencyLoc + if agencyID != state.fallbackAgencyID { + if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID); agErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + stopLoc = parsed } + } + } - activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) - for _, sid := range activeServiceIDs { - activeServiceIDSet[sid] = true - } + stopQueryTime := params.Time.In(stopLoc) + allActiveStopTimes, err := api.fetchActiveStopTimesForLocationWindow(ctx, stopCode, stopLoc, stopQueryTime, params) + if err != nil { + api.clientCanceledResponse(w, r, err) + return err + } + if len(allActiveStopTimes) == 0 { + return nil + } - startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() - endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() - if endNanos < 0 { - continue - } + stopProducedArrival, err := api.buildArrivalsFromLocationStopTimes(ctx, w, r, stopCode, agencyID, allActiveStopTimes, params, stopQueryTime, state) + if err != nil { + return err + } + if stopProducedArrival { + state.stopsWithArrivals[stopCode] = true + } + return nil +} - stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ - StopID: stopCode, - WindowStartNanos: startNanos, - WindowEndNanos: endNanos, - }) - if stErr != nil { - api.Logger.Warn("failed to query stop times in window", - slog.String("stopID", stopCode), - slog.Any("error", stErr)) - continue - } +func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, stopCode string, stopLoc *time.Location, stopQueryTime time.Time, params ArrivalsAndDeparturesForLocationParams) ([]activeStopTime, error) { + stopWindowStart := stopQueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) + stopWindowEnd := stopQueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) - for _, st := range stopTimes { - if activeServiceIDSet[st.ServiceID] { - allActiveStopTimes = append(allActiveStopTimes, activeStopTime{ - GetStopTimesForStopInWindowRow: st, - ServiceDate: serviceMidnight, - }) - } - } + var allActiveStopTimes []activeStopTime + for dayOffset := -1; dayOffset <= 1; dayOffset++ { + if ctx.Err() != nil { + return nil, ctx.Err() } + targetDate := stopQueryTime.AddDate(0, 0, dayOffset) + serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) + serviceDateStr := targetDate.Format("20060102") - if len(allActiveStopTimes) == 0 { - // This stop has no arrivals in the window — do not include it in stopIds. + activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) + if svcErr != nil { + api.Logger.Warn("failed to query active service IDs", slog.String("date", serviceDateStr), slog.Any("error", svcErr)) continue } - - // Batch-fetch routes & trips for this stop's active stop times. - batchRouteIDs := make(map[string]bool) - batchTripIDs := make(map[string]bool) - for _, ast := range allActiveStopTimes { - if ast.RouteID != "" { - batchRouteIDs[ast.RouteID] = true - } - if ast.TripID != "" { - batchTripIDs[ast.TripID] = true - } + if len(activeServiceIDs) == 0 { + continue } - uniqueRouteIDs := stringMapKeys(batchRouteIDs) - uniqueTripIDs := stringMapKeys(batchTripIDs) - - fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) - if rErr != nil { - api.Logger.Warn("failed to batch fetch routes", - slog.String("stopID", stopCode), slog.Any("error", rErr)) - continue + activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) + for _, sid := range activeServiceIDs { + activeServiceIDSet[sid] = true } - fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) - if tErr != nil { - api.Logger.Warn("failed to batch fetch trips", - slog.String("stopID", stopCode), slog.Any("error", tErr)) + + startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() + endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() + if endNanos < 0 { continue } - routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) - for _, rt := range fetchedRoutes { - routesLookup[rt.ID] = rt - } - tripsLookup := make(map[string]gtfsdb.Trip, len(fetchedTrips)) - for _, tr := range fetchedTrips { - tripsLookup[tr.ID] = tr + stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ + StopID: stopCode, + WindowStartNanos: startNanos, + WindowEndNanos: endNanos, + }) + if stErr != nil { + api.Logger.Warn("failed to query stop times in window", slog.String("stopID", stopCode), slog.Any("error", stErr)) + continue } - // Batch total-stop-count per trip (avoids N+1 for totalStopsInTrip field). - tripStopCountMap := make(map[string]int, len(uniqueTripIDs)) - if len(uniqueTripIDs) > 0 { - allST, countErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTripIDs(ctx, uniqueTripIDs) - if countErr != nil { - api.Logger.Warn("failed to batch fetch stop times for trips", slog.Any("error", countErr)) - } else { - for _, st := range allST { - tripStopCountMap[st.TripID]++ - } + for _, st := range stopTimes { + if activeServiceIDSet[st.ServiceID] { + allActiveStopTimes = append(allActiveStopTimes, activeStopTime{ + GetStopTimesForStopInWindowRow: st, + ServiceDate: serviceMidnight, + }) } } + } + return allActiveStopTimes, nil +} - // Build one ArrivalAndDeparture per active stop time. - stopProducedArrival := false - for _, ast := range allActiveStopTimes { - // Respect maxCount mid-stop as well. - if len(arrivals) >= params.MaxCount { - limitExceeded = true - break +func (api *RestAPI) buildArrivalsFromLocationStopTimes( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + stopCode string, + agencyID string, + allActiveStopTimes []activeStopTime, + params ArrivalsAndDeparturesForLocationParams, + stopQueryTime time.Time, + state *locationArrivalsState, +) (bool, error) { + batchRouteIDs := make(map[string]bool) + batchTripIDs := make(map[string]bool) + for _, ast := range allActiveStopTimes { + if ast.RouteID != "" { + batchRouteIDs[ast.RouteID] = true + } + if ast.TripID != "" { + batchTripIDs[ast.TripID] = true + } + } + + uniqueRouteIDs := stringMapKeys(batchRouteIDs) + uniqueTripIDs := stringMapKeys(batchTripIDs) + + fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) + if rErr != nil { + api.Logger.Warn("failed to batch fetch routes", slog.String("stopID", stopCode), slog.Any("error", rErr)) + return false, nil + } + fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) + if tErr != nil { + api.Logger.Warn("failed to batch fetch trips", slog.String("stopID", stopCode), slog.Any("error", tErr)) + return false, nil + } + + routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) + for _, rt := range fetchedRoutes { + routesLookup[rt.ID] = rt + } + tripsLookup := make(map[string]gtfsdb.Trip, len(fetchedTrips)) + for _, tr := range fetchedTrips { + tripsLookup[tr.ID] = tr + } + + tripStopCountMap := make(map[string]int, len(uniqueTripIDs)) + if len(uniqueTripIDs) > 0 { + allST, countErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTripIDs(ctx, uniqueTripIDs) + if countErr != nil { + api.Logger.Warn("failed to batch fetch stop times for trips", slog.Any("error", countErr)) + } else { + for _, st := range allST { + tripStopCountMap[st.TripID]++ } + } + } - if ctx.Err() != nil { - api.clientCanceledResponse(w, r, ctx.Err()) - return - } + stopProducedArrival := false + combinedStopID := utils.FormCombinedID(agencyID, stopCode) - st := ast.GetStopTimesForStopInWindowRow - serviceMidnight := ast.ServiceDate + for _, ast := range allActiveStopTimes { + if len(state.arrivals) >= params.MaxCount { + state.limitExceeded = true + break + } + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return stopProducedArrival, ctx.Err() + } - route, routeOK := routesLookup[st.RouteID] - if !routeOK { - api.Logger.Debug("skipping stop time: route not found", - slog.String("routeID", st.RouteID), slog.String("tripID", st.TripID)) - continue - } - trip, tripOK := tripsLookup[st.TripID] - if !tripOK { - api.Logger.Debug("skipping stop time: trip not found", - slog.String("tripID", st.TripID)) - continue - } + st := ast.GetStopTimesForStopInWindowRow + serviceMidnight := ast.ServiceDate - rCopy := route - routeIDSet[route.ID] = &rCopy - tCopy := trip - tripIDSet[trip.ID] = &tCopy - - scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) - scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) - - var ( - predictedArrivalTime time.Time - predictedDepartureTime time.Time - predicted = false - vehicleID string - tripStatus *models.TripStatus - distanceFromStop = 0.0 - numberOfStopsAway = 0 - lastUpdateTime time.Time - - // FIX #4: derive status from schedule deviation instead of - // always emitting "default". Falls back to "default" when - // there is no real-time data. - arrivalStatus = "default" - ) - - vehicle := api.GtfsManager.GetVehicleForTrip(ctx, st.TripID) - if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { - vehicleID = vehicle.ID.ID - } + route, routeOK := routesLookup[st.RouteID] + if !routeOK { + continue + } + trip, tripOK := tripsLookup[st.TripID] + if !tripOK { + continue + } - predArr, predDep, isPredicted := api.getPredictedTimes( - st.TripID, stopCode, int64(st.StopSequence), - scheduledArrivalTime, scheduledDepartureTime, - ) - if isPredicted { - predicted = true - predictedArrivalTime = predArr - predictedDepartureTime = predDep + rCopy := route + state.routeIDSet[route.ID] = &rCopy + tCopy := trip + state.tripIDSet[trip.ID] = &tCopy + + scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) + scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) + + var ( + predictedArrivalTime time.Time + predictedDepartureTime time.Time + predicted = false + vehicleID string + tripStatus *models.TripStatus + distanceFromStop = 0.0 + numberOfStopsAway = 0 + lastUpdateTime time.Time + arrivalStatus = "default" + ) + + vehicle := api.GtfsManager.GetVehicleForTrip(ctx, st.TripID) + if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { + vehicleID = vehicle.ID.ID + } + + predArr, predDep, isPredicted := api.getPredictedTimes( + st.TripID, stopCode, int64(st.StopSequence), + scheduledArrivalTime, scheduledDepartureTime, + ) + if isPredicted { + predicted = true + predictedArrivalTime = predArr + predictedDepartureTime = predDep + } + + if vehicle != nil { + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, stopQueryTime) + if statusErr != nil { + api.Logger.Warn("BuildTripStatus failed", "tripID", st.TripID, "error", statusErr) } - // When not predicted, leave predictedArrivalTime/predictedDepartureTime as zero time.Time - // (matches Java which emits 0 for unpredicted arrivals). - - // Gate BuildTripStatus on vehicle presence — matches the stop handler convention. - if vehicle != nil { - status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, stopQueryTime) - if statusErr != nil { - api.Logger.Warn("BuildTripStatus failed", - "tripID", st.TripID, "error", statusErr) + if status != nil { + tripStatus = status + + if !predicted && status.Predicted { + dev := time.Duration(status.ScheduleDeviation) * time.Second + predictedArrivalTime = scheduledArrivalTime.Add(dev) + predictedDepartureTime = scheduledDepartureTime.Add(dev) + predicted = true } - if status != nil { - tripStatus = status - - // Block-based prediction propagation: when the GTFS-RT feed only - // has data for the active (preceding) block trip, getPredictedTimes - // above returns isPredicted=false for the scheduled (future) trip. - // If tripStatus.Predicted is true, the vehicle IS tracked — apply - // the active trip's schedule deviation to the scheduled times. - // This matches Java OBA's ArrivalAndDepartureServiceImpl which - // propagates block-level delay to future block trips. - if !predicted && status.Predicted { - dev := time.Duration(status.ScheduleDeviation) * time.Second - predictedArrivalTime = scheduledArrivalTime.Add(dev) - predictedDepartureTime = scheduledDepartureTime.Add(dev) - predicted = true - } - // status field: derive from deviation only when predicted (direct or block-propagated). - if predicted { - arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) - } + if predicted { + arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) + } - // Collect stops referenced in trip status for the references block. - if status.NextStop != "" { - if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { - stopIDSet[nsID] = true - if nsAgency != "" { - stopAgencyOverride[nsID] = nsAgency - } + if status.NextStop != "" { + if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { + state.stopIDSet[nsID] = true + if nsAgency != "" { + state.stopAgencyOverride[nsID] = nsAgency } } - if status.ClosestStop != "" { - if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { - stopIDSet[csID] = true - if csAgency != "" { - stopAgencyOverride[csID] = csAgency - } + } + if status.ClosestStop != "" { + if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { + state.stopIDSet[csID] = true + if csAgency != "" { + state.stopAgencyOverride[csID] = csAgency } } + } - if vehicle.Position != nil { - distanceFromStop = api.getBlockDistanceToStop(ctx, st.TripID, stopCode, vehicle, stopQueryTime) - nsa := api.getNumberOfStopsAway(ctx, st.TripID, int(st.StopSequence), vehicle, stopQueryTime) - if nsa != nil { - numberOfStopsAway = *nsa - } else { - numberOfStopsAway = -1 - } + if vehicle.Position != nil { + distanceFromStop = api.getBlockDistanceToStop(ctx, st.TripID, stopCode, vehicle, stopQueryTime) + nsa := api.getNumberOfStopsAway(ctx, st.TripID, int(st.StopSequence), vehicle, stopQueryTime) + if nsa != nil { + numberOfStopsAway = *nsa + } else { + numberOfStopsAway = -1 } + } - // Ensure the active trip (if different from scheduled) is in references. - // Also override tripStatus.BlockTripSequence to use the ACTIVE trip's - // block sequence — BuildTripStatus computes it from the scheduled (target) - // tripID, but Java OBA emits the active trip's sequence on the wire. - if status.ActiveTripID != "" { - if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil { - // Override BlockTripSequence to the active trip's sequence. - if activeSeq := api.calculateBlockTripSequence(ctx, atID, serviceMidnight); activeSeq > 0 { - status.BlockTripSequence = activeSeq - } + if status.ActiveTripID != "" { + if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil { + if activeSeq := api.calculateBlockTripSequence(ctx, atID, serviceMidnight); activeSeq > 0 { + status.BlockTripSequence = activeSeq + } - if atID != st.TripID { - if _, exists := tripIDSet[atID]; !exists { - if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { - atCopy := at - tripIDSet[at.ID] = &atCopy - if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { - arCopy := ar - routeIDSet[ar.ID] = &arCopy - } + if atID != st.TripID { + if _, exists := state.tripIDSet[atID]; !exists { + if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { + atCopy := at + state.tripIDSet[at.ID] = &atCopy + if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { + arCopy := ar + state.routeIDSet[ar.ID] = &arCopy } } } } } } - - lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) } + lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) + } - totalStopsInTrip := tripStopCountMap[st.TripID] - blockTripSequence := api.calculateBlockTripSequence(ctx, st.TripID, serviceMidnight) - - // alert.ID from GTFS-RT already contains the agency prefix (e.g. "40_16931"). - // Do NOT wrap with FormCombinedID — that would double-prefix to "40_40_16931". - // Both per-arrival situationIds and top-level situationIds use the raw alert.ID. - tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, st.TripID) - situationIDs := make([]string, 0, len(tripAlerts)) - for _, alert := range tripAlerts { - if alert.ID == "" { - continue - } - situationIDs = append(situationIDs, alert.ID) - if _, seen := collectedAlerts[alert.ID]; !seen { - collectedAlerts[alert.ID] = alert - } - } + totalStopsInTrip := tripStopCountMap[st.TripID] + blockTripSequence := api.calculateBlockTripSequence(ctx, st.TripID, serviceMidnight) - // vehicleID must carry the agency prefix to match Java output ("1_6853"). - formattedVehicleID := "" - if vehicleID != "" { - formattedVehicleID = utils.FormCombinedID(route.AgencyID, vehicleID) + tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, st.TripID) + situationIDs := make([]string, 0, len(tripAlerts)) + for _, alert := range tripAlerts { + if alert.ID == "" { + continue } + situationIDs = append(situationIDs, alert.ID) + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } + } + + formattedVehicleID := "" + if vehicleID != "" { + formattedVehicleID = utils.FormCombinedID(route.AgencyID, vehicleID) + } + + rawStopSequence := int(st.StopSequence) - 1 + + state.arrivals = append(state.arrivals, *models.NewArrivalAndDeparture( + utils.FormCombinedID(route.AgencyID, route.ID), + route.ShortName.String, + route.LongName.String, + utils.FormCombinedID(route.AgencyID, st.TripID), + st.TripHeadsign.String, + combinedStopID, + formattedVehicleID, + serviceMidnight, + scheduledArrivalTime, + scheduledDepartureTime, + predictedArrivalTime, + predictedDepartureTime, + lastUpdateTime, + predicted, + true, + true, + rawStopSequence, + totalStopsInTrip, + numberOfStopsAway, + blockTripSequence, + distanceFromStop, + arrivalStatus, + "", "", "", + tripStatus, + situationIDs, + )) + stopProducedArrival = true + } + + return stopProducedArrival, nil +} - // stopSequence is zero-based on the wire (matching Java OBA and the - // arrivals-and-departures-for-stop handler which uses StopSequence-1). - rawStopSequence := int(st.StopSequence) - 1 - - arrivals = append(arrivals, *models.NewArrivalAndDeparture( - utils.FormCombinedID(route.AgencyID, route.ID), // routeID - route.ShortName.String, // routeShortName - route.LongName.String, // routeLongName - utils.FormCombinedID(route.AgencyID, st.TripID), // tripID - st.TripHeadsign.String, // tripHeadsign - combinedStopID, // stopID - formattedVehicleID, // vehicleID (agency-prefixed or empty) - serviceMidnight, // serviceDate - scheduledArrivalTime, // scheduledArrivalTime - scheduledDepartureTime, // scheduledDepartureTime - predictedArrivalTime, // predictedArrivalTime (zero when unpredicted) - predictedDepartureTime, // predictedDepartureTime (zero when unpredicted) - lastUpdateTime, // lastUpdateTime - predicted, // predicted - true, // arrivalEnabled - true, // departureEnabled - rawStopSequence, // stopSequence (zero-based, matching stop handler) - totalStopsInTrip, // totalStopsInTrip - numberOfStopsAway, // numberOfStopsAway - blockTripSequence, // blockTripSequence - distanceFromStop, // distanceFromStop - arrivalStatus, // FIX #4: derived from scheduleDeviation - "", // occupancyStatus - "", // predictedOccupancy - "", // historicalOccupancy - tripStatus, // tripStatus - situationIDs, // situationIDs (agency-prefixed) - )) - stopProducedArrival = true - } - - if stopProducedArrival { - stopsWithArrivals[stopCode] = true - } - } - - // Sort arrivals by predicted (or scheduled) arrival time ascending. - // Matches Java's ArrivalAndDepartureComparator. +func (api *RestAPI) sortLocationArrivalsByTime(arrivals []models.ArrivalAndDeparture) { sort.Slice(arrivals, func(i, j int) bool { ai := arrivals[i] aj := arrivals[j] @@ -715,24 +740,23 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } return ti.Before(tj) }) +} - // Build references block (agencies, routes, stops, trips, situations). +func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *locationArrivalsState) (*models.ReferencesModel, []string) { references := models.NewEmptyReferences() addedAgencyIDs := make(map[string]bool) - // Trips - for _, trip := range tripIDSet { - routeForTrip, ok := routeIDSet[trip.RouteID] + for _, trip := range state.tripIDSet { + routeForTrip, ok := state.routeIDSet[trip.RouteID] if !ok { - fetched, fErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, trip.RouteID) - if fErr != nil { - api.Logger.Warn("failed to fetch route for trip reference", - "tripID", trip.ID, "routeID", trip.RouteID, "error", fErr) + if fetched, fErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, trip.RouteID); fErr == nil { + fCopy := fetched + state.routeIDSet[fetched.ID] = &fCopy + routeForTrip = &fCopy + } else { + api.Logger.Warn("failed to fetch route for trip reference", "tripID", trip.ID, "routeID", trip.RouteID) continue } - fCopy := fetched - routeIDSet[fetched.ID] = &fCopy - routeForTrip = &fCopy } references.Trips = append(references.Trips, *models.NewTripReference( utils.FormCombinedID(routeForTrip.AgencyID, trip.ID), @@ -746,8 +770,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite )) } - // Routes + their agencies. - for _, route := range routeIDSet { + for _, route := range state.routeIDSet { references.Routes = append(references.Routes, models.NewRoute( utils.FormCombinedID(route.AgencyID, route.ID), route.AgencyID, @@ -760,34 +783,19 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite route.TextColor.String, )) if !addedAgencyIDs[route.AgencyID] { - ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, route.AgencyID) - if agErr == nil { + if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, route.AgencyID); agErr == nil { references.Agencies = append(references.Agencies, models.NewAgencyReference( ag.ID, ag.Name, ag.Url, ag.Timezone, ag.Lang.String, ag.Phone.String, ag.Email.String, ag.FareUrl.String, "", false, )) addedAgencyIDs[ag.ID] = true - } else { - api.Logger.Warn("failed to fetch agency for reference", - "agencyID", route.AgencyID, "error", agErr) } } } - // Stops (queried stops + nextStop/closestStop referenced by TripStatus). - stopIDsSlice := make([]string, 0, len(stopIDSet)) - for sid := range stopIDSet { - stopIDsSlice = append(stopIDsSlice, sid) - } - - batchStops, bsErr := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) - if bsErr != nil { - api.Logger.Warn("failed to batch fetch stop references", slog.Any("error", bsErr)) - } - batchRoutesForStops, brsErr := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) - if brsErr != nil { - api.Logger.Warn("failed to batch fetch routes for stops", slog.Any("error", brsErr)) - } + stopIDsSlice := stringMapKeys(state.stopIDSet) + batchStops, _ := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) + batchRoutesForStops, _ := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) stopsMap := make(map[string]gtfsdb.Stop, len(batchStops)) for _, s := range batchStops { @@ -799,26 +807,22 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } for _, sid := range stopIDsSlice { - if ctx.Err() != nil { - api.clientCanceledResponse(w, r, ctx.Err()) - return - } stopData, ok := stopsMap[sid] if !ok { continue } - ag := stopAgencyMap[sid] + ag := state.stopAgencyMap[sid] if ag == "" { - ag = stopAgencyOverride[sid] + ag = state.stopAgencyOverride[sid] } if ag == "" { - ag = fallbackAgencyID + ag = state.fallbackAgencyID } routesForStop := routesByStop[sid] combinedRouteIDs := make([]string, len(routesForStop)) for i, rr := range routesForStop { combinedRouteIDs[i] = utils.FormCombinedID(rr.AgencyID, rr.ID) - if _, exists := routeIDSet[rr.ID]; !exists { + if _, exists := state.routeIDSet[rr.ID]; !exists { rc := gtfsdb.Route{ ID: rr.ID, AgencyID: rr.AgencyID, @@ -830,7 +834,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite Color: rr.Color, TextColor: rr.TextColor, } - routeIDSet[rr.ID] = &rc + state.routeIDSet[rr.ID] = &rc } } references.Stops = append(references.Stops, models.Stop{ @@ -847,72 +851,35 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite }) } - // Collect stop-level service alerts. - // These fall back to fallbackAgencyID for the agency prefix since there is - // no route context available at the stop level. - for _, sc := range rawStopCodes { - for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { - if alert.ID != "" { - if _, seen := collectedAlerts[alert.ID]; !seen { - collectedAlerts[alert.ID] = alert - } - } - } - } - - // Build situation references and top-level situationIds. - // Entry-level situationIds use the raw alert ID (e.g. "1_85725", "40_16559"). - // Alert IDs from GTFS-RT already contain the agency prefix, so no extra - // FormCombinedID wrapping is applied here. - // Per-arrival situationIds DO wrap with FormCombinedID — that is separate. - topLevelSituationIDs := make([]string, 0, len(collectedAlerts)) - if len(collectedAlerts) > 0 { - alertSlice := make([]gtfs.Alert, 0, len(collectedAlerts)) - for _, a := range collectedAlerts { + topLevelSituationIDs := make([]string, 0, len(state.collectedAlerts)) + if len(state.collectedAlerts) > 0 { + alertSlice := make([]gtfs.Alert, 0, len(state.collectedAlerts)) + for alertID, a := range state.collectedAlerts { alertSlice = append(alertSlice, a) - } - references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) - for alertID := range collectedAlerts { topLevelSituationIDs = append(topLevelSituationIDs, alertID) } + references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) } - // Build the entry's stopIds — only stops that produced at least one arrival, - // in the same order as the original stops slice (deterministic, not map order). - // Java: stops are only added when !arrivalsAndDepartures.isEmpty(). - queriedStopIDs := make([]string, 0, len(stopsWithArrivals)) + return references, topLevelSituationIDs +} + +func (api *RestAPI) buildLocationQueriedStopIDs(stops []gtfsdb.Stop, state *locationArrivalsState) []string { + queriedStopIDs := make([]string, 0, len(state.stopsWithArrivals)) for _, dbStop := range stops { - if stopsWithArrivals[dbStop.ID] { - ag := stopAgencyMap[dbStop.ID] + if state.stopsWithArrivals[dbStop.ID] { + ag := state.stopAgencyMap[dbStop.ID] if ag == "" { - ag = fallbackAgencyID + ag = state.fallbackAgencyID } queriedStopIDs = append(queriedStopIDs, utils.FormCombinedID(ag, dbStop.ID)) } } - - // Build nearbyStopIds as []StopWithDistance. - // - // FIX #3: Java's includeInputIdsInNearby is overridden to true in this endpoint, - // so we DO NOT exclude queried stops from the nearby stops list. - nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time) - - api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( - arrivals, - *references, - nearbyStops, - topLevelSituationIDs, - queriedStopIDs, - limitExceeded, - api.Clock, - )) + return queriedStopIDs } // getLocationNearbyStops returns stops near the query centre together with their // distance from the centre, sorted ascending by distance. -// -// Java equivalent: getNearbyStops() in StopWithArrivalsAndDeparturesBeanServiceImpl, -// which calls SphericalGeometryLibrary.distance() to populate distanceFromQuery. func getLocationNearbyStops( api *RestAPI, ctx context.Context, @@ -936,7 +903,6 @@ func getLocationNearbyStops( return nil } - // Batch-resolve owning agency for each nearby stop. candidateIDs := make([]string, len(nearby)) for i, s := range nearby { candidateIDs[i] = s.ID @@ -954,7 +920,6 @@ func getLocationNearbyStops( } } - // pickPrimaryAgency over the nearby set for stops with no resolved agency. nearbyFallback := pickPrimaryAgency(nearbyAgencyMap) result := make([]models.StopWithDistance, 0, len(nearby)) @@ -976,7 +941,6 @@ func getLocationNearbyStops( return nil } - // Sort by distance ascending to match Java's ordering. sort.Slice(result, func(i, j int) bool { return result[i].DistanceFromQuery < result[j].DistanceFromQuery }) From 96eab9c4fa2f9d0c677a62923d9160f59e455230 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 19:35:22 +0530 Subject: [PATCH 13/34] Refactor buildArrivalsFromLocationStopTimes to reduce cognitive complexity below 15 --- .../arrivals_and_departures_for_location.go | 326 ++++++++++-------- 1 file changed, 187 insertions(+), 139 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 69bd21b7..c733858f 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -73,6 +73,25 @@ func newLocationArrivalsState() *locationArrivalsState { } } +type arrivalContext struct { + st gtfsdb.GetStopTimesForStopInWindowRow + serviceMidnight time.Time + scheduledArrivalTime time.Time + scheduledDepartureTime time.Time + predictedArrivalTime time.Time + predictedDepartureTime time.Time + predicted bool + vehicleID string + tripStatus *models.TripStatus + distanceFromStop float64 + numberOfStopsAway int + lastUpdateTime time.Time + arrivalStatus string + totalStopsInTrip int + blockTripSequence int + situationIDs []string +} + // Error message constants shared by the parameter-parsing helpers below. const ( errMustBeValidInteger = "must be a valid integer" @@ -552,7 +571,6 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( } st := ast.GetStopTimesForStopInWindowRow - serviceMidnight := ast.ServiceDate route, routeOK := routesLookup[st.RouteID] if !routeOK { @@ -568,159 +586,189 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( tCopy := trip state.tripIDSet[trip.ID] = &tCopy - scheduledArrivalTime := serviceMidnight.Add(time.Duration(st.ArrivalTime)) - scheduledDepartureTime := serviceMidnight.Add(time.Duration(st.DepartureTime)) - - var ( - predictedArrivalTime time.Time - predictedDepartureTime time.Time - predicted = false - vehicleID string - tripStatus *models.TripStatus - distanceFromStop = 0.0 - numberOfStopsAway = 0 - lastUpdateTime time.Time - arrivalStatus = "default" - ) - - vehicle := api.GtfsManager.GetVehicleForTrip(ctx, st.TripID) - if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { - vehicleID = vehicle.ID.ID - } - - predArr, predDep, isPredicted := api.getPredictedTimes( - st.TripID, stopCode, int64(st.StopSequence), - scheduledArrivalTime, scheduledDepartureTime, - ) - if isPredicted { - predicted = true - predictedArrivalTime = predArr - predictedDepartureTime = predDep - } - - if vehicle != nil { - status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, stopQueryTime) - if statusErr != nil { - api.Logger.Warn("BuildTripStatus failed", "tripID", st.TripID, "error", statusErr) - } - if status != nil { - tripStatus = status - - if !predicted && status.Predicted { - dev := time.Duration(status.ScheduleDeviation) * time.Second - predictedArrivalTime = scheduledArrivalTime.Add(dev) - predictedDepartureTime = scheduledDepartureTime.Add(dev) - predicted = true - } + api.buildSingleArrival(ctx, stopCode, combinedStopID, ast, stopQueryTime, state, route, tripStopCountMap[st.TripID]) + stopProducedArrival = true + } - if predicted { - arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) - } + return stopProducedArrival, nil +} - if status.NextStop != "" { - if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(status.NextStop); nsErr == nil { - state.stopIDSet[nsID] = true - if nsAgency != "" { - state.stopAgencyOverride[nsID] = nsAgency - } - } - } - if status.ClosestStop != "" { - if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(status.ClosestStop); csErr == nil { - state.stopIDSet[csID] = true - if csAgency != "" { - state.stopAgencyOverride[csID] = csAgency - } - } - } +func (api *RestAPI) buildSingleArrival( + ctx context.Context, + stopCode string, + combinedStopID string, + ast activeStopTime, + stopQueryTime time.Time, + state *locationArrivalsState, + route gtfsdb.Route, + totalStopsInTrip int, +) { + st := ast.GetStopTimesForStopInWindowRow + ac := &arrivalContext{ + st: st, + serviceMidnight: ast.ServiceDate, + totalStopsInTrip: totalStopsInTrip, + arrivalStatus: "default", + } - if vehicle.Position != nil { - distanceFromStop = api.getBlockDistanceToStop(ctx, st.TripID, stopCode, vehicle, stopQueryTime) - nsa := api.getNumberOfStopsAway(ctx, st.TripID, int(st.StopSequence), vehicle, stopQueryTime) - if nsa != nil { - numberOfStopsAway = *nsa - } else { - numberOfStopsAway = -1 - } - } + ac.scheduledArrivalTime = ac.serviceMidnight.Add(time.Duration(ac.st.ArrivalTime)) + ac.scheduledDepartureTime = ac.serviceMidnight.Add(time.Duration(ac.st.DepartureTime)) - if status.ActiveTripID != "" { - if _, atID, atErr := utils.ExtractAgencyIDAndCodeID(status.ActiveTripID); atErr == nil { - if activeSeq := api.calculateBlockTripSequence(ctx, atID, serviceMidnight); activeSeq > 0 { - status.BlockTripSequence = activeSeq - } - - if atID != st.TripID { - if _, exists := state.tripIDSet[atID]; !exists { - if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { - atCopy := at - state.tripIDSet[at.ID] = &atCopy - if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { - arCopy := ar - state.routeIDSet[ar.ID] = &arCopy - } - } - } - } - } - } - } - lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) + vehicle := api.GtfsManager.GetVehicleForTrip(ctx, ac.st.TripID) + if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { + ac.vehicleID = vehicle.ID.ID + } + + api.applyPredictedTimes(ac, stopCode) + + if vehicle != nil { + api.applyTripStatus(ctx, ac, route, vehicle, stopQueryTime, stopCode, state) + } + + ac.blockTripSequence = api.calculateBlockTripSequence(ctx, ac.st.TripID, ac.serviceMidnight) + api.applyAlerts(ctx, ac, state) + + formattedVehicleID := "" + if ac.vehicleID != "" { + formattedVehicleID = utils.FormCombinedID(route.AgencyID, ac.vehicleID) + } + + rawStopSequence := int(ac.st.StopSequence) - 1 + + state.arrivals = append(state.arrivals, *models.NewArrivalAndDeparture( + utils.FormCombinedID(route.AgencyID, route.ID), + route.ShortName.String, + route.LongName.String, + utils.FormCombinedID(route.AgencyID, ac.st.TripID), + ac.st.TripHeadsign.String, + combinedStopID, + formattedVehicleID, + ac.serviceMidnight, + ac.scheduledArrivalTime, + ac.scheduledDepartureTime, + ac.predictedArrivalTime, + ac.predictedDepartureTime, + ac.lastUpdateTime, + ac.predicted, + true, + true, + rawStopSequence, + ac.totalStopsInTrip, + ac.numberOfStopsAway, + ac.blockTripSequence, + ac.distanceFromStop, + ac.arrivalStatus, + "", "", "", + ac.tripStatus, + ac.situationIDs, + )) +} + +func (api *RestAPI) applyPredictedTimes(ac *arrivalContext, stopCode string) { + predArr, predDep, isPredicted := api.getPredictedTimes( + ac.st.TripID, stopCode, int64(ac.st.StopSequence), + ac.scheduledArrivalTime, ac.scheduledDepartureTime, + ) + if isPredicted { + ac.predicted = true + ac.predictedArrivalTime = predArr + ac.predictedDepartureTime = predDep + } +} + +func (api *RestAPI) applyTripStatus(ctx context.Context, ac *arrivalContext, route gtfsdb.Route, vehicle *gtfs.Vehicle, stopQueryTime time.Time, stopCode string, state *locationArrivalsState) { + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, ac.st.TripID, vehicle, ac.serviceMidnight, stopQueryTime) + if statusErr != nil { + api.Logger.Warn("BuildTripStatus failed", "tripID", ac.st.TripID, "error", statusErr) + } + if status != nil { + ac.tripStatus = status + + if !ac.predicted && status.Predicted { + dev := time.Duration(status.ScheduleDeviation) * time.Second + ac.predictedArrivalTime = ac.scheduledArrivalTime.Add(dev) + ac.predictedDepartureTime = ac.scheduledDepartureTime.Add(dev) + ac.predicted = true } - totalStopsInTrip := tripStopCountMap[st.TripID] - blockTripSequence := api.calculateBlockTripSequence(ctx, st.TripID, serviceMidnight) + if ac.predicted { + ac.arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) + } - tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, st.TripID) - situationIDs := make([]string, 0, len(tripAlerts)) - for _, alert := range tripAlerts { - if alert.ID == "" { - continue - } - situationIDs = append(situationIDs, alert.ID) - if _, seen := state.collectedAlerts[alert.ID]; !seen { - state.collectedAlerts[alert.ID] = alert + api.applyTripStatusStops(ac, state) + + if vehicle.Position != nil { + ac.distanceFromStop = api.getBlockDistanceToStop(ctx, ac.st.TripID, stopCode, vehicle, stopQueryTime) + nsa := api.getNumberOfStopsAway(ctx, ac.st.TripID, int(ac.st.StopSequence), vehicle, stopQueryTime) + if nsa != nil { + ac.numberOfStopsAway = *nsa + } else { + ac.numberOfStopsAway = -1 } } - formattedVehicleID := "" - if vehicleID != "" { - formattedVehicleID = utils.FormCombinedID(route.AgencyID, vehicleID) + api.applyActiveTrip(ctx, ac, state) + } + ac.lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) +} + +func (api *RestAPI) applyTripStatusStops(ac *arrivalContext, state *locationArrivalsState) { + if ac.tripStatus.NextStop != "" { + if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.NextStop); nsErr == nil { + state.stopIDSet[nsID] = true + if nsAgency != "" { + state.stopAgencyOverride[nsID] = nsAgency + } + } + } + if ac.tripStatus.ClosestStop != "" { + if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.ClosestStop); csErr == nil { + state.stopIDSet[csID] = true + if csAgency != "" { + state.stopAgencyOverride[csID] = csAgency + } } + } +} - rawStopSequence := int(st.StopSequence) - 1 +func (api *RestAPI) applyActiveTrip(ctx context.Context, ac *arrivalContext, state *locationArrivalsState) { + if ac.tripStatus.ActiveTripID == "" { + return + } + _, atID, atErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.ActiveTripID) + if atErr != nil { + return + } + if activeSeq := api.calculateBlockTripSequence(ctx, atID, ac.serviceMidnight); activeSeq > 0 { + ac.tripStatus.BlockTripSequence = activeSeq + } - state.arrivals = append(state.arrivals, *models.NewArrivalAndDeparture( - utils.FormCombinedID(route.AgencyID, route.ID), - route.ShortName.String, - route.LongName.String, - utils.FormCombinedID(route.AgencyID, st.TripID), - st.TripHeadsign.String, - combinedStopID, - formattedVehicleID, - serviceMidnight, - scheduledArrivalTime, - scheduledDepartureTime, - predictedArrivalTime, - predictedDepartureTime, - lastUpdateTime, - predicted, - true, - true, - rawStopSequence, - totalStopsInTrip, - numberOfStopsAway, - blockTripSequence, - distanceFromStop, - arrivalStatus, - "", "", "", - tripStatus, - situationIDs, - )) - stopProducedArrival = true + if atID != ac.st.TripID { + if _, exists := state.tripIDSet[atID]; !exists { + if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { + atCopy := at + state.tripIDSet[at.ID] = &atCopy + if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { + arCopy := ar + state.routeIDSet[ar.ID] = &arCopy + } + } + } } +} - return stopProducedArrival, nil +func (api *RestAPI) applyAlerts(ctx context.Context, ac *arrivalContext, state *locationArrivalsState) { + tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, ac.st.TripID) + ac.situationIDs = make([]string, 0, len(tripAlerts)) + for _, alert := range tripAlerts { + if alert.ID == "" { + continue + } + ac.situationIDs = append(ac.situationIDs, alert.ID) + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } + } } func (api *RestAPI) sortLocationArrivalsByTime(arrivals []models.ArrivalAndDeparture) { From 1cd4a6c7f891088889b98fabdab21b7b85661c1a Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 19:43:32 +0530 Subject: [PATCH 14/34] Refactor location handlers to reduce Cognitive Complexity below 15 --- .../arrivals_and_departures_for_location.go | 202 +++++++++++------- 1 file changed, 119 insertions(+), 83 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index c733858f..e2b6d604 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -332,20 +332,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite api.sortLocationArrivalsByTime(state.arrivals) - // Collect stop-level service alerts. - rawStopCodes := make([]string, 0, len(stops)) - for _, s := range stops { - rawStopCodes = append(rawStopCodes, s.ID) - } - for _, sc := range rawStopCodes { - for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { - if alert.ID != "" { - if _, seen := state.collectedAlerts[alert.ID]; !seen { - state.collectedAlerts[alert.ID] = alert - } - } - } - } + api.collectStopLevelAlerts(stops, state) references, topLevelSituationIDs := api.buildLocationReferencesBlock(ctx, state) queriedStopIDs := api.buildLocationQueriedStopIDs(stops, state) @@ -378,6 +365,22 @@ func (api *RestAPI) handleEmptyStopsResponseForLocation(w http.ResponseWriter, r )) } +func (api *RestAPI) collectStopLevelAlerts(stops []gtfsdb.Stop, state *locationArrivalsState) { + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } + for _, sc := range rawStopCodes { + for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { + if alert.ID != "" { + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } + } + } + } +} + func (api *RestAPI) resolveAgenciesForStopsLocation(ctx context.Context, stops []gtfsdb.Stop, state *locationArrivalsState) error { rawStopCodes := make([]string, 0, len(stops)) for _, s := range stops { @@ -451,66 +454,17 @@ func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, s var allActiveStopTimes []activeStopTime for dayOffset := -1; dayOffset <= 1; dayOffset++ { - if ctx.Err() != nil { - return nil, ctx.Err() - } - targetDate := stopQueryTime.AddDate(0, 0, dayOffset) - serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) - serviceDateStr := targetDate.Format("20060102") - - activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) - if svcErr != nil { - api.Logger.Warn("failed to query active service IDs", slog.String("date", serviceDateStr), slog.Any("error", svcErr)) - continue - } - if len(activeServiceIDs) == 0 { - continue - } - - activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) - for _, sid := range activeServiceIDs { - activeServiceIDSet[sid] = true - } - - startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() - endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() - if endNanos < 0 { - continue - } - - stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ - StopID: stopCode, - WindowStartNanos: startNanos, - WindowEndNanos: endNanos, - }) - if stErr != nil { - api.Logger.Warn("failed to query stop times in window", slog.String("stopID", stopCode), slog.Any("error", stErr)) - continue - } - - for _, st := range stopTimes { - if activeServiceIDSet[st.ServiceID] { - allActiveStopTimes = append(allActiveStopTimes, activeStopTime{ - GetStopTimesForStopInWindowRow: st, - ServiceDate: serviceMidnight, - }) - } + err := api.fetchStopTimesForDayOffset(ctx, stopCode, stopLoc, stopQueryTime, params, dayOffset, stopWindowStart, stopWindowEnd, &allActiveStopTimes) + if err != nil { + return nil, err } } return allActiveStopTimes, nil } -func (api *RestAPI) buildArrivalsFromLocationStopTimes( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, - stopCode string, - agencyID string, - allActiveStopTimes []activeStopTime, - params ArrivalsAndDeparturesForLocationParams, - stopQueryTime time.Time, - state *locationArrivalsState, -) (bool, error) { +func (api *RestAPI) batchFetchLocationRoutesAndTrips( + ctx context.Context, stopCode string, allActiveStopTimes []activeStopTime, +) (map[string]gtfsdb.Route, map[string]gtfsdb.Trip, map[string]int, error) { batchRouteIDs := make(map[string]bool) batchTripIDs := make(map[string]bool) for _, ast := range allActiveStopTimes { @@ -528,12 +482,12 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) if rErr != nil { api.Logger.Warn("failed to batch fetch routes", slog.String("stopID", stopCode), slog.Any("error", rErr)) - return false, nil + return nil, nil, nil, rErr } fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) if tErr != nil { api.Logger.Warn("failed to batch fetch trips", slog.String("stopID", stopCode), slog.Any("error", tErr)) - return false, nil + return nil, nil, nil, tErr } routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) @@ -556,6 +510,78 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( } } } + return routesLookup, tripsLookup, tripStopCountMap, nil +} + +func (api *RestAPI) fetchStopTimesForDayOffset( + ctx context.Context, stopCode string, stopLoc *time.Location, + stopQueryTime time.Time, params ArrivalsAndDeparturesForLocationParams, + dayOffset int, stopWindowStart, stopWindowEnd time.Time, + allActiveStopTimes *[]activeStopTime, +) error { + if ctx.Err() != nil { + return ctx.Err() + } + targetDate := stopQueryTime.AddDate(0, 0, dayOffset) + serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) + serviceDateStr := targetDate.Format("20060102") + + activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) + if svcErr != nil { + api.Logger.Warn("failed to query active service IDs", slog.String("date", serviceDateStr), slog.Any("error", svcErr)) + return nil + } + if len(activeServiceIDs) == 0 { + return nil + } + + activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) + for _, sid := range activeServiceIDs { + activeServiceIDSet[sid] = true + } + + startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() + endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() + if endNanos < 0 { + return nil + } + + stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ + StopID: stopCode, + WindowStartNanos: startNanos, + WindowEndNanos: endNanos, + }) + if stErr != nil { + api.Logger.Warn("failed to query stop times in window", slog.String("stopID", stopCode), slog.Any("error", stErr)) + return nil + } + + for _, st := range stopTimes { + if activeServiceIDSet[st.ServiceID] { + *allActiveStopTimes = append(*allActiveStopTimes, activeStopTime{ + GetStopTimesForStopInWindowRow: st, + ServiceDate: serviceMidnight, + }) + } + } + return nil +} + +func (api *RestAPI) buildArrivalsFromLocationStopTimes( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + stopCode string, + agencyID string, + allActiveStopTimes []activeStopTime, + params ArrivalsAndDeparturesForLocationParams, + stopQueryTime time.Time, + state *locationArrivalsState, +) (bool, error) { + routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, stopCode, allActiveStopTimes) + if bErr != nil { + return false, nil + } stopProducedArrival := false combinedStopID := utils.FormCombinedID(agencyID, stopCode) @@ -794,6 +820,24 @@ func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *loc references := models.NewEmptyReferences() addedAgencyIDs := make(map[string]bool) + api.addTripReferences(ctx, state, references) + api.addRouteAndAgencyReferences(ctx, state, references, addedAgencyIDs) + api.addStopReferences(ctx, state, references) + + topLevelSituationIDs := make([]string, 0, len(state.collectedAlerts)) + if len(state.collectedAlerts) > 0 { + alertSlice := make([]gtfs.Alert, 0, len(state.collectedAlerts)) + for alertID, a := range state.collectedAlerts { + alertSlice = append(alertSlice, a) + topLevelSituationIDs = append(topLevelSituationIDs, alertID) + } + references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) + } + + return references, topLevelSituationIDs +} + +func (api *RestAPI) addTripReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { for _, trip := range state.tripIDSet { routeForTrip, ok := state.routeIDSet[trip.RouteID] if !ok { @@ -817,7 +861,9 @@ func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *loc utils.FormCombinedID(routeForTrip.AgencyID, trip.ShapeID.String), )) } +} +func (api *RestAPI) addRouteAndAgencyReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel, addedAgencyIDs map[string]bool) { for _, route := range state.routeIDSet { references.Routes = append(references.Routes, models.NewRoute( utils.FormCombinedID(route.AgencyID, route.ID), @@ -840,7 +886,9 @@ func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *loc } } } +} +func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { stopIDsSlice := stringMapKeys(state.stopIDSet) batchStops, _ := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) batchRoutesForStops, _ := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) @@ -898,18 +946,6 @@ func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *loc StaticRouteIDs: combinedRouteIDs, }) } - - topLevelSituationIDs := make([]string, 0, len(state.collectedAlerts)) - if len(state.collectedAlerts) > 0 { - alertSlice := make([]gtfs.Alert, 0, len(state.collectedAlerts)) - for alertID, a := range state.collectedAlerts { - alertSlice = append(alertSlice, a) - topLevelSituationIDs = append(topLevelSituationIDs, alertID) - } - references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) - } - - return references, topLevelSituationIDs } func (api *RestAPI) buildLocationQueriedStopIDs(stops []gtfsdb.Stop, state *locationArrivalsState) []string { From e069e080c918ce5dd73c453af492f39d448a4e54 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 19:47:13 +0530 Subject: [PATCH 15/34] Remove unused params variable from fetchStopTimesForDayOffset --- .../restapi/arrivals_and_departures_for_location.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index e2b6d604..de3aee0e 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -454,7 +454,7 @@ func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, s var allActiveStopTimes []activeStopTime for dayOffset := -1; dayOffset <= 1; dayOffset++ { - err := api.fetchStopTimesForDayOffset(ctx, stopCode, stopLoc, stopQueryTime, params, dayOffset, stopWindowStart, stopWindowEnd, &allActiveStopTimes) + err := api.fetchStopTimesForDayOffset(ctx, stopCode, stopLoc, stopQueryTime, dayOffset, stopWindowStart, stopWindowEnd, &allActiveStopTimes) if err != nil { return nil, err } @@ -499,6 +499,11 @@ func (api *RestAPI) batchFetchLocationRoutesAndTrips( tripsLookup[tr.ID] = tr } + tripStopCountMap := api.buildTripStopCountMap(ctx, uniqueTripIDs) + return routesLookup, tripsLookup, tripStopCountMap, nil +} + +func (api *RestAPI) buildTripStopCountMap(ctx context.Context, uniqueTripIDs []string) map[string]int { tripStopCountMap := make(map[string]int, len(uniqueTripIDs)) if len(uniqueTripIDs) > 0 { allST, countErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTripIDs(ctx, uniqueTripIDs) @@ -510,13 +515,13 @@ func (api *RestAPI) batchFetchLocationRoutesAndTrips( } } } - return routesLookup, tripsLookup, tripStopCountMap, nil + return tripStopCountMap } func (api *RestAPI) fetchStopTimesForDayOffset( ctx context.Context, stopCode string, stopLoc *time.Location, - stopQueryTime time.Time, params ArrivalsAndDeparturesForLocationParams, - dayOffset int, stopWindowStart, stopWindowEnd time.Time, + stopQueryTime time.Time, dayOffset int, + stopWindowStart, stopWindowEnd time.Time, allActiveStopTimes *[]activeStopTime, ) error { if ctx.Err() != nil { From 9d3989b4d97e027d042e25cde57e7d87b23e8421 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Thu, 14 May 2026 19:49:42 +0530 Subject: [PATCH 16/34] Remove unused queryTime parameter from getLocationNearbyStops --- internal/restapi/arrivals_and_departures_for_location.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index de3aee0e..879672c2 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -336,7 +336,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite references, topLevelSituationIDs := api.buildLocationReferencesBlock(ctx, state) queriedStopIDs := api.buildLocationQueriedStopIDs(stops, state) - nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon, params.Time) + nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon) api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( state.arrivals, @@ -973,7 +973,6 @@ func getLocationNearbyStops( api *RestAPI, ctx context.Context, centerLat, centerLon float64, - queryTime time.Time, ) []models.StopWithDistance { nearby, _ := api.GtfsManager.GetStopsForLocation( From 16533b743b6e4b4feaf8c6bef18e4af1a2bb2943 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 17 May 2026 19:41:34 +0530 Subject: [PATCH 17/34] fix: achieve API parity for location arrivals (filters, layovers, frequency) --- gtfsdb/helpers.go | 8 ++++++- .../arrivals_and_departures_for_location.go | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/gtfsdb/helpers.go b/gtfsdb/helpers.go index 09b72eab..e648f31a 100644 --- a/gtfsdb/helpers.go +++ b/gtfsdb/helpers.go @@ -1418,8 +1418,14 @@ func (c *Client) buildBlockLayoverIndex(ctx context.Context, staticData *gtfs.St layoverStart := int64(lastStopCurrent.ArrivalTime) layoverEnd := int64(firstStopNext.DepartureTime) + // If the layover appears negative, check if it's a valid midnight wraparound + // (e.g. within a reasonable layover threshold, like 4 hours) if layoverStart > layoverEnd { - continue + if (layoverEnd+86400)-layoverStart < (4 * 3600) { + layoverEnd += 86400 // It crosses midnight, shift it by 24h + } else { + continue // It's invalid data + } } err := qtx.CreateBlockLayover(ctx, CreateBlockLayoverParams{ diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 879672c2..895d40ef 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -449,6 +449,16 @@ func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.R } func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, stopCode string, stopLoc *time.Location, stopQueryTime time.Time, params ArrivalsAndDeparturesForLocationParams) ([]activeStopTime, error) { + maxBefore := params.MinutesBefore + if params.FrequencyMinutesBefore > maxBefore { + maxBefore = params.FrequencyMinutesBefore + } + + maxAfter := params.MinutesAfter + if params.FrequencyMinutesAfter > maxAfter { + maxAfter = params.FrequencyMinutesAfter + } + stopWindowStart := stopQueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) stopWindowEnd := stopQueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) @@ -607,6 +617,20 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( if !routeOK { continue } + + if len(params.RouteTypes) > 0 { + routeTypeMatch := false + for _, rt := range params.RouteTypes { + if int(route.Type) == rt { + routeTypeMatch = true + break + } + } + if !routeTypeMatch { + continue // Skip this trip, it's the wrong vehicle type + } + } + trip, tripOK := tripsLookup[st.TripID] if !tripOK { continue From 17b5af798a68256e642dd1948f7098ddfd15521c Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 17 May 2026 22:42:49 +0530 Subject: [PATCH 18/34] fix: remove non-standard scheduled field and add position-based numberOfStopsAway fallback for OBA Java parity --- internal/models/trip_details.go | 2 - internal/models/trip_details_test.go | 2 - .../arrival_and_departure_for_stop_handler.go | 70 ++++++++++++++++++- internal/restapi/trips_helper.go | 1 - internal/restapi/trips_helper_test.go | 2 - .../vehicles_for_agency_handler_test.go | 4 +- 6 files changed, 72 insertions(+), 9 deletions(-) diff --git a/internal/models/trip_details.go b/internal/models/trip_details.go index e9921755..3b5553c5 100644 --- a/internal/models/trip_details.go +++ b/internal/models/trip_details.go @@ -63,12 +63,10 @@ type TripStatus struct { TotalDistanceAlongTrip float64 `json:"totalDistanceAlongTrip"` VehicleFeatures []string `json:"vehicleFeatures"` VehicleID string `json:"vehicleId"` - Scheduled bool `json:"scheduled"` // (Scheduled = !Predicted) ,this field is not part of the OpenAPI TripStatus schema but is retained for compatibility with existing API consumers. Tracked as a known spec deviation. } func (ts *TripStatus) SetPredicted(predicted bool) { ts.Predicted = predicted - ts.Scheduled = !predicted } // SituationIDs and VehicleFeatures default to empty slices (never null in JSON). diff --git a/internal/models/trip_details_test.go b/internal/models/trip_details_test.go index 60523272..94250087 100644 --- a/internal/models/trip_details_test.go +++ b/internal/models/trip_details_test.go @@ -160,7 +160,6 @@ func TestTripStatusJSON(t *testing.T) { TotalDistanceAlongTrip: 5000.0, VehicleFeatures: []string{"wifi", "bike_rack"}, VehicleID: "vehicle_789", - Scheduled: false, } jsonData, err := json.Marshal(tripStatus) @@ -176,7 +175,6 @@ func TestTripStatusJSON(t *testing.T) { assert.Equal(t, tripStatus.Predicted, unmarshaledStatus.Predicted) assert.Equal(t, tripStatus.Position.Lat, unmarshaledStatus.Position.Lat) assert.Equal(t, tripStatus.Position.Lon, unmarshaledStatus.Position.Lon) - assert.Equal(t, tripStatus.Scheduled, unmarshaledStatus.Scheduled) } func TestTripStatus_JSONAlwaysPresent(t *testing.T) { diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 90cfa268..94c2d0a9 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -713,8 +713,24 @@ func (api *RestAPI) getPredictedTimes( func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID string, targetStopSequence int, vehicle *gtfs.Vehicle, serviceDate time.Time) *int { currentVehicleStopSequence := getCurrentVehicleStopSequence(vehicle) + if currentVehicleStopSequence == nil { - return nil + // Fallback: infer the vehicle's current stop from its lat/lon position. + // This handles agencies (e.g. Sound Transit Link light rail) that don't + // publish current_stop_sequence in GTFS-RT vehicle positions. + if vehicle == nil || vehicle.Position == nil || + vehicle.Position.Latitude == nil || vehicle.Position.Longitude == nil { + return nil + } + inferred := api.inferStopSequenceFromPosition( + ctx, targetTripID, + float64(*vehicle.Position.Latitude), + float64(*vehicle.Position.Longitude), + ) + if inferred == nil { + return nil + } + currentVehicleStopSequence = inferred } activeTripID := GetVehicleActiveTripID(vehicle) @@ -728,3 +744,55 @@ func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID strin numberOfStopsAway := targetGlobalSeq - vehicleGlobalSeq - 1 return &numberOfStopsAway } + +// inferStopSequenceFromPosition returns the stop_sequence of the stop the vehicle +// is currently at or has most recently passed, determined by projecting the vehicle's +// lat/lon onto the ordered list of stop positions for the trip. +// +// It fetches stop times (ordered by sequence) and stop coordinates in a single batch, +// then finds the last stop that is "behind" the vehicle along the route direction. +// Returns nil when no stop times exist or coordinates cannot be resolved. +func (api *RestAPI) inferStopSequenceFromPosition(ctx context.Context, tripID string, vehLat, vehLon float64) *int { + stopTimes, err := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTrip(ctx, tripID) + if err != nil || len(stopTimes) == 0 { + return nil + } + + stopIDs := make([]string, len(stopTimes)) + for i, st := range stopTimes { + stopIDs[i] = st.StopID + } + + stops, err := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDs) + if err != nil { + return nil + } + + coordMap := make(map[string][2]float64, len(stops)) + for _, s := range stops { + coordMap[s.ID] = [2]float64{s.Lat, s.Lon} + } + + // Find the stop that is geometrically closest to the vehicle's current position. + // OBA Java uses a similar nearest-stop heuristic when stop-sequence is absent. + bestIdx := -1 + bestDist := -1.0 + for i, st := range stopTimes { + coords, ok := coordMap[st.StopID] + if !ok { + continue + } + d := utils.Distance(vehLat, vehLon, coords[0], coords[1]) + if bestDist < 0 || d < bestDist { + bestDist = d + bestIdx = i + } + } + + if bestIdx < 0 { + return nil + } + + seq := int(stopTimes[bestIdx].StopSequence) + return &seq +} diff --git a/internal/restapi/trips_helper.go b/internal/restapi/trips_helper.go index 1d062b7f..13b9e8b1 100644 --- a/internal/restapi/trips_helper.go +++ b/internal/restapi/trips_helper.go @@ -67,7 +67,6 @@ func (api *RestAPI) BuildTripStatus( // Predicted is true because the cancellation itself is real-time information. if status.Status == "CANCELED" { status.Predicted = vehicle != nil && !defaultStaleDetector.Check(vehicle, currentTime) - status.Scheduled = !status.Predicted return status, nil } diff --git a/internal/restapi/trips_helper_test.go b/internal/restapi/trips_helper_test.go index df381a77..3df95655 100644 --- a/internal/restapi/trips_helper_test.go +++ b/internal/restapi/trips_helper_test.go @@ -462,7 +462,6 @@ func TestBuildTripStatus_ScheduleDeviation_SetsPredicted(t *testing.T) { require.NotZero(t, status.ScheduleDeviation) assert.Equal(t, 120, status.ScheduleDeviation, "ScheduleDeviation should reflect the trip update delay") assert.True(t, status.Predicted, "Predicted should be true when trip update exists") - assert.False(t, status.Scheduled, "Scheduled should be false when predicted is true") } func TestBuildTripStatus_NoRealtimeData_SetsScheduled(t *testing.T) { @@ -488,7 +487,6 @@ func TestBuildTripStatus_NoRealtimeData_SetsScheduled(t *testing.T) { assert.Equal(t, 0, status.ScheduleDeviation, "ScheduleDeviation should be 0 with no real-time data") assert.False(t, status.Predicted, "Predicted should be false with no real-time data") - assert.True(t, status.Scheduled, "Scheduled should be true with no real-time data") assert.Equal(t, "default", status.Status) assert.Equal(t, "scheduled", status.Phase) } diff --git a/internal/restapi/vehicles_for_agency_handler_test.go b/internal/restapi/vehicles_for_agency_handler_test.go index 05c837ca..139c155b 100644 --- a/internal/restapi/vehicles_for_agency_handler_test.go +++ b/internal/restapi/vehicles_for_agency_handler_test.go @@ -714,7 +714,9 @@ func TestVehiclesForAgencyHandlerWithRealTimeData(t *testing.T) { tripStatus := vehicle["tripStatus"].(map[string]any) assert.NotEmpty(t, tripStatus["activeTripId"], "TripStatus should have activeTripId") - assert.IsType(t, true, tripStatus["scheduled"]) + // Note: "scheduled" field has been removed from TripStatus to match OBA Java wire format. + // Use "predicted" (tripStatus["predicted"]) instead. + assert.Contains(t, tripStatus, "predicted", "predicted field must be present in tripStatus") if tripStatus["serviceDate"] != nil { assert.IsType(t, float64(0), tripStatus["serviceDate"]) From 7e11bb5a18125931a0443401a3df21a3edb7c9b7 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Mon, 18 May 2026 17:32:44 +0530 Subject: [PATCH 19/34] layover fix --- gtfsdb/helpers.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/gtfsdb/helpers.go b/gtfsdb/helpers.go index e648f31a..195514eb 100644 --- a/gtfsdb/helpers.go +++ b/gtfsdb/helpers.go @@ -1418,11 +1418,15 @@ func (c *Client) buildBlockLayoverIndex(ctx context.Context, staticData *gtfs.St layoverStart := int64(lastStopCurrent.ArrivalTime) layoverEnd := int64(firstStopNext.DepartureTime) + // ArrivalTime/DepartureTime are nanoseconds since service-day midnight + // (they come from the go-gtfs library as time.Duration values). // If the layover appears negative, check if it's a valid midnight wraparound - // (e.g. within a reasonable layover threshold, like 4 hours) + // (e.g. within a reasonable layover threshold, like 4 hours). + const dayNs = int64(24 * time.Hour) // 86_400_000_000_000 ns + const maxLayoverNs = int64(4 * time.Hour) // 14_400_000_000_000 ns if layoverStart > layoverEnd { - if (layoverEnd+86400)-layoverStart < (4 * 3600) { - layoverEnd += 86400 // It crosses midnight, shift it by 24h + if (layoverEnd+dayNs)-layoverStart < maxLayoverNs { + layoverEnd += dayNs // It crosses midnight, shift it by 24h } else { continue // It's invalid data } From a84f6210bec95b11cf1e90c301176dc5c5a8dfae Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Wed, 20 May 2026 12:59:21 +0530 Subject: [PATCH 20/34] fix(api): add missing agency prefix to routeId and stopId in situation allAffects --- internal/restapi/reference_utils.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/internal/restapi/reference_utils.go b/internal/restapi/reference_utils.go index 5bb2fadb..b2b1c270 100644 --- a/internal/restapi/reference_utils.go +++ b/internal/restapi/reference_utils.go @@ -101,17 +101,29 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa } for _, entity := range alert.InformedEntities { + agencyID := getStringValue(entity.AgencyID) + + rawRouteID := getStringValue(entity.RouteID) + if rawRouteID != "" { + rawRouteID = utils.FormCombinedID(agencyID, rawRouteID) + } + + rawStopID := getStringValue(entity.StopID) + if rawStopID != "" { + rawStopID = utils.FormCombinedID(agencyID, rawStopID) + } + affectedEntity := models.AffectedEntity{ - AgencyID: getStringValue(entity.AgencyID), + AgencyID: agencyID, ApplicationID: "", DirectionID: entity.DirectionID.String(), - RouteID: getStringValue(entity.RouteID), - StopID: getStringValue(entity.StopID), + RouteID: rawRouteID, + StopID: rawStopID, TripID: "", } - if entity.TripID != nil { - affectedEntity.TripID = entity.TripID.ID + if entity.TripID != nil && entity.TripID.ID != "" { + affectedEntity.TripID = utils.FormCombinedID(agencyID, entity.TripID.ID) } situation.AllAffects = append(situation.AllAffects, affectedEntity) From 94d3ede9868f8c08bc9dbd8946d1b1be2029c363 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Wed, 20 May 2026 13:02:50 +0530 Subject: [PATCH 21/34] updated openaAPI --- testdata/openapi.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testdata/openapi.yml b/testdata/openapi.yml index 28b44ad9..999c4acb 100644 --- a/testdata/openapi.yml +++ b/testdata/openapi.yml @@ -1,5 +1,5 @@ # Source: https://github.com/OneBusAway/sdk-config/blob/main/openapi.yml -# Fetched: 2026-03-31 +# Fetched: 2026-05-20 openapi: 3.0.0 info: title: OneBusAway @@ -1859,7 +1859,6 @@ components: description: The ID of the trip for the arriving vehicle. tripStatus: $ref: '#/components/schemas/TripStatus' - description: Trip-specific status for the arriving transit vehicle. vehicleId: type: string description: ID of the transit vehicle serving this trip. @@ -2165,6 +2164,7 @@ components: TripStatus: type: object + description: Trip-specific status for the arriving transit vehicle. properties: activeTripId: type: string From de22fdf25582af4eefe0fc8165161dda2bdf3f70 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Wed, 20 May 2026 13:49:15 +0530 Subject: [PATCH 22/34] some parity minor fixes --- internal/models/situation.go | 18 +++++++++++++++++- internal/restapi/reference_utils.go | 18 +++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/internal/models/situation.go b/internal/models/situation.go index 7dcf108e..b4e17acf 100644 --- a/internal/models/situation.go +++ b/internal/models/situation.go @@ -6,7 +6,7 @@ type Situation struct { ActiveWindows []ActiveWindow `json:"activeWindows"` AllAffects []AffectedEntity `json:"allAffects"` ConsequenceMessage string `json:"consequenceMessage"` - Consequences []any `json:"consequences"` + Consequences []Consequence `json:"consequences"` PublicationWindows []any `json:"publicationWindows"` Reason string `json:"reason"` Severity string `json:"severity"` @@ -15,6 +15,22 @@ type Situation struct { URL *TranslatedString `json:"url,omitempty"` } +type Consequence struct { + Condition string `json:"condition"` + ConditionDetails ConditionDetails `json:"conditionDetails"` +} + +type ConditionDetails struct { + DiversionPath DiversionPath `json:"diversionPath"` + DiversionStopIDs []string `json:"diversionStopIds"` +} + +type DiversionPath struct { + Length int `json:"length"` + Levels string `json:"levels"` + Points string `json:"points"` +} + type ActiveWindow struct { From int64 `json:"from"` To int64 `json:"to"` diff --git a/internal/restapi/reference_utils.go b/internal/restapi/reference_utils.go index b2b1c270..a929ceac 100644 --- a/internal/restapi/reference_utils.go +++ b/internal/restapi/reference_utils.go @@ -2,7 +2,6 @@ package restapi import ( "context" - "time" "github.com/OneBusAway/go-gtfs" "maglev.onebusaway.org/gtfsdb" @@ -79,11 +78,24 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa for _, alert := range alerts { situation := models.Situation{ ID: alert.ID, - CreationTime: models.NewModelTime(time.Time{}), + CreationTime: models.NewModelTime(api.Clock.Now()), ActiveWindows: make([]models.ActiveWindow, 0, len(alert.ActivePeriods)), AllAffects: make([]models.AffectedEntity, 0, len(alert.InformedEntities)), ConsequenceMessage: "", - Consequences: []any{}, + Consequences: []models.Consequence{ + { + Condition: "", + ConditionDetails: models.ConditionDetails{ + DiversionPath: models.DiversionPath{ + Length: 0, + Levels: "", + Points: "", + }, + // Initialized to an empty slice so it outputs [] instead of null + DiversionStopIDs: []string{}, + }, + }, + }, PublicationWindows: []any{}, Reason: mapAlertCauseToReason(alert.Cause), Severity: mapAlertEffectToSeverity(alert.Effect), From 838338c6a759842af3c909878fd564bcd8c8323a Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Wed, 20 May 2026 13:59:43 +0530 Subject: [PATCH 23/34] fixes the test failure dure to time --- internal/restapi/coverage_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/restapi/coverage_test.go b/internal/restapi/coverage_test.go index 61effc5a..2eae4795 100644 --- a/internal/restapi/coverage_test.go +++ b/internal/restapi/coverage_test.go @@ -12,7 +12,11 @@ import ( ) func TestBuildSituationReferencesCoverage(t *testing.T) { - api := &RestAPI{} + api := &RestAPI{ + Application: &app.Application{ + Clock: clock.NewMockClock(time.Now()), + }, + } alerts := []gtfs.Alert{ { From 53674577869726c1c5cae7699632b2e15a5dc789 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Fri, 22 May 2026 10:01:53 +0530 Subject: [PATCH 24/34] some minor revert and fixes --- internal/models/trip_details.go | 2 ++ internal/models/trip_details_test.go | 2 ++ internal/restapi/arrivals_and_departures_for_location.go | 8 +++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/internal/models/trip_details.go b/internal/models/trip_details.go index 3b5553c5..e9921755 100644 --- a/internal/models/trip_details.go +++ b/internal/models/trip_details.go @@ -63,10 +63,12 @@ type TripStatus struct { TotalDistanceAlongTrip float64 `json:"totalDistanceAlongTrip"` VehicleFeatures []string `json:"vehicleFeatures"` VehicleID string `json:"vehicleId"` + Scheduled bool `json:"scheduled"` // (Scheduled = !Predicted) ,this field is not part of the OpenAPI TripStatus schema but is retained for compatibility with existing API consumers. Tracked as a known spec deviation. } func (ts *TripStatus) SetPredicted(predicted bool) { ts.Predicted = predicted + ts.Scheduled = !predicted } // SituationIDs and VehicleFeatures default to empty slices (never null in JSON). diff --git a/internal/models/trip_details_test.go b/internal/models/trip_details_test.go index 94250087..60523272 100644 --- a/internal/models/trip_details_test.go +++ b/internal/models/trip_details_test.go @@ -160,6 +160,7 @@ func TestTripStatusJSON(t *testing.T) { TotalDistanceAlongTrip: 5000.0, VehicleFeatures: []string{"wifi", "bike_rack"}, VehicleID: "vehicle_789", + Scheduled: false, } jsonData, err := json.Marshal(tripStatus) @@ -175,6 +176,7 @@ func TestTripStatusJSON(t *testing.T) { assert.Equal(t, tripStatus.Predicted, unmarshaledStatus.Predicted) assert.Equal(t, tripStatus.Position.Lat, unmarshaledStatus.Position.Lat) assert.Equal(t, tripStatus.Position.Lon, unmarshaledStatus.Position.Lon) + assert.Equal(t, tripStatus.Scheduled, unmarshaledStatus.Scheduled) } func TestTripStatus_JSONAlwaysPresent(t *testing.T) { diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location.go index 895d40ef..45675689 100644 --- a/internal/restapi/arrivals_and_departures_for_location.go +++ b/internal/restapi/arrivals_and_departures_for_location.go @@ -283,8 +283,6 @@ func arrivalStatusFromDeviation(deviationSeconds int) string { // arrivalsAndDeparturesForLocationHandler returns arrivals and departures for all // stops within a geographic bounding box (lat/lon + latSpan/lonSpan or radius). -// -// Java equivalent: ArrivalsAndDeparturesForLocationAction.index() func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -294,7 +292,7 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite return } - stops, _ := api.GtfsManager.GetStopsForLocation( + stops, limitExceeded := api.GtfsManager.GetStopsForLocation( ctx, &internalgtfs.LocationParams{ Lat: params.Lat, @@ -314,6 +312,10 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite } state := newLocationArrivalsState() + if limitExceeded { + state.limitExceeded = true + } + if err := api.resolveAgenciesForStopsLocation(ctx, stops, state); err != nil { api.serverErrorResponse(w, r, err) return From ccf440f6d18e26ae81f0b10a0ec645055935a60c Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Fri, 22 May 2026 10:06:57 +0530 Subject: [PATCH 25/34] file name change added handler at the end --- ...ocation.go => arrivals_and_departures_for_location_handler.go} | 0 ...st.go => arrivals_and_departures_for_location_handler_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename internal/restapi/{arrivals_and_departures_for_location.go => arrivals_and_departures_for_location_handler.go} (100%) rename internal/restapi/{arrivals_and_departures_for_location_test.go => arrivals_and_departures_for_location_handler_test.go} (100%) diff --git a/internal/restapi/arrivals_and_departures_for_location.go b/internal/restapi/arrivals_and_departures_for_location_handler.go similarity index 100% rename from internal/restapi/arrivals_and_departures_for_location.go rename to internal/restapi/arrivals_and_departures_for_location_handler.go diff --git a/internal/restapi/arrivals_and_departures_for_location_test.go b/internal/restapi/arrivals_and_departures_for_location_handler_test.go similarity index 100% rename from internal/restapi/arrivals_and_departures_for_location_test.go rename to internal/restapi/arrivals_and_departures_for_location_handler_test.go From 6ccc6c432f1db0aa132a5e4be8de3ee7663f584a Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 24 May 2026 11:37:05 +0530 Subject: [PATCH 26/34] refactor(restapi): reduce cognitive complexity and parameter count in location arrivals handler --- ...als_and_departures_for_location_handler.go | 115 +++++++++++------- 1 file changed, 72 insertions(+), 43 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index 45675689..32d95fd6 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -44,6 +44,24 @@ type activeStopTime struct { ServiceDate time.Time } +// stopProcessingContext holds parameters for processing a single stop's arrivals. +type stopProcessingContext struct { + StopCode string + AgencyID string + CombinedStopID string + QueryTime time.Time + Loc *time.Location +} + +// fetchWindow groups parameters for fetching stop times to reduce function arguments. +type fetchWindow struct { + StopCode string + Loc *time.Location + QueryTime time.Time + Start time.Time + End time.Time +} + // locationArrivalsState holds the shared accumulation state across all stops // while processing arrivals and departures for a location. type locationArrivalsState struct { @@ -431,7 +449,16 @@ func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.R } stopQueryTime := params.Time.In(stopLoc) - allActiveStopTimes, err := api.fetchActiveStopTimesForLocationWindow(ctx, stopCode, stopLoc, stopQueryTime, params) + + spc := stopProcessingContext{ + StopCode: stopCode, + AgencyID: agencyID, + CombinedStopID: utils.FormCombinedID(agencyID, stopCode), + QueryTime: stopQueryTime, + Loc: stopLoc, + } + + allActiveStopTimes, err := api.fetchActiveStopTimesForLocationWindow(ctx, spc, params) if err != nil { api.clientCanceledResponse(w, r, err) return err @@ -440,7 +467,7 @@ func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.R return nil } - stopProducedArrival, err := api.buildArrivalsFromLocationStopTimes(ctx, w, r, stopCode, agencyID, allActiveStopTimes, params, stopQueryTime, state) + stopProducedArrival, err := api.buildArrivalsFromLocationStopTimes(w, r, spc, allActiveStopTimes, params, state) if err != nil { return err } @@ -450,7 +477,9 @@ func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.R return nil } -func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, stopCode string, stopLoc *time.Location, stopQueryTime time.Time, params ArrivalsAndDeparturesForLocationParams) ([]activeStopTime, error) { +func (api *RestAPI) fetchActiveStopTimesForLocationWindow( + ctx context.Context, spc stopProcessingContext, params ArrivalsAndDeparturesForLocationParams, +) ([]activeStopTime, error) { maxBefore := params.MinutesBefore if params.FrequencyMinutesBefore > maxBefore { maxBefore = params.FrequencyMinutesBefore @@ -461,12 +490,20 @@ func (api *RestAPI) fetchActiveStopTimesForLocationWindow(ctx context.Context, s maxAfter = params.FrequencyMinutesAfter } - stopWindowStart := stopQueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) - stopWindowEnd := stopQueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) + stopWindowStart := spc.QueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) + stopWindowEnd := spc.QueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) + + fw := fetchWindow{ + StopCode: spc.StopCode, + Loc: spc.Loc, + QueryTime: spc.QueryTime, + Start: stopWindowStart, + End: stopWindowEnd, + } var allActiveStopTimes []activeStopTime for dayOffset := -1; dayOffset <= 1; dayOffset++ { - err := api.fetchStopTimesForDayOffset(ctx, stopCode, stopLoc, stopQueryTime, dayOffset, stopWindowStart, stopWindowEnd, &allActiveStopTimes) + err := api.fetchStopTimesForDayOffset(ctx, fw, dayOffset, &allActiveStopTimes) if err != nil { return nil, err } @@ -531,16 +568,13 @@ func (api *RestAPI) buildTripStopCountMap(ctx context.Context, uniqueTripIDs []s } func (api *RestAPI) fetchStopTimesForDayOffset( - ctx context.Context, stopCode string, stopLoc *time.Location, - stopQueryTime time.Time, dayOffset int, - stopWindowStart, stopWindowEnd time.Time, - allActiveStopTimes *[]activeStopTime, + ctx context.Context, fw fetchWindow, dayOffset int, allActiveStopTimes *[]activeStopTime, ) error { if ctx.Err() != nil { return ctx.Err() } - targetDate := stopQueryTime.AddDate(0, 0, dayOffset) - serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, stopLoc) + targetDate := fw.QueryTime.AddDate(0, 0, dayOffset) + serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, fw.Loc) serviceDateStr := targetDate.Format("20060102") activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) @@ -557,19 +591,19 @@ func (api *RestAPI) fetchStopTimesForDayOffset( activeServiceIDSet[sid] = true } - startNanos := stopWindowStart.Sub(serviceMidnight).Nanoseconds() - endNanos := stopWindowEnd.Sub(serviceMidnight).Nanoseconds() + startNanos := fw.Start.Sub(serviceMidnight).Nanoseconds() + endNanos := fw.End.Sub(serviceMidnight).Nanoseconds() if endNanos < 0 { return nil } stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ - StopID: stopCode, + StopID: fw.StopCode, WindowStartNanos: startNanos, WindowEndNanos: endNanos, }) if stErr != nil { - api.Logger.Warn("failed to query stop times in window", slog.String("stopID", stopCode), slog.Any("error", stErr)) + api.Logger.Warn("failed to query stop times in window", slog.String("stopID", fw.StopCode), slog.Any("error", stErr)) return nil } @@ -584,24 +618,34 @@ func (api *RestAPI) fetchStopTimesForDayOffset( return nil } +// isRouteTypeAllowed checks if a route's type matches any in the requested filter list. +func isRouteTypeAllowed(routeType int64, allowedTypes []int) bool { + if len(allowedTypes) == 0 { + return true + } + for _, rt := range allowedTypes { + if int(routeType) == rt { + return true + } + } + return false +} + func (api *RestAPI) buildArrivalsFromLocationStopTimes( - ctx context.Context, w http.ResponseWriter, r *http.Request, - stopCode string, - agencyID string, + spc stopProcessingContext, allActiveStopTimes []activeStopTime, params ArrivalsAndDeparturesForLocationParams, - stopQueryTime time.Time, state *locationArrivalsState, ) (bool, error) { - routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, stopCode, allActiveStopTimes) + ctx := r.Context() + routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, spc.StopCode, allActiveStopTimes) if bErr != nil { return false, nil } stopProducedArrival := false - combinedStopID := utils.FormCombinedID(agencyID, stopCode) for _, ast := range allActiveStopTimes { if len(state.arrivals) >= params.MaxCount { @@ -616,23 +660,10 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( st := ast.GetStopTimesForStopInWindowRow route, routeOK := routesLookup[st.RouteID] - if !routeOK { + if !routeOK || !isRouteTypeAllowed(route.Type, params.RouteTypes) { continue } - if len(params.RouteTypes) > 0 { - routeTypeMatch := false - for _, rt := range params.RouteTypes { - if int(route.Type) == rt { - routeTypeMatch = true - break - } - } - if !routeTypeMatch { - continue // Skip this trip, it's the wrong vehicle type - } - } - trip, tripOK := tripsLookup[st.TripID] if !tripOK { continue @@ -643,7 +674,7 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( tCopy := trip state.tripIDSet[trip.ID] = &tCopy - api.buildSingleArrival(ctx, stopCode, combinedStopID, ast, stopQueryTime, state, route, tripStopCountMap[st.TripID]) + api.buildSingleArrival(ctx, spc, ast, state, route, tripStopCountMap[st.TripID]) stopProducedArrival = true } @@ -652,10 +683,8 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( func (api *RestAPI) buildSingleArrival( ctx context.Context, - stopCode string, - combinedStopID string, + spc stopProcessingContext, ast activeStopTime, - stopQueryTime time.Time, state *locationArrivalsState, route gtfsdb.Route, totalStopsInTrip int, @@ -676,10 +705,10 @@ func (api *RestAPI) buildSingleArrival( ac.vehicleID = vehicle.ID.ID } - api.applyPredictedTimes(ac, stopCode) + api.applyPredictedTimes(ac, spc.StopCode) if vehicle != nil { - api.applyTripStatus(ctx, ac, route, vehicle, stopQueryTime, stopCode, state) + api.applyTripStatus(ctx, ac, route, vehicle, spc.QueryTime, spc.StopCode, state) } ac.blockTripSequence = api.calculateBlockTripSequence(ctx, ac.st.TripID, ac.serviceMidnight) @@ -698,7 +727,7 @@ func (api *RestAPI) buildSingleArrival( route.LongName.String, utils.FormCombinedID(route.AgencyID, ac.st.TripID), ac.st.TripHeadsign.String, - combinedStopID, + spc.CombinedStopID, formattedVehicleID, ac.serviceMidnight, ac.scheduledArrivalTime, From 5da7ad840843d5aebc35fc50651f0fe6451f65ac Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sun, 24 May 2026 11:58:13 +0530 Subject: [PATCH 27/34] fix: address PR feedback for location handler and tests --- ...als_and_departures_for_location_handler.go | 4 ++-- ...nd_departures_for_location_handler_test.go | 1 + internal/restapi/context_cancellation_test.go | 19 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index 32d95fd6..d1947c87 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -490,8 +490,8 @@ func (api *RestAPI) fetchActiveStopTimesForLocationWindow( maxAfter = params.FrequencyMinutesAfter } - stopWindowStart := spc.QueryTime.Add(-time.Duration(params.MinutesBefore) * time.Minute) - stopWindowEnd := spc.QueryTime.Add(time.Duration(params.MinutesAfter) * time.Minute) + stopWindowStart := spc.QueryTime.Add(-time.Duration(maxBefore) * time.Minute) + stopWindowEnd := spc.QueryTime.Add(time.Duration(maxAfter) * time.Minute) fw := fetchWindow{ StopCode: spc.StopCode, diff --git a/internal/restapi/arrivals_and_departures_for_location_handler_test.go b/internal/restapi/arrivals_and_departures_for_location_handler_test.go index ccd383f4..1f8033fa 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler_test.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler_test.go @@ -490,4 +490,5 @@ func TestArrivalsAndDeparturesForLocationLimitExceeded(t *testing.T) { ads, _ := entry["arrivalsAndDepartures"].([]interface{}) assert.LessOrEqual(t, len(ads), 1) + assert.Equal(t, true, entry["limitExceeded"]) } diff --git a/internal/restapi/context_cancellation_test.go b/internal/restapi/context_cancellation_test.go index 386326de..22ac6587 100644 --- a/internal/restapi/context_cancellation_test.go +++ b/internal/restapi/context_cancellation_test.go @@ -23,32 +23,32 @@ func TestContextCancellationHandling(t *testing.T) { }{ { name: "agencies with coverage should handle context cancellation", - endpoint: "/api/where/agencies-with-coverage.json?key=test", + endpoint: "/api/where/agencies-with-coverage.json?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, // Very short timeout to trigger cancellation }, { name: "stop IDs for agency should handle context cancellation", - endpoint: "/api/where/stop-ids-for-agency/1?key=test", + endpoint: "/api/where/stop-ids-for-agency/1?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "routes for location should handle context cancellation", - endpoint: "/api/where/routes-for-location.json?lat=38.9&lon=-77.0&key=test", + endpoint: "/api/where/routes-for-location.json?lat=38.9&lon=-77.0&key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "stops for location should handle context cancellation", - endpoint: "/api/where/stops-for-location.json?lat=38.9&lon=-77.0&key=test", + endpoint: "/api/where/stops-for-location.json?lat=38.9&lon=-77.0&key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "arrivals and departures for location should handle context cancellation", - endpoint: "/api/where/arrivals-and-departures-for-location.json?lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01&key=test", + endpoint: "/api/where/arrivals-and-departures-for-location.json?lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01&key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "stops for route should handle context cancellation", - endpoint: "/api/where/stops-for-route/1?key=test", + endpoint: "/api/where/stops-for-route/1?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, } @@ -82,16 +82,15 @@ func TestContextCancellationHandling(t *testing.T) { statusCode := w.Code // Valid responses: 200 (completed), 401 (API validation), 500 (error), timeout-related, - // or 429 (rate limit — valid when many sub-tests exhaust the rate limiter). + // or 404 (not found). Rate limit 429 is prevented by using an exempt key. assert.True(t, statusCode == http.StatusOK || statusCode == http.StatusUnauthorized || // API key validation happens first statusCode == http.StatusBadRequest || statusCode == http.StatusInternalServerError || statusCode == http.StatusRequestTimeout || statusCode == http.StatusGatewayTimeout || - statusCode == http.StatusTooManyRequests || // rate limit is valid under load statusCode == http.StatusNotFound, - "Expected status 200, 401, 404, 429, 500, 408, or 504, got %d", statusCode) + "Expected status 200, 401, 400, 404, 500, 408, or 504, got %d", statusCode) }) } } @@ -102,7 +101,7 @@ func TestLongerTimeoutContextHandling(t *testing.T) { // Test with a reasonable timeout that should allow completion t.Run("reasonable timeout should complete successfully", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=test", nil) + req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=org.onebusaway.iphone", nil) require.NoError(t, err) // Create context with reasonable timeout From 5628b9a061bd6dcc1ec94fbfa1c1499ed48e5fa6 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Fri, 5 Jun 2026 09:56:30 +0530 Subject: [PATCH 28/34] fix: address CodeRabbit PR feedback for arrivals and departures endpoints --- internal/models/response.go | 3 + .../arrival_and_departure_for_stop_handler.go | 12 ++-- ...als_and_departures_for_location_handler.go | 57 +++++++++++++++---- ...nd_departures_for_location_handler_test.go | 12 +++- ...rrivals_and_departures_for_stop_handler.go | 14 +++-- internal/restapi/context_cancellation_test.go | 8 +-- internal/restapi/reference_utils.go | 6 +- 7 files changed, 81 insertions(+), 31 deletions(-) diff --git a/internal/models/response.go b/internal/models/response.go index e26d6578..bf537eb8 100644 --- a/internal/models/response.go +++ b/internal/models/response.go @@ -77,6 +77,9 @@ func NewArrivalsAndDeparturesForLocationResponse( if stopIds == nil { stopIds = []string{} } + if arrivalsAndDepartures == nil { + arrivalsAndDepartures = []ArrivalAndDeparture{} + } entryData := map[string]any{ "arrivalsAndDepartures": arrivalsAndDepartures, "limitExceeded": limitExceeded, diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 94c2d0a9..627db11a 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -712,6 +712,11 @@ func (api *RestAPI) getPredictedTimes( } func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID string, targetStopSequence int, vehicle *gtfs.Vehicle, serviceDate time.Time) *int { + activeTripID := GetVehicleActiveTripID(vehicle) + if activeTripID == "" { + activeTripID = targetTripID + } + currentVehicleStopSequence := getCurrentVehicleStopSequence(vehicle) if currentVehicleStopSequence == nil { @@ -723,7 +728,7 @@ func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID strin return nil } inferred := api.inferStopSequenceFromPosition( - ctx, targetTripID, + ctx, activeTripID, float64(*vehicle.Position.Latitude), float64(*vehicle.Position.Longitude), ) @@ -733,11 +738,6 @@ func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID strin currentVehicleStopSequence = inferred } - activeTripID := GetVehicleActiveTripID(vehicle) - if activeTripID == "" { - activeTripID = targetTripID - } - targetGlobalSeq := api.getBlockSequenceForStopSequence(ctx, targetTripID, targetStopSequence, serviceDate) vehicleGlobalSeq := api.getBlockSequenceForStopSequence(ctx, activeTripID, *currentVehicleStopSequence, serviceDate) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index d1947c87..d94acdc3 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -354,7 +354,11 @@ func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWrite api.collectStopLevelAlerts(stops, state) - references, topLevelSituationIDs := api.buildLocationReferencesBlock(ctx, state) + references, topLevelSituationIDs, refErr := api.buildLocationReferencesBlock(ctx, state) + if refErr != nil { + api.serverErrorResponse(w, r, refErr) + return + } queriedStopIDs := api.buildLocationQueriedStopIDs(stops, state) nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon) @@ -460,7 +464,11 @@ func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.R allActiveStopTimes, err := api.fetchActiveStopTimesForLocationWindow(ctx, spc, params) if err != nil { - api.clientCanceledResponse(w, r, err) + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + } else { + api.serverErrorResponse(w, r, err) + } return err } if len(allActiveStopTimes) == 0 { @@ -642,7 +650,12 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( ctx := r.Context() routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, spc.StopCode, allActiveStopTimes) if bErr != nil { - return false, nil + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return false, ctx.Err() + } + api.serverErrorResponse(w, r, bErr) + return false, bErr } stopProducedArrival := false @@ -876,13 +889,15 @@ func (api *RestAPI) sortLocationArrivalsByTime(arrivals []models.ArrivalAndDepar }) } -func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *locationArrivalsState) (*models.ReferencesModel, []string) { +func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *locationArrivalsState) (*models.ReferencesModel, []string, error) { references := models.NewEmptyReferences() addedAgencyIDs := make(map[string]bool) api.addTripReferences(ctx, state, references) api.addRouteAndAgencyReferences(ctx, state, references, addedAgencyIDs) - api.addStopReferences(ctx, state, references) + if err := api.addStopReferences(ctx, state, references); err != nil { + return nil, nil, err + } topLevelSituationIDs := make([]string, 0, len(state.collectedAlerts)) if len(state.collectedAlerts) > 0 { @@ -894,7 +909,7 @@ func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *loc references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) } - return references, topLevelSituationIDs + return references, topLevelSituationIDs, nil } func (api *RestAPI) addTripReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { @@ -910,13 +925,24 @@ func (api *RestAPI) addTripReferences(ctx context.Context, state *locationArriva continue } } + + headsign := "" + if trip.TripHeadsign.Valid { + headsign = trip.TripHeadsign.String + } + + direction := "" + if trip.DirectionID.Valid { + direction = strconv.FormatInt(trip.DirectionID.Int64, 10) + } + references.Trips = append(references.Trips, *models.NewTripReference( utils.FormCombinedID(routeForTrip.AgencyID, trip.ID), utils.FormCombinedID(routeForTrip.AgencyID, trip.RouteID), utils.FormCombinedID(routeForTrip.AgencyID, trip.ServiceID), - trip.TripHeadsign.String, + headsign, "", - strconv.FormatInt(trip.DirectionID.Int64, 10), + direction, utils.FormCombinedID(routeForTrip.AgencyID, trip.BlockID.String), utils.FormCombinedID(routeForTrip.AgencyID, trip.ShapeID.String), )) @@ -948,10 +974,18 @@ func (api *RestAPI) addRouteAndAgencyReferences(ctx context.Context, state *loca } } -func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { +func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) error { stopIDsSlice := stringMapKeys(state.stopIDSet) - batchStops, _ := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) - batchRoutesForStops, _ := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) + batchStops, sErr := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) + if sErr != nil { + api.Logger.Warn("failed to batch fetch stops for references", slog.Any("error", sErr)) + return sErr + } + batchRoutesForStops, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) + if rErr != nil { + api.Logger.Warn("failed to batch fetch routes for stop references", slog.Any("error", rErr)) + return rErr + } stopsMap := make(map[string]gtfsdb.Stop, len(batchStops)) for _, s := range batchStops { @@ -1006,6 +1040,7 @@ func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArriva StaticRouteIDs: combinedRouteIDs, }) } + return nil } func (api *RestAPI) buildLocationQueriedStopIDs(stops []gtfsdb.Stop, state *locationArrivalsState) []string { diff --git a/internal/restapi/arrivals_and_departures_for_location_handler_test.go b/internal/restapi/arrivals_and_departures_for_location_handler_test.go index 1f8033fa..bc24db8d 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler_test.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler_test.go @@ -46,7 +46,7 @@ func TestParseArrivalsAndDeparturesForLocationParams_CustomValues(t *testing.T) assert.Equal(t, 10, params.MinutesBefore) assert.Equal(t, 60, params.MinutesAfter) assert.Equal(t, 50, params.MaxCount) - assert.False(t, params.Time.IsZero()) + assert.Equal(t, time.UnixMilli(1609459200000).UTC(), params.Time.UTC()) } func TestParseArrivalsAndDeparturesForLocationParams_MissingLatLon(t *testing.T) { @@ -477,6 +477,16 @@ func TestArrivalsAndDeparturesForLocationLimitExceeded(t *testing.T) { mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) api := createTestApiWithClock(t, mockClock) + // First verify unbounded returns > 1 arrival + _, unboundedModel := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + unboundedData, ok := unboundedModel.Data.(map[string]interface{}) + require.True(t, ok) + unboundedEntry, ok := unboundedData["entry"].(map[string]interface{}) + require.True(t, ok) + unboundedAds, _ := unboundedEntry["arrivalsAndDepartures"].([]interface{}) + require.Greater(t, len(unboundedAds), 1, "unbounded request should return more than 1 arrival for test to be valid") + // maxCount=1 forces limitExceeded=true if there is more than 1 arrival. resp, model := serveApiAndRetrieveEndpoint(t, api, "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500&maxCount=1") diff --git a/internal/restapi/arrivals_and_departures_for_stop_handler.go b/internal/restapi/arrivals_and_departures_for_stop_handler.go index afee2785..478e4f1b 100644 --- a/internal/restapi/arrivals_and_departures_for_stop_handler.go +++ b/internal/restapi/arrivals_and_departures_for_stop_handler.go @@ -418,9 +418,11 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r continue } - situationIDs = append(situationIDs, alert.ID) - if _, seen := collectedAlerts[alert.ID]; !seen { - collectedAlerts[alert.ID] = alert + namespacedID := utils.FormCombinedID(route.AgencyID, alert.ID) + situationIDs = append(situationIDs, namespacedID) + if _, seen := collectedAlerts[namespacedID]; !seen { + alert.ID = namespacedID + collectedAlerts[namespacedID] = alert } } @@ -602,8 +604,10 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r for _, alert := range api.GtfsManager.GetAlertsForStop(stopCode) { if alert.ID != "" { - if _, seen := collectedAlerts[alert.ID]; !seen { - collectedAlerts[alert.ID] = alert + namespacedID := utils.FormCombinedID(stopAgencyID, alert.ID) + if _, seen := collectedAlerts[namespacedID]; !seen { + alert.ID = namespacedID + collectedAlerts[namespacedID] = alert } } } diff --git a/internal/restapi/context_cancellation_test.go b/internal/restapi/context_cancellation_test.go index 22ac6587..789895fc 100644 --- a/internal/restapi/context_cancellation_test.go +++ b/internal/restapi/context_cancellation_test.go @@ -81,16 +81,14 @@ func TestContextCancellationHandling(t *testing.T) { // If cancelled, we expect a timeout or cancellation error response statusCode := w.Code - // Valid responses: 200 (completed), 401 (API validation), 500 (error), timeout-related, + // Valid responses: 200 (completed), 500 (error), timeout-related, // or 404 (not found). Rate limit 429 is prevented by using an exempt key. assert.True(t, statusCode == http.StatusOK || - statusCode == http.StatusUnauthorized || // API key validation happens first - statusCode == http.StatusBadRequest || statusCode == http.StatusInternalServerError || statusCode == http.StatusRequestTimeout || statusCode == http.StatusGatewayTimeout || statusCode == http.StatusNotFound, - "Expected status 200, 401, 400, 404, 500, 408, or 504, got %d", statusCode) + "Expected status 200, 404, 500, 408, or 504, got %d", statusCode) }) } } @@ -101,7 +99,7 @@ func TestLongerTimeoutContextHandling(t *testing.T) { // Test with a reasonable timeout that should allow completion t.Run("reasonable timeout should complete successfully", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=org.onebusaway.iphone", nil) + req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=TEST", nil) require.NoError(t, err) // Create context with reasonable timeout diff --git a/internal/restapi/reference_utils.go b/internal/restapi/reference_utils.go index a929ceac..b194caba 100644 --- a/internal/restapi/reference_utils.go +++ b/internal/restapi/reference_utils.go @@ -116,12 +116,12 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa agencyID := getStringValue(entity.AgencyID) rawRouteID := getStringValue(entity.RouteID) - if rawRouteID != "" { + if rawRouteID != "" && agencyID != "" { rawRouteID = utils.FormCombinedID(agencyID, rawRouteID) } rawStopID := getStringValue(entity.StopID) - if rawStopID != "" { + if rawStopID != "" && agencyID != "" { rawStopID = utils.FormCombinedID(agencyID, rawStopID) } @@ -134,7 +134,7 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa TripID: "", } - if entity.TripID != nil && entity.TripID.ID != "" { + if entity.TripID != nil && entity.TripID.ID != "" && agencyID != "" { affectedEntity.TripID = utils.FormCombinedID(agencyID, entity.TripID.ID) } From 9cd6f63be268846b265929df65472c50dc1cef7f Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Fri, 5 Jun 2026 10:03:11 +0530 Subject: [PATCH 29/34] left out changes pushed --- internal/restapi/reference_utils.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/restapi/reference_utils.go b/internal/restapi/reference_utils.go index b194caba..4784698e 100644 --- a/internal/restapi/reference_utils.go +++ b/internal/restapi/reference_utils.go @@ -134,8 +134,12 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa TripID: "", } - if entity.TripID != nil && entity.TripID.ID != "" && agencyID != "" { - affectedEntity.TripID = utils.FormCombinedID(agencyID, entity.TripID.ID) + if entity.TripID != nil && entity.TripID.ID != "" { + if agencyID != "" { + affectedEntity.TripID = utils.FormCombinedID(agencyID, entity.TripID.ID) + } else { + affectedEntity.TripID = entity.TripID.ID + } } situation.AllAffects = append(situation.AllAffects, affectedEntity) From db633da9254d0dc8e63217ed56eda411a7a9e0a9 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Fri, 5 Jun 2026 10:06:50 +0530 Subject: [PATCH 30/34] fixed sonar warning --- ...als_and_departures_for_location_handler.go | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index d94acdc3..5edae246 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -1001,32 +1001,10 @@ func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArriva if !ok { continue } - ag := state.stopAgencyMap[sid] - if ag == "" { - ag = state.stopAgencyOverride[sid] - } - if ag == "" { - ag = state.fallbackAgencyID - } - routesForStop := routesByStop[sid] - combinedRouteIDs := make([]string, len(routesForStop)) - for i, rr := range routesForStop { - combinedRouteIDs[i] = utils.FormCombinedID(rr.AgencyID, rr.ID) - if _, exists := state.routeIDSet[rr.ID]; !exists { - rc := gtfsdb.Route{ - ID: rr.ID, - AgencyID: rr.AgencyID, - ShortName: rr.ShortName, - LongName: rr.LongName, - Desc: rr.Desc, - Type: rr.Type, - Url: rr.Url, - Color: rr.Color, - TextColor: rr.TextColor, - } - state.routeIDSet[rr.ID] = &rc - } - } + + ag := resolveAgencyForStop(sid, state) + combinedRouteIDs := processRoutesForStop(routesByStop[sid], state) + references.Stops = append(references.Stops, models.Stop{ ID: utils.FormCombinedID(ag, stopData.ID), Name: stopData.Name.String, @@ -1043,6 +1021,38 @@ func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArriva return nil } +func resolveAgencyForStop(sid string, state *locationArrivalsState) string { + if ag := state.stopAgencyMap[sid]; ag != "" { + return ag + } + if ag := state.stopAgencyOverride[sid]; ag != "" { + return ag + } + return state.fallbackAgencyID +} + +func processRoutesForStop(routesForStop []gtfsdb.GetRoutesForStopsRow, state *locationArrivalsState) []string { + combinedRouteIDs := make([]string, len(routesForStop)) + for i, rr := range routesForStop { + combinedRouteIDs[i] = utils.FormCombinedID(rr.AgencyID, rr.ID) + if _, exists := state.routeIDSet[rr.ID]; !exists { + rc := gtfsdb.Route{ + ID: rr.ID, + AgencyID: rr.AgencyID, + ShortName: rr.ShortName, + LongName: rr.LongName, + Desc: rr.Desc, + Type: rr.Type, + Url: rr.Url, + Color: rr.Color, + TextColor: rr.TextColor, + } + state.routeIDSet[rr.ID] = &rc + } + } + return combinedRouteIDs +} + func (api *RestAPI) buildLocationQueriedStopIDs(stops []gtfsdb.Stop, state *locationArrivalsState) []string { queriedStopIDs := make([]string, 0, len(state.stopsWithArrivals)) for _, dbStop := range stops { From 10a604ffe8c9a33e68d2b3b5284017d4371cf454 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sat, 6 Jun 2026 10:30:44 +0530 Subject: [PATCH 31/34] bugfix --- ...als_and_departures_for_location_handler.go | 61 ++++++++++++++++--- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index 5edae246..8a633629 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -521,7 +521,7 @@ func (api *RestAPI) fetchActiveStopTimesForLocationWindow( func (api *RestAPI) batchFetchLocationRoutesAndTrips( ctx context.Context, stopCode string, allActiveStopTimes []activeStopTime, -) (map[string]gtfsdb.Route, map[string]gtfsdb.Trip, map[string]int, error) { +) (map[string]gtfsdb.Route, map[string]gtfsdb.Trip, map[string]int, map[string]bool, error) { batchRouteIDs := make(map[string]bool) batchTripIDs := make(map[string]bool) for _, ast := range allActiveStopTimes { @@ -539,12 +539,12 @@ func (api *RestAPI) batchFetchLocationRoutesAndTrips( fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) if rErr != nil { api.Logger.Warn("failed to batch fetch routes", slog.String("stopID", stopCode), slog.Any("error", rErr)) - return nil, nil, nil, rErr + return nil, nil, nil, nil, rErr } fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) if tErr != nil { api.Logger.Warn("failed to batch fetch trips", slog.String("stopID", stopCode), slog.Any("error", tErr)) - return nil, nil, nil, tErr + return nil, nil, nil, nil, tErr } routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) @@ -557,7 +557,20 @@ func (api *RestAPI) batchFetchLocationRoutesAndTrips( } tripStopCountMap := api.buildTripStopCountMap(ctx, uniqueTripIDs) - return routesLookup, tripsLookup, tripStopCountMap, nil + + frequencyTripsMap := make(map[string]bool, len(uniqueTripIDs)) + if len(uniqueTripIDs) > 0 { + freqRows, err := api.GtfsManager.GtfsDB.Queries.GetFrequenciesForTrips(ctx, uniqueTripIDs) + if err != nil { + api.Logger.Warn("failed to batch fetch frequencies for trips", slog.Any("error", err)) + } else { + for _, freq := range freqRows { + frequencyTripsMap[freq.TripID] = true + } + } + } + + return routesLookup, tripsLookup, tripStopCountMap, frequencyTripsMap, nil } func (api *RestAPI) buildTripStopCountMap(ctx context.Context, uniqueTripIDs []string) map[string]int { @@ -648,7 +661,7 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( state *locationArrivalsState, ) (bool, error) { ctx := r.Context() - routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, spc.StopCode, allActiveStopTimes) + routesLookup, tripsLookup, tripStopCountMap, frequencyTripsMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, spc.StopCode, allActiveStopTimes) if bErr != nil { if ctx.Err() != nil { api.clientCanceledResponse(w, r, ctx.Err()) @@ -687,8 +700,10 @@ func (api *RestAPI) buildArrivalsFromLocationStopTimes( tCopy := trip state.tripIDSet[trip.ID] = &tCopy - api.buildSingleArrival(ctx, spc, ast, state, route, tripStopCountMap[st.TripID]) - stopProducedArrival = true + added := api.buildSingleArrival(ctx, spc, ast, state, route, tripStopCountMap[st.TripID], frequencyTripsMap[st.TripID], params) + if added { + stopProducedArrival = true + } } return stopProducedArrival, nil @@ -701,7 +716,9 @@ func (api *RestAPI) buildSingleArrival( state *locationArrivalsState, route gtfsdb.Route, totalStopsInTrip int, -) { + isFrequency bool, + params ArrivalsAndDeparturesForLocationParams, +) bool { st := ast.GetStopTimesForStopInWindowRow ac := &arrivalContext{ st: st, @@ -724,6 +741,32 @@ func (api *RestAPI) buildSingleArrival( api.applyTripStatus(ctx, ac, route, vehicle, spc.QueryTime, spc.StopCode, state) } + // Secondary filter for exact time bounds based on predicted vs scheduled times + var tripBefore, tripAfter int + if isFrequency { + tripBefore = params.FrequencyMinutesBefore + tripAfter = params.FrequencyMinutesAfter + } else { + tripBefore = params.MinutesBefore + tripAfter = params.MinutesAfter + } + + windowStart := params.Time.Add(-time.Duration(tripBefore) * time.Minute) + windowEnd := params.Time.Add(time.Duration(tripAfter) * time.Minute) + + arrTimeForFilter := ac.scheduledArrivalTime + if ac.predicted { + arrTimeForFilter = ac.predictedArrivalTime + } + depTimeForFilter := ac.scheduledDepartureTime + if ac.predicted { + depTimeForFilter = ac.predictedDepartureTime + } + + if depTimeForFilter.Before(windowStart) || arrTimeForFilter.After(windowEnd) { + return false + } + ac.blockTripSequence = api.calculateBlockTripSequence(ctx, ac.st.TripID, ac.serviceMidnight) api.applyAlerts(ctx, ac, state) @@ -761,6 +804,8 @@ func (api *RestAPI) buildSingleArrival( ac.tripStatus, ac.situationIDs, )) + + return true } func (api *RestAPI) applyPredictedTimes(ac *arrivalContext, stopCode string) { From 139f010234b113f93f9f0f0f5a963660f6abb4b6 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sat, 6 Jun 2026 10:46:46 +0530 Subject: [PATCH 32/34] default value fix --- ...rrivals_and_departures_for_location_handler.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index 8a633629..c4ef4a22 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -158,8 +158,19 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) params.Time = parseTimeParam(q, params.Time, addError) parseMinutesCappedParam(q, "minutesBefore", maxMinutesBefore, ¶ms.MinutesBefore, addError) parseMinutesCappedParam(q, "minutesAfter", maxMinutesAfter, ¶ms.MinutesAfter, addError) - parseMinutesUncappedParam(q, "frequencyMinutesBefore", ¶ms.FrequencyMinutesBefore, addError) - parseMinutesUncappedParam(q, "frequencyMinutesAfter", ¶ms.FrequencyMinutesAfter, addError) + + if q.Get("frequencyMinutesBefore") == "" { + params.FrequencyMinutesBefore = params.MinutesBefore + } else { + parseMinutesUncappedParam(q, "frequencyMinutesBefore", ¶ms.FrequencyMinutesBefore, addError) + } + + if q.Get("frequencyMinutesAfter") == "" { + params.FrequencyMinutesAfter = params.MinutesAfter + } else { + parseMinutesUncappedParam(q, "frequencyMinutesAfter", ¶ms.FrequencyMinutesAfter, addError) + } + params.EmptyReturnsNotFound = parseEmptyReturnsNotFoundParam(q, addError) params.RouteTypes = parseRouteTypesParam(q, addError) From 8cec8408c9e40f4f1012a7fa3b9fd8dfbc060d62 Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sat, 6 Jun 2026 10:51:08 +0530 Subject: [PATCH 33/34] minor fix --- .../restapi/arrivals_and_departures_for_location_handler.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index c4ef4a22..0a484217 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -147,6 +147,9 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) if len(locErrors) > 0 { mergeFieldErrors(&fieldErrors, locErrors) } else { + if loc.Radius == 0 && loc.LatSpan == 0 && loc.LonSpan == 0 { + loc.Radius = models.QuerySearchRadiusInMeters + } params.Lat = loc.Lat params.Lon = loc.Lon params.Radius = loc.Radius From 1ff7198b42d407f3d88e2ae54e214c27ee66abfa Mon Sep 17 00:00:00 2001 From: Aditya Rana Date: Sat, 6 Jun 2026 10:52:02 +0530 Subject: [PATCH 34/34] ran fmt --- .../restapi/arrivals_and_departures_for_location_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go index 0a484217..5a17002c 100644 --- a/internal/restapi/arrivals_and_departures_for_location_handler.go +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -161,7 +161,7 @@ func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) params.Time = parseTimeParam(q, params.Time, addError) parseMinutesCappedParam(q, "minutesBefore", maxMinutesBefore, ¶ms.MinutesBefore, addError) parseMinutesCappedParam(q, "minutesAfter", maxMinutesAfter, ¶ms.MinutesAfter, addError) - + if q.Get("frequencyMinutesBefore") == "" { params.FrequencyMinutesBefore = params.MinutesBefore } else {