diff --git a/v3/bind.go b/v3/bind.go index a37f8e2c..5d08715b 100644 --- a/v3/bind.go +++ b/v3/bind.go @@ -6,11 +6,11 @@ import ( enchex "encoding/hex" "errors" "fmt" + "github.com/Azure/go-ntlmssp" "io/ioutil" "math/rand" "strings" - "github.com/Azure/go-ntlmssp" ber "github.com/go-asn1-ber/asn1-ber" ) @@ -733,3 +733,115 @@ RESP: return nil, GetLDAPError(packet) } + +// ServerBindStep sends the SASLBindRequest and return the result code, servercred and errors. +// If the result code is not SUCCESS(0x00) or IN_PROGRESS(0x0e), it's also considered as an error. +func (l *Conn) ServerBindStep(clientCred []byte, dn string, methodName string, controls []Control) (resultCode uint16, serverCred []byte, err error) { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "User Name")) + + auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, methodName, "SASL Mech")) + auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(clientCred), "Credentials")) + request.AppendChild(auth) + packet.AppendChild(request) + + if len(controls) > 0 { + packet.AppendChild(encodeControls(controls)) + } + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return 0, nil, err + } + defer l.finishMessage(msgCtx) + + packetResponse, ok := <-msgCtx.responses + if !ok { + return 0, nil, errors.New("ldap: response channel closed") + } + packet, err = packetResponse.ReadPacket() + if err != nil { + return 0, nil, err + } + + bindResponse, err := parseBindResponse(packet) + if err != nil { + return 0, nil, err + } + + // TODO: add global constants. As this package used `uint16` to represent the result code nearly everywhere, + // these constants were + const LDAPResultCodeSuccess = 0 + const LDAPResultCodeSASLBindInProgress = 0x0e + if bindResponse.resultCode != LDAPResultCodeSuccess && bindResponse.resultCode != LDAPResultCodeSASLBindInProgress { + return bindResponse.resultCode, nil, &Error{ + ResultCode: bindResponse.resultCode, + MatchedDN: bindResponse.matchedDN, + Err: fmt.Errorf("%s", bindResponse.errorMessage), + Packet: packet, + } + } + return bindResponse.resultCode, bindResponse.serverSaslCreds, nil +} + +type bindResponse struct { + resultCode uint16 + matchedDN string + errorMessage string + serverSaslCreds []byte +} + +// parseBindResponse parses the bind response. The format of a BindResponse is like below: +// +// ``` +// +// BindResponse ::= [APPLICATION 1] SEQUENCE { +// COMPONENTS OF LDAPResult, +// serverSaslCreds [7] OCTET STRING OPTIONAL } +// +// LDAPResult ::= SEQUENCE { +// resultCode ENUMERATED, +// matchedDN LDAPDN, +// errorMessage LDAPString, +// referral [3] Referral OPTIONAL } +// +// ``` +// +// TODO: support `referral` field in this function +func parseBindResponse(packet *ber.Packet) (*bindResponse, error) { + if packet == nil { + return nil, &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} + } + if len(packet.Children) < 2 { + return nil, &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet} + } + + response := packet.Children[1] + if response == nil { + return nil, &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} + } + + if !(response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3) { + return nil, &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} + } + resp := &bindResponse{ + uint16(response.Children[0].Value.(int64)), + response.Children[1].Value.(string), + response.Children[2].Value.(string), + nil, + } + if len(response.Children) < 4 { + return resp, nil + } + // then the response.Children[3] can be an referral or serverSaslCreds. It can be asserted with tag + // TODO: also add referral + if response.Children[3].ClassType == ber.ClassContext && response.Children[3].Tag == ber.TagObjectDescriptor { + resp.serverSaslCreds = response.Children[3].Data.Bytes() + } + return resp, nil +}