-
Notifications
You must be signed in to change notification settings - Fork 3
Design Ideas
// The StorageView is something which does not have any constructor
// or destructor of the Storage.
// I have hardcoded it for double for illustration.
// You can do it for T *
template <utils::MemorySpace memSpace>
struct StorageView
{
double* d_x; // raw pointer, no ownership
int n;
DFTEFE_HOST_DEVICE
double* data() const { return d_x; }
DFTEFE_HOST_DEVICE
int size() const { return n; }
};
template <utils::MemorySpace memSpace>
class Storage
{
public:
// usual Storage implementation
// additional part to return a view of the Storage
// Produce a non-owning view safe to pass into kernels
StorageView<memSpace> view() const { return {d_x, n}; }
private:
double* d_x;
int n;
};// MyClass.h
template <utils::MemorySpace MemSpace>
struct MyClassView
{
StorageView<MemSpace> d_storageview;
DFTEFE_HOST_DEVICE
void eval(double * x) const
{
// operate on d_storageview.data, d_storageview.n and x
}
};
template <utils::MemorySpace MemSpace>
class MyClass
{
public:
MyClass(int n) : d_storage(n) {}
void eval(double * x)
{
// first get a MyClassView object (which is cheap to construct)
// and then call eval() using it. This is to avoid implementing
// eval() in both MyClass and MyClassView
view().eval(x);
}
// Produce a non-owning view
MyClassView<MemSpace> view() const { return {d_storage.view()}; }
private:
Storage<MemSpace> d_storage;
};template <utils::MemorySpace memSpace>
struct TestPolicy
{
MyClassView<memSpace> obj; // non-owning, safe to copy into kernel
DFTEFE_HOST_DEVICE
void operator()(double * x) const
{
obj.foo(x);
}
};// Owning objects live on the host
MyClass<utils::MemorySpace::DEVICE> myObj(n);
// Views are cheap to produce and safe to copy
TestPolicy<utils::MemorySpace::DEVICE> policy{myObj.view()};
// double * xhost allocated allocated on device
// double * xdev allocated on host
// Same launch infrastructure as before, nothing changes
DFTEFE_CREATE_KERNEL(
void,
test,
{
if (globalThreadId < n)
policy(x);
},
TestPolicy policy,
int n,
double * x
);
launch<utils::MemorySpace::DEVICE>(test, policy, n, xdev);
launch<utils::MemorySpace::HOST>(test, policy, n, xhost);There are three distinct ways #ifdef DFTEFE_WITH_DEVICE appears in DFTEFE.
Preferred approach: template the class on MemorySpace.
This is the cleanest design when the class owns memory-space-specific data.
template <utils::MemorySpace memorySpace>
class DensityCalculator
{
public:
void computeRho(
const std::vector<RealType> &occupation,
const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
&waveFunc,
quadrature::QuadratureValuesContainer<RealType,
utils::MemorySpace::HOST> &rho);
};Implementation:
template <utils::MemorySpace memorySpace>
void DensityCalculator<memorySpace>::computeRho(
const std::vector<RealType> &occupation,
const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
&waveFunc,
quadrature::QuadratureValuesContainer<RealType,
utils::MemorySpace::HOST> &rho)
{
// internally calls computeRhoInBatch(...)
}template <
typename ValueType,
typename RealType,
utils::MemorySpace memorySpace>
class DensityCalculatorKernels
{
public:
static void computeRhoInBatch(
const utils::MemoryStorage<RealType, memorySpace> &occupationInBatch,
quadrature::QuadratureValuesContainer<ValueType, memorySpace>
&psiBatchQuad,
quadrature::QuadratureValuesContainer<RealType, memorySpace>
&modPsiSqBatchQuad,
std::shared_ptr<const quadrature::QuadratureRuleContainer>
quadRuleContainer,
quadrature::QuadratureValuesContainer<RealType, memorySpace>
&rhoBatch,
linearAlgebra::LinAlgOpContext<memorySpace>
&linAlgOpContext);
};#ifdef DFTEFE_WITH_DEVICE
template <typename ValueType, typename RealType>
class DensityCalculatorKernels<
ValueType,
RealType,
utils::MemorySpace::DEVICE>
{
public:
static void computeRhoInBatch(
const utils::MemoryStorage<
RealType,
utils::MemorySpace::DEVICE> &occupationInBatch,
quadrature::QuadratureValuesContainer<
ValueType,
utils::MemorySpace::DEVICE> &psiBatchQuad,
quadrature::QuadratureValuesContainer<
RealType,
utils::MemorySpace::DEVICE> &modPsiSqBatchQuad,
std::shared_ptr<const quadrature::QuadratureRuleContainer>
quadRuleContainer,
quadrature::QuadratureValuesContainer<
RealType,
utils::MemorySpace::DEVICE> &rhoBatch,
linearAlgebra::LinAlgOpContext<
utils::MemorySpace::DEVICE> &linAlgOpContext);
};
#endifThe memory space is resolved at compile time, so no runtime branching is needed.
Alternative when templating the class would require large refactoring.
This pattern keeps the external interface unchanged.
class ScalarSpatialFunction
{
public:
template <utils::MemorySpace memorySpace>
void eval(
const size_type numPoints,
const double *t,
Q *q) const
{
if constexpr (memorySpace == utils::MemorySpace::DEVICE)
{
#ifdef DFTEFE_WITH_DEVICE
evalDevice(numPoints, t, q);
#else
utils::throwException(
false,
"eval<DEVICE>() called but DEVICE support not compiled.");
#endif
}
else
{
evalHost(numPoints, t, q);
}
}
protected:
virtual void evalHost(
size_type numPoints,
const double *t,
Q *q) const;
#ifdef DFTEFE_WITH_DEVICE
virtual void evalDevice(
size_type numPoints,
const double *t,
Q *q) const;
#endif
};class SmearChargePotentialFunction
: public ScalarSpatialFunctionReal
{
protected:
void evalHost(
size_type numPoints,
const double *t,
double *q) const override;
#ifdef DFTEFE_WITH_DEVICE
void evalDevice(
size_type numPoints,
const double *t,
double *q) const override;
#endif
};- Minimal API disruption
- No class-wide templating
- Preserves polymorphism
This is the hardest case in the current DFTEFE design.
This happens when a class member function itself must execute inside a CUDA/HIP kernel.
Example:
obj->getValueDevice(...)This is problematic because:
- Host objects cannot generally be dereferenced on device
- Virtual dispatch on device is difficult
- Host object pointers must be transferred to device
- Polymorphism does not map cleanly
Expose a device representation.
class SphericalDataNumerical : public SphericalData
{
public:
#ifdef DFTEFE_WITH_DEVICE
struct DeviceView
{
utils::SplineDeviceView radialSpline;
int l;
int m;
int mEff;
double constant;
double cutoff;
double smoothness;
double polarAngleTolerance;
DFTEFE_DEVICE_FUNC
double getValueDevice(
const double *point,
const double *origin) const;
};
DeviceView getDeviceView() const;
DFTEFE_DEVICE_FUNC
double getValueDevice(
const double *point,
const double *origin) const;
#endif
private:
#ifdef DFTEFE_WITH_DEVICE
DeviceView d_deviceView;
#endif
};#ifdef DFTEFE_WITH_DEVICE
SphericalDataNumerical::DeviceView
SphericalDataNumerical::getDeviceView() const
{
DeviceView v;
v.radialSpline = d_spline->getDeviceView();
v.l = d_qNumbers[1];
v.m = d_qNumbers[2];
v.mEff = std::abs(v.m);
v.constant = Clm(v.l, v.m) * Dm(v.m);
v.cutoff = d_cutoff;
v.smoothness = d_smoothness;
v.polarAngleTolerance = d_polarAngleTolerance;
return v;
}
#endifDFTEFE_DEVICE_FUNC
double
SphericalDataNumerical::DeviceView::getValueDevice(
const double *point,
const double *origin) const
{
...
}The key idea is:
- Construct
DeviceViewon host - Copy an array of
DeviceViews to device - Pass the device array directly into kernels
- Invoke device member functions from inside the kernel
namespace
{
DFTEFE_CREATE_KERNEL(
void,
evalEnrichmentInCell,
{
for (size_type iThread = globalThreadId;
iThread < numEnrichInCell * numQuadInCell;
iThread += nThreadsPerBlock * nThreadBlock)
{
const size_type enrichId = iThread % numQuadInCell;
const size_type quadId = iThread / numQuadInCell;
output[iThread] =
views[enrichId].getValueDevice(
quadPtsInCell + quadId * 3,
origin + enrichId * 3);
}
},
const double *quadPtsInCell,
const double *origin,
const size_type numEnrichInCell,
const size_type numQuadInCell,
const atoms::SphericalDataNumerical::DeviceView *views,
double *output);
}The important line is:
views[enrichId].getValueDevice(...)This is possible because views is a device-accessible array of DeviceView objects which have to be created,
void
EnrichmentDataEvalKernels<utils::MemorySpace::DEVICE>::
getEnrichmentValuesInCellRange(
const double *quadPtsInAllCells,
const double *originPtsInAllCells,
std::pair<size_type, size_type> cellRange,
const std::vector<size_type> numEnrichIdsInAllCells,
const std::vector<size_type> numQuadPtsInAllCells,
const atoms::SphericalDataNumerical::DeviceView *data,
double *output,
linearAlgebra::LinAlgOpContext<
utils::MemorySpace::DEVICE> &linAlgOpContext)
{
const size_type numStreams =
linAlgOpContext.numBlasStreams();
auto *streams =
linAlgOpContext.getBlasStreamsVec();
constexpr size_type dim = 3;
size_type cumulativeEnrichInCellRange = 0;
size_type cumulativeQuadPtsInCellRange = 0;
size_type cumulativeQuadxEnrichInCellRange = 0;
for (int iCell = 0; iCell < cellRange.first; iCell++)
{
const size_type numEnrichInCell =
numEnrichIdsInAllCells[iCell];
const size_type numQuadInCell =
numQuadPtsInAllCells[iCell];
cumulativeEnrichInCellRange += numEnrichInCell;
cumulativeQuadPtsInCellRange += numQuadInCell;
cumulativeQuadxEnrichInCellRange +=
numEnrichInCell * numQuadInCell;
}
size_type cumulativeCellWithNonZeroNumEnrich = 0;
for (int iCell = cellRange.first;
iCell < cellRange.second;
iCell++)
{
const size_type numEnrichInCell =
numEnrichIdsInAllCells[iCell];
const size_type numQuadInCell =
numQuadPtsInAllCells[iCell];
if (numEnrichInCell > 0)
{
const size_type sid =
cumulativeCellWithNonZeroNumEnrich % numStreams;
const size_type total =
numEnrichInCell * numQuadInCell;
const size_type blockSize =
utils::DEVICE_BLOCK_SIZE;
const size_type grid =
(total + blockSize - 1) / blockSize;
DFTEFE_LAUNCH_KERNEL(
evalEnrichmentInCell,
grid,
blockSize,
streams[sid],
quadPtsInAllCells +
cumulativeQuadPtsInCellRange * dim,
originPtsInAllCells +
cumulativeEnrichInCellRange * dim,
numEnrichInCell,
numQuadInCell,
data + cumulativeEnrichInCellRange,
output +
cumulativeQuadxEnrichInCellRange);
cumulativeCellWithNonZeroNumEnrich++;
}
cumulativeQuadPtsInCellRange += numQuadInCell;
cumulativeEnrichInCellRange += numEnrichInCell;
cumulativeQuadxEnrichInCellRange +=
numEnrichInCell * numQuadInCell;
}
for (int s = 0; s < numStreams; ++s)
utils::deviceStreamSynchronize(streams[s]);
}The kernel receives:
const atoms::SphericalDataNumerical::DeviceView *viewsEach thread can safely call:
views[i].getValueDevice(...)because:
-
DeviceViewcontains only device-safe data - no host pointers are dereferenced
- all state is explicitly copied to device memory