From b0076b2810dda0a31082705d32614cb275551c9b Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 9 Jun 2026 15:38:39 -0700 Subject: [PATCH] Add a OneOf lattice Elements of this lattice are elements of "one of" an arbitrary number of component lattices, or top or bottom. The elements are represented with a std::variant. Join and meet operations between elements of different component lattices produce top and bottom values, respectively. Join and meet between elements of the same component lattice return the result of the join or meet operation of that component lattice. Also add unit tests and lattice fuzzer support for the new lattice. TAG=agy --- src/analysis/lattices/one-of.h | 235 +++++++++++++++++++++++++++++++ src/tools/wasm-fuzz-lattices.cpp | 101 +++++++++++-- test/gtest/lattices.cpp | 123 +++++++++++++++- 3 files changed, 444 insertions(+), 15 deletions(-) create mode 100644 src/analysis/lattices/one-of.h diff --git a/src/analysis/lattices/one-of.h b/src/analysis/lattices/one-of.h new file mode 100644 index 00000000000..ac1ff8cd4dc --- /dev/null +++ b/src/analysis/lattices/one-of.h @@ -0,0 +1,235 @@ +/* + * Copyright 2026 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_analysis_lattices_one_of_h +#define wasm_analysis_lattices_one_of_h + +#include +#include +#include +#include + +#if __has_include() +#include +#endif + +#include "analysis/lattice.h" +#include "support/utilities.h" + +namespace wasm::analysis { + +// Elements of this lattice are elements of "one of" an arbitrary number of +// component lattices, or top or bottom. The elements are represented with a +// std::variant. Join and meet operations between elements of different +// component lattices produce top and bottom values, respectively. Join and meet +// between elements of the same component lattice return the result of the join +// or meet operation of that component lattice. +#if defined(__cpp_lib_concepts) +template +#else +template +#endif +struct OneOf { +private: + struct Bot : std::monostate {}; + struct Top : std::monostate {}; + + template using L = std::tuple_element_t>; + +public: + template using EI = typename L::Element; + + struct Element : std::variant { + using std::variant::variant; + bool isBottom() const noexcept { + return std::holds_alternative(*this); + } + bool isTop() const noexcept { return std::holds_alternative(*this); } + + template const U* getVal() const noexcept { + return std::get_if(this); + } + template U* getVal() noexcept { return std::get_if(this); } + template const EI* getVal() const noexcept { + return std::get_if(this); + } + template EI* getVal() noexcept { + return std::get_if(this); + } + + bool operator==(const Element& other) const noexcept { + return this->index() == other.index() && + std::visit( + [](const auto& a, const auto& b) { + if constexpr (std::is_same_v) { + return a == b; + } + return false; + }, + *this, + other); + } + bool operator!=(const Element& other) const noexcept { + return !(*this == other); + } + }; + + std::tuple lattices; + + OneOf(Ls&&... lattices) : lattices({std::move(lattices)...}) {} + OneOf() = default; + + Element getBottom() const noexcept { + return Element{std::in_place_type}; + } + Element getTop() const noexcept { return Element{std::in_place_type}; } + + template Element get(EI&& val) const noexcept { + return Element(std::in_place_index, std::move(val)); + } + + template Element get(const EI& val) const noexcept { + return Element(std::in_place_index, val); + } + +private: + static constexpr std::size_t BotIndex = sizeof...(Ls); + static constexpr std::size_t TopIndex = sizeof...(Ls) + 1; + + template + static constexpr auto makeCompareFuncs(std::index_sequence) noexcept { + using F = LatticeComparison (*)( + const std::tuple&, const Element&, const Element&); + return std::array{ + [](const std::tuple& lattices, + const Element& a, + const Element& b) -> LatticeComparison { + return std::get(lattices).compare(std::get(a), std::get(b)); + }...}; + } + static constexpr auto compareFuncs() noexcept { + return makeCompareFuncs(std::make_index_sequence{}); + } + + template + static constexpr auto makeJoinFuncs(std::index_sequence) noexcept { + using F = bool (*)(const std::tuple&, Element&, const Element&); + return std::array{[](const std::tuple& lattices, + Element& joinee, + const Element& joiner) -> bool { + return std::get(lattices).join(std::get(joinee), + std::get(joiner)); + }...}; + } + static constexpr auto joinFuncs() noexcept { + return makeJoinFuncs(std::make_index_sequence{}); + } + + template + static constexpr auto makeMeetFuncs(std::index_sequence) noexcept { + using F = bool (*)(const std::tuple&, Element&, const Element&); + return std::array{[](const std::tuple& lattices, + Element& meetee, + const Element& meeter) -> bool { + return std::get(lattices).meet(std::get(meetee), + std::get(meeter)); + }...}; + } + static constexpr auto meetFuncs() noexcept { + return makeMeetFuncs(std::make_index_sequence{}); + } + +public: + LatticeComparison compare(const Element& a, const Element& b) const noexcept { + if (a.index() == BotIndex && b.index() == BotIndex) { + return EQUAL; + } + if (a.index() == BotIndex) { + return LESS; + } + if (b.index() == BotIndex) { + return GREATER; + } + + if (a.index() == TopIndex && b.index() == TopIndex) { + return EQUAL; + } + if (a.index() == TopIndex) { + return GREATER; + } + if (b.index() == TopIndex) { + return LESS; + } + + if (a.index() != b.index()) { + return NO_RELATION; + } + + return compareFuncs()[a.index()](lattices, a, b); + } + + bool join(Element& joinee, const Element& joiner) const noexcept { + if (joiner.index() == BotIndex) { + return false; + } + if (joinee.index() == BotIndex) { + joinee = joiner; + return true; + } + if (joinee.index() == TopIndex) { + return false; + } + if (joiner.index() == TopIndex) { + joinee = Element(std::in_place_type); + return true; + } + if (joinee.index() != joiner.index()) { + joinee = Element(std::in_place_type); + return true; + } + return joinFuncs()[joinee.index()](lattices, joinee, joiner); + } + + bool meet(Element& meetee, const Element& meeter) const noexcept { + if (meeter.index() == TopIndex) { + return false; + } + if (meetee.index() == TopIndex) { + meetee = meeter; + return true; + } + if (meetee.index() == BotIndex) { + return false; + } + if (meeter.index() == BotIndex) { + meetee = Element(std::in_place_type); + return true; + } + if (meetee.index() != meeter.index()) { + meetee = Element(std::in_place_type); + return true; + } + return meetFuncs()[meetee.index()](lattices, meetee, meeter); + } +}; + +#if defined(__cpp_lib_concepts) +static_assert(FullLattice>); +#endif + +} // namespace wasm::analysis + +#endif // wasm_analysis_lattices_one_of_h diff --git a/src/tools/wasm-fuzz-lattices.cpp b/src/tools/wasm-fuzz-lattices.cpp index 4b29e8eb477..ecc23b151bb 100644 --- a/src/tools/wasm-fuzz-lattices.cpp +++ b/src/tools/wasm-fuzz-lattices.cpp @@ -27,6 +27,7 @@ #include "analysis/lattices/int.h" #include "analysis/lattices/inverted.h" #include "analysis/lattices/lift.h" +#include "analysis/lattices/one-of.h" #include "analysis/lattices/shared.h" #include "analysis/lattices/stack.h" #include "analysis/lattices/tuple.h" @@ -159,13 +160,17 @@ using ArrayLattice = analysis::Array; using TupleFullLattice = analysis::Tuple; using TupleLattice = analysis::Tuple; +using OneOfFullLattice = analysis::OneOf; +using OneOfLattice = analysis::OneOf; + using FullLatticeVariant = std::variant, ArrayFullLattice, Vector, - TupleFullLattice>; + TupleFullLattice, + OneOfFullLattice>; struct RandomFullLattice::LatticeImpl : FullLatticeVariant {}; @@ -176,7 +181,8 @@ using FullLatticeElementVariant = typename Inverted::Element, typename ArrayFullLattice::Element, typename Vector::Element, - typename TupleFullLattice::Element>; + typename TupleFullLattice::Element, + typename OneOfFullLattice::Element>; struct RandomFullLattice::ElementImpl : FullLatticeElementVariant {}; @@ -186,7 +192,8 @@ using LatticeVariant = std::variant, TupleLattice, - SharedPath>; + SharedPath, + OneOfLattice>; struct RandomLattice::LatticeImpl : LatticeVariant {}; @@ -197,11 +204,12 @@ using LatticeElementVariant = typename ArrayLattice::Element, typename Vector::Element, typename TupleLattice::Element, - typename SharedPath::Element>; + typename SharedPath::Element, + typename OneOfLattice::Element>; struct RandomLattice::ElementImpl : LatticeElementVariant {}; -constexpr int FullLatticePicks = 7; +constexpr int FullLatticePicks = 8; RandomFullLattice::RandomFullLattice(Random& rand, size_t depth, @@ -236,13 +244,18 @@ RandomFullLattice::RandomFullLattice(Random& rand, LatticeImpl{TupleFullLattice{RandomFullLattice{rand, depth + 1}, RandomFullLattice{rand, depth + 1}}}); return; + case 7: + lattice = std::make_unique( + LatticeImpl{OneOfFullLattice{RandomFullLattice{rand, depth + 1}, + RandomFullLattice{rand, depth + 1}}}); + return; } WASM_UNREACHABLE("unexpected pick"); } RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { // TODO: Limit the depth once we get lattices with more fan-out. - uint32_t pick = rand.upTo(FullLatticePicks + 6); + uint32_t pick = rand.upTo(FullLatticePicks + 7); if (pick < FullLatticePicks) { lattice = std::make_unique( @@ -274,13 +287,17 @@ RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { lattice = std::make_unique( LatticeImpl{SharedPath{RandomLattice{rand, depth + 1}}}); return; + case FullLatticePicks + 6: + lattice = std::make_unique(LatticeImpl{OneOfLattice{ + RandomLattice{rand, depth + 1}, RandomLattice{rand, depth + 1}}}); + return; } WASM_UNREACHABLE("unexpected pick"); } RandomFullLattice::Element RandomFullLattice::makeElement() const noexcept { if (std::get_if(lattice.get())) { - return ElementImpl{rand.pick(true, false)}; + return ElementImpl{Bool::Element{rand.pick(true, false)}}; } if (std::get_if(lattice.get())) { return ElementImpl{rand.upToSquared(33)}; @@ -337,6 +354,19 @@ RandomFullLattice::Element RandomFullLattice::makeElement() const noexcept { std::get<0>(l->lattices).makeElement(), std::get<1>(l->lattices).makeElement()}}; } + if (const auto* l = std::get_if(lattice.get())) { + auto pick = rand.upTo(4); + switch (pick) { + case 0: + return ElementImpl{l->getBottom()}; + case 1: + return ElementImpl{l->getTop()}; + case 2: + return ElementImpl{l->get<0>(std::get<0>(l->lattices).makeElement())}; + case 3: + return ElementImpl{l->get<1>(std::get<1>(l->lattices).makeElement())}; + } + } WASM_UNREACHABLE("unexpected lattice"); } @@ -381,6 +411,19 @@ RandomLattice::Element RandomLattice::makeElement() const noexcept { l->join(elem, l->lattice.makeElement()); return ElementImpl{elem}; } + if (const auto* l = std::get_if(lattice.get())) { + auto pick = rand.upTo(4); + switch (pick) { + case 0: + return ElementImpl{l->getBottom()}; + case 1: + return ElementImpl{l->getTop()}; + case 2: + return ElementImpl{l->get<0>(std::get<0>(l->lattices).makeElement())}; + case 3: + return ElementImpl{l->get<1>(std::get<1>(l->lattices).makeElement())}; + } + } WASM_UNREACHABLE("unexpected lattice"); } @@ -425,13 +468,24 @@ void printFullElement(std::ostream& os, indent(os, depth); os << "]\n"; } else if (const auto* e = - std::get_if(&*elem)) { - os << "Tuple(\n"; - const auto& [first, second] = *e; - printFullElement(os, first, depth + 1); - printFullElement(os, second, depth + 1); - indent(os, depth); - os << ")\n"; + std::get_if(&*elem)) { + if (e->isBottom()) { + os << "one-of bot\n"; + } else if (e->isTop()) { + os << "one-of top\n"; + } else if (const auto* val0 = e->getVal<0>()) { + os << "OneOf(0: \n"; + printFullElement(os, *val0, depth + 1); + indent(os, depth); + os << ")\n"; + } else if (const auto* val1 = e->getVal<1>()) { + os << "OneOf(1: \n"; + printFullElement(os, *val1, depth + 1); + indent(os, depth); + os << ")\n"; + } else { + WASM_UNREACHABLE("unexpected one-of element"); + } } else { WASM_UNREACHABLE("unexpected element"); } @@ -496,6 +550,25 @@ void printElement(std::ostream& os, printElement(os, **e, depth + 1); indent(os, depth); os << ")\n"; + } else if (const auto* e = + std::get_if(&*elem)) { + if (e->isBottom()) { + os << "one-of bot\n"; + } else if (e->isTop()) { + os << "one-of top\n"; + } else if (const auto* val0 = e->getVal<0>()) { + os << "OneOf(0: \n"; + printElement(os, *val0, depth + 1); + indent(os, depth); + os << ")\n"; + } else if (const auto* val1 = e->getVal<1>()) { + os << "OneOf(1: \n"; + printElement(os, *val1, depth + 1); + indent(os, depth); + os << ")\n"; + } else { + WASM_UNREACHABLE("unexpected one-of element"); + } } else { WASM_UNREACHABLE("unexpected element"); } diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 6f3a88deee7..a6098f8de12 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -23,6 +23,7 @@ #include "analysis/lattices/int.h" #include "analysis/lattices/inverted.h" #include "analysis/lattices/lift.h" +#include "analysis/lattices/one-of.h" #include "analysis/lattices/shared.h" #include "analysis/lattices/stack.h" #include "analysis/lattices/tuple.h" @@ -1389,5 +1390,125 @@ TEST_F(ConeTypeLatticeTest, Depths) { CHECK_PAIR(b1, b1, b1, b1); CHECK_PAIR(b1, c0, b1, c0); - CHECK_PAIR(b1, c1, b2, c0); +} + +TEST(OneOfLattice, GetBottom) { + analysis::OneOf lattice; + EXPECT_TRUE(lattice.getBottom().isBottom()); + EXPECT_FALSE(lattice.getBottom().isTop()); +} + +TEST(OneOfLattice, GetTop) { + analysis::OneOf lattice; + EXPECT_TRUE(lattice.getTop().isTop()); + EXPECT_FALSE(lattice.getTop().isBottom()); +} + +TEST(OneOfLattice, Compare) { + analysis::OneOf lattice; + + auto bot = lattice.getBottom(); + auto top = lattice.getTop(); + + auto b_false = lattice.get<0>(false); + auto b_true = lattice.get<0>(true); + + auto i_0 = lattice.get<1>(0); + auto i_10 = lattice.get<1>(10); + + // Bot and Top relations + EXPECT_EQ(lattice.compare(bot, bot), analysis::EQUAL); + EXPECT_EQ(lattice.compare(bot, b_false), analysis::LESS); + EXPECT_EQ(lattice.compare(bot, i_0), analysis::LESS); + EXPECT_EQ(lattice.compare(bot, top), analysis::LESS); + + EXPECT_EQ(lattice.compare(top, top), analysis::EQUAL); + EXPECT_EQ(lattice.compare(top, b_true), analysis::GREATER); + EXPECT_EQ(lattice.compare(top, i_10), analysis::GREATER); + EXPECT_EQ(lattice.compare(top, bot), analysis::GREATER); + + // Same lattice relations + EXPECT_EQ(lattice.compare(b_false, b_true), analysis::LESS); + EXPECT_EQ(lattice.compare(b_true, b_false), analysis::GREATER); + EXPECT_EQ(lattice.compare(b_false, b_false), analysis::EQUAL); + + EXPECT_EQ(lattice.compare(i_0, i_10), analysis::LESS); + EXPECT_EQ(lattice.compare(i_10, i_0), analysis::GREATER); + EXPECT_EQ(lattice.compare(i_0, i_0), analysis::EQUAL); + + // Different lattice relations + EXPECT_EQ(lattice.compare(b_false, i_0), analysis::NO_RELATION); + EXPECT_EQ(lattice.compare(i_10, b_true), analysis::NO_RELATION); +} + +TEST(OneOfLattice, Join) { + analysis::OneOf lattice; + + auto bot = lattice.getBottom(); + auto top = lattice.getTop(); + + auto b_false = lattice.get<0>(false); + auto b_true = lattice.get<0>(true); + + auto i_0 = lattice.get<1>(0); + auto i_10 = lattice.get<1>(10); + + auto test = + [&](const auto& joinee, const auto& joiner, const auto& expected) { + auto copy = joinee; + EXPECT_EQ(lattice.join(copy, joiner), joinee != expected); + EXPECT_EQ(copy, expected); + }; + + // Bot and Top joins + test(bot, bot, bot); + test(bot, b_false, b_false); + test(b_false, bot, b_false); + test(top, b_false, top); + test(b_false, top, top); + + // Same lattice joins + test(b_false, b_true, b_true); + test(b_true, b_false, b_true); + test(i_0, i_10, i_10); + + // Different lattice joins + test(b_false, i_0, top); + test(i_10, b_true, top); +} + +TEST(OneOfLattice, Meet) { + analysis::OneOf lattice; + + auto bot = lattice.getBottom(); + auto top = lattice.getTop(); + + auto b_false = lattice.get<0>(false); + auto b_true = lattice.get<0>(true); + + auto i_0 = lattice.get<1>(0); + auto i_10 = lattice.get<1>(10); + + auto test = + [&](const auto& meetee, const auto& meeter, const auto& expected) { + auto copy = meetee; + EXPECT_EQ(lattice.meet(copy, meeter), meetee != expected); + EXPECT_EQ(copy, expected); + }; + + // Bot and Top meets + test(bot, bot, bot); + test(bot, b_false, bot); + test(b_false, bot, bot); + test(top, b_false, b_false); + test(b_false, top, b_false); + + // Same lattice meets + test(b_false, b_true, b_false); + test(b_true, b_false, b_false); + test(i_0, i_10, i_0); + + // Different lattice meets + test(b_false, i_0, bot); + test(i_10, b_true, bot); }