-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcbo_stat_load
executable file
·253 lines (219 loc) · 10.1 KB
/
cbo_stat_load
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/usr/bin/env python3
'''
cbo_stat_load utility
Loads statistics from JSON file to a test database.
'''
import psycopg2
import argparse
import sys
import os, platform
import json
import packaging.version as version
import logging
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from cbo_stat_dump import __version__
logger = logging.getLogger(__name__)
def get_connection_dict(args):
envOpts = os.environ
host = args.host
port = args.port
username = args.username
password = args.password
db = args.database;
logger.debug (f"Connecting to database host={host}, port={port}, username={username}, db={db}")
connectionDict = {
'host': host,
'port': port,
'user': username,
'password': password,
'database': db
}
return connectionDict
def connect_database(connectionDict):
conn = psycopg2.connect(**connectionDict)
cursor = conn.cursor()
return conn, cursor
def parse_cmd_line():
parser = argparse.ArgumentParser(
prog = 'cbo_stat_load',
description = 'Import statistics from JSON file to pg_class and pg_statistics system tables',
add_help=False
)
parser.add_argument('--help', action='store_true', help='show this help message and exit')
parser.add_argument('--debug', action='store_true', help='Set log level to DEBUG')
parser.add_argument('-h', '--host', help='Hostname or IP address, default localhost', default="localhost")
parser.add_argument('-p', '--port', help='Port number, default 5432')
parser.add_argument('-d', '--database', help='Database name')
parser.add_argument('-u', '--username', help='Username, default postgres')
parser.add_argument('-W', '--password', help='Password, default no password')
parser.add_argument('-s', '--stat_file', help='JSON file with table statistics')
parser.add_argument('-o', '--output_file', help='SQL file with insert statemsnts to import statistics')
parser.add_argument('-D', '--dry_run', action='store_true', help='Do not execute insert statements, only produce output_file')
parser.add_argument('--yb_mode',action='store_true', help='Use this option to import statistics in YugabyteDB')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.debug:
logging.basicConfig(level=logging.DEBUG)
if args.help:
parser.print_help()
sys.exit(0)
if args.yb_mode:
if args.username is None:
args.username = 'yugabyte'
if args.database is None:
args.database = 'yugabyte'
if args.port is None:
args.port = 5433
else:
if args.username is None:
args.username = 'postgres'
if args.database is None:
args.database = 'postgres'
if args.port is None:
args.port = 5432
if args.dry_run:
if args.output_file is None:
args.output_file = f"import_statistics.sql"
else:
if args.database is None:
logger.fatal('\nMust specify target database with [--database|-D] option\n\n')
sys.exit(1)
return args
def output_and_execute_query(args, cursor, query):
if args.output_file:
with open(args.output_file, 'a') as output_file:
output_file.write(query)
output_file.write("\n")
if not args.dry_run:
cursor.execute(query)
def update_pg_statistic(args, cursor, stat_json):
columnTypes = { "stainherit": "boolean",
"stanullfrac": "real",
"stawidth": "integer",
"stadistinct": "real",
"stakind1": "smallint",
"stakind2": "smallint",
"stakind3": "smallint",
"stakind4": "smallint",
"stakind5": "smallint",
"staop1": "oid",
"staop2": "oid",
"staop3": "oid",
"staop4": "oid",
"staop5": "oid",
"stacoll1": "oid",
"stacoll2": "oid",
"stacoll3": "oid",
"stacoll4": "oid",
"stacoll5": "oid",
"stanumbers1": "real[]",
"stanumbers2": "real[]",
"stanumbers3": "real[]",
"stanumbers4": "real[]",
"stanumbers5": "real[]"}
if 'typnspname' not in stat_json:
stavaluesType = 'pg_catalog'
else:
stavaluesType = stat_json['typnspname']
stavaluesType = stavaluesType + '.' + stat_json["typname"]
for i in range (1, 5):
columnTypes["stavalues" + str(i)] = stavaluesType
columnValues = ""
for columnName, columnType in columnTypes.items():
val = stat_json[columnName]
sql_val = ""
if (columnName.startswith('stavalues')):
# Convert a python list into SQL array string representation '{"...", "..."}'
sql_array = ""
if val is None:
sql_val = "NULL"
else:
if isinstance(val[0], str):
# Escape backslash and double quotes with backslash, but single quote with single quote
val = [f'"%s"' % e.replace('\\', '\\\\').replace('"', '\\"').replace('\'', '\'\'') for e in val]
sql_array = ', '.join(val)
elif isinstance(val[0], dict):
# Array of JSON Objects, the type of the column must be jsonb
assert(stavaluesType == 'pg_catalog.jsonb')
val = [f'"%s"' % str(e).replace('"', '\\\\\\\"').replace("'", '\\"').replace('True', 'true').replace('False', 'false') for e in val]
sql_array = ', '.join(val)
else:
val = [f"%s" % e for e in val]
sql_array = ', '.join(val)
sql_array = f"'{{{sql_array}}}'"
sql_val = f"array_in({sql_array}, '{columnType}'::regtype, -1)::anyarray"
elif isinstance(val, list):
assert(columnType == 'real[]')
sql_val = str(val);
sql_val = f"'{{{sql_val[1:-1]}}}'::{columnType}"
elif val is None:
sql_val = 'NULL::' + columnType
else:
sql_val = str(stat_json[columnName]) + "::" + columnType
columnValues += ", " + sql_val
starelid = f"'{stat_json['nspname']}.{stat_json['relname']}'::regclass"
# Find staattnum from starelid and "attname" from statistics
# query = f"SELECT a.attnum FROM pg_attribute a WHERE a.attrelid = {starelid} and a.attname = '{stat_json['attname']}'"
# cursor.execute(query)
# staattnum = cursor.fetchone()[0]
staattnum_subquery = f"(SELECT a.attnum FROM pg_attribute a WHERE a.attrelid = {starelid} and a.attname = '{stat_json['attname']}')"
query = f"""DELETE FROM pg_statistic WHERE starelid = {starelid} AND staattnum = {staattnum_subquery};
INSERT INTO pg_statistic VALUES ({starelid}, {staattnum_subquery}{columnValues});"""
output_and_execute_query(args, cursor, query)
def update_reltuples(args, cursor, relnamespace, relname, reltuples, relpages, relallvisible):
query = f"UPDATE pg_class SET reltuples = {reltuples}, relpages = {relpages}, relallvisible = {relallvisible} WHERE relnamespace = '{relnamespace}'::regnamespace AND (relname = '{relname}' OR relname = '{relname}_pkey');"
output_and_execute_query(args, cursor, query)
def enable_write_on_sys_tables(args, cursor):
query = "SET yb_non_ddl_txn_for_sys_tables_allowed = ON;"
output_and_execute_query(args, cursor, query)
def disable_write_on_sys_tables(args, cursor):
query = "SET yb_non_ddl_txn_for_sys_tables_allowed = OFF;"
output_and_execute_query(args, cursor, query)
def update_pg_yb_catalog_version(args, cursor):
query = "update pg_yb_catalog_version set current_version=current_version+1 where db_oid=1;"
output_and_execute_query(args, cursor, query)
def check_version_compatibility(version_string):
current_version = version.parse(__version__)
stat_version = version.parse(version_string)
if (current_version == stat_version or
current_version.major == stat_version.major):
return
else:
logger.fatal(f'\nVersion mismatch: statistics version ({version_string}) does not match with version of cbo_stat_load ({__version__})\n\n')
sys.exit(1)
def import_statistics(args, cursor):
if not os.path.exists(args.stat_file):
logger.debug (f"Statistics file {args.stat_file} does not exist")
logger.debug (f"Importing statistics from file {args.stat_file}")
statistics_json = json.load(open(args.stat_file, 'r'))
if ('version' in statistics_json):
check_version_compatibility(statistics_json["version"])
if args.yb_mode:
enable_write_on_sys_tables(args, cursor)
for pg_class_row_json in statistics_json['pg_class']:
logger.debug (pg_class_row_json["nspname"] + "." + pg_class_row_json["relname"] + " = " + str(pg_class_row_json["reltuples"]))
update_reltuples(args, cursor, pg_class_row_json["nspname"], pg_class_row_json["relname"], pg_class_row_json["reltuples"], pg_class_row_json["relpages"], pg_class_row_json["relallvisible"])
for pg_statistic_row_json in statistics_json['pg_statistic']:
logger.debug (pg_statistic_row_json["nspname"] + "." + pg_statistic_row_json["relname"] + "." + pg_statistic_row_json["attname"])
update_pg_statistic(args, cursor, pg_statistic_row_json)
if args.yb_mode:
update_pg_yb_catalog_version(args, cursor)
disable_write_on_sys_tables(args, cursor)
def main():
args = parse_cmd_line()
# clear the data in the info file
if args.output_file is not None:
with open(args.output_file,'w') as file:
pass
cursor = None
if not args.dry_run:
connectionDict = get_connection_dict(args)
conn, cursor = connect_database(connectionDict)
import_statistics(args, cursor)
if not args.dry_run:
conn.commit()
cursor.close()
conn.close()
if __name__ == "__main__":
main()