Skip to content

mxfp4-gemm sample/checker reference appears broken (null expected rows / null diffs) #65

Description

@josusanmartin

Summary

As of April 23, 2026, mxfp4-gemm does not look solvable by iteration because the reference side for sample/check appears corrupted.

Problem page:
https://tensara.org/problems/mxfp4-gemm

What I observed

1. Public sample endpoint returns a corrupted expected output

Using the real Tensara sample endpoint with the public reference-style solution:

import torch
import torch.nn.functional as F

def solution(q_a, scale_a, q_b, scale_b, c, m, n, k):
    out = F.scaled_mm(
        q_a.view(torch.float4_e2m1fn_x2),
        q_b.view(torch.float4_e2m1fn_x2).t(),
        scale_a=scale_a.view(torch.float8_e8m0fnu).flatten(),
        scale_recipe_a=F.ScalingType.BlockWise1x32,
        swizzle_a=F.SwizzleType.SWIZZLE_32_4_4,
        scale_b=scale_b.view(torch.float8_e8m0fnu).flatten(),
        scale_recipe_b=F.ScalingType.BlockWise1x32,
        swizzle_b=F.SwizzleType.SWIZZLE_32_4_4,
        output_dtype=torch.float32,
    )
    c.copy_(out)

The sample response contains an expected_output where the lower half of the rows are serialized as null, while the computed output contains finite values in the upper half and null/NaN in the lower half.

That strongly suggests the sample/reference side is already producing NaNs or corrupted values for part of the output.

2. Hidden checker returns WRONG_ANSWER with debug_info.max_difference = null

Even a trivial all-zero output:

def solution(q_a, scale_a, q_b, scale_b, c, m, n, k):
    c.zero_()

returns:

  • status = WRONG_ANSWER
  • debug_info.max_difference = null
  • debug_info.mean_difference = null

If the reference were finite, those fields should be numeric.

3. Diagnostic comparator shows internal disagreement around the same case

On the hidden 1024 x 1024 x 1024 case, a diagnostic submission comparing three paths produced:

  • full_manual max=225.551483 mean=30.069611
  • full_tiled max=230.004395 mean=21.021412
  • manual_tiled max=219.252747 mean=30.071690
  • full00=-47.064709
  • manual00=48.734375
  • tiled00=-47.064709

So the scaled_mm family is self-consistent (full and tiled agree on [0,0]), but the checker path still reports a generic WRONG_ANSWER with null diffs.

4. Input layout is not the issue

Both local and public sample runs confirm the sample tensors are already shaped, not flattened:

  • q_a=(32, 16)
  • scale_a=(32, 16)
  • q_b=(32, 16)
  • scale_b=(32, 16)
  • c=(32, 32)

So this does not look like a simple user-side layout mismatch.

Why this is a blocker

The current sample/check behavior makes mxfp4-gemm impossible to debug reliably:

  • sample returns corrupted expected values
  • checker reports null diff metrics even for obviously wrong outputs
  • public reference-style scaled_mm path cannot be validated against a stable oracle

Repro environment

Observed on April 23, 2026 against:

  • real Tensara sample endpoint
  • real Tensara direct-submit/check path
  • local self-host mirror reproduces the same null-diff behavior

Public submission links

Initial public attempts:

Recent public diagnostic and sanity-check attempts:

Requested fix

Please verify the mxfp4-gemm reference/checker implementation, especially:

  • the reference computation used for sample/check
  • whether the expected output is producing NaNs in the lower rows
  • why debug_info.max_difference / mean_difference become null instead of numeric values on wrong answers

Once that is fixed, candidate solutions can be re-run against a trustworthy oracle.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions