-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_utils.py
115 lines (107 loc) · 4.92 KB
/
db_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from sqlalchemy import text
import psycopg2
import sqlalchemy
import time
from psycopg2 import OperationalError
from config import DB_USER, DB_PASSWORD, DB_NAME, CLOUD_SQL_CONNECTION_NAME
def create_conn():
try:
db_user = DB_USER
db_pass = DB_PASSWORD
db_name = DB_NAME
cloud_sql_connection_name = CLOUD_SQL_CONNECTION_NAME
#########################
# Creating engine for connecting PostgreSQL database
#########################
engine = sqlalchemy.create_engine(
sqlalchemy.engine.url.URL.create(
drivername="postgresql+psycopg2",
username=db_user,
password=db_pass,
database=db_name,
host=f'/cloudsql/{cloud_sql_connection_name}',
),
)
return engine
except OperationalError:
return 'Unable to connect to the database, please try again later.'
def save_token_info(connection, STRAVA_CLIENT_ID, athlete_id, access_token, refresh_token, expires_at, expires_in, current_scope):
# First, check if this athlete already has tokens in DB
result = connection.execute(
text("SELECT * FROM strava_access_tokens WHERE athlete_id = :athlete_id"),
{"athlete_id": athlete_id}
)
# Fetch the result (if any)
row = result.fetchone()
existing_token_info = row._asdict() if row else None
# Now, depending on whether we already have tokens for this athlete in DB, either update or insert new record
if existing_token_info:
current_time = int(time.time())
update_result = connection.execute(text("UPDATE strava_access_tokens SET total_refresh_checks = total_refresh_checks + 1 WHERE athlete_id = :athlete_id RETURNING total_refresh_checks"), {"athlete_id": athlete_id})
total_refresh_checks = update_result.fetchone()[0]
# If the access token is expired, refresh it and increment total_refreshes
if current_time > existing_token_info['expires_at']:
result = connection.execute(
text("""
UPDATE strava_access_tokens
SET client_id=:client_id,
access_token=:access_token,
refresh_token=:refresh_token,
expires_at=:expires_at,
expires_in=:expires_in,
total_refreshes = total_refreshes + 1,
current_scope=:current_scope,
last_refreshed_by = 'gcp-strava.wl.r.appspot.com'
WHERE athlete_id=:athlete_id
RETURNING pk_id
"""),
{
"client_id": STRAVA_CLIENT_ID,
"athlete_id": athlete_id,
"access_token": access_token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"expires_in": expires_in,
"current_scope":current_scope,
}
)
else:
# If the access token is not expired, only update the expires_in time, and last_refreshed_by, and current_scope
result = connection.execute(
text("""
UPDATE strava_access_tokens
SET expires_in=:expires_in,
last_updated=now(),
current_scope=:current_scope,
last_refreshed_by = 'gcp-strava.wl.r.appspot.com'
WHERE athlete_id=:athlete_id
RETURNING pk_id
"""),
{
"athlete_id": athlete_id,
"expires_in": expires_in,
"current_scope": current_scope
}
)
else:
result = connection.execute(
text("""
INSERT INTO strava_access_tokens (client_id, athlete_id, access_token, refresh_token, expires_at, expires_in, last_refreshed_by, current_scope)
VALUES (:client_id, :athlete_id, :access_token, :refresh_token, :expires_at, :expires_in,'gcp-strava.wl.r.appspot.com', :current_scope)
RETURNING pk_id
"""),
{
"client_id": STRAVA_CLIENT_ID,
"athlete_id": athlete_id,
"access_token": access_token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"expires_in": expires_in,
"current_scope": current_scope
}
)
update_result = connection.execute(text("SELECT total_refresh_checks FROM strava_access_tokens WHERE athlete_id = :athlete_id"), {"athlete_id": athlete_id})
total_refresh_checks = update_result.fetchone()[0]
# Fetch the primary key id of the row just inserted/updated
pk_id = result.fetchone()[0]
return pk_id, total_refresh_checks