Skip to content

Commit

Permalink
ssh: return informations to client (fatedier#3821)
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier authored Dec 1, 2023
1 parent 6d9e0c2 commit 95cf418
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 51 deletions.
20 changes: 15 additions & 5 deletions client/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func init() {
crypto.DefaultSalt = "frp"
}

type cancelErr struct {
Err error
}

func (e cancelErr) Error() string {
return e.Err.Error()
}

// ServiceOptions contains options for creating a new client service.
type ServiceOptions struct {
Common *v1.ClientCommonConfig
Expand Down Expand Up @@ -108,7 +116,7 @@ type Service struct {
// service context
ctx context.Context
// call cancel to stop service
cancel context.CancelFunc
cancel context.CancelCauseFunc
gracefulShutdownDuration time.Duration

connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
Expand Down Expand Up @@ -145,7 +153,7 @@ func NewService(options ServiceOptions) (*Service, error) {
}

func (svr *Service) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancelCause(ctx)
svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
svr.cancel = cancel

Expand All @@ -157,7 +165,9 @@ func (svr *Service) Run(ctx context.Context) error {
// first login to frps
svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.common.LoginFailExit))
if svr.ctl == nil {
return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled")
cancelCause := cancelErr{}
_ = errors.As(context.Cause(svr.ctx), &cancelCause)
return fmt.Errorf("login to the server failed: %v. With loginFailExit enabled, no additional retries will be attempted", cancelCause.Err)
}

go svr.keepControllerWorking()
Expand Down Expand Up @@ -280,7 +290,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
if err != nil {
xl.Warn("connect to server error: %v", err)
if firstLoginExit {
svr.cancel()
svr.cancel(cancelErr{Err: err})
}
return err
}
Expand Down Expand Up @@ -356,7 +366,7 @@ func (svr *Service) Close() {

func (svr *Service) GracefulClose(d time.Duration) {
svr.gracefulShutdownDuration = d
svr.cancel()
svr.cancel(nil)
}

func (svr *Service) stop() {
Expand Down
79 changes: 53 additions & 26 deletions pkg/config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ import (
"github.com/fatedier/frp/pkg/config/v1/validation"
)

type RegisterFlagOption func(*registerFlagOptions)

type registerFlagOptions struct {
sshMode bool
}

func WithSSHMode() RegisterFlagOption {
return func(o *registerFlagOptions) {
o.sshMode = true
}
}

type BandwidthQuantityFlag struct {
V *types.BandwidthQuantity
}
Expand All @@ -41,8 +53,9 @@ func (f *BandwidthQuantityFlag) Type() string {
return "string"
}

func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) {
registerProxyBaseConfigFlags(cmd, c.GetBaseConfig())
func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer, opts ...RegisterFlagOption) {
registerProxyBaseConfigFlags(cmd, c.GetBaseConfig(), opts...)

switch cc := c.(type) {
case *v1.TCPProxyConfig:
cmd.Flags().IntVarP(&cc.RemotePort, "remote_port", "r", 0, "remote port")
Expand Down Expand Up @@ -73,17 +86,25 @@ func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) {
}
}

func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig) {
func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig, opts ...RegisterFlagOption) {
if c == nil {
return
}
options := &registerFlagOptions{}
for _, opt := range opts {
opt(options)
}

cmd.Flags().StringVarP(&c.Name, "proxy_name", "n", "", "proxy name")
cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip")
cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port")
cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption")
cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression")
cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode")
cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)")

if !options.sshMode {
cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip")
cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port")
cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption")
cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression")
cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode")
cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)")
}
}

func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) {
Expand All @@ -94,13 +115,13 @@ func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) {
cmd.Flags().StringVarP(&c.SubDomain, "sd", "", "", "sub domain")
}

func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer) {
registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig())
func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer, opts ...RegisterFlagOption) {
registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig(), opts...)

// add visitor flags if exist
}

func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig) {
func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig, _ ...RegisterFlagOption) {
if c == nil {
return
}
Expand All @@ -113,21 +134,27 @@ func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig)
cmd.Flags().IntVarP(&c.BindPort, "bind_port", "", 0, "bind port")
}

func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig) {
cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address")
cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port")
func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig, opts ...RegisterFlagOption) {
options := &registerFlagOptions{}
for _, opt := range opts {
opt(options)
}

if !options.sshMode {
cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address")
cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port")
cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp",
fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols))
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path")
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days")
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate")
cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one")
c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
}
cmd.PersistentFlags().StringVarP(&c.User, "user", "u", "", "user")
cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp",
fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols))
cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token")
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path")
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days")
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate")
cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one")

c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
}

type PortsRangeSliceFlag struct {
Expand Down Expand Up @@ -185,7 +212,7 @@ func (f *BoolFuncFlag) Type() string {
return "bool"
}

func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) {
func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig, opts ...RegisterFlagOption) {
cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address")
cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port")
cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")
Expand Down
72 changes: 53 additions & 19 deletions pkg/ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
libio "github.com/fatedier/golib/io"
"github.com/samber/lo"
"github.com/spf13/cobra"
flag "github.com/spf13/pflag"
"golang.org/x/crypto/ssh"

"github.com/fatedier/frp/client/proxy"
Expand Down Expand Up @@ -64,6 +65,7 @@ type TunnelServer struct {
underlyingConn net.Conn
sshConn *ssh.ServerConn
sc *ssh.ServerConfig
firstChannel ssh.Channel

vc *virtual.Client
peerServerListener *netpkg.InternalListener
Expand All @@ -86,16 +88,22 @@ func (s *TunnelServer) Run() error {
if err != nil {
return err
}

s.sshConn = sshConn

addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
if err != nil {
return err
}

clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
clientCfg, pc, helpMessage, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
if err != nil {
return err
if errors.Is(err, flag.ErrHelp) {
s.writeToClient(helpMessage)
return nil
}
s.writeToClient(err.Error())
return fmt.Errorf("parse flags from ssh client error: %v", err)
}
clientCfg.Complete()
if sshConn.Permissions != nil {
Expand Down Expand Up @@ -142,7 +150,11 @@ func (s *TunnelServer) Run() error {
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
ctx := xlog.NewContext(context.Background(), xl)
go func() {
_ = s.vc.Run(ctx)
vcErr := s.vc.Run(ctx)
if vcErr != nil {
s.writeToClient(vcErr.Error())
}

// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
// One scenario is that the virtual client exits due to login failure.
s.closeDoneChOnce.Do(func() {
Expand All @@ -153,9 +165,12 @@ func (s *TunnelServer) Run() error {

s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})

if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
if ps, err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
s.writeToClient(err.Error())
log.Warn("wait proxy status ready error: %v", err)
} else {
// success
s.writeToClient(createSuccessInfo(clientCfg.User, pc, ps))
_ = sshConn.Wait()
}

Expand All @@ -168,6 +183,13 @@ func (s *TunnelServer) Run() error {
return nil
}

func (s *TunnelServer) writeToClient(data string) {
if s.firstChannel == nil {
return
}
_, _ = s.firstChannel.Write([]byte(data + "\n"))
}

func (s *TunnelServer) waitForwardAddrAndExtraPayload(
channels <-chan ssh.NewChannel,
requests <-chan *ssh.Request,
Expand Down Expand Up @@ -225,45 +247,57 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
return addr, extraPayload, nil
}

func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) {
cmd := &cobra.Command{}
func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, string, error) {
helpMessage := ""
cmd := &cobra.Command{
Use: "ssh v0@{address} [command]",
Short: "ssh v0@{address} [command]",
Run: func(*cobra.Command, []string) {},
}
args := strings.Split(extraPayload, " ")
if len(args) < 1 {
return nil, nil, fmt.Errorf("invalid extra payload")
return nil, nil, helpMessage, fmt.Errorf("invalid extra payload")
}
proxyType := strings.TrimSpace(args[0])
supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
if !lo.Contains(supportTypes, proxyType) {
return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
return nil, nil, helpMessage, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
}
pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
if pc == nil {
return nil, nil, fmt.Errorf("new proxy configurer error")
return nil, nil, helpMessage, fmt.Errorf("new proxy configurer error")
}
config.RegisterProxyFlags(cmd, pc)
config.RegisterProxyFlags(cmd, pc, config.WithSSHMode())

clientCfg := v1.ClientCommonConfig{}
config.RegisterClientCommonConfigFlags(cmd, &clientCfg)
config.RegisterClientCommonConfigFlags(cmd, &clientCfg, config.WithSSHMode())

cmd.InitDefaultHelpCmd()
if err := cmd.ParseFlags(args); err != nil {
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
if errors.Is(err, flag.ErrHelp) {
helpMessage = cmd.UsageString()
}
return nil, nil, helpMessage, err
}
// if name is not set, generate a random one
if pc.GetBaseConfig().Name == "" {
id, err := util.RandIDWithLen(8)
if err != nil {
return nil, nil, fmt.Errorf("generate random id error: %v", err)
return nil, nil, helpMessage, fmt.Errorf("generate random id error: %v", err)
}
pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
}
return &clientCfg, pc, nil
return &clientCfg, pc, helpMessage, nil
}

func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
ch, reqs, err := channel.Accept()
if err != nil {
return
}
if s.firstChannel == nil {
s.firstChannel = ch
}
go s.keepAlive(ch)

for req := range reqs {
Expand Down Expand Up @@ -320,7 +354,7 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
return conn, nil
}

func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error {
func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) (*proxy.WorkingStatus, error) {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

Expand All @@ -336,14 +370,14 @@ func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration)
}
switch ps.Phase {
case proxy.ProxyPhaseRunning:
return nil
return ps, nil
case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
return errors.New(ps.Err)
return ps, errors.New(ps.Err)
}
case <-timer.C:
return fmt.Errorf("wait proxy status ready timeout")
return nil, fmt.Errorf("wait proxy status ready timeout")
case <-s.doneCh:
return fmt.Errorf("ssh tunnel server closed")
return nil, fmt.Errorf("ssh tunnel server closed")
}
}
}
Loading

0 comments on commit 95cf418

Please sign in to comment.