From 01fb56b4f1eb0313e41c25ef08c1b543bd153d0e Mon Sep 17 00:00:00 2001 From: Yexiang Zhang Date: Thu, 2 Nov 2023 17:38:10 +0800 Subject: [PATCH 01/26] dashboard: update hotfix version (#7303) close tikv/pd#7302 Signed-off-by: mornyx --- go.mod | 2 +- go.sum | 4 ++-- tests/integrations/client/go.mod | 2 +- tests/integrations/client/go.sum | 4 ++-- tests/integrations/mcs/go.mod | 2 +- tests/integrations/mcs/go.sum | 4 ++-- tests/integrations/tso/go.mod | 2 +- tests/integrations/tso/go.sum | 4 ++-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 86f56089347..e8da2542be2 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 - github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 + github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 github.com/prometheus/client_golang v1.11.1 github.com/prometheus/common v0.26.0 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/go.sum b/go.sum index 9392644f181..28e210ef1cd 100644 --- a/go.sum +++ b/go.sum @@ -446,8 +446,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 h1:oyrCfNlAWmLlUfEr+7YTSBo29SP/J1N8hnxBt5yUABo= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/client/go.mod b/tests/integrations/client/go.mod index e38efbeb438..b9b868cf8e3 100644 --- a/tests/integrations/client/go.mod +++ b/tests/integrations/client/go.mod @@ -119,7 +119,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/client/go.sum b/tests/integrations/client/go.sum index c745c4fa518..81fa6fd7b39 100644 --- a/tests/integrations/client/go.sum +++ b/tests/integrations/client/go.sum @@ -410,8 +410,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 h1:oyrCfNlAWmLlUfEr+7YTSBo29SP/J1N8hnxBt5yUABo= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/mcs/go.mod b/tests/integrations/mcs/go.mod index 000bfdc8312..c2dfdbe96ef 100644 --- a/tests/integrations/mcs/go.mod +++ b/tests/integrations/mcs/go.mod @@ -119,7 +119,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/mcs/go.sum b/tests/integrations/mcs/go.sum index 0da75329284..d1b0962ab55 100644 --- a/tests/integrations/mcs/go.sum +++ b/tests/integrations/mcs/go.sum @@ -414,8 +414,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 h1:oyrCfNlAWmLlUfEr+7YTSBo29SP/J1N8hnxBt5yUABo= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/tso/go.mod b/tests/integrations/tso/go.mod index f8a5cfac75f..e5131f15d91 100644 --- a/tests/integrations/tso/go.mod +++ b/tests/integrations/tso/go.mod @@ -117,7 +117,7 @@ require ( github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/tso/go.sum b/tests/integrations/tso/go.sum index 63327985f0d..576c3e75765 100644 --- a/tests/integrations/tso/go.sum +++ b/tests/integrations/tso/go.sum @@ -408,8 +408,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511 h1:oyrCfNlAWmLlUfEr+7YTSBo29SP/J1N8hnxBt5yUABo= -github.com/pingcap/tidb-dashboard v0.0.0-20230911054332-22add1e00511/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= +github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= From 689fcbe2ff081e96ecad2762dd0b89c07364ffc5 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 3 Nov 2023 10:09:39 +0800 Subject: [PATCH 02/26] checker: replace down check with disconnect check when fixing orphan peer (#7294) close tikv/pd#7249 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/core/store.go | 3 + pkg/schedule/checker/rule_checker.go | 60 +++++--- pkg/schedule/checker/rule_checker_test.go | 176 +++++++++++++++++++++- 3 files changed, 214 insertions(+), 25 deletions(-) diff --git a/pkg/core/store.go b/pkg/core/store.go index 1d3362cac0e..b3c62f45750 100644 --- a/pkg/core/store.go +++ b/pkg/core/store.go @@ -551,6 +551,9 @@ var ( // tikv's store heartbeat for a short time, maybe caused by process restart or // temporary network failure. func (s *StoreInfo) IsDisconnected() bool { + if s == nil { + return true + } return s.DownTime() > storeDisconnectDuration } diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index 7012359ca36..84cafaa871e 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -447,7 +447,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg if len(fit.OrphanPeers) == 0 { return nil, nil } - var pinDownPeer *metapb.Peer + isUnhealthyPeer := func(id uint64) bool { for _, downPeer := range region.GetDownPeers() { if downPeer.Peer.GetId() == id { @@ -461,31 +461,41 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg } return false } + + isDisconnectedPeer := func(p *metapb.Peer) bool { + // avoid to meet down store when fix orphan peers, + // Isdisconnected is more strictly than IsUnhealthy. + return c.cluster.GetStore(p.GetStoreId()).IsDisconnected() + } + + checkDownPeer := func(peers []*metapb.Peer) (*metapb.Peer, bool) { + for _, p := range peers { + if isUnhealthyPeer(p.GetId()) { + // make sure is down peer. + if region.GetDownPeer(p.GetId()) != nil { + return p, true + } + return nil, true + } + if isDisconnectedPeer(p) { + return p, true + } + } + return nil, false + } + // remove orphan peers only when all rules are satisfied (count+role) and all peers selected // by RuleFits is not pending or down. + var pinDownPeer *metapb.Peer hasUnhealthyFit := false -loopFits: for _, rf := range fit.RuleFits { if !rf.IsSatisfied() { hasUnhealthyFit = true break } - for _, p := range rf.Peers { - if isUnhealthyPeer(p.GetId()) { - // make sure is down peer. - if region.GetDownPeer(p.GetId()) != nil { - pinDownPeer = p - } - hasUnhealthyFit = true - break loopFits - } - // avoid to meet down store when fix orpahn peers, - // Isdisconnected is more strictly than IsUnhealthy. - if c.cluster.GetStore(p.GetStoreId()).IsDisconnected() { - hasUnhealthyFit = true - pinDownPeer = p - break loopFits - } + pinDownPeer, hasUnhealthyFit = checkDownPeer(rf.Peers) + if hasUnhealthyFit { + break } } @@ -502,15 +512,15 @@ loopFits: continue } // make sure the orphan peer is healthy. - if isUnhealthyPeer(orphanPeer.GetId()) { + if isUnhealthyPeer(orphanPeer.GetId()) || isDisconnectedPeer(orphanPeer) { continue } // no consider witness in this path. if pinDownPeer.GetIsWitness() || orphanPeer.GetIsWitness() { continue } - // down peer's store should be down. - if !c.isStoreDownTimeHitMaxDownTime(pinDownPeer.GetStoreId()) { + // down peer's store should be disconnected + if !isDisconnectedPeer(pinDownPeer) { continue } // check if down peer can replace with orphan peer. @@ -525,7 +535,7 @@ loopFits: case orphanPeerRole == metapb.PeerRole_Voter && destRole == metapb.PeerRole_Learner: return operator.CreateDemoteLearnerOperatorAndRemovePeer("replace-down-peer-with-orphan-peer", c.cluster, region, orphanPeer, pinDownPeer) case orphanPeerRole == metapb.PeerRole_Voter && destRole == metapb.PeerRole_Voter && - c.cluster.GetStore(pinDownPeer.GetStoreId()).IsDisconnected() && !dstStore.IsDisconnected(): + isDisconnectedPeer(pinDownPeer) && !dstStore.IsDisconnected(): return operator.CreateRemovePeerOperator("remove-replaced-orphan-peer", c.cluster, 0, region, pinDownPeer.GetStoreId()) default: // destRole should not same with orphanPeerRole. if role is same, it fit with orphanPeer should be better than now. @@ -542,7 +552,11 @@ loopFits: for _, orphanPeer := range fit.OrphanPeers { if isUnhealthyPeer(orphanPeer.GetId()) { ruleCheckerRemoveOrphanPeerCounter.Inc() - return operator.CreateRemovePeerOperator("remove-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) + return operator.CreateRemovePeerOperator("remove-unhealthy-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) + } + if isDisconnectedPeer(orphanPeer) { + ruleCheckerRemoveOrphanPeerCounter.Inc() + return operator.CreateRemovePeerOperator("remove-disconnected-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) } if hasHealthPeer { // there already exists a healthy orphan peer, so we can remove other orphan Peers. diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index 8ee3b1eccfa..0c4a2a9ecc9 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -235,7 +235,7 @@ func (suite *ruleCheckerTestSuite) TestFixToManyOrphanPeers() { suite.cluster.PutRegion(region) op = suite.rc.Check(suite.cluster.GetRegion(1)) suite.NotNil(op) - suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal("remove-unhealthy-orphan-peer", op.Desc()) suite.Equal(uint64(4), op.Step(0).(operator.RemovePeer).FromStore) } @@ -702,7 +702,7 @@ func (suite *ruleCheckerTestSuite) TestPriorityFixOrphanPeer() { suite.cluster.PutRegion(testRegion) op = suite.rc.Check(suite.cluster.GetRegion(1)) suite.NotNil(op) - suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal("remove-unhealthy-orphan-peer", op.Desc()) suite.IsType(remove, op.Step(0)) // Ref #3521 suite.cluster.SetStoreOffline(2) @@ -723,6 +723,178 @@ func (suite *ruleCheckerTestSuite) TestPriorityFixOrphanPeer() { suite.Equal("remove-orphan-peer", op.Desc()) } +// Ref https://github.com/tikv/pd/issues/7249 https://github.com/tikv/tikv/issues/15799 +func (suite *ruleCheckerTestSuite) TestFixOrphanPeerWithDisconnectedStoreAndRuleChanged() { + // init cluster with 5 replicas + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + storeIDs := []uint64{1, 2, 3, 4, 5} + suite.cluster.AddLeaderRegionWithRange(1, "", "", storeIDs[0], storeIDs[1:]...) + rule := &placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 5, + StartKey: []byte{}, + EndKey: []byte{}, + } + suite.ruleManager.SetRule(rule) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + + // set store 1, 2 to disconnected + suite.cluster.SetStoreDisconnect(1) + suite.cluster.SetStoreDisconnect(2) + + // change rule to 3 replicas + rule = &placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 3, + StartKey: []byte{}, + EndKey: []byte{}, + Override: true, + } + suite.ruleManager.SetRule(rule) + + // remove store 1 from region 1 + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-replaced-orphan-peer", op.Desc()) + suite.Equal(op.Len(), 2) + newLeaderID := op.Step(0).(operator.TransferLeader).ToStore + removedPeerID := op.Step(1).(operator.RemovePeer).FromStore + r1 := suite.cluster.GetRegion(1) + r1 = r1.Clone( + core.WithLeader(r1.GetPeer(newLeaderID)), + core.WithRemoveStorePeer(removedPeerID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 4) + + // remove store 2 from region 1 + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-replaced-orphan-peer", op.Desc()) + suite.Equal(op.Len(), 1) + removedPeerID = op.Step(0).(operator.RemovePeer).FromStore + r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 3) + for _, p := range r1.GetPeers() { + suite.NotEqual(p.GetStoreId(), 1) + suite.NotEqual(p.GetStoreId(), 2) + } +} + +// Ref https://github.com/tikv/pd/issues/7249 https://github.com/tikv/tikv/issues/15799 +func (suite *ruleCheckerTestSuite) TestFixOrphanPeerWithDisconnectedStoreAndRuleChanged2() { + // init cluster with 5 voters and 1 learner + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + suite.cluster.AddLabelsStore(6, 1, map[string]string{"host": "host6"}) + storeIDs := []uint64{1, 2, 3, 4, 5} + suite.cluster.AddLeaderRegionWithRange(1, "", "", storeIDs[0], storeIDs[1:]...) + r1 := suite.cluster.GetRegion(1) + r1 = r1.Clone(core.WithAddPeer(&metapb.Peer{Id: 6, StoreId: 6, Role: metapb.PeerRole_Learner})) + suite.cluster.PutRegion(r1) + err := suite.ruleManager.SetRules([]*placement.Rule{ + { + GroupID: "pd", + ID: "default", + Index: 100, + Override: true, + Role: placement.Voter, + Count: 5, + IsWitness: false, + }, + { + GroupID: "pd", + ID: "r1", + Index: 100, + Override: false, + Role: placement.Learner, + Count: 1, + IsWitness: false, + }, + }) + suite.NoError(err) + + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + + // set store 1, 2 to disconnected + suite.cluster.SetStoreDisconnect(1) + suite.cluster.SetStoreDisconnect(2) + suite.cluster.SetStoreDisconnect(3) + + // change rule to 3 replicas + suite.ruleManager.DeleteRuleGroup("pd") + suite.ruleManager.SetRule(&placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 2, + StartKey: []byte{}, + EndKey: []byte{}, + Override: true, + }) + + // remove store 1 from region 1 + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-replaced-orphan-peer", op.Desc()) + suite.Equal(op.Len(), 2) + newLeaderID := op.Step(0).(operator.TransferLeader).ToStore + removedPeerID := op.Step(1).(operator.RemovePeer).FromStore + r1 = suite.cluster.GetRegion(1) + r1 = r1.Clone( + core.WithLeader(r1.GetPeer(newLeaderID)), + core.WithRemoveStorePeer(removedPeerID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 5) + + // remove store 2 from region 1 + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-replaced-orphan-peer", op.Desc()) + suite.Equal(op.Len(), 1) + removedPeerID = op.Step(0).(operator.RemovePeer).FromStore + r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 4) + for _, p := range r1.GetPeers() { + fmt.Println(p.GetStoreId(), p.Role.String()) + } + + // remove store 3 from region 1 + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-replaced-orphan-peer", op.Desc()) + suite.Equal(op.Len(), 1) + removedPeerID = op.Step(0).(operator.RemovePeer).FromStore + r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 3) + + for _, p := range r1.GetPeers() { + suite.NotEqual(p.GetStoreId(), 1) + suite.NotEqual(p.GetStoreId(), 2) + suite.NotEqual(p.GetStoreId(), 3) + } +} + func (suite *ruleCheckerTestSuite) TestPriorityFitHealthWithDifferentRole1() { suite.cluster.SetEnableUseJointConsensus(true) suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) From ab8bf7b7a62d43abd9a8d5213fc3b5855472bd9e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Mon, 6 Nov 2023 10:14:10 +0800 Subject: [PATCH 03/26] mcs: fix duplicated metrics (#7319) close tikv/pd#7290 Signed-off-by: Ryan Leung --- pkg/mcs/resourcemanager/server/apis/v1/api.go | 2 +- pkg/mcs/scheduling/server/apis/v1/api.go | 2 +- pkg/mcs/tso/server/apis/v1/api.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/apis/v1/api.go b/pkg/mcs/resourcemanager/server/apis/v1/api.go index ffcb9318590..7b5f2903484 100644 --- a/pkg/mcs/resourcemanager/server/apis/v1/api.go +++ b/pkg/mcs/resourcemanager/server/apis/v1/api.go @@ -81,10 +81,10 @@ func NewService(srv *rmserver.Service) *Service { c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer()) c.Next() }) - apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) apiHandlerEngine.GET("metrics", utils.PromHandler()) pprof.Register(apiHandlerEngine) endpoint := apiHandlerEngine.Group(APIPathPrefix) + endpoint.Use(multiservicesapi.ServiceRedirector()) s := &Service{ manager: manager, apiHandlerEngine: apiHandlerEngine, diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 356dc5a7f42..98fb68c090b 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -100,10 +100,10 @@ func NewService(srv *scheserver.Service) *Service { c.Set(handlerKey, handler.NewHandler(&server{srv.Server})) c.Next() }) - apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) apiHandlerEngine.GET("metrics", mcsutils.PromHandler()) pprof.Register(apiHandlerEngine) root := apiHandlerEngine.Group(APIPathPrefix) + root.Use(multiservicesapi.ServiceRedirector()) s := &Service{ srv: srv, apiHandlerEngine: apiHandlerEngine, diff --git a/pkg/mcs/tso/server/apis/v1/api.go b/pkg/mcs/tso/server/apis/v1/api.go index f1853bf5483..1b8f68778af 100644 --- a/pkg/mcs/tso/server/apis/v1/api.go +++ b/pkg/mcs/tso/server/apis/v1/api.go @@ -89,10 +89,10 @@ func NewService(srv *tsoserver.Service) *Service { c.Set(multiservicesapi.ServiceContextKey, srv) c.Next() }) - apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) apiHandlerEngine.GET("metrics", utils.PromHandler()) pprof.Register(apiHandlerEngine) root := apiHandlerEngine.Group(APIPathPrefix) + root.Use(multiservicesapi.ServiceRedirector()) s := &Service{ srv: srv, apiHandlerEngine: apiHandlerEngine, From c332ddce95b0a9193022724c94047bdb87953633 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 6 Nov 2023 20:20:11 +0800 Subject: [PATCH 04/26] checker: avoid unnecessary remove disconnected peer with multi orphan peers (#7315) close tikv/pd#7249 Signed-off-by: lhy1024 --- pkg/core/store.go | 3 - pkg/schedule/checker/rule_checker.go | 29 +- pkg/schedule/checker/rule_checker_test.go | 445 ++++++++++++++-------- 3 files changed, 308 insertions(+), 169 deletions(-) diff --git a/pkg/core/store.go b/pkg/core/store.go index b3c62f45750..1d3362cac0e 100644 --- a/pkg/core/store.go +++ b/pkg/core/store.go @@ -551,9 +551,6 @@ var ( // tikv's store heartbeat for a short time, maybe caused by process restart or // temporary network failure. func (s *StoreInfo) IsDisconnected() bool { - if s == nil { - return true - } return s.DownTime() > storeDisconnectDuration } diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index 84cafaa871e..c4e7c242dea 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -78,6 +78,7 @@ var ( ruleCheckerSkipRemoveOrphanPeerCounter = checkerCounter.WithLabelValues(ruleChecker, "skip-remove-orphan-peer") ruleCheckerRemoveOrphanPeerCounter = checkerCounter.WithLabelValues(ruleChecker, "remove-orphan-peer") ruleCheckerReplaceOrphanPeerCounter = checkerCounter.WithLabelValues(ruleChecker, "replace-orphan-peer") + ruleCheckerReplaceOrphanPeerNoFitCounter = checkerCounter.WithLabelValues(ruleChecker, "replace-orphan-peer-no-fit") ) // RuleChecker fix/improve region by placement rules. @@ -465,7 +466,11 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg isDisconnectedPeer := func(p *metapb.Peer) bool { // avoid to meet down store when fix orphan peers, // Isdisconnected is more strictly than IsUnhealthy. - return c.cluster.GetStore(p.GetStoreId()).IsDisconnected() + store := c.cluster.GetStore(p.GetStoreId()) + if store == nil { + return true + } + return store.IsDisconnected() } checkDownPeer := func(peers []*metapb.Peer) (*metapb.Peer, bool) { @@ -519,7 +524,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg if pinDownPeer.GetIsWitness() || orphanPeer.GetIsWitness() { continue } - // down peer's store should be disconnected + // pinDownPeer's store should be disconnected, because we use more strict judge before. if !isDisconnectedPeer(pinDownPeer) { continue } @@ -534,13 +539,14 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg return operator.CreatePromoteLearnerOperatorAndRemovePeer("replace-down-peer-with-orphan-peer", c.cluster, region, orphanPeer, pinDownPeer) case orphanPeerRole == metapb.PeerRole_Voter && destRole == metapb.PeerRole_Learner: return operator.CreateDemoteLearnerOperatorAndRemovePeer("replace-down-peer-with-orphan-peer", c.cluster, region, orphanPeer, pinDownPeer) - case orphanPeerRole == metapb.PeerRole_Voter && destRole == metapb.PeerRole_Voter && - isDisconnectedPeer(pinDownPeer) && !dstStore.IsDisconnected(): + case orphanPeerRole == destRole && isDisconnectedPeer(pinDownPeer) && !dstStore.IsDisconnected(): return operator.CreateRemovePeerOperator("remove-replaced-orphan-peer", c.cluster, 0, region, pinDownPeer.GetStoreId()) default: // destRole should not same with orphanPeerRole. if role is same, it fit with orphanPeer should be better than now. // destRole never be leader, so we not consider it. } + } else { + ruleCheckerReplaceOrphanPeerNoFitCounter.Inc() } } } @@ -549,18 +555,25 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg // Ref https://github.com/tikv/pd/issues/4045 if len(fit.OrphanPeers) >= 2 { hasHealthPeer := false + var disconnectedPeer *metapb.Peer + for _, orphanPeer := range fit.OrphanPeers { + if isDisconnectedPeer(orphanPeer) { + disconnectedPeer = orphanPeer + break + } + } for _, orphanPeer := range fit.OrphanPeers { if isUnhealthyPeer(orphanPeer.GetId()) { ruleCheckerRemoveOrphanPeerCounter.Inc() return operator.CreateRemovePeerOperator("remove-unhealthy-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) } - if isDisconnectedPeer(orphanPeer) { - ruleCheckerRemoveOrphanPeerCounter.Inc() - return operator.CreateRemovePeerOperator("remove-disconnected-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) - } if hasHealthPeer { // there already exists a healthy orphan peer, so we can remove other orphan Peers. ruleCheckerRemoveOrphanPeerCounter.Inc() + // if there exists a disconnected orphan peer, we will pick it to remove firstly. + if disconnectedPeer != nil { + return operator.CreateRemovePeerOperator("remove-orphan-peer", c.cluster, 0, region, disconnectedPeer.StoreId) + } return operator.CreateRemovePeerOperator("remove-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) } hasHealthPeer = true diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index 0c4a2a9ecc9..eb357f302b7 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -17,6 +17,7 @@ package checker import ( "context" "fmt" + "strconv" "strings" "testing" @@ -225,7 +226,6 @@ func (suite *ruleCheckerTestSuite) TestFixToManyOrphanPeers() { suite.NotNil(op) suite.Equal("remove-orphan-peer", op.Desc()) suite.Equal(uint64(5), op.Step(0).(operator.RemovePeer).FromStore) - // Case2: // store 4, 5, 6 are orphan peers, and peer on store 3 is down peer. and peer on store 4, 5 are pending. region = suite.cluster.GetRegion(1) @@ -237,6 +237,91 @@ func (suite *ruleCheckerTestSuite) TestFixToManyOrphanPeers() { suite.NotNil(op) suite.Equal("remove-unhealthy-orphan-peer", op.Desc()) suite.Equal(uint64(4), op.Step(0).(operator.RemovePeer).FromStore) + // Case3: + // store 4, 5, 6 are orphan peers, and peer on one of stores is disconnect peer + // we should remove disconnect peer first. + for i := uint64(4); i <= 6; i++ { + region = suite.cluster.GetRegion(1) + suite.cluster.SetStoreDisconnect(i) + region = region.Clone( + core.WithDownPeers([]*pdpb.PeerStats{{Peer: region.GetStorePeer(3), DownSeconds: 60000}}), + core.WithPendingPeers([]*metapb.Peer{region.GetStorePeer(3)})) + suite.cluster.PutRegion(region) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(i, op.Step(0).(operator.RemovePeer).FromStore) + suite.cluster.SetStoreUp(i) + } + // Case4: + // store 4, 5, 6 are orphan peers, and peer on two of stores is disconnect peer + // we should remove disconnect peer first. + for i := uint64(4); i <= 6; i++ { + region = suite.cluster.GetRegion(1) + suite.cluster.SetStoreDisconnect(4) + suite.cluster.SetStoreDisconnect(5) + suite.cluster.SetStoreDisconnect(6) + suite.cluster.SetStoreUp(i) + region = region.Clone( + core.WithDownPeers([]*pdpb.PeerStats{{Peer: region.GetStorePeer(3), DownSeconds: 60000}}), + core.WithPendingPeers([]*metapb.Peer{region.GetStorePeer(3)})) + suite.cluster.PutRegion(region) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + removedPeerStoreID := op.Step(0).(operator.RemovePeer).FromStore + suite.NotEqual(i, removedPeerStoreID) + region = suite.cluster.GetRegion(1) + newRegion := region.Clone(core.WithRemoveStorePeer(removedPeerStoreID)) + suite.cluster.PutRegion(newRegion) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + removedPeerStoreID = op.Step(0).(operator.RemovePeer).FromStore + suite.NotEqual(i, removedPeerStoreID) + suite.cluster.PutRegion(region) + } +} + +func (suite *ruleCheckerTestSuite) TestFixToManyOrphanPeers2() { + suite.cluster.AddLeaderStore(1, 1) + suite.cluster.AddLeaderStore(2, 1) + suite.cluster.AddLeaderStore(3, 1) + suite.cluster.AddLeaderStore(4, 1) + suite.cluster.AddLeaderStore(5, 1) + suite.cluster.AddRegionWithLearner(1, 1, []uint64{2, 3}, []uint64{4, 5}) + + // Case1: + // store 4, 5 are orphan peers, and peer on one of stores is disconnect peer + // we should remove disconnect peer first. + for i := uint64(4); i <= 5; i++ { + region := suite.cluster.GetRegion(1) + suite.cluster.SetStoreDisconnect(i) + region = region.Clone( + core.WithDownPeers([]*pdpb.PeerStats{{Peer: region.GetStorePeer(3), DownSeconds: 60000}}), + core.WithPendingPeers([]*metapb.Peer{region.GetStorePeer(3)})) + suite.cluster.PutRegion(region) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(i, op.Step(0).(operator.RemovePeer).FromStore) + suite.cluster.SetStoreUp(i) + } + + // Case2: + // store 4, 5 are orphan peers, and they are disconnect peers + // we should remove the peer on disconnect stores at least. + region := suite.cluster.GetRegion(1) + suite.cluster.SetStoreDisconnect(4) + suite.cluster.SetStoreDisconnect(5) + region = region.Clone( + core.WithDownPeers([]*pdpb.PeerStats{{Peer: region.GetStorePeer(3), DownSeconds: 60000}}), + core.WithPendingPeers([]*metapb.Peer{region.GetStorePeer(3)})) + suite.cluster.PutRegion(region) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) + suite.Equal(uint64(4), op.Step(0).(operator.RemovePeer).FromStore) } func (suite *ruleCheckerTestSuite) TestFixOrphanPeers2() { @@ -725,173 +810,217 @@ func (suite *ruleCheckerTestSuite) TestPriorityFixOrphanPeer() { // Ref https://github.com/tikv/pd/issues/7249 https://github.com/tikv/tikv/issues/15799 func (suite *ruleCheckerTestSuite) TestFixOrphanPeerWithDisconnectedStoreAndRuleChanged() { - // init cluster with 5 replicas - suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) - suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) - suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) - storeIDs := []uint64{1, 2, 3, 4, 5} - suite.cluster.AddLeaderRegionWithRange(1, "", "", storeIDs[0], storeIDs[1:]...) - rule := &placement.Rule{ - GroupID: "pd", - ID: "default", - Role: placement.Voter, - Count: 5, - StartKey: []byte{}, - EndKey: []byte{}, - } - suite.ruleManager.SetRule(rule) - op := suite.rc.Check(suite.cluster.GetRegion(1)) - suite.Nil(op) - - // set store 1, 2 to disconnected - suite.cluster.SetStoreDisconnect(1) - suite.cluster.SetStoreDisconnect(2) - - // change rule to 3 replicas - rule = &placement.Rule{ - GroupID: "pd", - ID: "default", - Role: placement.Voter, - Count: 3, - StartKey: []byte{}, - EndKey: []byte{}, - Override: true, + // disconnect any two stores and change rule to 3 replicas + stores := []uint64{1, 2, 3, 4, 5} + testCases := [][]uint64{} + for i := 0; i < len(stores); i++ { + for j := i + 1; j < len(stores); j++ { + testCases = append(testCases, []uint64{stores[i], stores[j]}) + } } - suite.ruleManager.SetRule(rule) + for _, leader := range stores { + var followers []uint64 + for i := 0; i < len(stores); i++ { + if stores[i] != leader { + followers = append(followers, stores[i]) + } + } - // remove store 1 from region 1 - op = suite.rc.Check(suite.cluster.GetRegion(1)) - suite.NotNil(op) - suite.Equal("remove-replaced-orphan-peer", op.Desc()) - suite.Equal(op.Len(), 2) - newLeaderID := op.Step(0).(operator.TransferLeader).ToStore - removedPeerID := op.Step(1).(operator.RemovePeer).FromStore - r1 := suite.cluster.GetRegion(1) - r1 = r1.Clone( - core.WithLeader(r1.GetPeer(newLeaderID)), - core.WithRemoveStorePeer(removedPeerID)) - suite.cluster.PutRegion(r1) - r1 = suite.cluster.GetRegion(1) - suite.Len(r1.GetPeers(), 4) + for _, testCase := range testCases { + // init cluster with 5 replicas + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", leader, followers...) + rule := &placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 5, + StartKey: []byte{}, + EndKey: []byte{}, + } + err := suite.ruleManager.SetRule(rule) + suite.NoError(err) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + + // set two stores to disconnected + suite.cluster.SetStoreDisconnect(testCase[0]) + suite.cluster.SetStoreDisconnect(testCase[1]) + + // change rule to 3 replicas + rule = &placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 3, + StartKey: []byte{}, + EndKey: []byte{}, + Override: true, + } + suite.ruleManager.SetRule(rule) + + // remove peer from region 1 + for j := 1; j <= 2; j++ { + r1 := suite.cluster.GetRegion(1) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Contains(op.Desc(), "orphan") + var removedPeerStoreID uint64 + newLeaderStoreID := r1.GetLeader().GetStoreId() + for i := 0; i < op.Len(); i++ { + if s, ok := op.Step(i).(operator.RemovePeer); ok { + removedPeerStoreID = s.FromStore + } + if s, ok := op.Step(i).(operator.TransferLeader); ok { + newLeaderStoreID = s.ToStore + } + } + suite.NotZero(removedPeerStoreID) + r1 = r1.Clone( + core.WithLeader(r1.GetStorePeer(newLeaderStoreID)), + core.WithRemoveStorePeer(removedPeerStoreID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 5-j) + } - // remove store 2 from region 1 - op = suite.rc.Check(suite.cluster.GetRegion(1)) - suite.NotNil(op) - suite.Equal("remove-replaced-orphan-peer", op.Desc()) - suite.Equal(op.Len(), 1) - removedPeerID = op.Step(0).(operator.RemovePeer).FromStore - r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) - suite.cluster.PutRegion(r1) - r1 = suite.cluster.GetRegion(1) - suite.Len(r1.GetPeers(), 3) - for _, p := range r1.GetPeers() { - suite.NotEqual(p.GetStoreId(), 1) - suite.NotEqual(p.GetStoreId(), 2) + r1 := suite.cluster.GetRegion(1) + for _, p := range r1.GetPeers() { + suite.NotEqual(p.GetStoreId(), testCase[0]) + suite.NotEqual(p.GetStoreId(), testCase[1]) + } + suite.TearDownTest() + suite.SetupTest() + } } } // Ref https://github.com/tikv/pd/issues/7249 https://github.com/tikv/tikv/issues/15799 -func (suite *ruleCheckerTestSuite) TestFixOrphanPeerWithDisconnectedStoreAndRuleChanged2() { - // init cluster with 5 voters and 1 learner - suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) - suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) - suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) - suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) - suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) - suite.cluster.AddLabelsStore(6, 1, map[string]string{"host": "host6"}) - storeIDs := []uint64{1, 2, 3, 4, 5} - suite.cluster.AddLeaderRegionWithRange(1, "", "", storeIDs[0], storeIDs[1:]...) - r1 := suite.cluster.GetRegion(1) - r1 = r1.Clone(core.WithAddPeer(&metapb.Peer{Id: 6, StoreId: 6, Role: metapb.PeerRole_Learner})) - suite.cluster.PutRegion(r1) - err := suite.ruleManager.SetRules([]*placement.Rule{ - { - GroupID: "pd", - ID: "default", - Index: 100, - Override: true, - Role: placement.Voter, - Count: 5, - IsWitness: false, - }, - { - GroupID: "pd", - ID: "r1", - Index: 100, - Override: false, - Role: placement.Learner, - Count: 1, - IsWitness: false, - }, - }) - suite.NoError(err) - - op := suite.rc.Check(suite.cluster.GetRegion(1)) - suite.Nil(op) - - // set store 1, 2 to disconnected - suite.cluster.SetStoreDisconnect(1) - suite.cluster.SetStoreDisconnect(2) - suite.cluster.SetStoreDisconnect(3) - - // change rule to 3 replicas - suite.ruleManager.DeleteRuleGroup("pd") - suite.ruleManager.SetRule(&placement.Rule{ - GroupID: "pd", - ID: "default", - Role: placement.Voter, - Count: 2, - StartKey: []byte{}, - EndKey: []byte{}, - Override: true, - }) - - // remove store 1 from region 1 - op = suite.rc.Check(suite.cluster.GetRegion(1)) - suite.NotNil(op) - suite.Equal("remove-replaced-orphan-peer", op.Desc()) - suite.Equal(op.Len(), 2) - newLeaderID := op.Step(0).(operator.TransferLeader).ToStore - removedPeerID := op.Step(1).(operator.RemovePeer).FromStore - r1 = suite.cluster.GetRegion(1) - r1 = r1.Clone( - core.WithLeader(r1.GetPeer(newLeaderID)), - core.WithRemoveStorePeer(removedPeerID)) - suite.cluster.PutRegion(r1) - r1 = suite.cluster.GetRegion(1) - suite.Len(r1.GetPeers(), 5) - - // remove store 2 from region 1 - op = suite.rc.Check(suite.cluster.GetRegion(1)) - suite.NotNil(op) - suite.Equal("remove-replaced-orphan-peer", op.Desc()) - suite.Equal(op.Len(), 1) - removedPeerID = op.Step(0).(operator.RemovePeer).FromStore - r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) - suite.cluster.PutRegion(r1) - r1 = suite.cluster.GetRegion(1) - suite.Len(r1.GetPeers(), 4) - for _, p := range r1.GetPeers() { - fmt.Println(p.GetStoreId(), p.Role.String()) +func (suite *ruleCheckerTestSuite) TestFixOrphanPeerWithDisconnectedStoreAndRuleChangedWithLearner() { + // disconnect any three stores and change rule to 3 replicas + // and there is a learner in the disconnected store. + stores := []uint64{1, 2, 3, 4, 5, 6} + testCases := [][]uint64{} + for i := 0; i < len(stores); i++ { + for j := i + 1; j < len(stores); j++ { + for k := j + 1; k < len(stores); k++ { + testCases = append(testCases, []uint64{stores[i], stores[j], stores[k]}) + } + } } + for _, leader := range stores { + var followers []uint64 + for i := 0; i < len(stores); i++ { + if stores[i] != leader { + followers = append(followers, stores[i]) + } + } - // remove store 3 from region 1 - op = suite.rc.Check(suite.cluster.GetRegion(1)) - suite.NotNil(op) - suite.Equal("remove-replaced-orphan-peer", op.Desc()) - suite.Equal(op.Len(), 1) - removedPeerID = op.Step(0).(operator.RemovePeer).FromStore - r1 = r1.Clone(core.WithRemoveStorePeer(removedPeerID)) - suite.cluster.PutRegion(r1) - r1 = suite.cluster.GetRegion(1) - suite.Len(r1.GetPeers(), 3) + for _, testCase := range testCases { + for _, learnerStore := range testCase { + if learnerStore == leader { + continue + } + voterFollowers := []uint64{} + for _, follower := range followers { + if follower != learnerStore { + voterFollowers = append(voterFollowers, follower) + } + } + // init cluster with 5 voters and 1 learner + suite.cluster.AddLabelsStore(1, 1, map[string]string{"host": "host1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"host": "host2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"host": "host3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"host": "host4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"host": "host5"}) + suite.cluster.AddLabelsStore(6, 1, map[string]string{"host": "host6"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", leader, voterFollowers...) + err := suite.ruleManager.SetRules([]*placement.Rule{ + { + GroupID: "pd", + ID: "default", + Index: 100, + Override: true, + Role: placement.Voter, + Count: 5, + IsWitness: false, + }, + { + GroupID: "pd", + ID: "r1", + Index: 100, + Override: false, + Role: placement.Learner, + Count: 1, + IsWitness: false, + LabelConstraints: []placement.LabelConstraint{ + {Key: "host", Op: "in", Values: []string{"host" + strconv.FormatUint(learnerStore, 10)}}, + }, + }, + }) + suite.NoError(err) + r1 := suite.cluster.GetRegion(1) + r1 = r1.Clone(core.WithAddPeer(&metapb.Peer{Id: 12, StoreId: learnerStore, Role: metapb.PeerRole_Learner})) + suite.cluster.PutRegion(r1) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + + // set three stores to disconnected + suite.cluster.SetStoreDisconnect(testCase[0]) + suite.cluster.SetStoreDisconnect(testCase[1]) + suite.cluster.SetStoreDisconnect(testCase[2]) + + // change rule to 3 replicas + suite.ruleManager.DeleteRule("pd", "r1") + suite.ruleManager.SetRule(&placement.Rule{ + GroupID: "pd", + ID: "default", + Role: placement.Voter, + Count: 3, + StartKey: []byte{}, + EndKey: []byte{}, + Override: true, + }) + + // remove peer from region 1 + for j := 1; j <= 3; j++ { + r1 := suite.cluster.GetRegion(1) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Contains(op.Desc(), "orphan") + var removedPeerStroeID uint64 + newLeaderStoreID := r1.GetLeader().GetStoreId() + for i := 0; i < op.Len(); i++ { + if s, ok := op.Step(i).(operator.RemovePeer); ok { + removedPeerStroeID = s.FromStore + } + if s, ok := op.Step(i).(operator.TransferLeader); ok { + newLeaderStoreID = s.ToStore + } + } + suite.NotZero(removedPeerStroeID) + r1 = r1.Clone( + core.WithLeader(r1.GetStorePeer(newLeaderStoreID)), + core.WithRemoveStorePeer(removedPeerStroeID)) + suite.cluster.PutRegion(r1) + r1 = suite.cluster.GetRegion(1) + suite.Len(r1.GetPeers(), 6-j) + } - for _, p := range r1.GetPeers() { - suite.NotEqual(p.GetStoreId(), 1) - suite.NotEqual(p.GetStoreId(), 2) - suite.NotEqual(p.GetStoreId(), 3) + r1 = suite.cluster.GetRegion(1) + for _, p := range r1.GetPeers() { + suite.NotEqual(p.GetStoreId(), testCase[0]) + suite.NotEqual(p.GetStoreId(), testCase[1]) + suite.NotEqual(p.GetStoreId(), testCase[2]) + } + suite.TearDownTest() + suite.SetupTest() + } + } } } From 356066a3598c68620b3807bc2bf449ba209e33f4 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 7 Nov 2023 14:11:10 +0800 Subject: [PATCH 05/26] mcs: solve forward stream error (#7321) close tikv/pd#7320 Signed-off-by: lhy1024 --- server/grpc_service.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/grpc_service.go b/server/grpc_service.go index 2e59bdaf742..4aa6dc5b1da 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -2632,7 +2632,13 @@ func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timest if err != nil { return pdpb.Timestamp{}, err } - forwardStream.Send(request) + err := forwardStream.Send(request) + if err != nil { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + continue + } ts, err = forwardStream.Recv() if err != nil { if strings.Contains(err.Error(), errs.NotLeaderErr) { From 47ba96f95bf96eeaa323dfd091990d4dc1bb1684 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 7 Nov 2023 14:59:10 +0800 Subject: [PATCH 06/26] mcs: support config http interface in scheduling server (#7278) ref tikv/pd#5839 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/apis/v1/api.go | 19 + pkg/mcs/scheduling/server/config/config.go | 13 +- pkg/mcs/scheduling/server/server.go | 23 + server/api/config.go | 63 ++- server/api/config_test.go | 440 --------------- tests/integrations/mcs/scheduling/api_test.go | 83 +++ tests/pdctl/config/config_test.go | 162 +++--- tests/pdctl/scheduler/scheduler_test.go | 12 +- tests/server/api/operator_test.go | 33 +- tests/server/api/scheduler_test.go | 25 +- tests/server/config/config_test.go | 531 +++++++++++++++++- 11 files changed, 855 insertions(+), 549 deletions(-) delete mode 100644 server/api/config_test.go diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 98fb68c090b..47fdb95543f 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -111,6 +111,7 @@ func NewService(srv *scheserver.Service) *Service { rd: createIndentRender(), } s.RegisterAdminRouter() + s.RegisterConfigRouter() s.RegisterOperatorsRouter() s.RegisterSchedulersRouter() s.RegisterCheckersRouter() @@ -126,6 +127,12 @@ func (s *Service) RegisterAdminRouter() { router.DELETE("cache/regions/:id", deleteRegionCacheByID) } +// RegisterConfigRouter registers the router of the config handler. +func (s *Service) RegisterConfigRouter() { + router := s.root.Group("config") + router.GET("", getConfig) +} + // RegisterSchedulersRouter registers the router of the schedulers handler. func (s *Service) RegisterSchedulersRouter() { router := s.root.Group("schedulers") @@ -186,6 +193,18 @@ func changeLogLevel(c *gin.Context) { c.String(http.StatusOK, "The log level is updated.") } +// @Tags config +// @Summary Get full config. +// @Produce json +// @Success 200 {object} config.Config +// @Router /config [get] +func getConfig(c *gin.Context) { + svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) + cfg := svr.GetConfig() + cfg.Schedule.MaxMergeRegionKeys = cfg.Schedule.GetMaxMergeRegionKeys() + c.IndentedJSON(http.StatusOK, cfg) +} + // @Tags admin // @Summary Drop all regions from cache. // @Produce json diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index 4f9caca41e6..772eab835f1 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -61,9 +61,9 @@ type Config struct { Metric metricutil.MetricConfig `toml:"metric" json:"metric"` // Log related config. - Log log.Config `toml:"log" json:"log"` - Logger *zap.Logger - LogProps *log.ZapProperties + Log log.Config `toml:"log" json:"log"` + Logger *zap.Logger `json:"-"` + LogProps *log.ZapProperties `json:"-"` Security configutil.SecurityConfig `toml:"security" json:"security"` @@ -195,6 +195,13 @@ func (c *Config) validate() error { return nil } +// Clone creates a copy of current config. +func (c *Config) Clone() *Config { + cfg := &Config{} + *cfg = *c + return cfg +} + // PersistConfig wraps all configurations that need to persist to storage and // allows to access them safely. type PersistConfig struct { diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index 5e2ed58a009..1790cb2b4be 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -504,6 +504,29 @@ func (s *Server) stopWatcher() { s.metaWatcher.Close() } +// GetPersistConfig returns the persist config. +// It's used to test. +func (s *Server) GetPersistConfig() *config.PersistConfig { + return s.persistConfig +} + +// GetConfig gets the config. +func (s *Server) GetConfig() *config.Config { + cfg := s.cfg.Clone() + cfg.Schedule = *s.persistConfig.GetScheduleConfig().Clone() + cfg.Replication = *s.persistConfig.GetReplicationConfig().Clone() + cfg.ClusterVersion = *s.persistConfig.GetClusterVersion() + if s.storage == nil { + return cfg + } + sches, configs, err := s.storage.LoadAllSchedulerConfigs() + if err != nil { + return cfg + } + cfg.Schedule.SchedulersPayload = schedulers.ToPayload(sches, configs) + return cfg +} + // CreateServer creates the Server func CreateServer(ctx context.Context, cfg *config.Config) *Server { svr := &Server{ diff --git a/server/api/config.go b/server/api/config.go index c63bd953c37..746b1119a73 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -27,6 +27,8 @@ import ( "github.com/pingcap/errcode" "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/jsonutil" @@ -60,7 +62,17 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler { // @Router /config [get] func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetConfig() - cfg.Schedule.MaxMergeRegionKeys = cfg.Schedule.GetMaxMergeRegionKeys() + if h.svr.IsAPIServiceMode() { + schedulingServerConfig, err := h.GetSchedulingServerConfig() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + cfg.Schedule = schedulingServerConfig.Schedule + cfg.Replication = schedulingServerConfig.Replication + } else { + cfg.Schedule.MaxMergeRegionKeys = cfg.Schedule.GetMaxMergeRegionKeys() + } h.rd.JSON(w, http.StatusOK, cfg) } @@ -301,6 +313,16 @@ func getConfigMap(cfg map[string]interface{}, key []string, value interface{}) m // @Success 200 {object} sc.ScheduleConfig // @Router /config/schedule [get] func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) { + if h.svr.IsAPIServiceMode() { + cfg, err := h.GetSchedulingServerConfig() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + cfg.Schedule.SchedulersPayload = nil + h.rd.JSON(w, http.StatusOK, cfg.Schedule) + return + } cfg := h.svr.GetScheduleConfig() cfg.MaxMergeRegionKeys = cfg.GetMaxMergeRegionKeys() h.rd.JSON(w, http.StatusOK, cfg) @@ -364,6 +386,15 @@ func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request) // @Success 200 {object} sc.ReplicationConfig // @Router /config/replicate [get] func (h *confHandler) GetReplicationConfig(w http.ResponseWriter, r *http.Request) { + if h.svr.IsAPIServiceMode() { + cfg, err := h.GetSchedulingServerConfig() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + h.rd.JSON(w, http.StatusOK, cfg.Replication) + return + } h.rd.JSON(w, http.StatusOK, h.svr.GetReplicationConfig()) } @@ -505,3 +536,33 @@ func (h *confHandler) SetReplicationModeConfig(w http.ResponseWriter, r *http.Re func (h *confHandler) GetPDServerConfig(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, h.svr.GetPDServerConfig()) } + +func (h *confHandler) GetSchedulingServerConfig() (*config.Config, error) { + addr, ok := h.svr.GetServicePrimaryAddr(h.svr.Context(), utils.SchedulingServiceName) + if !ok { + return nil, errs.ErrNotFoundSchedulingAddr.FastGenByArgs() + } + url := fmt.Sprintf("%s/scheduling/api/v1/config", addr) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := h.svr.GetHTTPClient().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, errs.ErrSchedulingServer.FastGenByArgs(resp.StatusCode) + } + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var schedulingServerConfig config.Config + err = json.Unmarshal(b, &schedulingServerConfig) + if err != nil { + return nil, err + } + return &schedulingServerConfig, nil +} diff --git a/server/api/config_test.go b/server/api/config_test.go deleted file mode 100644 index fbfb3f94518..00000000000 --- a/server/api/config_test.go +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2016 TiKV Project Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/suite" - sc "github.com/tikv/pd/pkg/schedule/config" - tu "github.com/tikv/pd/pkg/utils/testutil" - "github.com/tikv/pd/pkg/utils/typeutil" - "github.com/tikv/pd/pkg/versioninfo" - "github.com/tikv/pd/server" - "github.com/tikv/pd/server/config" -) - -type configTestSuite struct { - suite.Suite - svr *server.Server - cleanup tu.CleanupFunc - urlPrefix string -} - -func TestConfigTestSuite(t *testing.T) { - suite.Run(t, new(configTestSuite)) -} - -func (suite *configTestSuite) SetupSuite() { - re := suite.Require() - suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { - cfg.Replication.EnablePlacementRules = false - }) - server.MustWaitLeader(re, []*server.Server{suite.svr}) - - addr := suite.svr.GetAddr() - suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) -} - -func (suite *configTestSuite) TearDownSuite() { - suite.cleanup() -} - -func (suite *configTestSuite) TestConfigAll() { - re := suite.Require() - addr := fmt.Sprintf("%s/config", suite.urlPrefix) - cfg := &config.Config{} - err := tu.ReadGetJSON(re, testDialClient, addr, cfg) - suite.NoError(err) - - // the original way - r := map[string]int{"max-replicas": 5} - postData, err := json.Marshal(r) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - l := map[string]interface{}{ - "location-labels": "zone,rack", - "region-schedule-limit": 10, - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - l = map[string]interface{}{ - "metric-storage": "http://127.0.0.1:9090", - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - newCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg) - suite.NoError(err) - cfg.Replication.MaxReplicas = 5 - cfg.Replication.LocationLabels = []string{"zone", "rack"} - cfg.Schedule.RegionScheduleLimit = 10 - cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:9090" - suite.Equal(newCfg, cfg) - - // the new way - l = map[string]interface{}{ - "schedule.tolerant-size-ratio": 2.5, - "schedule.enable-tikv-split-region": "false", - "replication.location-labels": "idc,host", - "pd-server.metric-storage": "http://127.0.0.1:1234", - "log.level": "warn", - "cluster-version": "v4.0.0-beta", - "replication-mode.replication-mode": "dr-auto-sync", - "replication-mode.dr-auto-sync.label-key": "foobar", - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - newCfg1 := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg1) - suite.NoError(err) - cfg.Schedule.EnableTiKVSplitRegion = false - cfg.Schedule.TolerantSizeRatio = 2.5 - cfg.Replication.LocationLabels = []string{"idc", "host"} - cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:1234" - cfg.Log.Level = "warn" - cfg.ReplicationMode.DRAutoSync.LabelKey = "foobar" - cfg.ReplicationMode.ReplicationMode = "dr-auto-sync" - v, err := versioninfo.ParseVersion("v4.0.0-beta") - suite.NoError(err) - cfg.ClusterVersion = *v - suite.Equal(cfg, newCfg1) - - // revert this to avoid it affects TestConfigTTL - l["schedule.enable-tikv-split-region"] = "true" - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - // illegal prefix - l = map[string]interface{}{ - "replicate.max-replicas": 1, - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, - tu.StatusNotOK(re), - tu.StringContain(re, "not found")) - suite.NoError(err) - - // update prefix directly - l = map[string]interface{}{ - "replication-mode": nil, - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, - tu.StatusNotOK(re), - tu.StringContain(re, "cannot update config prefix")) - suite.NoError(err) - - // config item not found - l = map[string]interface{}{ - "schedule.region-limit": 10, - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) - suite.NoError(err) -} - -func (suite *configTestSuite) TestConfigSchedule() { - re := suite.Require() - addr := fmt.Sprintf("%s/config/schedule", suite.urlPrefix) - scheduleConfig := &sc.ScheduleConfig{} - suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig)) - scheduleConfig.MaxStoreDownTime.Duration = time.Second - postData, err := json.Marshal(scheduleConfig) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - scheduleConfig1 := &sc.ScheduleConfig{} - suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig1)) - suite.Equal(*scheduleConfig1, *scheduleConfig) -} - -func (suite *configTestSuite) TestConfigReplication() { - re := suite.Require() - addr := fmt.Sprintf("%s/config/replicate", suite.urlPrefix) - rc := &sc.ReplicationConfig{} - err := tu.ReadGetJSON(re, testDialClient, addr, rc) - suite.NoError(err) - - rc.MaxReplicas = 5 - rc1 := map[string]int{"max-replicas": 5} - postData, err := json.Marshal(rc1) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - rc.LocationLabels = []string{"zone", "rack"} - rc2 := map[string]string{"location-labels": "zone,rack"} - postData, err = json.Marshal(rc2) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - rc.IsolationLevel = "zone" - rc3 := map[string]string{"isolation-level": "zone"} - postData, err = json.Marshal(rc3) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - rc4 := &sc.ReplicationConfig{} - err = tu.ReadGetJSON(re, testDialClient, addr, rc4) - suite.NoError(err) - - suite.Equal(*rc4, *rc) -} - -func (suite *configTestSuite) TestConfigLabelProperty() { - re := suite.Require() - addr := suite.svr.GetAddr() + apiPrefix + "/api/v1/config/label-property" - loadProperties := func() config.LabelPropertyConfig { - var cfg config.LabelPropertyConfig - err := tu.ReadGetJSON(re, testDialClient, addr, &cfg) - suite.NoError(err) - return cfg - } - - cfg := loadProperties() - suite.Empty(cfg) - - cmds := []string{ - `{"type": "foo", "action": "set", "label-key": "zone", "label-value": "cn1"}`, - `{"type": "foo", "action": "set", "label-key": "zone", "label-value": "cn2"}`, - `{"type": "bar", "action": "set", "label-key": "host", "label-value": "h1"}`, - } - for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) - suite.NoError(err) - } - - cfg = loadProperties() - suite.Len(cfg, 2) - suite.Equal([]config.StoreLabel{ - {Key: "zone", Value: "cn1"}, - {Key: "zone", Value: "cn2"}, - }, cfg["foo"]) - suite.Equal([]config.StoreLabel{{Key: "host", Value: "h1"}}, cfg["bar"]) - - cmds = []string{ - `{"type": "foo", "action": "delete", "label-key": "zone", "label-value": "cn1"}`, - `{"type": "bar", "action": "delete", "label-key": "host", "label-value": "h1"}`, - } - for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) - suite.NoError(err) - } - - cfg = loadProperties() - suite.Len(cfg, 1) - suite.Equal([]config.StoreLabel{{Key: "zone", Value: "cn2"}}, cfg["foo"]) -} - -func (suite *configTestSuite) TestConfigDefault() { - addr := fmt.Sprintf("%s/config", suite.urlPrefix) - - r := map[string]int{"max-replicas": 5} - postData, err := json.Marshal(r) - suite.NoError(err) - re := suite.Require() - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - l := map[string]interface{}{ - "location-labels": "zone,rack", - "region-schedule-limit": 10, - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - l = map[string]interface{}{ - "metric-storage": "http://127.0.0.1:9090", - } - postData, err = json.Marshal(l) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - - addr = fmt.Sprintf("%s/config/default", suite.urlPrefix) - defaultCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, defaultCfg) - suite.NoError(err) - - suite.Equal(uint64(3), defaultCfg.Replication.MaxReplicas) - suite.Equal(typeutil.StringSlice([]string{}), defaultCfg.Replication.LocationLabels) - suite.Equal(uint64(2048), defaultCfg.Schedule.RegionScheduleLimit) - suite.Equal("", defaultCfg.PDServerCfg.MetricStorage) -} - -func (suite *configTestSuite) TestConfigPDServer() { - re := suite.Require() - addrPost := fmt.Sprintf("%s/config", suite.urlPrefix) - ms := map[string]interface{}{ - "metric-storage": "", - } - postData, err := json.Marshal(ms) - suite.NoError(err) - suite.NoError(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(re))) - addrGet := fmt.Sprintf("%s/config/pd-server", suite.urlPrefix) - sc := &config.PDServerConfig{} - suite.NoError(tu.ReadGetJSON(re, testDialClient, addrGet, sc)) - suite.Equal(bool(true), sc.UseRegionStorage) - suite.Equal("table", sc.KeyType) - suite.Equal(typeutil.StringSlice([]string{}), sc.RuntimeServices) - suite.Equal("", sc.MetricStorage) - suite.Equal("auto", sc.DashboardAddress) - suite.Equal(int(3), sc.FlowRoundByDigit) - suite.Equal(typeutil.NewDuration(time.Second), sc.MinResolvedTSPersistenceInterval) - suite.Equal(24*time.Hour, sc.MaxResetTSGap.Duration) -} - -var ttlConfig = map[string]interface{}{ - "schedule.max-snapshot-count": 999, - "schedule.enable-location-replacement": false, - "schedule.max-merge-region-size": 999, - "schedule.max-merge-region-keys": 999, - "schedule.scheduler-max-waiting-operator": 999, - "schedule.leader-schedule-limit": 999, - "schedule.region-schedule-limit": 999, - "schedule.hot-region-schedule-limit": 999, - "schedule.replica-schedule-limit": 999, - "schedule.merge-schedule-limit": 999, - "schedule.enable-tikv-split-region": false, -} - -var invalidTTLConfig = map[string]interface{}{ - "schedule.invalid-ttl-config": 0, -} - -func assertTTLConfig( - options *config.PersistOptions, - equality func(interface{}, interface{}, ...interface{}) bool, -) { - equality(uint64(999), options.GetMaxSnapshotCount()) - equality(false, options.IsLocationReplacementEnabled()) - equality(uint64(999), options.GetMaxMergeRegionSize()) - equality(uint64(999), options.GetMaxMergeRegionKeys()) - equality(uint64(999), options.GetSchedulerMaxWaitingOperator()) - equality(uint64(999), options.GetLeaderScheduleLimit()) - equality(uint64(999), options.GetRegionScheduleLimit()) - equality(uint64(999), options.GetHotRegionScheduleLimit()) - equality(uint64(999), options.GetReplicaScheduleLimit()) - equality(uint64(999), options.GetMergeScheduleLimit()) - equality(false, options.IsTikvRegionSplitEnabled()) -} - -func createTTLUrl(url string, ttl int) string { - return fmt.Sprintf("%s/config?ttlSecond=%d", url, ttl) -} - -func (suite *configTestSuite) TestConfigTTL() { - postData, err := json.Marshal(ttlConfig) - suite.NoError(err) - - // test no config and cleaning up - re := suite.Require() - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) - suite.NoError(err) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) - - // test time goes by - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) - suite.NoError(err) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) - time.Sleep(2 * time.Second) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) - - // test cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) - suite.NoError(err) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) - suite.NoError(err) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.NotEqual) - - postData, err = json.Marshal(invalidTTLConfig) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, - tu.StatusNotOK(re), tu.StringEqual(re, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) - suite.NoError(err) - - // only set max-merge-region-size - mergeConfig := map[string]interface{}{ - "schedule.max-merge-region-size": 999, - } - postData, err = json.Marshal(mergeConfig) - suite.NoError(err) - - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) - suite.NoError(err) - suite.Equal(uint64(999), suite.svr.GetPersistOptions().GetMaxMergeRegionSize()) - // max-merge-region-keys should keep consistence with max-merge-region-size. - suite.Equal(uint64(999*10000), suite.svr.GetPersistOptions().GetMaxMergeRegionKeys()) - - // on invalid value, we use default config - mergeConfig = map[string]interface{}{ - "schedule.enable-tikv-split-region": "invalid", - } - postData, err = json.Marshal(mergeConfig) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 1), postData, tu.StatusOK(re)) - suite.NoError(err) - suite.True(suite.svr.GetPersistOptions().IsTikvRegionSplitEnabled()) -} - -func (suite *configTestSuite) TestTTLConflict() { - addr := createTTLUrl(suite.urlPrefix, 1) - postData, err := json.Marshal(ttlConfig) - suite.NoError(err) - re := suite.Require() - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) - assertTTLConfig(suite.svr.GetPersistOptions(), suite.Equal) - - cfg := map[string]interface{}{"max-snapshot-count": 30} - postData, err = json.Marshal(cfg) - suite.NoError(err) - addr = fmt.Sprintf("%s/config", suite.urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) - suite.NoError(err) - addr = fmt.Sprintf("%s/config/schedule", suite.urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) - suite.NoError(err) - cfg = map[string]interface{}{"schedule.max-snapshot-count": 30} - postData, err = json.Marshal(cfg) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(suite.urlPrefix, 0), postData, tu.StatusOK(re)) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) - suite.NoError(err) -} diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 3793c09d883..15c66ce5829 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" + "github.com/tikv/pd/pkg/mcs/scheduling/server/config" "github.com/tikv/pd/pkg/schedule/handler" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/storage" @@ -242,6 +243,88 @@ func (suite *apiTestSuite) TestAPIForward() { re.NoError(err) } +func (suite *apiTestSuite) TestConfig() { + checkConfig := func(cluster *tests.TestCluster) { + re := suite.Require() + s := cluster.GetSchedulingPrimaryServer() + testutil.Eventually(re, func() bool { + return s.IsServing() + }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) + addr := s.GetAddr() + urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/config", addr) + + var cfg config.Config + testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + suite.Equal(cfg.GetListenAddr(), s.GetConfig().GetListenAddr()) + suite.Equal(cfg.Schedule.LeaderScheduleLimit, s.GetConfig().Schedule.LeaderScheduleLimit) + suite.Equal(cfg.Schedule.EnableCrossTableMerge, s.GetConfig().Schedule.EnableCrossTableMerge) + suite.Equal(cfg.Replication.MaxReplicas, s.GetConfig().Replication.MaxReplicas) + suite.Equal(cfg.Replication.LocationLabels, s.GetConfig().Replication.LocationLabels) + suite.Equal(cfg.DataDir, s.GetConfig().DataDir) + testutil.Eventually(re, func() bool { + // wait for all schedulers to be loaded in scheduling server. + return len(cfg.Schedule.SchedulersPayload) == 5 + }) + suite.Contains(cfg.Schedule.SchedulersPayload, "balance-leader-scheduler") + suite.Contains(cfg.Schedule.SchedulersPayload, "balance-region-scheduler") + suite.Contains(cfg.Schedule.SchedulersPayload, "balance-hot-region-scheduler") + suite.Contains(cfg.Schedule.SchedulersPayload, "balance-witness-scheduler") + suite.Contains(cfg.Schedule.SchedulersPayload, "transfer-witness-leader-scheduler") + } + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInAPIMode(checkConfig) +} + +func TestConfigForward(t *testing.T) { + re := require.New(t) + checkConfigForward := func(cluster *tests.TestCluster) { + sche := cluster.GetSchedulingPrimaryServer() + opts := sche.GetPersistConfig() + var cfg map[string]interface{} + addr := cluster.GetLeaderServer().GetAddr() + urlPrefix := fmt.Sprintf("%s/pd/api/v1/config", addr) + + // Test config forward + // Expect to get same config in scheduling server and api server + testutil.Eventually(re, func() bool { + testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + re.Equal(cfg["schedule"].(map[string]interface{})["leader-schedule-limit"], + float64(opts.GetLeaderScheduleLimit())) + re.Equal(cfg["replication"].(map[string]interface{})["max-replicas"], + float64(opts.GetReplicationConfig().MaxReplicas)) + schedulers := cfg["schedule"].(map[string]interface{})["schedulers-payload"].(map[string]interface{}) + return len(schedulers) == 5 + }) + + // Test to change config in api server + // Expect to get new config in scheduling server and api server + reqData, err := json.Marshal(map[string]interface{}{ + "max-replicas": 4, + }) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, urlPrefix, reqData, testutil.StatusOK(re)) + re.NoError(err) + testutil.Eventually(re, func() bool { + testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + return cfg["replication"].(map[string]interface{})["max-replicas"] == 4. && + opts.GetReplicationConfig().MaxReplicas == 4. + }) + + // Test to change config only in scheduling server + // Expect to get new config in scheduling server but not old config in api server + opts.GetScheduleConfig().LeaderScheduleLimit = 100 + re.Equal(100, int(opts.GetLeaderScheduleLimit())) + testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + re.Equal(100., cfg["schedule"].(map[string]interface{})["leader-schedule-limit"]) + opts.GetReplicationConfig().MaxReplicas = 5 + re.Equal(5, int(opts.GetReplicationConfig().MaxReplicas)) + testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + re.Equal(5., cfg["replication"].(map[string]interface{})["max-replicas"]) + } + env := tests.NewSchedulingTestEnvironment(t) + env.RunTestInAPIMode(checkConfigForward) +} + func TestAdminRegionCache(t *testing.T) { re := require.New(t) checkAdminRegionCache := func(cluster *tests.TestCluster) { diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 6ed0841bf74..26d70bb955f 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -25,8 +25,10 @@ import ( "github.com/coreos/go-semver/semver" "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" @@ -48,24 +50,29 @@ func (t *testCase) judge(re *require.Assertions, scheduleConfigs ...*sc.Schedule } } -func TestConfig(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +type configTestSuite struct { + suite.Suite +} + +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, new(configTestSuite)) +} + +func (suite *configTestSuite) TestConfig() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfig) +} + +func (suite *configTestSuite) checkConfig(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ Id: 1, State: metapb.StoreState_Up, } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() tests.MustPutStore(re, cluster, store) defer cluster.Destroy() @@ -283,16 +290,15 @@ func TestConfig(t *testing.T) { re.Contains(string(output), "is invalid") } -func TestPlacementRules(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +func (suite *configTestSuite) TestPlacementRules() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkPlacementRules) +} + +func (suite *configTestSuite) checkPlacementRules(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ @@ -300,8 +306,6 @@ func TestPlacementRules(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(re, cluster, store) defer cluster.Destroy() @@ -380,16 +384,15 @@ func TestPlacementRules(t *testing.T) { re.Equal([2]string{"pd", "test1"}, rules[0].Key()) } -func TestPlacementRuleGroups(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +func (suite *configTestSuite) TestPlacementRuleGroups() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkPlacementRuleGroups) +} + +func (suite *configTestSuite) checkPlacementRuleGroups(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ @@ -397,8 +400,6 @@ func TestPlacementRuleGroups(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(re, cluster, store) defer cluster.Destroy() @@ -454,16 +455,15 @@ func TestPlacementRuleGroups(t *testing.T) { re.Contains(string(output), "404") } -func TestPlacementRuleBundle(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +func (suite *configTestSuite) TestPlacementRuleBundle() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkPlacementRuleBundle) +} + +func (suite *configTestSuite) checkPlacementRuleBundle(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ @@ -471,8 +471,6 @@ func TestPlacementRuleBundle(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(re, cluster, store) defer cluster.Destroy() @@ -648,24 +646,21 @@ func TestReplicationMode(t *testing.T) { check() } -func TestUpdateDefaultReplicaConfig(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +func (suite *configTestSuite) TestUpdateDefaultReplicaConfig() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkUpdateDefaultReplicaConfig) +} + +func (suite *configTestSuite) checkUpdateDefaultReplicaConfig(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ Id: 1, State: metapb.StoreState_Up, } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(re, cluster, store) defer cluster.Destroy() @@ -675,7 +670,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) replicationCfg := sc.ReplicationConfig{} re.NoError(json.Unmarshal(output, &replicationCfg)) - re.Equal(expect, replicationCfg.MaxReplicas) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return replicationCfg.MaxReplicas == expect + }) } checkLocationLabels := func(expect int) { @@ -684,7 +681,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) replicationCfg := sc.ReplicationConfig{} re.NoError(json.Unmarshal(output, &replicationCfg)) - re.Len(replicationCfg.LocationLabels, expect) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return len(replicationCfg.LocationLabels) == expect + }) } checkIsolationLevel := func(expect string) { @@ -693,7 +692,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) replicationCfg := sc.ReplicationConfig{} re.NoError(json.Unmarshal(output, &replicationCfg)) - re.Equal(replicationCfg.IsolationLevel, expect) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return replicationCfg.IsolationLevel == expect + }) } checkRuleCount := func(expect int) { @@ -702,7 +703,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) rule := placement.Rule{} re.NoError(json.Unmarshal(output, &rule)) - re.Equal(expect, rule.Count) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return rule.Count == expect + }) } checkRuleLocationLabels := func(expect int) { @@ -711,7 +714,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) rule := placement.Rule{} re.NoError(json.Unmarshal(output, &rule)) - re.Len(rule.LocationLabels, expect) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return len(rule.LocationLabels) == expect + }) } checkRuleIsolationLevel := func(expect string) { @@ -720,7 +725,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) rule := placement.Rule{} re.NoError(json.Unmarshal(output, &rule)) - re.Equal(rule.IsolationLevel, expect) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + return rule.IsolationLevel == expect + }) } // update successfully when placement rules is not enabled. @@ -764,7 +771,7 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { checkRuleIsolationLevel("host") // update unsuccessfully when many rule exists. - fname := t.TempDir() + fname := suite.T().TempDir() rules := []placement.Rule{ { GroupID: "pd", @@ -791,16 +798,15 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { checkRuleIsolationLevel("host") } -func TestPDServerConfig(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestCluster(ctx, 1) - re.NoError(err) - err = cluster.RunInitialServers() - re.NoError(err) - cluster.WaitLeader() - pdAddr := cluster.GetConfig().GetClientURL() +func (suite *configTestSuite) TestPDServerConfig() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkPDServerConfig) +} + +func (suite *configTestSuite) checkPDServerConfig(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() cmd := pdctlCmd.GetRootCmd() store := &metapb.Store{ @@ -808,8 +814,6 @@ func TestPDServerConfig(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetLeaderServer() - re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(re, cluster, store) defer cluster.Destroy() diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index 3554b828269..cd599405124 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -46,7 +46,6 @@ func TestSchedulerTestSuite(t *testing.T) { func (suite *schedulerTestSuite) TestScheduler() { env := tests.NewSchedulingTestEnvironment(suite.T()) env.RunTestInTwoModes(suite.checkScheduler) - env.RunTestInTwoModes(suite.checkSchedulerDiagnostic) } func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { @@ -414,8 +413,10 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler", "show"}, &conf) re.Equal(4., conf["batch"]) mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler", "set", "batch", "3"}, nil) - mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler"}, &conf1) - re.Equal(3., conf1["batch"]) + testutil.Eventually(re, func() bool { + mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "config", "balance-leader-scheduler"}, &conf1) + return conf1["batch"] == 3. + }) echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "add", "balance-leader-scheduler"}, nil) re.NotContains(echo, "Success!") echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "remove", "balance-leader-scheduler"}, nil) @@ -494,6 +495,11 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { checkSchedulerWithStatusCommand("disabled", nil) } +func (suite *schedulerTestSuite) TestSchedulerDiagnostic() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkSchedulerDiagnostic) +} + func (suite *schedulerTestSuite) checkSchedulerDiagnostic(cluster *tests.TestCluster) { re := suite.Require() pdAddr := cluster.GetConfig().GetClientURL() diff --git a/tests/server/api/operator_test.go b/tests/server/api/operator_test.go index 83ab0f3c7ed..908daf21aac 100644 --- a/tests/server/api/operator_test.go +++ b/tests/server/api/operator_test.go @@ -51,7 +51,7 @@ func TestOperatorTestSuite(t *testing.T) { suite.Run(t, new(operatorTestSuite)) } -func (suite *operatorTestSuite) TestOperator() { +func (suite *operatorTestSuite) TestAddRemovePeer() { opts := []tests.ConfigOption{ func(conf *config.Config, serverName string) { conf.Replication.MaxReplicas = 1 @@ -59,17 +59,6 @@ func (suite *operatorTestSuite) TestOperator() { } env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) env.RunTestInTwoModes(suite.checkAddRemovePeer) - - env = tests.NewSchedulingTestEnvironment(suite.T(), opts...) - env.RunTestInTwoModes(suite.checkMergeRegionOperator) - - opts = []tests.ConfigOption{ - func(conf *config.Config, serverName string) { - conf.Replication.MaxReplicas = 3 - }, - } - env = tests.NewSchedulingTestEnvironment(suite.T(), opts...) - env.RunTestInTwoModes(suite.checkTransferRegionWithPlacementRule) } func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { @@ -178,6 +167,16 @@ func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { suite.NoError(err) } +func (suite *operatorTestSuite) TestMergeRegionOperator() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.Replication.MaxReplicas = 1 + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkMergeRegionOperator) +} + func (suite *operatorTestSuite) checkMergeRegionOperator(cluster *tests.TestCluster) { re := suite.Require() suite.pauseRuleChecker(cluster) @@ -204,6 +203,16 @@ func (suite *operatorTestSuite) checkMergeRegionOperator(cluster *tests.TestClus suite.NoError(err) } +func (suite *operatorTestSuite) TestTransferRegionWithPlacementRule() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.Replication.MaxReplicas = 3 + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkTransferRegionWithPlacementRule) +} + func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *tests.TestCluster) { re := suite.Require() suite.pauseRuleChecker(cluster) diff --git a/tests/server/api/scheduler_test.go b/tests/server/api/scheduler_test.go index 9db94e8562d..38f691a4eda 100644 --- a/tests/server/api/scheduler_test.go +++ b/tests/server/api/scheduler_test.go @@ -42,13 +42,9 @@ func TestScheduleTestSuite(t *testing.T) { suite.Run(t, new(scheduleTestSuite)) } -func (suite *scheduleTestSuite) TestScheduler() { +func (suite *scheduleTestSuite) TestOriginAPI() { env := tests.NewSchedulingTestEnvironment(suite.T()) env.RunTestInTwoModes(suite.checkOriginAPI) - env = tests.NewSchedulingTestEnvironment(suite.T()) - env.RunTestInTwoModes(suite.checkAPI) - env = tests.NewSchedulingTestEnvironment(suite.T()) - env.RunTestInTwoModes(suite.checkDisable) } func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { @@ -115,6 +111,11 @@ func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { suite.NoError(err) } +func (suite *scheduleTestSuite) TestAPI() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkAPI) +} + func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { re := suite.Require() leaderAddr := cluster.GetLeaderServer().GetAddr() @@ -153,9 +154,12 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { body, err := json.Marshal(dataMap) suite.NoError(err) suite.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) - resp = make(map[string]interface{}) - suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) - suite.Equal(3.0, resp["batch"]) + tu.Eventually(re, func() bool { // wait for scheduling server to be synced. + resp = make(map[string]interface{}) + suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + return resp["batch"] == 3.0 + }) + // update again err = tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re), @@ -556,6 +560,11 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { } } +func (suite *scheduleTestSuite) TestDisable() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkDisable) +} + func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { re := suite.Require() leaderAddr := cluster.GetLeaderServer().GetAddr() diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index 1b2178bde33..8d8cf40e692 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -18,17 +18,25 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/ratelimit" + sc "github.com/tikv/pd/pkg/schedule/config" + tu "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/pkg/utils/typeutil" + "github.com/tikv/pd/pkg/versioninfo" "github.com/tikv/pd/server" + "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ +// testDialClient used to dial http request. +var testDialClient = &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, }, @@ -56,7 +64,7 @@ func TestRateLimitConfigReload(t *testing.T) { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := testDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) @@ -74,3 +82,520 @@ func TestRateLimitConfigReload(t *testing.T) { re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) re.Len(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, 1) } + +type configTestSuite struct { + suite.Suite +} + +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, new(configTestSuite)) +} + +func (suite *configTestSuite) TestConfigAll() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigAll) +} + +func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addr := fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) + cfg := &config.Config{} + tu.Eventually(re, func() bool { + err := tu.ReadGetJSON(re, testDialClient, addr, cfg) + suite.NoError(err) + return cfg.PDServerCfg.DashboardAddress != "auto" + }) + + // the original way + r := map[string]int{"max-replicas": 5} + postData, err := json.Marshal(r) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + l := map[string]interface{}{ + "location-labels": "zone,rack", + "region-schedule-limit": 10, + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + l = map[string]interface{}{ + "metric-storage": "http://127.0.0.1:9090", + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + newCfg := &config.Config{} + err = tu.ReadGetJSON(re, testDialClient, addr, newCfg) + suite.NoError(err) + cfg.Replication.MaxReplicas = 5 + cfg.Replication.LocationLabels = []string{"zone", "rack"} + cfg.Schedule.RegionScheduleLimit = 10 + cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:9090" + suite.Equal(newCfg, cfg) + + // the new way + l = map[string]interface{}{ + "schedule.tolerant-size-ratio": 2.5, + "schedule.enable-tikv-split-region": "false", + "replication.location-labels": "idc,host", + "pd-server.metric-storage": "http://127.0.0.1:1234", + "log.level": "warn", + "cluster-version": "v4.0.0-beta", + "replication-mode.replication-mode": "dr-auto-sync", + "replication-mode.dr-auto-sync.label-key": "foobar", + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + newCfg1 := &config.Config{} + err = tu.ReadGetJSON(re, testDialClient, addr, newCfg1) + suite.NoError(err) + cfg.Schedule.EnableTiKVSplitRegion = false + cfg.Schedule.TolerantSizeRatio = 2.5 + cfg.Replication.LocationLabels = []string{"idc", "host"} + cfg.PDServerCfg.MetricStorage = "http://127.0.0.1:1234" + cfg.Log.Level = "warn" + cfg.ReplicationMode.DRAutoSync.LabelKey = "foobar" + cfg.ReplicationMode.ReplicationMode = "dr-auto-sync" + v, err := versioninfo.ParseVersion("v4.0.0-beta") + suite.NoError(err) + cfg.ClusterVersion = *v + suite.Equal(cfg, newCfg1) + + // revert this to avoid it affects TestConfigTTL + l["schedule.enable-tikv-split-region"] = "true" + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + // illegal prefix + l = map[string]interface{}{ + "replicate.max-replicas": 1, + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, + tu.StatusNotOK(re), + tu.StringContain(re, "not found")) + suite.NoError(err) + + // update prefix directly + l = map[string]interface{}{ + "replication-mode": nil, + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, + tu.StatusNotOK(re), + tu.StringContain(re, "cannot update config prefix")) + suite.NoError(err) + + // config item not found + l = map[string]interface{}{ + "schedule.region-limit": 10, + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) + suite.NoError(err) +} + +func (suite *configTestSuite) TestConfigSchedule() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigSchedule) +} + +func (suite *configTestSuite) checkConfigSchedule(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addr := fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) + + scheduleConfig := &sc.ScheduleConfig{} + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig)) + scheduleConfig.MaxStoreDownTime.Duration = time.Second + postData, err := json.Marshal(scheduleConfig) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + scheduleConfig1 := &sc.ScheduleConfig{} + suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig1)) + suite.Equal(*scheduleConfig1, *scheduleConfig) +} + +func (suite *configTestSuite) TestConfigReplication() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigReplication) +} + +func (suite *configTestSuite) checkConfigReplication(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addr := fmt.Sprintf("%s/pd/api/v1/config/replicate", urlPrefix) + rc := &sc.ReplicationConfig{} + err := tu.ReadGetJSON(re, testDialClient, addr, rc) + suite.NoError(err) + + rc.MaxReplicas = 5 + rc1 := map[string]int{"max-replicas": 5} + postData, err := json.Marshal(rc1) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + rc.LocationLabels = []string{"zone", "rack"} + rc2 := map[string]string{"location-labels": "zone,rack"} + postData, err = json.Marshal(rc2) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + rc.IsolationLevel = "zone" + rc3 := map[string]string{"isolation-level": "zone"} + postData, err = json.Marshal(rc3) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + rc4 := &sc.ReplicationConfig{} + err = tu.ReadGetJSON(re, testDialClient, addr, rc4) + suite.NoError(err) + + suite.Equal(*rc4, *rc) +} + +func (suite *configTestSuite) TestConfigLabelProperty() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigLabelProperty) +} + +func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addr := urlPrefix + "/pd/api/v1/config/label-property" + loadProperties := func() config.LabelPropertyConfig { + var cfg config.LabelPropertyConfig + err := tu.ReadGetJSON(re, testDialClient, addr, &cfg) + suite.NoError(err) + return cfg + } + + cfg := loadProperties() + suite.Empty(cfg) + + cmds := []string{ + `{"type": "foo", "action": "set", "label-key": "zone", "label-value": "cn1"}`, + `{"type": "foo", "action": "set", "label-key": "zone", "label-value": "cn2"}`, + `{"type": "bar", "action": "set", "label-key": "host", "label-value": "h1"}`, + } + for _, cmd := range cmds { + err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + suite.NoError(err) + } + + cfg = loadProperties() + suite.Len(cfg, 2) + suite.Equal([]config.StoreLabel{ + {Key: "zone", Value: "cn1"}, + {Key: "zone", Value: "cn2"}, + }, cfg["foo"]) + suite.Equal([]config.StoreLabel{{Key: "host", Value: "h1"}}, cfg["bar"]) + + cmds = []string{ + `{"type": "foo", "action": "delete", "label-key": "zone", "label-value": "cn1"}`, + `{"type": "bar", "action": "delete", "label-key": "host", "label-value": "h1"}`, + } + for _, cmd := range cmds { + err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + suite.NoError(err) + } + + cfg = loadProperties() + suite.Len(cfg, 1) + suite.Equal([]config.StoreLabel{{Key: "zone", Value: "cn2"}}, cfg["foo"]) +} + +func (suite *configTestSuite) TestConfigDefault() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigDefault) +} + +func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addr := urlPrefix + "/pd/api/v1/config" + + r := map[string]int{"max-replicas": 5} + postData, err := json.Marshal(r) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + l := map[string]interface{}{ + "location-labels": "zone,rack", + "region-schedule-limit": 10, + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + l = map[string]interface{}{ + "metric-storage": "http://127.0.0.1:9090", + } + postData, err = json.Marshal(l) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + + addr = fmt.Sprintf("%s/pd/api/v1/config/default", urlPrefix) + defaultCfg := &config.Config{} + err = tu.ReadGetJSON(re, testDialClient, addr, defaultCfg) + suite.NoError(err) + + suite.Equal(uint64(3), defaultCfg.Replication.MaxReplicas) + suite.Equal(typeutil.StringSlice([]string{}), defaultCfg.Replication.LocationLabels) + suite.Equal(uint64(2048), defaultCfg.Schedule.RegionScheduleLimit) + suite.Equal("", defaultCfg.PDServerCfg.MetricStorage) +} + +func (suite *configTestSuite) TestConfigPDServer() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + env.RunTestInTwoModes(suite.checkConfigPDServer) +} + +func (suite *configTestSuite) checkConfigPDServer(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + + addrPost := urlPrefix + "/pd/api/v1/config" + ms := map[string]interface{}{ + "metric-storage": "", + } + postData, err := json.Marshal(ms) + suite.NoError(err) + suite.NoError(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(re))) + addrGet := fmt.Sprintf("%s/pd/api/v1/config/pd-server", urlPrefix) + sc := &config.PDServerConfig{} + suite.NoError(tu.ReadGetJSON(re, testDialClient, addrGet, sc)) + suite.Equal(bool(true), sc.UseRegionStorage) + suite.Equal("table", sc.KeyType) + suite.Equal(typeutil.StringSlice([]string{}), sc.RuntimeServices) + suite.Equal("", sc.MetricStorage) + suite.Equal("auto", sc.DashboardAddress) + suite.Equal(int(3), sc.FlowRoundByDigit) + suite.Equal(typeutil.NewDuration(time.Second), sc.MinResolvedTSPersistenceInterval) + suite.Equal(24*time.Hour, sc.MaxResetTSGap.Duration) +} + +var ttlConfig = map[string]interface{}{ + "schedule.max-snapshot-count": 999, + "schedule.enable-location-replacement": false, + "schedule.max-merge-region-size": 999, + "schedule.max-merge-region-keys": 999, + "schedule.scheduler-max-waiting-operator": 999, + "schedule.leader-schedule-limit": 999, + "schedule.region-schedule-limit": 999, + "schedule.hot-region-schedule-limit": 999, + "schedule.replica-schedule-limit": 999, + "schedule.merge-schedule-limit": 999, + "schedule.enable-tikv-split-region": false, +} + +var invalidTTLConfig = map[string]interface{}{ + "schedule.invalid-ttl-config": 0, +} + +type ttlConfigInterface interface { + GetMaxSnapshotCount() uint64 + IsLocationReplacementEnabled() bool + GetMaxMergeRegionSize() uint64 + GetMaxMergeRegionKeys() uint64 + GetSchedulerMaxWaitingOperator() uint64 + GetLeaderScheduleLimit() uint64 + GetRegionScheduleLimit() uint64 + GetHotRegionScheduleLimit() uint64 + GetReplicaScheduleLimit() uint64 + GetMergeScheduleLimit() uint64 + IsTikvRegionSplitEnabled() bool +} + +func (suite *configTestSuite) assertTTLConfig( + cluster *tests.TestCluster, + expectedEqual bool, +) { + equality := suite.Equal + if !expectedEqual { + equality = suite.NotEqual + } + checkfunc := func(options ttlConfigInterface) { + equality(uint64(999), options.GetMaxSnapshotCount()) + equality(false, options.IsLocationReplacementEnabled()) + equality(uint64(999), options.GetMaxMergeRegionSize()) + equality(uint64(999), options.GetMaxMergeRegionKeys()) + equality(uint64(999), options.GetSchedulerMaxWaitingOperator()) + equality(uint64(999), options.GetLeaderScheduleLimit()) + equality(uint64(999), options.GetRegionScheduleLimit()) + equality(uint64(999), options.GetHotRegionScheduleLimit()) + equality(uint64(999), options.GetReplicaScheduleLimit()) + equality(uint64(999), options.GetMergeScheduleLimit()) + equality(false, options.IsTikvRegionSplitEnabled()) + } + checkfunc(cluster.GetLeaderServer().GetServer().GetPersistOptions()) + if cluster.GetSchedulingPrimaryServer() != nil { + // wait for the scheduling primary server to be synced + options := cluster.GetSchedulingPrimaryServer().GetPersistConfig() + tu.Eventually(suite.Require(), func() bool { + if expectedEqual { + return uint64(999) == options.GetMaxSnapshotCount() + } + return uint64(999) != options.GetMaxSnapshotCount() + }) + checkfunc(options) + } +} + +func (suite *configTestSuite) assertTTLConfigItemEqaul( + cluster *tests.TestCluster, + item string, + expectedValue interface{}, +) { + checkfunc := func(options ttlConfigInterface) bool { + switch item { + case "max-merge-region-size": + return expectedValue.(uint64) == options.GetMaxMergeRegionSize() + case "max-merge-region-keys": + return expectedValue.(uint64) == options.GetMaxMergeRegionKeys() + case "enable-tikv-split-region": + return expectedValue.(bool) == options.IsTikvRegionSplitEnabled() + } + return false + } + suite.True(checkfunc(cluster.GetLeaderServer().GetServer().GetPersistOptions())) + if cluster.GetSchedulingPrimaryServer() != nil { + // wait for the scheduling primary server to be synced + tu.Eventually(suite.Require(), func() bool { + return checkfunc(cluster.GetSchedulingPrimaryServer().GetPersistConfig()) + }) + } +} + +func createTTLUrl(url string, ttl int) string { + return fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=%d", url, ttl) +} + +func (suite *configTestSuite) TestConfigTTL() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + // FIXME: enable this test in two modes after ttl config is supported. + env.RunTestInPDMode(suite.checkConfigTTL) +} + +func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + postData, err := json.Marshal(ttlConfig) + suite.NoError(err) + + // test no config and cleaning up + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfig(cluster, false) + + // test time goes by + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfig(cluster, true) + time.Sleep(2 * time.Second) + suite.assertTTLConfig(cluster, false) + + // test cleaning up + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfig(cluster, true) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfig(cluster, false) + + postData, err = json.Marshal(invalidTTLConfig) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, + tu.StatusNotOK(re), tu.StringEqual(re, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) + suite.NoError(err) + + // only set max-merge-region-size + mergeConfig := map[string]interface{}{ + "schedule.max-merge-region-size": 999, + } + postData, err = json.Marshal(mergeConfig) + suite.NoError(err) + + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfigItemEqaul(cluster, "max-merge-region-size", uint64(999)) + // max-merge-region-keys should keep consistence with max-merge-region-size. + suite.assertTTLConfigItemEqaul(cluster, "max-merge-region-keys", uint64(999*10000)) + + // on invalid value, we use default config + mergeConfig = map[string]interface{}{ + "schedule.enable-tikv-split-region": "invalid", + } + postData, err = json.Marshal(mergeConfig) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 10), postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfigItemEqaul(cluster, "enable-tikv-split-region", true) +} + +func (suite *configTestSuite) TestTTLConflict() { + env := tests.NewSchedulingTestEnvironment(suite.T()) + // FIXME: enable this test in two modes after ttl config is supported. + env.RunTestInPDMode(suite.checkTTLConflict) +} + +func (suite *configTestSuite) checkTTLConflict(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + urlPrefix := leaderServer.GetAddr() + addr := createTTLUrl(urlPrefix, 1) + postData, err := json.Marshal(ttlConfig) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) + suite.assertTTLConfig(cluster, true) + + cfg := map[string]interface{}{"max-snapshot-count": 30} + postData, err = json.Marshal(cfg) + suite.NoError(err) + addr = fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + suite.NoError(err) + addr = fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + suite.NoError(err) + cfg = map[string]interface{}{"schedule.max-snapshot-count": 30} + postData, err = json.Marshal(cfg) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + suite.NoError(err) +} From d651c6b91f2cbd0b22ae87a22c06317cdee12462 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 8 Nov 2023 11:18:42 +0800 Subject: [PATCH 07/26] core: batch get region size (#7252) close tikv/pd#7248 Signed-off-by: nolouch Co-authored-by: nolouch Co-authored-by: ShuNing --- pkg/core/region.go | 36 ++++++++--- pkg/core/region_test.go | 120 +++++++++++++++++++++++++++++++++++ server/cluster/cluster.go | 3 +- tests/server/api/api_test.go | 6 +- 4 files changed, 152 insertions(+), 13 deletions(-) diff --git a/pkg/core/region.go b/pkg/core/region.go index 2ac323a1272..c9daa69c477 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -41,7 +41,10 @@ import ( "go.uber.org/zap" ) -const randomRegionMaxRetry = 10 +const ( + randomRegionMaxRetry = 10 + scanRegionLimit = 1000 +) // errRegionIsStale is error info for region is stale. func errRegionIsStale(region *metapb.Region, origin *metapb.Region) error { @@ -1610,16 +1613,31 @@ func (r *RegionsInfo) ScanRegionWithIterator(startKey []byte, iterator func(regi // GetRegionSizeByRange scans regions intersecting [start key, end key), returns the total region size of this range. func (r *RegionsInfo) GetRegionSizeByRange(startKey, endKey []byte) int64 { - r.t.RLock() - defer r.t.RUnlock() var size int64 - r.tree.scanRange(startKey, func(region *RegionInfo) bool { - if len(endKey) > 0 && bytes.Compare(region.GetStartKey(), endKey) >= 0 { - return false + for { + r.t.RLock() + var cnt int + r.tree.scanRange(startKey, func(region *RegionInfo) bool { + if len(endKey) > 0 && bytes.Compare(region.GetStartKey(), endKey) >= 0 { + return false + } + if cnt >= scanRegionLimit { + return false + } + cnt++ + startKey = region.GetEndKey() + size += region.GetApproximateSize() + return true + }) + r.t.RUnlock() + if cnt == 0 { + break } - size += region.GetApproximateSize() - return true - }) + if len(startKey) == 0 { + break + } + } + return size } diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 50302de920e..508e7aa59aa 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -18,8 +18,10 @@ import ( "crypto/rand" "fmt" "math" + mrand "math/rand" "strconv" "testing" + "time" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" @@ -658,6 +660,124 @@ func BenchmarkRandomRegion(b *testing.B) { } } +func BenchmarkRandomSetRegion(b *testing.B) { + regions := NewRegionsInfo() + var items []*RegionInfo + for i := 0; i < 1000000; i++ { + peer := &metapb.Peer{StoreId: 1, Id: uint64(i + 1)} + region := NewRegionInfo(&metapb.Region{ + Id: uint64(i + 1), + Peers: []*metapb.Peer{peer}, + StartKey: []byte(fmt.Sprintf("%20d", i)), + EndKey: []byte(fmt.Sprintf("%20d", i+1)), + }, peer) + origin, overlaps, rangeChanged := regions.SetRegion(region) + regions.UpdateSubTree(region, origin, overlaps, rangeChanged) + items = append(items, region) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + item := items[i%len(items)] + item.approximateKeys = int64(200000) + item.approximateSize = int64(20) + origin, overlaps, rangeChanged := regions.SetRegion(item) + regions.UpdateSubTree(item, origin, overlaps, rangeChanged) + } +} + +func TestGetRegionSizeByRange(t *testing.T) { + regions := NewRegionsInfo() + nums := 1000010 + for i := 0; i < nums; i++ { + peer := &metapb.Peer{StoreId: 1, Id: uint64(i + 1)} + endKey := []byte(fmt.Sprintf("%20d", i+1)) + if i == nums-1 { + endKey = []byte("") + } + region := NewRegionInfo(&metapb.Region{ + Id: uint64(i + 1), + Peers: []*metapb.Peer{peer}, + StartKey: []byte(fmt.Sprintf("%20d", i)), + EndKey: endKey, + }, peer, SetApproximateSize(10)) + origin, overlaps, rangeChanged := regions.SetRegion(region) + regions.UpdateSubTree(region, origin, overlaps, rangeChanged) + } + totalSize := regions.GetRegionSizeByRange([]byte(""), []byte("")) + require.Equal(t, int64(nums*10), totalSize) + for i := 1; i < 10; i++ { + verifyNum := nums / i + endKey := fmt.Sprintf("%20d", verifyNum) + totalSize := regions.GetRegionSizeByRange([]byte(""), []byte(endKey)) + require.Equal(t, int64(verifyNum*10), totalSize) + } +} + +func BenchmarkRandomSetRegionWithGetRegionSizeByRange(b *testing.B) { + regions := NewRegionsInfo() + var items []*RegionInfo + for i := 0; i < 1000000; i++ { + peer := &metapb.Peer{StoreId: 1, Id: uint64(i + 1)} + region := NewRegionInfo(&metapb.Region{ + Id: uint64(i + 1), + Peers: []*metapb.Peer{peer}, + StartKey: []byte(fmt.Sprintf("%20d", i)), + EndKey: []byte(fmt.Sprintf("%20d", i+1)), + }, peer, SetApproximateSize(10)) + origin, overlaps, rangeChanged := regions.SetRegion(region) + regions.UpdateSubTree(region, origin, overlaps, rangeChanged) + items = append(items, region) + } + b.ResetTimer() + go func() { + for { + regions.GetRegionSizeByRange([]byte(""), []byte("")) + time.Sleep(time.Millisecond) + } + }() + for i := 0; i < b.N; i++ { + item := items[i%len(items)] + item.approximateKeys = int64(200000) + origin, overlaps, rangeChanged := regions.SetRegion(item) + regions.UpdateSubTree(item, origin, overlaps, rangeChanged) + } +} + +func BenchmarkRandomSetRegionWithGetRegionSizeByRangeParallel(b *testing.B) { + regions := NewRegionsInfo() + var items []*RegionInfo + for i := 0; i < 1000000; i++ { + peer := &metapb.Peer{StoreId: 1, Id: uint64(i + 1)} + region := NewRegionInfo(&metapb.Region{ + Id: uint64(i + 1), + Peers: []*metapb.Peer{peer}, + StartKey: []byte(fmt.Sprintf("%20d", i)), + EndKey: []byte(fmt.Sprintf("%20d", i+1)), + }, peer) + origin, overlaps, rangeChanged := regions.SetRegion(region) + regions.UpdateSubTree(region, origin, overlaps, rangeChanged) + items = append(items, region) + } + b.ResetTimer() + go func() { + for { + regions.GetRegionSizeByRange([]byte(""), []byte("")) + time.Sleep(time.Millisecond) + } + }() + + b.RunParallel( + func(pb *testing.PB) { + for pb.Next() { + item := items[mrand.Intn(len(items))] + n := item.Clone(SetApproximateSize(20)) + origin, overlaps, rangeChanged := regions.SetRegion(n) + regions.UpdateSubTree(item, origin, overlaps, rangeChanged) + } + }, + ) +} + const keyLength = 100 func randomBytes(n int) []byte { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 25a47a7fca9..8362ee9f331 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1846,12 +1846,13 @@ func (c *RaftCluster) checkStores() { if err := c.ReadyToServe(storeID); err != nil { log.Error("change store to serving failed", zap.Stringer("store", store.GetMeta()), + zap.Int("region-count", c.GetTotalRegionCount()), errs.ZapError(err)) } } else if c.IsPrepared() { threshold := c.getThreshold(stores, store) - log.Debug("store serving threshold", zap.Uint64("store-id", storeID), zap.Float64("threshold", threshold)) regionSize := float64(store.GetRegionSize()) + log.Debug("store serving threshold", zap.Uint64("store-id", storeID), zap.Float64("threshold", threshold), zap.Float64("region-size", regionSize)) if regionSize >= threshold { if err := c.ReadyToServe(storeID); err != nil { log.Error("change store to serving failed", diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index ff430f1b848..04bcdc0d461 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -914,7 +914,7 @@ func TestPreparingProgress(t *testing.T) { tests.MustPutStore(re, cluster, store) } for i := 0; i < 100; i++ { - tests.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("%20d", i)), []byte(fmt.Sprintf("%20d", i+1)), core.SetApproximateSize(10)) } // no store preparing output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) @@ -941,8 +941,8 @@ func TestPreparingProgress(t *testing.T) { re.Equal(math.MaxFloat64, p.LeftSeconds) // update size - tests.MustPutRegion(re, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) - tests.MustPutRegion(re, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) + tests.MustPutRegion(re, cluster, 1000, 4, []byte(fmt.Sprintf("%20d", 1000)), []byte(fmt.Sprintf("%20d", 1001)), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, 1001, 5, []byte(fmt.Sprintf("%20d", 1001)), []byte(fmt.Sprintf("%20d", 1002)), core.SetApproximateSize(40)) time.Sleep(2 * time.Second) output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) re.NoError(json.Unmarshal(output, &p)) From 4457ac2717644a39a1ccfaeeb5cfb7ecd0542e99 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 8 Nov 2023 12:14:13 +0800 Subject: [PATCH 08/26] mcs/scheduling: fix typo (#7333) ref tikv/pd#5839 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/grpc_service.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go index 79c5c293ee7..b865e917d75 100644 --- a/pkg/mcs/scheduling/server/grpc_service.go +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -65,7 +65,7 @@ type Service struct { *Server } -// NewService creates a new TSO service. +// NewService creates a new scheduling service. func NewService[T ConfigProvider](svr bs.Server) registry.RegistrableService { server, ok := svr.(*Server) if !ok { @@ -118,7 +118,7 @@ func (s *heartbeatServer) Recv() (*schedulingpb.RegionHeartbeatRequest, error) { return req, nil } -// RegionHeartbeat implements gRPC PDServer. +// RegionHeartbeat implements gRPC SchedulingServer. func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeatServer) error { var ( server = &heartbeatServer{stream: stream} @@ -168,7 +168,7 @@ func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeat } } -// StoreHeartbeat implements gRPC PDServer. +// StoreHeartbeat implements gRPC SchedulingServer. func (s *Service) StoreHeartbeat(ctx context.Context, request *schedulingpb.StoreHeartbeatRequest) (*schedulingpb.StoreHeartbeatResponse, error) { c := s.GetCluster() if c == nil { @@ -202,7 +202,7 @@ func (s *Service) SplitRegions(ctx context.Context, request *schedulingpb.SplitR }, nil } -// ScatterRegions implements gRPC PDServer. +// ScatterRegions implements gRPC SchedulingServer. func (s *Service) ScatterRegions(ctx context.Context, request *schedulingpb.ScatterRegionsRequest) (*schedulingpb.ScatterRegionsResponse, error) { c := s.GetCluster() if c == nil { @@ -261,7 +261,7 @@ func (s *Service) GetOperator(ctx context.Context, request *schedulingpb.GetOper }, nil } -// AskBatchSplit implements gRPC PDServer. +// AskBatchSplit implements gRPC SchedulingServer. func (s *Service) AskBatchSplit(ctx context.Context, request *schedulingpb.AskBatchSplitRequest) (*schedulingpb.AskBatchSplitResponse, error) { c := s.GetCluster() if c == nil { From a98295c22490b9fdb0e2cb052b68691a9d56f6dc Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 8 Nov 2023 16:05:14 +0800 Subject: [PATCH 09/26] mcs: fix participant name (#7335) close tikv/pd#7336 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/resourcemanager/server/config.go | 20 ++++++++++++++++++++ pkg/mcs/resourcemanager/server/server.go | 6 +++--- pkg/mcs/scheduling/server/config/config.go | 15 +++++++++++++++ pkg/mcs/scheduling/server/server.go | 8 ++++---- 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/config.go b/pkg/mcs/resourcemanager/server/config.go index 3f64b2987fd..10e91612842 100644 --- a/pkg/mcs/resourcemanager/server/config.go +++ b/pkg/mcs/resourcemanager/server/config.go @@ -250,6 +250,26 @@ func (c *Config) adjustLog(meta *configutil.ConfigMetaData) { } } +// GetName returns the Name +func (c *Config) GetName() string { + return c.Name +} + +// GeBackendEndpoints returns the BackendEndpoints +func (c *Config) GeBackendEndpoints() string { + return c.BackendEndpoints +} + +// GetListenAddr returns the ListenAddr +func (c *Config) GetListenAddr() string { + return c.ListenAddr +} + +// GetAdvertiseListenAddr returns the AdvertiseListenAddr +func (c *Config) GetAdvertiseListenAddr() string { + return c.AdvertiseListenAddr +} + // GetTLSConfig returns the TLS config. func (c *Config) GetTLSConfig() *grpcutil.TLSConfig { return &c.Security.TLSConfig diff --git a/pkg/mcs/resourcemanager/server/server.go b/pkg/mcs/resourcemanager/server/server.go index 47248208c8a..7b660c07605 100644 --- a/pkg/mcs/resourcemanager/server/server.go +++ b/pkg/mcs/resourcemanager/server/server.go @@ -296,14 +296,14 @@ func (s *Server) startServer() (err error) { // different service modes provided by the same pd-server binary serverInfo.WithLabelValues(versioninfo.PDReleaseVersion, versioninfo.PDGitHash).Set(float64(time.Now().Unix())) - uniqueName := s.cfg.ListenAddr + uniqueName := s.cfg.GetAdvertiseListenAddr() uniqueID := memberutil.GenerateUniqueID(uniqueName) log.Info("joining primary election", zap.String("participant-name", uniqueName), zap.Uint64("participant-id", uniqueID)) s.participant = member.NewParticipant(s.GetClient(), utils.ResourceManagerServiceName) p := &resource_manager.Participant{ Name: uniqueName, Id: uniqueID, // id is unique among all participants - ListenUrls: []string{s.cfg.AdvertiseListenAddr}, + ListenUrls: []string{s.cfg.GetAdvertiseListenAddr()}, } s.participant.InitInfo(p, endpoint.ResourceManagerSvcRootPath(s.clusterID), utils.PrimaryKey, "primary election") @@ -312,7 +312,7 @@ func (s *Server) startServer() (err error) { manager: NewManager[*Server](s), } - if err := s.InitListener(s.GetTLSConfig(), s.cfg.ListenAddr); err != nil { + if err := s.InitListener(s.GetTLSConfig(), s.cfg.GetListenAddr()); err != nil { return err } diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index 772eab835f1..a211c989c64 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -164,11 +164,26 @@ func (c *Config) adjustLog(meta *configutil.ConfigMetaData) { } } +// GetName returns the Name +func (c *Config) GetName() string { + return c.Name +} + +// GeBackendEndpoints returns the BackendEndpoints +func (c *Config) GeBackendEndpoints() string { + return c.BackendEndpoints +} + // GetListenAddr returns the ListenAddr func (c *Config) GetListenAddr() string { return c.ListenAddr } +// GetAdvertiseListenAddr returns the AdvertiseListenAddr +func (c *Config) GetAdvertiseListenAddr() string { + return c.AdvertiseListenAddr +} + // GetTLSConfig returns the TLS config. func (c *Config) GetTLSConfig() *grpcutil.TLSConfig { return &c.Security.TLSConfig diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index 1790cb2b4be..4304ffb218a 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -405,21 +405,21 @@ func (s *Server) startServer() (err error) { // different service modes provided by the same pd-server binary serverInfo.WithLabelValues(versioninfo.PDReleaseVersion, versioninfo.PDGitHash).Set(float64(time.Now().Unix())) - uniqueName := s.cfg.ListenAddr + uniqueName := s.cfg.GetAdvertiseListenAddr() uniqueID := memberutil.GenerateUniqueID(uniqueName) log.Info("joining primary election", zap.String("participant-name", uniqueName), zap.Uint64("participant-id", uniqueID)) s.participant = member.NewParticipant(s.GetClient(), utils.SchedulingServiceName) p := &schedulingpb.Participant{ Name: uniqueName, Id: uniqueID, // id is unique among all participants - ListenUrls: []string{s.cfg.AdvertiseListenAddr}, + ListenUrls: []string{s.cfg.GetAdvertiseListenAddr()}, } s.participant.InitInfo(p, endpoint.SchedulingSvcRootPath(s.clusterID), utils.PrimaryKey, "primary election") s.service = &Service{Server: s} s.AddServiceReadyCallback(s.startCluster) s.AddServiceExitCallback(s.stopCluster) - if err := s.InitListener(s.GetTLSConfig(), s.cfg.ListenAddr); err != nil { + if err := s.InitListener(s.GetTLSConfig(), s.cfg.GetListenAddr()); err != nil { return err } @@ -443,7 +443,7 @@ func (s *Server) startServer() (err error) { return err } s.serviceRegister = discovery.NewServiceRegister(s.Context(), s.GetClient(), strconv.FormatUint(s.clusterID, 10), - utils.SchedulingServiceName, s.cfg.AdvertiseListenAddr, serializedEntry, discovery.DefaultLeaseInSeconds) + utils.SchedulingServiceName, s.cfg.GetAdvertiseListenAddr(), serializedEntry, discovery.DefaultLeaseInSeconds) if err := s.serviceRegister.Register(); err != nil { log.Error("failed to register the service", zap.String("service-name", utils.SchedulingServiceName), errs.ZapError(err)) return err From d189a42894f5e2d957f9f4da0790cbe1468558da Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 8 Nov 2023 17:08:12 +0800 Subject: [PATCH 10/26] mcs: solve stream error when forward tso (#7327) close tikv/pd#7320 Signed-off-by: lhy1024 --- pkg/utils/grpcutil/grpcutil.go | 14 +++++++ server/gc_service.go | 11 +----- server/grpc_service.go | 71 +++++++++++++++++----------------- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index ee9d85a4ee1..44d45ff4c70 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -18,7 +18,9 @@ import ( "context" "crypto/tls" "crypto/x509" + "io" "net/url" + "strings" "time" "github.com/pingcap/errors" @@ -28,6 +30,7 @@ import ( "go.etcd.io/etcd/pkg/transport" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" ) @@ -221,3 +224,14 @@ func CheckStream(ctx context.Context, cancel context.CancelFunc, done chan struc } <-done } + +// NeedRebuildConnection checks if the error is a connection error. +func NeedRebuildConnection(err error) bool { + return err == io.EOF || + strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition. + strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion. + strings.Contains(err.Error(), codes.Internal.String()) || // Internal errors. + strings.Contains(err.Error(), codes.Unknown.String()) || // Unknown error. + strings.Contains(err.Error(), codes.ResourceExhausted.String()) // ResourceExhausted is returned when either the client or the server has exhausted their resources. + // Besides, we don't need to rebuild the connection if the code is Canceled, which means the client cancelled the request. +} diff --git a/server/gc_service.go b/server/gc_service.go index d8a0158920d..90333654e5e 100644 --- a/server/gc_service.go +++ b/server/gc_service.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/storage/endpoint" - "github.com/tikv/pd/pkg/tso" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/tsoutil" "go.etcd.io/etcd/clientv3" @@ -107,15 +106,7 @@ func (s *GrpcServer) UpdateServiceSafePointV2(ctx context.Context, request *pdpb return rsp.(*pdpb.UpdateServiceSafePointV2Response), err } - var ( - nowTSO pdpb.Timestamp - err error - ) - if s.IsAPIServiceMode() { - nowTSO, err = s.getGlobalTSOFromTSOServer(ctx) - } else { - nowTSO, err = s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) - } + nowTSO, err := s.getGlobalTSO(ctx) if err != nil { return nil, err } diff --git a/server/grpc_service.go b/server/grpc_service.go index 4aa6dc5b1da..05ec38919cb 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -2002,15 +2002,7 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb return nil, err } } - var ( - nowTSO pdpb.Timestamp - err error - ) - if s.IsAPIServiceMode() { - nowTSO, err = s.getGlobalTSOFromTSOServer(ctx) - } else { - nowTSO, err = s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) - } + nowTSO, err := s.getGlobalTSO(ctx) if err != nil { return nil, err } @@ -2608,7 +2600,10 @@ func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient } } -func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timestamp, error) { +func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { + if !s.IsAPIServiceMode() { + return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) + } request := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: s.clusterID, @@ -2622,9 +2617,28 @@ func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timest forwardStream tsopb.TSO_TsoClient ts *tsopb.TsoResponse err error + ok bool ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } for i := 0; i < maxRetryTimesRequestTSOServer; i++ { - forwardedHost, ok := s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) if !ok || forwardedHost == "" { return pdpb.Timestamp{}, ErrNotFoundTSOAddr } @@ -2632,32 +2646,25 @@ func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timest if err != nil { return pdpb.Timestamp{}, err } - err := forwardStream.Send(request) + err = forwardStream.Send(request) if err != nil { - s.tsoClientPool.Lock() - delete(s.tsoClientPool.clients, forwardedHost) - s.tsoClientPool.Unlock() - continue + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err } ts, err = forwardStream.Recv() if err != nil { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - time.Sleep(retryIntervalRequestTSOServer) - continue - } - if strings.Contains(err.Error(), codes.Unavailable.String()) { - s.tsoClientPool.Lock() - delete(s.tsoClientPool.clients, forwardedHost) - s.tsoClientPool.Unlock() + if needRetry := handleStreamError(err); needRetry { continue } - log.Error("get global tso from tso service primary addr failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) return pdpb.Timestamp{}, err } return *ts.GetTimestamp(), nil } - log.Error("get global tso from tso service primary addr failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) return pdpb.Timestamp{}, err } @@ -2906,15 +2913,7 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set return rsp.(*pdpb.SetExternalTimestampResponse), nil } - var ( - nowTSO pdpb.Timestamp - err error - ) - if s.IsAPIServiceMode() { - nowTSO, err = s.getGlobalTSOFromTSOServer(ctx) - } else { - nowTSO, err = s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) - } + nowTSO, err := s.getGlobalTSO(ctx) if err != nil { return nil, err } From 2c07c241114fe9afabd9927ecbee61c4252f2d8e Mon Sep 17 00:00:00 2001 From: Sparkle <1284531+baurine@users.noreply.github.com> Date: Wed, 8 Nov 2023 18:18:42 +0800 Subject: [PATCH 11/26] chore(dashboard): update tidb dashboard verstion to v2023.11.08.1 (#7339) close tikv/pd#7340 Signed-off-by: baurine <2008.hbl@gmail.com> --- go.mod | 2 +- go.sum | 4 ++-- tests/integrations/client/go.mod | 2 +- tests/integrations/client/go.sum | 4 ++-- tests/integrations/mcs/go.mod | 2 +- tests/integrations/mcs/go.sum | 4 ++-- tests/integrations/tso/go.mod | 2 +- tests/integrations/tso/go.sum | 4 ++-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index e8da2542be2..0306d70f7a3 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 - github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 + github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 github.com/prometheus/client_golang v1.11.1 github.com/prometheus/common v0.26.0 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/go.sum b/go.sum index 28e210ef1cd..fb178321864 100644 --- a/go.sum +++ b/go.sum @@ -446,8 +446,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 h1:wnHt7ETIB0vm+gbLx8QhcIEmRtrT4QlWlfpcI9vjxOk= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/client/go.mod b/tests/integrations/client/go.mod index b9b868cf8e3..a4aca195f3f 100644 --- a/tests/integrations/client/go.mod +++ b/tests/integrations/client/go.mod @@ -119,7 +119,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/client/go.sum b/tests/integrations/client/go.sum index 81fa6fd7b39..ef9c4d2a5f3 100644 --- a/tests/integrations/client/go.sum +++ b/tests/integrations/client/go.sum @@ -410,8 +410,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 h1:wnHt7ETIB0vm+gbLx8QhcIEmRtrT4QlWlfpcI9vjxOk= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/mcs/go.mod b/tests/integrations/mcs/go.mod index c2dfdbe96ef..f6df0eb4de0 100644 --- a/tests/integrations/mcs/go.mod +++ b/tests/integrations/mcs/go.mod @@ -119,7 +119,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/mcs/go.sum b/tests/integrations/mcs/go.sum index d1b0962ab55..fc1dc1bbea5 100644 --- a/tests/integrations/mcs/go.sum +++ b/tests/integrations/mcs/go.sum @@ -414,8 +414,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 h1:wnHt7ETIB0vm+gbLx8QhcIEmRtrT4QlWlfpcI9vjxOk= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/tso/go.mod b/tests/integrations/tso/go.mod index e5131f15d91..7e833943e6e 100644 --- a/tests/integrations/tso/go.mod +++ b/tests/integrations/tso/go.mod @@ -117,7 +117,7 @@ require ( github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/tso/go.sum b/tests/integrations/tso/go.sum index 576c3e75765..65a7f3e3558 100644 --- a/tests/integrations/tso/go.sum +++ b/tests/integrations/tso/go.sum @@ -408,8 +408,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9 h1:xIeaDUq2ItkYMIgpWXAYKC/N3hs8aurfFvvz79lhHYE= -github.com/pingcap/tidb-dashboard v0.0.0-20231102083420-865955cd15d9/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537 h1:wnHt7ETIB0vm+gbLx8QhcIEmRtrT4QlWlfpcI9vjxOk= +github.com/pingcap/tidb-dashboard v0.0.0-20231108071238-7cb8b7ff0537/go.mod h1:EZ90+V5S4TttbYag6oKZ3jcNKRwZe1Mc9vXwOt9JBYw= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= From 0c352271d7413bdf6ac948e11f1a3fb905fe2ccb Mon Sep 17 00:00:00 2001 From: disksing Date: Fri, 10 Nov 2023 12:03:42 +0800 Subject: [PATCH 12/26] dr-autosync: add recover timeout (#6295) ref tikv/pd#4399 Signed-off-by: husharp --- pkg/replication/replication_mode.go | 15 +++++- pkg/replication/replication_mode_test.go | 60 ++++++++++++++++-------- server/config/config.go | 15 +++--- 3 files changed, 62 insertions(+), 28 deletions(-) diff --git a/pkg/replication/replication_mode.go b/pkg/replication/replication_mode.go index 30b34e4596a..9093f911901 100644 --- a/pkg/replication/replication_mode.go +++ b/pkg/replication/replication_mode.go @@ -212,6 +212,7 @@ const ( type drAutoSyncStatus struct { State string `json:"state,omitempty"` StateID uint64 `json:"state_id,omitempty"` + AsyncStartTime *time.Time `json:"async_start,omitempty"` RecoverStartTime *time.Time `json:"recover_start,omitempty"` TotalRegions int `json:"total_regions,omitempty"` SyncedRegions int `json:"synced_regions,omitempty"` @@ -262,7 +263,8 @@ func (m *ModeManager) drSwitchToAsyncWithLock(availableStores []uint64) error { log.Warn("failed to switch to async state", zap.String("replicate-mode", modeDRAutoSync), errs.ZapError(err)) return err } - dr := drAutoSyncStatus{State: drStateAsync, StateID: id, AvailableStores: availableStores} + now := time.Now() + dr := drAutoSyncStatus{State: drStateAsync, StateID: id, AvailableStores: availableStores, AsyncStartTime: &now} if err := m.storage.SaveReplicationStatus(modeDRAutoSync, dr); err != nil { log.Warn("failed to switch to async state", zap.String("replicate-mode", modeDRAutoSync), errs.ZapError(err)) return err @@ -272,6 +274,15 @@ func (m *ModeManager) drSwitchToAsyncWithLock(availableStores []uint64) error { return nil } +func (m *ModeManager) drDurationSinceAsyncStart() time.Duration { + m.RLock() + defer m.RUnlock() + if m.drAutoSync.AsyncStartTime == nil { + return 0 + } + return time.Since(*m.drAutoSync.AsyncStartTime) +} + func (m *ModeManager) drSwitchToSyncRecover() error { m.Lock() defer m.Unlock() @@ -477,7 +488,7 @@ func (m *ModeManager) tickUpdateState() { m.drSwitchToAsync(storeIDs[primaryUp]) } case drStateAsync: - if canSync { + if canSync && m.drDurationSinceAsyncStart() > m.config.DRAutoSync.WaitRecoverTimeout.Duration { m.drSwitchToSyncRecover() break } diff --git a/pkg/replication/replication_mode_test.go b/pkg/replication/replication_mode_test.go index e01fb7a0b9a..5cf9f1a1450 100644 --- a/pkg/replication/replication_mode_test.go +++ b/pkg/replication/replication_mode_test.go @@ -16,6 +16,7 @@ package replication import ( "context" + "encoding/json" "errors" "fmt" "testing" @@ -159,6 +160,20 @@ func newMockReplicator(ids []uint64) *mockFileReplicator { } } +func assertLastData(t *testing.T, data string, state string, stateID uint64, availableStores []uint64) { + type status struct { + State string `json:"state"` + StateID uint64 `json:"state_id"` + AvailableStores []uint64 `json:"available_stores"` + } + var s status + err := json.Unmarshal([]byte(data), &s) + require.NoError(t, err) + require.Equal(t, state, s.State) + require.Equal(t, stateID, s.StateID) + require.Equal(t, availableStores, s.AvailableStores) +} + func TestStateSwitch(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -190,7 +205,7 @@ func TestStateSwitch(t *testing.T) { stateID := rep.drAutoSync.StateID re.NotEqual(uint64(0), stateID) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "sync", stateID, nil) assertStateIDUpdate := func() { re.NotEqual(stateID, rep.drAutoSync.StateID) stateID = rep.drAutoSync.StateID @@ -207,7 +222,7 @@ func TestStateSwitch(t *testing.T) { re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", stateID, []uint64{1, 2, 3, 4}) re.False(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) conf.DRAutoSync.PauseRegionSplit = true @@ -218,7 +233,7 @@ func TestStateSwitch(t *testing.T) { rep.tickUpdateState() assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async", stateID, []uint64{1, 2, 3, 4}) // add new store in dr zone. cluster.AddLabelsStore(5, 1, map[string]string{"zone": "zone2"}) @@ -268,18 +283,19 @@ func TestStateSwitch(t *testing.T) { rep.tickUpdateState() re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() + rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", stateID, []uint64{1, 2, 3, 4}) setStoreState(cluster, "down", "up", "up", "up", "down", "down") rep.tickUpdateState() assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[2,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", stateID, []uint64{2, 3, 4}) setStoreState(cluster, "up", "down", "up", "up", "down", "down") rep.tickUpdateState() assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", stateID, []uint64{1, 3, 4}) // async_wait -> async rep.tickUpdateState() @@ -291,26 +307,32 @@ func TestStateSwitch(t *testing.T) { rep.tickUpdateState() assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async", stateID, []uint64{1, 3, 4}) // async -> async setStoreState(cluster, "up", "up", "up", "up", "down", "down") rep.tickUpdateState() // store 2 won't be available before it syncs status. rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async", stateID, []uint64{1, 3, 4}) syncStoreStatus(1, 2, 3, 4) rep.tickUpdateState() assertStateIDUpdate() rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async", stateID, []uint64{1, 2, 3, 4}) // async -> sync_recover setStoreState(cluster, "up", "up", "up", "up", "up", "up") rep.tickUpdateState() re.Equal(drStateSyncRecover, rep.drGetState()) assertStateIDUpdate() + rep.drSwitchToAsync([]uint64{1, 2, 3, 4, 5}) + rep.config.DRAutoSync.WaitRecoverTimeout = typeutil.NewDuration(time.Hour) + rep.tickUpdateState() + re.Equal(drStateAsync, rep.drGetState()) // wait recover timeout + + rep.config.DRAutoSync.WaitRecoverTimeout = typeutil.NewDuration(0) setStoreState(cluster, "down", "up", "up", "up", "up", "up") rep.tickUpdateState() re.Equal(drStateSyncRecover, rep.drGetState()) @@ -387,27 +409,27 @@ func TestReplicateState(t *testing.T) { stateID := rep.drAutoSync.StateID // replicate after initialized rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "sync", stateID, nil) // repliate state to new member replicator.memberIDs = append(replicator.memberIDs, 2, 3) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) - re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[3]) + assertLastData(t, replicator.lastData[2], "sync", stateID, nil) + assertLastData(t, replicator.lastData[3], "sync", stateID, nil) // inject error replicator.errors[2] = errors.New("failed to persist") rep.tickUpdateState() // switch async_wait since there is only one zone newStateID := rep.drAutoSync.StateID rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2]}`, newStateID), replicator.lastData[1]) - re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2]}`, newStateID), replicator.lastData[3]) + assertLastData(t, replicator.lastData[1], "async_wait", newStateID, []uint64{1, 2}) + assertLastData(t, replicator.lastData[2], "sync", stateID, nil) + assertLastData(t, replicator.lastData[3], "async_wait", newStateID, []uint64{1, 2}) // clear error, replicate to node 2 next time delete(replicator.errors, 2) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2]}`, newStateID), replicator.lastData[2]) + assertLastData(t, replicator.lastData[2], "async_wait", newStateID, []uint64{1, 2}) } func TestAsynctimeout(t *testing.T) { @@ -637,7 +659,7 @@ func TestComplexPlacementRules(t *testing.T) { rep.tickUpdateState() re.Equal(drStateAsyncWait, rep.drGetState()) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4,5,6]}`, rep.drAutoSync.StateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", rep.drAutoSync.StateID, []uint64{1, 2, 3, 4, 5, 6}) // reset to sync setStoreState(cluster, "up", "up", "up", "up", "up", "up", "up", "up", "up", "up") @@ -698,7 +720,7 @@ func TestComplexPlacementRules2(t *testing.T) { rep.tickUpdateState() re.Equal(drStateAsyncWait, rep.drGetState()) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, rep.drAutoSync.StateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", rep.drAutoSync.StateID, []uint64{1, 2, 3, 4}) } func TestComplexPlacementRules3(t *testing.T) { @@ -737,7 +759,7 @@ func TestComplexPlacementRules3(t *testing.T) { rep.tickUpdateState() re.Equal(drStateAsyncWait, rep.drGetState()) rep.tickReplicateStatus() - re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, rep.drAutoSync.StateID), replicator.lastData[1]) + assertLastData(t, replicator.lastData[1], "async_wait", rep.drAutoSync.StateID, []uint64{1, 2, 3, 4}) } func genRegions(cluster *mockcluster.Cluster, stateID uint64, n int) []*core.RegionInfo { diff --git a/server/config/config.go b/server/config/config.go index 0485e077c67..da6b0e29e07 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -831,13 +831,14 @@ func NormalizeReplicationMode(m string) string { // DRAutoSyncReplicationConfig is the configuration for auto sync mode between 2 data centers. type DRAutoSyncReplicationConfig struct { - LabelKey string `toml:"label-key" json:"label-key"` - Primary string `toml:"primary" json:"primary"` - DR string `toml:"dr" json:"dr"` - PrimaryReplicas int `toml:"primary-replicas" json:"primary-replicas"` - DRReplicas int `toml:"dr-replicas" json:"dr-replicas"` - WaitStoreTimeout typeutil.Duration `toml:"wait-store-timeout" json:"wait-store-timeout"` - PauseRegionSplit bool `toml:"pause-region-split" json:"pause-region-split,string"` + LabelKey string `toml:"label-key" json:"label-key"` + Primary string `toml:"primary" json:"primary"` + DR string `toml:"dr" json:"dr"` + PrimaryReplicas int `toml:"primary-replicas" json:"primary-replicas"` + DRReplicas int `toml:"dr-replicas" json:"dr-replicas"` + WaitStoreTimeout typeutil.Duration `toml:"wait-store-timeout" json:"wait-store-timeout"` + WaitRecoverTimeout typeutil.Duration `toml:"wait-recover-timeout" json:"wait-recover-timeout"` + PauseRegionSplit bool `toml:"pause-region-split" json:"pause-region-split,string"` } func (c *DRAutoSyncReplicationConfig) adjust(meta *configutil.ConfigMetaData) { From f1cee6c3971e18c6ab201e50555261a8c51c3041 Mon Sep 17 00:00:00 2001 From: guo-shaoge Date: Fri, 10 Nov 2023 15:48:43 +0800 Subject: [PATCH 13/26] mcs/resourcemanager: delete expire tokenSlot (#7344) close tikv/pd#7346 Signed-off-by: guo-shaoge --- .../resourcemanager/server/token_buckets.go | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/token_buckets.go b/pkg/mcs/resourcemanager/server/token_buckets.go index a0acba3b54d..05a93c32673 100644 --- a/pkg/mcs/resourcemanager/server/token_buckets.go +++ b/pkg/mcs/resourcemanager/server/token_buckets.go @@ -20,6 +20,8 @@ import ( "github.com/gogo/protobuf/proto" rmpb "github.com/pingcap/kvproto/pkg/resource_manager" + "github.com/pingcap/log" + "go.uber.org/zap" ) const ( @@ -31,6 +33,7 @@ const ( defaultReserveRatio = 0.5 defaultLoanCoefficient = 2 maxAssignTokens = math.MaxFloat64 / 1024 // assume max client connect is 1024 + slotExpireTimeout = 10 * time.Minute ) // GroupTokenBucket is a token bucket for a resource group. @@ -62,6 +65,7 @@ type TokenSlot struct { // tokenCapacity is the number of tokens in the slot. tokenCapacity float64 lastTokenCapacity float64 + lastReqTime time.Time } // GroupTokenBucketState is the running state of TokenBucket. @@ -75,7 +79,8 @@ type GroupTokenBucketState struct { LastUpdate *time.Time `json:"last_update,omitempty"` Initialized bool `json:"initialized"` // settingChanged is used to avoid that the number of tokens returned is jitter because of changing fill rate. - settingChanged bool + settingChanged bool + lastCheckExpireSlot time.Time } // Clone returns the copy of GroupTokenBucketState @@ -95,6 +100,7 @@ func (gts *GroupTokenBucketState) Clone() *GroupTokenBucketState { Initialized: gts.Initialized, tokenSlots: tokenSlots, clientConsumptionTokensSum: gts.clientConsumptionTokensSum, + lastCheckExpireSlot: gts.lastCheckExpireSlot, } } @@ -119,16 +125,18 @@ func (gts *GroupTokenBucketState) balanceSlotTokens( clientUniqueID uint64, settings *rmpb.TokenLimitSettings, requiredToken, elapseTokens float64) { + now := time.Now() slot, exist := gts.tokenSlots[clientUniqueID] if !exist { // Only slots that require a positive number will be considered alive, // but still need to allocate the elapsed tokens as well. if requiredToken != 0 { - slot = &TokenSlot{} + slot = &TokenSlot{lastReqTime: now} gts.tokenSlots[clientUniqueID] = slot gts.clientConsumptionTokensSum = 0 } } else { + slot.lastReqTime = now if gts.clientConsumptionTokensSum >= maxAssignTokens { gts.clientConsumptionTokensSum = 0 } @@ -139,6 +147,16 @@ func (gts *GroupTokenBucketState) balanceSlotTokens( } } + if time.Since(gts.lastCheckExpireSlot) >= slotExpireTimeout { + gts.lastCheckExpireSlot = now + for clientUniqueID, slot := range gts.tokenSlots { + if time.Since(slot.lastReqTime) >= slotExpireTimeout { + delete(gts.tokenSlots, clientUniqueID) + log.Info("delete resource group slot because expire", zap.Time("last-req-time", slot.lastReqTime), + zap.Any("expire timeout", slotExpireTimeout), zap.Any("del client id", clientUniqueID), zap.Any("len", len(gts.tokenSlots))) + } + } + } if len(gts.tokenSlots) == 0 { return } @@ -264,6 +282,7 @@ func (gtb *GroupTokenBucket) init(now time.Time, clientID uint64) { lastTokenCapacity: gtb.Tokens, } gtb.LastUpdate = &now + gtb.lastCheckExpireSlot = now gtb.Initialized = true } From b5119ea4bf2c3bc1d94256810c7e3e3670e96f45 Mon Sep 17 00:00:00 2001 From: lucasliang Date: Fri, 10 Nov 2023 16:32:12 +0800 Subject: [PATCH 14/26] scheduler: refine the interval of scheduling tick in evict-slow-trend-scheduler. (#7326) ref tikv/pd#7156 Implement the `GetNextInterval` for `evict-slow-trend-scheduler`, to refine the ticking interval. Default `GetNextInterval` is not appropriate for `evict-slow-trend-scheduler`, as it might delay the checking of other nodes' slowness status. This pr adjusts the ticking interval of the evict-slow-trend-scheduler to optimize its behavior. If a slow node is already identified as a candidate, the next interval is now set to be shorter, ensuring quicker subsequent scheduling. This refinement aims to decrease response time. Signed-off-by: lucasliang Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/schedule/schedulers/evict_slow_trend.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/pkg/schedule/schedulers/evict_slow_trend.go b/pkg/schedule/schedulers/evict_slow_trend.go index 3983e9c345d..f31ba420c97 100644 --- a/pkg/schedule/schedulers/evict_slow_trend.go +++ b/pkg/schedule/schedulers/evict_slow_trend.go @@ -108,8 +108,12 @@ func (conf *evictSlowTrendSchedulerConfig) getKeyRangesByID(id uint64) []core.Ke return []core.KeyRange{core.NewKeyRange("", "")} } +func (conf *evictSlowTrendSchedulerConfig) hasEvictedStores() bool { + return len(conf.EvictedStores) > 0 +} + func (conf *evictSlowTrendSchedulerConfig) evictedStore() uint64 { - if len(conf.EvictedStores) == 0 { + if !conf.hasEvictedStores() { return 0 } // If a candidate passes all checks and proved to be slow, it will be @@ -237,6 +241,19 @@ type evictSlowTrendScheduler struct { handler http.Handler } +func (s *evictSlowTrendScheduler) GetNextInterval(interval time.Duration) time.Duration { + var growthType intervalGrowthType + // If it already found a slow node as candidate, the next interval should be shorter + // to make the next scheduling as soon as possible. This adjustment will decrease the + // response time, as heartbeats from other nodes will be received and updated more quickly. + if s.conf.hasEvictedStores() { + growthType = zeroGrowth + } else { + growthType = exponentialGrowth + } + return intervalGrow(s.GetMinInterval(), MaxScheduleInterval, growthType) +} + func (s *evictSlowTrendScheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.handler.ServeHTTP(w, r) } From fe8a393e5cc898ab65c8d683b2f7aaa33252dfc1 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 10 Nov 2023 16:45:42 +0800 Subject: [PATCH 15/26] mcs: tso service should not forward again (#7348) ref tikv/pd#5836 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/tso/server/grpc_service.go | 34 ------------------------------ pkg/mcs/tso/server/server.go | 3 --- 2 files changed, 37 deletions(-) diff --git a/pkg/mcs/tso/server/grpc_service.go b/pkg/mcs/tso/server/grpc_service.go index 40a308c72f8..9006faf49da 100644 --- a/pkg/mcs/tso/server/grpc_service.go +++ b/pkg/mcs/tso/server/grpc_service.go @@ -28,8 +28,6 @@ import ( bs "github.com/tikv/pd/pkg/basicserver" "github.com/tikv/pd/pkg/mcs/registry" "github.com/tikv/pd/pkg/utils/apiutil" - "github.com/tikv/pd/pkg/utils/grpcutil" - "github.com/tikv/pd/pkg/utils/tsoutil" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -88,21 +86,9 @@ func (s *Service) RegisterRESTHandler(userDefineHandlers map[string]http.Handler // Tso returns a stream of timestamps func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { - var ( - doneCh chan struct{} - errCh chan error - ) ctx, cancel := context.WithCancel(stream.Context()) defer cancel() for { - // Prevent unnecessary performance overhead of the channel. - if errCh != nil { - select { - case err := <-errCh: - return errors.WithStack(err) - default: - } - } request, err := stream.Recv() if err == io.EOF { return nil @@ -111,26 +97,6 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { return errors.WithStack(err) } - streamCtx := stream.Context() - forwardedHost := grpcutil.GetForwardedHost(streamCtx) - if !s.IsLocalRequest(forwardedHost) { - clientConn, err := s.GetDelegateClient(s.Context(), s.GetTLSConfig(), forwardedHost) - if err != nil { - return errors.WithStack(err) - } - - if errCh == nil { - doneCh = make(chan struct{}) - defer close(doneCh) - errCh = make(chan error) - } - - tsoProtoFactory := s.tsoProtoFactory - tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream) - s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh) - continue - } - start := time.Now() // TSO uses leader lease to determine validity. No need to check leader here. if s.IsClosed() { diff --git a/pkg/mcs/tso/server/server.go b/pkg/mcs/tso/server/server.go index 16ef3216c62..1a2430477d8 100644 --- a/pkg/mcs/tso/server/server.go +++ b/pkg/mcs/tso/server/server.go @@ -78,9 +78,6 @@ type Server struct { service *Service keyspaceGroupManager *tso.KeyspaceGroupManager - // tsoDispatcher is used to dispatch the TSO requests to - // the corresponding forwarding TSO channels. - tsoDispatcher *tsoutil.TSODispatcher // tsoProtoFactory is the abstract factory for creating tso // related data structures defined in the tso grpc protocol tsoProtoFactory *tsoutil.TSOProtoFactory From afe6afccf9ddbf35c4210d40e00c6d69a030d3b3 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 10 Nov 2023 17:09:42 +0800 Subject: [PATCH 16/26] mcs: support rules http interface in scheduling server (#7199) ref tikv/pd#5839 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- errors.toml | 20 + pkg/errs/errno.go | 13 +- pkg/mcs/scheduling/server/apis/v1/api.go | 296 ++++++++++++- pkg/schedule/handler/handler.go | 43 ++ pkg/utils/apiutil/serverapi/middleware.go | 5 +- server/api/region_test.go | 12 +- server/api/rule.go | 296 ++++++++----- server/api/server.go | 25 ++ tests/integrations/mcs/scheduling/api_test.go | 102 ++++- tests/pdctl/config/config_test.go | 10 +- {server => tests/server}/api/rule_test.go | 390 ++++++++++++------ 11 files changed, 932 insertions(+), 280 deletions(-) rename {server => tests/server}/api/rule_test.go (67%) diff --git a/errors.toml b/errors.toml index 1d10d40d294..b6123058310 100644 --- a/errors.toml +++ b/errors.toml @@ -551,6 +551,11 @@ error = ''' build rule list failed, %s ''' +["PD:placement:ErrKeyFormat"] +error = ''' +key should be in hex format, %s +''' + ["PD:placement:ErrLoadRule"] error = ''' load rule failed @@ -561,11 +566,21 @@ error = ''' load rule group failed ''' +["PD:placement:ErrPlacementDisabled"] +error = ''' +placement rules feature is disabled +''' + ["PD:placement:ErrRuleContent"] error = ''' invalid rule content, %s ''' +["PD:placement:ErrRuleNotFound"] +error = ''' +rule not found +''' + ["PD:plugin:ErrLoadPlugin"] error = ''' failed to load plugin @@ -616,6 +631,11 @@ error = ''' region %v has abnormal peer ''' +["PD:region:ErrRegionInvalidID"] +error = ''' +invalid region id +''' + ["PD:region:ErrRegionNotAdjacent"] error = ''' two regions are not adjacent diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index e5bac8519be..b8a882cd187 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -102,6 +102,8 @@ var ( // region errors var ( + // ErrRegionInvalidID is error info for region id invalid. + ErrRegionInvalidID = errors.Normalize("invalid region id", errors.RFCCodeText("PD:region:ErrRegionInvalidID")) // ErrRegionNotAdjacent is error info for region not adjacent. ErrRegionNotAdjacent = errors.Normalize("two regions are not adjacent", errors.RFCCodeText("PD:region:ErrRegionNotAdjacent")) // ErrRegionNotFound is error info for region not found. @@ -153,10 +155,13 @@ var ( // placement errors var ( - ErrRuleContent = errors.Normalize("invalid rule content, %s", errors.RFCCodeText("PD:placement:ErrRuleContent")) - ErrLoadRule = errors.Normalize("load rule failed", errors.RFCCodeText("PD:placement:ErrLoadRule")) - ErrLoadRuleGroup = errors.Normalize("load rule group failed", errors.RFCCodeText("PD:placement:ErrLoadRuleGroup")) - ErrBuildRuleList = errors.Normalize("build rule list failed, %s", errors.RFCCodeText("PD:placement:ErrBuildRuleList")) + ErrRuleContent = errors.Normalize("invalid rule content, %s", errors.RFCCodeText("PD:placement:ErrRuleContent")) + ErrLoadRule = errors.Normalize("load rule failed", errors.RFCCodeText("PD:placement:ErrLoadRule")) + ErrLoadRuleGroup = errors.Normalize("load rule group failed", errors.RFCCodeText("PD:placement:ErrLoadRuleGroup")) + ErrBuildRuleList = errors.Normalize("build rule list failed, %s", errors.RFCCodeText("PD:placement:ErrBuildRuleList")) + ErrPlacementDisabled = errors.Normalize("placement rules feature is disabled", errors.RFCCodeText("PD:placement:ErrPlacementDisabled")) + ErrKeyFormat = errors.Normalize("key should be in hex format, %s", errors.RFCCodeText("PD:placement:ErrKeyFormat")) + ErrRuleNotFound = errors.Normalize("rule not found", errors.RFCCodeText("PD:placement:ErrRuleNotFound")) ) // region label errors diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 47fdb95543f..172515d8620 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -15,6 +15,7 @@ package apis import ( + "encoding/hex" "net/http" "strconv" "sync" @@ -127,12 +128,6 @@ func (s *Service) RegisterAdminRouter() { router.DELETE("cache/regions/:id", deleteRegionCacheByID) } -// RegisterConfigRouter registers the router of the config handler. -func (s *Service) RegisterConfigRouter() { - router := s.root.Group("config") - router.GET("", getConfig) -} - // RegisterSchedulersRouter registers the router of the schedulers handler. func (s *Service) RegisterSchedulersRouter() { router := s.root.Group("schedulers") @@ -172,6 +167,32 @@ func (s *Service) RegisterOperatorsRouter() { router.GET("/records", getOperatorRecords) } +// RegisterConfigRouter registers the router of the config handler. +func (s *Service) RegisterConfigRouter() { + router := s.root.Group("config") + router.GET("", getConfig) + + rules := router.Group("rules") + rules.GET("", getAllRules) + rules.GET("/group/:group", getRuleByGroup) + rules.GET("/region/:region", getRulesByRegion) + rules.GET("/region/:region/detail", checkRegionPlacementRule) + rules.GET("/key/:key", getRulesByKey) + + // We cannot merge `/rule` and `/rules`, because we allow `group_id` to be "group", + // which is the same as the prefix of `/rules/group/:group`. + rule := router.Group("rule") + rule.GET("/:group/:id", getRuleByGroupAndID) + + groups := router.Group("rule_groups") + groups.GET("", getAllGroupConfigs) + groups.GET("/:id", getRuleGroupConfig) + + placementRule := router.Group("placement-rule") + placementRule.GET("", getPlacementRules) + placementRule.GET("/:group", getPlacementRuleByGroup) +} + // @Tags admin // @Summary Change the log level. // @Produce json @@ -671,3 +692,266 @@ func getHistoryHotRegions(c *gin.Context) { var res storage.HistoryHotRegions c.IndentedJSON(http.StatusOK, res) } + +// @Tags rule +// @Summary List all rules of cluster. +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules [get] +func getAllRules(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + rules := manager.GetAllRules() + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List all rules of cluster by group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/group/{group} [get] +func getRuleByGroup(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + group := c.Param("group") + rules := manager.GetRulesByGroup(group) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List all rules of cluster by region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/region/{region} [get] +func getRulesByRegion(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + regionStr := c.Param("region") + region, code, err := handler.PreCheckForRegion(regionStr) + if err != nil { + c.String(code, err.Error()) + return + } + rules := manager.GetRulesForApplyRegion(region) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List rules and matched peers related to the given region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} placement.RegionFit +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/region/{region}/detail [get] +func checkRegionPlacementRule(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + regionStr := c.Param("region") + region, code, err := handler.PreCheckForRegion(regionStr) + if err != nil { + c.String(code, err.Error()) + return + } + regionFit, err := handler.CheckRegionPlacementRule(region) + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.IndentedJSON(http.StatusOK, regionFit) +} + +// @Tags rule +// @Summary List all rules of cluster by key. +// @Param key path string true "The name of key" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/key/{key} [get] +func getRulesByKey(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + keyHex := c.Param("key") + key, err := hex.DecodeString(keyHex) + if err != nil { + c.String(http.StatusBadRequest, errs.ErrKeyFormat.Error()) + return + } + rules := manager.GetRulesByKey(key) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary Get rule of cluster by group and id. +// @Param group path string true "The name of group" +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {object} placement.Rule +// @Failure 404 {string} string "The rule does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rule/{group}/{id} [get] +func getRuleByGroupAndID(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + group, id := c.Param("group"), c.Param("id") + rule := manager.GetRule(group, id) + if rule == nil { + c.String(http.StatusNotFound, errs.ErrRuleNotFound.Error()) + return + } + c.IndentedJSON(http.StatusOK, rule) +} + +// @Tags rule +// @Summary List all rule group configs. +// @Produce json +// @Success 200 {array} placement.RuleGroup +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_groups [get] +func getAllGroupConfigs(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + ruleGroups := manager.GetRuleGroups() + c.IndentedJSON(http.StatusOK, ruleGroups) +} + +// @Tags rule +// @Summary Get rule group config by group id. +// @Param id path string true "Group Id" +// @Produce json +// @Success 200 {object} placement.RuleGroup +// @Failure 404 {string} string "The RuleGroup does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_groups/{id} [get] +func getRuleGroupConfig(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + id := c.Param("id") + group := manager.GetRuleGroup(id) + if group == nil { + c.String(http.StatusNotFound, errs.ErrRuleNotFound.Error()) + return + } + c.IndentedJSON(http.StatusOK, group) +} + +// @Tags rule +// @Summary List all rules and groups configuration. +// @Produce json +// @Success 200 {array} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rules [get] +func getPlacementRules(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + bundles := manager.GetAllGroupBundles() + c.IndentedJSON(http.StatusOK, bundles) +} + +// @Tags rule +// @Summary Get group config and all rules belong to the group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {object} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rules/{group} [get] +func getPlacementRuleByGroup(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + g := c.Param("group") + group := manager.GetGroupBundle(g) + c.IndentedJSON(http.StatusOK, group) +} diff --git a/pkg/schedule/handler/handler.go b/pkg/schedule/handler/handler.go index 45b0eaf502f..3f9f4f96622 100644 --- a/pkg/schedule/handler/handler.go +++ b/pkg/schedule/handler/handler.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/hex" "net/http" + "strconv" "strings" "time" @@ -1061,3 +1062,45 @@ func (h *Handler) GetHotBuckets(regionIDs ...uint64) (HotBucketsResponse, error) } return ret, nil } + +// GetRuleManager returns the rule manager. +func (h *Handler) GetRuleManager() (*placement.RuleManager, error) { + c := h.GetCluster() + if c == nil { + return nil, errs.ErrNotBootstrapped + } + if !c.GetSharedConfig().IsPlacementRulesEnabled() { + return nil, errs.ErrPlacementDisabled + } + return c.GetRuleManager(), nil +} + +// PreCheckForRegion checks if the region is valid. +func (h *Handler) PreCheckForRegion(regionStr string) (*core.RegionInfo, int, error) { + c := h.GetCluster() + if c == nil { + return nil, http.StatusInternalServerError, errs.ErrNotBootstrapped.GenWithStackByArgs() + } + regionID, err := strconv.ParseUint(regionStr, 10, 64) + if err != nil { + return nil, http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs() + } + region := c.GetRegion(regionID) + if region == nil { + return nil, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID) + } + return region, http.StatusOK, nil +} + +// CheckRegionPlacementRule checks if the region matches the placement rules. +func (h *Handler) CheckRegionPlacementRule(region *core.RegionInfo) (*placement.RegionFit, error) { + c := h.GetCluster() + if c == nil { + return nil, errs.ErrNotBootstrapped.GenWithStackByArgs() + } + manager, err := h.GetRuleManager() + if err != nil { + return nil, err + } + return manager.FitRegion(c, region), nil +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 19438ad0f91..2bb742ccbba 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -117,6 +117,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri r.URL.Path = strings.TrimRight(r.URL.Path, "/") for _, rule := range h.microserviceRedirectRules { if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) { + origin := r.URL.Path addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when trying to match redirect rules", @@ -134,8 +135,8 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } else { r.URL.Path = rule.targetPath } - log.Debug("redirect to micro service", zap.String("path", r.URL.Path), zap.String("target", addr), - zap.String("method", r.Method)) + log.Debug("redirect to micro service", zap.String("path", r.URL.Path), zap.String("origin-path", origin), + zap.String("target", addr), zap.String("method", r.Method)) return true, addr } } diff --git a/server/api/region_test.go b/server/api/region_test.go index a39a1e5c5fd..379fcf7d463 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -241,14 +241,14 @@ func (suite *regionTestSuite) TestRegions() { mustRegionHeartbeat(re, suite.svr, r) } url := fmt.Sprintf("%s/regions", suite.urlPrefix) - RegionsInfo := &RegionsInfo{} - err := tu.ReadGetJSON(re, testDialClient, url, RegionsInfo) + regionsInfo := &RegionsInfo{} + err := tu.ReadGetJSON(re, testDialClient, url, regionsInfo) suite.NoError(err) - suite.Len(regions, RegionsInfo.Count) - sort.Slice(RegionsInfo.Regions, func(i, j int) bool { - return RegionsInfo.Regions[i].ID < RegionsInfo.Regions[j].ID + suite.Len(regions, regionsInfo.Count) + sort.Slice(regionsInfo.Regions, func(i, j int) bool { + return regionsInfo.Regions[i].ID < regionsInfo.Regions[j].ID }) - for i, r := range RegionsInfo.Regions { + for i, r := range regionsInfo.Regions { suite.Equal(regions[i].ID, r.ID) suite.Equal(regions[i].ApproximateSize, r.ApproximateSize) suite.Equal(regions[i].ApproximateKeys, r.ApproximateKeys) diff --git a/server/api/rule.go b/server/api/rule.go index b3a720ece41..77aad42eb42 100644 --- a/server/api/rule.go +++ b/server/api/rule.go @@ -19,30 +19,26 @@ import ( "fmt" "net/http" "net/url" - "strconv" "github.com/gorilla/mux" - "github.com/pingcap/errors" - "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/server" - "github.com/tikv/pd/server/cluster" "github.com/unrolled/render" ) -var errPlacementDisabled = errors.New("placement rules feature is disabled") - type ruleHandler struct { + *server.Handler svr *server.Server rd *render.Render } func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { return &ruleHandler{ - svr: svr, - rd: rd, + Handler: svr.GetHandler(), + svr: svr, + rd: rd, } } @@ -51,14 +47,19 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { // @Produce json // @Success 200 {array} placement.Rule // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [get] func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - rules := cluster.GetRuleManager().GetAllRules() + rules := manager.GetAllRules() h.rd.JSON(w, http.StatusOK, rules) } @@ -72,9 +73,13 @@ func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [post] func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var rules []*placement.Rule @@ -87,7 +92,7 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { return } } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetRules(rules); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -105,15 +110,20 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { // @Produce json // @Success 200 {array} placement.Rule // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/group/{group} [get] func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group := mux.Vars(r)["group"] - rules := cluster.GetRuleManager().GetRulesByGroup(group) + rules := manager.GetRulesByGroup(group) h.rd.JSON(w, http.StatusOK, rules) } @@ -125,13 +135,25 @@ func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "The input is invalid." // @Failure 404 {string} string "The region does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region} [get] func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { - cluster, region := h.preCheckForRegionAndRule(w, r) - if cluster == nil || region == nil { + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - rules := cluster.GetRuleManager().GetRulesForApplyRegion(region) + regionStr := mux.Vars(r)["region"] + region, code, err := h.PreCheckForRegion(regionStr) + if err != nil { + h.rd.JSON(w, code, err.Error()) + return + } + rules := manager.GetRulesForApplyRegion(region) h.rd.JSON(w, http.StatusOK, rules) } @@ -143,34 +165,25 @@ func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "The input is invalid." // @Failure 404 {string} string "The region does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region}/detail [get] func (h *ruleHandler) CheckRegionPlacementRule(w http.ResponseWriter, r *http.Request) { - cluster, region := h.preCheckForRegionAndRule(w, r) - if cluster == nil || region == nil { + regionStr := mux.Vars(r)["region"] + region, code, err := h.PreCheckForRegion(regionStr) + if err != nil { + h.rd.JSON(w, code, err.Error()) return } - regionFit := cluster.GetRuleManager().FitRegion(cluster, region) - h.rd.JSON(w, http.StatusOK, regionFit) -} - -func (h *ruleHandler) preCheckForRegionAndRule(w http.ResponseWriter, r *http.Request) (*cluster.RaftCluster, *core.RegionInfo) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) - return cluster, nil + regionFit, err := h.Handler.CheckRegionPlacementRule(region) + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return } - regionStr := mux.Vars(r)["region"] - regionID, err := strconv.ParseUint(regionStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusBadRequest, "invalid region id") - return cluster, nil - } - region := cluster.GetRegion(regionID) - if region == nil { - h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) - return cluster, nil + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return } - return cluster, region + h.rd.JSON(w, http.StatusOK, regionFit) } // @Tags rule @@ -180,20 +193,25 @@ func (h *ruleHandler) preCheckForRegionAndRule(w http.ResponseWriter, r *http.Re // @Success 200 {array} placement.Rule // @Failure 400 {string} string "The input is invalid." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/key/{key} [get] func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } keyHex := mux.Vars(r)["key"] key, err := hex.DecodeString(keyHex) if err != nil { - h.rd.JSON(w, http.StatusBadRequest, "key should be in hex format") + h.rd.JSON(w, http.StatusBadRequest, errs.ErrKeyFormat.FastGenByArgs(err).Error()) return } - rules := cluster.GetRuleManager().GetRulesByKey(key) + rules := manager.GetRulesByKey(key) h.rd.JSON(w, http.StatusOK, rules) } @@ -207,15 +225,19 @@ func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { // @Failure 412 {string} string "Placement rules feature is disabled." // @Router /config/rule/{group}/{id} [get] func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] - rule := cluster.GetRuleManager().GetRule(group, id) + rule := manager.GetRule(group, id) if rule == nil { - h.rd.JSON(w, http.StatusNotFound, nil) + h.rd.JSON(w, http.StatusNotFound, errs.ErrRuleNotFound.Error()) return } h.rd.JSON(w, http.StatusOK, rule) @@ -232,21 +254,25 @@ func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule [post] func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var rule placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rule); err != nil { return } - oldRule := cluster.GetRuleManager().GetRule(rule.GroupID, rule.ID) + oldRule := manager.GetRule(rule.GroupID, rule.ID) if err := h.syncReplicateConfigWithDefaultRule(&rule); err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetRule(&rule); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -255,6 +281,7 @@ func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { } return } + cluster := getCluster(r) cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) if oldRule != nil { cluster.AddSuspectKeyRange(oldRule.StartKey, oldRule.EndKey) @@ -285,18 +312,23 @@ func (h *ruleHandler) syncReplicateConfigWithDefaultRule(rule *placement.Rule) e // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule/{group}/{id} [delete] func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] - rule := cluster.GetRuleManager().GetRule(group, id) - if err := cluster.GetRuleManager().DeleteRule(group, id); err != nil { + rule := manager.GetRule(group, id) + if err := manager.DeleteRule(group, id); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } if rule != nil { + cluster := getCluster(r) cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) } @@ -313,16 +345,20 @@ func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/batch [post] func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var opts []placement.RuleOp if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &opts); err != nil { return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). Batch(opts); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -341,15 +377,20 @@ func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} placement.RuleGroup // @Failure 404 {string} string "The RuleGroup does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [get] func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } id := mux.Vars(r)["id"] - group := cluster.GetRuleManager().GetRuleGroup(id) + group := manager.GetRuleGroup(id) if group == nil { h.rd.JSON(w, http.StatusNotFound, nil) return @@ -368,21 +409,26 @@ func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group [post] func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var ruleGroup placement.RuleGroup if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &ruleGroup); err != nil { return } - if err := cluster.GetRuleManager().SetRuleGroup(&ruleGroup); err != nil { + if err := manager.SetRuleGroup(&ruleGroup); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - for _, r := range cluster.GetRuleManager().GetRulesByGroup(ruleGroup.ID) { - cluster.AddSuspectKeyRange(r.StartKey, r.EndKey) + cluster := getCluster(r) + for _, rule := range manager.GetRulesByGroup(ruleGroup.ID) { + cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) } h.rd.JSON(w, http.StatusOK, "Update rule group successfully.") } @@ -396,18 +442,23 @@ func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [delete] func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } id := mux.Vars(r)["id"] - err := cluster.GetRuleManager().DeleteRuleGroup(id) + err = manager.DeleteRuleGroup(id) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - for _, r := range cluster.GetRuleManager().GetRulesByGroup(id) { + cluster := getCluster(r) + for _, r := range manager.GetRulesByGroup(id) { cluster.AddSuspectKeyRange(r.StartKey, r.EndKey) } h.rd.JSON(w, http.StatusOK, "Delete rule group successfully.") @@ -418,14 +469,19 @@ func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {array} placement.RuleGroup // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_groups [get] func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) return } - ruleGroups := cluster.GetRuleManager().GetRuleGroups() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + ruleGroups := manager.GetRuleGroups() h.rd.JSON(w, http.StatusOK, ruleGroups) } @@ -434,14 +490,19 @@ func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {array} placement.GroupBundle // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [get] func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - bundles := cluster.GetRuleManager().GetAllGroupBundles() + bundles := manager.GetAllGroupBundles() h.rd.JSON(w, http.StatusOK, bundles) } @@ -455,9 +516,13 @@ func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [post] func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var groups []placement.GroupBundle @@ -465,7 +530,7 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) return } _, partial := r.URL.Query()["partial"] - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetAllGroupBundles(groups, !partial); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -483,14 +548,20 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {object} placement.GroupBundle // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [get] func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) return } - group := cluster.GetRuleManager().GetGroupBundle(mux.Vars(r)["group"]) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + g := mux.Vars(r)["group"] + group := manager.GetGroupBundle(g) h.rd.JSON(w, http.StatusOK, group) } @@ -502,21 +573,26 @@ func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req // @Success 200 {string} string "Delete group and rules successfully." // @Failure 400 {string} string "Bad request." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [delete] func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group := mux.Vars(r)["group"] - group, err := url.PathUnescape(group) + group, err = url.PathUnescape(group) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } _, regex := r.URL.Query()["regexp"] - if err := cluster.GetRuleManager().DeleteGroupBundle(group, regex); err != nil { + if err := manager.DeleteGroupBundle(group, regex); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } @@ -532,9 +608,13 @@ func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http. // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [post] func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } groupID := mux.Vars(r)["group"] @@ -549,7 +629,7 @@ func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req h.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("group id %s does not match request URI %s", group.ID, groupID)) return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetGroupBundle(group); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) diff --git a/server/api/server.go b/server/api/server.go index ae877b8407c..77a51eb04e5 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -84,6 +84,31 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP scheapi.APIPathPrefix+"/hotspot", mcs.SchedulingServiceName, []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rules", + scheapi.APIPathPrefix+"/config/rules", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule/", + scheapi.APIPathPrefix+"/config/rule", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule_group/", + scheapi.APIPathPrefix+"/config/rule_groups", // Note: this is a typo in the original code + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule_groups", + scheapi.APIPathPrefix+"/config/rule_groups", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/placement-rule", + scheapi.APIPathPrefix+"/config/placement-rule", + mcs.SchedulingServiceName, + []string{http.MethodGet}), // because the writing of all the meta information of the scheduling service is in the API server, // we should not post and delete the scheduler directly in the scheduling service. serverapi.MicroserviceRedirectRule( diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 15c66ce5829..cfeaa4db033 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -15,10 +15,10 @@ import ( _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" "github.com/tikv/pd/pkg/schedule/handler" + "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/storage" "github.com/tikv/pd/pkg/utils/apiutil" - "github.com/tikv/pd/pkg/utils/tempurl" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" ) @@ -43,7 +43,7 @@ func TestAPI(t *testing.T) { suite.Run(t, &apiTestSuite{}) } -func (suite *apiTestSuite) SetupSuite() { +func (suite *apiTestSuite) SetupTest() { ctx, cancel := context.WithCancel(context.Background()) suite.ctx = ctx cluster, err := tests.NewTestAPICluster(suite.ctx, 1) @@ -62,14 +62,19 @@ func (suite *apiTestSuite) SetupSuite() { suite.cleanupFunc = func() { cancel() } + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) + suite.NoError(err) + suite.cluster.SetSchedulingCluster(tc) + tc.WaitForPrimaryServing(suite.Require()) } -func (suite *apiTestSuite) TearDownSuite() { +func (suite *apiTestSuite) TearDownTest() { suite.cluster.Destroy() suite.cleanupFunc() } func (suite *apiTestSuite) TestGetCheckerByName() { + re := suite.Require() testCases := []struct { name string }{ @@ -81,14 +86,8 @@ func (suite *apiTestSuite) TestGetCheckerByName() { {name: "joint-state"}, } - re := suite.Require() - s, cleanup := tests.StartSingleSchedulingTestServer(suite.ctx, re, suite.backendEndpoints, tempurl.Alloc()) - defer cleanup() - testutil.Eventually(re, func() bool { - return s.IsServing() - }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - addr := s.GetAddr() - urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/checkers", addr) + s := suite.cluster.GetSchedulingPrimaryServer() + urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/checkers", s.GetAddr()) co := s.GetCoordinator() for _, testCase := range testCases { @@ -123,17 +122,12 @@ func (suite *apiTestSuite) TestAPIForward() { re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) }() - tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) - re.NoError(err) - defer tc.Destroy() - tc.WaitForPrimaryServing(re) - urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) var slice []string var resp map[string]interface{} // Test opeartor - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Len(slice, 0) @@ -241,6 +235,80 @@ func (suite *apiTestSuite) TestAPIForward() { err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) + + // Test rules: only forward `GET` request + var rules []*placement.Rule + tests.MustPutRegion(re, suite.cluster, 2, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + rules = []*placement.Rule{ + { + GroupID: "pd", + ID: "default", + Role: "voter", + Count: 3, + LocationLabels: []string{}, + }, + } + rulesArgs, err := json.Marshal(rules) + suite.NoError(err) + + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "/config/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + var fit placement.RegionFit + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) } func (suite *apiTestSuite) TestConfig() { diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 26d70bb955f..2cc8427911a 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "os" "reflect" + "strings" "testing" "time" @@ -409,9 +410,12 @@ func (suite *configTestSuite) checkPlacementRuleGroups(cluster *tests.TestCluste // test show var group placement.RuleGroup - output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "pd") - re.NoError(err) - re.NoError(json.Unmarshal(output, &group)) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "pd") + re.NoError(err) + return !strings.Contains(string(output), "404") + }) + re.NoError(json.Unmarshal(output, &group), string(output)) re.Equal(placement.RuleGroup{ID: "pd"}, group) // test set diff --git a/server/api/rule_test.go b/tests/server/api/rule_test.go similarity index 67% rename from server/api/rule_test.go rename to tests/server/api/rule_test.go index d2dc50f1119..3ee3357e031 100644 --- a/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -25,57 +25,37 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/schedule/placement" tu "github.com/tikv/pd/pkg/utils/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" + "github.com/tikv/pd/tests" ) type ruleTestSuite struct { suite.Suite - svr *server.Server - cleanup tu.CleanupFunc - urlPrefix string } func TestRuleTestSuite(t *testing.T) { suite.Run(t, new(ruleTestSuite)) } -func (suite *ruleTestSuite) SetupSuite() { - re := suite.Require() - suite.svr, suite.cleanup = mustNewServer(re) - server.MustWaitLeader(re, []*server.Server{suite.svr}) - - addr := suite.svr.GetAddr() - suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) - - mustBootstrapCluster(re, suite.svr) - PDServerCfg := suite.svr.GetConfig().PDServerCfg - PDServerCfg.KeyType = "raw" - err := suite.svr.SetPDServerConfig(PDServerCfg) - suite.NoError(err) - suite.NoError(tu.CheckPostJSON(testDialClient, suite.urlPrefix, []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(re))) -} - -func (suite *ruleTestSuite) TearDownSuite() { - suite.cleanup() -} - -func (suite *ruleTestSuite) TearDownTest() { - def := placement.GroupBundle{ - ID: "pd", - Rules: []*placement.Rule{ - {GroupID: "pd", ID: "default", Role: "voter", Count: 3}, +func (suite *ruleTestSuite) TestSet() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true }, } - data, err := json.Marshal([]placement.GroupBundle{def}) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(suite.Require())) - suite.NoError(err) + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkSet) } -func (suite *ruleTestSuite) TestSet() { +func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} successData, err := json.Marshal(rule) suite.NoError(err) @@ -159,12 +139,12 @@ func (suite *ruleTestSuite) TestSet() { for _, testCase := range testCases { suite.T().Log(testCase.name) // clear suspect keyRanges to prevent test case from others - suite.svr.GetRaftCluster().ClearSuspectKeyRanges() + leaderServer.GetRaftCluster().ClearSuspectKeyRanges() if testCase.success { - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { - v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} @@ -175,7 +155,7 @@ func (suite *ruleTestSuite) TestSet() { suite.True(ok) } } else { - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) } @@ -184,11 +164,26 @@ func (suite *ruleTestSuite) TestSet() { } func (suite *ruleTestSuite) TestGet() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGet) +} + +func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "a", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -213,7 +208,7 @@ func (suite *ruleTestSuite) TestGet() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp placement.Rule - url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.rule.GroupID, testCase.rule.ID) + url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.rule.GroupID, testCase.rule.ID) if testCase.found { err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.compareRule(&resp, &testCase.rule) @@ -225,20 +220,50 @@ func (suite *ruleTestSuite) TestGet() { } func (suite *ruleTestSuite) TestGetAll() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAll) +} + +func (suite *ruleTestSuite) checkGetAll(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "b", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) var resp2 []*placement.Rule - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/rules", &resp2) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/rules", &resp2) suite.NoError(err) suite.GreaterOrEqual(len(resp2), 1) } func (suite *ruleTestSuite) TestSetAll() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkSetAll) +} + +func (suite *ruleTestSuite) checkSetAll(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule1 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule2 := placement.Rule{GroupID: "b", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule3 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "XXXX", EndKeyHex: "3333", Role: "voter", Count: 1} @@ -247,10 +272,10 @@ func (suite *ruleTestSuite) TestSetAll() { LocationLabels: []string{"host"}} rule6 := placement.Rule{GroupID: "pd", ID: "default", StartKeyHex: "", EndKeyHex: "", Role: "voter", Count: 3} - suite.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - defaultRule := suite.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + leaderServer.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} + defaultRule := leaderServer.GetRaftCluster().GetRuleManager().GetRule("pd", "default") defaultRule.LocationLabels = []string{"host"} - suite.svr.GetRaftCluster().GetRuleManager().SetRule(defaultRule) + leaderServer.GetRaftCluster().GetRuleManager().SetRule(defaultRule) successData, err := json.Marshal([]*placement.Rule{&rule1, &rule2}) suite.NoError(err) @@ -333,13 +358,13 @@ func (suite *ruleTestSuite) TestSetAll() { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) suite.NoError(err) if testCase.isDefaultRule { - suite.Equal(int(suite.svr.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) + suite.Equal(int(leaderServer.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) } } else { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StringEqual(re, testCase.response)) suite.NoError(err) } @@ -347,17 +372,32 @@ func (suite *ruleTestSuite) TestSetAll() { } func (suite *ruleTestSuite) TestGetAllByGroup() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByGroup) +} + +func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + re := suite.Require() rule := placement.Rule{GroupID: "c", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) rule1 := placement.Rule{GroupID: "c", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err = json.Marshal(rule1) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -380,7 +420,7 @@ func (suite *ruleTestSuite) TestGetAllByGroup() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/group/%s", suite.urlPrefix, testCase.groupID) + url := fmt.Sprintf("%s/rules/group/%s", urlPrefix, testCase.groupID) err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.NoError(err) suite.Len(resp, testCase.count) @@ -392,15 +432,30 @@ func (suite *ruleTestSuite) TestGetAllByGroup() { } func (suite *ruleTestSuite) TestGetAllByRegion() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByRegion) +} + +func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "e", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) r := core.NewTestRegionInfo(4, 1, []byte{0x22, 0x22}, []byte{0x33, 0x33}) - mustRegionHeartbeat(re, suite.svr, r) + tests.MustPutRegionInfo(re, cluster, r) testCases := []struct { name string @@ -429,7 +484,7 @@ func (suite *ruleTestSuite) TestGetAllByRegion() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/region/%s", suite.urlPrefix, testCase.regionID) + url := fmt.Sprintf("%s/rules/region/%s", urlPrefix, testCase.regionID) if testCase.success { err = tu.ReadGetJSON(re, testDialClient, url, &resp) @@ -446,11 +501,26 @@ func (suite *ruleTestSuite) TestGetAllByRegion() { } func (suite *ruleTestSuite) TestGetAllByKey() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByKey) +} + +func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "f", ID: "40", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -483,7 +553,7 @@ func (suite *ruleTestSuite) TestGetAllByKey() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/key/%s", suite.urlPrefix, testCase.key) + url := fmt.Sprintf("%s/rules/key/%s", urlPrefix, testCase.key) if testCase.success { err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.Len(resp, testCase.respSize) @@ -495,10 +565,25 @@ func (suite *ruleTestSuite) TestGetAllByKey() { } func (suite *ruleTestSuite) TestDelete() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkDelete) +} + +func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "g", ID: "10", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(suite.Require())) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(suite.Require())) suite.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) suite.NoError(err) @@ -529,15 +614,15 @@ func (suite *ruleTestSuite) TestDelete() { } for _, testCase := range testCases { suite.T().Log(testCase.name) - url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.groupID, testCase.id) + url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others - suite.svr.GetRaftCluster().ClearSuspectKeyRanges() + leaderServer.GetRaftCluster().ClearSuspectKeyRanges() err = tu.CheckDelete(testDialClient, url, tu.StatusOK(suite.Require())) suite.NoError(err) if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { - v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} @@ -551,16 +636,22 @@ func (suite *ruleTestSuite) TestDelete() { } } -func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) { - suite.Equal(r2.GroupID, r1.GroupID) - suite.Equal(r2.ID, r1.ID) - suite.Equal(r2.StartKeyHex, r1.StartKeyHex) - suite.Equal(r2.EndKeyHex, r1.EndKeyHex) - suite.Equal(r2.Role, r1.Role) - suite.Equal(r2.Count, r1.Count) +func (suite *ruleTestSuite) TestBatch() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBatch) } -func (suite *ruleTestSuite) TestBatch() { +func (suite *ruleTestSuite) checkBatch(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + opt1 := placement.RuleOp{ Action: placement.RuleOpAdd, Rule: &placement.Rule{GroupID: "a", ID: "13", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1}, @@ -670,10 +761,10 @@ func (suite *ruleTestSuite) TestBatch() { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) suite.NoError(err) } else { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) suite.NoError(err) @@ -682,6 +773,21 @@ func (suite *ruleTestSuite) TestBatch() { } func (suite *ruleTestSuite) TestBundle() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBundle) +} + +func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + re := suite.Require() // GetAll b1 := placement.GroupBundle{ @@ -691,7 +797,7 @@ func (suite *ruleTestSuite) TestBundle() { }, } var bundles []placement.GroupBundle - err := tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err := tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b1) @@ -707,28 +813,28 @@ func (suite *ruleTestSuite) TestBundle() { } data, err := json.Marshal(b2) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) suite.NoError(err) // Get var bundle placement.GroupBundle - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/foo", &bundle) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/foo", &bundle) suite.NoError(err) suite.compareBundle(bundle, b2) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) suite.compareBundle(bundles[0], b1) suite.compareBundle(bundles[1], b2) // Delete - err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) + err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b2) @@ -739,11 +845,11 @@ func (suite *ruleTestSuite) TestBundle() { b3 := placement.GroupBundle{ID: "foobar", Index: 100} data, err = json.Marshal([]placement.GroupBundle{b1, b2, b3}) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) suite.compareBundle(bundles[0], b2) @@ -751,11 +857,11 @@ func (suite *ruleTestSuite) TestBundle() { suite.compareBundle(bundles[2], b3) // Delete using regexp - err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) + err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b1) @@ -770,19 +876,19 @@ func (suite *ruleTestSuite) TestBundle() { } data, err = json.Marshal(b4) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) suite.NoError(err) b4.ID = id b4.Rules[0].GroupID = b4.ID // Get - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/"+id, &bundle) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/"+id, &bundle) suite.NoError(err) suite.compareBundle(bundle, b4) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) suite.compareBundle(bundles[0], b1) @@ -798,13 +904,13 @@ func (suite *ruleTestSuite) TestBundle() { } data, err = json.Marshal([]placement.GroupBundle{b1, b4, b5}) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) b5.Rules[0].GroupID = b5.ID // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) suite.compareBundle(bundles[0], b1) @@ -813,6 +919,21 @@ func (suite *ruleTestSuite) TestBundle() { } func (suite *ruleTestSuite) TestBundleBadRequest() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBundleBadRequest) +} + +func (suite *ruleTestSuite) checkBundleBadRequest(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + testCases := []struct { uri string data string @@ -826,7 +947,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() { {"/placement-rule", `[{"group_id":"foo", "rules": [{"group_id":"bar", "id":"baz", "role":"voter", "count":1}]}]`, false}, } for _, testCase := range testCases { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data), + err := tu.CheckPostJSON(testDialClient, urlPrefix+testCase.uri, []byte(testCase.data), func(_ []byte, code int, _ http.Header) { suite.Equal(testCase.ok, code == http.StatusOK) }) @@ -844,22 +965,42 @@ func (suite *ruleTestSuite) compareBundle(b1, b2 placement.GroupBundle) { } } +func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) { + suite.Equal(r2.GroupID, r1.GroupID) + suite.Equal(r2.ID, r1.ID) + suite.Equal(r2.StartKeyHex, r1.StartKeyHex) + suite.Equal(r2.EndKeyHex, r1.EndKeyHex) + suite.Equal(r2.Role, r1.Role) + suite.Equal(r2.Count, r1.Count) +} + type regionRuleTestSuite struct { suite.Suite - svr *server.Server - grpcSvr *server.GrpcServer - cleanup tu.CleanupFunc - urlPrefix string - stores []*metapb.Store - regions []*core.RegionInfo } func TestRegionRuleTestSuite(t *testing.T) { suite.Run(t, new(regionRuleTestSuite)) } -func (suite *regionRuleTestSuite) SetupSuite() { - suite.stores = []*metapb.Store{ +func (suite *regionRuleTestSuite) TestRegionPlacementRule() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.Replication.EnablePlacementRules = true + conf.Replication.MaxReplicas = 1 + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + // FIXME: enable this test in two modes after we support region label forward. + env.RunTestInPDMode(suite.checkRegionPlacementRule) +} + +func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1", pdAddr, apiPrefix) + + stores := []*metapb.Store{ { Id: 1, Address: "tikv1", @@ -875,49 +1016,30 @@ func (suite *regionRuleTestSuite) SetupSuite() { Version: "2.0.0", }, } - re := suite.Require() - suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { - cfg.Replication.EnablePlacementRules = true - cfg.Replication.MaxReplicas = 1 - }) - server.MustWaitLeader(re, []*server.Server{suite.svr}) - - addr := suite.svr.GetAddr() - suite.grpcSvr = &server.GrpcServer{Server: suite.svr} - suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - - mustBootstrapCluster(re, suite.svr) - - for _, store := range suite.stores { - mustPutStore(re, suite.svr, store.Id, store.State, store.NodeState, nil) + for _, store := range stores { + tests.MustPutStore(re, cluster, store) } - suite.regions = make([]*core.RegionInfo, 0) + regions := make([]*core.RegionInfo, 0) peers1 := []*metapb.Peer{ {Id: 102, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 103, StoreId: 2, Role: metapb.PeerRole_Voter}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers1, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}}, peers1[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers1, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}}, peers1[0], core.WithStartKey([]byte("abc")), core.WithEndKey([]byte("def")))) peers2 := []*metapb.Peer{ {Id: 104, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 105, StoreId: 2, Role: metapb.PeerRole_Learner}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers2, RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 2}}, peers2[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers2, RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 2}}, peers2[0], core.WithStartKey([]byte("ghi")), core.WithEndKey([]byte("jkl")))) peers3 := []*metapb.Peer{ {Id: 106, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 107, StoreId: 2, Role: metapb.PeerRole_Learner}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 3, Peers: peers3, RegionEpoch: &metapb.RegionEpoch{ConfVer: 3, Version: 3}}, peers3[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 3, Peers: peers3, RegionEpoch: &metapb.RegionEpoch{ConfVer: 3, Version: 3}}, peers3[0], core.WithStartKey([]byte("mno")), core.WithEndKey([]byte("pqr")))) - for _, rg := range suite.regions { - suite.svr.GetBasicCluster().PutRegion(rg) + for _, rg := range regions { + tests.MustPutRegionInfo(re, cluster, rg) } -} - -func (suite *regionRuleTestSuite) TearDownSuite() { - suite.cleanup() -} -func (suite *regionRuleTestSuite) TestRegionPlacementRule() { - ruleManager := suite.svr.GetRaftCluster().GetRuleManager() + ruleManager := leaderServer.GetRaftCluster().GetRuleManager() ruleManager.SetRule(&placement.Rule{ GroupID: "test", ID: "test2", @@ -934,38 +1056,38 @@ func (suite *regionRuleTestSuite) TestRegionPlacementRule() { Role: placement.Learner, Count: 1, }) - re := suite.Require() - url := fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 1) fit := &placement.RegionFit{} + + url := fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) err := tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 1) suite.Equal(len(fit.OrphanPeers), 1) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 2) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 2) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 2) suite.Equal(len(fit.OrphanPeers), 0) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 3) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 3) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 0) suite.Equal(len(fit.OrphanPeers), 2) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 4) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 4) err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( re, "region 4 not found")) suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%s/detail", suite.urlPrefix, "id") + url = fmt.Sprintf("%s/config/rules/region/%s/detail", urlPrefix, "id") err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( - re, "invalid region id")) + re, errs.ErrRegionInvalidID.Error())) suite.NoError(err) - suite.svr.GetRaftCluster().GetReplicationConfig().EnablePlacementRules = false - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 1) + leaderServer.GetRaftCluster().GetReplicationConfig().EnablePlacementRules = false + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( re, "placement rules feature is disabled")) suite.NoError(err) From da30175bdbb44a6dd7180e89fbc1076c781aba3a Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 10 Nov 2023 17:26:13 +0800 Subject: [PATCH 17/26] etcdutil, leadership: avoid redundant created watch channel (#7352) close tikv/pd#7351 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/election/leadership.go | 1 + pkg/utils/etcdutil/etcdutil.go | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/election/leadership.go b/pkg/election/leadership.go index d5d73e90b58..8cfdcf423ac 100644 --- a/pkg/election/leadership.go +++ b/pkg/election/leadership.go @@ -260,6 +260,7 @@ func (ls *Leadership) Watch(serverCtx context.Context, revision int64) { continue } } + lastReceivedResponseTime = time.Now() log.Info("watch channel is created", zap.Int64("revision", revision), zap.String("leader-key", ls.leaderKey), zap.String("purpose", ls.purpose)) watchChanLoop: diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index 1432b6e37c3..e004247c6d0 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -707,7 +707,6 @@ func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision }() ticker := time.NewTicker(RequestProgressInterval) defer ticker.Stop() - lastReceivedResponseTime := time.Now() for { if watcherCancel != nil { @@ -736,8 +735,10 @@ func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision continue } } + lastReceivedResponseTime := time.Now() log.Info("watch channel is created in watch loop", zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) + watchChanLoop: select { case <-ctx.Done(): From 71621d3cc296190f11cba4532bc4afcb73e314c6 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 10 Nov 2023 17:42:43 +0800 Subject: [PATCH 18/26] checker: refactor fixOrphanPeers (#7342) ref tikv/pd#4399 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/schedule/checker/rule_checker.go | 42 +++++++++++++---------- pkg/schedule/checker/rule_checker_test.go | 1 + server/cluster/cluster_test.go | 2 +- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index c4e7c242dea..08ef5f7b45c 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -449,23 +449,31 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg return nil, nil } - isUnhealthyPeer := func(id uint64) bool { - for _, downPeer := range region.GetDownPeers() { - if downPeer.Peer.GetId() == id { + isPendingPeer := func(id uint64) bool { + for _, pendingPeer := range region.GetPendingPeers() { + if pendingPeer.GetId() == id { return true } } - for _, pendingPeer := range region.GetPendingPeers() { - if pendingPeer.GetId() == id { + return false + } + + isDownPeer := func(id uint64) bool { + for _, downPeer := range region.GetDownPeers() { + if downPeer.Peer.GetId() == id { return true } } return false } - isDisconnectedPeer := func(p *metapb.Peer) bool { + isUnhealthyPeer := func(id uint64) bool { + return isPendingPeer(id) || isDownPeer(id) + } + + isInDisconnectedStore := func(p *metapb.Peer) bool { // avoid to meet down store when fix orphan peers, - // Isdisconnected is more strictly than IsUnhealthy. + // isInDisconnectedStore is usually more strictly than IsUnhealthy. store := c.cluster.GetStore(p.GetStoreId()) if store == nil { return true @@ -475,16 +483,12 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg checkDownPeer := func(peers []*metapb.Peer) (*metapb.Peer, bool) { for _, p := range peers { - if isUnhealthyPeer(p.GetId()) { - // make sure is down peer. - if region.GetDownPeer(p.GetId()) != nil { - return p, true - } - return nil, true - } - if isDisconnectedPeer(p) { + if isInDisconnectedStore(p) || isDownPeer(p.GetId()) { return p, true } + if isPendingPeer(p.GetId()) { + return nil, true + } } return nil, false } @@ -517,7 +521,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg continue } // make sure the orphan peer is healthy. - if isUnhealthyPeer(orphanPeer.GetId()) || isDisconnectedPeer(orphanPeer) { + if isUnhealthyPeer(orphanPeer.GetId()) || isInDisconnectedStore(orphanPeer) { continue } // no consider witness in this path. @@ -525,7 +529,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg continue } // pinDownPeer's store should be disconnected, because we use more strict judge before. - if !isDisconnectedPeer(pinDownPeer) { + if !isInDisconnectedStore(pinDownPeer) { continue } // check if down peer can replace with orphan peer. @@ -539,7 +543,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg return operator.CreatePromoteLearnerOperatorAndRemovePeer("replace-down-peer-with-orphan-peer", c.cluster, region, orphanPeer, pinDownPeer) case orphanPeerRole == metapb.PeerRole_Voter && destRole == metapb.PeerRole_Learner: return operator.CreateDemoteLearnerOperatorAndRemovePeer("replace-down-peer-with-orphan-peer", c.cluster, region, orphanPeer, pinDownPeer) - case orphanPeerRole == destRole && isDisconnectedPeer(pinDownPeer) && !dstStore.IsDisconnected(): + case orphanPeerRole == destRole && isInDisconnectedStore(pinDownPeer) && !dstStore.IsDisconnected(): return operator.CreateRemovePeerOperator("remove-replaced-orphan-peer", c.cluster, 0, region, pinDownPeer.GetStoreId()) default: // destRole should not same with orphanPeerRole. if role is same, it fit with orphanPeer should be better than now. @@ -557,7 +561,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg hasHealthPeer := false var disconnectedPeer *metapb.Peer for _, orphanPeer := range fit.OrphanPeers { - if isDisconnectedPeer(orphanPeer) { + if isInDisconnectedStore(orphanPeer) { disconnectedPeer = orphanPeer break } diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index eb357f302b7..4185ce6c167 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -1052,6 +1052,7 @@ func (suite *ruleCheckerTestSuite) TestPriorityFitHealthWithDifferentRole1() { suite.Equal("replace-down-peer-with-orphan-peer", op.Desc()) // set peer3 only pending + suite.cluster.GetStore(3).GetMeta().LastHeartbeat = time.Now().UnixNano() r1 = r1.Clone(core.WithDownPeers(nil)) suite.cluster.PutRegion(r1) op = suite.rc.Check(suite.cluster.GetRegion(1)) diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 89c9ea32f19..70782e27cd3 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -2792,7 +2792,7 @@ func TestReplica(t *testing.T) { re.NoError(dispatchHeartbeat(co, region, stream)) waitNoResponse(re, stream) - // Remove peer from store 4. + // Remove peer from store 3. re.NoError(tc.addLeaderRegion(2, 1, 2, 3, 4)) region = tc.GetRegion(2) re.NoError(dispatchHeartbeat(co, region, stream)) From ad96bf1f3522e6e73c14990eba24971dedfa315a Mon Sep 17 00:00:00 2001 From: Hu# Date: Fri, 10 Nov 2023 17:59:14 +0800 Subject: [PATCH 19/26] mcs: fix error typo (#7354) ref tikv/pd#4399 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/tso/server/install/install.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/mcs/tso/server/install/install.go b/pkg/mcs/tso/server/install/install.go index 27db0c51d75..a821505474f 100644 --- a/pkg/mcs/tso/server/install/install.go +++ b/pkg/mcs/tso/server/install/install.go @@ -28,5 +28,5 @@ func init() { // Install registers the API group and grpc service. func Install(register *registry.ServiceRegistry) { - register.RegisterService("Scheduling", server.NewService[*server.Server]) + register.RegisterService("TSO", server.NewService[*server.Server]) } From 1a0233b8d598b83904024cbd160d77cac5eaa446 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 10 Nov 2023 19:33:43 +0800 Subject: [PATCH 20/26] mcs: use a controller to manage scheduling jobs (#7270) ref tikv/pd#5839 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/cluster/cluster.go | 402 ++++------------- server/cluster/cluster_test.go | 13 +- server/cluster/cluster_worker.go | 3 +- server/cluster/scheduling_controller.go | 424 ++++++++++++++++++ server/grpc_service.go | 12 +- server/handler.go | 7 +- server/server.go | 11 +- .../mcs/scheduling/config_test.go | 1 - tests/server/cluster/cluster_test.go | 6 +- 9 files changed, 531 insertions(+), 348 deletions(-) create mode 100644 server/cluster/scheduling_controller.go diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 8362ee9f331..c0fb1b15f8f 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -41,22 +41,18 @@ import ( "github.com/tikv/pd/pkg/gctuner" "github.com/tikv/pd/pkg/id" "github.com/tikv/pd/pkg/keyspace" + mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/memory" "github.com/tikv/pd/pkg/progress" "github.com/tikv/pd/pkg/replication" "github.com/tikv/pd/pkg/schedule" - "github.com/tikv/pd/pkg/schedule/checker" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/schedule/placement" - "github.com/tikv/pd/pkg/schedule/scatter" - "github.com/tikv/pd/pkg/schedule/schedulers" - "github.com/tikv/pd/pkg/schedule/splitter" "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/statistics" - "github.com/tikv/pd/pkg/statistics/buckets" "github.com/tikv/pd/pkg/statistics/utils" "github.com/tikv/pd/pkg/storage" "github.com/tikv/pd/pkg/storage/endpoint" @@ -97,6 +93,7 @@ const ( clientTimeout = 3 * time.Second defaultChangedRegionsLimit = 10000 gcTombstoneInterval = 30 * 24 * time.Hour + serviceCheckInterval = 10 * time.Second // persistLimitRetryTimes is used to reduce the probability of the persistent error // since the once the store is added or removed, we shouldn't return an error even if the store limit is failed to persist. persistLimitRetryTimes = 5 @@ -155,16 +152,12 @@ type RaftCluster struct { prevStoreLimit map[uint64]map[storelimit.Type]float64 // This below fields are all read-only, we cannot update itself after the raft cluster starts. - clusterID uint64 - id id.Allocator - core *core.BasicCluster // cached cluster info - opt *config.PersistOptions - limiter *StoreLimiter - coordinator *schedule.Coordinator - labelLevelStats *statistics.LabelStatistics - regionStats *statistics.RegionStatistics - hotStat *statistics.HotStat - slowStat *statistics.SlowStat + clusterID uint64 + id id.Allocator + core *core.BasicCluster // cached cluster info + opt *config.PersistOptions + limiter *StoreLimiter + *schedulingController ruleManager *placement.RuleManager regionLabeler *labeler.RegionLabeler replicationMode *replication.ModeManager @@ -173,6 +166,8 @@ type RaftCluster struct { regionSyncer *syncer.RegionSyncer changedRegions chan *core.RegionInfo keyspaceGroupManager *keyspace.GroupManager + independentServices sync.Map + hbstreams *hbstream.HeartbeatStreams } // Status saves some state information. @@ -266,17 +261,17 @@ func (c *RaftCluster) InitCluster( opt sc.ConfProvider, storage storage.Storage, basicCluster *core.BasicCluster, + hbstreams *hbstream.HeartbeatStreams, keyspaceGroupManager *keyspace.GroupManager) { c.core, c.opt, c.storage, c.id = basicCluster, opt.(*config.PersistOptions), storage, id c.ctx, c.cancel = context.WithCancel(c.serverCtx) - c.labelLevelStats = statistics.NewLabelStatistics() - c.hotStat = statistics.NewHotStat(c.ctx) - c.slowStat = statistics.NewSlowStat(c.ctx) c.progressManager = progress.NewManager() c.changedRegions = make(chan *core.RegionInfo, defaultChangedRegionsLimit) c.prevStoreLimit = make(map[uint64]map[storelimit.Type]float64) c.unsafeRecoveryController = unsaferecovery.NewController(c) c.keyspaceGroupManager = keyspaceGroupManager + c.hbstreams = hbstreams + c.schedulingController = newSchedulingController(c.ctx) } // Start starts a cluster. @@ -290,7 +285,7 @@ func (c *RaftCluster) Start(s Server) error { } c.isAPIServiceMode = s.IsAPIServiceMode() - c.InitCluster(s.GetAllocator(), s.GetPersistOptions(), s.GetStorage(), s.GetBasicCluster(), s.GetKeyspaceGroupManager()) + c.InitCluster(s.GetAllocator(), s.GetPersistOptions(), s.GetStorage(), s.GetBasicCluster(), s.GetHBStreams(), s.GetKeyspaceGroupManager()) cluster, err := c.LoadClusterInfo() if err != nil { return err @@ -316,8 +311,7 @@ func (c *RaftCluster) Start(s Server) error { return err } - c.coordinator = schedule.NewCoordinator(c.ctx, cluster, s.GetHBStreams()) - c.regionStats = statistics.NewRegionStatistics(c.core, c.opt, c.ruleManager) + c.schedulingController.init(c.core, c.opt, schedule.NewCoordinator(c.ctx, c, c.GetHeartbeatStreams()), c.ruleManager) c.limiter = NewStoreLimiter(s.GetPersistOptions()) c.externalTS, err = c.storage.LoadExternalTS() if err != nil { @@ -331,14 +325,9 @@ func (c *RaftCluster) Start(s Server) error { if err != nil { return err } - c.initSchedulers() - } else { - c.wg.Add(2) - go c.runCoordinator() - go c.runStatsBackgroundJobs() } - - c.wg.Add(8) + c.wg.Add(9) + go c.runServiceCheckJob() go c.runMetricsCollectionJob() go c.runNodeStateCheckJob() go c.syncRegions() @@ -352,6 +341,38 @@ func (c *RaftCluster) Start(s Server) error { return nil } +func (c *RaftCluster) runServiceCheckJob() { + defer logutil.LogPanic() + defer c.wg.Done() + + var once sync.Once + + checkFn := func() { + if c.isAPIServiceMode { + once.Do(c.initSchedulers) + c.independentServices.Store(mcsutils.SchedulingServiceName, true) + return + } + if c.startSchedulingJobs() { + c.independentServices.Delete(mcsutils.SchedulingServiceName) + } + } + checkFn() + + ticker := time.NewTicker(serviceCheckInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("service check job is stopped") + return + case <-ticker.C: + checkFn() + } + } +} + // startGCTuner func (c *RaftCluster) startGCTuner() { defer logutil.LogPanic() @@ -600,10 +621,9 @@ func (c *RaftCluster) LoadClusterInfo() (*RaftCluster, error) { zap.Int("count", c.core.GetTotalRegionCount()), zap.Duration("cost", time.Since(start)), ) - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { for _, store := range c.GetStores() { storeID := store.GetID() - c.hotStat.GetOrCreateRollingStoreStats(storeID) c.slowStat.ObserveSlowStoreStatus(storeID, store.IsSlow()) } } @@ -619,7 +639,6 @@ func (c *RaftCluster) runMetricsCollectionJob() { ticker.Stop() ticker = time.NewTicker(time.Microsecond) }) - defer ticker.Stop() for { @@ -657,24 +676,6 @@ func (c *RaftCluster) runNodeStateCheckJob() { } } -func (c *RaftCluster) runStatsBackgroundJobs() { - defer logutil.LogPanic() - defer c.wg.Done() - - ticker := time.NewTicker(statistics.RegionsStatsObserveInterval) - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - log.Info("statistics background jobs has been stopped") - return - case <-ticker.C: - c.hotStat.ObserveRegionsStats(c.core.GetStoresWriteRate()) - } - } -} - func (c *RaftCluster) runUpdateStoreStats() { defer logutil.LogPanic() defer c.wg.Done() @@ -696,13 +697,6 @@ func (c *RaftCluster) runUpdateStoreStats() { } } -// runCoordinator runs the main scheduling loop. -func (c *RaftCluster) runCoordinator() { - defer logutil.LogPanic() - defer c.wg.Done() - c.coordinator.RunUntilStop() -} - func (c *RaftCluster) syncRegions() { defer logutil.LogPanic() defer c.wg.Done() @@ -723,8 +717,8 @@ func (c *RaftCluster) Stop() { return } c.running = false - if !c.isAPIServiceMode { - c.coordinator.Stop() + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { + c.stopSchedulingJobs() } c.cancel() c.Unlock() @@ -750,6 +744,11 @@ func (c *RaftCluster) Context() context.Context { return nil } +// GetHeartbeatStreams returns the heartbeat streams. +func (c *RaftCluster) GetHeartbeatStreams() *hbstream.HeartbeatStreams { + return c.hbstreams +} + // GetCoordinator returns the coordinator. func (c *RaftCluster) GetCoordinator() *schedule.Coordinator { return c.coordinator @@ -760,71 +759,6 @@ func (c *RaftCluster) GetOperatorController() *operator.Controller { return c.coordinator.GetOperatorController() } -// SetPrepared set the prepare check to prepared. Only for test purpose. -func (c *RaftCluster) SetPrepared() { - c.coordinator.GetPrepareChecker().SetPrepared() -} - -// GetRegionScatterer returns the region scatter. -func (c *RaftCluster) GetRegionScatterer() *scatter.RegionScatterer { - return c.coordinator.GetRegionScatterer() -} - -// GetRegionSplitter returns the region splitter -func (c *RaftCluster) GetRegionSplitter() *splitter.RegionSplitter { - return c.coordinator.GetRegionSplitter() -} - -// GetMergeChecker returns merge checker. -func (c *RaftCluster) GetMergeChecker() *checker.MergeChecker { - return c.coordinator.GetMergeChecker() -} - -// GetRuleChecker returns rule checker. -func (c *RaftCluster) GetRuleChecker() *checker.RuleChecker { - return c.coordinator.GetRuleChecker() -} - -// GetSchedulers gets all schedulers. -func (c *RaftCluster) GetSchedulers() []string { - return c.coordinator.GetSchedulersController().GetSchedulerNames() -} - -// GetSchedulerHandlers gets all scheduler handlers. -func (c *RaftCluster) GetSchedulerHandlers() map[string]http.Handler { - return c.coordinator.GetSchedulersController().GetSchedulerHandlers() -} - -// AddSchedulerHandler adds a scheduler handler. -func (c *RaftCluster) AddSchedulerHandler(scheduler schedulers.Scheduler, args ...string) error { - return c.coordinator.GetSchedulersController().AddSchedulerHandler(scheduler, args...) -} - -// RemoveSchedulerHandler removes a scheduler handler. -func (c *RaftCluster) RemoveSchedulerHandler(name string) error { - return c.coordinator.GetSchedulersController().RemoveSchedulerHandler(name) -} - -// AddScheduler adds a scheduler. -func (c *RaftCluster) AddScheduler(scheduler schedulers.Scheduler, args ...string) error { - return c.coordinator.GetSchedulersController().AddScheduler(scheduler, args...) -} - -// RemoveScheduler removes a scheduler. -func (c *RaftCluster) RemoveScheduler(name string) error { - return c.coordinator.GetSchedulersController().RemoveScheduler(name) -} - -// PauseOrResumeScheduler pauses or resumes a scheduler. -func (c *RaftCluster) PauseOrResumeScheduler(name string, t int64) error { - return c.coordinator.GetSchedulersController().PauseOrResumeScheduler(name, t) -} - -// PauseOrResumeChecker pauses or resumes checker. -func (c *RaftCluster) PauseOrResumeChecker(name string, t int64) error { - return c.coordinator.PauseOrResumeChecker(name, t) -} - // AllocID returns a global unique ID. func (c *RaftCluster) AllocID() (uint64, error) { return c.id.Alloc() @@ -861,10 +795,6 @@ func (c *RaftCluster) GetOpts() sc.ConfProvider { return c.opt } -func (c *RaftCluster) initSchedulers() { - c.coordinator.InitSchedulers(false) -} - // GetScheduleConfig returns scheduling configurations. func (c *RaftCluster) GetScheduleConfig() *sc.ScheduleConfig { return c.opt.GetScheduleConfig() @@ -890,60 +820,11 @@ func (c *RaftCluster) SetPDServerConfig(cfg *config.PDServerConfig) { c.opt.SetPDServerConfig(cfg) } -// AddSuspectRegions adds regions to suspect list. -func (c *RaftCluster) AddSuspectRegions(regionIDs ...uint64) { - c.coordinator.GetCheckerController().AddSuspectRegions(regionIDs...) -} - -// GetSuspectRegions gets all suspect regions. -func (c *RaftCluster) GetSuspectRegions() []uint64 { - return c.coordinator.GetCheckerController().GetSuspectRegions() -} - -// GetHotStat gets hot stat. -func (c *RaftCluster) GetHotStat() *statistics.HotStat { - return c.hotStat -} - -// GetRegionStats gets region statistics. -func (c *RaftCluster) GetRegionStats() *statistics.RegionStatistics { - return c.regionStats -} - -// GetLabelStats gets label statistics. -func (c *RaftCluster) GetLabelStats() *statistics.LabelStatistics { - return c.labelLevelStats -} - -// RemoveSuspectRegion removes region from suspect list. -func (c *RaftCluster) RemoveSuspectRegion(id uint64) { - c.coordinator.GetCheckerController().RemoveSuspectRegion(id) -} - // GetUnsafeRecoveryController returns the unsafe recovery controller. func (c *RaftCluster) GetUnsafeRecoveryController() *unsaferecovery.Controller { return c.unsafeRecoveryController } -// AddSuspectKeyRange adds the key range with the its ruleID as the key -// The instance of each keyRange is like following format: -// [2][]byte: start key/end key -func (c *RaftCluster) AddSuspectKeyRange(start, end []byte) { - c.coordinator.GetCheckerController().AddSuspectKeyRange(start, end) -} - -// PopOneSuspectKeyRange gets one suspect keyRange group. -// it would return value and true if pop success, or return empty [][2][]byte and false -// if suspectKeyRanges couldn't pop keyRange group. -func (c *RaftCluster) PopOneSuspectKeyRange() ([2][]byte, bool) { - return c.coordinator.GetCheckerController().PopOneSuspectKeyRange() -} - -// ClearSuspectKeyRanges clears the suspect keyRanges, only for unit test -func (c *RaftCluster) ClearSuspectKeyRanges() { - c.coordinator.GetCheckerController().ClearSuspectKeyRanges() -} - // HandleStoreHeartbeat updates the store status. func (c *RaftCluster) HandleStoreHeartbeat(heartbeat *pdpb.StoreHeartbeatRequest, resp *pdpb.StoreHeartbeatResponse) error { stats := heartbeat.GetStats() @@ -970,7 +851,7 @@ func (c *RaftCluster) HandleStoreHeartbeat(heartbeat *pdpb.StoreHeartbeatRequest nowTime := time.Now() var newStore *core.StoreInfo // If this cluster has slow stores, we should awaken hibernated regions in other stores. - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { if needAwaken, slowStoreIDs := c.NeedAwakenAllRegionsInStore(storeID); needAwaken { log.Info("forcely awaken hibernated regions", zap.Uint64("store-id", storeID), zap.Uint64s("slow-stores", slowStoreIDs)) newStore = store.Clone(core.SetStoreStats(stats), core.SetLastHeartbeatTS(nowTime), core.SetLastAwakenTime(nowTime), opt) @@ -1005,7 +886,7 @@ func (c *RaftCluster) HandleStoreHeartbeat(heartbeat *pdpb.StoreHeartbeatRequest regions map[uint64]*core.RegionInfo interval uint64 ) - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { c.hotStat.Observe(storeID, newStore.GetStoreStats()) c.hotStat.FilterUnhealthyStore(c) c.slowStat.ObserveSlowStoreStatus(storeID, newStore.IsSlow()) @@ -1061,7 +942,7 @@ func (c *RaftCluster) HandleStoreHeartbeat(heartbeat *pdpb.StoreHeartbeatRequest e := int64(dur)*2 - int64(stat.GetTotalDurationSec()) store.Feedback(float64(e)) } - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { // Here we will compare the reported regions with the previous hot peers to decide if it is still hot. c.hotStat.CheckReadAsync(statistics.NewCollectUnReportedPeerTask(storeID, regions, interval)) } @@ -1097,11 +978,6 @@ func (c *RaftCluster) processReportBuckets(buckets *metapb.Buckets) error { return nil } -// IsPrepared return true if the prepare checker is ready. -func (c *RaftCluster) IsPrepared() bool { - return c.coordinator.GetPrepareChecker().IsPrepared() -} - var regionGuide = core.GenerateRegionGuideFunc(true) // processRegionHeartbeat updates the region information. @@ -1112,7 +988,7 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { } region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { cluster.HandleStatsAsync(c, region) } @@ -1121,7 +997,7 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { // Save to cache if meta or leader is updated, or contains any down/pending peer. // Mark isNew if the region in cache does not have leader. isNew, saveKV, saveCache, needSync := regionGuide(region, origin) - if !c.isAPIServiceMode && !saveKV && !saveCache && !isNew { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) && !saveKV && !saveCache && !isNew { // Due to some config changes need to update the region stats as well, // so we do some extra checks here. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { @@ -1146,13 +1022,13 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { if overlaps, err = c.core.AtomicCheckAndPutRegion(region); err != nil { return err } - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { cluster.HandleOverlaps(c, overlaps) } regionUpdateCacheEventCounter.Inc() } - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats, isNew, c.IsPrepared()) } @@ -1566,24 +1442,6 @@ func (c *RaftCluster) checkReplicaBeforeOfflineStore(storeID uint64) error { return nil } -func (c *RaftCluster) getEvictLeaderStores() (evictStores []uint64) { - if c.coordinator == nil { - return nil - } - handler, ok := c.coordinator.GetSchedulersController().GetSchedulerHandlers()[schedulers.EvictLeaderName] - if !ok { - return - } - type evictLeaderHandler interface { - EvictStoreIDs() []uint64 - } - h, ok := handler.(evictLeaderHandler) - if !ok { - return - } - return h.EvictStoreIDs() -} - func (c *RaftCluster) getUpStores() []uint64 { upStores := make([]uint64, 0) for _, store := range c.GetStores() { @@ -1634,9 +1492,8 @@ func (c *RaftCluster) BuryStore(storeID uint64, forceBury bool) error { c.resetProgress(storeID, addr) storeIDStr := strconv.FormatUint(storeID, 10) statistics.ResetStoreStatistics(addr, storeIDStr) - if !c.isAPIServiceMode { - c.hotStat.RemoveRollingStoreStats(storeID) - c.slowStat.RemoveSlowStoreStatus(storeID) + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { + c.removeStoreStatistics(storeID) } } return err @@ -1811,9 +1668,8 @@ func (c *RaftCluster) putStoreLocked(store *core.StoreInfo) error { } } c.core.PutStore(store) - if !c.isAPIServiceMode { - c.hotStat.GetOrCreateRollingStoreStats(store.GetID()) - c.slowStat.ObserveSlowStoreStatus(store.GetID(), store.IsSlow()) + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { + c.updateStoreStatistics(store.GetID(), store.IsSlow()) } return nil } @@ -2162,53 +2018,14 @@ func (c *RaftCluster) deleteStore(store *core.StoreInfo) error { } func (c *RaftCluster) collectMetrics() { - if !c.isAPIServiceMode { - statsMap := statistics.NewStoreStatisticsMap(c.opt) - stores := c.GetStores() - for _, s := range stores { - statsMap.Observe(s) - statsMap.ObserveHotStat(s, c.hotStat.StoresStats) - } - statsMap.Collect() - c.coordinator.GetSchedulersController().CollectSchedulerMetrics() - c.coordinator.CollectHotSpotMetrics() - c.collectClusterMetrics() - } c.collectHealthStatus() } func (c *RaftCluster) resetMetrics() { - statistics.Reset() - - if !c.isAPIServiceMode { - c.coordinator.GetSchedulersController().ResetSchedulerMetrics() - c.coordinator.ResetHotSpotMetrics() - c.resetClusterMetrics() - } c.resetHealthStatus() c.resetProgressIndicator() } -func (c *RaftCluster) collectClusterMetrics() { - if c.regionStats == nil { - return - } - c.regionStats.Collect() - c.labelLevelStats.Collect() - // collect hot cache metrics - c.hotStat.CollectMetrics() -} - -func (c *RaftCluster) resetClusterMetrics() { - if c.regionStats == nil { - return - } - c.regionStats.Reset() - c.labelLevelStats.Reset() - // reset hot cache metrics - c.hotStat.ResetMetrics() -} - func (c *RaftCluster) collectHealthStatus() { members, err := GetMembers(c.etcdClient) if err != nil { @@ -2235,21 +2052,6 @@ func (c *RaftCluster) resetProgressIndicator() { storesETAGauge.Reset() } -// GetRegionStatsByType gets the status of the region by types. -func (c *RaftCluster) GetRegionStatsByType(typ statistics.RegionStatisticType) []*core.RegionInfo { - if c.regionStats == nil { - return nil - } - return c.regionStats.GetRegionStatsByType(typ) -} - -// UpdateRegionsLabelLevelStats updates the status of the region label level by types. -func (c *RaftCluster) UpdateRegionsLabelLevelStats(regions []*core.RegionInfo) { - for _, region := range regions { - c.labelLevelStats.Observe(region, c.getStoresWithoutLabelLocked(region, core.EngineKey, core.EngineTiFlash), c.opt.GetLocationLabels()) - } -} - func (c *RaftCluster) getRegionStoresLocked(region *core.RegionInfo) []*core.StoreInfo { stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) for _, p := range region.GetPeers() { @@ -2260,16 +2062,6 @@ func (c *RaftCluster) getRegionStoresLocked(region *core.RegionInfo) []*core.Sto return stores } -func (c *RaftCluster) getStoresWithoutLabelLocked(region *core.RegionInfo, key, value string) []*core.StoreInfo { - stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) - for _, p := range region.GetPeers() { - if store := c.core.GetStore(p.StoreId); store != nil && !core.IsStoreContainLabel(store.GetMeta(), key, value) { - stores = append(stores, store) - } - } - return stores -} - // OnStoreVersionChange changes the version of the cluster when needed. func (c *RaftCluster) OnStoreVersionChange() { c.RLock() @@ -2345,49 +2137,6 @@ func (c *RaftCluster) GetRegionCount(startKey, endKey []byte) *statistics.Region return stats } -// GetStoresStats returns stores' statistics from cluster. -// And it will be unnecessary to filter unhealthy store, because it has been solved in process heartbeat -func (c *RaftCluster) GetStoresStats() *statistics.StoresStats { - return c.hotStat.StoresStats -} - -// GetStoresLoads returns load stats of all stores. -func (c *RaftCluster) GetStoresLoads() map[uint64][]float64 { - return c.hotStat.GetStoresLoads() -} - -// IsRegionHot checks if a region is in hot state. -func (c *RaftCluster) IsRegionHot(region *core.RegionInfo) bool { - return c.hotStat.IsRegionHot(region, c.opt.GetHotRegionCacheHitsThreshold()) -} - -// GetHotPeerStat returns hot peer stat with specified regionID and storeID. -func (c *RaftCluster) GetHotPeerStat(rw utils.RWType, regionID, storeID uint64) *statistics.HotPeerStat { - return c.hotStat.GetHotPeerStat(rw, regionID, storeID) -} - -// RegionReadStats returns hot region's read stats. -// The result only includes peers that are hot enough. -// RegionStats is a thread-safe method -func (c *RaftCluster) RegionReadStats() map[uint64][]*statistics.HotPeerStat { - // As read stats are reported by store heartbeat, the threshold needs to be adjusted. - threshold := c.GetOpts().GetHotRegionCacheHitsThreshold() * - (utils.RegionHeartBeatReportInterval / utils.StoreHeartBeatReportInterval) - return c.hotStat.RegionStats(utils.Read, threshold) -} - -// RegionWriteStats returns hot region's write stats. -// The result only includes peers that are hot enough. -func (c *RaftCluster) RegionWriteStats() map[uint64][]*statistics.HotPeerStat { - // RegionStats is a thread-safe method - return c.hotStat.RegionStats(utils.Write, c.GetOpts().GetHotRegionCacheHitsThreshold()) -} - -// BucketsStats returns hot region's buckets stats. -func (c *RaftCluster) BucketsStats(degree int, regionIDs ...uint64) map[uint64][]*buckets.BucketStat { - return c.hotStat.BucketsStats(degree, regionIDs...) -} - // TODO: remove me. // only used in test. func (c *RaftCluster) putRegion(region *core.RegionInfo) error { @@ -2775,12 +2524,11 @@ func IsClientURL(addr string, etcdClient *clientv3.Client) bool { return false } -// GetPausedSchedulerDelayAt returns DelayAt of a paused scheduler -func (c *RaftCluster) GetPausedSchedulerDelayAt(name string) (int64, error) { - return c.coordinator.GetSchedulersController().GetPausedSchedulerDelayAt(name) -} - -// GetPausedSchedulerDelayUntil returns DelayUntil of a paused scheduler -func (c *RaftCluster) GetPausedSchedulerDelayUntil(name string) (int64, error) { - return c.coordinator.GetSchedulersController().GetPausedSchedulerDelayUntil(name) +// IsServiceIndependent returns whether the service is independent. +func (c *RaftCluster) IsServiceIndependent(name string) bool { + independent, exist := c.independentServices.Load(name) + if !exist { + return false + } + return independent.(bool) } diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 70782e27cd3..7ebd012a6a2 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -1144,7 +1144,7 @@ func TestRegionLabelIsolationLevel(t *testing.T) { re.NoError(cluster.putRegion(r)) cluster.UpdateRegionsLabelLevelStats([]*core.RegionInfo{r}) - counter := cluster.labelLevelStats.GetLabelCounter() + counter := cluster.labelStats.GetLabelCounter() re.Equal(0, counter["none"]) re.Equal(1, counter["zone"]) } @@ -2130,7 +2130,7 @@ func newTestRaftCluster( basicCluster *core.BasicCluster, ) *RaftCluster { rc := &RaftCluster{serverCtx: ctx} - rc.InitCluster(id, opt, s, basicCluster, nil) + rc.InitCluster(id, opt, s, basicCluster, nil, nil) rc.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), rc, opt) if opt.IsPlacementRulesEnabled() { err := rc.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) @@ -2138,6 +2138,7 @@ func newTestRaftCluster( panic(err) } } + rc.schedulingController.init(basicCluster, opt, nil, rc.ruleManager) return rc } @@ -2502,11 +2503,11 @@ func TestCollectMetricsConcurrent(t *testing.T) { for i := 0; i < 1000; i++ { co.CollectHotSpotMetrics() controller.CollectSchedulerMetrics() - co.GetCluster().(*RaftCluster).collectClusterMetrics() + co.GetCluster().(*RaftCluster).collectStatisticsMetrics() } co.ResetHotSpotMetrics() controller.ResetSchedulerMetrics() - co.GetCluster().(*RaftCluster).resetClusterMetrics() + co.GetCluster().(*RaftCluster).resetStatisticsMetrics() wg.Wait() } @@ -2537,7 +2538,7 @@ func TestCollectMetrics(t *testing.T) { for i := 0; i < 1000; i++ { co.CollectHotSpotMetrics() controller.CollectSchedulerMetrics() - co.GetCluster().(*RaftCluster).collectClusterMetrics() + co.GetCluster().(*RaftCluster).collectStatisticsMetrics() } stores := co.GetCluster().GetStores() regionStats := co.GetCluster().RegionWriteStats() @@ -2552,7 +2553,7 @@ func TestCollectMetrics(t *testing.T) { re.Equal(status1, status2) co.ResetHotSpotMetrics() controller.ResetSchedulerMetrics() - co.GetCluster().(*RaftCluster).resetClusterMetrics() + co.GetCluster().(*RaftCluster).resetStatisticsMetrics() } func prepare(setCfg func(*sc.ScheduleConfig), setTc func(*testCluster), run func(*schedule.Coordinator), re *require.Assertions) (*testCluster, *schedule.Coordinator, func()) { diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index a38ae86123f..3a319c48196 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" + mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/statistics/buckets" "github.com/tikv/pd/pkg/utils/logutil" @@ -233,7 +234,7 @@ func (c *RaftCluster) HandleReportBuckets(b *metapb.Buckets) error { if err := c.processReportBuckets(b); err != nil { return err } - if !c.isAPIServiceMode { + if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { c.hotStat.CheckAsync(buckets.NewCheckPeerTask(b)) } return nil diff --git a/server/cluster/scheduling_controller.go b/server/cluster/scheduling_controller.go new file mode 100644 index 00000000000..1c41c830cf6 --- /dev/null +++ b/server/cluster/scheduling_controller.go @@ -0,0 +1,424 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cluster + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/schedule" + "github.com/tikv/pd/pkg/schedule/checker" + "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/schedule/scatter" + "github.com/tikv/pd/pkg/schedule/schedulers" + "github.com/tikv/pd/pkg/schedule/splitter" + "github.com/tikv/pd/pkg/statistics" + "github.com/tikv/pd/pkg/statistics/buckets" + "github.com/tikv/pd/pkg/statistics/utils" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/server/config" +) + +type schedulingController struct { + parentCtx context.Context + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + wg sync.WaitGroup + *core.BasicCluster + opt *config.PersistOptions + coordinator *schedule.Coordinator + labelStats *statistics.LabelStatistics + regionStats *statistics.RegionStatistics + hotStat *statistics.HotStat + slowStat *statistics.SlowStat + running bool +} + +func newSchedulingController(parentCtx context.Context) *schedulingController { + ctx, cancel := context.WithCancel(parentCtx) + return &schedulingController{ + parentCtx: parentCtx, + ctx: ctx, + cancel: cancel, + labelStats: statistics.NewLabelStatistics(), + hotStat: statistics.NewHotStat(parentCtx), + slowStat: statistics.NewSlowStat(parentCtx), + } +} + +func (sc *schedulingController) init(basicCluster *core.BasicCluster, opt *config.PersistOptions, coordinator *schedule.Coordinator, ruleManager *placement.RuleManager) { + sc.BasicCluster = basicCluster + sc.opt = opt + sc.coordinator = coordinator + sc.regionStats = statistics.NewRegionStatistics(basicCluster, opt, ruleManager) +} + +func (sc *schedulingController) stopSchedulingJobs() bool { + sc.mu.Lock() + defer sc.mu.Unlock() + if !sc.running { + return false + } + sc.coordinator.Stop() + sc.cancel() + sc.wg.Wait() + sc.running = false + log.Info("scheduling service is stopped") + return true +} + +func (sc *schedulingController) startSchedulingJobs() bool { + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.running { + return false + } + sc.ctx, sc.cancel = context.WithCancel(sc.parentCtx) + sc.wg.Add(3) + go sc.runCoordinator() + go sc.runStatsBackgroundJobs() + go sc.runSchedulingMetricsCollectionJob() + sc.running = true + log.Info("scheduling service is started") + return true +} + +// runCoordinator runs the main scheduling loop. +func (sc *schedulingController) runCoordinator() { + defer logutil.LogPanic() + defer sc.wg.Done() + sc.coordinator.RunUntilStop() +} + +func (sc *schedulingController) runStatsBackgroundJobs() { + defer logutil.LogPanic() + defer sc.wg.Done() + + ticker := time.NewTicker(statistics.RegionsStatsObserveInterval) + defer ticker.Stop() + + for _, store := range sc.GetStores() { + storeID := store.GetID() + sc.hotStat.GetOrCreateRollingStoreStats(storeID) + } + for { + select { + case <-sc.ctx.Done(): + log.Info("statistics background jobs has been stopped") + return + case <-ticker.C: + sc.hotStat.ObserveRegionsStats(sc.GetStoresWriteRate()) + } + } +} + +func (sc *schedulingController) runSchedulingMetricsCollectionJob() { + defer logutil.LogPanic() + defer sc.wg.Done() + + ticker := time.NewTicker(metricsCollectionJobInterval) + failpoint.Inject("highFrequencyClusterJobs", func() { + ticker.Stop() + ticker = time.NewTicker(time.Microsecond) + }) + defer ticker.Stop() + + for { + select { + case <-sc.ctx.Done(): + log.Info("scheduling metrics are reset") + sc.resetSchedulingMetrics() + log.Info("scheduling metrics collection job has been stopped") + return + case <-ticker.C: + sc.collectSchedulingMetrics() + } + } +} + +func (sc *schedulingController) resetSchedulingMetrics() { + statistics.Reset() + sc.coordinator.GetSchedulersController().ResetSchedulerMetrics() + sc.coordinator.ResetHotSpotMetrics() + sc.resetStatisticsMetrics() +} + +func (sc *schedulingController) collectSchedulingMetrics() { + statsMap := statistics.NewStoreStatisticsMap(sc.opt) + stores := sc.GetStores() + for _, s := range stores { + statsMap.Observe(s) + statsMap.ObserveHotStat(s, sc.hotStat.StoresStats) + } + statsMap.Collect() + sc.coordinator.GetSchedulersController().CollectSchedulerMetrics() + sc.coordinator.CollectHotSpotMetrics() + sc.collectStatisticsMetrics() +} + +func (sc *schedulingController) resetStatisticsMetrics() { + if sc.regionStats == nil { + return + } + sc.regionStats.Reset() + sc.labelStats.Reset() + // reset hot cache metrics + sc.hotStat.ResetMetrics() +} + +func (sc *schedulingController) collectStatisticsMetrics() { + if sc.regionStats == nil { + return + } + sc.regionStats.Collect() + sc.labelStats.Collect() + // collect hot cache metrics + sc.hotStat.CollectMetrics() +} + +func (sc *schedulingController) removeStoreStatistics(storeID uint64) { + sc.hotStat.RemoveRollingStoreStats(storeID) + sc.slowStat.RemoveSlowStoreStatus(storeID) +} + +func (sc *schedulingController) updateStoreStatistics(storeID uint64, isSlow bool) { + sc.hotStat.GetOrCreateRollingStoreStats(storeID) + sc.slowStat.ObserveSlowStoreStatus(storeID, isSlow) +} + +// GetHotStat gets hot stat. +func (sc *schedulingController) GetHotStat() *statistics.HotStat { + return sc.hotStat +} + +// GetRegionStats gets region statistics. +func (sc *schedulingController) GetRegionStats() *statistics.RegionStatistics { + return sc.regionStats +} + +// GetLabelStats gets label statistics. +func (sc *schedulingController) GetLabelStats() *statistics.LabelStatistics { + return sc.labelStats +} + +// GetRegionStatsByType gets the status of the region by types. +func (sc *schedulingController) GetRegionStatsByType(typ statistics.RegionStatisticType) []*core.RegionInfo { + if sc.regionStats == nil { + return nil + } + return sc.regionStats.GetRegionStatsByType(typ) +} + +// UpdateRegionsLabelLevelStats updates the status of the region label level by types. +func (sc *schedulingController) UpdateRegionsLabelLevelStats(regions []*core.RegionInfo) { + for _, region := range regions { + sc.labelStats.Observe(region, sc.getStoresWithoutLabelLocked(region, core.EngineKey, core.EngineTiFlash), sc.opt.GetLocationLabels()) + } +} + +func (sc *schedulingController) getStoresWithoutLabelLocked(region *core.RegionInfo, key, value string) []*core.StoreInfo { + stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) + for _, p := range region.GetPeers() { + if store := sc.GetStore(p.StoreId); store != nil && !core.IsStoreContainLabel(store.GetMeta(), key, value) { + stores = append(stores, store) + } + } + return stores +} + +// GetStoresStats returns stores' statistics from cluster. +// And it will be unnecessary to filter unhealthy store, because it has been solved in process heartbeat +func (sc *schedulingController) GetStoresStats() *statistics.StoresStats { + return sc.hotStat.StoresStats +} + +// GetStoresLoads returns load stats of all stores. +func (sc *schedulingController) GetStoresLoads() map[uint64][]float64 { + return sc.hotStat.GetStoresLoads() +} + +// IsRegionHot checks if a region is in hot state. +func (sc *schedulingController) IsRegionHot(region *core.RegionInfo) bool { + return sc.hotStat.IsRegionHot(region, sc.opt.GetHotRegionCacheHitsThreshold()) +} + +// GetHotPeerStat returns hot peer stat with specified regionID and storeID. +func (sc *schedulingController) GetHotPeerStat(rw utils.RWType, regionID, storeID uint64) *statistics.HotPeerStat { + return sc.hotStat.GetHotPeerStat(rw, regionID, storeID) +} + +// RegionReadStats returns hot region's read stats. +// The result only includes peers that are hot enough. +// RegionStats is a thread-safe method +func (sc *schedulingController) RegionReadStats() map[uint64][]*statistics.HotPeerStat { + // As read stats are reported by store heartbeat, the threshold needs to be adjusted. + threshold := sc.opt.GetHotRegionCacheHitsThreshold() * + (utils.RegionHeartBeatReportInterval / utils.StoreHeartBeatReportInterval) + return sc.hotStat.RegionStats(utils.Read, threshold) +} + +// RegionWriteStats returns hot region's write stats. +// The result only includes peers that are hot enough. +func (sc *schedulingController) RegionWriteStats() map[uint64][]*statistics.HotPeerStat { + // RegionStats is a thread-safe method + return sc.hotStat.RegionStats(utils.Write, sc.opt.GetHotRegionCacheHitsThreshold()) +} + +// BucketsStats returns hot region's buckets stats. +func (sc *schedulingController) BucketsStats(degree int, regionIDs ...uint64) map[uint64][]*buckets.BucketStat { + return sc.hotStat.BucketsStats(degree, regionIDs...) +} + +// GetPausedSchedulerDelayAt returns DelayAt of a paused scheduler +func (sc *schedulingController) GetPausedSchedulerDelayAt(name string) (int64, error) { + return sc.coordinator.GetSchedulersController().GetPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns DelayUntil of a paused scheduler +func (sc *schedulingController) GetPausedSchedulerDelayUntil(name string) (int64, error) { + return sc.coordinator.GetSchedulersController().GetPausedSchedulerDelayUntil(name) +} + +// GetRegionScatterer returns the region scatter. +func (sc *schedulingController) GetRegionScatterer() *scatter.RegionScatterer { + return sc.coordinator.GetRegionScatterer() +} + +// GetRegionSplitter returns the region splitter +func (sc *schedulingController) GetRegionSplitter() *splitter.RegionSplitter { + return sc.coordinator.GetRegionSplitter() +} + +// GetMergeChecker returns merge checker. +func (sc *schedulingController) GetMergeChecker() *checker.MergeChecker { + return sc.coordinator.GetMergeChecker() +} + +// GetRuleChecker returns rule checker. +func (sc *schedulingController) GetRuleChecker() *checker.RuleChecker { + return sc.coordinator.GetRuleChecker() +} + +// GetSchedulers gets all schedulers. +func (sc *schedulingController) GetSchedulers() []string { + return sc.coordinator.GetSchedulersController().GetSchedulerNames() +} + +// GetSchedulerHandlers gets all scheduler handlers. +func (sc *schedulingController) GetSchedulerHandlers() map[string]http.Handler { + return sc.coordinator.GetSchedulersController().GetSchedulerHandlers() +} + +// AddSchedulerHandler adds a scheduler handler. +func (sc *schedulingController) AddSchedulerHandler(scheduler schedulers.Scheduler, args ...string) error { + return sc.coordinator.GetSchedulersController().AddSchedulerHandler(scheduler, args...) +} + +// RemoveSchedulerHandler removes a scheduler handler. +func (sc *schedulingController) RemoveSchedulerHandler(name string) error { + return sc.coordinator.GetSchedulersController().RemoveSchedulerHandler(name) +} + +// AddScheduler adds a scheduler. +func (sc *schedulingController) AddScheduler(scheduler schedulers.Scheduler, args ...string) error { + return sc.coordinator.GetSchedulersController().AddScheduler(scheduler, args...) +} + +// RemoveScheduler removes a scheduler. +func (sc *schedulingController) RemoveScheduler(name string) error { + return sc.coordinator.GetSchedulersController().RemoveScheduler(name) +} + +// PauseOrResumeScheduler pauses or resumes a scheduler. +func (sc *schedulingController) PauseOrResumeScheduler(name string, t int64) error { + return sc.coordinator.GetSchedulersController().PauseOrResumeScheduler(name, t) +} + +// PauseOrResumeChecker pauses or resumes checker. +func (sc *schedulingController) PauseOrResumeChecker(name string, t int64) error { + return sc.coordinator.PauseOrResumeChecker(name, t) +} + +// AddSuspectRegions adds regions to suspect list. +func (sc *schedulingController) AddSuspectRegions(regionIDs ...uint64) { + sc.coordinator.GetCheckerController().AddSuspectRegions(regionIDs...) +} + +// GetSuspectRegions gets all suspect regions. +func (sc *schedulingController) GetSuspectRegions() []uint64 { + return sc.coordinator.GetCheckerController().GetSuspectRegions() +} + +// RemoveSuspectRegion removes region from suspect list. +func (sc *schedulingController) RemoveSuspectRegion(id uint64) { + sc.coordinator.GetCheckerController().RemoveSuspectRegion(id) +} + +// PopOneSuspectKeyRange gets one suspect keyRange group. +// it would return value and true if pop success, or return empty [][2][]byte and false +// if suspectKeyRanges couldn't pop keyRange group. +func (sc *schedulingController) PopOneSuspectKeyRange() ([2][]byte, bool) { + return sc.coordinator.GetCheckerController().PopOneSuspectKeyRange() +} + +// ClearSuspectKeyRanges clears the suspect keyRanges, only for unit test +func (sc *schedulingController) ClearSuspectKeyRanges() { + sc.coordinator.GetCheckerController().ClearSuspectKeyRanges() +} + +// AddSuspectKeyRange adds the key range with the its ruleID as the key +// The instance of each keyRange is like following format: +// [2][]byte: start key/end key +func (sc *schedulingController) AddSuspectKeyRange(start, end []byte) { + sc.coordinator.GetCheckerController().AddSuspectKeyRange(start, end) +} + +func (sc *schedulingController) initSchedulers() { + sc.coordinator.InitSchedulers(false) +} + +func (sc *schedulingController) getEvictLeaderStores() (evictStores []uint64) { + if sc.coordinator == nil { + return nil + } + handler, ok := sc.coordinator.GetSchedulersController().GetSchedulerHandlers()[schedulers.EvictLeaderName] + if !ok { + return + } + type evictLeaderHandler interface { + EvictStoreIDs() []uint64 + } + h, ok := handler.(evictLeaderHandler) + if !ok { + return + } + return h.EvictStoreIDs() +} + +// IsPrepared return true if the prepare checker is ready. +func (sc *schedulingController) IsPrepared() bool { + return sc.coordinator.GetPrepareChecker().IsPrepared() +} + +// SetPrepared set the prepare check to prepared. Only for test purpose. +func (sc *schedulingController) SetPrepared() { + sc.coordinator.GetPrepareChecker().SetPrepared() +} diff --git a/server/grpc_service.go b/server/grpc_service.go index 05ec38919cb..34741d4da5b 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -1002,7 +1002,7 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear s.handleDamagedStore(request.GetStats()) storeHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, _ := s.updateSchedulingClient(ctx) if forwardCli != nil { req := &schedulingpb.StoreHeartbeatRequest{ @@ -1360,7 +1360,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error continue } - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { ctx := stream.Context() primaryAddr, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName) if schedulingStream == nil || lastPrimaryAddr != primaryAddr { @@ -1632,7 +1632,7 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, err := s.updateSchedulingClient(ctx) if err != nil { return &pdpb.AskBatchSplitResponse{ @@ -1805,7 +1805,7 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, err := s.updateSchedulingClient(ctx) if err != nil { return &pdpb.ScatterRegionResponse{ @@ -2028,7 +2028,7 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, err := s.updateSchedulingClient(ctx) if err != nil { return &pdpb.GetOperatorResponse{ @@ -2300,7 +2300,7 @@ func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, err := s.updateSchedulingClient(ctx) if err != nil { return &pdpb.SplitRegionsResponse{ diff --git a/server/handler.go b/server/handler.go index dc4b43238d0..6c0679bd9f9 100644 --- a/server/handler.go +++ b/server/handler.go @@ -30,6 +30,7 @@ import ( "github.com/tikv/pd/pkg/core/storelimit" "github.com/tikv/pd/pkg/encryption" "github.com/tikv/pd/pkg/errs" + mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/schedule" sc "github.com/tikv/pd/pkg/schedule/config" sche "github.com/tikv/pd/pkg/schedule/core" @@ -192,7 +193,7 @@ func (h *Handler) AddScheduler(name string, args ...string) error { } var removeSchedulerCb func(string) error - if h.s.IsAPIServiceMode() { + if c.IsServiceIndependent(mcsutils.SchedulingServiceName) { removeSchedulerCb = c.GetCoordinator().GetSchedulersController().RemoveSchedulerHandler } else { removeSchedulerCb = c.GetCoordinator().GetSchedulersController().RemoveScheduler @@ -202,7 +203,7 @@ func (h *Handler) AddScheduler(name string, args ...string) error { return err } log.Info("create scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", args)) - if h.s.IsAPIServiceMode() { + if c.IsServiceIndependent(mcsutils.SchedulingServiceName) { if err = c.AddSchedulerHandler(s, args...); err != nil { log.Error("can not add scheduler handler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", args), errs.ZapError(err)) return err @@ -229,7 +230,7 @@ func (h *Handler) RemoveScheduler(name string) error { if err != nil { return err } - if h.s.IsAPIServiceMode() { + if c.IsServiceIndependent(mcsutils.SchedulingServiceName) { if err = c.RemoveSchedulerHandler(name); err != nil { log.Error("can not remove scheduler handler", zap.String("scheduler-name", name), errs.ZapError(err)) } else { diff --git a/server/server.go b/server/server.go index 9cd7f18578e..a2c99d0cbec 100644 --- a/server/server.go +++ b/server/server.go @@ -489,7 +489,7 @@ func (s *Server) startServer(ctx context.Context) error { s.safePointV2Manager = gc.NewSafePointManagerV2(s.ctx, s.storage, s.storage, s.storage) s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, "", s.cluster) // initial hot_region_storage in here. - if !s.IsAPIServiceMode() { + if !s.IsServiceIndependent(mcs.SchedulingServiceName) { s.hotRegionStorage, err = storage.NewHotRegionsStorage( ctx, filepath.Join(s.cfg.DataDir, "hot-region"), s.encryptionKeyManager, s.handler) if err != nil { @@ -1394,6 +1394,15 @@ func (s *Server) GetRegions() []*core.RegionInfo { return nil } +// IsServiceIndependent returns if the service is enabled +func (s *Server) IsServiceIndependent(name string) bool { + rc := s.GetRaftCluster() + if rc != nil { + return rc.IsServiceIndependent(name) + } + return false +} + // GetServiceLabels returns ApiAccessPaths by given service label // TODO: this function will be used for updating api rate limit config func (s *Server) GetServiceLabels(serviceLabel string) []apiutil.AccessPath { diff --git a/tests/integrations/mcs/scheduling/config_test.go b/tests/integrations/mcs/scheduling/config_test.go index 8b8e284f765..42ba051eb84 100644 --- a/tests/integrations/mcs/scheduling/config_test.go +++ b/tests/integrations/mcs/scheduling/config_test.go @@ -133,7 +133,6 @@ func persistConfig(re *require.Assertions, pdLeaderServer *tests.TestServer) { func (suite *configTestSuite) TestSchedulerConfigWatch() { re := suite.Require() - // Make sure the config is persisted before the watcher is created. persistConfig(re, suite.pdLeaderServer) // Create a config watcher. diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 701eb9b5d69..ebf0a4e574d 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -815,7 +815,7 @@ func TestLoadClusterInfo(t *testing.T) { rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) // Cluster is not bootstrapped. - rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster(), svr.GetKeyspaceGroupManager()) + rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster(), svr.GetHBStreams(), svr.GetKeyspaceGroupManager()) raftCluster, err := rc.LoadClusterInfo() re.NoError(err) re.Nil(raftCluster) @@ -853,7 +853,7 @@ func TestLoadClusterInfo(t *testing.T) { re.NoError(testStorage.Flush()) raftCluster = cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) - raftCluster.InitCluster(mockid.NewIDAllocator(), svr.GetPersistOptions(), testStorage, basicCluster, svr.GetKeyspaceGroupManager()) + raftCluster.InitCluster(mockid.NewIDAllocator(), svr.GetPersistOptions(), testStorage, basicCluster, svr.GetHBStreams(), svr.GetKeyspaceGroupManager()) raftCluster, err = raftCluster.LoadClusterInfo() re.NoError(err) re.NotNil(raftCluster) @@ -1561,7 +1561,7 @@ func TestTransferLeaderBack(t *testing.T) { leaderServer := tc.GetLeaderServer() svr := leaderServer.GetServer() rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) - rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster(), svr.GetKeyspaceGroupManager()) + rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster(), svr.GetHBStreams(), svr.GetKeyspaceGroupManager()) storage := rc.GetStorage() meta := &metapb.Cluster{Id: 123} re.NoError(storage.SaveMeta(meta)) From e6e35fdd4eb5b77a5d8b9a5bb96144be3c1461e9 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Mon, 13 Nov 2023 12:15:43 +0800 Subject: [PATCH 21/26] api: add rule middleware (#7357) ref tikv/pd#5839 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/api/router.go | 44 ++++----- server/api/rule.go | 220 +++++++++++-------------------------------- 2 files changed, 80 insertions(+), 184 deletions(-) diff --git a/server/api/router.go b/server/api/router.go index 0473e3e1bf7..d3c8f10cbf2 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -174,29 +174,31 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/config/replication-mode", confHandler.SetReplicationModeConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) rulesHandler := newRulesHandler(svr, rd) - registerFunc(clusterRouter, "/config/rules", rulesHandler.GetAllRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules", rulesHandler.SetAllRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rules/batch", rulesHandler.BatchRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rules/group/{group}", rulesHandler.GetRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/region/{region}", rulesHandler.GetRulesByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/region/{region}/detail", rulesHandler.CheckRegionPlacementRule, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rules/key/{key}", rulesHandler.GetRulesByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.GetRuleByGroupAndID, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule", rulesHandler.SetRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.DeleteRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) - - registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/rule_group", rulesHandler.SetGroupConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods(http.MethodGet), setAuditBackend(prometheus)) - - registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.GetPlacementRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.SetPlacementRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + ruleRouter := clusterRouter.NewRoute().Subrouter() + ruleRouter.Use(newRuleMiddleware(svr, rd).Middleware) + registerFunc(ruleRouter, "/config/rules", rulesHandler.GetAllRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules", rulesHandler.SetAllRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rules/batch", rulesHandler.BatchRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rules/group/{group}", rulesHandler.GetRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/region/{region}", rulesHandler.GetRulesByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/region/{region}/detail", rulesHandler.CheckRegionPlacementRule, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rules/key/{key}", rulesHandler.GetRulesByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule/{group}/{id}", rulesHandler.GetRuleByGroupAndID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule", rulesHandler.SetRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule/{group}/{id}", rulesHandler.DeleteRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + registerFunc(ruleRouter, "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/rule_group", rulesHandler.SetGroupConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + registerFunc(ruleRouter, "/config/placement-rule", rulesHandler.GetPlacementRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/placement-rule", rulesHandler.SetPlacementRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) // {group} can be a regular expression, we should enable path encode to // support special characters. - registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.GetPlacementRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.SetPlacementRuleByGroup, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(escapeRouter, "/config/placement-rule/{group}", rulesHandler.DeletePlacementRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.GetPlacementRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.SetPlacementRuleByGroup, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(ruleRouter, "/config/placement-rule/{group}", rulesHandler.DeletePlacementRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) regionLabelHandler := newRegionLabelHandler(svr, rd) registerFunc(clusterRouter, "/config/region-label/rules", regionLabelHandler.GetAllRegionLabelRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) diff --git a/server/api/rule.go b/server/api/rule.go index 77aad42eb42..47964d594be 100644 --- a/server/api/rule.go +++ b/server/api/rule.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/hex" "fmt" "net/http" @@ -42,6 +43,42 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { } } +type ruleMiddleware struct { + s *server.Server + rd *render.Render + *server.Handler +} + +func newRuleMiddleware(s *server.Server, rd *render.Render) ruleMiddleware { + return ruleMiddleware{ + s: s, + rd: rd, + Handler: s.GetHandler(), + } +} + +func (m ruleMiddleware) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + manager, err := m.GetRuleManager() + if err == errs.ErrPlacementDisabled { + m.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + m.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + ctx := context.WithValue(r.Context(), ruleCtxKey{}, manager) + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type ruleCtxKey struct{} + +func getRuleManager(r *http.Request) *placement.RuleManager { + return r.Context().Value(ruleCtxKey{}).(*placement.RuleManager) +} + // @Tags rule // @Summary List all rules of cluster. // @Produce json @@ -50,15 +87,7 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [get] func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) rules := manager.GetAllRules() h.rd.JSON(w, http.StatusOK, rules) } @@ -73,15 +102,7 @@ func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [post] func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var rules []*placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rules); err != nil { return @@ -113,15 +134,7 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/group/{group} [get] func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group := mux.Vars(r)["group"] rules := manager.GetRulesByGroup(group) h.rd.JSON(w, http.StatusOK, rules) @@ -138,15 +151,7 @@ func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region} [get] func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) regionStr := mux.Vars(r)["region"] region, code, err := h.PreCheckForRegion(regionStr) if err != nil { @@ -196,15 +201,7 @@ func (h *ruleHandler) CheckRegionPlacementRule(w http.ResponseWriter, r *http.Re // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/key/{key} [get] func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) keyHex := mux.Vars(r)["key"] key, err := hex.DecodeString(keyHex) if err != nil { @@ -225,15 +222,7 @@ func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { // @Failure 412 {string} string "Placement rules feature is disabled." // @Router /config/rule/{group}/{id} [get] func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] rule := manager.GetRule(group, id) if rule == nil { @@ -254,15 +243,7 @@ func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule [post] func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var rule placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rule); err != nil { return @@ -312,15 +293,7 @@ func (h *ruleHandler) syncReplicateConfigWithDefaultRule(rule *placement.Rule) e // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule/{group}/{id} [delete] func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] rule := manager.GetRule(group, id) if err := manager.DeleteRule(group, id); err != nil { @@ -345,15 +318,7 @@ func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/batch [post] func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var opts []placement.RuleOp if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &opts); err != nil { return @@ -380,15 +345,7 @@ func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [get] func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) id := mux.Vars(r)["id"] group := manager.GetRuleGroup(id) if group == nil { @@ -409,15 +366,7 @@ func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group [post] func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var ruleGroup placement.RuleGroup if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &ruleGroup); err != nil { return @@ -442,17 +391,9 @@ func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [delete] func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) id := mux.Vars(r)["id"] - err = manager.DeleteRuleGroup(id) + err := manager.DeleteRuleGroup(id) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return @@ -472,15 +413,7 @@ func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_groups [get] func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) ruleGroups := manager.GetRuleGroups() h.rd.JSON(w, http.StatusOK, ruleGroups) } @@ -493,15 +426,7 @@ func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [get] func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) bundles := manager.GetAllGroupBundles() h.rd.JSON(w, http.StatusOK, bundles) } @@ -516,15 +441,7 @@ func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [post] func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) var groups []placement.GroupBundle if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &groups); err != nil { return @@ -551,15 +468,7 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [get] func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) g := mux.Vars(r)["group"] group := manager.GetGroupBundle(g) h.rd.JSON(w, http.StatusOK, group) @@ -576,16 +485,9 @@ func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [delete] func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) group := mux.Vars(r)["group"] + var err error group, err = url.PathUnescape(group) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -608,15 +510,7 @@ func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http. // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [post] func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - manager, err := h.Handler.GetRuleManager() - if err == errs.ErrPlacementDisabled { - h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) - return - } - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) - return - } + manager := getRuleManager(r) groupID := mux.Vars(r)["group"] var group placement.GroupBundle if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &group); err != nil { From 7dbe6079ca2f303a2f5886ee7e8aac6c54c2532d Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 13 Nov 2023 13:41:13 +0800 Subject: [PATCH 22/26] client: introduce the HTTP client (#7304) ref tikv/pd#7300 Introduce the HTTP client. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/client.go | 4 +- client/http/api.go | 54 +++ client/http/client.go | 337 ++++++++++++++++++ client/http/types.go | 178 +++++++++ client/http/types_test.go | 49 +++ tests/integrations/client/http_client_test.go | 87 +++++ 6 files changed, 707 insertions(+), 2 deletions(-) create mode 100644 client/http/api.go create mode 100644 client/http/client.go create mode 100644 client/http/types.go create mode 100644 client/http/types_test.go create mode 100644 tests/integrations/client/http_client_test.go diff --git a/client/client.go b/client/client.go index 067872d2d39..56923b697e2 100644 --- a/client/client.go +++ b/client/client.go @@ -74,7 +74,7 @@ type GlobalConfigItem struct { PayLoad []byte } -// Client is a PD (Placement Driver) client. +// Client is a PD (Placement Driver) RPC client. // It should not be used after calling Close(). type Client interface { // GetClusterID gets the cluster ID from PD. @@ -1062,7 +1062,7 @@ func (c *client) ScanRegions(ctx context.Context, key, endKey []byte, limit int) defer span.Finish() } start := time.Now() - defer cmdDurationScanRegions.Observe(time.Since(start).Seconds()) + defer func() { cmdDurationScanRegions.Observe(time.Since(start).Seconds()) }() var cancel context.CancelFunc scanCtx := ctx diff --git a/client/http/api.go b/client/http/api.go new file mode 100644 index 00000000000..5326919561d --- /dev/null +++ b/client/http/api.go @@ -0,0 +1,54 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "fmt" + "net/url" +) + +// The following constants are the paths of PD HTTP APIs. +const ( + HotRead = "/pd/api/v1/hotspot/regions/read" + HotWrite = "/pd/api/v1/hotspot/regions/write" + Regions = "/pd/api/v1/regions" + regionByID = "/pd/api/v1/region/id" + regionByKey = "/pd/api/v1/region/key" + regionsByKey = "/pd/api/v1/regions/key" + regionsByStoreID = "/pd/api/v1/regions/store" + Stores = "/pd/api/v1/stores" + MinResolvedTSPrefix = "/pd/api/v1/min-resolved-ts" +) + +// RegionByID returns the path of PD HTTP API to get region by ID. +func RegionByID(regionID uint64) string { + return fmt.Sprintf("%s/%d", regionByID, regionID) +} + +// RegionByKey returns the path of PD HTTP API to get region by key. +func RegionByKey(key []byte) string { + return fmt.Sprintf("%s/%s", regionByKey, url.QueryEscape(string(key))) +} + +// RegionsByKey returns the path of PD HTTP API to scan regions with given start key, end key and limit parameters. +func RegionsByKey(startKey, endKey []byte, limit int) string { + return fmt.Sprintf("%s?start_key=%s&end_key=%s&limit=%d", + regionsByKey, url.QueryEscape(string(startKey)), url.QueryEscape(string(endKey)), limit) +} + +// RegionsByStoreID returns the path of PD HTTP API to get regions by store ID. +func RegionsByStoreID(storeID uint64) string { + return fmt.Sprintf("%s/%d", regionsByStoreID, storeID) +} diff --git a/client/http/client.go b/client/http/client.go new file mode 100644 index 00000000000..6cb1277dfcb --- /dev/null +++ b/client/http/client.go @@ -0,0 +1,337 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +const ( + httpScheme = "http" + httpsScheme = "https" + networkErrorStatus = "network error" + + defaultTimeout = 30 * time.Second +) + +// Client is a PD (Placement Driver) HTTP client. +type Client interface { + GetRegionByID(context.Context, uint64) (*RegionInfo, error) + GetRegionByKey(context.Context, []byte) (*RegionInfo, error) + GetRegions(context.Context) (*RegionsInfo, error) + GetRegionsByKey(context.Context, []byte, []byte, int) (*RegionsInfo, error) + GetRegionsByStoreID(context.Context, uint64) (*RegionsInfo, error) + GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) + GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) + GetStores(context.Context) (*StoresInfo, error) + GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) + Close() +} + +var _ Client = (*client)(nil) + +type client struct { + pdAddrs []string + tlsConf *tls.Config + cli *http.Client + + requestCounter *prometheus.CounterVec + executionDuration *prometheus.HistogramVec +} + +// ClientOption configures the HTTP client. +type ClientOption func(c *client) + +// WithHTTPClient configures the client with the given initialized HTTP client. +func WithHTTPClient(cli *http.Client) ClientOption { + return func(c *client) { + c.cli = cli + } +} + +// WithTLSConfig configures the client with the given TLS config. +// This option won't work if the client is configured with WithHTTPClient. +func WithTLSConfig(tlsConf *tls.Config) ClientOption { + return func(c *client) { + c.tlsConf = tlsConf + } +} + +// WithMetrics configures the client with metrics. +func WithMetrics( + requestCounter *prometheus.CounterVec, + executionDuration *prometheus.HistogramVec, +) ClientOption { + return func(c *client) { + c.requestCounter = requestCounter + c.executionDuration = executionDuration + } +} + +// NewClient creates a PD HTTP client with the given PD addresses and TLS config. +func NewClient( + pdAddrs []string, + opts ...ClientOption, +) Client { + c := &client{} + // Apply the options first. + for _, opt := range opts { + opt(c) + } + // Normalize the addresses with correct scheme prefix. + for i, addr := range pdAddrs { + if !strings.HasPrefix(addr, httpScheme) { + var scheme string + if c.tlsConf != nil { + scheme = httpsScheme + } else { + scheme = httpScheme + } + pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) + } + } + c.pdAddrs = pdAddrs + // Init the HTTP client if it's not configured. + if c.cli == nil { + c.cli = &http.Client{Timeout: defaultTimeout} + if c.tlsConf != nil { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = c.tlsConf + c.cli.Transport = transport + } + } + + return c +} + +// Close closes the HTTP client. +func (c *client) Close() { + if c.cli != nil { + c.cli.CloseIdleConnections() + } + log.Info("[pd] http client closed") +} + +func (c *client) reqCounter(name, status string) { + if c.requestCounter == nil { + return + } + c.requestCounter.WithLabelValues(name, status).Inc() +} + +func (c *client) execDuration(name string, duration time.Duration) { + if c.executionDuration == nil { + return + } + c.executionDuration.WithLabelValues(name).Observe(duration.Seconds()) +} + +// At present, we will use the retry strategy of polling by default to keep +// it consistent with the current implementation of some clients (e.g. TiDB). +func (c *client) requestWithRetry( + ctx context.Context, + name, uri string, + res interface{}, +) error { + var ( + err error + addr string + ) + for idx := 0; idx < len(c.pdAddrs); idx++ { + addr = c.pdAddrs[idx] + err = c.request(ctx, name, addr, uri, res) + if err == nil { + break + } + log.Debug("[pd] request one addr failed", + zap.Int("idx", idx), zap.String("addr", addr), zap.Error(err)) + } + return err +} + +func (c *client) request( + ctx context.Context, + name, addr, uri string, + res interface{}, +) error { + reqURL := fmt.Sprintf("%s%s", addr, uri) + logFields := []zap.Field{ + zap.String("name", name), + zap.String("url", reqURL), + } + log.Debug("[pd] request the http url", logFields...) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + log.Error("[pd] create http request failed", append(logFields, zap.Error(err))...) + return errors.Trace(err) + } + start := time.Now() + resp, err := c.cli.Do(req) + if err != nil { + c.reqCounter(name, networkErrorStatus) + log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) + return errors.Trace(err) + } + c.execDuration(name, time.Since(start)) + c.reqCounter(name, resp.Status) + defer func() { + err = resp.Body.Close() + if err != nil { + log.Warn("[pd] close http response body failed", append(logFields, zap.Error(err))...) + } + }() + + if resp.StatusCode != http.StatusOK { + logFields = append(logFields, zap.String("status", resp.Status)) + + bs, readErr := io.ReadAll(resp.Body) + if readErr != nil { + logFields = append(logFields, zap.NamedError("read-body-error", err)) + } else { + logFields = append(logFields, zap.ByteString("body", bs)) + } + + log.Error("[pd] request failed with a non-200 status", logFields...) + return errors.Errorf("request pd http api failed with status: '%s'", resp.Status) + } + + err = json.NewDecoder(resp.Body).Decode(res) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// GetRegionByID gets the region info by ID. +func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) { + var region RegionInfo + err := c.requestWithRetry(ctx, "GetRegionByID", RegionByID(regionID), ®ion) + if err != nil { + return nil, err + } + return ®ion, nil +} + +// GetRegionByKey gets the region info by key. +func (c *client) GetRegionByKey(ctx context.Context, key []byte) (*RegionInfo, error) { + var region RegionInfo + err := c.requestWithRetry(ctx, "GetRegionByKey", RegionByKey(key), ®ion) + if err != nil { + return nil, err + } + return ®ion, nil +} + +// GetRegions gets the regions info. +func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.requestWithRetry(ctx, "GetRegions", Regions, ®ions) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetRegionsByKey gets the regions info by key range. If the limit is -1, it will return all regions within the range. +func (c *client) GetRegionsByKey(ctx context.Context, startKey, endKey []byte, limit int) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.requestWithRetry(ctx, "GetRegionsByKey", RegionsByKey(startKey, endKey, limit), ®ions) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetRegionsByStoreID gets the regions info by store ID. +func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.requestWithRetry(ctx, "GetRegionsByStoreID", RegionsByStoreID(storeID), ®ions) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetHotReadRegions gets the hot read region statistics info. +func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, error) { + var hotReadRegions StoreHotPeersInfos + err := c.requestWithRetry(ctx, "GetHotReadRegions", HotRead, &hotReadRegions) + if err != nil { + return nil, err + } + return &hotReadRegions, nil +} + +// GetHotWriteRegions gets the hot write region statistics info. +func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, error) { + var hotWriteRegions StoreHotPeersInfos + err := c.requestWithRetry(ctx, "GetHotWriteRegions", HotWrite, &hotWriteRegions) + if err != nil { + return nil, err + } + return &hotWriteRegions, nil +} + +// GetStores gets the stores info. +func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { + var stores StoresInfo + err := c.requestWithRetry(ctx, "GetStores", Stores, &stores) + if err != nil { + return nil, err + } + return &stores, nil +} + +// GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. +func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { + uri := MinResolvedTSPrefix + // scope is an optional parameter, it can be `cluster` or specified store IDs. + // - When no scope is given, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be nil. + // - When scope is `cluster`, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be filled. + // - When scope given a list of stores, min_resolved_ts will be provided for each store + // and the scope-specific min_resolved_ts will be returned. + if len(storeIDs) != 0 { + storeIDStrs := make([]string, len(storeIDs)) + for idx, id := range storeIDs { + storeIDStrs[idx] = fmt.Sprintf("%d", id) + } + uri = fmt.Sprintf("%s?scope=%s", uri, strings.Join(storeIDStrs, ",")) + } + resp := struct { + MinResolvedTS uint64 `json:"min_resolved_ts"` + IsRealTime bool `json:"is_real_time,omitempty"` + StoresMinResolvedTS map[uint64]uint64 `json:"stores_min_resolved_ts"` + }{} + err := c.requestWithRetry(ctx, "GetMinResolvedTSByStoresIDs", uri, &resp) + if err != nil { + return 0, nil, err + } + if !resp.IsRealTime { + return 0, nil, errors.Trace(errors.New("min resolved ts is not enabled")) + } + return resp.MinResolvedTS, resp.StoresMinResolvedTS, nil +} diff --git a/client/http/types.go b/client/http/types.go new file mode 100644 index 00000000000..66eb31ec3a1 --- /dev/null +++ b/client/http/types.go @@ -0,0 +1,178 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import "time" + +// NOTICE: the structures below are copied from the PD API definitions. +// Please make sure the consistency if any change happens to the PD API. + +// RegionInfo stores the information of one region. +type RegionInfo struct { + ID int64 `json:"id"` + StartKey string `json:"start_key"` + EndKey string `json:"end_key"` + Epoch RegionEpoch `json:"epoch"` + Peers []RegionPeer `json:"peers"` + Leader RegionPeer `json:"leader"` + DownPeers []RegionPeerStat `json:"down_peers"` + PendingPeers []RegionPeer `json:"pending_peers"` + WrittenBytes uint64 `json:"written_bytes"` + ReadBytes uint64 `json:"read_bytes"` + ApproximateSize int64 `json:"approximate_size"` + ApproximateKeys int64 `json:"approximate_keys"` + + ReplicationStatus *ReplicationStatus `json:"replication_status,omitempty"` +} + +// GetStartKey gets the start key of the region. +func (r *RegionInfo) GetStartKey() string { return r.StartKey } + +// GetEndKey gets the end key of the region. +func (r *RegionInfo) GetEndKey() string { return r.EndKey } + +// RegionEpoch stores the information about its epoch. +type RegionEpoch struct { + ConfVer int64 `json:"conf_ver"` + Version int64 `json:"version"` +} + +// RegionPeer stores information of one peer. +type RegionPeer struct { + ID int64 `json:"id"` + StoreID int64 `json:"store_id"` + IsLearner bool `json:"is_learner"` +} + +// RegionPeerStat stores one field `DownSec` which indicates how long it's down than `RegionPeer`. +type RegionPeerStat struct { + Peer RegionPeer `json:"peer"` + DownSec int64 `json:"down_seconds"` +} + +// ReplicationStatus represents the replication mode status of the region. +type ReplicationStatus struct { + State string `json:"state"` + StateID int64 `json:"state_id"` +} + +// RegionsInfo stores the information of regions. +type RegionsInfo struct { + Count int64 `json:"count"` + Regions []RegionInfo `json:"regions"` +} + +// Merge merges two RegionsInfo together and returns a new one. +func (ri *RegionsInfo) Merge(other *RegionsInfo) *RegionsInfo { + newRegionsInfo := &RegionsInfo{ + Regions: make([]RegionInfo, 0, ri.Count+other.Count), + } + m := make(map[int64]RegionInfo, ri.Count+other.Count) + for _, region := range ri.Regions { + m[region.ID] = region + } + for _, region := range other.Regions { + m[region.ID] = region + } + for _, region := range m { + newRegionsInfo.Regions = append(newRegionsInfo.Regions, region) + } + newRegionsInfo.Count = int64(len(newRegionsInfo.Regions)) + return newRegionsInfo +} + +// StoreHotPeersInfos is used to get human-readable description for hot regions. +type StoreHotPeersInfos struct { + AsPeer StoreHotPeersStat `json:"as_peer"` + AsLeader StoreHotPeersStat `json:"as_leader"` +} + +// StoreHotPeersStat is used to record the hot region statistics group by store. +type StoreHotPeersStat map[uint64]*HotPeersStat + +// HotPeersStat records all hot regions statistics +type HotPeersStat struct { + StoreByteRate float64 `json:"store_bytes"` + StoreKeyRate float64 `json:"store_keys"` + StoreQueryRate float64 `json:"store_query"` + TotalBytesRate float64 `json:"total_flow_bytes"` + TotalKeysRate float64 `json:"total_flow_keys"` + TotalQueryRate float64 `json:"total_flow_query"` + Count int `json:"regions_count"` + Stats []HotPeerStatShow `json:"statistics"` +} + +// HotPeerStatShow records the hot region statistics for output +type HotPeerStatShow struct { + StoreID uint64 `json:"store_id"` + Stores []uint64 `json:"stores"` + IsLeader bool `json:"is_leader"` + IsLearner bool `json:"is_learner"` + RegionID uint64 `json:"region_id"` + HotDegree int `json:"hot_degree"` + ByteRate float64 `json:"flow_bytes"` + KeyRate float64 `json:"flow_keys"` + QueryRate float64 `json:"flow_query"` + AntiCount int `json:"anti_count"` + LastUpdateTime time.Time `json:"last_update_time,omitempty"` +} + +// StoresInfo represents the information of all TiKV/TiFlash stores. +type StoresInfo struct { + Count int `json:"count"` + Stores []StoreInfo `json:"stores"` +} + +// StoreInfo represents the information of one TiKV/TiFlash store. +type StoreInfo struct { + Store MetaStore `json:"store"` + Status StoreStatus `json:"status"` +} + +// MetaStore represents the meta information of one store. +type MetaStore struct { + ID int64 `json:"id"` + Address string `json:"address"` + State int64 `json:"state"` + StateName string `json:"state_name"` + Version string `json:"version"` + Labels []StoreLabel `json:"labels"` + StatusAddress string `json:"status_address"` + GitHash string `json:"git_hash"` + StartTimestamp int64 `json:"start_timestamp"` +} + +// StoreLabel stores the information of one store label. +type StoreLabel struct { + Key string `json:"key"` + Value string `json:"value"` +} + +// StoreStatus stores the detail information of one store. +type StoreStatus struct { + Capacity string `json:"capacity"` + Available string `json:"available"` + LeaderCount int64 `json:"leader_count"` + LeaderWeight float64 `json:"leader_weight"` + LeaderScore float64 `json:"leader_score"` + LeaderSize int64 `json:"leader_size"` + RegionCount int64 `json:"region_count"` + RegionWeight float64 `json:"region_weight"` + RegionScore float64 `json:"region_score"` + RegionSize int64 `json:"region_size"` + StartTS time.Time `json:"start_ts"` + LastHeartbeatTS time.Time `json:"last_heartbeat_ts"` + Uptime string `json:"uptime"` +} diff --git a/client/http/types_test.go b/client/http/types_test.go new file mode 100644 index 00000000000..0dfebacbdcf --- /dev/null +++ b/client/http/types_test.go @@ -0,0 +1,49 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMergeRegionsInfo(t *testing.T) { + re := require.New(t) + regionsInfo1 := &RegionsInfo{ + Count: 1, + Regions: []RegionInfo{ + { + ID: 1, + StartKey: "", + EndKey: "a", + }, + }, + } + regionsInfo2 := &RegionsInfo{ + Count: 1, + Regions: []RegionInfo{ + { + ID: 2, + StartKey: "a", + EndKey: "", + }, + }, + } + regionsInfo := regionsInfo1.Merge(regionsInfo2) + re.Equal(int64(2), regionsInfo.Count) + re.Equal(2, len(regionsInfo.Regions)) + re.Equal(append(regionsInfo1.Regions, regionsInfo2.Regions...), regionsInfo.Regions) +} diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go new file mode 100644 index 00000000000..03d90c6cd32 --- /dev/null +++ b/tests/integrations/client/http_client_test.go @@ -0,0 +1,87 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client_test + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/suite" + pd "github.com/tikv/pd/client/http" + "github.com/tikv/pd/tests" +) + +type httpClientTestSuite struct { + suite.Suite + ctx context.Context + cancelFunc context.CancelFunc + cluster *tests.TestCluster + client pd.Client +} + +func TestHTTPClientTestSuite(t *testing.T) { + suite.Run(t, new(httpClientTestSuite)) +} + +func (suite *httpClientTestSuite) SetupSuite() { + re := suite.Require() + var err error + suite.ctx, suite.cancelFunc = context.WithCancel(context.Background()) + suite.cluster, err = tests.NewTestCluster(suite.ctx, 1) + re.NoError(err) + err = suite.cluster.RunInitialServers() + re.NoError(err) + leader := suite.cluster.WaitLeader() + re.NotEmpty(leader) + err = suite.cluster.GetLeaderServer().BootstrapCluster() + re.NoError(err) + var ( + testServers = suite.cluster.GetServers() + endpoints = make([]string, 0, len(testServers)) + ) + for _, s := range testServers { + endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) + } + suite.client = pd.NewClient(endpoints) +} + +func (suite *httpClientTestSuite) TearDownSuite() { + suite.cancelFunc() + suite.client.Close() + suite.cluster.Destroy() +} + +func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { + re := suite.Require() + // Get the cluster-level min resolved TS. + minResolvedTS, storeMinResolvedTSMap, err := suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, nil) + re.NoError(err) + re.Greater(minResolvedTS, uint64(0)) + re.Empty(storeMinResolvedTSMap) + // Get the store-level min resolved TS. + minResolvedTS, storeMinResolvedTSMap, err = suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, []uint64{1}) + re.NoError(err) + re.Greater(minResolvedTS, uint64(0)) + re.Len(storeMinResolvedTSMap, 1) + re.Equal(minResolvedTS, storeMinResolvedTSMap[1]) + // Get the store-level min resolved TS with an invalid store ID. + minResolvedTS, storeMinResolvedTSMap, err = suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, []uint64{1, 2}) + re.NoError(err) + re.Greater(minResolvedTS, uint64(0)) + re.Len(storeMinResolvedTSMap, 2) + re.Equal(minResolvedTS, storeMinResolvedTSMap[1]) + re.Equal(uint64(math.MaxUint64), storeMinResolvedTSMap[2]) +} From be31c08186fa2b6c154532d1130e5727a4631473 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 13 Nov 2023 17:24:44 +0800 Subject: [PATCH 23/26] resource_controller: prevent loadServerConfig from panic (#7361) close tikv/pd#7360 Prevent `loadServerConfig` from panic. Signed-off-by: JmPotato --- client/resource_group/controller/controller.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/resource_group/controller/controller.go b/client/resource_group/controller/controller.go index e3495a21ff1..b528351bedf 100755 --- a/client/resource_group/controller/controller.go +++ b/client/resource_group/controller/controller.go @@ -171,12 +171,13 @@ func loadServerConfig(ctx context.Context, provider ResourceGroupProvider) (*Con if err != nil { return nil, err } - if len(resp.Kvs) == 0 { + kvs := resp.GetKvs() + if len(kvs) == 0 { log.Warn("[resource group controller] server does not save config, load config failed") return DefaultConfig(), nil } config := &Config{} - err = json.Unmarshal(resp.Kvs[0].GetValue(), config) + err = json.Unmarshal(kvs[0].GetValue(), config) if err != nil { return nil, err } From 8dcd49720cd9999119d27212220bc0b03f82a75e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 14 Nov 2023 11:29:14 +0800 Subject: [PATCH 24/26] *: Improve region forward (#7305) ref tikv/pd#5839 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/forward.go | 504 +++++++++++++++++++++++ server/grpc_service.go | 891 ++++++++++++----------------------------- 2 files changed, 750 insertions(+), 645 deletions(-) create mode 100644 server/forward.go diff --git a/server/forward.go b/server/forward.go new file mode 100644 index 00000000000..e765d442539 --- /dev/null +++ b/server/forward.go @@ -0,0 +1,504 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/tso" + "github.com/tikv/pd/pkg/utils/grpcutil" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/tsoutil" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (s *GrpcServer) forwardTSORequest( + ctx context.Context, + request *pdpb.TsoRequest, + forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + tsopbReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + DcLocation: request.GetDcLocation(), + } + + failpoint.Inject("tsoProxySendToTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if err := forwardStream.Send(tsopbReq); err != nil { + return nil, err + } + + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + return forwardStream.Recv() +} + +// forwardTSO forward the TSO requests to the TSO service. +func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { + var ( + server = &tsoServer{stream: stream} + forwardStream tsopb.TSO_TsoClient + forwardCtx context.Context + cancelForward context.CancelFunc + lastForwardedHost string + ) + defer func() { + s.concurrentTSOProxyStreamings.Add(-1) + if cancelForward != nil { + cancelForward() + } + }() + + maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) + if maxConcurrentTSOProxyStreamings >= 0 { + if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { + return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) + } + } + + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + + for { + select { + case <-s.ctx.Done(): + return errors.WithStack(s.ctx.Err()) + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Errorf(codes.Unknown, err.Error()) + } + + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + return errors.WithStack(ErrNotFoundTSOAddr) + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if cancelForward != nil { + cancelForward() + } + + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return errors.WithStack(err) + } + forwardStream, forwardCtx, cancelForward, err = s.createTSOForwardStream(stream.Context(), clientConn) + if err != nil { + return errors.WithStack(err) + } + lastForwardedHost = forwardedHost + } + + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + if err != nil { + return errors.WithStack(err) + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), + } + } + } + + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if err := server.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +func (s *GrpcServer) forwardTSORequestWithDeadLine( + forwardCtx context.Context, + cancelForward context.CancelFunc, + forwardStream tsopb.TSO_TsoClient, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { + done := make(chan struct{}) + dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + select { + case tsDeadlineCh <- dl: + case <-forwardCtx.Done(): + return nil, forwardCtx.Err() + } + + start := time.Now() + resp, err := s.forwardTSORequest(forwardCtx, request, forwardStream) + close(done) + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + } + return nil, err + } + tsoProxyBatchSize.Observe(float64(request.GetCount())) + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + return resp, nil +} + +func (s *GrpcServer) createTSOForwardStream(ctx context.Context, client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func (s *GrpcServer) createRegionHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) createRegionHeartbeatSchedulingStream(ctx context.Context, client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func forwardRegionHeartbeatToScheduling(forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err == io.EOF { + errCh <- errors.WithStack(err) + return + } + if err != nil { + errCh <- errors.WithStack(err) + return + } + // The error types defined for schedulingpb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + schedulingpbErr := resp.GetHeader().GetError() + if schedulingpbErr != nil { + if schedulingpbErr.Type == schedulingpb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: schedulingpbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: schedulingpbErr.GetMessage(), + } + } + } + response := &pdpb.RegionHeartbeatResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: resp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + ChangePeer: resp.GetChangePeer(), + TransferLeader: resp.GetTransferLeader(), + RegionId: resp.GetRegionId(), + RegionEpoch: resp.GetRegionEpoch(), + TargetPeer: resp.GetTargetPeer(), + Merge: resp.GetMerge(), + SplitRegion: resp.GetSplitRegion(), + ChangePeerV2: resp.GetChangePeerV2(), + SwitchWitnesses: resp.GetSwitchWitnesses(), + } + + if err := server.Send(response); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.CloseAndRecv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { + client, ok := s.clientConns.Load(forwardedHost) + if ok { + // Mostly, the connection is already established, and return it directly. + return client.(*grpc.ClientConn), nil + } + + tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + if err != nil { + return nil, err + } + ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) + defer cancel() + newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) + if err != nil { + return nil, err + } + conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) + if !loaded { + // Successfully stored the connection we created. + return newConn, nil + } + // Loaded a connection created/stored by another goroutine, so close the one we created + // and return the one we loaded. + newConn.Close() + return conn.(*grpc.ClientConn), nil +} + +func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context, serviceName ...string) (forwardedHost string, err error) { + if s.IsAPIServiceMode() { + var ok bool + if len(serviceName) == 0 { + return "", ErrNotFoundService + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, serviceName[0]) + if !ok || len(forwardedHost) == 0 { + switch serviceName[0] { + case utils.TSOServiceName: + return "", ErrNotFoundTSOAddr + case utils.SchedulingServiceName: + return "", ErrNotFoundSchedulingAddr + } + } + } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { + forwardedHost = fh + } + return forwardedHost, nil +} + +func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { + failpoint.Inject("useForwardRequest", func() { + failpoint.Return(false) + }) + if forwardedHost == "" { + return true + } + memberAddrs := s.GetMember().Member().GetClientUrls() + for _, addr := range memberAddrs { + if addr == forwardedHost { + return true + } + } + return false +} + +func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { + if !s.IsAPIServiceMode() { + return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) + } + request := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: s.clusterID, + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: 1, + } + var ( + forwardedHost string + forwardStream tsopb.TSO_TsoClient + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err +} + +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { + s.tsoClientPool.RLock() + forwardStream, ok := s.tsoClientPool.clients[forwardedHost] + s.tsoClientPool.RUnlock() + if ok { + // This is the common case to return here + return forwardStream, nil + } + + s.tsoClientPool.Lock() + defer s.tsoClientPool.Unlock() + + // Double check after entering the critical section + forwardStream, ok = s.tsoClientPool.clients[forwardedHost] + if ok { + return forwardStream, nil + } + + // Now let's create the client connection and the forward stream + client, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return nil, err + } + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + done <- struct{}{} + if err != nil { + return nil, err + } + s.tsoClientPool.clients[forwardedHost] = forwardStream + return forwardStream, nil +} diff --git a/server/grpc_service.go b/server/grpc_service.go index 34741d4da5b..b0384a7d629 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -70,6 +70,7 @@ var ( ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address") ErrNotFoundSchedulingAddr = status.Errorf(codes.NotFound, "not found scheduling address") + ErrNotFoundService = status.Errorf(codes.NotFound, "not found service") ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout") ErrMaxCountTSOProxyRoutinesExceeded = status.Errorf(codes.ResourceExhausted, "max count of concurrent tso proxy routines exceeded") ErrTSOProxyRecvFromClientTimeout = status.Errorf(codes.DeadlineExceeded, "tso proxy timeout when receiving from client; stream closed by server") @@ -83,9 +84,120 @@ type GrpcServer struct { concurrentTSOProxyStreamings atomic.Int32 } +// tsoServer wraps PD_TsoServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type tsoServer struct { + stream pdpb.PD_TsoServer + closed int32 +} + +type pdpbTSORequest struct { + request *pdpb.TsoRequest + err error +} + +func (s *tsoServer) Send(m *pdpb.TsoResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + failpoint.Inject("tsoProxyFailToSendToClient", func() { + done <- errors.New("injected error") + failpoint.Return() + }) + done <- s.stream.Send(m) + }() + timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) + defer timer.Stop() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return ErrForwardTSOTimeout + } +} + +func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + failpoint.Inject("tsoProxyRecvFromClientTimeout", func(val failpoint.Value) { + if customTimeoutInSeconds, ok := val.(int); ok { + timeout = time.Duration(customTimeoutInSeconds) * time.Second + } + }) + requestCh := make(chan *pdpbTSORequest, 1) + go func() { + defer logutil.LogPanic() + request, err := s.stream.Recv() + requestCh <- &pdpbTSORequest{request: request, err: err} + }() + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case req := <-requestCh: + if req.err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(req.err) + } + return req.request, nil + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return nil, ErrTSOProxyRecvFromClientTimeout + } +} + +// heartbeatServer wraps PD_RegionHeartbeatServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type heartbeatServer struct { + stream pdpb.PD_RegionHeartbeatServer + closed int32 +} + +func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + done <- s.stream.Send(m.(*pdpb.RegionHeartbeatResponse)) + }() + timer := time.NewTimer(heartbeatSendTimeout) + defer timer.Stop() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return ErrSendHeartbeatTimeout + } +} + +func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + req, err := s.stream.Recv() + if err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(err) + } + return req, nil +} + type schedulingClient struct { - client schedulingpb.SchedulingClient - lastPrimary string + client schedulingpb.SchedulingClient + primary string } func (s *schedulingClient) getClient() schedulingpb.SchedulingClient { @@ -99,7 +211,7 @@ func (s *schedulingClient) getPrimaryAddr() string { if s == nil { return "" } - return s.lastPrimary + return s.primary } type request interface { @@ -393,7 +505,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return errors.WithStack(err) } - if forwardedHost, err := s.getForwardedHost(ctx, stream.Context()); err != nil { + if forwardedHost, err := s.getForwardedHost(ctx, stream.Context(), utils.TSOServiceName); err != nil { return err } else if len(forwardedHost) > 0 { clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) @@ -440,268 +552,6 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } } -// forwardTSO forward the TSO requests to the TSO service. -func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { - var ( - server = &tsoServer{stream: stream} - forwardStream tsopb.TSO_TsoClient - forwardCtx context.Context - cancelForward context.CancelFunc - lastForwardedHost string - ) - defer func() { - s.concurrentTSOProxyStreamings.Add(-1) - if cancelForward != nil { - cancelForward() - } - }() - - maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) - if maxConcurrentTSOProxyStreamings >= 0 { - if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { - return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) - } - } - - tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) - go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) - - for { - select { - case <-s.ctx.Done(): - return errors.WithStack(s.ctx.Err()) - case <-stream.Context().Done(): - return stream.Context().Err() - default: - } - - request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) - if err == io.EOF { - return nil - } - if err != nil { - return errors.WithStack(err) - } - if request.GetCount() == 0 { - err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") - return status.Errorf(codes.Unknown, err.Error()) - } - - forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return errors.WithStack(ErrNotFoundTSOAddr) - } - if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancelForward != nil { - cancelForward() - } - - clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - return errors.WithStack(err) - } - forwardStream, forwardCtx, cancelForward, err = - s.createTSOForwardStream(stream.Context(), clientConn) - if err != nil { - return errors.WithStack(err) - } - lastForwardedHost = forwardedHost - } - - tsopbResp, err := s.forwardTSORequestWithDeadLine( - forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) - if err != nil { - return errors.WithStack(err) - } - - // The error types defined for tsopb and pdpb are different, so we need to convert them. - var pdpbErr *pdpb.Error - tsopbErr := tsopbResp.GetHeader().GetError() - if tsopbErr != nil { - if tsopbErr.Type == tsopb.ErrorType_OK { - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_OK, - Message: tsopbErr.GetMessage(), - } - } else { - // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_UNKNOWN, - Message: tsopbErr.GetMessage(), - } - } - } - - response := &pdpb.TsoResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: tsopbResp.GetHeader().GetClusterId(), - Error: pdpbErr, - }, - Count: tsopbResp.GetCount(), - Timestamp: tsopbResp.GetTimestamp(), - } - if err := server.Send(response); err != nil { - return errors.WithStack(err) - } - } -} - -func (s *GrpcServer) forwardTSORequestWithDeadLine( - forwardCtx context.Context, - cancelForward context.CancelFunc, - forwardStream tsopb.TSO_TsoClient, - request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline, -) (*tsopb.TsoResponse, error) { - done := make(chan struct{}) - dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) - select { - case tsDeadlineCh <- dl: - case <-forwardCtx.Done(): - return nil, forwardCtx.Err() - } - - start := time.Now() - resp, err := s.forwardTSORequest(forwardCtx, request, forwardStream) - close(done) - if err != nil { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - } - return nil, err - } - tsoProxyBatchSize.Observe(float64(request.GetCount())) - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) - return resp, nil -} - -func (s *GrpcServer) forwardTSORequest( - ctx context.Context, - request *pdpb.TsoRequest, - forwardStream tsopb.TSO_TsoClient, -) (*tsopb.TsoResponse, error) { - tsopbReq := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - KeyspaceId: utils.DefaultKeyspaceID, - KeyspaceGroupId: utils.DefaultKeyspaceGroupID, - }, - Count: request.GetCount(), - DcLocation: request.GetDcLocation(), - } - - failpoint.Inject("tsoProxySendToTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if err := forwardStream.Send(tsopbReq); err != nil { - return nil, err - } - - failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - return forwardStream.Recv() -} - -// tsoServer wraps PD_TsoServer to ensure when any error -// occurs on Send() or Recv(), both endpoints will be closed. -type tsoServer struct { - stream pdpb.PD_TsoServer - closed int32 -} - -type pdpbTSORequest struct { - request *pdpb.TsoRequest - err error -} - -func (s *tsoServer) Send(m *pdpb.TsoResponse) error { - if atomic.LoadInt32(&s.closed) == 1 { - return io.EOF - } - done := make(chan error, 1) - go func() { - defer logutil.LogPanic() - failpoint.Inject("tsoProxyFailToSendToClient", func() { - done <- errors.New("injected error") - failpoint.Return() - }) - done <- s.stream.Send(m) - }() - timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) - defer timer.Stop() - select { - case err := <-done: - if err != nil { - atomic.StoreInt32(&s.closed, 1) - } - return errors.WithStack(err) - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return ErrForwardTSOTimeout - } -} - -func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { - if atomic.LoadInt32(&s.closed) == 1 { - return nil, io.EOF - } - failpoint.Inject("tsoProxyRecvFromClientTimeout", func(val failpoint.Value) { - if customTimeoutInSeconds, ok := val.(int); ok { - timeout = time.Duration(customTimeoutInSeconds) * time.Second - } - }) - requestCh := make(chan *pdpbTSORequest, 1) - go func() { - defer logutil.LogPanic() - request, err := s.stream.Recv() - requestCh <- &pdpbTSORequest{request: request, err: err} - }() - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case req := <-requestCh: - if req.err != nil { - atomic.StoreInt32(&s.closed, 1) - return nil, errors.WithStack(req.err) - } - return req.request, nil - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return nil, ErrTSOProxyRecvFromClientTimeout - } -} - -func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context) (forwardedHost string, err error) { - if s.IsAPIServiceMode() { - var ok bool - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return "", ErrNotFoundTSOAddr - } - } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { - forwardedHost = fh - } - return forwardedHost, nil -} - // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { @@ -1004,7 +854,8 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear storeHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, _ := s.updateSchedulingClient(ctx) - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.StoreHeartbeatRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -1012,9 +863,10 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear }, Stats: request.GetStats(), } - if _, err := forwardCli.StoreHeartbeat(ctx, req); err != nil { + if _, err := cli.StoreHeartbeat(ctx, req); err != nil { + log.Debug("forward store heartbeat failed", zap.Error(err)) // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) } } } @@ -1031,28 +883,38 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear return resp, nil } -func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (schedulingpb.SchedulingClient, error) { +// 1. forwardedHost is empty, return nil +// 2. forwardedHost is not empty and forwardedHost is equal to pre, return pre +// 3. the rest of cases, update forwardedHost and return new client +func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (*schedulingClient, error) { forwardedHost, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName) + if forwardedHost == "" { + return nil, ErrNotFoundSchedulingAddr + } + pre := s.schedulingClient.Load() - // 1. forwardedHost is not empty and pre is empty, update the schedulingClient - // 2. forwardedHost is not empty and forwardedHost is not equal to pre, update the schedulingClient - // 3. forwardedHost is not empty and forwardedHost is equal to pre, return pre - // 4. forwardedHost is empty, return nil - if forwardedHost != "" && ((pre == nil) || (pre != nil && forwardedHost != pre.(*schedulingClient).getPrimaryAddr())) { - client, err := s.getDelegateClient(ctx, forwardedHost) - if err != nil { - log.Error("get delegate client failed", zap.Error(err)) - } - forwardCli := &schedulingClient{ - client: schedulingpb.NewSchedulingClient(client), - lastPrimary: forwardedHost, + if pre != nil && forwardedHost == pre.(*schedulingClient).getPrimaryAddr() { + return pre.(*schedulingClient), nil + } + + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + log.Error("get delegate client failed", zap.Error(err)) + return nil, err + } + forwardCli := &schedulingClient{ + client: schedulingpb.NewSchedulingClient(client), + primary: forwardedHost, + } + swapped := s.schedulingClient.CompareAndSwap(pre, forwardCli) + if swapped { + oldForwardedHost := "" + if pre != nil { + oldForwardedHost = pre.(*schedulingClient).getPrimaryAddr() } - s.schedulingClient.Store(forwardCli) - return forwardCli.getClient(), nil - } else if forwardedHost != "" && (pre != nil && forwardedHost == pre.(*schedulingClient).getPrimaryAddr()) { - return pre.(*schedulingClient).getClient(), nil + log.Info("update scheduling client", zap.String("old-forwarded-host", oldForwardedHost), zap.String("new-forwarded-host", forwardedHost)) } - return nil, ErrNotFoundSchedulingAddr + return forwardCli, nil } // bucketHeartbeatServer wraps PD_ReportBucketsServer to ensure when any error @@ -1097,48 +959,6 @@ func (b *bucketHeartbeatServer) Recv() (*pdpb.ReportBucketsRequest, error) { return req, nil } -// heartbeatServer wraps PD_RegionHeartbeatServer to ensure when any error -// occurs on Send() or Recv(), both endpoints will be closed. -type heartbeatServer struct { - stream pdpb.PD_RegionHeartbeatServer - closed int32 -} - -func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { - if atomic.LoadInt32(&s.closed) == 1 { - return io.EOF - } - done := make(chan error, 1) - go func() { - defer logutil.LogPanic() - done <- s.stream.Send(m.(*pdpb.RegionHeartbeatResponse)) - }() - timer := time.NewTimer(heartbeatSendTimeout) - defer timer.Stop() - select { - case err := <-done: - if err != nil { - atomic.StoreInt32(&s.closed, 1) - } - return errors.WithStack(err) - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return ErrSendHeartbeatTimeout - } -} - -func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { - if atomic.LoadInt32(&s.closed) == 1 { - return nil, io.EOF - } - req, err := s.stream.Recv() - if err != nil { - atomic.StoreInt32(&s.closed, 1) - return nil, errors.WithStack(err) - } - return req, nil -} - // ReportBuckets implements gRPC PDServer func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { var ( @@ -1236,16 +1056,16 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { // RegionHeartbeat implements gRPC PDServer. func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error { var ( - server = &heartbeatServer{stream: stream} - flowRoundOption = core.WithFlowRoundByDigit(s.persistOptions.GetPDServerConfig().FlowRoundByDigit) - forwardStream pdpb.PD_RegionHeartbeatClient - cancel context.CancelFunc - lastForwardedHost string - lastBind time.Time - errCh chan error - schedulingStream schedulingpb.Scheduling_RegionHeartbeatClient - cancel1 context.CancelFunc - lastPrimaryAddr string + server = &heartbeatServer{stream: stream} + flowRoundOption = core.WithFlowRoundByDigit(s.persistOptions.GetPDServerConfig().FlowRoundByDigit) + cancel context.CancelFunc + lastBind time.Time + errCh chan error + forwardStream pdpb.PD_RegionHeartbeatClient + lastForwardedHost string + forwardErrCh chan error + forwardSchedulingStream schedulingpb.Scheduling_RegionHeartbeatClient + lastForwardedSchedulingHost string ) defer func() { // cancel the forward stream @@ -1262,8 +1082,10 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error if err != nil { return errors.WithStack(err) } - forwardedHost := grpcutil.GetForwardedHost(stream.Context()) + failpoint.Inject("grpcClientClosed", func() { + forwardedHost = s.GetMember().Member().GetClientUrls()[0] + }) if !s.isLocalRequest(forwardedHost) { if forwardStream == nil || lastForwardedHost != forwardedHost { if cancel != nil { @@ -1274,7 +1096,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error return err } log.Info("create region heartbeat forward stream", zap.String("forwarded-host", forwardedHost)) - forwardStream, cancel, err = s.createHeartbeatForwardStream(client) + forwardStream, cancel, err = s.createRegionHeartbeatForwardStream(client) if err != nil { return err } @@ -1360,56 +1182,83 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error continue } + regionHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "ok").Inc() + if s.IsServiceIndependent(utils.SchedulingServiceName) { - ctx := stream.Context() - primaryAddr, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName) - if schedulingStream == nil || lastPrimaryAddr != primaryAddr { - if cancel1 != nil { - cancel1() + if forwardErrCh != nil { + select { + case err, ok := <-forwardErrCh: + if ok { + if cancel != nil { + cancel() + } + forwardSchedulingStream = nil + log.Error("meet error and need to re-establish the stream", zap.Error(err)) + } + default: } - client, err := s.getDelegateClient(ctx, primaryAddr) + } + forwardedSchedulingHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.SchedulingServiceName) + if !ok || len(forwardedSchedulingHost) == 0 { + log.Debug("failed to find scheduling service primary address") + if cancel != nil { + cancel() + } + continue + } + if forwardSchedulingStream == nil || lastForwardedSchedulingHost != forwardedSchedulingHost { + if cancel != nil { + cancel() + } + client, err := s.getDelegateClient(s.ctx, forwardedSchedulingHost) if err != nil { - log.Error("get delegate client failed", zap.Error(err)) + log.Error("failed to get client", zap.Error(err)) + continue } - - log.Info("create region heartbeat forward stream", zap.String("forwarded-host", primaryAddr)) - schedulingStream, cancel1, err = s.createSchedulingStream(client) + log.Info("create scheduling forwarding stream", zap.String("forwarded-host", forwardedSchedulingHost)) + forwardSchedulingStream, _, cancel, err = s.createRegionHeartbeatSchedulingStream(stream.Context(), client) if err != nil { - log.Error("create region heartbeat forward stream failed", zap.Error(err)) - } else { - lastPrimaryAddr = primaryAddr - errCh = make(chan error, 1) - go forwardSchedulingToServer(schedulingStream, server, errCh) + log.Error("failed to create stream", zap.Error(err)) + continue } + lastForwardedSchedulingHost = forwardedSchedulingHost + forwardErrCh = make(chan error, 1) + go forwardRegionHeartbeatToScheduling(forwardSchedulingStream, server, forwardErrCh) } - if schedulingStream != nil { - req := &schedulingpb.RegionHeartbeatRequest{ - Header: &schedulingpb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - }, - Region: request.GetRegion(), - Leader: request.GetLeader(), - DownPeers: request.GetDownPeers(), - PendingPeers: request.GetPendingPeers(), - BytesWritten: request.GetBytesWritten(), - BytesRead: request.GetBytesRead(), - KeysWritten: request.GetKeysWritten(), - KeysRead: request.GetKeysRead(), - ApproximateSize: request.GetApproximateSize(), - ApproximateKeys: request.GetApproximateKeys(), - Interval: request.GetInterval(), - Term: request.GetTerm(), - QueryStats: request.GetQueryStats(), - } - if err := schedulingStream.Send(req); err != nil { - log.Error("forward region heartbeat failed", zap.Error(err)) + schedulingpbReq := &schedulingpb.RegionHeartbeatRequest{ + Header: &schedulingpb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + }, + Region: request.GetRegion(), + Leader: request.GetLeader(), + DownPeers: request.GetDownPeers(), + PendingPeers: request.GetPendingPeers(), + BytesWritten: request.GetBytesWritten(), + BytesRead: request.GetBytesRead(), + KeysWritten: request.GetKeysWritten(), + KeysRead: request.GetKeysRead(), + ApproximateSize: request.GetApproximateSize(), + ApproximateKeys: request.GetApproximateKeys(), + Interval: request.GetInterval(), + Term: request.GetTerm(), + QueryStats: request.GetQueryStats(), + } + if err := forwardSchedulingStream.Send(schedulingpbReq); err != nil { + forwardSchedulingStream = nil + log.Error("failed to send request to scheduling service", zap.Error(err)) + } + + select { + case err, ok := <-forwardErrCh: + if ok { + forwardSchedulingStream = nil + log.Error("failed to send response", zap.Error(err)) } + default: } } - - regionHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) - regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "ok").Inc() } } @@ -1639,7 +1488,8 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.AskBatchSplitRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -1648,10 +1498,10 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp Region: request.GetRegion(), SplitCount: request.GetSplitCount(), } - resp, err := s.schedulingClient.Load().(*schedulingClient).getClient().AskBatchSplit(ctx, req) + resp, err := cli.AskBatchSplit(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertAskSplitResponse(resp), err } return s.convertAskSplitResponse(resp), nil @@ -1812,7 +1662,8 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { var regionsID []uint64 // nolint if request.GetRegionId() != 0 { @@ -1836,10 +1687,10 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg RetryLimit: request.GetRetryLimit(), SkipStoreLimit: request.GetSkipStoreLimit(), } - resp, err := forwardCli.ScatterRegions(ctx, req) + resp, err := cli.ScatterRegions(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertScatterResponse(resp), err } return s.convertScatterResponse(resp), nil @@ -2035,7 +1886,8 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.GetOperatorRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -2043,10 +1895,10 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR }, RegionId: request.GetRegionId(), } - resp, err := forwardCli.GetOperator(ctx, req) + resp, err := cli.GetOperator(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertOperatorResponse(resp), err } return s.convertOperatorResponse(resp), nil @@ -2307,7 +2159,8 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.SplitRegionsRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -2316,10 +2169,10 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion SplitKeys: request.GetSplitKeys(), RetryLimit: request.GetRetryLimit(), } - resp, err := forwardCli.SplitRegions(ctx, req) + resp, err := cli.SplitRegions(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertSplitResponse(resp), err } return s.convertSplitResponse(resp), nil @@ -2451,258 +2304,6 @@ func (s *GrpcServer) validateInternalRequest(header *pdpb.RequestHeader, onlyAll return nil } -func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { - client, ok := s.clientConns.Load(forwardedHost) - if ok { - // Mostly, the connection is already established, and return it directly. - return client.(*grpc.ClientConn), nil - } - - tlsConfig, err := s.GetTLSConfig().ToTLSConfig() - if err != nil { - return nil, err - } - ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) - defer cancel() - newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) - if err != nil { - return nil, err - } - conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) - if !loaded { - // Successfully stored the connection we created. - return newConn, nil - } - // Loaded a connection created/stored by another goroutine, so close the one we created - // and return the one we loaded. - newConn.Close() - return conn.(*grpc.ClientConn), nil -} - -func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { - failpoint.Inject("useForwardRequest", func() { - failpoint.Return(false) - }) - if forwardedHost == "" { - return true - } - memberAddrs := s.GetMember().Member().GetClientUrls() - for _, addr := range memberAddrs { - if addr == forwardedHost { - return true - } - } - return false -} - -func (s *GrpcServer) createHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.Send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) createSchedulingStream(client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.CancelFunc, error) { - if client == nil { - return nil, nil, errors.New("connection is not set") - } - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardSchedulingToServer(forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - response := &pdpb.RegionHeartbeatResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: resp.GetHeader().GetClusterId(), - // ignore error here - }, - ChangePeer: resp.GetChangePeer(), - TransferLeader: resp.GetTransferLeader(), - RegionId: resp.GetRegionId(), - RegionEpoch: resp.GetRegionEpoch(), - TargetPeer: resp.GetTargetPeer(), - Merge: resp.GetMerge(), - SplitRegion: resp.GetSplitRegion(), - ChangePeerV2: resp.GetChangePeerV2(), - SwitchWitnesses: resp.GetSwitchWitnesses(), - } - - if err := server.Send(response); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) createTSOForwardStream( - ctx context.Context, client *grpc.ClientConn, -) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { - done := make(chan struct{}) - forwardCtx, cancelForward := context.WithCancel(ctx) - go grpcutil.CheckStream(forwardCtx, cancelForward, done) - forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) - done <- struct{}{} - return forwardStream, forwardCtx, cancelForward, err -} - -func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.CloseAndRecv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.Send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { - if !s.IsAPIServiceMode() { - return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) - } - request := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: s.clusterID, - KeyspaceId: utils.DefaultKeyspaceID, - KeyspaceGroupId: utils.DefaultKeyspaceGroupID, - }, - Count: 1, - } - var ( - forwardedHost string - forwardStream tsopb.TSO_TsoClient - ts *tsopb.TsoResponse - err error - ok bool - ) - handleStreamError := func(err error) (needRetry bool) { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - if grpcutil.NeedRebuildConnection(err) { - s.tsoClientPool.Lock() - delete(s.tsoClientPool.clients, forwardedHost) - s.tsoClientPool.Unlock() - log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - return false - } - for i := 0; i < maxRetryTimesRequestTSOServer; i++ { - if i > 0 { - time.Sleep(retryIntervalRequestTSOServer) - } - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, ErrNotFoundTSOAddr - } - forwardStream, err = s.getTSOForwardStream(forwardedHost) - if err != nil { - return pdpb.Timestamp{}, err - } - err = forwardStream.Send(request) - if err != nil { - if needRetry := handleStreamError(err); needRetry { - continue - } - log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err - } - ts, err = forwardStream.Recv() - if err != nil { - if needRetry := handleStreamError(err); needRetry { - continue - } - log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err - } - return *ts.GetTimestamp(), nil - } - log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err -} - -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { - s.tsoClientPool.RLock() - forwardStream, ok := s.tsoClientPool.clients[forwardedHost] - s.tsoClientPool.RUnlock() - if ok { - // This is the common case to return here - return forwardStream, nil - } - - s.tsoClientPool.Lock() - defer s.tsoClientPool.Unlock() - - // Double check after entering the critical section - forwardStream, ok = s.tsoClientPool.clients[forwardedHost] - if ok { - return forwardStream, nil - } - - // Now let's create the client connection and the forward stream - client, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - return nil, err - } - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) - done <- struct{}{} - if err != nil { - return nil, err - } - s.tsoClientPool.clients[forwardedHost] = forwardStream - return forwardStream, nil -} - // for CDC compatibility, we need to initialize config path to `globalConfigPath` const globalConfigPath = "/global/config/" From 86831ce7186525bdbcd33f92ba9008080a2c7a05 Mon Sep 17 00:00:00 2001 From: Hu# Date: Tue, 14 Nov 2023 12:11:14 +0800 Subject: [PATCH 25/26] prepare_check: remove redundant check (#7217) ref tikv/pd#7016 remove redundant check in prepare_check Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/cluster/cluster.go | 3 --- pkg/core/region.go | 18 ++++++++++--- pkg/core/region_tree.go | 7 +++++ pkg/schedule/prepare_checker.go | 35 +++++-------------------- server/cluster/cluster.go | 2 +- server/cluster/cluster_test.go | 18 ++++++------- tests/pdctl/scheduler/scheduler_test.go | 2 +- tests/server/cluster/cluster_test.go | 2 +- tests/testutil.go | 1 + 9 files changed, 40 insertions(+), 48 deletions(-) diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 0b3e0351b16..8809a706936 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -59,7 +59,4 @@ func Collect(c Cluster, region *core.RegionInfo, stores []*core.StoreInfo, hasRe if hasRegionStats { c.GetRegionStats().Observe(region, stores) } - if !isPrepared && isNew { - c.GetCoordinator().GetPrepareChecker().Collect(region) - } } diff --git a/pkg/core/region.go b/pkg/core/region.go index c9daa69c477..b141e8478da 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -1340,11 +1340,23 @@ func (r *RegionsInfo) GetStoreWriteRate(storeID uint64) (bytesRate, keysRate flo return } -// GetClusterNotFromStorageRegionsCnt gets the total count of regions that not loaded from storage anymore +// GetClusterNotFromStorageRegionsCnt gets the `NotFromStorageRegionsCnt` count of regions that not loaded from storage anymore. func (r *RegionsInfo) GetClusterNotFromStorageRegionsCnt() int { r.t.RLock() defer r.t.RUnlock() - return r.tree.notFromStorageRegionsCnt + return r.tree.notFromStorageRegionsCount() +} + +// GetNotFromStorageRegionsCntByStore gets the `NotFromStorageRegionsCnt` count of a store's leader, follower and learner by storeID. +func (r *RegionsInfo) GetNotFromStorageRegionsCntByStore(storeID uint64) int { + r.st.RLock() + defer r.st.RUnlock() + return r.getNotFromStorageRegionsCntByStoreLocked(storeID) +} + +// getNotFromStorageRegionsCntByStoreLocked gets the `NotFromStorageRegionsCnt` count of a store's leader, follower and learner by storeID. +func (r *RegionsInfo) getNotFromStorageRegionsCntByStoreLocked(storeID uint64) int { + return r.leaders[storeID].notFromStorageRegionsCount() + r.followers[storeID].notFromStorageRegionsCount() + r.learners[storeID].notFromStorageRegionsCount() } // GetMetaRegions gets a set of metapb.Region from regionMap @@ -1380,7 +1392,7 @@ func (r *RegionsInfo) GetStoreRegionCount(storeID uint64) int { return r.getStoreRegionCountLocked(storeID) } -// GetStoreRegionCount gets the total count of a store's leader, follower and learner RegionInfo by storeID +// getStoreRegionCountLocked gets the total count of a store's leader, follower and learner RegionInfo by storeID func (r *RegionsInfo) getStoreRegionCountLocked(storeID uint64) int { return r.leaders[storeID].length() + r.followers[storeID].length() + r.learners[storeID].length() } diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index ed3445de6b6..ecc988d97d8 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -82,6 +82,13 @@ func (t *regionTree) length() int { return t.tree.Len() } +func (t *regionTree) notFromStorageRegionsCount() int { + if t == nil { + return 0 + } + return t.notFromStorageRegionsCnt +} + // GetOverlaps returns the range items that has some intersections with the given items. func (t *regionTree) overlaps(item *regionItem) []*regionItem { // note that Find() gets the last item that is less or equal than the item. diff --git a/pkg/schedule/prepare_checker.go b/pkg/schedule/prepare_checker.go index c7faa57af81..34618427930 100644 --- a/pkg/schedule/prepare_checker.go +++ b/pkg/schedule/prepare_checker.go @@ -25,16 +25,13 @@ import ( type prepareChecker struct { syncutil.RWMutex - reactiveRegions map[uint64]int - start time.Time - sum int - prepared bool + start time.Time + prepared bool } func newPrepareChecker() *prepareChecker { return &prepareChecker{ - start: time.Now(), - reactiveRegions: make(map[uint64]int), + start: time.Now(), } } @@ -51,13 +48,8 @@ func (checker *prepareChecker) check(c *core.BasicCluster) bool { } notLoadedFromRegionsCnt := c.GetClusterNotFromStorageRegionsCnt() totalRegionsCnt := c.GetTotalRegionCount() - if float64(notLoadedFromRegionsCnt) > float64(totalRegionsCnt)*collectFactor { - log.Info("meta not loaded from region number is satisfied, finish prepare checker", zap.Int("not-from-storage-region", notLoadedFromRegionsCnt), zap.Int("total-region", totalRegionsCnt)) - checker.prepared = true - return true - } // The number of active regions should be more than total region of all stores * collectFactor - if float64(totalRegionsCnt)*collectFactor > float64(checker.sum) { + if float64(totalRegionsCnt)*collectFactor > float64(notLoadedFromRegionsCnt) { return false } for _, store := range c.GetStores() { @@ -66,23 +58,15 @@ func (checker *prepareChecker) check(c *core.BasicCluster) bool { } storeID := store.GetID() // For each store, the number of active regions should be more than total region of the store * collectFactor - if float64(c.GetStoreRegionCount(storeID))*collectFactor > float64(checker.reactiveRegions[storeID]) { + if float64(c.GetStoreRegionCount(storeID))*collectFactor > float64(c.GetNotFromStorageRegionsCntByStore(storeID)) { return false } } + log.Info("not loaded from storage region number is satisfied, finish prepare checker", zap.Int("not-from-storage-region", notLoadedFromRegionsCnt), zap.Int("total-region", totalRegionsCnt)) checker.prepared = true return true } -func (checker *prepareChecker) Collect(region *core.RegionInfo) { - checker.Lock() - defer checker.Unlock() - for _, p := range region.GetPeers() { - checker.reactiveRegions[p.GetStoreId()]++ - } - checker.sum++ -} - func (checker *prepareChecker) IsPrepared() bool { checker.RLock() defer checker.RUnlock() @@ -95,10 +79,3 @@ func (checker *prepareChecker) SetPrepared() { defer checker.Unlock() checker.prepared = true } - -// for test purpose -func (checker *prepareChecker) GetSum() int { - checker.RLock() - defer checker.RUnlock() - return checker.sum -} diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index c0fb1b15f8f..3b50ae16d9b 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1018,7 +1018,7 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, // check its validation again here. // - // However it can't solve the race condition of concurrent heartbeats from the same region. + // However, it can't solve the race condition of concurrent heartbeats from the same region. if overlaps, err = c.core.AtomicCheckAndPutRegion(region); err != nil { return err } diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 7ebd012a6a2..4b9b401e0c9 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -2384,7 +2384,7 @@ func (c *testCluster) LoadRegion(regionID uint64, followerStoreIDs ...uint64) er peer, _ := c.AllocPeer(id) region.Peers = append(region.Peers, peer) } - return c.putRegion(core.NewRegionInfo(region, nil)) + return c.putRegion(core.NewRegionInfo(region, nil, core.SetSource(core.Storage))) } func TestBasic(t *testing.T) { @@ -2469,7 +2469,7 @@ func TestDispatch(t *testing.T) { func dispatchHeartbeat(co *schedule.Coordinator, region *core.RegionInfo, stream hbstream.HeartbeatStream) error { co.GetHeartbeatStreams().BindStream(region.GetLeader().GetStoreId(), stream) - if err := co.GetCluster().(*RaftCluster).putRegion(region.Clone()); err != nil { + if err := co.GetCluster().(*RaftCluster).putRegion(region.Clone(core.SetSource(core.Heartbeat))); err != nil { return err } co.GetOperatorController().Dispatch(region, operator.DispatchFromHeartBeat, nil) @@ -2943,14 +2943,14 @@ func TestShouldRun(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) - nr := r.Clone(core.WithLeader(r.GetPeers()[0])) + nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) re.NoError(tc.processRegionHeartbeat(nr)) re.Equal(testCase.ShouldRun, co.ShouldRun()) } nr := &metapb.Region{Id: 6, Peers: []*metapb.Peer{}} - newRegion := core.NewRegionInfo(nr, nil) + newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) re.Error(tc.processRegionHeartbeat(newRegion)) - re.Equal(7, co.GetPrepareChecker().GetSum()) + re.Equal(7, tc.core.GetClusterNotFromStorageRegionsCnt()) } func TestShouldRunWithNonLeaderRegions(t *testing.T) { @@ -2986,14 +2986,14 @@ func TestShouldRunWithNonLeaderRegions(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) - nr := r.Clone(core.WithLeader(r.GetPeers()[0])) + nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) re.NoError(tc.processRegionHeartbeat(nr)) re.Equal(testCase.ShouldRun, co.ShouldRun()) } nr := &metapb.Region{Id: 9, Peers: []*metapb.Peer{}} - newRegion := core.NewRegionInfo(nr, nil) + newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) re.Error(tc.processRegionHeartbeat(newRegion)) - re.Equal(9, co.GetPrepareChecker().GetSum()) + re.Equal(9, tc.core.GetClusterNotFromStorageRegionsCnt()) // Now, after server is prepared, there exist some regions with no leader. re.Equal(uint64(0), tc.GetRegion(10).GetLeader().GetStoreId()) @@ -3263,7 +3263,6 @@ func TestRestart(t *testing.T) { re.NoError(tc.addRegionStore(3, 3)) re.NoError(tc.addLeaderRegion(1, 1)) region := tc.GetRegion(1) - co.GetPrepareChecker().Collect(region) // Add 1 replica on store 2. stream := mockhbstream.NewHeartbeatStream() @@ -3277,7 +3276,6 @@ func TestRestart(t *testing.T) { // Recreate coordinator then add another replica on store 3. co = schedule.NewCoordinator(ctx, tc.RaftCluster, hbStreams) - co.GetPrepareChecker().Collect(region) co.Run() re.NoError(dispatchHeartbeat(co, region, stream)) region = waitAddLearner(re, stream, region, 3) diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index cd599405124..d0fac2c1137 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -539,7 +539,7 @@ func (suite *schedulerTestSuite) checkSchedulerDiagnostic(cluster *tests.TestClu tests.MustPutStore(re, cluster, store) } - // note: because pdqsort is a unstable sort algorithm, set ApproximateSize for this region. + // note: because pdqsort is an unstable sort algorithm, set ApproximateSize for this region. tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) echo := mustExec(re, cmd, []string{"-u", pdAddr, "config", "set", "enable-diagnostic", "true"}, nil) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index ebf0a4e574d..b7a428e3683 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1402,7 +1402,7 @@ func putRegionWithLeader(re *require.Assertions, rc *cluster.RaftCluster, id id. StartKey: []byte{byte(i)}, EndKey: []byte{byte(i + 1)}, } - rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0], core.SetSource(core.Heartbeat))) } time.Sleep(50 * time.Millisecond) diff --git a/tests/testutil.go b/tests/testutil.go index 613705d3eb6..059a152f06f 100644 --- a/tests/testutil.go +++ b/tests/testutil.go @@ -197,6 +197,7 @@ func MustPutRegion(re *require.Assertions, cluster *TestCluster, regionID, store Peers: []*metapb.Peer{leader}, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, } + opts = append(opts, core.SetSource(core.Heartbeat)) r := core.NewRegionInfo(metaRegion, leader, opts...) MustPutRegionInfo(re, cluster, r) return r From 181fdc95be65fd8c83155c76f0a69ddb2cf143bf Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 15 Nov 2023 14:45:46 +0800 Subject: [PATCH 26/26] makefile: support build with `boringcrypto` to support Fips (#7275) close tikv/pd#7274 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- Makefile | 28 +++++++++++++++++++++++----- pkg/versioninfo/fips.go | 26 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 pkg/versioninfo/fips.go diff --git a/Makefile b/Makefile index 54ad331aea4..906dd9414f9 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,8 @@ dev-basic: build check basic-test BUILD_FLAGS ?= BUILD_TAGS ?= BUILD_CGO_ENABLED := 0 +BUILD_TOOL_CGO_ENABLED := 0 +BUILD_GOEXPERIMENT ?= PD_EDITION ?= Community # Ensure PD_EDITION is set to Community or Enterprise before running build process. ifneq "$(PD_EDITION)" "Community" @@ -46,6 +48,13 @@ ifeq ($(PLUGIN), 1) BUILD_TAGS += with_plugin endif +ifeq ($(ENABLE_FIPS), 1) + BUILD_TAGS+=boringcrypto + BUILD_GOEXPERIMENT=boringcrypto + BUILD_CGO_ENABLED := 1 + BUILD_TOOL_CGO_ENABLED := 1 +endif + LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDReleaseVersion=$(shell git describe --tags --dirty --always)" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDGitHash=$(shell git rev-parse HEAD)" @@ -66,6 +75,8 @@ BUILD_BIN_PATH := $(ROOT_PATH)/bin build: pd-server pd-ctl pd-recover +build-fips: pd-server-fips pd-ctl-fips pd-recover-fips + tools: pd-tso-bench pd-heartbeat-bench regions-dump stores-dump pd-api-bench PD_SERVER_DEP := @@ -79,7 +90,7 @@ endif PD_SERVER_DEP += dashboard-ui pd-server: ${PD_SERVER_DEP} - CGO_ENABLED=$(BUILD_CGO_ENABLED) go build $(BUILD_FLAGS) -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags "$(BUILD_TAGS)" -o $(BUILD_BIN_PATH)/pd-server cmd/pd-server/main.go + GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_CGO_ENABLED) go build $(BUILD_FLAGS) -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags "$(BUILD_TAGS)" -o $(BUILD_BIN_PATH)/pd-server cmd/pd-server/main.go pd-server-failpoint: @$(FAILPOINT_ENABLE) @@ -89,18 +100,25 @@ pd-server-failpoint: pd-server-basic: SWAGGER=0 DASHBOARD=0 $(MAKE) pd-server -.PHONY: build tools pd-server pd-server-basic +pd-server-fips: + ENABLE_FIPS=1 $(MAKE) pd-server + +.PHONY: build tools pd-server pd-server-basic pd-server-fips # Tools pd-ctl: - CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ctl tools/pd-ctl/main.go + GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ctl tools/pd-ctl/main.go +pd-ctl-fips: + ENABLE_FIPS=1 $(MAKE) pd-ctl pd-tso-bench: cd tools/pd-tso-bench && CGO_ENABLED=0 go build -o $(BUILD_BIN_PATH)/pd-tso-bench main.go pd-api-bench: cd tools/pd-api-bench && CGO_ENABLED=0 go build -o $(BUILD_BIN_PATH)/pd-api-bench main.go pd-recover: - CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-recover tools/pd-recover/main.go + GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-recover tools/pd-recover/main.go +pd-recover-fips: + ENABLE_FIPS=1 $(MAKE) pd-recover pd-analysis: CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-analysis tools/pd-analysis/main.go pd-heartbeat-bench: @@ -112,7 +130,7 @@ regions-dump: stores-dump: CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/stores-dump tools/stores-dump/main.go -.PHONY: pd-ctl pd-tso-bench pd-recover pd-analysis pd-heartbeat-bench simulator regions-dump stores-dump pd-api-bench +.PHONY: pd-ctl pd-ctl-fips pd-tso-bench pd-recover pd-recover-fips pd-analysis pd-heartbeat-bench simulator regions-dump stores-dump pd-api-bench #### Docker image #### diff --git a/pkg/versioninfo/fips.go b/pkg/versioninfo/fips.go new file mode 100644 index 00000000000..02478b103fa --- /dev/null +++ b/pkg/versioninfo/fips.go @@ -0,0 +1,26 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build boringcrypto +// +build boringcrypto + +package versioninfo + +import ( + _ "crypto/tls/fipsonly" +) + +func init() { + PDReleaseVersion += "-fips" +}