Skip to content

Commit

Permalink
add pointer analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
cokeBeer committed Sep 16, 2022
1 parent 1266d6d commit 8089bc4
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 47 deletions.
11 changes: 6 additions & 5 deletions cmd/taintanalysis/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ import (
func main() {
// the ../../ takes you back to root of the project
// and the ... means scan packages in package pkg recursively
runner := taint.NewRunner("github.com/gorilla/schema")
runner := taint.NewRunner("../../pkg/main")
// the module name is the name defined in go.mod
runner.ModuleName = "github.com/cokeBeer/goot"
runner.PassThroughSrcPath = []string{"gostd1.19.json", "additional.json"}
//runner.PassThroughSrcPath = []string{"gostd1.19.json", "additional.json"}
runner.PassThroughDstPath = "passthrough.json"
runner.CallGraphDstPath = "callgraph.json"
runner.TaintGraphDstPath = "taintgraph.json"
runner.UsePointerAnalysis = true
runner.PassThroughOnly = true
runner.InitOnly = false
runner.Debug = true
runner.PersistToNeo4j = true
runner.TargetFunc = "(*github.com/gorilla/schema.Decoder).Decode"
runner.PersistToNeo4j = false
runner.TargetFunc = ""
runner.Neo4jURI = "bolt://localhost:7687"
runner.Neo4jUsername = "neo4j"
runner.Neo4jPassword = "password"
Expand Down
1 change: 1 addition & 0 deletions cmd/taintanalysis/taintgraph.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
3 changes: 2 additions & 1 deletion pkg/example/dataflow/taint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ All options are:
- `PassThroughOnly`(optional): when set true only do passthrough analysis, default `false`
- `PassThroughSrcPath`(optional): path to passthrough sources, you can use it to accelerate analysis or add additional passthrough, default `[]string{}`
- `PassThroughDstPath`(optional): path to save passthrough output, default `""`
- `CallGraphDstPath`(optional): path to save taint edge output, default `""`
- `TaintGraphDstPath`(optional): path to save taint edge output, default `""`
- `Ruler `(optional): ruler is interface that defines how to decide whether a node is sink, source or intra. You can implements it, default [DummyRuler](ruler.go)
- `PersistToNeo4j`(optional): when set true, save nodes and edges to neo4j, default `false`
- `Neo4jUsername`(optiosnal): neo4j usename, default `""`
- `Neo4jPassword`(optional): neo4j password, default `""`
- `Neo4jURI`(optional): neo4j uri, default `""`
- `TargetFunc`(optional): when set, only analysis target function and output its SSA, default `""`
- `UsePointerAnalysis`(optional): when set, use pointer analysis to help selecting callee, default `false`. ⚠️ note that if you set this true, the `PkgPath` option can only contain main packages
6 changes: 2 additions & 4 deletions pkg/example/dataflow/taint/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cokeBeer/goot/pkg/dataflow/toolkits/solver"
"github.com/cokeBeer/goot/pkg/dataflow/util/entry"
"github.com/cokeBeer/goot/pkg/example/dataflow/taint/rule"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/ssa"
)

Expand All @@ -22,8 +23,7 @@ type TaintAnalysis struct {
config *TaintConfig
passThroughContainer *map[string][][]int
initMap *map[string]*ssa.Function
interfaceHierarchy *InterfaceHierarchy
callGraph *CallGraph
callGraph *callgraph.Graph
ruler rule.Ruler
}

Expand Down Expand Up @@ -105,8 +105,6 @@ func New(g *graph.UnitGraph, c *TaintConfig) *TaintAnalysis {
taintAnalysis.passThroughContainer = c.PassThroughContainer
taintAnalysis.initMap = c.InitMap
taintAnalysis.passThrough = make([]*TaintWrapper, 0)
taintAnalysis.interfaceHierarchy = c.InterfaceHierarchy
taintAnalysis.callGraph = c.CallGraph
taintAnalysis.ruler = c.Ruler
f := taintAnalysis.Graph.Func

Expand Down
4 changes: 3 additions & 1 deletion pkg/example/dataflow/taint/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package taint

import (
"github.com/cokeBeer/goot/pkg/example/dataflow/taint/rule"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/ssa"
)

Expand All @@ -11,7 +12,8 @@ type TaintConfig struct {
InitMap *map[string]*ssa.Function
History *map[string]bool
InterfaceHierarchy *InterfaceHierarchy
CallGraph *CallGraph
TaintGraph *TaintGraph
CallGraph *callgraph.Graph
Ruler rule.Ruler
PassThroughOnly bool
TargetFunc string
Expand Down
8 changes: 8 additions & 0 deletions pkg/example/dataflow/taint/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package taint

type NoMainPkgError struct {
}

func (e *NoMainPkgError) Error() string {
return "No main package found in runner.PkgPath"
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import (
"golang.org/x/tools/go/ssa"
)

// CallGraph represents a graph contain static call nodes and edges
type CallGraph struct {
// TaintGraph represents a graph contain static call nodes and edges
type TaintGraph struct {
Nodes *map[string]*Node
Edges *map[string]*Edge
}

// NewCallGraph returns a CallGraph
func NewCallGraph(allFuncs *map[*ssa.Function]bool, ruler rule.Ruler) *CallGraph {
callGraph := new(CallGraph)
// NewTaintGraph returns a TaintGraph
func NewTaintGraph(allFuncs *map[*ssa.Function]bool, ruler rule.Ruler) *TaintGraph {
callGraph := new(TaintGraph)
nodes := make(map[string]*Node)
edges := make(map[string]*Edge)
callGraph.Nodes = &nodes
Expand Down
45 changes: 38 additions & 7 deletions pkg/example/dataflow/taint/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package taint

import (
"github.com/cokeBeer/goot/pkg/example/dataflow/taint/rule"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/pointer"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
Expand All @@ -11,12 +13,13 @@ import (
type Runner struct {
ModuleName string
PkgPath []string
UsePointerAnalysis bool
Debug bool
InitOnly bool
PassThroughOnly bool
PassThroughSrcPath []string
PassThroughDstPath string
CallGraphDstPath string
TaintGraphDstPath string
Ruler rule.Ruler
PersistToNeo4j bool
Neo4jUsername string
Expand All @@ -30,10 +33,11 @@ type Runner struct {
func NewRunner(PkgPath ...string) *Runner {
return &Runner{PkgPath: PkgPath, ModuleName: "",
PassThroughSrcPath: nil, PassThroughDstPath: "",
CallGraphDstPath: "", Ruler: nil,
TaintGraphDstPath: "", Ruler: nil,
Debug: false, InitOnly: false, PassThroughOnly: false,
PersistToNeo4j: false, Neo4jURI: "", Neo4jUsername: "", Neo4jPassword: "",
TargetFunc: "", PassBack: false}
TargetFunc: "", PassBack: false,
UsePointerAnalysis: false}
}

// Run kick off an analysis
Expand All @@ -59,13 +63,39 @@ func (r *Runner) Run() error {
funcs := ssautil.AllFunctions(prog)

interfaceHierarchy := NewInterfaceHierarchy(&funcs)

var callGraph *callgraph.Graph
if r.UsePointerAnalysis {
mainPkgs := make([]*ssa.Package, 0)
for _, pkg := range initial {
mainPkg := prog.Package(pkg.Types)
if mainPkg != nil && mainPkg.Pkg.Name() == "main" && mainPkg.Func("main") != nil {
mainPkgs = append(mainPkgs, mainPkg)
}
}
if len(mainPkgs) == 0 {
return new(NoMainPkgError)
}
config := &pointer.Config{
Mains: mainPkgs,
BuildCallGraph: true,
}

result, err := pointer.Analyze(config)
if err != nil {
return err
}
callGraph = result.CallGraph
callGraph.DeleteSyntheticNodes()
}

var ruler rule.Ruler
if r.Ruler != nil {
ruler = r.Ruler
} else {
ruler = NewDummyRuler(r.ModuleName)
}
callGraph := NewCallGraph(&funcs, ruler)
taintGraph := NewTaintGraph(&funcs, ruler)

passThroughContainter := make(map[string][][]int)
if r.PassThroughSrcPath != nil {
Expand All @@ -79,6 +109,7 @@ func (r *Runner) Run() error {
InitMap: &initMap,
History: &history,
InterfaceHierarchy: interfaceHierarchy,
TaintGraph: taintGraph,
CallGraph: callGraph,
Ruler: ruler,
PassThroughOnly: r.PassThroughOnly,
Expand Down Expand Up @@ -106,11 +137,11 @@ func (r *Runner) Run() error {
if r.PassThroughDstPath != "" {
PersistPassThrough(&passThroughContainter, r.PassThroughDstPath)
}
if r.CallGraphDstPath != "" {
PersistCallGraph(callGraph.Edges, r.CallGraphDstPath)
if r.TaintGraphDstPath != "" {
PersistCallGraph(taintGraph.Edges, r.TaintGraphDstPath)
}
if !r.PassThroughOnly && r.PersistToNeo4j {
PersistToNeo4j(callGraph.Nodes, callGraph.Edges, r.Neo4jURI, r.Neo4jUsername, r.Neo4jPassword)
PersistToNeo4j(taintGraph.Nodes, taintGraph.Edges, r.Neo4jURI, r.Neo4jUsername, r.Neo4jPassword)
}
return nil
}
68 changes: 44 additions & 24 deletions pkg/example/dataflow/taint/switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ func (s *TaintSwitcher) CaseCall(inst *ssa.Call) {
c := s.taintAnalysis.config
container := c.PassThroughContainer
init := s.taintAnalysis.initMap
// try to use pointer analysis to select callee
callGraph := s.taintAnalysis.config.CallGraph
if callGraph != nil && inst.Common().StaticCallee() == nil {
node := callGraph.Nodes[inst.Parent()]
if node != nil {
for _, edge := range node.Out {
if edge.Site == inst {
if inst.Call.Method != nil {
// invoke
s.passMethodTaint(edge.Callee.Func, inst)
} else {
// anonymous function
s.passStaticCallTaint(edge.Callee.Func, inst)
}
return
}
}
}
}
// try to use CHA to select callee
switch v := (inst.Call.Value).(type) {
case *ssa.Field:
// caller can be a field from a struct
Expand Down Expand Up @@ -774,7 +794,7 @@ func (s *TaintSwitcher) passInvokeTaint(f *types.Func, inst *ssa.Call) {
if !s.taintAnalysis.config.PassThroughOnly {
s.collectMethodEdges(f, inst)
}
interfaceHierarchy := s.taintAnalysis.interfaceHierarchy
interfaceHierarchy := s.taintAnalysis.config.InterfaceHierarchy
tiface := inst.Call.Value.Type().Underlying().(*types.Interface)
methods := interfaceHierarchy.LookupMethods(tiface, f)
if len(methods) != 0 {
Expand Down Expand Up @@ -892,7 +912,7 @@ func (s *TaintSwitcher) passFuncParamTaint(signature *types.Signature, inst *ssa
if !s.taintAnalysis.config.PassThroughOnly {
s.collectSignatureEdges(signature, inst)
}
interfaceHierarchy := s.taintAnalysis.interfaceHierarchy
interfaceHierarchy := s.taintAnalysis.config.InterfaceHierarchy
funcs := interfaceHierarchy.LookupFuncs(signature)
if len(funcs) != 0 {
s.passStaticCallTaint(funcs[0], inst)
Expand Down Expand Up @@ -928,7 +948,7 @@ func (s *TaintSwitcher) passCopyTaint(inst *ssa.Call) {
}

func (s *TaintSwitcher) collectCallEdges(f *ssa.Function, inst *ssa.Call) {
callGraph := s.taintAnalysis.callGraph
taintGraph := s.taintAnalysis.config.TaintGraph
if s.taintAnalysis.Graph.Func.Name() == "init" {
return
}
Expand All @@ -939,13 +959,13 @@ func (s *TaintSwitcher) collectCallEdges(f *ssa.Function, inst *ssa.Call) {
edge := Edge{From: s.taintAnalysis.Graph.Func.String(), FromIndex: k, To: f.String(), ToIndex: i}
key := s.taintAnalysis.Graph.Func.String() + "#" + strconv.Itoa(k)
key2 := f.String() + "#" + strconv.Itoa(i)
node := (*callGraph.Nodes)[key]
node2 := (*callGraph.Nodes)[key2]
node := (*taintGraph.Nodes)[key]
node2 := (*taintGraph.Nodes)[key2]
if node.IsIntra {
if _, ok := (*callGraph.Edges)[key+"#"+key2]; ok {
if _, ok := (*taintGraph.Edges)[key+"#"+key2]; ok {
continue
} else {
(*callGraph.Edges)[key+"#"+key2] = &edge
(*taintGraph.Edges)[key+"#"+key2] = &edge
}
node.Out = append(node.Out, &edge)
node2.In = append(node2.In, &edge)
Expand All @@ -961,30 +981,30 @@ func (s *TaintSwitcher) collectCallEdges(f *ssa.Function, inst *ssa.Call) {
func (s *TaintSwitcher) collectMethodEdges(f *types.Func, inst *ssa.Call) {
signature, ok := f.Type().(*types.Signature)
ruler := s.taintAnalysis.ruler
callGraph := s.taintAnalysis.callGraph
taintGraph := s.taintAnalysis.config.TaintGraph
if ok {
for name := range *GetTaint(s.outMap, inst.Call.Value.Name()) {
for k, v := range s.taintAnalysis.Graph.Func.Params {
if v.Name() == name {
edge := Edge{From: s.taintAnalysis.Graph.Func.String(), FromIndex: k, To: f.String(), ToIndex: 0}
key := s.taintAnalysis.Graph.Func.String() + "#" + strconv.Itoa(k)
key2 := f.String() + "#" + strconv.Itoa(0)
node := (*callGraph.Nodes)[key]
node := (*taintGraph.Nodes)[key]
if node.IsIntra {
node.Out = append(node.Out, &edge)
if _, ok := (*callGraph.Edges)[key+"#"+key2]; ok {
if _, ok := (*taintGraph.Edges)[key+"#"+key2]; ok {
continue
} else {
(*callGraph.Edges)[key+"#"+key2] = &edge
(*taintGraph.Edges)[key+"#"+key2] = &edge
}
if node2, ok := (*callGraph.Nodes)[key2]; ok {
if node2, ok := (*taintGraph.Nodes)[key2]; ok {
node2.In = append(node2.In, &edge)
passProperty(node2, &edge)
} else {
node2 := &Node{Canonical: signature.String(), Index: 0, Out: make([]*Edge, 0), In: make([]*Edge, 0), IsSignature: false, IsMethod: true, IsStatic: false}
decidePropertry(node2, ruler)
node2.In = append(node2.In, &edge)
(*callGraph.Nodes)[f.String()] = node2
(*taintGraph.Nodes)[f.String()] = node2
passProperty(node2, &edge)
}
}
Expand All @@ -999,22 +1019,22 @@ func (s *TaintSwitcher) collectMethodEdges(f *types.Func, inst *ssa.Call) {
edge := Edge{From: s.taintAnalysis.Graph.Func.String(), FromIndex: k, To: f.String(), ToIndex: i + 1}
key := s.taintAnalysis.Graph.Func.String() + "#" + strconv.Itoa(k)
key2 := f.String() + "#" + strconv.Itoa(0)
node := (*callGraph.Nodes)[key]
node := (*taintGraph.Nodes)[key]
if node.IsIntra {
node.Out = append(node.Out, &edge)
if _, ok := (*callGraph.Edges)[key+"#"+key2]; ok {
if _, ok := (*taintGraph.Edges)[key+"#"+key2]; ok {
continue
} else {
(*callGraph.Edges)[key+"#"+key2] = &edge
(*taintGraph.Edges)[key+"#"+key2] = &edge
}
if node2, ok := (*callGraph.Nodes)[key2]; ok {
if node2, ok := (*taintGraph.Nodes)[key2]; ok {
node2.In = append(node2.In, &edge)
passProperty(node2, &edge)
} else {
node2 := &Node{Canonical: signature.String(), Index: 0, Out: make([]*Edge, 0), In: make([]*Edge, 0), IsSignature: false, IsMethod: true, IsStatic: false}
decidePropertry(node2, ruler)
node2.In = append(node2.In, &edge)
(*callGraph.Nodes)[f.String()] = node2
(*taintGraph.Nodes)[f.String()] = node2
passProperty(node2, &edge)
}
}
Expand All @@ -1028,7 +1048,7 @@ func (s *TaintSwitcher) collectMethodEdges(f *types.Func, inst *ssa.Call) {
// collectSignatureEdges records node only use signature information
func (s *TaintSwitcher) collectSignatureEdges(signature *types.Signature, inst *ssa.Call) {
ruler := s.taintAnalysis.ruler
callGraph := s.taintAnalysis.callGraph
taintGraph := s.taintAnalysis.config.TaintGraph
n := signature.Params().Len()
for i := 0; i < n; i++ {
for name := range *GetTaint(s.outMap, inst.Call.Args[i].Name()) {
Expand All @@ -1037,22 +1057,22 @@ func (s *TaintSwitcher) collectSignatureEdges(signature *types.Signature, inst *
edge := Edge{From: s.taintAnalysis.Graph.Func.String(), FromIndex: k, To: signature.String(), ToIndex: i}
key := s.taintAnalysis.Graph.Func.String() + "#" + strconv.Itoa(k)
key2 := signature.String() + "#" + strconv.Itoa(0)
node := (*callGraph.Nodes)[key]
node := (*taintGraph.Nodes)[key]
if node.IsIntra {
node.Out = append(node.Out, &edge)
if _, ok := (*callGraph.Edges)[key+"#"+key2]; ok {
if _, ok := (*taintGraph.Edges)[key+"#"+key2]; ok {
continue
} else {
(*callGraph.Edges)[key+"#"+key2] = &edge
(*taintGraph.Edges)[key+"#"+key2] = &edge
}
if node2, ok := (*callGraph.Nodes)[key2]; ok {
if node2, ok := (*taintGraph.Nodes)[key2]; ok {
node2.In = append(node2.In, &edge)
passProperty(node2, &edge)
} else {
node2 := &Node{Canonical: signature.String(), Index: 0, Out: make([]*Edge, 0), In: make([]*Edge, 0), IsSignature: true, IsMethod: false, IsStatic: false}
decidePropertry(node2, ruler)
node2.In = append(node2.In, &edge)
(*callGraph.Nodes)[signature.String()] = node2
(*taintGraph.Nodes)[signature.String()] = node2
passProperty(node2, &edge)
}
}
Expand Down

0 comments on commit 8089bc4

Please sign in to comment.