Skip to content

Commit

Permalink
Merge branch 'release/v1.0.8'
Browse files Browse the repository at this point in the history
  • Loading branch information
bububa committed Feb 6, 2025
2 parents 18a2c8b + 7568b5a commit d8b63a0
Show file tree
Hide file tree
Showing 31 changed files with 1,588 additions and 139 deletions.
65 changes: 63 additions & 2 deletions agents/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package agents

import (
"context"
"errors"

"github.com/bububa/instructor-go/pkg/instructor"
cohere "github.com/cohere-ai/cohere-go/v2"
Expand All @@ -14,6 +15,15 @@ import (
"github.com/bububa/atomic-agents/schema"
)

type IAgent interface {
Name() string
}

type ChainableAgent interface {
IAgent
RunForChain(context.Context, any, *components.ApiResponse) (any, error)
}

type AgentSetter interface {
SetClient(clt instructor.Instructor)
SetMemory(m *components.Memory)
Expand Down Expand Up @@ -41,13 +51,18 @@ type Config struct {
temperature float32
// maxTokens Maximum number of tokens allowed in the response
maxTokens int
// name is Agent name presentation
name string
}

// Agent class for chat agents.
// This class provides the core functionality for handling chat interactions, including managing memory,
// generating system prompts, and obtaining responses from a language model.
type Agent[T schema.Schema, O schema.Schema] struct {
type Agent[I schema.Schema, O schema.Schema] struct {
Config
startHook func(context.Context, *Agent[I, O], *I)
endHook func(context.Context, *Agent[I, O], *I, *O, *components.ApiResponse)
errorHook func(context.Context, *Agent[I, O], *I, *components.ApiResponse, error)
}

// NewAgent initializes the AgentAgent
Expand All @@ -69,7 +84,7 @@ func NewAgent[I schema.Schema, O schema.Schema](options ...Option) *Agent[I, O]

// ResetMemory resets the memory to its initial state
func (a *Agent[I, O]) ResetMemory() {
a.initialMemory.Copy(a.memory)
a.memory.Reset()
}

func (a *Agent[I, O]) SetClient(clt instructor.Instructor) {
Expand All @@ -96,6 +111,26 @@ func (a *Agent[I, O]) SetMaxTokens(maxTokens int) {
a.maxTokens = maxTokens
}

func (a Agent[I, O]) Name() string {
return a.name
}

func (a *Agent[I, O]) SetName(name string) {
a.name = name
}

func (a *Agent[I, O]) SetStartHook(fn func(context.Context, *Agent[I, O], *I)) {
a.startHook = fn
}

func (a *Agent[I, O]) SetEndHook(fn func(context.Context, *Agent[I, O], *I, *O, *components.ApiResponse)) {
a.endHook = fn
}

func (a *Agent[I, O]) SetErrorHook(fn func(context.Context, *Agent[I, O], *I, *components.ApiResponse, error)) {
a.errorHook = fn
}

// Response obtains a response from the language model synchronously
func (a *Agent[I, O]) response(ctx context.Context, response *O, apiResponse *components.ApiResponse) error {
messages := make([]components.Message, 0, a.memory.MessageCount()+1)
Expand Down Expand Up @@ -163,17 +198,43 @@ func (a *Agent[I, O]) response(ctx context.Context, response *O, apiResponse *co

// Run runs the chat agent with the given user input synchronously.
func (a *Agent[I, O]) Run(ctx context.Context, userInput *I, output *O, apiResp *components.ApiResponse) error {
if fn := a.startHook; fn != nil {
fn(ctx, a, userInput)
}
if userInput != nil {
a.memory.NewTurn()
a.memory.NewMessage(components.UserRole, *userInput)
}
if err := a.response(ctx, output, apiResp); err != nil {
if fn := a.errorHook; fn != nil {
fn(ctx, a, userInput, apiResp, err)
}
return err
}
a.memory.NewMessage(components.AssistantRole, *output)
if fn := a.endHook; fn != nil {
fn(ctx, a, userInput, output, apiResp)
}
return nil
}

// Run runs the chat agent with the given user input for chain.
func (a *Agent[I, O]) RunForChain(ctx context.Context, userInput any, apiResp *components.ApiResponse) (any, error) {
in, ok := userInput.(*I)
if !ok {
return nil, errors.New("invalid input schema")
}
out := new(O)
if err := a.Run(ctx, in, out, apiResp); err != nil {
return nil, err
}
return out, nil
}

func (a *Agent[I, O]) NewMessage(role components.MessageRole, content schema.Schema) *components.Message {
return a.memory.NewMessage(role, content)
}

// SystemPromptContextProvider returns agent systemPromptGenerator's context provider
func (a *Agent[I, O]) SystemPromptContextProvider(title string) (systemprompt.ContextProvider, error) {
return a.systemPromptGenerator.ContextProvider(title)
Expand Down
108 changes: 108 additions & 0 deletions agents/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package agents

import (
"context"
"errors"

"github.com/bububa/atomic-agents/components"
"github.com/bububa/atomic-agents/schema"
)

// Chain agents chain
type Chain[I schema.Schema, O schema.Schema] struct {
name string
agents []ChainableAgent
startHook func(context.Context, *Chain[I, O], *I)
endHook func(context.Context, *Chain[I, O], *I, *O, []components.ApiResponse)
errorHook func(context.Context, *Chain[I, O], *I, []components.ApiResponse, error)
}

// NewChain returns a new Chain instance
func NewChain[I schema.Schema, O schema.Schema](agents ...ChainableAgent) *Chain[I, O] {
return &Chain[I, O]{
agents: agents,
}
}

func (c *Chain[I, O]) Name() string {
return c.name
}

func (c *Chain[I, O]) SetName(name string) {
c.name = name
}

func (c *Chain[I, O]) SetStartHook(fn func(context.Context, *Chain[I, O], *I)) {
c.startHook = fn
}

func (c *Chain[I, O]) SetEndHook(fn func(context.Context, *Chain[I, O], *I, *O, []components.ApiResponse)) {
c.endHook = fn
}

func (c *Chain[I, O]) SetErrorHook(fn func(context.Context, *Chain[I, O], *I, []components.ApiResponse, error)) {
c.errorHook = fn
}

// Run runs the chat agents with the given user input synchronously.
func (c *Chain[I, O]) Run(ctx context.Context, input *I, output *O) ([]components.ApiResponse, error) {
if fn := c.startHook; fn != nil {
fn(ctx, c, input)
}
l := len(c.agents)
apiRespList := make([]components.ApiResponse, 0, l)
var (
in any = input
out any
)
for _, agent := range c.agents {
apiResp := new(components.ApiResponse)
if ret, err := agent.RunForChain(ctx, in, apiResp); err != nil {
if fn := c.errorHook; fn != nil {
fn(ctx, c, input, apiRespList, err)
}
return apiRespList, err
} else {
in = ret
out = ret
}
apiRespList = append(apiRespList, *apiResp)
}
if outO, ok := out.(*O); !ok {
err := errors.New("invalid agent output schema")
if fn := c.errorHook; fn != nil {
fn(ctx, c, input, apiRespList, err)
}
return apiRespList, err
} else {
*output = *outO
}
if fn := c.endHook; fn != nil {
fn(ctx, c, input, output, apiRespList)
}
return apiRespList, nil
}

// Run runs the chat agents with the given user input synchronously.
func (c *Chain[I, O]) RunForChain(ctx context.Context, input any, apiResp *components.ApiResponse) (any, error) {
in, ok := input.(*I)
if !ok {
return nil, errors.New("invalid agent input schema")
}
out := new(O)
apiRespList, err := c.Run(ctx, in, out)
if err != nil {
return nil, err
}
for _, v := range apiRespList {
if v.Usage == nil {
continue
}
if apiResp.Usage == nil {
apiResp.Usage = new(components.ApiUsage)
}
apiResp.Usage.InputTokens = v.Usage.InputTokens
apiResp.Usage.OutputTokens = v.Usage.OutputTokens
}
return out, nil
}
6 changes: 6 additions & 0 deletions agents/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ func WithMaxTokens(maxTokens int) Option {
c.maxTokens = maxTokens
}
}

func WithName(name string) Option {
return func(c *Config) {
c.name = name
}
}
59 changes: 59 additions & 0 deletions agents/orchestration_agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package agents

import (
"context"
"errors"

"github.com/bububa/atomic-agents/components"
"github.com/bububa/atomic-agents/schema"
)

// AgentSelector will returns a Tool based on input param
type AgentSelector[I schema.Schema] func(req *I) (ChainableAgent, any, error)

// OrchestrationAgent is an agent for orchestration
type OrchestrationAgent[I schema.Schema, O schema.Schema] struct {
name string
selector AgentSelector[I]
}

func NewOrchestrationAgent[I schema.Schema, O schema.Schema](selector AgentSelector[I]) *OrchestrationAgent[I, O] {
return &OrchestrationAgent[I, O]{
selector: selector,
}
}

func (a *OrchestrationAgent[I, O]) Name() string {
return a.name
}

func (a *OrchestrationAgent[I, O]) SetName(name string) {
a.name = name
}

func (a *OrchestrationAgent[I, O]) Run(ctx context.Context, input *I, output *O, apiResp *components.ApiResponse) error {
fn, params, err := a.selector(input)
if err != nil {
return err
}
if out, err := fn.RunForChain(ctx, params, apiResp); err != nil {
return err
} else if outO, ok := out.(*O); !ok {
return errors.New("invalid agent output schema")
} else {
*output = *outO
}
return nil
}

func (a *OrchestrationAgent[I, O]) RunForChain(ctx context.Context, input any, apiResp *components.ApiResponse) (any, error) {
in, ok := input.(*I)
if !ok {
return nil, errors.New("invalid agent input schema")
}
fn, params, err := a.selector(in)
if err != nil {
return nil, err
}
return fn.RunForChain(ctx, params, apiResp)
}
Loading

0 comments on commit d8b63a0

Please sign in to comment.