Skip to content

Commit

Permalink
feat: table prefix (#47)
Browse files Browse the repository at this point in the history
Signed-off-by: abingcbc <[email protected]>
  • Loading branch information
Abingcbc authored Jan 20, 2022
1 parent 2ea99d3 commit a77f892
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 60 deletions.
34 changes: 22 additions & 12 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Adapter struct {
dbSpecified bool
isFiltered bool
engine *xorm.Engine
tablePrefix string
tableName string
}

Expand Down Expand Up @@ -112,11 +113,12 @@ func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) (
}

// NewAdapterWithTableName .
func NewAdapterWithTableName(driverName string, dataSourceName string, tableName string, dbSpecified ...bool) (*Adapter, error) {
func NewAdapterWithTableName(driverName string, dataSourceName string, tableName string, tablePrefix string, dbSpecified ...bool) (*Adapter, error) {
a := &Adapter{
driverName: driverName,
dataSourceName: dataSourceName,
tableName: tableName,
tablePrefix: tablePrefix,
}

if len(dbSpecified) == 0 {
Expand Down Expand Up @@ -154,10 +156,11 @@ func NewAdapterByEngine(engine *xorm.Engine) (*Adapter, error) {
}

// NewAdapterByEngineWithTableName .
func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string) (*Adapter, error) {
func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string, tablePrefix string) (*Adapter, error) {
a := &Adapter{
engine: engine,
tableName: tableName,
engine: engine,
tableName: tableName,
tablePrefix: tablePrefix,
}

err := a.createTable()
Expand All @@ -168,6 +171,13 @@ func NewAdapterByEngineWithTableName(engine *xorm.Engine, tableName string) (*Ad
return a, nil
}

func (a *Adapter) getFullTableName() string {
if a.tablePrefix != "" {
return a.tablePrefix + "_" + a.tableName
}
return a.tableName
}

func (a *Adapter) createDatabase() error {
var err error
var engine *xorm.Engine
Expand Down Expand Up @@ -231,11 +241,11 @@ func (a *Adapter) open() error {
}

func (a *Adapter) createTable() error {
return a.engine.Sync2(&CasbinRule{tableName: a.tableName})
return a.engine.Sync2(&CasbinRule{tableName: a.getFullTableName()})
}

func (a *Adapter) dropTable() error {
return a.engine.DropTables(&CasbinRule{tableName: a.tableName})
return a.engine.DropTables(&CasbinRule{tableName: a.getFullTableName()})
}

func loadPolicyLine(line *CasbinRule, model model.Model) {
Expand Down Expand Up @@ -263,7 +273,7 @@ func loadPolicyLine(line *CasbinRule, model model.Model) {
func (a *Adapter) LoadPolicy(model model.Model) error {
lines := make([]*CasbinRule, 0, 64)

if err := a.engine.Table(&CasbinRule{tableName: a.tableName}).Find(&lines); err != nil {
if err := a.engine.Table(&CasbinRule{tableName: a.getFullTableName()}).Find(&lines); err != nil {
return err
}

Expand All @@ -275,7 +285,7 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
}

func (a *Adapter) genPolicyLine(ptype string, rule []string) *CasbinRule {
line := CasbinRule{PType: ptype, tableName: a.tableName}
line := CasbinRule{PType: ptype, tableName: a.getFullTableName()}

l := len(rule)
if l > 0 {
Expand Down Expand Up @@ -383,7 +393,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
line := CasbinRule{PType: ptype, tableName: a.tableName}
line := CasbinRule{PType: ptype, tableName: a.getFullTableName()}

idx := fieldIndex + len(fieldValues)
if fieldIndex <= 0 && idx > 0 {
Expand Down Expand Up @@ -417,7 +427,7 @@ func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) erro
}

lines := make([]*CasbinRule, 0, 64)
if err := a.filterQuery(a.engine.NewSession(), filterValue).Table(&CasbinRule{tableName: a.tableName}).Find(&lines); err != nil {
if err := a.filterQuery(a.engine.NewSession(), filterValue).Table(&CasbinRule{tableName: a.getFullTableName()}).Find(&lines); err != nil {
return err
}

Expand Down Expand Up @@ -516,7 +526,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
for _, newRule := range newPolicies {
newP = append(newP, *a.genPolicyLine(ptype, newRule))
}
tx := a.engine.NewSession()
tx := a.engine.NewSession().Table(&CasbinRule{tableName: a.getFullTableName()})
defer tx.Close()

if err := tx.Begin(); err != nil {
Expand All @@ -528,7 +538,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
if err := tx.Where(str, args...).Find(&oldP); err != nil {
return nil, tx.Rollback()
}
if _, err := tx.Where(str.(string), args...).Delete(CasbinRule{}); err != nil {
if _, err := tx.Where(str.(string), args...).Delete(&CasbinRule{tableName: a.getFullTableName()}); err != nil {
return nil, tx.Rollback()
}
if _, err := tx.Insert(&newP[i]); err != nil {
Expand Down
91 changes: 43 additions & 48 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
package xormadapter

import (
"github.com/casbin/casbin/v2/util"
"log"
"strings"
"testing"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
Expand All @@ -45,20 +45,15 @@ func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
}
}

func initPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func initPolicy(t *testing.T, a *Adapter) {
// Because the DB is empty at first,
// so we need to load the policy from the file adapter (.CSV) first.
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", "examples/rbac_policy.csv")

a, err := NewAdapter(driverName, dataSourceName, dbSpecified...)
if err != nil {
panic(err)
}

// This is a trick to save the current policy to the DB.
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
// The current policy means the policy in the Casbin enforcer (aka in memory).
err = a.SavePolicy(e.GetModel())
err := a.SavePolicy(e.GetModel())
if err != nil {
panic(err)
}
Expand All @@ -75,30 +70,28 @@ func initPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecif
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testSaveLoad(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testSaveLoad(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testAutoSave(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testAutoSave(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)

// AutoSave is enabled by default.
Expand Down Expand Up @@ -152,16 +145,15 @@ func testAutoSave(t *testing.T, driverName string, dataSourceName string, dbSpec
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}

func testFilteredPolicy(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testFilteredPolicy(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")
// Now set the adapter
e.SetAdapter(a)
Expand Down Expand Up @@ -194,16 +186,15 @@ func testFilteredPolicy(t *testing.T, driverName string, dataSourceName string,
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}

func testRemovePolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testRemovePolicies(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")

// Now set the adapter
Expand Down Expand Up @@ -236,16 +227,15 @@ func testRemovePolicies(t *testing.T, driverName string, dataSourceName string,
testGetPolicy(t, e, [][]string{{"max", "data1", "delete"}})
}

func testAddPolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testAddPolicies(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")

// Now set the adapter
Expand All @@ -268,16 +258,15 @@ func testAddPolicies(t *testing.T, driverName string, dataSourceName string, dbS
testGetPolicy(t, e, [][]string{{"max", "data2", "read"}, {"max", "data1", "write"}})
}

func testUpdatePolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testUpdatePolicies(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")

// Now set the adapter
Expand All @@ -301,16 +290,15 @@ func testUpdatePolicies(t *testing.T, driverName string, dataSourceName string,
testGetPolicy(t, e, [][]string{{"bob", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testUpdateFilteredPolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
initPolicy(t, a)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")

// Now set the adapter
Expand Down Expand Up @@ -370,23 +358,30 @@ func TestAdapters(t *testing.T) {
// You can also use the following way to use an existing DB "abc":
// testSaveLoad(t, "mysql", "root:@tcp(127.0.0.1:3306)/abc", true)

testSaveLoad(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testSaveLoad(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testAutoSave(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testAutoSave(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testFilteredPolicy(t, "mysql", "root:@tcp(127.0.0.1:3306)/")

testAddPolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testAddPolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testRemovePolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testRemovePolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testUpdatePolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testUpdatePolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testUpdateFilteredPolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testUpdateFilteredPolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
a, _ := NewAdapter("mysql", "root:@tcp(127.0.0.1:3306)/")
testSaveLoad(t, a)
testAutoSave(t, a)
testFilteredPolicy(t, a)
testAddPolicies(t, a)
testRemovePolicies(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)

a, _ = NewAdapter("postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
testSaveLoad(t, a)
testAutoSave(t, a)
testFilteredPolicy(t, a)
testAddPolicies(t, a)
testRemovePolicies(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)

a, _ = NewAdapterWithTableName("mysql", "root:@tcp(127.0.0.1:3306)/", "test", "abc")
testSaveLoad(t, a)
testAutoSave(t, a)
testFilteredPolicy(t, a)
testAddPolicies(t, a)
testRemovePolicies(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)
}

0 comments on commit a77f892

Please sign in to comment.