Skip to content

Design Ideas

Avirup Sircar edited this page May 16, 2026 · 20 revisions

**** Policy Design Ideas

Step-1: Create a View for Storage

// The StorageView is something which does not have any constructor
// or destructor of the Storage.
// I have hardcoded it for double for illustration.
// You can do it for T *

template <utils::MemorySpace memSpace>
struct StorageView
{
  double* d_x;  // raw pointer, no ownership
  int     n;

  DFTEFE_HOST_DEVICE
  double* data() const { return d_x; }

  DFTEFE_HOST_DEVICE
  int size() const { return n; }
};

template <utils::MemorySpace memSpace>
class Storage
{
public:
  // usual Storage implementation

  // additional part to return a view of the Storage

  // Produce a non-owning view safe to pass into kernels
  StorageView<memSpace> view() const { return {d_x, n}; }

private:
  double* d_x;
  int     n;
};

Step-2: Create a View for MyClass (placeholder for Spline class) which also has the eval()

// MyClass.h

template <utils::MemorySpace MemSpace>
struct MyClassView
{
  StorageView<MemSpace> d_storageview;

  DFTEFE_HOST_DEVICE
  void eval(double * x) const
  {
    // operate on d_storageview.data, d_storageview.n and x
  }
};

template <utils::MemorySpace MemSpace>
class MyClass
{
public:
  MyClass(int n) : d_storage(n) {}

  void eval(double * x)
  {
    // first get a MyClassView object (which is cheap to construct)
    // and then call eval() using it. This is to avoid implementing
    // eval() in both MyClass and MyClassView

    view().eval(x);
  }

  // Produce a non-owning view
  MyClassView<MemSpace> view() const { return {d_storage.view()}; }

private:
  Storage<MemSpace> d_storage;
};

Step-3: Define the Policy using MyClassView

template <utils::MemorySpace memSpace>
struct TestPolicy
{
  MyClassView<memSpace> obj;  // non-owning, safe to copy into kernel

  DFTEFE_HOST_DEVICE
  void operator()(double * x) const
  {
    obj.foo(x);
  }
};

Usage

// Owning objects live on the host
MyClass<utils::MemorySpace::DEVICE> myObj(n);

// Views are cheap to produce and safe to copy
TestPolicy<utils::MemorySpace::DEVICE> policy{myObj.view()};

// double * xhost allocated allocated on device
// double * xdev allocated on host

// Same launch infrastructure as before, nothing changes

DFTEFE_CREATE_KERNEL(
  void,
  test,
  {
    if (globalThreadId < n)
      policy(x);
  },
  TestPolicy policy,
  int n,
  double * x
);

launch<utils::MemorySpace::DEVICE>(test, policy, n, xdev);
launch<utils::MemorySpace::HOST>(test, policy, n, xhost);

CPU–GPU #ifdef DFTEFE_WITH_DEVICE Design Patterns in DFTEFE

There are three distinct ways #ifdef DFTEFE_WITH_DEVICE appears in DFTEFE.


1. Class stores device data members and launches device kernels internally

Preferred approach: template the class on MemorySpace.

This is the cleanest design when the class owns memory-space-specific data.

Example

template <utils::MemorySpace memorySpace>
class DensityCalculator
{
public:
  void computeRho(
    const std::vector<RealType> &occupation,
    const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
      &waveFunc,
    quadrature::QuadratureValuesContainer<RealType,
                                          utils::MemorySpace::HOST> &rho);
};

Implementation:

template <utils::MemorySpace memorySpace>
void DensityCalculator<memorySpace>::computeRho(
  const std::vector<RealType> &occupation,
  const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
    &waveFunc,
  quadrature::QuadratureValuesContainer<RealType,
                                        utils::MemorySpace::HOST> &rho)
{
  // internally calls computeRhoInBatch(...)
}

Kernel dispatch helper

template <
  typename ValueType,
  typename RealType,
  utils::MemorySpace memorySpace>
class DensityCalculatorKernels
{
public:
  static void computeRhoInBatch(
    const utils::MemoryStorage<RealType, memorySpace> &occupationInBatch,
    quadrature::QuadratureValuesContainer<ValueType, memorySpace>
      &psiBatchQuad,
    quadrature::QuadratureValuesContainer<RealType, memorySpace>
      &modPsiSqBatchQuad,
    std::shared_ptr<const quadrature::QuadratureRuleContainer>
      quadRuleContainer,
    quadrature::QuadratureValuesContainer<RealType, memorySpace>
      &rhoBatch,
    linearAlgebra::LinAlgOpContext<memorySpace>
      &linAlgOpContext);
};

Device specialization

#ifdef DFTEFE_WITH_DEVICE

template <typename ValueType, typename RealType>
class DensityCalculatorKernels<
  ValueType,
  RealType,
  utils::MemorySpace::DEVICE>
{
public:
  static void computeRhoInBatch(
    const utils::MemoryStorage<
      RealType,
      utils::MemorySpace::DEVICE> &occupationInBatch,

    quadrature::QuadratureValuesContainer<
      ValueType,
      utils::MemorySpace::DEVICE> &psiBatchQuad,

    quadrature::QuadratureValuesContainer<
      RealType,
      utils::MemorySpace::DEVICE> &modPsiSqBatchQuad,

    std::shared_ptr<const quadrature::QuadratureRuleContainer>
      quadRuleContainer,

    quadrature::QuadratureValuesContainer<
      RealType,
      utils::MemorySpace::DEVICE> &rhoBatch,

    linearAlgebra::LinAlgOpContext<
      utils::MemorySpace::DEVICE> &linAlgOpContext);
};

#endif

Why this works

The memory space is resolved at compile time, so no runtime branching is needed.


2. Base class exposes host API and internally dispatches to host/device implementations

Alternative when templating the class would require large refactoring.

This pattern keeps the external interface unchanged.

Base class

class ScalarSpatialFunction
{
public:
  template <utils::MemorySpace memorySpace>
  void eval(
    const size_type numPoints,
    const double   *t,
    Q              *q) const
  {
    if constexpr (memorySpace == utils::MemorySpace::DEVICE)
    {
#ifdef DFTEFE_WITH_DEVICE
      evalDevice(numPoints, t, q);
#else
      utils::throwException(
        false,
        "eval<DEVICE>() called but DEVICE support not compiled.");
#endif
    }
    else
    {
      evalHost(numPoints, t, q);
    }
  }

protected:
  virtual void evalHost(
    size_type      numPoints,
    const double  *t,
    Q             *q) const;

#ifdef DFTEFE_WITH_DEVICE
  virtual void evalDevice(
    size_type      numPoints,
    const double  *t,
    Q             *q) const;
#endif
};

Derived class

class SmearChargePotentialFunction
  : public ScalarSpatialFunctionReal
{
protected:
  void evalHost(
    size_type      numPoints,
    const double  *t,
    double        *q) const override;

#ifdef DFTEFE_WITH_DEVICE
  void evalDevice(
    size_type      numPoints,
    const double  *t,
    double        *q) const override;
#endif
};

Advantages

  • Minimal API disruption
  • No class-wide templating
  • Preserves polymorphism

3. Class member functions callable inside device kernels

This is the hardest case in the current DFTEFE design.

This happens when a class member function itself must execute inside a CUDA/HIP kernel.

Example:

obj->getValueDevice(...)

This is problematic because:

  • Host objects cannot generally be dereferenced on device
  • Virtual dispatch on device is difficult
  • Host object pointers must be transferred to device
  • Polymorphism does not map cleanly

Recommended solution: DeviceView by chatgpt

Expose a device representation.


Header: SphericalDataNumerical.h

class SphericalDataNumerical : public SphericalData
{
public:

#ifdef DFTEFE_WITH_DEVICE

  struct DeviceView
  {
    utils::SplineDeviceView radialSpline;

    int l;
    int m;
    int mEff;

    double constant;
    double cutoff;
    double smoothness;
    double polarAngleTolerance;

    DFTEFE_DEVICE_FUNC
    double getValueDevice(
      const double *point,
      const double *origin) const;
  };

  DeviceView getDeviceView() const;

  DFTEFE_DEVICE_FUNC
  double getValueDevice(
    const double *point,
    const double *origin) const;

#endif

private:

#ifdef DFTEFE_WITH_DEVICE
  DeviceView d_deviceView;
#endif
};

DeviceView construction

#ifdef DFTEFE_WITH_DEVICE

SphericalDataNumerical::DeviceView
SphericalDataNumerical::getDeviceView() const
{
  DeviceView v;

  v.radialSpline = d_spline->getDeviceView();

  v.l    = d_qNumbers[1];
  v.m    = d_qNumbers[2];
  v.mEff = std::abs(v.m);

  v.constant            = Clm(v.l, v.m) * Dm(v.m);
  v.cutoff              = d_cutoff;
  v.smoothness          = d_smoothness;
  v.polarAngleTolerance = d_polarAngleTolerance;

  return v;
}

#endif

Device-side implementation

DFTEFE_DEVICE_FUNC
double
SphericalDataNumerical::DeviceView::getValueDevice(
  const double *point,
  const double *origin) const
{
  ...
}

How it is used inside kernels

The key idea is:

  • Construct DeviceView on host
  • Copy an array of DeviceViews to device
  • Pass the device array directly into kernels
  • Invoke device member functions from inside the kernel

Kernel definition

namespace
{
  DFTEFE_CREATE_KERNEL(
    void,
    evalEnrichmentInCell,
    {
      for (size_type iThread = globalThreadId;
           iThread < numEnrichInCell * numQuadInCell;
           iThread += nThreadsPerBlock * nThreadBlock)
      {
        const size_type enrichId = iThread % numQuadInCell;
        const size_type quadId   = iThread / numQuadInCell;

        output[iThread] =
          views[enrichId].getValueDevice(
            quadPtsInCell + quadId * 3,
            origin + enrichId * 3);
      }
    },
    const double *quadPtsInCell,
    const double *origin,
    const size_type numEnrichInCell,
    const size_type numQuadInCell,
    const atoms::SphericalDataNumerical::DeviceView *views,
    double *output);
}

The important line is:

views[enrichId].getValueDevice(...)

This is possible because views is a device-accessible array of DeviceView objects which have to be created,


Host-side kernel launcher

void
EnrichmentDataEvalKernels<utils::MemorySpace::DEVICE>::
getEnrichmentValuesInCellRange(
  const double *quadPtsInAllCells,
  const double *originPtsInAllCells,
  std::pair<size_type, size_type> cellRange,
  const std::vector<size_type> numEnrichIdsInAllCells,
  const std::vector<size_type> numQuadPtsInAllCells,
  const atoms::SphericalDataNumerical::DeviceView *data,
  double *output,
  linearAlgebra::LinAlgOpContext<
    utils::MemorySpace::DEVICE> &linAlgOpContext)
{
  const size_type numStreams =
    linAlgOpContext.numBlasStreams();

  auto *streams =
    linAlgOpContext.getBlasStreamsVec();

  constexpr size_type dim = 3;

  size_type cumulativeEnrichInCellRange = 0;
  size_type cumulativeQuadPtsInCellRange = 0;
  size_type cumulativeQuadxEnrichInCellRange = 0;

  for (int iCell = 0; iCell < cellRange.first; iCell++)
  {
    const size_type numEnrichInCell =
      numEnrichIdsInAllCells[iCell];

    const size_type numQuadInCell =
      numQuadPtsInAllCells[iCell];

    cumulativeEnrichInCellRange += numEnrichInCell;
    cumulativeQuadPtsInCellRange += numQuadInCell;
    cumulativeQuadxEnrichInCellRange +=
      numEnrichInCell * numQuadInCell;
  }

  size_type cumulativeCellWithNonZeroNumEnrich = 0;

  for (int iCell = cellRange.first;
       iCell < cellRange.second;
       iCell++)
  {
    const size_type numEnrichInCell =
      numEnrichIdsInAllCells[iCell];

    const size_type numQuadInCell =
      numQuadPtsInAllCells[iCell];

    if (numEnrichInCell > 0)
    {
      const size_type sid =
        cumulativeCellWithNonZeroNumEnrich % numStreams;

      const size_type total =
        numEnrichInCell * numQuadInCell;

      const size_type blockSize =
        utils::DEVICE_BLOCK_SIZE;

      const size_type grid =
        (total + blockSize - 1) / blockSize;

      DFTEFE_LAUNCH_KERNEL(
        evalEnrichmentInCell,
        grid,
        blockSize,
        streams[sid],
        quadPtsInAllCells +
          cumulativeQuadPtsInCellRange * dim,
        originPtsInAllCells +
          cumulativeEnrichInCellRange * dim,
        numEnrichInCell,
        numQuadInCell,
        data + cumulativeEnrichInCellRange,
        output +
          cumulativeQuadxEnrichInCellRange);

      cumulativeCellWithNonZeroNumEnrich++;
    }

    cumulativeQuadPtsInCellRange += numQuadInCell;
    cumulativeEnrichInCellRange += numEnrichInCell;
    cumulativeQuadxEnrichInCellRange +=
      numEnrichInCell * numQuadInCell;
  }

  for (int s = 0; s < numStreams; ++s)
    utils::deviceStreamSynchronize(streams[s]);
}

Why this works

The kernel receives:

const atoms::SphericalDataNumerical::DeviceView *views

Each thread can safely call:

views[i].getValueDevice(...)

because:

  • DeviceView contains only device-safe data
  • no host pointers are dereferenced
  • all state is explicitly copied to device memory