From 474e2bee487b63d155d28eef95dca9a260dcebe4 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Fri, 22 Nov 2024 14:53:45 +0100 Subject: [PATCH] core/state: refactor journal-fuzzer core/state: work on fuzztest for journals/state --- core/state/journal_test.go | 395 ++++++++++++++++++++++--------------- 1 file changed, 236 insertions(+), 159 deletions(-) diff --git a/core/state/journal_test.go b/core/state/journal_test.go index 22f6b7787bb9f..4ec36188e9b9c 100644 --- a/core/state/journal_test.go +++ b/core/state/journal_test.go @@ -18,8 +18,12 @@ package state import ( + "bytes" + "crypto/rand" + "encoding/binary" "fmt" - "math/rand/v2" + "io" + "slices" "testing" "github.com/ethereum/go-ethereum/common" @@ -135,173 +139,246 @@ func testJournalRefunds(t *testing.T, j journal) { } } -func FuzzJournals(f *testing.F) { +type fuzzReader struct { + input io.Reader + exhausted bool +} - randByte := func() byte { - return byte(rand.Int()) - } - randBool := func() bool { - return rand.Int()%2 == 0 - } - randAccount := func() *types.StateAccount { - return &types.StateAccount{ - Nonce: uint64(randByte()), - Balance: uint256.NewInt(uint64(randByte())), - Root: types.EmptyRootHash, - CodeHash: types.EmptyCodeHash[:], - } +func (f *fuzzReader) byte() byte { + return f.bytes(1)[0] +} + +func (f *fuzzReader) bytes(n int) []byte { + r := make([]byte, n) + if _, err := f.input.Read(r); err != nil { + f.exhausted = true } + return r +} - f.Fuzz(func(t *testing.T, operations []byte) { - var ( - statedb1, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) - statedb2, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) - linear = newLinearJournal() - sparse = newSparseJournal() - ) - statedb1.journal = linear - statedb2.journal = sparse - linear.snapshot() - sparse.snapshot() +func newEmptyState() *StateDB { + s, _ := New(types.EmptyRootHash, NewDatabaseForTesting()) + return s +} - for _, o := range operations { - switch o { - case 0: - addr := randByte() - linear.accessListAddAccount(common.Address{addr}) - sparse.accessListAddAccount(common.Address{addr}) - statedb1.accessList.AddAddress(common.Address{addr}) - statedb2.accessList.AddAddress(common.Address{addr}) - case 1: - addr := randByte() - slot := randByte() - linear.accessListAddSlot(common.Address{addr}, common.Hash{slot}) - sparse.accessListAddSlot(common.Address{addr}, common.Hash{slot}) - statedb1.accessList.AddSlot(common.Address{addr}, common.Hash{slot}) - statedb2.accessList.AddSlot(common.Address{addr}, common.Hash{slot}) - case 2: - addr := randByte() - account := randAccount() - destructed := randBool() - newContract := randBool() - linear.balanceChange(common.Address{addr}, account, destructed, newContract) - sparse.balanceChange(common.Address{addr}, account, destructed, newContract) - case 3: - linear = linear.copy().(*linearJournal) - sparse = sparse.copy().(*sparseJournal) - case 4: - addr := randByte() - account := randAccount() - linear.createContract(common.Address{addr}, account) - sparse.createContract(common.Address{addr}, account) - case 5: - addr := randByte() - linear.createObject(common.Address{addr}) - sparse.createObject(common.Address{addr}) - case 6: - addr := randByte() - account := randAccount() - linear.destruct(common.Address{addr}, account) - sparse.destruct(common.Address{addr}, account) - case 7: - txHash := randByte() - linear.logChange(common.Hash{txHash}) - sparse.logChange(common.Hash{txHash}) - case 8: - addr := randByte() - account := randAccount() - destructed := randBool() - newContract := randBool() - linear.nonceChange(common.Address{addr}, account, destructed, newContract) - sparse.nonceChange(common.Address{addr}, account, destructed, newContract) - case 9: - refund := randByte() - linear.refundChange(uint64(refund)) - sparse.refundChange(uint64(refund)) - case 10: - addr := randByte() - account := randAccount() - linear.setCode(common.Address{addr}, account) - sparse.setCode(common.Address{addr}, account) - case 11: - addr := randByte() - key := randByte() - prev := randByte() - origin := randByte() - linear.storageChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}, common.Hash{origin}) - sparse.storageChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}, common.Hash{origin}) - case 12: - addr := randByte() - account := randAccount() - destructed := randBool() - newContract := randBool() - linear.touchChange(common.Address{addr}, account, destructed, newContract) - sparse.touchChange(common.Address{addr}, account, destructed, newContract) - case 13: - addr := randByte() - key := randByte() - prev := randByte() - linear.transientStateChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}) - sparse.transientStateChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}) - case 14: - linear.reset() - sparse.reset() - case 15: - linear.snapshot() - sparse.snapshot() - case 16: - linear.discardSnapshot() - sparse.discardSnapshot() - case 17: - linear.revertSnapshot(statedb1) - sparse.revertSnapshot(statedb2) - case 18: - accs1 := linear.dirtyAccounts() - accs2 := linear.dirtyAccounts() - if len(accs1) != len(accs2) { - panic(fmt.Sprintf("mismatched accounts: %v %v", accs1, accs2)) +// fuzzJournals is pretty similar to `TestSnapshotRandom`/ `newTestAction` in +// statedb_test.go. They both execute a sequence of state-actions, however, they +// test for different aspects. +// This test compares two differing journal-implementations. +// The other test compares every point in time, whether it is identical when going +// forward as when going backwards through the journal entries. +func fuzzJournals(t *testing.T, data []byte) { + var ( + reader = fuzzReader{input: bytes.NewReader(data)} + stateDbs = []*StateDB{ + newEmptyState(), + newEmptyState(), + } + ) + apply := func(action func(stateDbs *StateDB)) { + for _, sdb := range stateDbs { + action(sdb) + } + } + stateDbs[0].journal = newLinearJournal() + stateDbs[1].journal = newSparseJournal() + for !reader.exhausted { + op := reader.byte() % 18 + switch op { + case 0: // Add account to access lists + addr := common.BytesToAddress(reader.bytes(1)) + t.Logf("Op %d: Add to access list %#x", op, addr) + apply(func(sdb *StateDB) { + sdb.accessList.AddAddress(addr) + }) + case 1: // Add slot to access list + addr := common.BytesToAddress(reader.bytes(1)) + slot := common.BytesToHash(reader.bytes(1)) + t.Logf("Op %d: Add addr:slot to access list %#x : %#x", op, addr, slot) + apply(func(sdb *StateDB) { + sdb.AddSlotToAccessList(addr, slot) + }) + case 2: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + value = uint64(reader.byte()) + ) + t.Logf("Op %d: Add balance %#x %d", op, addr, value) + apply(func(sdb *StateDB) { + sdb.AddBalance(addr, uint256.NewInt(value), 0) + }) + case 3: + t.Logf("Op %d: Copy journals[0]", op) + stateDbs[0].journal = stateDbs[0].journal.copy() + case 4: + t.Logf("Op %d: Copy journals[1]", op) + stateDbs[1].journal = stateDbs[1].journal.copy() + case 5: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + code = reader.bytes(2) + ) + t.Logf("Op %d: (Create and) set code 0x%x", op, addr) + apply(func(s *StateDB) { + if !s.Exist(addr) { + s.CreateAccount(addr) } - for _, val := range accs1 { - found := false - for _, val2 := range accs2 { - if val == val2 { - if found { - panic(fmt.Sprintf("account found twice: %v %v account %v", accs1, accs2, val)) - } - found = true - } - } - if !found { - panic(fmt.Sprintf("missing account: %v %v account %v", accs1, accs2, val)) + contractHash := s.GetCodeHash(addr) + emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash + storageRoot := s.GetStorageRoot(addr) + emptyStorage := storageRoot == (common.Hash{}) || storageRoot == types.EmptyRootHash + + if obj := s.getStateObject(addr); obj != nil { + if obj.selfDestructed { + // If it's selfdestructed, we cannot create into it + return } } - } - } - // After all operations have been processed, verify equality - accs1 := linear.dirtyAccounts() - accs2 := linear.dirtyAccounts() - for _, val := range accs1 { - found := false - for _, val2 := range accs2 { - if val == val2 { - if found { - panic(fmt.Sprintf("account found twice: %v %v account %v", accs1, accs2, val)) - } - found = true + if s.GetNonce(addr) == 0 && emptyCode && emptyStorage { + s.CreateContract(addr) + // We also set some code here, to prevent the + // CreateContract action from being performed twice in a row, + // which would cause a difference in state when unrolling + // the linearJournal. (CreateContact assumes created was false prior to + // invocation, and the linearJournal rollback sets it to false). + s.SetCode(addr, code) } - } - if !found { - panic(fmt.Sprintf("missing account: %v %v account %v", accs1, accs2, val)) - } - } - h1, err1 := statedb1.Commit(0, false) - h2, err2 := statedb2.Commit(0, false) - if err1 != err2 { - panic(fmt.Sprintf("mismatched errors: %v %v", err1, err2)) + }) + case 6: + addr := common.BytesToAddress(reader.bytes(1)) + t.Logf("Op %d: Create 0x%x", op, addr) + apply(func(sdb *StateDB) { + if !sdb.Exist(addr) { + sdb.CreateAccount(addr) + } + }) + case 7: + addr := common.BytesToAddress(reader.bytes(1)) + t.Logf("Op %d: (Create and) destruct 0x%x", op, addr) + apply(func(s *StateDB) { + if !s.Exist(addr) { + s.CreateAccount(addr) + } + s.SelfDestruct(addr) + }) + case 8: + txHash := common.BytesToHash(reader.bytes(1)) + t.Logf("Op %d: Add log %#x", op, txHash) + apply(func(sdb *StateDB) { + sdb.logs[txHash] = append(sdb.logs[txHash], new(types.Log)) + sdb.logSize++ + sdb.journal.logChange(txHash) + }) + case 9: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + nonce = binary.BigEndian.Uint64(reader.bytes(8)) + ) + t.Logf("Op %d: Set nonce %#x %d", op, addr, nonce) + apply(func(sdb *StateDB) { + sdb.SetNonce(addr, nonce) + }) + case 10: + refund := uint64(reader.byte()) + t.Logf("Op %d: Set refund %d", op, refund) + apply(func(sdb *StateDB) { + sdb.journal.refundChange(refund) + }) + case 11: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + key = common.BytesToHash(reader.bytes(1)) + val = common.BytesToHash(reader.bytes(1)) + ) + t.Logf("Op %d: Set storage %#x [%#x]=%#x", op, addr, key, val) + apply(func(sdb *StateDB) { + sdb.SetState(addr, key, val) + }) + case 12: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + ) + t.Logf("Op %d: Zero-balance transfer (touch) %#x", op, addr) + apply(func(sdb *StateDB) { + sdb.AddBalance(addr, uint256.NewInt(0), 0) + }) + case 13: + var ( + addr = common.BytesToAddress(reader.bytes(1)) + key = common.BytesToHash(reader.bytes(1)) + value = common.BytesToHash(reader.bytes(1)) + ) + t.Logf("Op %d: Set t-storage %#x [%#x]=%#x", op, addr, key, value) + apply(func(sdb *StateDB) { + sdb.SetTransientState(addr, key, value) + }) + case 14: + t.Logf("Op %d: Reset journal", op) + apply(func(sdb *StateDB) { + sdb.journal.reset() + }) + case 15: + t.Logf("Op %d: Snapshot", op) + apply(func(sdb *StateDB) { + sdb.Snapshot() + }) + case 16: + t.Logf("Op %d: Discard snapshot", op) + apply(func(sdb *StateDB) { + sdb.DiscardSnapshot() + }) + + case 17: + t.Logf("Op %d: Revert snapshot", op) + apply(func(sdb *StateDB) { + sdb.RevertSnapshot() + }) } - if h1 != h2 { - panic(fmt.Sprintf("mismatched roots: %v %v", h1, h2)) + // Cross-check the dirty-sets + accs1 := stateDbs[0].journal.dirtyAccounts() + slices.SortFunc(accs1, func(a, b common.Address) int { + return bytes.Compare(a.Bytes(), b.Bytes()) + }) + accs2 := stateDbs[1].journal.dirtyAccounts() + slices.SortFunc(accs2, func(a, b common.Address) int { + return bytes.Compare(a.Bytes(), b.Bytes()) + }) + if !slices.Equal(accs1, accs2) { + t.Fatalf("mismatched dirty-sets:\n%v\n%v", accs1, accs2) } - }) + } + h1, err1 := stateDbs[0].Commit(0, false) + h2, err2 := stateDbs[1].Commit(0, false) + if err1 != err2 { + t.Fatalf("Mismatched errors: %v %v", err1, err2) + } + if h1 != h2 { + t.Fatalf("Mismatched roots: %v %v", h1, h2) + } +} + +// FuzzJournals fuzzes the journals. +func FuzzJournals(f *testing.F) { + f.Fuzz(fuzzJournals) +} + +// TestFuzzJournals runs 200 fuzz-tests +func TestFuzzJournals(t *testing.T) { + input := make([]byte, 200) + for i := 0; i < 200; i++ { + rand.Read(input) + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + t.Logf("input: %x", input) + fuzzJournals(t, input) + }) + } +} + +// TestFuzzJournalsSpecific can be used to test a specific input +func TestFuzzJournalsSpecific(t *testing.T) { + t.Skip("example") + input := common.FromHex("71d598d781f65eb7c047fed5d09b1e4e0c1ecad5c447a2149e7d1137fcb1b1d63f4ba6f761918a441a98eb61d69fe011cabfbce00d74bb78539ca9946a602e94d6eabc43c0924ba65ce3e171b476208059d81f33e81d90607e0b6e59d6016840b5c4e9b1a8e9798a5a40be909930658eea351d7a312dba0b1c7199c7e5f62a908a80f7faf29bc0108faae0cf0f497d0f4cd228b7600ef0d88532dfafa6349ea7782f28ad7426eeffc155282a9e58a606d25acd8a730dde61a6e5e887d1ba1fea813bb7f2c6caff25") + fuzzJournals(t, input) }