-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
1,588 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package agents | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
|
||
"github.com/bububa/atomic-agents/components" | ||
"github.com/bububa/atomic-agents/schema" | ||
) | ||
|
||
// Chain agents chain | ||
type Chain[I schema.Schema, O schema.Schema] struct { | ||
name string | ||
agents []ChainableAgent | ||
startHook func(context.Context, *Chain[I, O], *I) | ||
endHook func(context.Context, *Chain[I, O], *I, *O, []components.ApiResponse) | ||
errorHook func(context.Context, *Chain[I, O], *I, []components.ApiResponse, error) | ||
} | ||
|
||
// NewChain returns a new Chain instance | ||
func NewChain[I schema.Schema, O schema.Schema](agents ...ChainableAgent) *Chain[I, O] { | ||
return &Chain[I, O]{ | ||
agents: agents, | ||
} | ||
} | ||
|
||
func (c *Chain[I, O]) Name() string { | ||
return c.name | ||
} | ||
|
||
func (c *Chain[I, O]) SetName(name string) { | ||
c.name = name | ||
} | ||
|
||
func (c *Chain[I, O]) SetStartHook(fn func(context.Context, *Chain[I, O], *I)) { | ||
c.startHook = fn | ||
} | ||
|
||
func (c *Chain[I, O]) SetEndHook(fn func(context.Context, *Chain[I, O], *I, *O, []components.ApiResponse)) { | ||
c.endHook = fn | ||
} | ||
|
||
func (c *Chain[I, O]) SetErrorHook(fn func(context.Context, *Chain[I, O], *I, []components.ApiResponse, error)) { | ||
c.errorHook = fn | ||
} | ||
|
||
// Run runs the chat agents with the given user input synchronously. | ||
func (c *Chain[I, O]) Run(ctx context.Context, input *I, output *O) ([]components.ApiResponse, error) { | ||
if fn := c.startHook; fn != nil { | ||
fn(ctx, c, input) | ||
} | ||
l := len(c.agents) | ||
apiRespList := make([]components.ApiResponse, 0, l) | ||
var ( | ||
in any = input | ||
out any | ||
) | ||
for _, agent := range c.agents { | ||
apiResp := new(components.ApiResponse) | ||
if ret, err := agent.RunForChain(ctx, in, apiResp); err != nil { | ||
if fn := c.errorHook; fn != nil { | ||
fn(ctx, c, input, apiRespList, err) | ||
} | ||
return apiRespList, err | ||
} else { | ||
in = ret | ||
out = ret | ||
} | ||
apiRespList = append(apiRespList, *apiResp) | ||
} | ||
if outO, ok := out.(*O); !ok { | ||
err := errors.New("invalid agent output schema") | ||
if fn := c.errorHook; fn != nil { | ||
fn(ctx, c, input, apiRespList, err) | ||
} | ||
return apiRespList, err | ||
} else { | ||
*output = *outO | ||
} | ||
if fn := c.endHook; fn != nil { | ||
fn(ctx, c, input, output, apiRespList) | ||
} | ||
return apiRespList, nil | ||
} | ||
|
||
// Run runs the chat agents with the given user input synchronously. | ||
func (c *Chain[I, O]) RunForChain(ctx context.Context, input any, apiResp *components.ApiResponse) (any, error) { | ||
in, ok := input.(*I) | ||
if !ok { | ||
return nil, errors.New("invalid agent input schema") | ||
} | ||
out := new(O) | ||
apiRespList, err := c.Run(ctx, in, out) | ||
if err != nil { | ||
return nil, err | ||
} | ||
for _, v := range apiRespList { | ||
if v.Usage == nil { | ||
continue | ||
} | ||
if apiResp.Usage == nil { | ||
apiResp.Usage = new(components.ApiUsage) | ||
} | ||
apiResp.Usage.InputTokens = v.Usage.InputTokens | ||
apiResp.Usage.OutputTokens = v.Usage.OutputTokens | ||
} | ||
return out, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package agents | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
|
||
"github.com/bububa/atomic-agents/components" | ||
"github.com/bububa/atomic-agents/schema" | ||
) | ||
|
||
// AgentSelector will returns a Tool based on input param | ||
type AgentSelector[I schema.Schema] func(req *I) (ChainableAgent, any, error) | ||
|
||
// OrchestrationAgent is an agent for orchestration | ||
type OrchestrationAgent[I schema.Schema, O schema.Schema] struct { | ||
name string | ||
selector AgentSelector[I] | ||
} | ||
|
||
func NewOrchestrationAgent[I schema.Schema, O schema.Schema](selector AgentSelector[I]) *OrchestrationAgent[I, O] { | ||
return &OrchestrationAgent[I, O]{ | ||
selector: selector, | ||
} | ||
} | ||
|
||
func (a *OrchestrationAgent[I, O]) Name() string { | ||
return a.name | ||
} | ||
|
||
func (a *OrchestrationAgent[I, O]) SetName(name string) { | ||
a.name = name | ||
} | ||
|
||
func (a *OrchestrationAgent[I, O]) Run(ctx context.Context, input *I, output *O, apiResp *components.ApiResponse) error { | ||
fn, params, err := a.selector(input) | ||
if err != nil { | ||
return err | ||
} | ||
if out, err := fn.RunForChain(ctx, params, apiResp); err != nil { | ||
return err | ||
} else if outO, ok := out.(*O); !ok { | ||
return errors.New("invalid agent output schema") | ||
} else { | ||
*output = *outO | ||
} | ||
return nil | ||
} | ||
|
||
func (a *OrchestrationAgent[I, O]) RunForChain(ctx context.Context, input any, apiResp *components.ApiResponse) (any, error) { | ||
in, ok := input.(*I) | ||
if !ok { | ||
return nil, errors.New("invalid agent input schema") | ||
} | ||
fn, params, err := a.selector(in) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return fn.RunForChain(ctx, params, apiResp) | ||
} |
Oops, something went wrong.