diff --git a/drivers/shared/executor/executor.go b/drivers/shared/executor/executor.go index 7f42ec06c7e..396d5919764 100644 --- a/drivers/shared/executor/executor.go +++ b/drivers/shared/executor/executor.go @@ -361,7 +361,7 @@ func (e *UniversalExecutor) Launch(command *ExecCommand) (*ProcessState, error) // setting the user of the process if command.User != "" { e.logger.Debug("running command as user", "user", command.User) - if err := setCmdUser(&e.childCmd, command.User); err != nil { + if err := setCmdUser(e.logger, &e.childCmd, command.User); err != nil { return nil, err } } @@ -527,7 +527,7 @@ func (e *UniversalExecutor) ExecStreaming(ctx context.Context, command []string, }, processStart: func() error { if u := e.command.User; u != "" { - if err := setCmdUser(cmd, u); err != nil { + if err := setCmdUser(e.logger, cmd, u); err != nil { return err } } diff --git a/drivers/shared/executor/executor_basic.go b/drivers/shared/executor/executor_basic.go index 72f1e21ddf0..8877be3488c 100644 --- a/drivers/shared/executor/executor_basic.go +++ b/drivers/shared/executor/executor_basic.go @@ -35,7 +35,7 @@ func withNetworkIsolation(f func() error, _ *drivers.NetworkIsolationSpec) error return f() } -func setCmdUser(*exec.Cmd, string) error { return nil } +func setCmdUser(hclog.Logger, *exec.Cmd, string) error { return nil } func (e *UniversalExecutor) ListProcesses() set.Collection[int] { return procstats.ListByPid(e.childCmd.Process.Pid) diff --git a/drivers/shared/executor/executor_universal_linux.go b/drivers/shared/executor/executor_universal_linux.go index 53d042a6e69..936a77ff7e6 100644 --- a/drivers/shared/executor/executor_universal_linux.go +++ b/drivers/shared/executor/executor_universal_linux.go @@ -12,6 +12,7 @@ import ( "strconv" "syscall" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-set/v3" "github.com/hashicorp/nomad/client/lib/cgroupslib" "github.com/hashicorp/nomad/client/lib/nsutil" @@ -30,7 +31,7 @@ const ( // setCmdUser takes a user id as a string and looks up the user, and sets the command // to execute as that user. -func setCmdUser(cmd *exec.Cmd, userid string) error { +func setCmdUser(logger hclog.Logger, cmd *exec.Cmd, userid string) error { u, err := users.Lookup(userid) if err != nil { return fmt.Errorf("failed to identify user %v: %v", userid, err) diff --git a/drivers/shared/executor/executor_windows.go b/drivers/shared/executor/executor_windows.go index 9523c053634..1a0ea211d01 100644 --- a/drivers/shared/executor/executor_windows.go +++ b/drivers/shared/executor/executor_windows.go @@ -6,7 +6,6 @@ package executor import ( - "errors" "fmt" "os" "os/exec" @@ -19,6 +18,7 @@ import ( "github.com/hashicorp/go-set/v3" "github.com/hashicorp/nomad/client/lib/cpustats" "github.com/hashicorp/nomad/drivers/shared/executor/procstats" + "github.com/hashicorp/nomad/drivers/shared/executor/s4u" "github.com/hashicorp/nomad/plugins/drivers" "golang.org/x/sys/windows" ) @@ -43,12 +43,8 @@ func withNetworkIsolation(f func() error, _ *drivers.NetworkIsolationSpec) error return f() } -func setCmdUser(cmd *exec.Cmd, user string) error { - nameParts := strings.Split(user, "\\") - if len(nameParts) != 2 { - return errors.New("user name must contain domain") - } - token, err := createUserToken(nameParts[0], nameParts[1]) +func setCmdUser(logger hclog.Logger, cmd *exec.Cmd, user string) error { + token, err := createUserToken(logger, user) if err != nil { return fmt.Errorf("failed to create user token: %w", err) } @@ -56,7 +52,7 @@ func setCmdUser(cmd *exec.Cmd, user string) error { if cmd.SysProcAttr == nil { cmd.SysProcAttr = &syscall.SysProcAttr{} } - cmd.SysProcAttr.Token = *token + cmd.SysProcAttr.Token = token runtime.AddCleanup(cmd, func(attr *syscall.SysProcAttr) { _ = attr.Token.Close() @@ -75,7 +71,62 @@ const ( _PROVIDER_DEFAULT uint32 = 0 ) -func createUserToken(domain, username string) (*syscall.Token, error) { +// username can be of the form "domain\username", ".\username" or "username@domain" +func createUserToken(logger hclog.Logger, username string) (syscall.Token, error) { + var token windows.Token + var err error + + var runAsUpn string + if strings.IndexByte(username, '\\') != -1 { + runAsUpn, err = convertUserToUpn(username) + if err != nil { + return 0, fmt.Errorf("failed to convert username %q to UPN : %w", username, err) + } + } else if strings.IndexByte(username, '@') != -1 { + runAsUpn = username + } + + logger.Debug("creating user token", "username", username, "runAsUpn", runAsUpn) + + if runAsUpn != "" { + token, err = s4u.GetDomainS4uToken(runAsUpn) + } else { + token, err = s4u.GetLocalS4uToken(username) + } + if err != nil { + return 0, fmt.Errorf("failed to create S4U token for user : %w", err) + } + + return syscall.Token(token), nil +} + +func convertUserToUpn(username string) (string, error) { + usernameUtf16, err := windows.UTF16FromString(username) + if err != nil { + return "", fmt.Errorf("error converting username to UTF16 : %w", err) + } + + upnUtf16, err := translateSamToUpn(usernameUtf16) + if err != nil { + return "", err + } + + return windows.UTF16ToString(upnUtf16), nil +} + +const MAX_UPN_LEN = 1024 + +func translateSamToUpn(samAccountNameUtf16 []uint16) ([]uint16, error) { + var domainUpnLen uint32 = MAX_UPN_LEN + 1 + domainUpn := make([]uint16, domainUpnLen) + err := windows.TranslateName(&samAccountNameUtf16[0], windows.NameSamCompatible, windows.NameUserPrincipal, &domainUpn[0], &domainUpnLen) + if err != nil { + return nil, err + } + return domainUpn[:domainUpnLen-1], nil +} + +func createUserTokenOld(username, domain string) (*syscall.Token, error) { userw, err := syscall.UTF16PtrFromString(username) if err != nil { return nil, fmt.Errorf("failed to convert username to UTF-16: %w", err) diff --git a/drivers/shared/executor/s4u/lsa.go b/drivers/shared/executor/s4u/lsa.go new file mode 100644 index 00000000000..0ffc6bf134d --- /dev/null +++ b/drivers/shared/executor/s4u/lsa.go @@ -0,0 +1,203 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows + +package s4u + +import ( + "log" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + modsecur32 = windows.NewLazySystemDLL("secur32.dll") + + procLsaRegisterLogonProcess = modsecur32.NewProc("LsaRegisterLogonProcess") + procLsaDeregisterLogonProcess = modsecur32.NewProc("LsaDeregisterLogonProcess") + procLsaLookupAuthenticationPackage = modsecur32.NewProc("LsaLookupAuthenticationPackage") + procLsaLogonUser = modsecur32.NewProc("LsaLogonUser") + procLsaFreeReturnBuffer = modsecur32.NewProc("LsaFreeReturnBuffer") +) + +type LSA_STRING struct { + Length uint16 + MaximumLength uint16 + Buffer *byte +} + +type LSA_UNICODE_STRING struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +func mustStringToLsaString(s string) *LSA_STRING { + var result LSA_STRING + var err error + result.Length = uint16(len(s)) + result.MaximumLength = uint16(len(s)) + result.Buffer, err = windows.BytePtrFromString(s) + if err != nil { + log.Fatal(err) + } + return &result +} + +func LsaRegisterLogonProcess(logonProcessName *LSA_STRING, handle *windows.Handle) error { + var mode uint32 + status, _, _ := syscall.SyscallN( + procLsaRegisterLogonProcess.Addr(), + uintptr(unsafe.Pointer(logonProcessName)), + uintptr(unsafe.Pointer(handle)), + uintptr(unsafe.Pointer(&mode))) + + if status != 0 { + return windows.NTStatus(status) + } + return nil +} + +func LsaDeregisterLogonProcess(hnd windows.Handle) error { + status, _, _ := syscall.SyscallN(procLsaDeregisterLogonProcess.Addr(), uintptr(hnd)) + if status != 0 { + return windows.NTStatus(status) + } + return nil +} + +func LsaLookupAuthenticationPackage(hnd windows.Handle, packageName *LSA_STRING, authPackageId *uint32) error { + status, _, _ := syscall.SyscallN( + procLsaLookupAuthenticationPackage.Addr(), + uintptr(hnd), + uintptr(unsafe.Pointer(packageName)), + uintptr(unsafe.Pointer(authPackageId))) + + if status != 0 { + return windows.NTStatus(status) + } + return nil +} + +type SecurityLogonType uint32 + +const ( + LogonTypeUndefined = 0 + LogonTypeInteractive = 1 + iota + LogonTypeNetwork + LogonTypeBatch + LogonTypeService + LogonTypeProxy + LogonTypeUnlock + LogonTypeNetworkCleartext + LogonTypeNewCredentials + LogonTypeRemoteInteractive + LogonTypeCachedInteractive + LogonTypeCachedRemoteInteractive + LogonTypeCachedUnlock +) + +type TokenGroups struct { + PrivilegeCount uint32 + Privileges [1]windows.LUIDAndAttributes +} +type TokenSource struct { + SourceName [8]byte + SourceId windows.LUID +} +type QuotaLimits struct { + PagedPoolLimit uintptr + NonPagedPoolLimit uintptr + MinimumWorkingSetSize uintptr + MaximumWorkingSetSize uintptr + PagefileLimit uintptr + TimeLimit uint64 +} + +func LsaLogonUser(hnd windows.Handle, + originName *LSA_STRING, + logonType SecurityLogonType, + authPackageId uint32, + accountInformation *byte, accountInformationLength uint32, + tokenGroups *TokenGroups, + tokenSource *TokenSource, + profileBuffer **byte, profileBufferLength *uint32, + logonId *windows.LUID, + token *windows.Token, + quotas *QuotaLimits, + subStatus *windows.NTStatus) error { + status, _, _ := syscall.SyscallN( + procLsaLogonUser.Addr(), + uintptr(hnd), + uintptr(unsafe.Pointer(originName)), + uintptr(logonType), + uintptr(authPackageId), + uintptr(unsafe.Pointer(accountInformation)), + uintptr(accountInformationLength), + uintptr(unsafe.Pointer(tokenGroups)), + uintptr(unsafe.Pointer(tokenSource)), + uintptr(unsafe.Pointer(profileBuffer)), + uintptr(unsafe.Pointer(profileBufferLength)), + uintptr(unsafe.Pointer(logonId)), + uintptr(unsafe.Pointer(token)), + uintptr(unsafe.Pointer(quotas)), + uintptr(unsafe.Pointer(subStatus))) + + if status != 0 { + return windows.NTStatus(status) + } + return nil +} + +type MSV1_0_S4U_LOGON struct { + MessageType MSV1_0_LOGON_SUBMIT_TYPE + Flags uint32 + UserPrincipalName LSA_UNICODE_STRING // username or username@domain + DomainName LSA_UNICODE_STRING // Optional: if missing, using the local machine +} + +type MSV1_0_LOGON_SUBMIT_TYPE uint32 + +const ( + MsV1_0InteractiveLogon = 2 + MsV1_0Lm20Logon = 3 + MsV1_0NetworkLogon = 4 + MsV1_0SubAuthLogon = 5 + MsV1_0WorkstationUnlockLogon = 7 + MsV1_0S4ULogon = 12 + MsV1_0VirtualLogon = 82 +) + +type KERB_S4U_LOGON struct { + MessageType KERB_LOGON_SUBMIT_TYPE + Flags uint32 + ClientUpn LSA_UNICODE_STRING + ClientRealm LSA_UNICODE_STRING +} + +type KERB_LOGON_SUBMIT_TYPE uint32 + +const ( + KerbInteractiveLogon = 2 + KerbSmartCardLogon = 6 + KerbWorkstationUnlockLogon = 7 + KerbSmartCardUnlockLogon = 8 + KerbProxyLogon = 9 + KerbTicketLogon = 10 + KerbTicketUnlockLogon = 11 + KerbS4ULogon = 12 + KerbCertificateLogon = 13 + KerbCertificateS4ULogon = 14 + KerbCertificateUnlockLogon = 15 +) + +func LsaFreeReturnBuffer(buff *byte) error { + status, _, _ := syscall.SyscallN(procLsaFreeReturnBuffer.Addr(), uintptr(unsafe.Pointer(buff))) + if status != 0 { + return windows.NTStatus(status) + } + return nil +} diff --git a/drivers/shared/executor/s4u/misc.go b/drivers/shared/executor/s4u/misc.go new file mode 100644 index 00000000000..d7fdedc4b28 --- /dev/null +++ b/drivers/shared/executor/s4u/misc.go @@ -0,0 +1,25 @@ +//go:build windows +// +build windows + +package s4u + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + + procAllocateLocallyUniqueId = modadvapi32.NewProc("AllocateLocallyUniqueId") +) + +func AllocateLocallyUniqueId(result *windows.LUID) error { + r1, _, e1 := syscall.SyscallN(procAllocateLocallyUniqueId.Addr(), uintptr(unsafe.Pointer(result))) + if r1 == 0 { + return e1 + } + return nil +} diff --git a/drivers/shared/executor/s4u/s4u.go b/drivers/shared/executor/s4u/s4u.go new file mode 100644 index 00000000000..d5f0c3570df --- /dev/null +++ b/drivers/shared/executor/s4u/s4u.go @@ -0,0 +1,224 @@ +//go:build windows +// +build windows + +package s4u + +import ( + "fmt" + "math" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Best intro to Windows S4U is probably here: +// https://learn.microsoft.com/en-us/archive/msdn-magazine/2003/april/exploring-s4u-kerberos-extensions-in-windows-server-2003 + +const MSV1_0_PACKAGE_NAME = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0" +const MICROSOFT_KERBEROS_NAME = "Kerberos" + +// From the LsaLogonUser documentation: +// A token source identifies the source module — for example, +// the session manager—and the context that may be useful to that module. +// This information is included in the user token and can be retrieved by +// calling GetTokenInformation. +const TOKEN_SOURCE_NAME = "nomad_raw_exec" + +// From the LsaRegisterLogonProcess documentation: +// Name identifying the logon application. +// This should be a printable name suitable for display to administrators. +// For example, the Windows logon application might use the name "User32LogonProcess". +// This name is used by the LSA during auditing. +const LOGON_PROCESS_NAME = "nomad_raw_exec" + +// A string that identifies the origin of the logon attempt +// From the LsaLogonUser documentation: +// +// The OriginName parameter should specify meaningful information. +// For example, it might contain "TTY1" to indicate terminal one or +// "NTLM - remote node JAZZ" to indicate a network logon that uses +// NTLM through a remote node called "JAZZ". +// +// Not overly helpful. +const LOGON_ORIGIN_NAME = "nomad_raw_exec" + +func createTokenSource(sourceName string) (TokenSource, error) { + var result TokenSource + if err := AllocateLocallyUniqueId(&result.SourceId); err != nil { + return result, err + } + + for i, c := range []byte(sourceName) { + if i == 8 { + break + } + result.SourceName[i] = c + } + return result, nil +} + +func copyUtf16ToBytes(s []uint16, stringbuf []byte) { + for i, c := range s { + *(*uint16)(unsafe.Pointer(&stringbuf[i*2])) = c + } +} + +func buildContiguousLsaUnicodeString(result *LSA_UNICODE_STRING, s []uint16, stringbuf []byte) error { + bytelen := len(s) * 2 + if bytelen > math.MaxUint16 { + return fmt.Errorf("String too long to for API : %d", bytelen) + } + if bytelen > len(stringbuf) { + return fmt.Errorf("Insufficient buffer space, require %d, got %d", bytelen, len(stringbuf)) + } + + copyUtf16ToBytes(s, stringbuf) + + result.Length = uint16(bytelen) + result.MaximumLength = uint16(bytelen) + result.Buffer = (*uint16)(unsafe.Pointer(&stringbuf[0])) + return nil +} + +func sizeofMSV1_0_S4U_LOGON() uintptr { + var temp MSV1_0_S4U_LOGON + return unsafe.Sizeof(temp) +} + +func buildLocalS4uLogonInfo(username string) ([]byte, uint32, error) { + usernameUtf16 := utf16.Encode([]rune(username)) + usernameUtf16Bytelen := len(usernameUtf16) * 2 + domainUtf16 := utf16.Encode([]rune(".")) + domainUf16Bytelen := len(usernameUtf16) * 2 + + accountInformationLength := sizeofMSV1_0_S4U_LOGON() + uintptr(usernameUtf16Bytelen+domainUf16Bytelen) + accountInformation := make([]byte, accountInformationLength) + + var offset uintptr = 0 + s4uLogon := (*MSV1_0_S4U_LOGON)(unsafe.Pointer(&accountInformation[offset])) + s4uLogon.MessageType = MsV1_0S4ULogon + offset += unsafe.Sizeof(*s4uLogon) + + err := buildContiguousLsaUnicodeString(&s4uLogon.UserPrincipalName, usernameUtf16, accountInformation[offset:]) + if err != nil { + return nil, 0, fmt.Errorf("Error building UserPrincipalName buffer : %w", err) + } + offset += uintptr(usernameUtf16Bytelen) + + err = buildContiguousLsaUnicodeString(&s4uLogon.DomainName, domainUtf16, accountInformation[offset:]) + if err != nil { + return nil, 0, fmt.Errorf("Error building DomainName buffer : %w", err) + } + offset += uintptr(domainUf16Bytelen) + + return accountInformation, uint32(offset), nil +} + +func sizeofKERB_S4U_LOGON() uintptr { + var temp KERB_S4U_LOGON + return unsafe.Sizeof(temp) +} + +func buildDomainS4uLogonInfo(userUpn string) ([]byte, uint32, error) { + upnUtf16 := utf16.Encode([]rune(userUpn)) + upnUtf16ByteLen := len(upnUtf16) * 2 + + accountInformationLength := sizeofKERB_S4U_LOGON() + uintptr(len(userUpn)*2) + accountInformation := make([]byte, accountInformationLength) + + var offset uintptr = 0 + s4uLogon := (*KERB_S4U_LOGON)(unsafe.Pointer(&accountInformation[offset])) + s4uLogon.MessageType = MsV1_0S4ULogon + offset += unsafe.Sizeof(*s4uLogon) + + s4uLogon.ClientUpn.Length = uint16(upnUtf16ByteLen) + s4uLogon.ClientUpn.MaximumLength = uint16(upnUtf16ByteLen) + s4uLogon.ClientUpn.Buffer = (*uint16)(unsafe.Pointer(&accountInformation[offset])) + + copyUtf16ToBytes(upnUtf16, accountInformation[offset:]) + offset += uintptr(upnUtf16ByteLen) + + return accountInformation, uint32(offset), nil +} + +func lsaLogonUser(logonProcessHnd windows.Handle, authPackageId uint32, accountInformation []byte, accountInformationLength uint32) (result windows.Token, err error) { + var profileLen uint32 + var profileBuffer *byte + var logonId windows.LUID + var quotas QuotaLimits + var substatus windows.NTStatus + var tokenSource TokenSource + + if tokenSource, err = createTokenSource(TOKEN_SOURCE_NAME); err != nil { + return result, fmt.Errorf("Error creating token source : %w", err) + } + + err = LsaLogonUser(logonProcessHnd, + mustStringToLsaString("nomad"), + LogonTypeNetwork, + authPackageId, + &accountInformation[0], + accountInformationLength, + nil, + &tokenSource, + &profileBuffer, + &profileLen, + &logonId, + &result, + "as, + &substatus) + if err != nil { + return result, fmt.Errorf("Error calling LsaLogonUser : %w, substatus : %v", err, substatus) + } + + _ = LsaFreeReturnBuffer(profileBuffer) + + return result, nil +} + +func GetLocalS4uToken(username string) (result windows.Token, err error) { + var logonProcessHnd windows.Handle + if err = LsaRegisterLogonProcess(mustStringToLsaString(LOGON_PROCESS_NAME), &logonProcessHnd); err != nil { + return result, fmt.Errorf("Error from LsaRegisterLogonProcess : %w", err) + } + + defer func() { + _ = LsaDeregisterLogonProcess(logonProcessHnd) + }() + + var authPackageId uint32 + if err := LsaLookupAuthenticationPackage(logonProcessHnd, mustStringToLsaString(MSV1_0_PACKAGE_NAME), &authPackageId); err != nil { + return result, fmt.Errorf("Error from LsaLookupAuthenticationPackage : %w", err) + } + + accountInformation, accountInformationLength, err := buildLocalS4uLogonInfo(username) + if err != nil { + return result, fmt.Errorf("Error building account information buffer : %w", err) + } + + return lsaLogonUser(logonProcessHnd, authPackageId, accountInformation, accountInformationLength) +} + +func GetDomainS4uToken(upn string) (result windows.Token, err error) { + var logonProcessHnd windows.Handle + if err = LsaRegisterLogonProcess(mustStringToLsaString(LOGON_PROCESS_NAME), &logonProcessHnd); err != nil { + return result, fmt.Errorf("Error from LsaRegisterLogonProcess : %w", err) + } + + defer func() { + _ = LsaDeregisterLogonProcess(logonProcessHnd) + }() + + var authPackageId uint32 + if err := LsaLookupAuthenticationPackage(logonProcessHnd, mustStringToLsaString(MICROSOFT_KERBEROS_NAME), &authPackageId); err != nil { + return result, fmt.Errorf("Error from LsaLookupAuthenticationPackage : %w", err) + } + + accountInformation, accountInformationLength, err := buildDomainS4uLogonInfo(upn) + if err != nil { + return result, fmt.Errorf("Error building account information buffer : %w", err) + } + + return lsaLogonUser(logonProcessHnd, authPackageId, accountInformation, accountInformationLength) +}