diff --git a/python/python/tests/test_arrow.py b/python/python/tests/test_arrow.py index 4215866c2d5..b6e6024e0e7 100644 --- a/python/python/tests/test_arrow.py +++ b/python/python/tests/test_arrow.py @@ -48,6 +48,16 @@ def test_bf16_value(): assert not any(comparison for comparison in should_be_false) +def test_bf16_from_bytes_roundtrip(): + assert BFloat16.from_bytes(b"\xc0\x3f") == BFloat16(1.5) + + +@pytest.mark.parametrize("bad", [b"", b"\x00", b"\x00\x00\x00", b"\x00" * 4]) +def test_bf16_from_bytes_invalid_length_raises(bad): + with pytest.raises(ValueError, match="expected 2 bytes"): + BFloat16.from_bytes(bad) + + def test_bf16_repr(): data = [1.1, None, 3.4] arr = bfloat16_array(data) diff --git a/python/src/arrow.rs b/python/src/arrow.rs index 0a628e52f6f..97b70f0e61c 100644 --- a/python/src/arrow.rs +++ b/python/src/arrow.rs @@ -36,10 +36,10 @@ impl BFloat16 { #[classmethod] fn from_bytes(_cls: &Bound<'_, PyType>, bytes: &[u8]) -> PyResult { if bytes.len() != 2 { - PyValueError::new_err(format!( + return Err(PyValueError::new_err(format!( "BFloat16::from_bytes: expected 2 bytes, got {}", bytes.len() - )); + ))); } Ok(Self(bf16::from_bits(u16::from_ne_bytes([ bytes[0], bytes[1],