Skip to content

[BUG] raft::matrix::gather (in-place overload): illegal memory access when N*D ≥ 2^31 (32-bit index overflow) #3055

Description

@irina-resh-nvda

Component: raft::matrix::gather(res, inout, map) -> raft::matrix::detail::gatherInplaceImpl

Summary: The in-place gather corrupts/illegally accesses memory once the flattened element count NxD reaches 2^31. The out-of-place gather with identical inputs is correct, so the index type used for the flattened size in the in-place path appears to be 32-bit. The matrix index type is int64_t in all cases.

Repro: included _gather_repro.cu (raft-only; fills an NxD uint8 matrix, builds a random permutation via raft::random::permute, permutes rows, verifies on device). Build with C++20, linking rmm + cudart; run as ./_gather_repro <N> <D> <inplace:0|1>.

Observed (inplace=1):

N=2,782,000  D=772    N*D=2,147,704,000  (>2^31)   CRASH: cudaErrorIllegalAddress
N=1,000,000  D=2000   N*D=2,000,000,000  (< 2^31)   PASS
N=1,100,000  D=2000   N*D=2,200,000,000  (> 2^31)   CRASH
N=  100,000  D=22000  N*D=2,200,000,000  (> 2^31)   CRASH

The same sizes with inplace=0 (out-of-place gather(res, in, map, out)) all PASS with mismatches=0. Failure tracks NxD (not N or D individually); threshold is exactly NxD = 2^31.

Expected: in-place gather works for NxD >= 2^31 (it advertises int64_t indexing), matching the out-of-place result.

Env: CUDA 12.9, NVIDIA B200 (sm_100). (Arch-independent : it's an index-overflow.)
raft 26.08.00

_gather_repro.cu :

// Minimal standalone repro for raft::matrix::gather in-place vs out-of-place.
//
// Usage: _gather_repro <N> <D> <inplace:0|1> [use_stream_pool:0|1]

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/permute.cuh>
#include <raft/util/cudart_utils.hpp>

#include <cstdint>
#include <cstdio>
#include <cstdlib>

__global__ void fill_kernel(uint8_t* d, int64_t N, int64_t D)
{
  int64_t i = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
  if (i >= N) return;
  uint8_t v = (uint8_t)(i % 251);
  for (int64_t k = 0; k < D; k++)
    d[i * D + k] = v;
}

__global__ void check_kernel(
  const uint8_t* d, const int64_t* perm, int64_t N, int64_t D, int* bad)
{
  int64_t i = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
  if (i >= N) return;
  uint8_t exp = (uint8_t)(perm[i] % 251);
  if (d[i * D] != exp) atomicAdd(bad, 1);
}

int main(int argc, char** argv)
{
  int64_t N        = argc > 1 ? atoll(argv[1]) : 4990000;
  int64_t D        = argc > 2 ? atoll(argv[2]) : 772;
  int inplace      = argc > 3 ? atoi(argv[3]) : 1;
  int use_pool     = argc > 4 ? atoi(argv[4]) : 0;

  raft::device_resources res;
  if (use_pool) { raft::resource::set_cuda_stream_pool(res, std::make_shared<rmm::cuda_stream_pool>(4)); }
  auto stream = raft::resource::get_cuda_stream(res);

  auto data = raft::make_device_matrix<uint8_t, int64_t>(res, N, D);
  {
    int tpb = 256;
    int64_t nb = (N + tpb - 1) / tpb;
    fill_kernel<<<nb, tpb, 0, stream>>>(data.data_handle(), N, D);
  }

  auto perm = raft::make_device_vector<int64_t, int64_t>(res, N);
  raft::random::permute<uint8_t, int64_t, int64_t>(
    perm.data_handle(), (uint8_t*)nullptr, (const uint8_t*)nullptr, (int64_t)D, (int64_t)N, true, stream);
  raft::resource::sync_stream(res);

  printf("Running gather: N=%ld D=%ld inplace=%d use_pool=%d\n", N, D, inplace, use_pool);
  if (inplace) {
    raft::matrix::gather(res, data.view(), raft::make_const_mdspan(perm.view()));
  } else {
    auto out = raft::make_device_matrix<uint8_t, int64_t>(res, N, D);
    raft::matrix::gather(res,
                         raft::make_const_mdspan(data.view()),
                         raft::make_const_mdspan(perm.view()),
                         out.view());
    raft::copy(data.data_handle(), out.data_handle(), N * D, stream);
  }
  raft::resource::sync_stream(res);  // async illegal access (if any) surfaces here
  printf("Gather completed without crash.\n");

  auto bad = raft::make_device_scalar<int>(res, 0);
  {
    int tpb = 256;
    int64_t nb = (N + tpb - 1) / tpb;
    check_kernel<<<nb, tpb, 0, stream>>>(data.data_handle(), perm.data_handle(), N, D, bad.data_handle());
  }
  int h_bad = 0;
  raft::copy(&h_bad, bad.data_handle(), 1, stream);
  raft::resource::sync_stream(res);

  printf("RESULT inplace=%d use_pool=%d mismatches=%d %s\n",
         inplace, use_pool, h_bad, (h_bad == 0 ? "OK" : "WRONG"));
  return h_bad == 0 ? 0 : 2;
}

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

Status
Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions