@@ -104,8 +104,10 @@ def _read_password_file(passfile: pathlib.Path) \
104
104
105
105
def _read_password_from_pgpass (
106
106
* , passfile : typing .Optional [pathlib .Path ],
107
- hosts : typing .List [typing .Union [str , typing .Tuple [str , int ]]],
108
- port : int , database : str , user : str ):
107
+ hosts : typing .List [str ],
108
+ ports : typing .List [int ],
109
+ database : str ,
110
+ user : str ):
109
111
"""Parse the pgpass file and return the matching password.
110
112
111
113
:return:
@@ -116,7 +118,7 @@ def _read_password_from_pgpass(
116
118
if not passtab :
117
119
return None
118
120
119
- for host in hosts :
121
+ for host , port in zip ( hosts , ports ) :
120
122
if host .startswith ('/' ):
121
123
# Unix sockets get normalized into 'localhost'
122
124
host = 'localhost'
@@ -137,27 +139,83 @@ def _read_password_from_pgpass(
137
139
return None
138
140
139
141
142
+ def _validate_port_spec (hosts , port ):
143
+ if isinstance (port , list ):
144
+ # If there is a list of ports, its length must
145
+ # match that of the host list.
146
+ if len (port ) != len (hosts ):
147
+ raise exceptions .InterfaceError (
148
+ 'could not match {} port numbers to {} hosts' .format (
149
+ len (port ), len (hosts )))
150
+ else :
151
+ port = [port for _ in range (len (hosts ))]
152
+
153
+ return port
154
+
155
+
156
+ def _parse_hostlist (hostlist , port ):
157
+ if ',' in hostlist :
158
+ # A comma-separated list of host addresses.
159
+ hostspecs = hostlist .split (',' )
160
+ else :
161
+ hostspecs = [hostlist ]
162
+
163
+ hosts = []
164
+ hostlist_ports = []
165
+
166
+ if not port :
167
+ portspec = os .environ .get ('PGPORT' )
168
+ if portspec :
169
+ if ',' in portspec :
170
+ default_port = [int (p ) for p in portspec .split (',' )]
171
+ else :
172
+ default_port = int (portspec )
173
+ else :
174
+ default_port = 5432
175
+
176
+ default_port = _validate_port_spec (hostspecs , default_port )
177
+
178
+ else :
179
+ port = _validate_port_spec (hostspecs , port )
180
+
181
+ for i , hostspec in enumerate (hostspecs ):
182
+ addr , _ , hostspec_port = hostspec .partition (':' )
183
+ hosts .append (addr )
184
+
185
+ if not port :
186
+ if hostspec_port :
187
+ hostlist_ports .append (int (hostspec_port ))
188
+ else :
189
+ hostlist_ports .append (default_port [i ])
190
+
191
+ if not port :
192
+ port = hostlist_ports
193
+
194
+ return hosts , port
195
+
196
+
140
197
def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
141
198
password , passfile , database , ssl ,
142
199
connect_timeout , server_settings ):
143
- if host is not None and not isinstance (host , str ):
144
- raise TypeError (
145
- 'host argument is expected to be str, got {!r}' .format (
146
- type (host )))
200
+ # `auth_hosts` is the version of host information for the purposes
201
+ # of reading the pgpass file.
202
+ auth_hosts = None
147
203
148
204
if dsn :
149
205
parsed = urllib .parse .urlparse (dsn )
150
206
151
207
if parsed .scheme not in {'postgresql' , 'postgres' }:
152
208
raise ValueError (
153
- 'invalid DSN: scheme is expected to be either of '
209
+ 'invalid DSN: scheme is expected to be either '
154
210
'"postgresql" or "postgres", got {!r}' .format (parsed .scheme ))
155
211
156
- if parsed .port and port is None :
157
- port = int (parsed .port )
212
+ if not host and parsed .netloc :
213
+ if '@' in parsed .netloc :
214
+ auth , _ , hostspec = parsed .netloc .partition ('@' )
215
+ else :
216
+ hostspec = parsed .netloc
158
217
159
- if parsed .hostname and host is None :
160
- host = parsed .hostname
218
+ host , port = _parse_hostlist (hostspec , port )
161
219
162
220
if parsed .path and database is None :
163
221
database = parsed .path
@@ -178,13 +236,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
178
236
179
237
if 'host' in query :
180
238
val = query .pop ('host' )
181
- if host is None :
182
- host = val
239
+ if not host and val :
240
+ host , port = _parse_hostlist ( val , port )
183
241
184
242
if 'port' in query :
185
- val = int ( query .pop ('port' ) )
186
- if port is None :
187
- port = val
243
+ val = query .pop ('port' )
244
+ if not port and val :
245
+ port = [ int ( p ) for p in val . split ( ',' )]
188
246
189
247
if 'dbname' in query :
190
248
val = query .pop ('dbname' )
@@ -222,40 +280,44 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
222
280
else :
223
281
server_settings = {** query , ** server_settings }
224
282
225
- # On env-var -> connection parameter conversion read here:
226
- # https://www.postgresql.org/docs/current/static/libpq-envars.html
227
- # Note that env values may be an empty string in cases when
228
- # the variable is "unset" by setting it to an empty value
229
- # `auth_hosts` is the version of host information for the purposes
230
- # of reading the pgpass file.
231
- auth_hosts = None
232
- if host is None :
233
- host = os .getenv ('PGHOST' )
234
- if not host :
235
- auth_hosts = ['localhost' ]
283
+ if not host :
284
+ hostspec = os .environ .get ('PGHOST' )
285
+ if hostspec :
286
+ host , port = _parse_hostlist (hostspec , port )
236
287
237
- if _system == 'Windows' :
238
- host = ['localhost' ]
239
- else :
240
- host = ['/tmp' , '/private/tmp' ,
241
- '/var/pgsql_socket' , '/run/postgresql' ,
242
- 'localhost' ]
288
+ if not host :
289
+ auth_hosts = ['localhost' ]
290
+
291
+ if _system == 'Windows' :
292
+ host = ['localhost' ]
293
+ else :
294
+ host = ['/run/postgresql' , '/var/run/postgresql' ,
295
+ '/tmp' , '/private/tmp' , 'localhost' ]
243
296
244
297
if not isinstance (host , list ):
245
298
host = [host ]
246
299
247
300
if auth_hosts is None :
248
301
auth_hosts = host
249
302
250
- if port is None :
251
- port = os .getenv ('PGPORT' )
252
- if port :
253
- port = int (port )
303
+ if not port :
304
+ portspec = os .environ .get ('PGPORT' )
305
+ if portspec :
306
+ if ',' in portspec :
307
+ port = [int (p ) for p in portspec .split (',' )]
308
+ else :
309
+ port = int (portspec )
254
310
else :
255
311
port = 5432
312
+
313
+ elif isinstance (port , (list , tuple )):
314
+ port = [int (p ) for p in port ]
315
+
256
316
else :
257
317
port = int (port )
258
318
319
+ port = _validate_port_spec (host , port )
320
+
259
321
if user is None :
260
322
user = os .getenv ('PGUSER' )
261
323
if not user :
@@ -293,19 +355,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
293
355
294
356
if passfile is not None :
295
357
password = _read_password_from_pgpass (
296
- hosts = auth_hosts , port = port , database = database , user = user ,
358
+ hosts = auth_hosts , ports = port ,
359
+ database = database , user = user ,
297
360
passfile = passfile )
298
361
299
362
addrs = []
300
- for h in host :
363
+ for h , p in zip ( host , port ) :
301
364
if h .startswith ('/' ):
302
365
# UNIX socket name
303
366
if '.s.PGSQL.' not in h :
304
- h = os .path .join (h , '.s.PGSQL.{}' .format (port ))
367
+ h = os .path .join (h , '.s.PGSQL.{}' .format (p ))
305
368
addrs .append (h )
306
369
else :
307
370
# TCP host/port
308
- addrs .append ((h , port ))
371
+ addrs .append ((h , p ))
309
372
310
373
if not addrs :
311
374
raise ValueError (
@@ -329,7 +392,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
329
392
sslmode = SSLMODES [ssl ]
330
393
except KeyError :
331
394
modes = ', ' .join (SSLMODES .keys ())
332
- raise ValueError ('`sslmode` parameter must be one of ' + modes )
395
+ raise exceptions .InterfaceError (
396
+ '`sslmode` parameter must be one of: {}' .format (modes ))
333
397
334
398
# sslmode 'allow' is currently handled as 'prefer' because we're
335
399
# missing the "retry with SSL" behavior for 'allow', but do have the
0 commit comments