Skip to content

Commit

Permalink
add support for indirect calls that return an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
nunoplopes committed Sep 11, 2023
1 parent bebdcfa commit 85c1061
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 39 deletions.
26 changes: 26 additions & 0 deletions ir/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ void Function::syncDataWithSrc(Function &src) {
};
copy_fns(src, *this);
copy_fns(*this, src);

for (auto &decl : fn_decls) {
src.addFnDecl(FnDecl(decl));
}
}

Function::instr_iterator::
Expand Down Expand Up @@ -313,6 +317,13 @@ static void add_users(Function::UsersTy &users, Value *i, BasicBlock *bb,
users[val].emplace(i, bb);
}

void Function::addFnDecl(FnDecl &&decl) {
if (find_if(fn_decls.begin(), fn_decls.end(),
[&](auto &d) { return d.name == decl.name; }) != fn_decls.end())
return;
fn_decls.emplace_back(std::move(decl));
}

Function::UsersTy Function::getUsers() const {
UsersTy users;
for (auto *bb : getBBs()) {
Expand Down Expand Up @@ -798,6 +809,21 @@ void Function::unroll(unsigned k) {
}

void Function::print(ostream &os, bool print_header) const {
if (!fn_decls.empty()) {
for (auto &decl : fn_decls) {
os << "declare " << *decl.output << ' ' << decl.name << '(';
bool first = true;
for (auto &input : decl.inputs) {
if (!first)
os << ", ";
os << input.second << *input.first;
first = false;
}
os << ")\n";
}
os << '\n';
}

{
const auto &gvars = getGlobalVars();
if (!gvars.empty()) {
Expand Down
14 changes: 14 additions & 0 deletions ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ class Function final {

FnAttrs attrs;

// TODO: Move this to a 'program' class
public:
struct FnDecl {
std::string name;
std::vector<std::pair<Type*, ParamAttrs>> inputs;
Type *output;
FnAttrs attrs;
};
private:
std::vector<FnDecl> fn_decls;

public:
Function() = default;
Function(Type &type, std::string &&name, unsigned bits_pointers = 64,
Expand Down Expand Up @@ -186,6 +197,9 @@ class Function final {
instr_helper instrs() { return *this; }
instr_helper instrs() const { return *this; }

void addFnDecl(FnDecl &&decl);
auto& getFnDecls() const { return fn_decls; }

using UsersTy = std::unordered_map<const Value*,
std::set<std::pair<Value*, BasicBlock*>>>;
UsersTy getUsers() const;
Expand Down
16 changes: 14 additions & 2 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,9 @@ StateValue FnCall::toSMT(State &s) const {
fnName_mangled << fnName;
}

optional<StateValue> ret_val;
vector<StateValue> ret_vals;

for (auto &[arg, flags] : args) {
// we duplicate each argument so that undef values are allowed to take
// different values so we can catch the bug in f(freeze(undef)) -> f(undef)
Expand All @@ -2402,6 +2405,14 @@ StateValue FnCall::toSMT(State &s) const {
sv2 = s.eval(*arg, true);
}

if (fnptr)
ret_vals.emplace_back(sv);

if (flags.has(ParamAttrs::Returned)) {
assert(!ret_val);
ret_val = sv;
}

unpack_inputs(s, *arg, arg->getType(), flags, std::move(sv), std::move(sv2),
inputs, ptr_inputs);
fnName_mangled << '#' << arg->getType();
Expand Down Expand Up @@ -2484,8 +2495,9 @@ StateValue FnCall::toSMT(State &s) const {
}

unsigned idx = 0;
auto ret = s.addFnCall(fnName_mangled.str(), std::move(inputs),
std::move(ptr_inputs), out_types, attrs);
auto ret = s.addFnCall(std::move(fnName_mangled).str(), std::move(inputs),
std::move(ptr_inputs), out_types, std::move(ret_val),
std::move(ret_vals), attrs);

return isVoid() ? StateValue()
: pack_return(s, getType(), ret, attrs, idx, args);
Expand Down
69 changes: 60 additions & 9 deletions ir/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,14 @@ expr State::FnCallInput::refinedBy(
return refines();
}

State::FnCallOutput
State::FnCallOutput::replace(const optional<StateValue> &retval) const {
FnCallOutput copy = *this;
assert(retvals.size() == 1);
copy.retvals[0] = *retval;
return copy;
}

State::FnCallOutput State::FnCallOutput::mkIf(const expr &cond,
const FnCallOutput &a,
const FnCallOutput &b) {
Expand All @@ -882,6 +890,7 @@ State::FnCallOutput State::FnCallOutput::mkIf(const expr &cond,
ret.noreturns = expr::mkIf(cond, a.noreturns, b.noreturns);
ret.callstate = Memory::CallState::mkIf(cond, a.callstate, b.callstate);
assert(a.retvals.size() == b.retvals.size());
assert(a.ret_data.size() == b.ret_data.size());
for (unsigned i = 0, e = a.retvals.size(); i != e; ++i) {
ret.retvals.emplace_back(
StateValue::mkIf(cond, a.retvals[i], b.retvals[i]));
Expand All @@ -904,11 +913,19 @@ expr State::FnCallOutput::operator==(const FnCallOutput &rhs) const {
vector<StateValue>
State::addFnCall(const string &name, vector<StateValue> &&inputs,
vector<Memory::PtrInput> &&ptr_inputs,
const vector<Type*> &out_types, const FnAttrs &attrs) {
bool noret = attrs.has(FnAttrs::NoReturn);
const vector<Type*> &out_types, optional<StateValue> &&ret_arg,
vector<StateValue> &&ret_args, const FnAttrs &attrs) {
bool noret = attrs.has(FnAttrs::NoReturn);
bool willret = attrs.has(FnAttrs::WillReturn);
bool noundef = attrs.has(FnAttrs::NoUndef);
bool noalias = attrs.has(FnAttrs::NoAlias);
bool is_indirect = name.starts_with("#indirect_call");

expr fn_ptr;
if (is_indirect) {
assert(inputs.size() >= 1);
fn_ptr = inputs[0].value;
}

assert(!noret || !willret);

Expand Down Expand Up @@ -972,12 +989,46 @@ State::addFnCall(const string &name, vector<StateValue> &&inputs,
vector<Memory::FnRetData> ret_data;
string valname = name + "#val";
string npname = name + "#np";
for (auto t : out_types) {
auto [val, data] = mk_val(*t, valname);
values.emplace_back(
std::move(val),
noundef ? expr(true) : expr::mkFreshVar(npname.c_str(), false));
ret_data.emplace_back(std::move(data));
if (ret_arg) {
assert(out_types.size() == 1);
values.emplace_back(std::move(ret_arg).value());
ret_data.emplace_back();
} else {
for (auto t : out_types) {
auto [val, data] = mk_val(*t, valname);
values.emplace_back(
std::move(val),
noundef ? expr(true) : expr::mkFreshVar(npname.c_str(), false));
ret_data.emplace_back(std::move(data));
}
}

// Indirect calls may be changed into direct in tgt
// Account for this if we have declarations with a returned argument
// to limit the behavior of the SMT var.
if (is_indirect) {
for (auto &decl : getFn().getFnDecls()) {
if (decl.inputs.size() != ret_args.size())
continue;
unsigned idx = 0;
for (auto &[ty, attrs] : decl.inputs) {
if (attrs.has(ParamAttrs::Returned)) {
auto &ret = ret_args[idx];
if (ret.value.isSameTypeOf(values[0].value)) {
assert(values.size() == 1);
auto gv = getFn().getConstant(string_view(decl.name).substr(1));
assert(gv);
auto cmp = Pointer(memory, fn_ptr).getAddress() ==
Pointer(memory, (*this)[*gv].value).getAddress();
values[0].value = expr::mkIf(cmp, ret.value, values[0].value);
values[0].non_poison
= expr::mkIf(cmp, ret.non_poison, values[0].non_poison);
}
break;
}
++idx;
}
}
}

I->second
Expand Down Expand Up @@ -1016,7 +1067,7 @@ State::addFnCall(const string &name, vector<StateValue> &&inputs,
auto refined = in.refinedBy(*this, name, inaccessible_bid, inputs,
ptr_inputs, call_ranges, memory, attrs.mem,
noret, willret);
data.add(out, std::move(refined));
data.add(ret_arg ? out.replace(ret_arg) : out, std::move(refined));
}

if (data) {
Expand Down
6 changes: 5 additions & 1 deletion ir/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class State {
Memory::CallState callstate;
std::vector<Memory::FnRetData> ret_data;

FnCallOutput replace(const std::optional<StateValue> &retval) const;

static FnCallOutput mkIf(const smt::expr &cond, const FnCallOutput &then,
const FnCallOutput &els);
smt::expr operator==(const FnCallOutput &rhs) const;
Expand Down Expand Up @@ -249,7 +251,9 @@ class State {
std::vector<StateValue>
addFnCall(const std::string &name, std::vector<StateValue> &&inputs,
std::vector<Memory::PtrInput> &&ptr_inputs,
const std::vector<Type*> &out_types, const FnAttrs &attrs);
const std::vector<Type*> &out_types,
std::optional<StateValue> &&ret_arg,
std::vector<StateValue> &&ret_args, const FnAttrs &attrs);

auto& getVarArgsData() { return var_args_data.data; }

Expand Down
61 changes: 34 additions & 27 deletions llvm_util/llvm2alive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ string_view s(llvm::StringRef str) {

class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
BasicBlock *BB;
Function *alive_fn;
llvm::Function &f;
const llvm::TargetLibraryInfo &TLI;
/// True if converting a source function, false when converting a target
Expand Down Expand Up @@ -306,16 +307,37 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
args.emplace_back(a);
}

auto fn = i.getCalledFunction();
FnAttrs attrs;
vector<ParamAttrs> param_attrs;

parse_fn_attrs(i, attrs);
parse_fn_attrs(i, attrs, true);

if (auto op = dyn_cast<llvm::FPMathOperator>(&i)) {
if (op->hasNoNaNs())
attrs.set(FnAttrs::NNaN);
}

// record fn decl in case there are indirect calls to this function
// elsewhere
if (fn) {
Function::FnDecl decl;
decl.name = '@' + fn->getName().str();
decl.output = llvm_type2alive(fn->getReturnType());
decl.attrs = attrs;

auto attrs_fndef = fn->getAttributes();
for (uint64_t idx = 0, nargs = fn->arg_size(); idx < nargs; ++idx) {
unsigned attr_argidx = llvm::AttributeList::FirstArgIndex + idx;
ParamAttrs pattr;
handleParamAttrs(attrs_fndef.getAttributes(attr_argidx), pattr, true);
decl.inputs.emplace_back(&args[idx]->getType(), std::move(pattr));
}
alive_fn->addFnDecl(std::move(decl));
}

parse_fn_attrs(i, attrs);

if (!approx) {
auto [known, approx0]
= known_call(i, TLI, *BB, args, std::move(attrs), param_attrs);
Expand All @@ -329,7 +351,6 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
return error(i);

unique_ptr<FnCall> call;
auto fn = i.getCalledFunction();
Value *fnptr = nullptr;

if (auto *iasm = dyn_cast<llvm::InlineAsm>(i.getCalledOperand())) {
Expand All @@ -343,17 +364,16 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
if (!fn) {
if (!(fnptr = get_operand(i.getCalledOperand())))
return error(i);
} else if (fn->getName().substr(0, 15) == "__llvm_profile_")
} else if (fn->getName().starts_with("__llvm_profile_"))
return NOP(i);

call = make_unique<FnCall>(*ty, value_name(i),
fn ? '@' + fn->getName().str() : string(),
std::move(attrs), fnptr);
}

llvm::AttributeList attrs_callsite = i.getAttributes();
llvm::AttributeList attrs_fndef = fn ? fn->getAttributes()
: llvm::AttributeList();
auto attrs_callsite = i.getAttributes();
auto attrs_fndef = fn ? fn->getAttributes() : llvm::AttributeList();

unique_ptr<Instr> ret_val;

Expand All @@ -375,28 +395,10 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
// padding
return errorAttr(i.getAttributeAtIndex(argidx, llvm::Attribute::NoUndef));
}

if (i.paramHasAttr(argidx, llvm::Attribute::Returned)) {
// fn may have different type than argument. LLVM assumes there's
// an implicit bitcast
assert(!ret_val);
if (i.getArgOperand(argidx)->getType() == i.getType())
ret_val = make_unique<UnaryOp>(*ty, value_name(i), *arg,
UnaryOp::Copy);
else
ret_val = make_unique<ConversionOp>(*ty, value_name(i), *arg,
ConversionOp::BitCast);
}

call->addArg(*arg, std::move(pattr));
}

call->setApproximated(approx);

if (ret_val) {
BB->addInstr(std::move(call));
RETURN_IDENTIFIER(std::move(ret_val));
}
RETURN_IDENTIFIER(std::move(call));
}

Expand Down Expand Up @@ -1529,17 +1531,20 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
return attrs;
}

void parse_fn_attrs(const llvm::CallInst &i, FnAttrs &attrs) {
void parse_fn_attrs(const llvm::CallInst &i, FnAttrs &attrs,
bool decl_only = false) {
auto fn = i.getCalledFunction();
llvm::AttributeList attrs_callsite = i.getAttributes();
llvm::AttributeList attrs_fndef = fn ? fn->getAttributes()
: llvm::AttributeList();
auto ret = llvm::AttributeList::ReturnIndex;
auto fnidx = llvm::AttributeList::FunctionIndex;

handleRetAttrs(attrs_callsite.getAttributes(ret), attrs);
if (!decl_only) {
handleRetAttrs(attrs_callsite.getAttributes(ret), attrs);
handleFnAttrs(attrs_callsite.getAttributes(fnidx), attrs);
}
handleRetAttrs(attrs_fndef.getAttributes(ret), attrs);
handleFnAttrs(attrs_callsite.getAttributes(fnidx), attrs);
handleFnAttrs(attrs_fndef.getAttributes(fnidx), attrs);
attrs.mem = handleMemAttrs(i.getMemoryEffects());
if (fn)
Expand Down Expand Up @@ -1570,6 +1575,8 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
f.isVarArg());
reset_state(Fn);

alive_fn = &Fn;

auto &attrs = Fn.getFnAttrs();
vector<ParamAttrs> param_attrs;
llvm::AttributeList attrlist = f.getAttributes();
Expand Down
5 changes: 5 additions & 0 deletions smt/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ bool expr::isInt(int64_t &n) const {
return true;
}

bool expr::isSameTypeOf(const expr &other) const {
C(other);
return sort() == other.sort();
}

bool expr::isEq(expr &lhs, expr &rhs) const {
return isBinOp(lhs, rhs, Z3_OP_EQ);
}
Expand Down
1 change: 1 addition & 0 deletions smt/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class expr {
unsigned bits() const;
bool isUInt(uint64_t &n) const;
bool isInt(int64_t &n) const;
bool isSameTypeOf(const expr &other) const;

bool isEq(expr &lhs, expr &rhs) const;
bool isSLE(expr &lhs, expr &rhs) const;
Expand Down
Loading

0 comments on commit 85c1061

Please sign in to comment.