Skip to content

Incorrect branch handling during derivative computation for Max function #445

@akhadke-bdai

Description

@akhadke-bdai

Describe the bug
I am attempting to generate PyTorch code for computing the derivative/jacobian for the following function $w = \sqrt{\max(0, x^2 - z^2)}$. The expected derivative/jacobian for this function is as follows

$\frac{\partial w}{\partial x} = 0$ if $x^2 < z^2$
undefined if $x^2 = z^2$
$\frac{\partial w}{\partial x} = \frac{x}{\sqrt{x^2 - z^2}}$ if $x^2 > z^2$

However, symforce codegen computes the following
$\left(\frac{\partial w}{\partial x}\right)_{\text{codegen}} = \frac{x}{2} \left(\frac{sign(x^2 - z^2) + 1}{\sqrt{\max(0, x^2 - z^2)}}\right)$
which leads to nan values when evaluating any $x: x^2 < z^2$

To Reproduce
My script to generate the symbolic code

import os
import symforce

symforce.set_symbolic_api("symengine")
symforce.set_log_level("warning")
symforce.set_epsilon_to_symbol()

import symforce.symbolic as se
from symforce import codegen
from symforce.codegen.backends.pytorch.pytorch_config import PyTorchConfig
from symforce.values import Values

def gen_py(inputs, outputs, name):
    gen = codegen.Codegen(
        inputs=inputs,
        outputs=outputs,
        config=PyTorchConfig(),
        name=name,
    )
    data = gen.generate_function()

    # Print what we generated
    print("Files generated in {}:\n".format(data.output_dir))
    for f in data.generated_files:
        print("  |- {}".format(os.path.relpath(f, data.output_dir)))

def main():
    x = se.Symbol('x')
    z = se.Symbol('z')
    y = se.Max(se.Scalar(0.0), (x**2 - z**2))
    w = se.sqrt(y)

    dwdx = [se.diff(w, x), se.diff(w, z)]

    inputs = Values(x=se.Matrix([x, z]))
    outputs = Values(dwdx=se.Matrix(dwdx))
    gen_py(inputs=inputs, outputs=outputs, name='dwdx')

if __name__=='__main__':
    main()

Expected behavior
See issue description

Environment (please complete the following information):

  • OS and version: Ubuntu 22.04.5 LTS
  • Python version 3.10.12
  • SymForce Version 0.10.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions