Skip to content

Commit

Permalink
Land rapid7#18276, Add sasl scram 256 auth support to postgres modules
Browse files Browse the repository at this point in the history
  • Loading branch information
dwelch-r7 authored Aug 18, 2023
2 parents 8e89a6a + 98ac76d commit 1878c08
Show file tree
Hide file tree
Showing 9 changed files with 751 additions and 39 deletions.
7 changes: 6 additions & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ PATH
mqtt
msgpack (~> 1.6.0)
nessus_rest
net-imap
net-ldap
net-smtp
net-ssh
Expand Down Expand Up @@ -181,6 +182,7 @@ GEM
cookiejar (0.3.3)
crass (1.0.6)
daemons (1.4.1)
date (3.3.3)
debug (1.8.0)
irb (>= 1.5.0)
reline (>= 0.3.1)
Expand Down Expand Up @@ -297,6 +299,9 @@ GEM
mustermann (3.0.0)
ruby2_keywords (~> 0.0.1)
nessus_rest (0.1.6)
net-imap (0.3.7)
date
net-protocol
net-ldap (0.18.0)
net-protocol (0.2.1)
timeout
Expand Down Expand Up @@ -498,7 +503,7 @@ GEM
thor (1.2.2)
tilt (2.2.0)
timecop (0.9.6)
timeout (0.3.2)
timeout (0.4.0)
ttfunk (1.7.0)
tzinfo (2.0.6)
concurrent-ruby (~> 1.0)
Expand Down
12 changes: 8 additions & 4 deletions lib/postgres/buffer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module Db
class Buffer

class Error < RuntimeError; end
class EOF < Error; end
class EOF < Error; end

def self.from_string(str)
new(str)
Expand All @@ -20,7 +20,7 @@ def self.from_string(str)
def self.of_size(size)
raise ArgumentError if size < 0
new('#' * size)
end
end

def initialize(content)
@size = content.size
Expand All @@ -36,6 +36,10 @@ def position
@position
end

def peek
@content[@position]
end

def position=(new_pos)
raise ArgumentError if new_pos < 0 or new_pos > @size
@position = new_pos
Expand Down Expand Up @@ -67,11 +71,11 @@ def write(str)
def copy_from_stream(stream, n)
raise ArgumentError if n < 0
while n > 0
str = stream.read(n)
str = stream.read(n)
write(str)
n -= str.size
end
raise if n < 0
raise if n < 0
end

NUL = "\000"
Expand Down
82 changes: 71 additions & 11 deletions lib/postgres/postgres-pr/connection.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
require 'postgres_msf'
require 'postgres/postgres-pr/message'
require 'postgres/postgres-pr/version'
require 'postgres/postgres-pr/scram_sha_256'
require 'uri'
require 'rex/socket'

Expand Down Expand Up @@ -65,7 +66,7 @@ def initialize(database, user, password=nil, uri = nil)
# Check if the password supplied is a Postgres-style md5 hash
md5_hash_match = password.match(/^md5([a-f0-9]{32})$/)

@conn << StartupMessage.new(PROTO_VERSION, 'user' => user, 'database' => database).dump
write_message(StartupMessage.new(PROTO_VERSION, 'user' => user, 'database' => database))

loop do
msg = Message.read(@conn)
Expand All @@ -74,11 +75,11 @@ def initialize(database, user, password=nil, uri = nil)
when AuthentificationClearTextPassword
raise ArgumentError, "no password specified" if password.nil?
raise AuthenticationMethodMismatch, "Server expected clear text password auth" if md5_hash_match
@conn << PasswordMessage.new(password).dump
write_message(PasswordMessage.new(password))
when AuthentificationCryptPassword
raise ArgumentError, "no password specified" if password.nil?
raise AuthenticationMethodMismatch, "Server expected crypt password auth" if md5_hash_match
@conn << PasswordMessage.new(password.crypt(msg.salt)).dump
write_message(PasswordMessage.new(password.crypt(msg.salt)))
when AuthentificationMD5Password
raise ArgumentError, "no password specified" if password.nil?
require 'digest/md5'
Expand All @@ -91,8 +92,10 @@ def initialize(database, user, password=nil, uri = nil)
m = Digest::MD5.hexdigest(m + msg.salt)
m = 'md5' + m

@conn << PasswordMessage.new(m).dump
write_message(PasswordMessage.new(m))

when AuthenticationSASL
negotiate_sasl(msg, user, password)
when UnknownAuthType
raise "unknown auth type '#{msg.auth_type}' with buffer content:\n#{Rex::Text.to_hex_dump(msg.buffer.content)}"

Expand All @@ -101,7 +104,7 @@ def initialize(database, user, password=nil, uri = nil)

when AuthentificationOk
when ErrorResponse
raise msg.field_values.join("\t")
handle_server_error_message(msg)
when NoticeResponse
@notice_processor.call(msg) if @notice_processor
when ParameterStatus
Expand All @@ -124,15 +127,15 @@ def close
@conn = nil
end

class Result
class Result
attr_accessor :rows, :fields, :cmd_tag
def initialize(rows=[], fields=[])
@rows, @fields = rows, fields
end
end

def query(sql)
@conn << Query.dump(sql)
write_message(Query.new(sql))

result = Result.new
errors = []
Expand Down Expand Up @@ -167,18 +170,69 @@ def query(sql)
result
end


# @param [AuthenticationSASL] msg
# @param [String] user
# @param [String,nil] password
def negotiate_sasl(msg, user, password = nil)
if msg.mechanisms.include?('SCRAM-SHA-256')
scram_sha_256 = ScramSha256.new
# Start negotiating scram, additionally wrapping in SASL and unwrapping the SASL responses
scram_sha_256.negotiate(user, password) do |state, value|
if state == :client_first
sasl_initial_response_message = SaslInitialResponseMessage.new(
mechanism: 'SCRAM-SHA-256',
value: value
)

write_message(sasl_initial_response_message)

sasl_continue = Message.read(@conn)
raise handle_server_error_message(sasl_continue) if sasl_continue.is_a?(ErrorResponse)
raise AuthenticationMethodMismatch, "Did not receive AuthenticationSASLContinue - instead got #{sasl_continue}" unless sasl_continue.is_a?(AuthenticationSASLContinue)

server_first_string = sasl_continue.value
server_first_string
elsif state == :client_final
sasl_initial_response_message = SASLResponseMessage.new(
value: value
)

write_message(sasl_initial_response_message)

server_final = Message.read(@conn)
raise handle_server_error_message(server_final) if server_final.is_a?(ErrorResponse)
raise AuthenticationMethodMismatch, "Did not receive AuthenticationSASLFinal - instead got #{server_final}" unless server_final.is_a?(AuthenticationSASLFinal)

server_final_string = server_final.value
server_final_string
else
raise AuthenticationMethodMismatch, "Unexpected negotiation state #{state}"
end
end
else
raise AuthenticationMethodMismatch, "unsupported SASL mechanisms #{msg.mechanisms.inspect}"
end
end

DEFAULT_PORT = 5432
DEFAULT_HOST = 'localhost'
DEFAULT_PATH = '/tmp'
DEFAULT_URI =
DEFAULT_PATH = '/tmp'
DEFAULT_URI =
if RUBY_PLATFORM.include?('win')
'tcp://' + DEFAULT_HOST + ':' + DEFAULT_PORT.to_s
'tcp://' + DEFAULT_HOST + ':' + DEFAULT_PORT.to_s
else
'unix:' + File.join(DEFAULT_PATH, '.s.PGSQL.' + DEFAULT_PORT.to_s)
'unix:' + File.join(DEFAULT_PATH, '.s.PGSQL.' + DEFAULT_PORT.to_s)
end

private

# @param [ErrorResponse] server_error_message
# @raise [RuntimeError]
def handle_server_error_message(server_error_message)
raise server_error_message.field_values.join("\t")
end

# tcp://localhost:5432
# unix:/tmp/.s.PGSQL.5432
def establish_connection(uri)
Expand All @@ -196,6 +250,12 @@ def establish_connection(uri)
raise 'unrecognized uri scheme format (must be tcp or unix)'
end
end

# @param [Message] message
# @return [Numeric] The byte count successfully written to the currently open connection
def write_message(message)
@conn << message.dump
end
end

end # module PostgresPR
Expand Down
Loading

0 comments on commit 1878c08

Please sign in to comment.