Skip to content

Added support for anonymous fields #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions di.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ const (

var initializeShutdownLock sync.Mutex
var createInstanceLock sync.Mutex
var configureInstanceLock sync.RWMutex
var containerInitialized int32
var beans = make(map[string]reflect.Type)
var beanFactories = make(map[string]func(context.Context) (interface{}, error))
var scopes = make(map[string]Scope)
var singletonInstances = make(map[string]interface{})
var userCreatedInstances = make(map[string]bool)
var beanPostprocessors = make(map[reflect.Type][]func(bean interface{}) error)
var configurations = make(map[reflect.Type]interface{})
var configurationTypeCache = make(map[reflect.Type]reflect.Type)

// InitializingBean is an interface marking beans that need to be additionally initialized after the container is ready.
type InitializingBean interface {
Expand Down Expand Up @@ -95,6 +98,20 @@ func RegisterBeanPostprocessor(beanType reflect.Type, postprocessor func(bean in
return nil
}

func RegisterBeanConfiguration [T interface{}](configuration T) error {
initializeShutdownLock.Lock()
defer initializeShutdownLock.Unlock()
if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) {
return errors.New("container is already initialized: can't register bean configuration")
}
confType := reflect.TypeOf(configuration)
if _, contains := configurations[confType]; contains {
return errors.New("configuration of this type is already registered")
}
configurations[confType] = configuration
return nil
}

// InitializeContainer function initializes the IoC container.
func InitializeContainer() error {
initializeShutdownLock.Lock()
Expand Down Expand Up @@ -259,9 +276,17 @@ func injectSingletonDependencies() error {
func injectDependencies(beanID string, instance interface{}, chain map[string]bool) error {
logrus.WithField("beanID", beanID).Trace("injecting dependencies")
instanceType := beans[beanID]
instanceElement := instanceType.Elem()
return injectDependenciesWithType(instanceType.Elem(), beanID, instance, chain)
}

func injectDependenciesWithType(instanceElement reflect.Type, beanID string, instance interface{}, chain map[string]bool) error {
for i := 0; i < instanceElement.NumField(); i++ {
field := instanceElement.Field(i)
if field.Type.Kind() == reflect.Struct && field.Anonymous {
fieldToInject := reflect.ValueOf(instance).Elem().Field(i)
injectDependenciesWithType(field.Type, beanID, fieldToInject.Addr().Interface(), chain)
continue
}
beanToInject, ok := field.Tag.Lookup(string(inject))
if !ok {
continue
Expand Down Expand Up @@ -457,14 +482,56 @@ func initializeSingletonInstances() error {
return nil
}

func applyConfiguration(beanType reflect.Type, instance interface{}) error {
configureInstanceLock.RLock()
configType, ok := configurationTypeCache[beanType]
configureInstanceLock.RUnlock()

if ok {
if configType != nil {
method := reflect.ValueOf(instance).MethodByName("Configure")
config := configurations[configType]
val:=method.Call([]reflect.Value{reflect.ValueOf(config)})
if val[0].Interface() != nil {
return val[0].Interface().(error)
}
}
return nil
} else {
configureInstanceLock.Lock()
defer configureInstanceLock.Unlock()
configurationTypeCache[beanType] = nil
for configType, config := range configurations {
method := reflect.ValueOf(instance).MethodByName("Configure")
if method.IsValid() && method.Type().NumIn() == 1 {
if configType.AssignableTo(method.Type().In(0)) {
configurationTypeCache[beanType] = configType
val:=method.Call([]reflect.Value{reflect.ValueOf(config)})
if val[0].Interface() != nil {
return val[0].Interface().(error)
}
break
}
}
}
return nil
}
}

func initializeInstance(beanID string, instance interface{}) error {
bean := reflect.TypeOf(instance)

// Configure first, then PostConstruct
if err:= applyConfiguration(bean, instance); err != nil {
return err
}

if impl, ok := instance.(InitializingBean); ok {
logrus.WithField("beanID", beanID).Trace("initializing bean")
if err := impl.PostConstruct(); err != nil {
return err
}
}
bean := reflect.TypeOf(instance)
if postprocessors, ok := beanPostprocessors[bean]; ok {
logrus.WithField("beanID", beanID).Trace("postprocessing bean")
for _, postprocessor := range postprocessors {
Expand All @@ -473,6 +540,7 @@ func initializeInstance(beanID string, instance interface{}) error {
}
}
}

return nil
}

Expand Down
99 changes: 99 additions & 0 deletions di_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1166,3 +1166,102 @@ func (suite *TestSuite) TestShutdownContinueOnError() {
assert.Equal(suite.T(), 5, len(closedSingletons))
assert.Equal(suite.T(), 5, len(singletonBeansWithErrorOnClose))
}

func (suite *TestSuite) TestInjectInParent() {
type SingletonBeanParent struct {
otherBean1 someInterface `di.inject:""`
}
type SingletonBeanChild struct {
SingletonBeanParent
otherBean2 someInterface `di.inject:""`
}

overwritten, err := RegisterBean("singletonBean", reflect.TypeOf((*SingletonBeanChild)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
overwritten, err = RegisterBean("otherBean", reflect.TypeOf((*otherBean)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
err = InitializeContainer()
assert.NoError(suite.T(), err)
instance, err := GetInstanceSafe("singletonBean")
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), instance.(*SingletonBeanChild).otherBean1)
assert.NotNil(suite.T(), instance.(*SingletonBeanChild).otherBean2)
_, ok := instance.(*SingletonBeanChild).otherBean1.(*otherBean)
assert.True(suite.T(), ok)
assert.EqualValues(suite.T(), instance.(*SingletonBeanChild).otherBean1, instance.(*SingletonBeanChild).otherBean2)
}

type SingletonConfiguringInnerBean struct {
value int
}
type SingletonInnerBean struct {
}

type SingletonConfiguredBean struct {
innerConfigBean *SingletonConfiguringInnerBean `di.inject:""`
innerBean *SingletonInnerBean `di.inject:""`
value int
postConstructCalled bool
}

type Config1 struct {
Conf int
}

type Config2 struct {
Conf int
}

func (bean *SingletonConfiguredBean) Configure(conf *Config1) error {
bean.value = conf.Conf
return nil
}

func (bean *SingletonConfiguredBean) PostConstruct() error {
bean.postConstructCalled = true
if bean.value == 0 {
return errors.New("not initialized")
}
return nil
}

func (bean *SingletonConfiguringInnerBean) Configure(conf *Config2) error {
bean.value = conf.Conf
return nil
}

func (suite *TestSuite) TestInjectConfiguration() {
overwritten, err := RegisterBean("singletonBean1", reflect.TypeOf((*SingletonConfiguredBean)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
overwritten, err = RegisterBean("singletonBean2", reflect.TypeOf((*SingletonConfiguredBean)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
overwritten, err = RegisterBean("innerConfigBean", reflect.TypeOf((*SingletonConfiguringInnerBean)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
overwritten, err = RegisterBean("innerBean", reflect.TypeOf((*SingletonInnerBean)(nil)))
assert.False(suite.T(), overwritten)
assert.NoError(suite.T(), err)
RegisterBeanConfiguration(&Config1{Conf: 11})
RegisterBeanConfiguration(&Config2{Conf: 22})
err = InitializeContainer()
assert.NoError(suite.T(), err)
instance, err := GetInstanceSafe("singletonBean1")
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerConfigBean)
assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerBean)
assert.Equal(suite.T(), 11, instance.(*SingletonConfiguredBean).value)
assert.Equal(suite.T(), 22, instance.(*SingletonConfiguredBean).innerConfigBean.value)
assert.Equal(suite.T(), true, instance.(*SingletonConfiguredBean).postConstructCalled)

instance, err = GetInstanceSafe("singletonBean2")
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerConfigBean)
assert.NotNil(suite.T(), instance.(*SingletonConfiguredBean).innerBean)
assert.Equal(suite.T(), 11, instance.(*SingletonConfiguredBean).value)
assert.Equal(suite.T(), 22, instance.(*SingletonConfiguredBean).innerConfigBean.value)
assert.Equal(suite.T(), true, instance.(*SingletonConfiguredBean).postConstructCalled)
}