diff --git a/proxy/mtproto/auth.go b/proxy/mtproto/auth.go index dd86623984..b0c53998af 100644 --- a/proxy/mtproto/auth.go +++ b/proxy/mtproto/auth.go @@ -1,6 +1,7 @@ package mtproto import ( + "context" "crypto/rand" "crypto/sha256" "io" @@ -13,6 +14,35 @@ const ( HeaderSize = 64 ) +type SessionContext struct { + ConnectionType [4]byte + DataCenterID uint16 +} + +func DefaultSessionContext() SessionContext { + return SessionContext{ + ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef}, + DataCenterID: 0, + } +} + +type contextKey int32 + +const ( + sessionContextKey contextKey = iota +) + +func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context { + return context.WithValue(ctx, sessionContextKey, c) +} + +func SessionContextFromContext(ctx context.Context) SessionContext { + if c := ctx.Value(sessionContextKey); c != nil { + return c.(SessionContext) + } + return DefaultSessionContext() +} + type Authentication struct { Header [HeaderSize]byte DecodingKey [32]byte @@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 { return uint16(x) - 1 } +func (a *Authentication) ConnectionType() [4]byte { + var x [4]byte + copy(x[:], a.Header[56:60]) + return x +} + func (a *Authentication) ApplySecret(b []byte) { a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...)) a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...)) } -func generateRandomBytes(random []byte) { +func generateRandomBytes(random []byte, connType [4]byte) { for { common.Must2(rand.Read(random)) @@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) { continue } - random[56] = 0xef - random[57] = 0xef - random[58] = 0xef - random[59] = 0xef + copy(random[56:60], connType[:]) return } } -func NewAuthentication() *Authentication { +func NewAuthentication(sc SessionContext) *Authentication { auth := getAuthenticationObject() random := auth.Header[:] - generateRandomBytes(random) + generateRandomBytes(random, sc.ConnectionType) copy(auth.EncodingKey[:], random[8:]) copy(auth.EncodingNonce[:], random[8+32:]) keyivInverse := Inverse(random[8 : 8+32+16]) diff --git a/proxy/mtproto/auth_test.go b/proxy/mtproto/auth_test.go index 6f392fe1e7..8f97a8d9d0 100644 --- a/proxy/mtproto/auth_test.go +++ b/proxy/mtproto/auth_test.go @@ -32,7 +32,7 @@ func TestInverse(t *testing.T) { func TestAuthenticationReadWrite(t *testing.T) { assert := With(t) - a := NewAuthentication() + a := NewAuthentication(DefaultSessionContext()) b := bytes.NewReader(a.Header[:]) a2, err := ReadAuthentication(b) assert(err, IsNil) diff --git a/proxy/mtproto/client.go b/proxy/mtproto/client.go index 058e19e537..d2f37b98fb 100644 --- a/proxy/mtproto/client.go +++ b/proxy/mtproto/client.go @@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial } defer conn.Close() // nolint: errcheck - auth := NewAuthentication() + sc := SessionContextFromContext(ctx) + auth := NewAuthentication(sc) defer putAuthenticationObject(auth) request := func() error { diff --git a/proxy/mtproto/server.go b/proxy/mtproto/server.go index a6f980057d..108e445807 100644 --- a/proxy/mtproto/server.go +++ b/proxy/mtproto/server.go @@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList { } } +func isValidConnectionType(c [4]byte) bool { + if compare.BytesAll(c[:], 0xef) { + return true + } + if compare.BytesAll(c[:], 0xee) { + return true + } + return false +} + func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error { sPolicy := s.policy.ForLevel(s.user.Level) @@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:]) decryptor.XORKeyStream(auth.Header[:], auth.Header[:]) - if !compare.BytesAll(auth.Header[56:60], 0xef) { - return newError("invalid connection type: ", auth.Header[56:60]) + ct := auth.ConnectionType() + if !isValidConnectionType(ct) { + return newError("invalid connection type: ", ct) } dcID := auth.DataCenterID() @@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle) ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer) + sc := SessionContext{ + ConnectionType: ct, + DataCenterID: dcID, + } + ctx = ContextWithSessionContext(ctx, sc) + link, err := dispatcher.Dispatch(ctx, dest) if err != nil { return newError("failed to dispatch request to: ", dest).Base(err)