Skip to content

Some bugs in model inference #14

@hdadong

Description

@hdadong

Hi, I noticed a potential parameter order mismatch when calling reward_fn in the training logic.

In ssrl/brax/training/agents/ssrl/networks.py#L133 the current implementation is:

reward = c.reward_fn(obs, obs_next, jp.mean(us), action)

However, the reward_fn definition in go1_go_fast.py` shows the expected parameter order as:

def reward_fn(obs_next, obs, us, action):

The parameters obs and obs_next should be swapped in the function call.

Additionally, I think it should use the clean obs (obs_next_mean) to compute rewards instead of using the noise version (obs_next), so it should be like this:
Proposed Fix:

reward = c.reward_fn(obs_next_mean, obs, jp.mean(us), action)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions