Skip to content

Commit

Permalink
Refactor SQLAlchemy and remove default parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Lxstr committed Jan 10, 2024
1 parent fe3f136 commit f07e28f
Showing 1 changed file with 115 additions and 37 deletions.
152 changes: 115 additions & 37 deletions src/flask_session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,21 @@ class RedisSessionInterface(ServerSideSessionInterface):
:param key_prefix: A prefix that is added to all Redis store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
"""

serializer = pickle
session_class = RedisSession

def __init__(self, redis, key_prefix, use_signer=False, permanent=True):
def __init__(
self, redis, key_prefix, use_signer, permanent, sid_length
):
if redis is None:
from redis import Redis

redis = Redis()
self.redis = redis
super().__init__(redis, key_prefix, use_signer, permanent)
super().__init__(redis, key_prefix, use_signer, permanent, sid_length)

def fetch_session_sid(self, sid):
if not isinstance(sid, str):
Expand Down Expand Up @@ -197,18 +201,21 @@ class MemcachedSessionInterface(ServerSideSessionInterface):
:param key_prefix: A prefix that is added to all Memcached store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
"""

serializer = pickle
session_class = MemcachedSession

def __init__(self, client, key_prefix, use_signer=False, permanent=True):
def __init__(
self, client, key_prefix, use_signer, permanent, sid_length
):
if client is None:
client = self._get_preferred_memcache_client()
if client is None:
raise RuntimeError('no memcache module found')
self.client = client
super().__init__(client, key_prefix, use_signer, permanent)
super().__init__(client, key_prefix, use_signer, permanent, sid_length)

def _get_preferred_memcache_client(self):
servers = ['127.0.0.1:11211']
Expand Down Expand Up @@ -288,18 +295,27 @@ class FileSystemSessionInterface(ServerSideSessionInterface):
:param key_prefix: A prefix that is added to FileSystemCache store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
"""

session_class = FileSystemSession

def __init__(self, cache_dir, threshold, mode, key_prefix,
use_signer=False, permanent=True):
def __init__(
self,
cache_dir,
threshold,
mode,
key_prefix,
use_signer,
permanent,
sid_length,
):
from cachelib.file import FileSystemCache

self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode)
super().__init__(self.cache, key_prefix, use_signer, permanent)
super().__init__(self.cache, key_prefix, use_signer, permanent, sid_length)

def fetch_session_sid(self, sid):

data = self.cache.get(self.key_prefix + sid)
if data is not None:
return self.session_class(data, sid=sid)
Expand All @@ -313,8 +329,9 @@ def save_session(self, app, session, response):
if not session:
if session.modified:
self.cache.delete(self.key_prefix + session.sid)
response.delete_cookie(app.config["SESSION_COOKIE_NAME"],
domain=domain, path=path)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return

expires = self.get_expiration_time(app, session)
Expand All @@ -335,6 +352,7 @@ class MongoDBSessionInterface(ServerSideSessionInterface):
:param key_prefix: A prefix that is added to all MongoDB store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
"""

serializer = pickle
Expand All @@ -345,23 +363,26 @@ def __init__(
client,
db,
collection,
tz_aware,
key_prefix,
use_signer=False,
permanent=True,
tz_aware=False,
use_signer,
permanent,
sid_length,
):
import pymongo

# Ensure that the client exists, support for tz_aware MongoClient
if client is None:
if tz_aware:
client = pymongo.MongoClient(tz_aware=tz_aware)
else:
client = pymongo.MongoClient()

self.client = client
self.store = client[db][collection]
self.tz_aware = tz_aware
self.use_deprecated_method = int(pymongo.version.split(".")[0]) < 4
super().__init__(self.store, key_prefix, use_signer, permanent)
super().__init__(self.store, key_prefix, use_signer, permanent, sid_length)

def fetch_session_sid(self, sid):
# Get the session document from the database
Expand Down Expand Up @@ -403,7 +424,7 @@ def save_session(self, app, session, response):
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)

# Generate a storage session key from the session id
# Generate a prefixed session id from the session id as a storage key
prefixed_session_id = self.key_prefix + session.sid

# If the session is empty, do not save it to the database or set a cookie
Expand Down Expand Up @@ -453,23 +474,58 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface):
:param key_prefix: A prefix that is added to all store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
:param sequence: The sequence to use for the primary key if needed.
:param schema: The db schema to use
:param bind_key: The db bind key to use
"""

serializer = pickle
session_class = SqlAlchemySession

def __init__(self, app, db, table, key_prefix, use_signer=False,
permanent=True):
def __init__(
self,
app,
db,
table,
sequence,
schema,
bind_key,
key_prefix,
use_signer,
permanent,
sid_length,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy(app)

self.db = db
super().__init__(self.db, key_prefix, use_signer, permanent)
self.sequence = sequence
self.schema = schema
self.bind_key = bind_key
super().__init__(self.db, key_prefix, use_signer, permanent, sid_length)

# Create the Session database model
class Session(self.db.Model):
__tablename__ = table

id = self.db.Column(self.db.Integer, primary_key=True)
if self.schema is not None:
__table_args__ = {"schema": self.schema, "keep_existing": True}
else:
__table_args__ = {"keep_existing": True}

if self.bind_key is not None:
__bind_key__ = self.bind_key

# Set the database columns, support for id sequences
if sequence:
id = self.db.Column(
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column(self.db.Integer, primary_key=True)
session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)
Expand All @@ -480,23 +536,29 @@ def __init__(self, session_id, data, expiry):
self.expiry = expiry

def __repr__(self):
return '<Session data %s>' % self.data
return "<Session data %s>" % self.data

with app.app_context():
self.db.create_all()

# self.db.create_all()
self.sql_session_model = Session

def fetch_session_sid(self, sid):

# Get the session document from the database
store_id = self.key_prefix + sid
saved_session = self.sql_session_model.query.filter_by(
session_id=store_id).first()
session_id=store_id
).first()

# If the expiration time is less than or equal to the current time (expired), delete the document
if saved_session and (
not saved_session.expiry or saved_session.expiry <= datetime.utcnow()
):
# Delete expired session
self.db.session.delete(saved_session)
self.db.session.commit()
saved_session = None

# If the session document still exists after checking for expiration, load the session data from the document
if saved_session:
try:
val = saved_session.data
Expand All @@ -509,28 +571,44 @@ def fetch_session_sid(self, sid):
def save_session(self, app, session, response):
if not self.should_set_cookie(app, session):
return

# Get the domain and path for the cookie from the app
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
store_id = self.key_prefix + session.sid
saved_session = self.sql_session_model.query.filter_by(
session_id=store_id).first()

# Generate a prefixed session id
prefixed_session_id = self.key_prefix + session.sid

# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the session document from the database and tell the client to delete the cookie
if session.modified:
if saved_session:
self.db.session.delete(saved_session)
self.db.session.commit()
response.delete_cookie(app.config["SESSION_COOKIE_NAME"],
domain=domain, path=path)
self.sql_session_model.query.filter_by(
session_id=prefixed_session_id
).delete()
self.db.session.commit()
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return

expires = self.get_expiration_time(app, session)
# Serialize session data and get expiration time
val = self.serializer.dumps(dict(session))
expires = self.get_expiration_time(app, session)

# Update or create the session in the database
saved_session = self.sql_session_model.query.filter_by(
session_id=prefixed_session_id
).first()
if saved_session:
saved_session.data = val
saved_session.expiry = expires
self.db.session.commit()
else:
new_session = self.sql_session_model(store_id, val, expires)
self.db.session.add(new_session)
self.db.session.commit()
saved_session = self.sql_session_model(
session_id=prefixed_session_id, data=val, expiry=expires
)
self.db.session.add(saved_session)

# Commit changes and set the cookie
self.db.session.commit()
self.set_cookie_to_response(app, session, response, expires)

0 comments on commit f07e28f

Please sign in to comment.