Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,9 @@ cc_library(
":type_manager",
":value",
":value_flattening",
":xls_ir_interface_cc_proto",
":xls_type_cc_proto",
"//xls/common/file:filesystem",
"//xls/common/fuzzing:fuzztest",
"@com_google_absl//absl/log:check",
"@googletest//:gtest",
Expand Down
18 changes: 18 additions & 0 deletions xls/ir/value_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@
#include "xls/ir/value_test_util.h"

#include <cstdint>
#include <utility>
#include <vector>

#include "gtest/gtest.h"
#include "xls/common/fuzzing/fuzztest.h"
#include "absl/log/check.h"
#include "xls/common/file/filesystem.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_test_utils.h"
#include "xls/ir/fuzz_type_domain.h"
#include "xls/ir/type.h"
#include "xls/ir/type_manager.h"
#include "xls/ir/value.h"
#include "xls/ir/value_flattening.h"
#include "xls/ir/xls_ir_interface.pb.h"
#include "xls/ir/xls_type.pb.h"

namespace xls {
Expand Down Expand Up @@ -75,4 +79,18 @@ fuzztest::Domain<Value> ArbitraryValue(TypeProto type) {
return ArbitraryValue(fuzztest::Just(type));
}

fuzztest::Domain<Value> ElementOfDomain(std::string_view values_text_proto) {
xls::PackageInterfaceProto::FuzzTestDomain::ElementOf proto;
CHECK_OK(xls::ParseTextProto(values_text_proto, /*file_name=*/"", &proto));
std::vector<Value> values;
values.reserve(proto.values_size());
for (const auto& value_proto : proto.values()) {
auto value_or = Value::FromProto(value_proto);
CHECK_OK(value_or.status()) << "Failed to parse Value from proto: "
<< value_proto.ShortDebugString();
values.push_back(std::move(value_or.value()));
}
return fuzztest::ElementOf(values);
}

} // namespace xls
4 changes: 4 additions & 0 deletions xls/ir/value_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define XLS_IR_VALUE_TEST_UTIL_H_

#include <cstdint>
#include <string_view>

#include "gtest/gtest.h"
#include "xls/common/fuzzing/fuzztest.h"
Expand Down Expand Up @@ -43,6 +44,9 @@ fuzztest::Domain<Value> ArbitraryValue(fuzztest::Domain<TypeProto> type);
// Create a domain for an arbitrary value which is of the given type.
fuzztest::Domain<Value> ArbitraryValue(TypeProto type);

// Create an element_of domain from a serialized ElementOf proto (text format).
fuzztest::Domain<Value> ElementOfDomain(std::string_view values_text_proto);

} // namespace xls

#endif // XLS_IR_VALUE_TEST_UTIL_H_
9 changes: 9 additions & 0 deletions xls/ir/value_test_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,14 @@ TEST(ValueTestUtilTest, ValuesEqual) {
EXPECT_FALSE(ValuesEqual(Value(UBits(1, 1234)), Value(UBits(1, 10))));
}

void ElementOfDomainTestHelper(const Value& value) {
EXPECT_TRUE(value == Value(UBits(1, 32)) || value == Value(UBits(2, 32)));
}
FUZZ_TEST(ValueTestUtilTest, ElementOfDomainTestHelper)
.WithDomains(ElementOfDomain(R"pb(
values { bits { bit_count: 32 data: "\x01" } }
values { bits { bit_count: 32 data: "\x02" } }
)pb"));

} // namespace
} // namespace xls
3 changes: 2 additions & 1 deletion xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ pytype_strict_library(
"//xls/ir:xls_ir_interface_py_pb2",
"//xls/ir:xls_type_py_pb2",
"@abseil-py//absl:app",
"@com_google_protobuf//:protobuf_python",
"@xls_pip_deps//jinja2",
],
)
Expand Down Expand Up @@ -413,8 +414,8 @@ pytype_strict_contrib_test(
"//xls/common:runfiles",
"//xls/ir:xls_ir_interface_py_pb2",
"//xls/ir:xls_type_py_pb2",
"@abseil-py//absl:app",
"@abseil-py//absl/testing:absltest",
"@com_google_protobuf//:protobuf_python",
"@xls_pip_deps//jinja2",
],
)
Expand Down
32 changes: 16 additions & 16 deletions xls/jit/jit_wrapper_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional

from absl import app
from google.protobuf import text_format
import jinja2

from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2
Expand Down Expand Up @@ -368,35 +369,34 @@ def to_domain(
cpp_type = to_specialized(t, int_only=True)
if cpp_type is None:
if not can_use_uint64_range(t, d):
raise app.UsageError(
"Range domain is only supported for specializable bits types or"
" ranges fitting in 64 bits"
)
return None
cpp_type = "uint64_t"
min_val = extract_int_from_bytes(d.range.min.bits.data)
max_val = extract_int_from_bytes(d.range.max.bits.data)
return f"fuzztest::InRange<{cpp_type}>({min_val}, {max_val})"

if d.HasField("element_of"):
c_type = to_specialized(t, int_only=True)
if c_type is None:
raise app.UsageError(
"ElementOf domain only supported for specializable bits types in"
" this CL"
)
vals = [
str(extract_int_from_bytes(v.bits.data)) for v in d.element_of.values
]
return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})"
if c_type is not None:
vals = [
str(extract_int_from_bytes(v.bits.data)) for v in d.element_of.values
]
return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})"
else:
proto_str = text_format.MessageToString(d.element_of)
return f'xls::ElementOfDomain(R"pb({proto_str})pb")'

if d.HasField("tuple"):
if t.type_enum != type_pb2.TypeProto.TUPLE:
raise app.UsageError("Tuple domain requires Tuple type")
if len(d.tuple.elements) != len(t.tuple_elements):
raise app.UsageError("Tuple domain and type element count mismatch")
elems = [
to_domain(te, de) for te, de in zip(t.tuple_elements, d.tuple.elements)
]
elems = []
for te, de in zip(t.tuple_elements, d.tuple.elements):
elem_d = to_domain(te, de)
if elem_d is None:
return None
elems.append(elem_d)
return f"fuzztest::TupleOf({', '.join(elems)})"

raise app.UsageError(f"Unsupported domain: {d}")
Expand Down
51 changes: 44 additions & 7 deletions xls/jit/jit_wrapper_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl import app
from google.protobuf import text_format
import jinja2

from absl.testing import absltest
Expand Down Expand Up @@ -920,7 +920,7 @@ def test_tuple_with_array_domain(self):
' fuzztest::ArrayOf<3>(fuzztest::Arbitrary<uint32_t>()))',
)

def test_unsupported_domain_raises(self):
def test_unsupported_range_domain_returns_none(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32]
Expand All @@ -930,12 +930,49 @@ def test_unsupported_domain_raises(self):
d.range.min.bits.data = b'\x00'
d.range.max.bits.bit_count = 32
d.range.max.bits.data = b'\x0a'
self.assertIsNone(jit_wrapper_generator.to_domain(tup, d))

with self.assertRaisesRegex(
app.UsageError,
'Range domain is only supported for specializable bits types',
):
jit_wrapper_generator.to_domain(tup, d)
def test_element_of_domain_non_specializable(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
v1 = d.element_of.values.add()
v1.bits.bit_count = 128
v1.bits.data = b'\x01'
v2 = d.element_of.values.add()
v2.bits.bit_count = 128
v2.bits.data = b'\x02'
expected_proto_str = text_format.MessageToString(d.element_of)
self.assertEqual(
jit_wrapper_generator.to_domain(u128, d),
f'xls::ElementOfDomain(R"pb({expected_proto_str})pb")',
)

def test_range_domain_wide_bits_does_not_fit(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
d.range.min.bits.bit_count = 128
d.range.min.bits.data = b'\x01'
d.range.max.bits.bit_count = 128
d.range.max.bits.data = b'\x00\x00\x00\x00\x00\x00\x00\x00\x01'
self.assertIsNone(jit_wrapper_generator.to_domain(u128, d))

def test_tuple_domain_with_unsupported_child_fallback(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u128]
)
d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
d.tuple.elements.add().range.min.bits.bit_count = 32
d.tuple.elements[0].range.min.bits.data = b'\x00'
d.tuple.elements[0].range.max.bits.bit_count = 32
d.tuple.elements[0].range.max.bits.data = b'\x0a'
d_child2 = d.tuple.elements.add()
d_child2.range.min.bits.bit_count = 128
d_child2.range.min.bits.data = b'\x01'
d_child2.range.max.bits.bit_count = 128
d_child2.range.max.bits.data = b'\x00\x00\x00\x00\x00\x00\x00\x00\x01'
self.assertIsNone(jit_wrapper_generator.to_domain(tup, d))


class JitWrapperGeneratorToParamTest(absltest.TestCase):
Expand Down
Loading