Skip to content

Commit

Permalink
accounts/abi: abigen v2
Browse files Browse the repository at this point in the history
  • Loading branch information
s1na committed Mar 1, 2023
1 parent e1b98f4 commit 3d0330d
Show file tree
Hide file tree
Showing 6 changed files with 464 additions and 49 deletions.
60 changes: 38 additions & 22 deletions accounts/abi/bind/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ func DeployContract(opts *TransactOpts, abi abi.ABI, bytecode []byte, backend Co
// returns, a slice of interfaces for anonymous returns and a struct for named
// returns.
func (c *BoundContract) Call(opts *CallOpts, results *[]interface{}, method string, params ...interface{}) error {
// Don't crash on a lazy user
if opts == nil {
opts = new(CallOpts)
}
if results == nil {
results = new([]interface{})
}
Expand All @@ -161,51 +157,64 @@ func (c *BoundContract) Call(opts *CallOpts, results *[]interface{}, method stri
if err != nil {
return err
}
output, err := c.call(opts, input)
if err != nil {
return err
}

if len(*results) == 0 {
res, err := c.abi.Unpack(method, output)
*results = res
return err
}
res := *results
return c.abi.UnpackIntoInterface(res[0], method, output)
}

func (c *BoundContract) call(opts *CallOpts, input []byte) ([]byte, error) {
// Don't crash on a lazy user
if opts == nil {
opts = new(CallOpts)
}
var (
msg = ethereum.CallMsg{From: opts.From, To: &c.address, Data: input}
ctx = ensureContext(opts.Context)
code []byte
output []byte
err error
)
if opts.Pending {
pb, ok := c.caller.(PendingContractCaller)
if !ok {
return ErrNoPendingState
return nil, ErrNoPendingState
}
output, err = pb.PendingCallContract(ctx, msg)
if err != nil {
return err
return nil, err
}
if len(output) == 0 {
// Make sure we have a contract to operate on, and bail out otherwise.
if code, err = pb.PendingCodeAt(ctx, c.address); err != nil {
return err
return nil, err
} else if len(code) == 0 {
return ErrNoCode
return nil, ErrNoCode
}
}
} else {
output, err = c.caller.CallContract(ctx, msg, opts.BlockNumber)
if err != nil {
return err
return nil, err
}
if len(output) == 0 {
// Make sure we have a contract to operate on, and bail out otherwise.
if code, err = c.caller.CodeAt(ctx, c.address, opts.BlockNumber); err != nil {
return err
return nil, err
} else if len(code) == 0 {
return ErrNoCode
return nil, ErrNoCode
}
}
}

if len(*results) == 0 {
res, err := c.abi.Unpack(method, output)
*results = res
return err
}
res := *results
return c.abi.UnpackIntoInterface(res[0], method, output)
return output, nil
}

// Transact invokes the (paid) contract method with params as input values.
Expand Down Expand Up @@ -409,13 +418,16 @@ func (c *BoundContract) transact(opts *TransactOpts, contract *common.Address, i
// FilterLogs filters contract logs for past blocks, returning the necessary
// channels to construct a strongly typed bound iterator on top of them.
func (c *BoundContract) FilterLogs(opts *FilterOpts, name string, query ...[]interface{}) (chan types.Log, event.Subscription, error) {
return c.filterLogs(opts, c.abi.Events[name].ID, query...)
}

func (c *BoundContract) filterLogs(opts *FilterOpts, eventID common.Hash, query ...[]interface{}) (chan types.Log, event.Subscription, error) {
// Don't crash on a lazy user
if opts == nil {
opts = new(FilterOpts)
}
// Append the event selector to the query parameters and construct the topic set
query = append([][]interface{}{{c.abi.Events[name].ID}}, query...)

query = append([][]interface{}{{eventID}}, query...)
topics, err := abi.MakeTopics(query...)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -458,12 +470,16 @@ func (c *BoundContract) FilterLogs(opts *FilterOpts, name string, query ...[]int
// WatchLogs filters subscribes to contract logs for future blocks, returning a
// subscription object that can be used to tear down the watcher.
func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]interface{}) (chan types.Log, event.Subscription, error) {
return c.watchLogs(opts, c.abi.Events[name].ID, query...)
}

func (c *BoundContract) watchLogs(opts *WatchOpts, eventID common.Hash, query ...[]interface{}) (chan types.Log, event.Subscription, error) {
// Don't crash on a lazy user
if opts == nil {
opts = new(WatchOpts)
}
// Append the event selector to the query parameters and construct the topic set
query = append([][]interface{}{{c.abi.Events[name].ID}}, query...)
query = append([][]interface{}{{eventID}}, query...)

topics, err := abi.MakeTopics(query...)
if err != nil {
Expand Down
125 changes: 99 additions & 26 deletions accounts/abi/bind/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,100 @@ func isKeyWord(arg string) bool {
// enforces compile time type safety and naming convention opposed to having to
// manually maintain hard coded strings that break on runtime.
func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]string, pkg string, lang Lang, libs map[string]string, aliases map[string]string) (string, error) {
data, err := bind(types, abis, bytecodes, fsigs, pkg, lang, libs, aliases)
if err != nil {
return "", err
}
buffer := new(bytes.Buffer)

funcs := map[string]interface{}{
"bindtype": bindType[lang],
"bindtopictype": bindTopicType[lang],
"namedtype": namedType[lang],
"capitalise": capitalise,
"decapitalise": decapitalise,
}
tmpl := template.Must(template.New("").Funcs(funcs).Parse(tmplSource[lang]))
if err := tmpl.Execute(buffer, data); err != nil {
return "", err
}
// For Go bindings pass the code through gofmt to clean it up
if lang == LangGo {
code, err := format.Source(buffer.Bytes())
if err != nil {
return "", fmt.Errorf("%v\n%s", err, buffer)
}
return string(code), nil
}
// For all others just return as is for now
return buffer.String(), nil
}

func BindV2(types []string, abis []string, bytecodes []string, fsigs []map[string]string, pkg string, lang Lang, libs map[string]string, aliases map[string]string) (string, error) {
data, err := bind(types, abis, bytecodes, fsigs, pkg, lang, libs, aliases)
if err != nil {
return "", err
}
for _, c := range data.Contracts {
// We want pack/unpack methods for all existing methods.
for name, t := range c.Transacts {
c.Calls[name] = t
}
c.Transacts = nil

// Make sure we return one argument. If multiple exist
// merge them into a struct.
for _, call := range c.Calls {
if call.Structured {
continue
}
if len(call.Normalized.Outputs) == 1 {
continue
}
// Build up dictionary of existing arg names.
keys := make(map[string]struct{})
for _, o := range call.Normalized.Outputs {
if o.Name != "" {
keys[strings.ToLower(o.Name)] = struct{}{}
}
}
// Assign names to anonymous fields.
for i, o := range call.Normalized.Outputs {
if o.Name != "" {
continue
}
o.Name = capitalise(abi.ResolveNameConflict("arg", func(name string) bool { _, ok := keys[name]; return ok }))
call.Normalized.Outputs[i] = o
keys[strings.ToLower(o.Name)] = struct{}{}
}
call.Structured = true
}
}
buffer := new(bytes.Buffer)
funcs := map[string]interface{}{
"bindtype": bindType[lang],
"bindtopictype": bindTopicType[lang],
"namedtype": namedType[lang],
"capitalise": capitalise,
"decapitalise": decapitalise,
}
tmpl := template.Must(template.New("").Funcs(funcs).Parse(tmplSourceV2[lang]))
if err := tmpl.Execute(buffer, data); err != nil {
return "", err
}
// For Go bindings pass the code through gofmt to clean it up
if lang == LangGo {
code, err := format.Source(buffer.Bytes())
if err != nil {
return "", fmt.Errorf("%v\n%s", err, buffer)
}
return string(code), nil
}
// For all others just return as is for now
return buffer.String(), nil
}

func bind(types []string, abis []string, bytecodes []string, fsigs []map[string]string, pkg string, lang Lang, libs map[string]string, aliases map[string]string) (*tmplData, error) {
var (
// contracts is the map of each individual contract requested binding
contracts = make(map[string]*tmplContract)
Expand All @@ -96,7 +190,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]
// Parse the actual ABI to generate the binding for
evmABI, err := abi.JSON(strings.NewReader(abis[i]))
if err != nil {
return "", err
return nil, err
}
// Strip any whitespace from the JSON ABI
strippedABI := strings.Map(func(r rune) rune {
Expand Down Expand Up @@ -140,7 +234,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]
identifiers = transactIdentifiers
}
if identifiers[normalizedName] {
return "", fmt.Errorf("duplicated identifier \"%s\"(normalized \"%s\"), use --alias for renaming", original.Name, normalizedName)
return nil, fmt.Errorf("duplicated identifier \"%s\"(normalized \"%s\"), use --alias for renaming", original.Name, normalizedName)
}
identifiers[normalizedName] = true

Expand Down Expand Up @@ -183,7 +277,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]
// Ensure there is no duplicated identifier
normalizedName := methodNormalizer[lang](alias(aliases, original.Name))
if eventIdentifiers[normalizedName] {
return "", fmt.Errorf("duplicated identifier \"%s\"(normalized \"%s\"), use --alias for renaming", original.Name, normalizedName)
return nil, fmt.Errorf("duplicated identifier \"%s\"(normalized \"%s\"), use --alias for renaming", original.Name, normalizedName)
}
eventIdentifiers[normalizedName] = true
normalized.Name = normalizedName
Expand Down Expand Up @@ -218,6 +312,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]
if evmABI.HasReceive() {
receive = &tmplMethod{Original: evmABI.Receive}
}

contracts[types[i]] = &tmplContract{
Type: capitalise(types[i]),
InputABI: strings.ReplaceAll(strippedABI, "\"", "\\\""),
Expand Down Expand Up @@ -262,29 +357,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string]
Libraries: libs,
Structs: structs,
}
buffer := new(bytes.Buffer)

funcs := map[string]interface{}{
"bindtype": bindType[lang],
"bindtopictype": bindTopicType[lang],
"namedtype": namedType[lang],
"capitalise": capitalise,
"decapitalise": decapitalise,
}
tmpl := template.Must(template.New("").Funcs(funcs).Parse(tmplSource[lang]))
if err := tmpl.Execute(buffer, data); err != nil {
return "", err
}
// For Go bindings pass the code through gofmt to clean it up
if lang == LangGo {
code, err := format.Source(buffer.Bytes())
if err != nil {
return "", fmt.Errorf("%v\n%s", err, buffer)
}
return string(code), nil
}
// For all others just return as is for now
return buffer.String(), nil
return data, nil
}

// bindType is a set of type binders that convert Solidity types to some supported
Expand Down
Loading

0 comments on commit 3d0330d

Please sign in to comment.