Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: Data flow through tuple and struct fields #18131

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 131 additions & 21 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ module Node {
* Nodes corresponding to AST elements, for example `ExprNode`, usually refer
* to the value before the update.
*/
final class PostUpdateNode extends Node, TArgumentPostUpdateNode {
final class PostUpdateNode extends Node, TExprPostUpdateNode {
private ExprCfgNode n;

PostUpdateNode() { this = TArgumentPostUpdateNode(n) }
PostUpdateNode() { this = TExprPostUpdateNode(n) }

/** Gets the node before the state update. */
Node getPreUpdateNode() { result = TExprNode(n) }
Expand Down Expand Up @@ -449,6 +449,49 @@ private class VariantFieldContent extends VariantContent, TVariantFieldContent {
}
}

/** A canonical path pointing to a struct. */
private class StructCanonicalPath extends MkStructCanonicalPath {
CrateOriginOption crate;
string path;

StructCanonicalPath() { this = MkStructCanonicalPath(crate, path) }

/** Gets the underlying struct. */
Struct getStruct() { hasExtendedCanonicalPath(result, crate, path) }

string toString() { result = this.getStruct().getName().getText() }

Location getLocation() { result = this.getStruct().getLocation() }
}

/** Content stored in a field on a struct. */
private class StructFieldContent extends VariantContent, TStructFieldContent {
private StructCanonicalPath s;
private string field_;

StructFieldContent() { this = TStructFieldContent(s, field_) }

StructCanonicalPath getStructCanonicalPath(string field) { result = s and field = field_ }

override string toString() { result = s.toString() + "." + field_.toString() }
}

/**
* Content stored at a position in a tuple.
*
* NOTE: Unlike `struct`s and `enum`s tuples are structural and not nominal,
* hence we don't store a canonical path for them.
*/
private class TuplePositionContent extends VariantContent, TTuplePositionContent {
private int pos;

TuplePositionContent() { this = TTuplePositionContent(pos) }

int getPosition() { result = pos }

override string toString() { result = "tuple." + pos.toString() }
}

/** A value that represents a set of `Content`s. */
abstract class ContentSet extends TContentSet {
/** Gets a textual representation of this element. */
Expand Down Expand Up @@ -608,6 +651,14 @@ module RustDataFlow implements InputSig<Location> {
*/
predicate jumpStep(Node node1, Node node2) { none() }

/** Holds if path `p` resolves to struct `s`. */
private predicate pathResolveToStructCanonicalPath(Path p, StructCanonicalPath s) {
exists(CrateOriginOption crate, string path |
resolveExtendedCanonicalPath(p, crate, path) and
s = MkStructCanonicalPath(crate, path)
)
}

/** Holds if path `p` resolves to variant `v`. */
private predicate pathResolveToVariantCanonicalPath(Path p, VariantCanonicalPath v) {
exists(CrateOriginOption crate, string path |
Expand Down Expand Up @@ -636,6 +687,12 @@ module RustDataFlow implements InputSig<Location> {
pathResolveToVariantCanonicalPath(p.getPath(), v)
}

/** Holds if `p` destructs an struct `s`. */
pragma[nomagic]
private predicate structDestruction(RecordPat p, StructCanonicalPath s) {
pathResolveToStructCanonicalPath(p.getPath(), s)
}

/**
* Holds if data can flow from `node1` to `node2` via a read of `c`. Thus,
* `node1` references an object with a content `c.getAReadContent()` whose
Expand All @@ -652,10 +709,24 @@ module RustDataFlow implements InputSig<Location> {
or
exists(RecordPatCfgNode pat, string field |
pat = node1.asPat() and
recordVariantDestruction(pat.getPat(),
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
(
// Pattern destructs a struct-like variant.
recordVariantDestruction(pat.getPat(),
c.(VariantFieldContent).getVariantCanonicalPath(field))
or
// Pattern destructs a struct.
structDestruction(pat.getPat(), c.(StructFieldContent).getStructCanonicalPath(field))
) and
node2.asPat() = pat.getFieldPat(field)
)
or
exists(FieldExprCfgNode access |
// Read of a tuple entry
access.getNameRef().getText().toInt() = c.(TuplePositionContent).getPosition() and
// TODO: Handle read of a struct field.
node1.asExpr() = access.getExpr() and
node2.asExpr() = access
)
)
}

Expand All @@ -671,30 +742,44 @@ module RustDataFlow implements InputSig<Location> {
pathResolveToVariantCanonicalPath(re.getPath(), v)
}

/** Holds if `re` constructs a struct value of type `v`. */
pragma[nomagic]
private predicate structConstruction(RecordExpr re, StructCanonicalPath s) {
pathResolveToStructCanonicalPath(re.getPath(), s)
}

/**
* Holds if data can flow from `node1` to `node2` via a store into `c`. Thus,
* `node2` references an object with a content `c.getAStoreContent()` that
* contains the value of `node1`.
*/
predicate storeStep(Node node1, ContentSet cs, Node node2) {
exists(Content c | c = cs.(SingletonContentSet).getContent() |
node2.asExpr() =
any(CallExprCfgNode call, int pos |
tupleVariantConstruction(call.getCallExpr(),
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node1.asExpr() = call.getArgument(pos)
|
call
)
exists(CallExprCfgNode call, int pos |
tupleVariantConstruction(call.getCallExpr(),
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node1.asExpr() = call.getArgument(pos) and
node2.asExpr() = call
)
or
node2.asExpr() =
any(RecordExprCfgNode re, string field |
exists(RecordExprCfgNode re, string field |
(
// Expression is for a struct-like enum variant.
recordVariantConstruction(re.getRecordExpr(),
c.(VariantFieldContent).getVariantCanonicalPath(field)) and
node1.asExpr() = re.getFieldExpr(field)
|
re
)
c.(VariantFieldContent).getVariantCanonicalPath(field))
or
// Expression is for a struct.
structConstruction(re.getRecordExpr(),
c.(StructFieldContent).getStructCanonicalPath(field))
) and
node1.asExpr() = re.getFieldExpr(field) and
node2.asExpr() = re
)
or
exists(TupleExprCfgNode tuple |
node1.asExpr() = tuple.getField(c.(TuplePositionContent).getPosition()) and
node2.asExpr() = tuple
)
)
}

Expand All @@ -703,7 +788,14 @@ module RustDataFlow implements InputSig<Location> {
* any value stored inside `f` is cleared at the pre-update node associated with `x`
* in `x.f = newValue`.
*/
predicate clearsContent(Node n, ContentSet c) { none() }
predicate clearsContent(Node n, ContentSet c) {
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
assignment.getLhs() = access and
n.asExpr() = access.getExpr() and
access.getNameRef().getText().toInt() =
c.(SingletonContentSet).getContent().(TuplePositionContent).getPosition()
)
}

/**
* Holds if the value that is being tracked is expected to be stored inside content `c`
Expand Down Expand Up @@ -773,7 +865,9 @@ private module Cached {
TExprNode(ExprCfgNode n) or
TParameterNode(ParamBaseCfgNode p) or
TPatNode(PatCfgNode p) or
TArgumentPostUpdateNode(ExprCfgNode e) { isArgumentForCall(e, _, _) } or
TExprPostUpdateNode(ExprCfgNode e) {
isArgumentForCall(e, _, _) or e = any(FieldExprCfgNode access).getExpr()
} or
TSsaNode(SsaImpl::DataFlowIntegration::SsaNode node)

cached
Expand Down Expand Up @@ -811,6 +905,12 @@ private module Cached {
name = ["Ok", "Err"]
}

cached
newtype TStructCanonicalPath =
MkStructCanonicalPath(CrateOriginOption crate, string path) {
exists(Struct s | hasExtendedCanonicalPath(s, crate, path))
}

cached
newtype TContent =
TVariantPositionContent(VariantCanonicalPath v, int pos) {
Expand All @@ -826,6 +926,16 @@ private module Cached {
} or
TVariantFieldContent(VariantCanonicalPath v, string field) {
field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText()
} or
TTuplePositionContent(int pos) {
pos in [0 .. max([
any(TuplePat pat).getNumberOfFields(),
any(FieldExpr access).getNameRef().getText().toInt()
]
)]
} or
TStructFieldContent(StructCanonicalPath s, string field) {
field = s.getStruct().getFieldList().(RecordFieldList).getAField().getName().getText()
}

cached
Expand Down
Loading