Skip to content

Commit

Permalink
Adds ValueKind enum to enable dynamic dispatch for Python bindings fo…
Browse files Browse the repository at this point in the history
…r Value/Result/Argument.
  • Loading branch information
Peter Goodman committed Dec 11, 2023
1 parent f9ccf25 commit a316fd7
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 3 deletions.
1 change: 1 addition & 0 deletions bindings/Python/Forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,7 @@ class RegexQuery;
namespace ir {
enum class AttributeKind : uint32_t;
class Attribute;
enum ValueKind : uint8_t;
class Value;
class Block;
class Argument;
Expand Down
31 changes: 30 additions & 1 deletion bindings/Python/Generated/IR/Argument.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,19 @@ std::optional<T> PythonBinding<T>::from_python(BorrowedPyObject *obj) noexcept {

template <>
SharedPyObject *PythonBinding<T>::to_python(T val) noexcept {
auto ret = gType->tp_alloc(gType, 0);
PyTypeObject *tp = nullptr;
switch (val.kind()) {
default:
assert(false);
tp = gType;
break;

case mx::ir::Argument::static_kind():
tp = &(gTypes[927]);
break;

}
auto ret = tp->tp_alloc(tp, 0);
if (auto obj = O_cast(ret)) {
obj->data = new (obj->backing_storage) T(std::move(val));
}
Expand Down Expand Up @@ -128,6 +140,23 @@ static PyGetSetDef gProperties[] = {

namespace {
static PyMethodDef gMethods[] = {
{
"static_kind",
reinterpret_cast<PyCFunction>(
+[] (BorrowedPyObject *, BorrowedPyObject * const *args, int num_args) -> SharedPyObject * {
(void) args;
while (num_args == 0) {

return ::mx::to_python(T::static_kind());
}

PyErrorStreamer(PyExc_TypeError)
<< "Invalid arguments passed to 'static_kind'";
return nullptr;
}),
METH_FASTCALL | METH_STATIC,
PyDoc_STR("Wrapper for mx::ir::Argument::static_kind"),
},
{
"FROM",
reinterpret_cast<PyCFunction>(
Expand Down
31 changes: 30 additions & 1 deletion bindings/Python/Generated/IR/Result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,19 @@ std::optional<T> PythonBinding<T>::from_python(BorrowedPyObject *obj) noexcept {

template <>
SharedPyObject *PythonBinding<T>::to_python(T val) noexcept {
auto ret = gType->tp_alloc(gType, 0);
PyTypeObject *tp = nullptr;
switch (val.kind()) {
default:
assert(false);
tp = gType;
break;

case mx::ir::Result::static_kind():
tp = &(gTypes[928]);
break;

}
auto ret = tp->tp_alloc(tp, 0);
if (auto obj = O_cast(ret)) {
obj->data = new (obj->backing_storage) T(std::move(val));
}
Expand Down Expand Up @@ -138,6 +150,23 @@ static PyGetSetDef gProperties[] = {

namespace {
static PyMethodDef gMethods[] = {
{
"static_kind",
reinterpret_cast<PyCFunction>(
+[] (BorrowedPyObject *, BorrowedPyObject * const *args, int num_args) -> SharedPyObject * {
(void) args;
while (num_args == 0) {

return ::mx::to_python(T::static_kind());
}

PyErrorStreamer(PyExc_TypeError)
<< "Invalid arguments passed to 'static_kind'";
return nullptr;
}),
METH_FASTCALL | METH_STATIC,
PyDoc_STR("Wrapper for mx::ir::Result::static_kind"),
},
{
"of",
reinterpret_cast<PyCFunction>(
Expand Down
28 changes: 27 additions & 1 deletion bindings/Python/Generated/IR/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,23 @@ std::optional<T> PythonBinding<T>::from_python(BorrowedPyObject *obj) noexcept {

template <>
SharedPyObject *PythonBinding<T>::to_python(T val) noexcept {
auto ret = gType->tp_alloc(gType, 0);
PyTypeObject *tp = nullptr;
switch (val.kind()) {
default:
assert(false);
tp = gType;
break;

case mx::ir::Argument::static_kind():
tp = &(gTypes[927]);
break;

case mx::ir::Result::static_kind():
tp = &(gTypes[928]);
break;

}
auto ret = tp->tp_alloc(tp, 0);
if (auto obj = O_cast(ret)) {
obj->data = new (obj->backing_storage) T(std::move(val));
}
Expand Down Expand Up @@ -112,6 +128,16 @@ bool PythonBinding<T>::load(BorrowedPyObject *module) noexcept {

namespace {
static PyGetSetDef gProperties[] = {
{
"kind",
reinterpret_cast<getter>(
+[] (BorrowedPyObject *self, void * /* closure */) -> SharedPyObject * {
return ::mx::to_python(T_cast(self)->kind());
}),
nullptr,
PyDoc_STR("Wrapper for mx::ir::Value::kind"),
nullptr,
},
{
"type",
reinterpret_cast<getter>(
Expand Down
155 changes: 155 additions & 0 deletions bindings/Python/Generated/IR/ValueKind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) 2023-present, Trail of Bits, Inc.
// All rights reserved.
//
// This source code is licensed in accordance with the terms specified in
// the LICENSE file found in the root directory of this source tree.

// Auto-generated file; do not modify!

#include <multiplier/IR/Value.h>
#include <multiplier/Iterator.h>

#include "Binding.h"
#include "Error.h"
#include "Types.h"

namespace mx {
namespace {
using T = mx::ir::ValueKind;
} // namespace

namespace {
static PyTypeObject *gType = nullptr;
} // namespace

template <>
PyTypeObject *PythonBinding<T>::type(void) noexcept {
return gType;
}

template <>
SharedPyObject *PythonBinding<T>::to_python(T val) noexcept {
return PyObject_GetAttrString(reinterpret_cast<BorrowedPyObject *>(gType),
EnumeratorName(val));
}

template <>
std::optional<T> PythonBinding<T>::from_python(BorrowedPyObject *obj) noexcept {
if (!obj) {
return std::nullopt;
}

if (Py_TYPE(obj) != gType) {
return std::nullopt;
}

auto long_val = PyObject_GetAttrString(obj, "value");
if (!long_val) {
PyErr_Clear();
return std::nullopt;
}

if (!PyLong_Check(long_val)) {
Py_DECREF(long_val);
return std::nullopt;
}

int did_overflow = 0;
const auto ret = static_cast<T>(
PyLong_AsLongLongAndOverflow(obj, &did_overflow));
if (did_overflow) {
Py_DECREF(long_val);
return std::nullopt;
}

return ret;
}

template <>
bool PythonBinding<T>::load(BorrowedPyObject *module) noexcept {
const char * const enum_name = EnumerationName(T{});
bool created = false;

if (!gType) {
auto enum_module = PyImport_ImportModule("enum");
if (!enum_module) {
return false;
}

auto int_enum = PyObject_GetAttrString(enum_module, "IntEnum");
Py_DECREF(enum_module);
if (!int_enum) {
return false;
}

auto enum_meta = PyObject_Type(int_enum);
auto prepare = PyObject_GetAttrString(enum_meta, "__prepare__");
if (!prepare) {
Py_DECREF(enum_meta);
Py_DECREF(int_enum);
return false;
}

// Get the `enum._EnumDict` for what we're making.
auto ns_dict = PyObject_CallFunction(prepare, "s(N)", enum_name, int_enum);
Py_DECREF(prepare);
if (!ns_dict) {
Py_DECREF(enum_meta);
Py_DECREF(int_enum);
return false;
}

// Assign each enumerator.
for (T val : EnumerationRange<T>()) {
auto iname = PyUnicode_FromString(EnumeratorName(val));
auto ival = PyLong_FromUnsignedLongLong(static_cast<uint64_t>(val));
if (ival) {
if (!PyObject_SetItem(ns_dict, iname, ival)) {
continue;
}
Py_DECREF(ival);
}

Py_DECREF(ns_dict);
Py_DECREF(enum_meta);
Py_DECREF(int_enum);
return false;
}

// Create the type.
auto enum_class = PyObject_CallFunction(
enum_meta, "s(N)N", enum_name, int_enum, ns_dict);
Py_DECREF(ns_dict);
Py_DECREF(enum_meta);
Py_DECREF(int_enum);

if (!enum_class) {
return false;
}

if (!PyType_Check(enum_class)) {
Py_DECREF(enum_class);

PyErrorStreamer(PyExc_ImportError)
<< "Created enum class for enumerator '" << enum_name
<< "' is not a python type";
return false;
}

gType = reinterpret_cast<PyTypeObject *>(enum_class);
created = true;
}

auto tp_obj = reinterpret_cast<BorrowedPyObject *>(gType);
if (0 != PyModule_AddObjectRef(module, enum_name, tp_obj)) {
return false;
}

if (created) {
Py_DECREF(tp_obj);
}

return true;
}

} // namespace mx
1 change: 1 addition & 0 deletions bindings/Python/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ static PyModuleDef gIRModule = {
static LoaderFunc * const gIRLoaders[] = {
PythonBinding<mx::ir::AttributeKind>::load,
PythonBinding<mx::ir::Attribute>::load,
PythonBinding<mx::ir::ValueKind>::load,
PythonBinding<mx::ir::Value>::load,
PythonBinding<mx::ir::Block>::load,
PythonBinding<mx::ir::Argument>::load,
Expand Down
13 changes: 13 additions & 0 deletions bindings/Python/multiplier-stubs/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class AttributeKind(IntEnum):
CORE_VOID = 84
META_IDENTIFIER = 85

class ValueKind(IntEnum):
OPERATION_RESULT = 0
BLOCK_ARGUMENT = 1

class OperationKind(IntEnum):
UNKNOWN = 0
BUILTIN_MODULE = 1
Expand Down Expand Up @@ -616,12 +620,17 @@ class Attribute(object):
kind: multiplier.ir.AttributeKind

class Value(object):
kind: multiplier.ir.ValueKind
type: multiplier.ir.Type
uses: Generator[multiplier.ir.Operand]

class Argument(multiplier.ir.Value):
index: int

@staticmethod
def static_kind() -> multiplier.ir.ValueKind:
...

@staticmethod
def FROM(val: multiplier.ir.Value) -> Optional[multiplier.ir.Argument]:
...
Expand All @@ -630,6 +639,10 @@ class Result(multiplier.ir.Value):
operation: multiplier.ir.Operation
index: int

@staticmethod
def static_kind() -> multiplier.ir.ValueKind:
...

@staticmethod
def of(arg_0: multiplier.ir.Operation) -> Optional[multiplier.ir.Result]:
...
Expand Down
4 changes: 4 additions & 0 deletions include/multiplier/IR/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class MX_EXPORT Argument final : public Value {

public:

inline static constexpr ValueKind static_kind(void) noexcept {
return ValueKind::BLOCK_ARGUMENT;
}

static std::optional<Argument> from(const Value &val);

// Index of this block argument.
Expand Down
4 changes: 4 additions & 0 deletions include/multiplier/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class Result final : public Value {

public:

inline static constexpr ValueKind static_kind(void) noexcept {
return ValueKind::OPERATION_RESULT;
}

inline Result(std::shared_ptr<const SourceIRImpl> module,
void *res)
: Value(std::move(module),
Expand Down
17 changes: 17 additions & 0 deletions include/multiplier/IR/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ class Result;
class SourceIRImpl;
class Type;

enum ValueKind : unsigned char {
OPERATION_RESULT,
BLOCK_ARGUMENT
};

inline static const char *EnumerationName(ValueKind) {
return "ValueKind";
}

MX_EXPORT const char *EnumeratorName(ValueKind);

inline static constexpr unsigned NumEnumerators(ValueKind) {
return 2;
}

// The result of an operation, or an argument to a block. Values can
// have an arbitrary number of users.
class MX_EXPORT Value {
Expand Down Expand Up @@ -61,6 +76,8 @@ class MX_EXPORT Value {
: module_(std::move(module)),
impl_(value) {}

ValueKind kind(void) const noexcept;

Type type(void) const noexcept;

// Generate the uses of this value.
Expand Down
Loading

0 comments on commit a316fd7

Please sign in to comment.