Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions symforce/caspar/source/runtime/pybind_array_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ void AssertUint2Vec(const py::object& obj) {
Assert2DNxk(obj, 2);
}

int GetDeviceId(const py::object& obj) {
try {
auto interface = obj.attr("__cuda_array_interface__").cast<py::dict>();
auto data = interface["data"].cast<py::tuple>();
void* ptr = reinterpret_cast<void*>(data[0].cast<size_t>());
cudaPointerAttributes attrs;
cudaError_t err = cudaPointerGetAttributes(&attrs, ptr);
if (err != cudaSuccess) {
cudaGetLastError();
return -1;
}
return attrs.device;
} catch (...) {
return -1; // Fallback if interface or attributes aren't available
}
}

float* AsFloatPtr(const py::object& obj) {
AssertFloatVec(obj);
py::tuple data = GetInterface(obj)["data"].cast<py::tuple>();
Expand Down
2 changes: 2 additions & 0 deletions symforce/caspar/source/runtime/pybind_array_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void AssertDeviceMemory(const py::object& obj);
void AssertNumRowsEquals(const py::object& obj, size_t n);
void AssertNumColsEquals(const py::object& obj, size_t n);

int GetDeviceId(const py::object& obj);

float* AsFloatPtr(const py::object& obj);
double* AsDoublePtr(const py::object& obj);
int* AsIntPtr(const py::object& obj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void add_casmappings_pybindings(pybind11::module_ module) {
throw std::runtime_error(
"The caspar data must have at least as many columns as stacked_data has rows.");
}
cudaSetDevice(GetDeviceId(stacked_data));
{{nodetype.__name__}}StackedToCaspar(
As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), As{{caslib.storage_t.capitalize()}}Ptr(cas_data), cas_stride, 0, num_objects);
});
Expand All @@ -45,7 +46,7 @@ void add_casmappings_pybindings(pybind11::module_ module) {
throw std::runtime_error(
"The caspar data must have at least as many columns as stacked_data has rows.");
}

cudaSetDevice(GetDeviceId(cas_data));
{{nodetype.__name__}}CasparToStacked(
As{{caslib.storage_t.capitalize()}}Ptr(cas_data), As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), cas_stride, 0, num_objects);
});
Expand Down
1 change: 1 addition & 0 deletions symforce/caspar/source/templates/lib.pyi.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class {{solver.struct_name}}:
{% for thing in solver.size_contributors %}
{{num_arg_key(thing)}}: int = 0,
{% endfor %}
device_id: int = 0,
): ...

def set_params(self, params: SolverParams) -> None:
Expand Down
36 changes: 33 additions & 3 deletions symforce/caspar/source/templates/solver.cc.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ namespace caspar {
{{ solver.struct_name }}::{{ solver.struct_name }}(
const SolverParams<double> &params,
{% for thing in solver.size_contributors %}
size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
size_t {{num_arg_key(thing)}}{{ ", " }}
{% endfor %}
int device_id
)
: params_(params),
device_id_(device_id),
{% for thing in solver.size_contributors %}
{{num_key(thing)}}({{num_arg_key(thing)}}),
{{num_max_key(thing)}}({{num_arg_key(thing)}}){{ ", " if not loop.last else "" }}
Expand All @@ -85,6 +87,20 @@ namespace caspar {
throw std::runtime_error("params.diag_init must be positive");
}
allocation_size_ = get_nbytes();

if (device_id_ < 0) {
throw std::runtime_error("Invalid CUDA device id: " + std::to_string(device_id_));
}
if (device_id_ != 0) {
int deviceCount;
cudaGetDeviceCount(&deviceCount);
if (deviceCount <= device_id_) {
throw std::runtime_error("CUDA detected " + std::to_string(deviceCount) +
" devices, but device " + std::to_string(device_id_) +
" was requested (0-indexed)");
}
}
cudaSetDevice(device_id_);
cudaMalloc(&origin_ptr_, allocation_size_);

size_t offset = 0;
Expand All @@ -97,6 +113,7 @@ namespace caspar {
}

{{ solver.struct_name }}::~{{ solver.struct_name }}(){
cudaSetDevice(device_id_);
cudaFree(origin_ptr_);
}

Expand All @@ -110,6 +127,7 @@ size_t {{ solver.struct_name }}::get_allocation_size(){


SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_logging) {
cudaSetDevice(device_id_);
SolveResult result;
result.exit_reason = ExitReason::MAX_ITERATIONS;
{{solver.linear_t}} score_best;
Expand Down Expand Up @@ -634,6 +652,7 @@ void {{ solver.struct_name }}::finish_indices() {

{% for nodetype in solver.node_types %}
void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) {
cudaSetDevice(device_id_);
if (num > {{num_max_key(nodetype)}}) {
throw std::runtime_error(std::to_string(num) + " > {{num_max_key(nodetype)}}");
}
Expand All @@ -642,6 +661,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) {

void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost(
const {{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -654,6 +674,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost(

void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice(
const {{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -663,6 +684,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice(

void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost(
{{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -675,6 +697,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost(

void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -695,6 +718,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% if fac.isnodeshared[arg] %}
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -708,7 +732,8 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

cudaSetDevice(device_id_);

if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand Down Expand Up @@ -737,6 +762,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% endif %}
) {
cudaSetDevice(device_id_);
{% if fac.isconstuniq[arg] %}
const size_t offset = 0;
const size_t num = 1;
Expand Down Expand Up @@ -769,6 +795,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% elif fac.isconstindexed[arg] %}
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% endif %} ) {
cudaSetDevice(device_id_);
{% if fac.isconstuniq[arg] %}
const size_t offset = 0;
const size_t num = 1;
Expand All @@ -791,6 +818,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% if fac.isconstshared[arg] %}
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -804,7 +832,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -824,6 +852,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -836,6 +865,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand Down
4 changes: 3 additions & 1 deletion symforce/caspar/source/templates/solver.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class {{ solver.struct_name }} {
{{ solver.struct_name }}(
const SolverParams<double> &params,
{% for thing in solver.size_contributors %}
size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
size_t {{num_arg_key(thing)}}{{ ", " }}
{% endfor %}
int device_id = 0
);

// This class is managing cuda memory and cannot be copied.
Expand Down Expand Up @@ -210,6 +211,7 @@ class {{ solver.struct_name }} {

private:
SolverParams<{{solver.linear_t}}> params_;
int device_id_;
uint8_t* origin_ptr_;
size_t scratch_inout_size_;
size_t allocation_size_;
Expand Down
6 changes: 4 additions & 2 deletions symforce/caspar/source/templates/solver_pybinding.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ inline void add_solver_pybinding(pybind11::module_ module) {
.def(py::init<SolverParams<double>,
{% for thing in solver.size_contributors %}
size_t{{ ", " if not loop.last else "" }}
{% endfor %}>(),
{% endfor %},
int>(),
py::arg("params"),
py::kw_only(),
{% for thing in solver.size_contributors %}
py::arg("{{num_arg_key(thing)}}") = 0{{ ", " if not loop.last else "" }}
py::arg("{{num_arg_key(thing)}}") = 0{{ ", " }}
{% endfor %}
py::arg("device_id") = 0
)

.def("set_params", &{{solver.struct_name}}::set_params)
Expand Down
Loading