Skip to content

Commit

Permalink
support mtproto conn type 0xee. fixes #1297
Browse files Browse the repository at this point in the history
  • Loading branch information
DarienRaymond committed Oct 11, 2018
1 parent d839595 commit 2e94561
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
47 changes: 40 additions & 7 deletions proxy/mtproto/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mtproto

import (
"context"
"crypto/rand"
"crypto/sha256"
"io"
Expand All @@ -13,6 +14,35 @@ const (
HeaderSize = 64
)

type SessionContext struct {
ConnectionType [4]byte
DataCenterID uint16
}

func DefaultSessionContext() SessionContext {
return SessionContext{
ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
DataCenterID: 0,
}
}

type contextKey int32

const (
sessionContextKey contextKey = iota
)

func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
return context.WithValue(ctx, sessionContextKey, c)
}

func SessionContextFromContext(ctx context.Context) SessionContext {
if c := ctx.Value(sessionContextKey); c != nil {
return c.(SessionContext)
}
return DefaultSessionContext()
}

type Authentication struct {
Header [HeaderSize]byte
DecodingKey [32]byte
Expand All @@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 {
return uint16(x) - 1
}

func (a *Authentication) ConnectionType() [4]byte {
var x [4]byte
copy(x[:], a.Header[56:60])
return x
}

func (a *Authentication) ApplySecret(b []byte) {
a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
}

func generateRandomBytes(random []byte) {
func generateRandomBytes(random []byte, connType [4]byte) {
for {
common.Must2(rand.Read(random))

Expand All @@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) {
continue
}

random[56] = 0xef
random[57] = 0xef
random[58] = 0xef
random[59] = 0xef
copy(random[56:60], connType[:])

return
}
}

func NewAuthentication() *Authentication {
func NewAuthentication(sc SessionContext) *Authentication {
auth := getAuthenticationObject()
random := auth.Header[:]
generateRandomBytes(random)
generateRandomBytes(random, sc.ConnectionType)
copy(auth.EncodingKey[:], random[8:])
copy(auth.EncodingNonce[:], random[8+32:])
keyivInverse := Inverse(random[8 : 8+32+16])
Expand Down
2 changes: 1 addition & 1 deletion proxy/mtproto/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestInverse(t *testing.T) {
func TestAuthenticationReadWrite(t *testing.T) {
assert := With(t)

a := NewAuthentication()
a := NewAuthentication(DefaultSessionContext())
b := bytes.NewReader(a.Header[:])
a2, err := ReadAuthentication(b)
assert(err, IsNil)
Expand Down
3 changes: 2 additions & 1 deletion proxy/mtproto/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
}
defer conn.Close() // nolint: errcheck

auth := NewAuthentication()
sc := SessionContextFromContext(ctx)
auth := NewAuthentication(sc)
defer putAuthenticationObject(auth)

request := func() error {
Expand Down
21 changes: 19 additions & 2 deletions proxy/mtproto/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList {
}
}

func isValidConnectionType(c [4]byte) bool {
if compare.BytesAll(c[:], 0xef) {
return true
}
if compare.BytesAll(c[:], 0xee) {
return true
}
return false
}

func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
sPolicy := s.policy.ForLevel(s.user.Level)

Expand All @@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
decryptor.XORKeyStream(auth.Header[:], auth.Header[:])

if !compare.BytesAll(auth.Header[56:60], 0xef) {
return newError("invalid connection type: ", auth.Header[56:60])
ct := auth.ConnectionType()
if !isValidConnectionType(ct) {
return newError("invalid connection type: ", ct)
}

dcID := auth.DataCenterID()
Expand All @@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer)

sc := SessionContext{
ConnectionType: ct,
DataCenterID: dcID,
}
ctx = ContextWithSessionContext(ctx, sc)

link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return newError("failed to dispatch request to: ", dest).Base(err)
Expand Down

0 comments on commit 2e94561

Please sign in to comment.