Skip to content

Commit

Permalink
Remove useless const
Browse files Browse the repository at this point in the history
  • Loading branch information
seven332 committed Sep 27, 2024
1 parent f8c5700 commit b3e0977
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 31 deletions.
20 changes: 20 additions & 0 deletions minifier/src/minifier_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,24 @@ fn average(a: f32, b: f32) -> f32 {
EXPECT_THAT(result.remappings, testing::UnorderedElementsAre(testing::Pair("vs1", "a")));
}

TEST(minifier, RemoveUselessConst) {
auto result = Minify(
R"(
const i = 2;
const j = 2;
@vertex fn vs1() -> @builtin(position) vec4f {
return vec4f(1) / i;
}
)",
{}
);
EXPECT_FALSE(result.failed);
EXPECT_EQ(
Write(result.program),
"const a = 2;\n\n@vertex\nfn b() -> @builtin(position) vec4f {\n return (vec4f(1) / a);\n}\n"
);
EXPECT_THAT(result.remappings, testing::UnorderedElementsAre(testing::Pair("vs1", "b")));
}

} // namespace wgslx::minifier
144 changes: 113 additions & 31 deletions minifier/src/remove_useless.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "remove_useless.h"

#include <src/tint/lang/wgsl/ast/const.h>
#include <src/tint/lang/wgsl/ast/function.h>
#include <src/tint/lang/wgsl/ast/identifier.h>
#include <src/tint/lang/wgsl/ast/traverse_expressions.h>
Expand All @@ -8,13 +9,13 @@
#include <src/tint/lang/wgsl/resolver/resolve.h>
#include <src/tint/utils/symbol/symbol.h>

#include <iostream>
#include <range/v3/range/conversion.hpp>
#include <range/v3/view/filter.hpp>
#include <range/v3/view/transform.hpp>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <variant>
#include <vector>

#include "traverser.h"
Expand All @@ -23,53 +24,134 @@ TINT_INSTANTIATE_TYPEINFO(wgslx::minifier::RemoveUseless);

namespace wgslx::minifier {

struct FuncRef {
std::unordered_set<tint::Symbol> funcs;
const tint::ast::Function* func;
template<class... Ts>
struct Match : Ts... {
using Ts::operator()...;
};

template<class... Ts>
Match(Ts...) -> Match<Ts...>;

template<typename... Ts, typename... Fs>
constexpr decltype(auto) operator|(const std::variant<Ts...>& v, const Match<Fs...>& match) {
return std::visit(match, v);
}

class Entity {
public:
explicit Entity(const tint::ast::Function* ptr) : ptr_(ptr) {}
explicit Entity(const tint::ast::Const* ptr) : ptr_(ptr) {}

[[nodiscard]] tint::Symbol Symbol() const {
return ptr_ | Match {
[](const tint::ast::Function* f) { return f->name->symbol; },
[](const tint::ast::Const* c) { return c->name->symbol; },
};
}

[[nodiscard]] const tint::ast::Node* Ptr() const {
return ptr_ | Match {
[](const tint::ast::Function* f) -> const tint::ast::Node* { return f; },
[](const tint::ast::Const* c) -> const tint::ast::Node* { return c; },
};
}

[[nodiscard]] bool IsEntryPoint() const {
return ptr_ | Match {
[](const tint::ast::Function* f) { return f->IsEntryPoint(); },
[](const tint::ast::Const* /* c */) { return false; },
};
}

private:
std::variant<const tint::ast::Function*, const tint::ast::Const*> ptr_;
};

struct Element {
Entity self;
std::unordered_set<tint::Symbol> refs;
bool visited = false;
};

using FuncRefs = std::unordered_map<tint::Symbol, FuncRef>;
using Elements = std::unordered_map<tint::Symbol, Element>;

static Elements CollectElements(const tint::Program& program) {
auto elements = program.AST().GlobalDeclarations() | ranges::views::filter([](const tint::ast::Node* node) {
return node->Is<tint::ast::Function>() || node->Is<tint::ast::Const>();
}) |
ranges::views::transform([](const tint::ast::Node* node) {
if (node->Is<tint::ast::Function>()) {
const auto* function = node->As<tint::ast::Function>();
return std::make_pair(function->name->symbol, Element {.self = Entity(function)});
} else {
const auto* c = node->As<tint::ast::Const>();
return std::make_pair(c->name->symbol, Element {.self = Entity(c)});
}
}) |
ranges::to<Elements>();

static FuncRefs CollectFuncRefs(const tint::Program& program) {
auto refs = program.AST().Functions() | ranges::views::transform([](const tint::ast::Function* function) {
return std::make_pair(function->name->symbol, FuncRef {.func = function});
}) |
ranges::to<FuncRefs>();
for (const auto* node : program.AST().GlobalDeclarations()) {
tint::Symbol symbol;
const tint::ast::Statement* statement = nullptr;
const tint::ast::Expression* expression = nullptr;
if (node->Is<tint::ast::Function>()) {
const auto* function = node->As<tint::ast::Function>();
symbol = function->name->symbol;
statement = function->body;
} else if (node->Is<tint::ast::Const>()) {
const auto* c = node->As<tint::ast::Const>();
symbol = c->name->symbol;
expression = c->initializer;
} else {
continue;
}

auto iter = elements.find(symbol);
if (iter == elements.end()) {
continue;
}

for (const auto* func : program.AST().Functions()) {
auto& funcs = refs[func->name->symbol].funcs;
Traverse(func->body, [&](const tint::ast::Identifier* i) {
if (refs.contains(i->symbol)) {
funcs.emplace(i->symbol);
auto& refs = iter->second.refs;
Traverse(statement, [&](const tint::ast::Identifier* i) {
if (elements.contains(i->symbol)) {
refs.emplace(i->symbol);
}
});
Traverse(expression, [&](const tint::ast::Identifier* i) {
if (elements.contains(i->symbol)) {
refs.emplace(i->symbol);
}
});
}

return refs;
return elements;
}

static void MarkVisited(FuncRefs& refs, const tint::Symbol& symbol) {
auto& ref = refs[symbol];
if (ref.visited) {
static void MarkVisited(Elements& elements, const tint::Symbol& symbol) {
auto iter = elements.find(symbol);
if (iter == elements.end()) {
return;
}

if (iter->second.visited) {
return;
}
ref.visited = true;
iter->second.visited = true;

for (const auto& s : ref.funcs) {
MarkVisited(refs, s);
for (const auto& s : iter->second.refs) {
MarkVisited(elements, s);
}
}

static std::vector<const tint::ast::Function*> FindUselessFunctions(const tint::Program& program) {
auto refs = CollectFuncRefs(program);
for (const auto& [symbol, ref] : refs) {
if (ref.func->IsEntryPoint()) {
MarkVisited(refs, symbol);
static std::vector<const tint::ast::Node*> FindUseless(const tint::Program& program) {
auto elements = CollectElements(program);
for (const auto& [symbol, element] : elements) {
if (element.self.IsEntryPoint()) {
MarkVisited(elements, symbol);
}
}
return refs | ranges::views::filter([](const auto& p) { return !p.second.visited; }) |
ranges::views::transform([](const auto& p) { return p.second.func; }) | ranges::to<std::vector>();
return elements | ranges::views::filter([](const auto& p) { return !p.second.visited; }) |
ranges::views::transform([](const auto& p) { return p.second.self.Ptr(); }) | ranges::to<std::vector>();
}

RemoveUseless::ApplyResult RemoveUseless::Apply(
Expand All @@ -79,8 +161,8 @@ RemoveUseless::ApplyResult RemoveUseless::Apply(
) const {
tint::ProgramBuilder builder;
tint::program::CloneContext ctx(&builder, &program, true);
for (const auto* func : FindUselessFunctions(*ctx.src)) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), func);
for (const auto* node : FindUseless(*ctx.src)) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), node);
}
ctx.Clone();
return tint::resolver::Resolve(builder);
Expand Down

0 comments on commit b3e0977

Please sign in to comment.